view hgext/remotefilelog/fileserverclient.py @ 42095:864f9f63d3ed

remotefilelog: correctly reject wdir filenodes This fixes `hg grep -r 'wdir()'` when remotefilelog is enabled and the working directory contains uncommitted modifications. Differential Revision: https://phab.mercurial-scm.org/D6217
author Augie Fackler <augie@google.com>
date Mon, 08 Apr 2019 10:52:04 -0400
parents 0129bf02fa04
children 65f3a77223e0
line wrap: on
line source

# fileserverclient.py - client for communicating with the cache process
#
# Copyright 2013 Facebook, Inc.
#
# This software may be used and distributed according to the terms of the
# GNU General Public License version 2 or any later version.

from __future__ import absolute_import

import hashlib
import io
import os
import threading
import time
import zlib

from mercurial.i18n import _
from mercurial.node import bin, hex, nullid
from mercurial import (
    error,
    node,
    pycompat,
    revlog,
    sshpeer,
    util,
    wireprotov1peer,
)
from mercurial.utils import procutil

from . import (
    constants,
    contentstore,
    metadatastore,
)

_sshv1peer = sshpeer.sshv1peer

# Statistics for debugging
fetchcost = 0
fetches = 0
fetched = 0
fetchmisses = 0

_lfsmod = None

def getcachekey(reponame, file, id):
    pathhash = node.hex(hashlib.sha1(file).digest())
    return os.path.join(reponame, pathhash[:2], pathhash[2:], id)

def getlocalkey(file, id):
    pathhash = node.hex(hashlib.sha1(file).digest())
    return os.path.join(pathhash, id)

def peersetup(ui, peer):

    class remotefilepeer(peer.__class__):
        @wireprotov1peer.batchable
        def x_rfl_getfile(self, file, node):
            if not self.capable('x_rfl_getfile'):
                raise error.Abort(
                    'configured remotefile server does not support getfile')
            f = wireprotov1peer.future()
            yield {'file': file, 'node': node}, f
            code, data = f.value.split('\0', 1)
            if int(code):
                raise error.LookupError(file, node, data)
            yield data

        @wireprotov1peer.batchable
        def x_rfl_getflogheads(self, path):
            if not self.capable('x_rfl_getflogheads'):
                raise error.Abort('configured remotefile server does not '
                                  'support getflogheads')
            f = wireprotov1peer.future()
            yield {'path': path}, f
            heads = f.value.split('\n') if f.value else []
            yield heads

        def _updatecallstreamopts(self, command, opts):
            if command != 'getbundle':
                return
            if (constants.NETWORK_CAP_LEGACY_SSH_GETFILES
                not in self.capabilities()):
                return
            if not util.safehasattr(self, '_localrepo'):
                return
            if (constants.SHALLOWREPO_REQUIREMENT
                not in self._localrepo.requirements):
                return

            bundlecaps = opts.get('bundlecaps')
            if bundlecaps:
                bundlecaps = [bundlecaps]
            else:
                bundlecaps = []

            # shallow, includepattern, and excludepattern are a hacky way of
            # carrying over data from the local repo to this getbundle
            # command. We need to do it this way because bundle1 getbundle
            # doesn't provide any other place we can hook in to manipulate
            # getbundle args before it goes across the wire. Once we get rid
            # of bundle1, we can use bundle2's _pullbundle2extraprepare to
            # do this more cleanly.
            bundlecaps.append(constants.BUNDLE2_CAPABLITY)
            if self._localrepo.includepattern:
                patterns = '\0'.join(self._localrepo.includepattern)
                includecap = "includepattern=" + patterns
                bundlecaps.append(includecap)
            if self._localrepo.excludepattern:
                patterns = '\0'.join(self._localrepo.excludepattern)
                excludecap = "excludepattern=" + patterns
                bundlecaps.append(excludecap)
            opts['bundlecaps'] = ','.join(bundlecaps)

        def _sendrequest(self, command, args, **opts):
            self._updatecallstreamopts(command, args)
            return super(remotefilepeer, self)._sendrequest(command, args,
                                                            **opts)

        def _callstream(self, command, **opts):
            supertype = super(remotefilepeer, self)
            if not util.safehasattr(supertype, '_sendrequest'):
                self._updatecallstreamopts(command, pycompat.byteskwargs(opts))
            return super(remotefilepeer, self)._callstream(command, **opts)

    peer.__class__ = remotefilepeer

