mercurial/wireprotoserver.py
changeset 35859 1bf5263fe5cc
parent 35858 1b76a9e0a9de
child 35860 d9e71cce3b2f
--- a/mercurial/wireprotoserver.py	Wed Jan 31 11:13:11 2018 -0800
+++ b/mercurial/wireprotoserver.py	Wed Jan 31 10:48:35 2018 -0800
@@ -8,9 +8,13 @@
 
 import cgi
 import struct
+import sys
 
+from .i18n import _
 from . import (
+    encoding,
     error,
+    hook,
     pycompat,
     util,
     wireproto,
@@ -197,3 +201,114 @@
         req.respond(HTTP_OK, HGERRTYPE, body=rsp)
         return []
     raise error.ProgrammingError('hgweb.protocol internal failure', rsp)
+
+class sshserver(wireproto.abstractserverproto):
+    def __init__(self, ui, repo):
+        self.ui = ui
+        self.repo = repo
+        self.lock = None
+        self.fin = ui.fin
+        self.fout = ui.fout
+        self.name = 'ssh'
+
+        hook.redirect(True)
+        ui.fout = repo.ui.fout = ui.ferr
+
+        # Prevent insertion/deletion of CRs
+        util.setbinary(self.fin)
+        util.setbinary(self.fout)
+
+    def getargs(self, args):
+        data = {}
+        keys = args.split()
+        for n in xrange(len(keys)):
+            argline = self.fin.readline()[:-1]
+            arg, l = argline.split()
+            if arg not in keys:
+                raise error.Abort(_("unexpected parameter %r") % arg)
+            if arg == '*':
+                star = {}
+                for k in xrange(int(l)):
+                    argline = self.fin.readline()[:-1]
+                    arg, l = argline.split()
+                    val = self.fin.read(int(l))
+                    star[arg] = val
+                data['*'] = star
+            else:
+                val = self.fin.read(int(l))
+                data[arg] = val
+        return [data[k] for k in keys]
+
+    def getarg(self, name):
+        return self.getargs(name)[0]
+
+    def getfile(self, fpout):
+        self.sendresponse('')
+        count = int(self.fin.readline())
+        while count:
+            fpout.write(self.fin.read(count))
+            count = int(self.fin.readline())
+
+    def redirect(self):
+        pass
+
+    def sendresponse(self, v):
+        self.fout.write("%d\n" % len(v))
+        self.fout.write(v)
+        self.fout.flush()
+
+    def sendstream(self, source):
+        write = self.fout.write
+        for chunk in source.gen:
+            write(chunk)
+        self.fout.flush()
+
+    def sendpushresponse(self, rsp):
+        self.sendresponse('')
+        self.sendresponse(str(rsp.res))
+
+    def sendpusherror(self, rsp):
+        self.sendresponse(rsp.res)
+
+    def sendooberror(self, rsp):
+        self.ui.ferr.write('%s\n-\n' % rsp.message)
+        self.ui.ferr.flush()
+        self.fout.write('\n')
+        self.fout.flush()
+
+    def serve_forever(self):
+        try:
+            while self.serve_one():
+                pass
+        finally:
+            if self.lock is not None:
+                self.lock.release()
+        sys.exit(0)
+
+    handlers = {
+        str: sendresponse,
+        wireproto.streamres: sendstream,
+        wireproto.streamres_legacy: sendstream,
+        wireproto.pushres: sendpushresponse,
+        wireproto.pusherr: sendpusherror,
+        wireproto.ooberror: sendooberror,
+    }
+
+    def serve_one(self):
+        cmd = self.fin.readline()[:-1]
+        if cmd and cmd in wireproto.commands:
+            rsp = wireproto.dispatch(self.repo, self, cmd)
+            self.handlers[rsp.__class__](self, rsp)
+        elif cmd:
+            impl = getattr(self, 'do_' + cmd, None)
+            if impl:
+                r = impl()
+                if r is not None:
+                    self.sendresponse(r)
+            else:
+                self.sendresponse("")
+        return cmd != ''
+
+    def _client(self):
+        client = encoding.environ.get('SSH_CLIENT', '').split(' ', 1)[0]
+        return 'remote:ssh:' + client