view hgext/states.py @ 51:d98e06ab8320

move extensions in a hgext directory
author Pierre-Yves David <pierre-yves.david@logilab.fr>
date Thu, 08 Sep 2011 17:15:20 +0200
parents states.py@dca86448d736
children ad1a4fb0fc49
line wrap: on
line source

# states.py - introduce the state concept for mercurial changeset
#
# Copyright 2011 Pierre-Yves David <pierre-yves.david@ens-lyon.org>
#                Logilab SA        <contact@logilab.fr>
#                Augie Fackler     <durin42@gmail.com>
#
# This software may be used and distributed according to the terms of the
# GNU General Public License version 2 or any later version.

'''introduce the state concept for mercurial changeset

Change can be in the following state:

0 immutable
1 mutable
2 private

name are not fixed yet.
'''
import os
from functools import partial

from mercurial.i18n import _
from mercurial import cmdutil
from mercurial import scmutil
from mercurial import context
from mercurial import revset
from mercurial import templatekw
from mercurial import util
from mercurial import node
from mercurial.node import nullid, hex, short
from mercurial import discovery
from mercurial import extensions
from mercurial import wireproto
from mercurial import pushkey
from mercurial.lock import release


_NOSHARE=2
_MUTABLE=1

class state(object):

    def __init__(self, name, properties=0, next=None):
        self.name = name
        self.properties = properties
        assert next is None or self < next
        self.next = next

    def __repr__(self):
        return 'state(%s)' % self.name

    def __str__(self):
        return self.name

    @util.propertycache
    def trackheads(self):
        """Do we need to track heads of changeset in this state ?

        We don't need to track heads for the last state as this is repos heads"""
        return self.next is not None

    def __cmp__(self, other):
        return cmp(self.properties, other.properties)

    @util.propertycache
    def _revsetheads(self):
        """function to be used by revset to finds heads of this states"""
        assert self.trackheads
        def revsetheads(repo, subset, x):
            args = revset.getargs(x, 0, 0, 'publicheads takes no arguments')
            heads = map(repo.changelog.rev, repo._statesheads[self])
            heads.sort()
            return heads
        return revsetheads

    @util.propertycache
    def headssymbol(self):
        """name of the revset symbols"""
        if self.trackheads:
            return "%sheads" % self.name
        else:
            return 'heads'

ST2 = state('draft', _NOSHARE | _MUTABLE)
ST1 = state('ready', _MUTABLE, next=ST2)
ST0 = state('published', next=ST1)

STATES = (ST0, ST1, ST2)

@util.cachefunc
def laststatewithout(prop):
    for state in STATES:
        if not state.properties & prop:
            candidate = state
        else:
            return candidate

# util function
#############################
def noderange(repo, revsets):
    return map(repo.changelog.node,
               scmutil.revrange(repo, revsets))

# Patch changectx
#############################

def state(ctx):
    if ctx.node()is None:
        return STATES[-1]
    return ctx._repo.nodestate(ctx.node())
context.changectx.state = state

# improve template
#############################

def showstate(ctx, **args):
    return ctx.state()


# New commands
#############################


def cmdstates(ui, repo, *states, **opt):
    """view and modify activated states.

    With no argument, list activated state.

    With argument, activate the state in argument.

    With argument plus the --off switch, deactivate the state in argument.

    note: published state are alway activated."""

    if not states:
        for st in sorted(repo._enabledstates):
            ui.write('%s\n' % st)
    else:
        off = opt.get('off', False)
        for state_name in states:
            for st in STATES:
                if st.name == state_name:
                    break
            else:
                ui.write_err(_('no state named %s\n') % state_name)
                return 1
            if off and st in repo._enabledstates:
                repo._enabledstates.remove(st)
            else:
                repo._enabledstates.add(st)
        repo._writeenabledstates()
    return 0

cmdtable = {'states': (cmdstates, [ ('', 'off', False, _('desactivate the state') )], '<state>')}
#cmdtable = {'states': (cmdstates, [], '<state>')}

def makecmd(state):
    def cmdmoveheads(ui, repo, *changesets):
        """set a revision in %s state""" % state
        revs = scmutil.revrange(repo, changesets)
        repo.setstate(state, [repo.changelog.node(rev) for rev in revs])
        return 0
    return cmdmoveheads