class cacheconnection(object):
    """The connection for communicating with the remote cache. Performs
    gets and sets by communicating with an external process that has the
    cache-specific implementation.
    """
    def __init__(self):
        self.pipeo = self.pipei = self.pipee = None
        self.subprocess = None
        self.connected = False

    def connect(self, cachecommand):
        if self.pipeo:
            raise error.Abort(_("cache connection already open"))
        self.pipei, self.pipeo, self.pipee, self.subprocess = (
            procutil.popen4(cachecommand))
        self.connected = True

    def close(self):
        def tryclose(pipe):
            try:
                pipe.close()
            except Exception:
                pass
        if self.connected:
            try:
                self.pipei.write("exit\n")
            except Exception:
                pass
            tryclose(self.pipei)
            self.pipei = None
            tryclose(self.pipeo)
            self.pipeo = None
            tryclose(self.pipee)
            self.pipee = None
            try:
                # Wait for process to terminate, making sure to avoid deadlock.
                # See https://docs.python.org/2/library/subprocess.html for
                # warnings about wait() and deadlocking.
                self.subprocess.communicate()
            except Exception:
                pass
            self.subprocess = None
        self.connected = False

    def request(self, request, flush=True):
        if self.connected:
            try:
                self.pipei.write(request)
                if flush:
                    self.pipei.flush()
            except IOError:
                self.close()

    def receiveline(self):
        if not self.connected:
            return None
        try:
            result = self.pipeo.readline()[:-1]
            if not result:
                self.close()
        except IOError:
            self.close()

        return result

def _getfilesbatch(
        remote, receivemissing, progresstick, missed, idmap, batchsize):
    # Over http(s), iterbatch is a streamy method and we can start
    # looking at results early. This means we send one (potentially
    # large) request, but then we show nice progress as we process
    # file results, rather than showing chunks of $batchsize in
    # progress.
    #
    # Over ssh, iterbatch isn't streamy because batch() wasn't
    # explicitly designed as a streaming method. In the future we
    # should probably introduce a streambatch() method upstream and
    # use that for this.
    with remote.commandexecutor() as e:
        futures = []
        for m in missed:
            futures.append(e.callcommand('x_rfl_getfile', {
                'file': idmap[m],
                'node': m[-40:]
            }))

        for i, m in enumerate(missed):
            r = futures[i].result()
            futures[i] = None  # release memory
            file_ = idmap[m]
            node = m[-40:]
            receivemissing(io.BytesIO('%d\n%s' % (len(r), r)), file_, node)
            progresstick()

def _getfiles_optimistic(
    remote, receivemissing, progresstick, missed, idmap, step):
    remote._callstream("x_rfl_getfiles")
    i = 0
    pipeo = remote._pipeo
    pipei = remote._pipei
    while i < len(missed):
        # issue a batch of requests
        start = i
        end = min(len(missed), start + step)
        i = end
        for missingid in missed[start:end]:
            # issue new request
            versionid = missingid[-40:]
            file = idmap[missingid]
            sshrequest = "%s%s\n" % (versionid, file)
            pipeo.write(sshrequest)
        pipeo.flush()

        # receive batch results
        for missingid in missed[start:end]:
            versionid = missingid[-40:]
            file = idmap[missingid]
            receivemissing(pipei, file, versionid)
            progresstick()

    # End the command
    pipeo.write('\n')
    pipeo.flush()

def _getfiles_threaded(
    remote, receivemissing, progresstick, missed, idmap, step):
    remote._callstream("getfiles")
    pipeo = remote._pipeo
    pipei = remote._pipei

    def writer():
        for missingid in missed:
            versionid = missingid[-40:]
            file = idmap[missingid]
            sshrequest = "%s%s\n" % (versionid, file)
            pipeo.write(sshrequest)
        pipeo.flush()
    writerthread = threading.Thread(target=writer)
    writerthread.daemon = True
    writerthread.start()

    for missingid in missed:
        versionid = missingid[-40:]
        file = idmap[missingid]
        receivemissing(pipei, file, versionid)
        progresstick()

    writerthread.join()
    # End the command
    pipeo.write('\n')
    pipeo.flush()

