protocol: wrap non-string protocol responses in classes
authorDirkjan Ochtman <dirkjan@ochtman.nl>
Tue, 20 Jul 2010 20:53:33 +0200
changeset 11625 cdeb861335d5
parent 11624 67260651d09d
child 11626 2f8adc60e013
protocol: wrap non-string protocol responses in classes
mercurial/hgweb/protocol.py
mercurial/sshserver.py
mercurial/wireproto.py
--- a/mercurial/hgweb/protocol.py	Fri Jul 16 22:20:19 2010 +0200
+++ b/mercurial/hgweb/protocol.py	Tue Jul 20 20:53:33 2010 +0200
@@ -48,13 +48,20 @@
         self.response = s
     def sendstream(self, source):
         self.req.respond(HTTP_OK, HGTYPE)
-        for chunk in source:
-            self.req.write(str(chunk))
-    def sendpushresponse(self, ret):
+        for chunk in source.gen:
+            self.req.write(chunk)
+    def sendpushresponse(self, rsp):
         val = sys.stdout.getvalue()
         sys.stdout, sys.stderr = self.oldio
         self.req.respond(HTTP_OK, HGTYPE)
-        self.response = '%d\n%s' % (ret, val)
+        self.response = '%d\n%s' % (rsp.res, val)
+
+    handlers = {
+        str: sendresponse,
+        wireproto.streamres: sendstream,
+        wireproto.pushres: sendpushresponse,
+    }
+
     def _client(self):
         return 'remote:%s:%s:%s' % (
             self.req.env.get('wsgi.url_scheme') or 'http',
@@ -66,5 +73,6 @@
 
 def call(repo, req, cmd):
     p = webproto(req)
-    wireproto.dispatch(repo, p, cmd)
-    yield p.response
+    rsp = wireproto.dispatch(repo, p, cmd)
+    webproto.handlers[rsp.__class__](p, rsp)
+    return [p.response]
--- a/mercurial/sshserver.py	Fri Jul 16 22:20:19 2010 +0200
+++ b/mercurial/sshserver.py	Tue Jul 20 20:53:33 2010 +0200
@@ -72,13 +72,13 @@
         self.fout.flush()
 
     def sendstream(self, source):
-        for chunk in source:
+        for chunk in source.gen:
             self.fout.write(chunk)
         self.fout.flush()
 
-    def sendpushresponse(self, ret):
+    def sendpushresponse(self, rsp):
         self.sendresponse('')
-        self.sendresponse(str(ret))
+        self.sendresponse(str(rsp.res))
 
     def serve_forever(self):
         try:
@@ -89,10 +89,17 @@
                 self.lock.release()
         sys.exit(0)
 
+    handlers = {
+        str: sendresponse,
+        wireproto.streamres: sendstream,
+        wireproto.pushres: sendpushresponse,
+    }
+
     def serve_one(self):
         cmd = self.fin.readline()[:-1]
         if cmd and cmd in wireproto.commands:
-            wireproto.dispatch(self.repo, self, cmd)
+            rsp = wireproto.dispatch(self.repo, self, cmd)
+            self.handlers[rsp.__class__](self, rsp)
         elif cmd:
             impl = getattr(self, 'do_' + cmd, None)
             if impl:
--- a/mercurial/wireproto.py	Fri Jul 16 22:20:19 2010 +0200
+++ b/mercurial/wireproto.py	Tue Jul 20 20:53:33 2010 +0200
@@ -133,12 +133,18 @@
 
 # server side
 
+class streamres(object):
+    def __init__(self, gen):
+        self.gen = gen
+
+class pushres(object):
+    def __init__(self, res):
+        self.res = res
+
 def dispatch(repo, proto, command):
     func, spec = commands[command]
     args = proto.getargs(spec)
-    r = func(repo, proto, *args)
-    if r != None:
-        proto.sendresponse(r)
+    return func(repo, proto, *args)
 
 def between(repo, proto, pairs):
     pairs = [decodelist(p, '-') for p in pairs.split(" ")]
@@ -173,13 +179,13 @@
 def changegroup(repo, proto, roots):
     nodes = decodelist(roots)
     cg = repo.changegroup(nodes, 'serve')
-    proto.sendstream(proto.groupchunks(cg))
+    return streamres(proto.groupchunks(cg))
 
 def changegroupsubset(repo, proto, bases, heads):
     bases = decodelist(bases)
     heads = decodelist(heads)
     cg = repo.changegroupsubset(bases, heads, 'serve')
-    proto.sendstream(proto.groupchunks(cg))
+    return streamres(proto.groupchunks(cg))
 
 def heads(repo, proto):
     h = repo.heads()
@@ -215,7 +221,7 @@
     return '%s\n' % int(r)
 
 def stream(repo, proto):
-    proto.sendstream(streamclone.stream_out(repo))
+    return streamres(streamclone.stream_out(repo))
 
 def unbundle(repo, proto, heads):
     their_heads = decodelist(heads)
@@ -259,7 +265,7 @@
                 sys.stderr.write("abort: %s\n" % inst)
         finally:
             lock.release()
-            proto.sendpushresponse(r)
+            return pushres(r)
 
     finally:
         fp.close()