view mercurial/wireproto.py @ 20836:a8b4541bb961

grep: reuse the first "util.binary()" result for efficiency Before this patch, to check whether the file in the specified revision is binary or not, "util.binary()" is invoked via internal function "binary()" of "hg grep" once per a line of "hg grep" output, even though binary-ness is not changed in the same file. This patch reuses the first "util.binary()" invocation result by annotating internal function "binary()" with "@util.cachefunc". Performance improvement measured by "hgperf grep -r 88d8e568add1 vfs mercurial/scmutil.py": before this patch: ! wall 0.024000 comb 0.015600 user 0.015600 sys 0.000000 (best of 118) after this patch: ! wall 0.023000 comb 0.015600 user 0.015600 sys 0.000000 (best of 123) Status of recent(88d8e568add1) "mercurial/scmutil.py": # of lines: 919 (may affect cost of search) # of bytes: 29633 (may affect cost of "util.binary()") # of matches: 22 (may affect frequency of "util.binary()")
author FUJIWARA Katsunori <foozy@lares.dti.ne.jp>
date Sat, 15 Feb 2014 19:52:36 +0900
parents 982f13bef503
children 1e4fda2f5cf1
line wrap: on
line source

# wireproto.py - generic wire protocol support functions
#
# Copyright 2005-2010 Matt Mackall <mpm@selenic.com>
#
# This software may be used and distributed according to the terms of the
# GNU General Public License version 2 or any later version.

import urllib, tempfile, os, sys
from i18n import _
from node import bin, hex
import changegroup as changegroupmod
import peer, error, encoding, util, store

# abstract batching support

class future(object):
    '''placeholder for a value to be set later'''
    def set(self, value):
        if util.safehasattr(self, 'value'):
            raise error.RepoError("future is already set")
        self.value = value

class batcher(object):
    '''base class for batches of commands submittable in a single request

    All methods invoked on instances of this class are simply queued and
    return a a future for the result. Once you call submit(), all the queued
    calls are performed and the results set in their respective futures.
    '''
    def __init__(self):
        self.calls = []
    def __getattr__(self, name):
        def call(*args, **opts):
            resref = future()
            self.calls.append((name, args, opts, resref,))
            return resref
        return call
    def submit(self):
        pass

class localbatch(batcher):
    '''performs the queued calls directly'''
    def __init__(self, local):
        batcher.__init__(self)
        self.local = local
    def submit(self):
        for name, args, opts, resref in self.calls:
            resref.set(getattr(self.local, name)(*args, **opts))

class remotebatch(batcher):
    '''batches the queued calls; uses as few roundtrips as possible'''
    def __init__(self, remote):
        '''remote must support _submitbatch(encbatch) and
        _submitone(op, encargs)'''
        batcher.__init__(self)
        self.remote = remote
    def submit(self):
        req, rsp = [], []
        for name, args, opts, resref in self.calls:
            mtd = getattr(self.remote, name)
            batchablefn = getattr(mtd, 'batchable', None)
            if batchablefn is not None:
                batchable = batchablefn(mtd.im_self, *args, **opts)
                encargsorres, encresref = batchable.next()
                if encresref:
                    req.append((name, encargsorres,))
                    rsp.append((batchable, encresref, resref,))
                else:
                    resref.set(encargsorres)
            else:
                if req:
                    self._submitreq(req, rsp)
                    req, rsp = [], []
                resref.set(mtd(*args, **opts))
        if req:
            self._submitreq(req, rsp)
    def _submitreq(self, req, rsp):
        encresults = self.remote._submitbatch(req)
        for encres, r in zip(encresults, rsp):
            batchable, encresref, resref = r
            encresref.set(encres)
            resref.set(batchable.next())

def batchable(f):
    '''annotation for batchable methods

    Such methods must implement a coroutine as follows:

    @batchable
    def sample(self, one, two=None):
        # Handle locally computable results first:
        if not one:
            yield "a local result", None
        # Build list of encoded arguments suitable for your wire protocol:
        encargs = [('one', encode(one),), ('two', encode(two),)]
        # Create future for injection of encoded result:
        encresref = future()
        # Return encoded arguments and future:
        yield encargs, encresref
        # Assuming the future to be filled with the result from the batched
        # request now. Decode it:
        yield decode(encresref.value)

    The decorator returns a function which wraps this coroutine as a plain
    method, but adds the original method as an attribute called "batchable",
    which is used by remotebatch to split the call into separate encoding and
    decoding phases.
    '''
    def plain(*args, **opts):
        batchable = f(*args, **opts)
        encargsorres, encresref = batchable.next()
        if not encresref:
            return encargsorres # a local result in this case
        self = args[0]
        encresref.set(self._submitone(f.func_name, encargsorres))
        return batchable.next()
    setattr(plain, 'batchable', f)
    return plain