for state in STATES:
    if state.trackheads:
        cmdmoveheads = makecmd(state)
        cmdtable[state.name] = (cmdmoveheads, [], '<revset>')

# Pushkey mechanism for mutable
#########################################

def pushimmutableheads(repo, key, old, new):
    st = ST0
    w = repo.wlock()
    try:
        #print 'pushing', key
        repo.setstate(ST0, [node.bin(key)])
    finally:
        w.release()

def listimmutableheads(repo):
    return dict.fromkeys(map(node.hex, repo.stateheads(ST0)), '1')

pushkey.register('immutableheads', pushimmutableheads, listimmutableheads)





def uisetup(ui):
    def filterprivateout(orig, repo, *args,**kwargs):
        common, heads = orig(repo, *args, **kwargs)
        return common, repo._reducehead(heads)
    def filterprivatein(orig, repo, remote, *args, **kwargs):
        common, anyinc, heads = orig(repo, remote, *args, **kwargs)
        heads = remote._reducehead(heads)
        return common, anyinc, heads

    extensions.wrapfunction(discovery, 'findcommonoutgoing', filterprivateout)
    extensions.wrapfunction(discovery, 'findcommonincoming', filterprivatein)

    # Write protocols
    ####################
    def heads(repo, proto):
        st = laststatewithout(_NOSHARE)
        h = repo.stateheads(st)
        return wireproto.encodelist(h) + "\n"

    def _reducehead(wirerepo, heads):
        """heads filtering is done repo side"""
        return heads

    wireproto.wirerepository._reducehead = _reducehead
    wireproto.commands['heads'] = (heads, '')

    templatekw.keywords['state'] = showstate

def extsetup(ui):
    for state in STATES:
        if state.trackheads:
            revset.symbols[state.headssymbol] = state._revsetheads