class fileserverclient(object):
    """A client for requesting files from the remote file server.
    """
    def __init__(self, repo):
        ui = repo.ui
        self.repo = repo
        self.ui = ui
        self.cacheprocess = ui.config("remotefilelog", "cacheprocess")
        if self.cacheprocess:
            self.cacheprocess = util.expandpath(self.cacheprocess)

        # This option causes remotefilelog to pass the full file path to the
        # cacheprocess instead of a hashed key.
        self.cacheprocesspasspath = ui.configbool(
            "remotefilelog", "cacheprocess.includepath")

        self.debugoutput = ui.configbool("remotefilelog", "debug")

        self.remotecache = cacheconnection()

    def setstore(self, datastore, historystore, writedata, writehistory):
        self.datastore = datastore
        self.historystore = historystore
        self.writedata = writedata
        self.writehistory = writehistory

    def _connect(self):
        return self.repo.connectionpool.get(self.repo.fallbackpath)

    def request(self, fileids):
        """Takes a list of filename/node pairs and fetches them from the
        server. Files are stored in the local cache.
        A list of nodes that the server couldn't find is returned.
        If the connection fails, an exception is raised.
        """
        if not self.remotecache.connected:
            self.connect()
        cache = self.remotecache
        writedata = self.writedata

        repo = self.repo
        total = len(fileids)
        request = "get\n%d\n" % total
        idmap = {}
        reponame = repo.name
        for file, id in fileids:
            fullid = getcachekey(reponame, file, id)
            if self.cacheprocesspasspath:
                request += file + '\0'
            request += fullid + "\n"
            idmap[fullid] = file

        cache.request(request)

        progress = self.ui.makeprogress(_('downloading'), total=total)
        progress.update(0)

        missed = []
        while True:
            missingid = cache.receiveline()
            if not missingid:
                missedset = set(missed)
                for missingid in idmap:
                    if not missingid in missedset:
                        missed.append(missingid)
                self.ui.warn(_("warning: cache connection closed early - " +
                    "falling back to server\n"))
                break
            if missingid == "0":
                break
            if missingid.startswith("_hits_"):
                # receive progress reports
                parts = missingid.split("_")
                progress.increment(int(parts[2]))
                continue

            missed.append(missingid)

        global fetchmisses
        fetchmisses += len(missed)

        fromcache = total - len(missed)
        progress.update(fromcache, total=total)
        self.ui.log("remotefilelog", "remote cache hit rate is %r of %r\n",
                    fromcache, total, hit=fromcache, total=total)

        oldumask = os.umask(0o002)
        try:
            # receive cache misses from master
            if missed:
                # When verbose is true, sshpeer prints 'running ssh...'
                # to stdout, which can interfere with some command
                # outputs
                verbose = self.ui.verbose
                self.ui.verbose = False
                try:
                    with self._connect() as conn:
                        remote = conn.peer
                        if remote.capable(
                                constants.NETWORK_CAP_LEGACY_SSH_GETFILES):
                            if not isinstance(remote, _sshv1peer):
                                raise error.Abort('remotefilelog requires ssh '
                                                  'servers')
                            step = self.ui.configint('remotefilelog',
                                                     'getfilesstep')
                            getfilestype = self.ui.config('remotefilelog',
                                                          'getfilestype')
                            if getfilestype == 'threaded':
                                _getfiles = _getfiles_threaded
                            else:
                                _getfiles = _getfiles_optimistic
                            _getfiles(remote, self.receivemissing,
                                      progress.increment, missed, idmap, step)
                        elif remote.capable("x_rfl_getfile"):
                            if remote.capable('batch'):
                                batchdefault = 100
                            else:
                                batchdefault = 10
                            batchsize = self.ui.configint(
                                'remotefilelog', 'batchsize', batchdefault)
                            _getfilesbatch(
                                remote, self.receivemissing, progress.increment,
                                missed, idmap, batchsize)
                        else:
                            raise error.Abort("configured remotefilelog server"
                                             " does not support remotefilelog")

                    self.ui.log("remotefilefetchlog",
                                "Success\n",
                                fetched_files = progress.pos - fromcache,
                                total_to_fetch = total - fromcache)
                except Exception:
                    self.ui.log("remotefilefetchlog",
                                "Fail\n",
                                fetched_files = progress.pos - fromcache,
                                total_to_fetch = total - fromcache)
                    raise
                finally:
                    self.ui.verbose = verbose
                # send to memcache
                request = "set\n%d\n%s\n" % (len(missed), "\n".join(missed))
                cache.request(request)

            progress.complete()

            # mark ourselves as a user of this cache
            writedata.markrepo(self.repo.path)
        finally:
            os.umask(oldumask)

    def receivemissing(self, pipe, filename, node):
        line = pipe.readline()[:-1]
        if not line:
            raise error.ResponseError(_("error downloading file contents:"),
                                      _("connection closed early"))
        size = int(line)
        data = pipe.read(size)
        if len(data) != size:
            raise error.ResponseError(_("error downloading file contents:"),
                                      _("only received %s of %s bytes")
                                      % (len(data), size))

        self.writedata.addremotefilelognode(filename, bin(node),
                                             zlib.decompress(data))

    def connect(self):
        if self.cacheprocess:
            cmd = "%s %s" % (self.cacheprocess, self.writedata._path)
            self.remotecache.connect(cmd)
        else:
            # If no cache process is specified, we fake one that always
            # returns cache misses.  This enables tests to run easily
            # and may eventually allow us to be a drop in replacement
            # for the largefiles extension.
            class simplecache(object):
                def __init__(self):
                    self.missingids = []
                    self.connected = True

                def close(self):
                    pass

                def request(self, value, flush=True):
                    lines = value.split("\n")
                    if lines[0] != "get":
                        return
                    self.missingids = lines[2:-1]
                    self.missingids.append('0')

                def receiveline(self):
                    if len(self.missingids) > 0:
                        return self.missingids.pop(0)
                    return None

            self.remotecache = simplecache()

    def close(self):
        if fetches:
            msg = ("%d files fetched over %d fetches - " +
                   "(%d misses, %0.2f%% hit ratio) over %0.2fs\n") % (
                       fetched,
                       fetches,
                       fetchmisses,
                       float(fetched - fetchmisses) / float(fetched) * 100.0,
                       fetchcost)
            if self.debugoutput:
                self.ui.warn(msg)
            self.ui.log("remotefilelog.prefetch", msg.replace("%", "%%"),
                remotefilelogfetched=fetched,
                remotefilelogfetches=fetches,
                remotefilelogfetchmisses=fetchmisses,
                remotefilelogfetchtime=fetchcost * 1000)

        if self.remotecache.connected:
            self.remotecache.close()

    def prefetch(self, fileids, force=False, fetchdata=True,
                 fetchhistory=False):
        """downloads the given file versions to the cache
        """
        repo = self.repo
        idstocheck = []
        for file, id in fileids:
            # hack
            # - we don't use .hgtags
            # - workingctx produces ids with length 42,
            #   which we skip since they aren't in any cache
            if (file == '.hgtags' or len(id) == 42
                or not repo.shallowmatch(file)):
                continue

            idstocheck.append((file, bin(id)))

        datastore = self.datastore
        historystore = self.historystore
        if force:
            datastore = contentstore.unioncontentstore(*repo.shareddatastores)
            historystore = metadatastore.unionmetadatastore(
                *repo.sharedhistorystores)

        missingids = set()
        if fetchdata:
            missingids.update(datastore.getmissing(idstocheck))
        if fetchhistory:
            missingids.update(historystore.getmissing(idstocheck))

        # partition missing nodes into nullid and not-nullid so we can
        # warn about this filtering potentially shadowing bugs.
        nullids = len([None for unused, id in missingids if id == nullid])
        if nullids:
            missingids = [(f, id) for f, id in missingids if id != nullid]
            repo.ui.develwarn(
                ('remotefilelog not fetching %d null revs'
                 ' - this is likely hiding bugs' % nullids),
                config='remotefilelog-ext')
        if missingids:
            global fetches, fetched, fetchcost
            fetches += 1

            # We want to be able to detect excess individual file downloads, so
            # let's log that information for debugging.
            if fetches >= 15 and fetches < 18:
                if fetches == 15:
                    fetchwarning = self.ui.config('remotefilelog',
                                                  'fetchwarning')
                    if fetchwarning:
                        self.ui.warn(fetchwarning + '\n')
                self.logstacktrace()
            missingids = [(file, hex(id)) for file, id in sorted(missingids)]
            fetched += len(missingids)
            start = time.time()
            missingids = self.request(missingids)
            if missingids:
                raise error.Abort(_("unable to download %d files") %
                                  len(missingids))
            fetchcost += time.time() - start
            self._lfsprefetch(fileids)

    def _lfsprefetch(self, fileids):
        if not _lfsmod or not util.safehasattr(
                self.repo.svfs, 'lfslocalblobstore'):
            return
        if not _lfsmod.wrapper.candownload(self.repo):
            return
        pointers = []
        store = self.repo.svfs.lfslocalblobstore
        for file, id in fileids:
            node = bin(id)
            rlog = self.repo.file(file)
            if rlog.flags(node) & revlog.REVIDX_EXTSTORED:
                text = rlog.revision(node, raw=True)
                p = _lfsmod.pointer.deserialize(text)
                oid = p.oid()
                if not store.has(oid):
                    pointers.append(p)
        if len(pointers) > 0:
            self.repo.svfs.lfsremoteblobstore.readbatch(pointers, store)
            assert all(store.has(p.oid()) for p in pointers)

    def logstacktrace(self):
        import traceback
        self.ui.log('remotefilelog', 'excess remotefilelog fetching:\n%s\n',
                    ''.join(traceback.format_stack()))