# list of nodes encoding / decoding

def decodelist(l, sep=' '):
    if l:
        return map(bin, l.split(sep))
    return []

def encodelist(l, sep=' '):
    return sep.join(map(hex, l))

# batched call argument encoding

def escapearg(plain):
    return (plain
            .replace(':', '::')
            .replace(',', ':,')
            .replace(';', ':;')
            .replace('=', ':='))

def unescapearg(escaped):
    return (escaped
            .replace(':=', '=')
            .replace(':;', ';')
            .replace(':,', ',')
            .replace('::', ':'))

# client side

class wirepeer(peer.peerrepository):

    def batch(self):
        return remotebatch(self)
    def _submitbatch(self, req):
        cmds = []
        for op, argsdict in req:
            args = ','.join('%s=%s' % p for p in argsdict.iteritems())
            cmds.append('%s %s' % (op, args))
        rsp = self._call("batch", cmds=';'.join(cmds))
        return rsp.split(';')
    def _submitone(self, op, args):
        return self._call(op, **args)

    @batchable
    def lookup(self, key):
        self.requirecap('lookup', _('look up remote revision'))
        f = future()
        yield {'key': encoding.fromlocal(key)}, f
        d = f.value
        success, data = d[:-1].split(" ", 1)
        if int(success):
            yield bin(data)
        self._abort(error.RepoError(data))

    @batchable
    def heads(self):
        f = future()
        yield {}, f
        d = f.value
        try:
            yield decodelist(d[:-1])
        except ValueError:
            self._abort(error.ResponseError(_("unexpected response:"), d))

    @batchable
    def known(self, nodes):
        f = future()
        yield {'nodes': encodelist(nodes)}, f
        d = f.value
        try:
            yield [bool(int(f)) for f in d]
        except ValueError:
            self._abort(error.ResponseError(_("unexpected response:"), d))

    @batchable
    def branchmap(self):
        f = future()
        yield {}, f
        d = f.value
        try:
            branchmap = {}
            for branchpart in d.splitlines():
                branchname, branchheads = branchpart.split(' ', 1)
                branchname = encoding.tolocal(urllib.unquote(branchname))
                branchheads = decodelist(branchheads)
                branchmap[branchname] = branchheads
            yield branchmap
        except TypeError:
            self._abort(error.ResponseError(_("unexpected response:"), d))

    def branches(self, nodes):
        n = encodelist(nodes)
        d = self._call("branches", nodes=n)
        try:
            br = [tuple(decodelist(b)) for b in d.splitlines()]
            return br
        except ValueError:
            self._abort(error.ResponseError(_("unexpected response:"), d))

    def between(self, pairs):
        batch = 8 # avoid giant requests
        r = []
        for i in xrange(0, len(pairs), batch):
            n = " ".join([encodelist(p, '-') for p in pairs[i:i + batch]])
            d = self._call("between", pairs=n)
            try:
                r.extend(l and decodelist(l) or [] for l in d.splitlines())
            except ValueError:
                self._abort(error.ResponseError(_("unexpected response:"), d))
        return r

    @batchable
    def pushkey(self, namespace, key, old, new):
        if not self.capable('pushkey'):
            yield False, None
        f = future()
        self.ui.debug('preparing pushkey for "%s:%s"\n' % (namespace, key))
        yield {'namespace': encoding.fromlocal(namespace),
               'key': encoding.fromlocal(key),
               'old': encoding.fromlocal(old),
               'new': encoding.fromlocal(new)}, f
        d = f.value
        d, output = d.split('\n', 1)
        try:
            d = bool(int(d))
        except ValueError:
            raise error.ResponseError(
                _('push failed (unexpected response):'), d)
        for l in output.splitlines(True):
            self.ui.status(_('remote: '), l)
        yield d

    @batchable
    def listkeys(self, namespace):
        if not self.capable('pushkey'):
            yield {}, None
        f = future()
        self.ui.debug('preparing listkeys for "%s"\n' % namespace)
        yield {'namespace': encoding.fromlocal(namespace)}, f
        d = f.value
        r = {}
        for l in d.splitlines():
            k, v = l.split('\t')
            r[encoding.tolocal(k)] = encoding.tolocal(v)
        yield r

    def stream_out(self):
        return self._callstream('stream_out')

    def changegroup(self, nodes, kind):
        n = encodelist(nodes)
        f = self._callstream("changegroup", roots=n)
        return changegroupmod.unbundle10(self._decompress(f), 'UN')

    def changegroupsubset(self, bases, heads, kind):
        self.requirecap('changegroupsubset', _('look up remote changes'))
        bases = encodelist(bases)
        heads = encodelist(heads)
        f = self._callstream("changegroupsubset",
                             bases=bases, heads=heads)
        return changegroupmod.unbundle10(self._decompress(f), 'UN')

    def getbundle(self, source, heads=None, common=None, bundlecaps=None):
        self.requirecap('getbundle', _('look up remote changes'))
        opts = {}
        if heads is not None:
            opts['heads'] = encodelist(heads)
        if common is not None:
            opts['common'] = encodelist(common)
        if bundlecaps is not None:
            opts['bundlecaps'] = ','.join(bundlecaps)
        f = self._callstream("getbundle", **opts)
        return changegroupmod.unbundle10(self._decompress(f), 'UN')

    def unbundle(self, cg, heads, source):
        '''Send cg (a readable file-like object representing the
        changegroup to push, typically a chunkbuffer object) to the
        remote server as a bundle. Return an integer indicating the
        result of the push (see localrepository.addchangegroup()).'''

        if heads != ['force'] and self.capable('unbundlehash'):
            heads = encodelist(['hashed',
                                util.sha1(''.join(sorted(heads))).digest()])
        else:
            heads = encodelist(heads)

        ret, output = self._callpush("unbundle", cg, heads=heads)
        if ret == "":
            raise error.ResponseError(
                _('push failed:'), output)
        try:
            ret = int(ret)
        except ValueError:
            raise error.ResponseError(
                _('push failed (unexpected response):'), ret)

        for l in output.splitlines(True):
            self.ui.status(_('remote: '), l)
        return ret

    def debugwireargs(self, one, two, three=None, four=None, five=None):
        # don't pass optional arguments left at their default value
        opts = {}
        if three is not None:
            opts['three'] = three
        if four is not None:
            opts['four'] = four
        return self._call('debugwireargs', one=one, two=two, **opts)