def reposetup(ui, repo):

    if not repo.local():
        return

    ocancopy =repo.cancopy
    opull = repo.pull
    opush = repo.push
    o_tag = repo._tag
    orollback = repo.rollback
    o_writejournal = repo._writejournal
    class statefulrepo(repo.__class__):

        def nodestate(self, node):
            rev = self.changelog.rev(node)

            for state in STATES:
                # XXX avoid for untracked heads
                if state.next is not None:
                    ancestors = map(self.changelog.rev, self.stateheads(state))
                    ancestors.extend(self.changelog.ancestors(*ancestors))
                    if rev in ancestors:
                        break
            return state



        def stateheads(self, state):
            # look for a relevant state
            while state.trackheads and state.next not in self._enabledstates:
                state = state.next
            # last state have no cached head.
            if state.trackheads:
                return self._statesheads[state]
            return self.heads()

        @util.propertycache
        def _statesheads(self):
            return self._readstatesheads()


        def _readheadsfile(self, filename):
            heads = [nullid]
            try:
                f = self.opener(filename)
                try:
                    heads = sorted([node.bin(n) for n in f.read().split() if n])
                finally:
                    f.close()
            except IOError:
                pass
            return heads

        def _readstatesheads(self, undo=False):
            statesheads = {}
            for state in STATES:
                if state.trackheads:
                    filemask = 'states/%s-heads'
                    filename = filemask % state.name
                    statesheads[state] = self._readheadsfile(filename)
            return statesheads

        def _writeheadsfile(self, filename, heads):
            f = self.opener(filename, 'w', atomictemp=True)
            try:
                for h in heads:
                    f.write(hex(h) + '\n')
                f.rename()
            finally:
                f.close()

        def _writestateshead(self):
            # transaction!
            for state in STATES:
                if state.trackheads:
                    filename = 'states/%s-heads' % state.name
                    self._writeheadsfile(filename, self._statesheads[state])

        def setstate(self, state, nodes):
            """change state of targets changeset and it's ancestors.

            Simplify the list of head."""
            assert not isinstance(nodes, basestring)
            heads = self._statesheads[state]
            olds = heads[:]
            heads.extend(nodes)
            heads[:] = set(heads)
            heads.sort()
            if olds != heads:
                heads[:] = noderange(repo, ["heads(::%s())" % state.headssymbol])
                heads.sort()
            if olds != heads:
                self._writestateshead()
            if state.next is not None and state.next.trackheads:
                self.setstate(state.next, nodes) # cascading

        def _reducehead(self, candidates):
            selected = set()
            st = laststatewithout(_NOSHARE)
            candidates = set(map(self.changelog.rev, candidates))
            heads = set(map(self.changelog.rev, self.stateheads(st)))
            shareable = set(self.changelog.ancestors(*heads))
            shareable.update(heads)
            selected = candidates & shareable
            unselected = candidates - shareable
            for rev in unselected:
                for revh in heads:
                    if self.changelog.descendant(revh, rev):
                        selected.add(revh)
            return sorted(map(self.changelog.node, selected))

        ### enable // disable logic

        @util.propertycache
        def _enabledstates(self):
            return self._readenabledstates()

        def _readenabledstates(self):
            states = set()
            states.add(ST0)
            mapping = dict([(st.name, st) for st in STATES])
            try:
                f = self.opener('states/Enabled')
                for line in f:
                    st =  mapping.get(line.strip())
                    if st is not None:
                        states.add(st)
            finally:
                return states

        def _writeenabledstates(self):
            f = self.opener('states/Enabled', 'w', atomictemp=True)
            try:
                for st in self._enabledstates:
                    f.write(st.name + '\n')
                f.rename()
            finally:
                f.close()

        ### local clone support

        def cancopy(self):
            st = laststatewithout(_NOSHARE)
            return ocancopy() and (self.stateheads(st) == self.heads())

        ### pull // push support

        def pull(self, remote, *args, **kwargs):
            result = opull(remote, *args, **kwargs)
            remoteheads = self._pullimmutableheads(remote)
            #print [node.short(h) for h in remoteheads]
            self.setstate(ST0, remoteheads)
            return result

        def push(self, remote, *args, **opts):
            result = opush(remote, *args, **opts)
            remoteheads = self._pullimmutableheads(remote)
            self.setstate(ST0, remoteheads)
            if remoteheads != self.stateheads(ST0):
                #print 'stuff to push'
                #print 'remote', [node.short(h) for h in remoteheads]
                #print 'local',  [node.short(h) for h in self._statesheads[ST0]]
                self._pushimmutableheads(remote, remoteheads)
            return result

        def _pushimmutableheads(self, remote, remoteheads):
            missing = set(self.stateheads(ST0)) - set(remoteheads)
            for h in missing:
                #print 'missing', node.short(h)
                remote.pushkey('immutableheads', node.hex(h), '', '1')


        def _pullimmutableheads(self, remote):
            self.ui.debug('checking for immutableheadshg on server')
            if 'immutableheads' not in remote.listkeys('namespaces'):
                self.ui.debug('immutableheads not enabled on the remote server, '
                              'marking everything as frozen')
                remote = remote.heads()
            else:
                self.ui.debug('server has immutableheads enabled, merging lists')
                remote = map(node.bin, remote.listkeys('immutableheads'))
            return remote

        ### Tag support

        def _tag(self, names, node, *args, **kwargs):
            tagnode = o_tag(names, node, *args, **kwargs)
            if tagnode is not None: # do nothing for local one
                self.setstate(ST0, [node, tagnode])
            return tagnode

        ### rollback support

        def _writejournal(self, desc):
            entries = list(o_writejournal(desc))
            for state in STATES:
                if state.trackheads:
                    filename = 'states/%s-heads' % state.name
                    filepath = self.join(filename)
                    if  os.path.exists(filepath):
                        journalname = 'states/journal.%s-heads' % state.name
                        journalpath = self.join(journalname)
                        util.copyfile(filepath, journalpath)
                        entries.append(journalpath)
            return tuple(entries)

        def rollback(self, dryrun=False):
            wlock = lock = None
            try:
                wlock = self.wlock()
                lock = self.lock()
                ret = orollback(dryrun)
                if not (ret or dryrun): #rollback did not failed
                    for state in STATES:
                        if state.trackheads:
                            src  = self.join('states/undo.%s-heads') % state.name
                            dest = self.join('states/%s-heads') % state.name
                            if os.path.exists(src):
                                util.rename(src, dest)
                            elif os.path.exists(dest): #unlink in any case
                                os.unlink(dest)
                    self.__dict__.pop('_statesheads', None)
                return ret
            finally:
                release(lock, wlock)

    repo.__class__ = statefulrepo