Add an sshrepository class and hg serve --stdio
authorMatt Mackall <mpm@selenic.com>
Tue, 05 Jul 2005 17:55:22 -0800
changeset 624 876333a295ff
parent 623 314867960a4a
child 625 978011cf5279
Add an sshrepository class and hg serve --stdio
mercurial/commands.py
mercurial/hg.py
--- a/mercurial/commands.py	Tue Jul 05 17:50:43 2005 -0800
+++ b/mercurial/commands.py	Tue Jul 05 17:55:22 2005 -0800
@@ -9,7 +9,7 @@
 demandload(globals(), "os re sys signal")
 demandload(globals(), "fancyopts ui hg util")
 demandload(globals(), "hgweb mdiff random signal time traceback")
-demandload(globals(), "errno socket version")
+demandload(globals(), "errno socket version struct")
 
 class UnknownCommand(Exception): pass
 
@@ -823,9 +823,67 @@
 
 def serve(ui, repo, **opts):
     """export the repository via HTTP"""
+
+    if opts["stdio"]:
+        def getarg():
+            argline = sys.stdin.readline()[:-1]
+            arg, l = argline.split()
+            val = sys.stdin.read(int(l))
+            return arg, val
+        def respond(v):
+            sys.stdout.write("%d\n" % len(v))
+            sys.stdout.write(v)
+            sys.stdout.flush()
+
+        while 1:
+            cmd = sys.stdin.readline()[:-1]
+            if cmd == '':
+                return
+            if cmd == "heads":
+                h = repo.heads()
+                respond(" ".join(map(hg.hex, h)) + "\n")
+            elif cmd == "branches":
+                arg, nodes = getarg()
+                nodes = map(hg.bin, nodes.split(" "))
+                r = []
+                for b in repo.branches(nodes):
+                    r.append(" ".join(map(hg.hex, b)) + "\n")
+                respond("".join(r))
+            elif cmd == "between":
+                arg, pairs = getarg()
+                pairs = [ map(hg.bin, p.split("-")) for p in pairs.split(" ") ]
+                r = []
+                for b in repo.between(pairs):
+                    r.append(" ".join(map(hg.hex, b)) + "\n")
+                respond("".join(r))
+            elif cmd == "changegroup":
+                nodes = []
+                arg, roots = getarg()
+                nodes = map(hg.bin, roots.split(" "))
+
+                b = []
+                t = 0
+                for chunk in repo.changegroup(nodes):
+                    t += len(chunk)
+                    b.append(chunk)
+                    if t > 4096:
+                        sys.stdout.write(struct.pack(">l", t))
+                        for c in b:
+                            sys.stdout.write(c)
+                        t = 0
+                        b = []
+
+                sys.stdout.write(struct.pack(">l", t))
+                for c in b:
+                    sys.stdout.write(c)
+
+                sys.stdout.write(struct.pack(">l", -1))
+                sys.stdout.flush()
+
     def openlog(opt, default):
         if opts[opt] and opts[opt] != '-': return open(opts[opt], 'w')
         else: return default
+
     httpd = hgweb.create_server(repo.root, opts["name"], opts["templates"],
                                 opts["address"], opts["port"],
                                 openlog('accesslog', sys.stdout),
@@ -1017,6 +1075,7 @@
                        ('p', 'port', 8000, 'listen port'),
                        ('a', 'address', '', 'interface address'),
                        ('n', 'name', os.getcwd(), 'repository name'),
+                       ('', 'stdio', None, 'for remote clients'),
                        ('t', 'templates', "", 'template map')],
               "hg serve [options]"),
     "^status": (status, [], 'hg status'),
--- a/mercurial/hg.py	Tue Jul 05 17:50:43 2005 -0800
+++ b/mercurial/hg.py	Tue Jul 05 17:55:22 2005 -0800
@@ -1592,6 +1592,88 @@
             yield zd.decompress(d)
         self.ui.note("%d bytes of data transfered\n" % bytes)
 
+class sshrepository:
+    def __init__(self, ui, path):
+        self.url = path
+        self.ui = ui
+
+        m = re.match(r'ssh://(([^@]+)@)?([^:/]+)(:(\d+))?(/(.*))?', path)
+        if not m:
+            raise RepoError("couldn't parse destination %s\n" % path)
+
+        self.user = m.group(2)
+        self.host = m.group(3)
+        self.port = m.group(5)
+        self.path = m.group(7)
+
+        args = self.user and ("%s@%s" % (self.user, self.host)) or self.host
+        args = self.port and ("%s -p %s") % (args, self.port) or args
+        path = self.path or ""
+
+        cmd = "ssh %s 'hg -R %s serve --stdio'"
+        cmd = cmd % (args, path)
+
+        self.pipeo, self.pipei = os.popen2(cmd)
+
+    def __del__(self):
+        self.pipeo.close()
+        self.pipei.close()
+
+    def do_cmd(self, cmd, **args):
+        self.ui.debug("sending %s command\n" % cmd)
+        self.pipeo.write("%s\n" % cmd)
+        for k, v in args.items():
+            self.pipeo.write("%s %d\n" % (k, len(v)))
+            self.pipeo.write(v)
+        self.pipeo.flush()
+
+        return self.pipei
+
+    def call(self, cmd, **args):
+        r = self.do_cmd(cmd, **args)
+        l = int(r.readline())
+        return r.read(l)
+
+    def heads(self):
+        d = self.call("heads")
+        try:
+            return map(bin, d[:-1].split(" "))
+        except:
+            self.ui.warn("unexpected response:\n" + d[:400] + "\n...\n")
+            raise
+
+    def branches(self, nodes):
+        n = " ".join(map(hex, nodes))
+        d = self.call("branches", nodes=n)
+        try:
+            br = [ tuple(map(bin, b.split(" "))) for b in d.splitlines() ]
+            return br
+        except:
+            self.ui.warn("unexpected response:\n" + d[:400] + "\n...\n")
+            raise
+
+    def between(self, pairs):
+        n = "\n".join(["-".join(map(hex, p)) for p in pairs])
+        d = self.call("between", pairs=n)
+        try:
+            p = [ l and map(bin, l.split(" ")) or [] for l in d.splitlines() ]
+            return p
+        except:
+            self.ui.warn("unexpected response:\n" + d[:400] + "\n...\n")
+            raise
+
+    def changegroup(self, nodes):
+        n = " ".join(map(hex, nodes))
+        f = self.do_cmd("changegroup", roots=n)
+        bytes = 0
+        while 1:
+            l = struct.unpack(">l", f.read(4))[0]
+            if l == -1: break
+            d = f.read(l)
+            bytes += len(d)
+            yield d
+        self.ui.note("%d bytes of data transfered\n" % bytes)
+
 def repository(ui, path=None, create=0):
     if path:
         if path.startswith("http://"):
@@ -1600,5 +1682,7 @@
             return httprepository(ui, path.replace("hg://", "http://"))
         if path.startswith("old-http://"):
             return localrepository(ui, path.replace("old-http://", "http://"))
+        if path.startswith("ssh://"):
+            return sshrepository(ui, path)
 
     return localrepository(ui, path, create)