# server side

class streamres(object):
    def __init__(self, gen):
        self.gen = gen

class pushres(object):
    def __init__(self, res):
        self.res = res

class pusherr(object):
    def __init__(self, res):
        self.res = res

class ooberror(object):
    def __init__(self, message):
        self.message = message

def dispatch(repo, proto, command):
    repo = repo.filtered("served")
    func, spec = commands[command]
    args = proto.getargs(spec)
    return func(repo, proto, *args)

def options(cmd, keys, others):
    opts = {}
    for k in keys:
        if k in others:
            opts[k] = others[k]
            del others[k]
    if others:
        sys.stderr.write("abort: %s got unexpected arguments %s\n"
                         % (cmd, ",".join(others)))
    return opts

def batch(repo, proto, cmds, others):
    repo = repo.filtered("served")
    res = []
    for pair in cmds.split(';'):
        op, args = pair.split(' ', 1)
        vals = {}
        for a in args.split(','):
            if a:
                n, v = a.split('=')
                vals[n] = unescapearg(v)
        func, spec = commands[op]
        if spec:
            keys = spec.split()
            data = {}
            for k in keys:
                if k == '*':
                    star = {}
                    for key in vals.keys():
                        if key not in keys:
                            star[key] = vals[key]
                    data['*'] = star
                else:
                    data[k] = vals[k]
            result = func(repo, proto, *[data[k] for k in keys])
        else:
            result = func(repo, proto)
        if isinstance(result, ooberror):
            return result
        res.append(escapearg(result))
    return ';'.join(res)

def between(repo, proto, pairs):
    pairs = [decodelist(p, '-') for p in pairs.split(" ")]
    r = []
    for b in repo.between(pairs):
        r.append(encodelist(b) + "\n")
    return "".join(r)

def branchmap(repo, proto):
    branchmap = repo.branchmap()
    heads = []
    for branch, nodes in branchmap.iteritems():
        branchname = urllib.quote(encoding.fromlocal(branch))
        branchnodes = encodelist(nodes)
        heads.append('%s %s' % (branchname, branchnodes))
    return '\n'.join(heads)

def branches(repo, proto, nodes):
    nodes = decodelist(nodes)
    r = []
    for b in repo.branches(nodes):
        r.append(encodelist(b) + "\n")
    return "".join(r)


wireprotocaps = ['lookup', 'changegroupsubset', 'branchmap', 'pushkey',
                 'known', 'getbundle', 'unbundlehash', 'batch']

def _capabilities(repo, proto):
    """return a list of capabilities for a repo

    This function exists to allow extensions to easily wrap capabilities
    computation

    - returns a lists: easy to alter
    - change done here will be propagated to both `capabilities` and `hello`
      command without any other effort. without any other action needed.
    """
    # copy to prevent modification of the global list
    caps = list(wireprotocaps)
    if _allowstream(repo.ui):
        if repo.ui.configbool('server', 'preferuncompressed', False):
            caps.append('stream-preferred')
        requiredformats = repo.requirements & repo.supportedformats
        # if our local revlogs are just revlogv1, add 'stream' cap
        if not requiredformats - set(('revlogv1',)):
            caps.append('stream')
        # otherwise, add 'streamreqs' detailing our local revlog format
        else:
            caps.append('streamreqs=%s' % ','.join(requiredformats))
    caps.append('unbundle=%s' % ','.join(changegroupmod.bundlepriority))
    caps.append('httpheader=1024')
    return caps

# If you are writting and extension and consider wrapping this function. Wrap
# `_capabilities` instead.
def capabilities(repo, proto):
    return ' '.join(_capabilities(repo, proto))

def changegroup(repo, proto, roots):
    nodes = decodelist(roots)
    cg = repo.changegroup(nodes, 'serve')
    return streamres(proto.groupchunks(cg))

def changegroupsubset(repo, proto, bases, heads):
    bases = decodelist(bases)
    heads = decodelist(heads)
    cg = repo.changegroupsubset(bases, heads, 'serve')
    return streamres(proto.groupchunks(cg))

def debugwireargs(repo, proto, one, two, others):
    # only accept optional args from the known set
    opts = options('debugwireargs', ['three', 'four'], others)
    return repo.debugwireargs(one, two, **opts)

def getbundle(repo, proto, others):
    opts = options('getbundle', ['heads', 'common', 'bundlecaps'], others)
    for k, v in opts.iteritems():
        if k in ('heads', 'common'):
            opts[k] = decodelist(v)
        elif k == 'bundlecaps':
            opts[k] = set(v.split(','))
    cg = repo.getbundle('serve', **opts)
    return streamres(proto.groupchunks(cg))

def heads(repo, proto):
    h = repo.heads()
    return encodelist(h) + "\n"

def hello(repo, proto):
    '''the hello command returns a set of lines describing various
    interesting things about the server, in an RFC822-like format.
    Currently the only one defined is "capabilities", which
    consists of a line in the form:

    capabilities: space separated list of tokens
    '''
    return "capabilities: %s\n" % (capabilities(repo, proto))

def listkeys(repo, proto, namespace):
    d = repo.listkeys(encoding.tolocal(namespace)).items()
    t = '\n'.join(['%s\t%s' % (encoding.fromlocal(k), encoding.fromlocal(v))
                   for k, v in d])
    return t

def lookup(repo, proto, key):
    try:
        k = encoding.tolocal(key)
        c = repo[k]
        r = c.hex()
        success = 1
    except Exception, inst:
        r = str(inst)
        success = 0
    return "%s %s\n" % (success, r)

def known(repo, proto, nodes, others):
    return ''.join(b and "1" or "0" for b in repo.known(decodelist(nodes)))

def pushkey(repo, proto, namespace, key, old, new):
    # compatibility with pre-1.8 clients which were accidentally
    # sending raw binary nodes rather than utf-8-encoded hex
    if len(new) == 20 and new.encode('string-escape') != new:
        # looks like it could be a binary node
        try:
            new.decode('utf-8')
            new = encoding.tolocal(new) # but cleanly decodes as UTF-8
        except UnicodeDecodeError:
            pass # binary, leave unmodified
    else:
        new = encoding.tolocal(new) # normal path

    if util.safehasattr(proto, 'restore'):

        proto.redirect()

        try:
            r = repo.pushkey(encoding.tolocal(namespace), encoding.tolocal(key),
                             encoding.tolocal(old), new) or False
        except util.Abort:
            r = False

        output = proto.restore()

        return '%s\n%s' % (int(r), output)

    r = repo.pushkey(encoding.tolocal(namespace), encoding.tolocal(key),
                     encoding.tolocal(old), new)
    return '%s\n' % int(r)

def _allowstream(ui):
    return ui.configbool('server', 'uncompressed', True, untrusted=True)

def _walkstreamfiles(repo):
    # this is it's own function so extensions can override it
    return repo.store.walk()

def stream(repo, proto):
    '''If the server supports streaming clone, it advertises the "stream"
    capability with a value representing the version and flags of the repo
    it is serving. Client checks to see if it understands the format.

    The format is simple: the server writes out a line with the amount
    of files, then the total amount of bytes to be transferred (separated
    by a space). Then, for each file, the server first writes the filename
    and filesize (separated by the null character), then the file contents.
    '''

    if not _allowstream(repo.ui):
        return '1\n'

    entries = []
    total_bytes = 0
    try:
        # get consistent snapshot of repo, lock during scan
        lock = repo.lock()
        try:
            repo.ui.debug('scanning\n')
            for name, ename, size in _walkstreamfiles(repo):
                if size:
                    entries.append((name, size))
                    total_bytes += size
        finally:
            lock.release()
    except error.LockError:
        return '2\n' # error: 2

    def streamer(repo, entries, total):
        '''stream out all metadata files in repository.'''
        yield '0\n' # success
        repo.ui.debug('%d files, %d bytes to transfer\n' %
                      (len(entries), total_bytes))
        yield '%d %d\n' % (len(entries), total_bytes)

        sopener = repo.sopener
        oldaudit = sopener.mustaudit
        debugflag = repo.ui.debugflag
        sopener.mustaudit = False

        try:
            for name, size in entries:
                if debugflag:
                    repo.ui.debug('sending %s (%d bytes)\n' % (name, size))
                # partially encode name over the wire for backwards compat
                yield '%s\0%d\n' % (store.encodedir(name), size)
                if size <= 65536:
                    fp = sopener(name)
                    try:
                        data = fp.read(size)
                    finally:
                        fp.close()
                    yield data
                else:
                    for chunk in util.filechunkiter(sopener(name), limit=size):
                        yield chunk
        # replace with "finally:" when support for python 2.4 has been dropped
        except Exception:
            sopener.mustaudit = oldaudit
            raise
        sopener.mustaudit = oldaudit

    return streamres(streamer(repo, entries, total_bytes))

def unbundle(repo, proto, heads):
    their_heads = decodelist(heads)

    def check_heads():
        heads = repo.heads()
        heads_hash = util.sha1(''.join(sorted(heads))).digest()
        return (their_heads == ['force'] or their_heads == heads or
                their_heads == ['hashed', heads_hash])

    proto.redirect()

    # fail early if possible
    if not check_heads():
        return pusherr('repository changed while preparing changes - '
                       'please try again')

    # write bundle data to temporary file because it can be big
    fd, tempname = tempfile.mkstemp(prefix='hg-unbundle-')
    fp = os.fdopen(fd, 'wb+')
    r = 0
    try:
        proto.getfile(fp)
        lock = repo.lock()
        try:
            if not check_heads():
                # someone else committed/pushed/unbundled while we
                # were transferring data
                return pusherr('repository changed while uploading changes - '
                               'please try again')

            # push can proceed
            fp.seek(0)
            gen = changegroupmod.readbundle(fp, None)

            try:
                r = repo.addchangegroup(gen, 'serve', proto._client())
            except util.Abort, inst:
                sys.stderr.write("abort: %s\n" % inst)
        finally:
            lock.release()
        return pushres(r)

    finally:
        fp.close()
        os.unlink(tempname)

commands = {
    'batch': (batch, 'cmds *'),
    'between': (between, 'pairs'),
    'branchmap': (branchmap, ''),
    'branches': (branches, 'nodes'),
    'capabilities': (capabilities, ''),
    'changegroup': (changegroup, 'roots'),
    'changegroupsubset': (changegroupsubset, 'bases heads'),
    'debugwireargs': (debugwireargs, 'one two *'),
    'getbundle': (getbundle, '*'),
    'heads': (heads, ''),
    'hello': (hello, ''),
    'known': (known, 'nodes *'),
    'listkeys': (listkeys, 'namespace'),
    'lookup': (lookup, 'key'),
    'pushkey': (pushkey, 'namespace key old new'),
    'stream_out': (stream, ''),
    'unbundle': (unbundle, 'heads'),
}