wireproto: introduce type for raw byte responses (API)
Right now we simply return a str/bytes instance for simple
responses. I want all wire protocol response types to be strongly
typed. So let's invent and use a type for raw bytes responses.
.. api::
Wire protocol command handlers now return a
wireprototypes.bytesresponse instead of a raw bytes instance.
Protocol handlers will continue handling bytes instances. However,
any extensions wrapping wire protocol commands will need to handle
the new type.
Differential Revision: https://phab.mercurial-scm.org/D2089
--- a/hgext/largefiles/proto.py Wed Feb 07 16:29:05 2018 -0800
+++ b/hgext/largefiles/proto.py Wed Feb 07 20:27:36 2018 -0800
@@ -14,6 +14,7 @@
httppeer,
util,
wireproto,
+ wireprototypes,
)
from . import (
@@ -85,8 +86,8 @@
server side.'''
filename = lfutil.findfile(repo, sha)
if not filename:
- return '2\n'
- return '0\n'
+ return wireprototypes.bytesresponse('2\n')
+ return wireprototypes.bytesresponse('0\n')
def wirereposetup(ui, repo):
class lfileswirerepository(repo.__class__):
--- a/mercurial/wireproto.py Wed Feb 07 16:29:05 2018 -0800
+++ b/mercurial/wireproto.py Wed Feb 07 20:27:36 2018 -0800
@@ -37,6 +37,7 @@
urlerr = util.urlerr
urlreq = util.urlreq
+bytesresponse = wireprototypes.bytesresponse
ooberror = wireprototypes.ooberror
pushres = wireprototypes.pushres
pusherr = wireprototypes.pusherr
@@ -696,8 +697,15 @@
result = func(repo, proto)
if isinstance(result, ooberror):
return result
+
+ # For now, all batchable commands must return bytesresponse or
+ # raw bytes (for backwards compatibility).
+ assert isinstance(result, (bytesresponse, bytes))
+ if isinstance(result, bytesresponse):
+ result = result.data
res.append(escapearg(result))
- return ';'.join(res)
+
+ return bytesresponse(';'.join(res))
@wireprotocommand('between', 'pairs')
def between(repo, proto, pairs):
@@ -705,7 +713,8 @@
r = []
for b in repo.between(pairs):
r.append(encodelist(b) + "\n")
- return "".join(r)
+
+ return bytesresponse(''.join(r))
@wireprotocommand('branchmap')
def branchmap(repo, proto):
@@ -715,7 +724,8 @@
branchname = urlreq.quote(encoding.fromlocal(branch))
branchnodes = encodelist(nodes)
heads.append('%s %s' % (branchname, branchnodes))
- return '\n'.join(heads)
+
+ return bytesresponse('\n'.join(heads))
@wireprotocommand('branches', 'nodes')
def branches(repo, proto, nodes):
@@ -723,7 +733,8 @@
r = []
for b in repo.branches(nodes):
r.append(encodelist(b) + "\n")
- return "".join(r)
+
+ return bytesresponse(''.join(r))
@wireprotocommand('clonebundles', '')
def clonebundles(repo, proto):
@@ -735,7 +746,7 @@
depending on the request. e.g. you could advertise URLs for the closest
data center given the client's IP address.
"""
- return repo.vfs.tryread('clonebundles.manifest')
+ return bytesresponse(repo.vfs.tryread('clonebundles.manifest'))
wireprotocaps = ['lookup', 'changegroupsubset', 'branchmap', 'pushkey',
'known', 'getbundle', 'unbundlehash', 'batch']
@@ -789,7 +800,7 @@
# `_capabilities` instead.
@wireprotocommand('capabilities')
def capabilities(repo, proto):
- return ' '.join(_capabilities(repo, proto))
+ return bytesresponse(' '.join(_capabilities(repo, proto)))
@wireprotocommand('changegroup', 'roots')
def changegroup(repo, proto, roots):
@@ -814,7 +825,8 @@
def debugwireargs(repo, proto, one, two, others):
# only accept optional args from the known set
opts = options('debugwireargs', ['three', 'four'], others)
- return repo.debugwireargs(one, two, **pycompat.strkwargs(opts))
+ return bytesresponse(repo.debugwireargs(one, two,
+ **pycompat.strkwargs(opts)))
@wireprotocommand('getbundle', '*')
def getbundle(repo, proto, others):
@@ -885,7 +897,7 @@
@wireprotocommand('heads')
def heads(repo, proto):
h = repo.heads()
- return encodelist(h) + "\n"
+ return bytesresponse(encodelist(h) + '\n')
@wireprotocommand('hello')
def hello(repo, proto):
@@ -896,12 +908,13 @@
capabilities: space separated list of tokens
'''
- return "capabilities: %s\n" % (capabilities(repo, proto))
+ caps = capabilities(repo, proto).data
+ return bytesresponse('capabilities: %s\n' % caps)
@wireprotocommand('listkeys', 'namespace')
def listkeys(repo, proto, namespace):
d = repo.listkeys(encoding.tolocal(namespace)).items()
- return pushkeymod.encodekeys(d)
+ return bytesresponse(pushkeymod.encodekeys(d))
@wireprotocommand('lookup', 'key')
def lookup(repo, proto, key):
@@ -913,11 +926,12 @@
except Exception as inst:
r = str(inst)
success = 0
- return "%d %s\n" % (success, r)
+ return bytesresponse('%d %s\n' % (success, r))
@wireprotocommand('known', 'nodes *')
def known(repo, proto, nodes, others):
- return ''.join(b and "1" or "0" for b in repo.known(decodelist(nodes)))
+ v = ''.join(b and '1' or '0' for b in repo.known(decodelist(nodes)))
+ return bytesresponse(v)
@wireprotocommand('pushkey', 'namespace key old new')
def pushkey(repo, proto, namespace, key, old, new):
@@ -938,7 +952,7 @@
encoding.tolocal(old), new) or False
output = output.getvalue() if output else ''
- return '%s\n%s' % (int(r), output)
+ return bytesresponse('%s\n%s' % (int(r), output))
@wireprotocommand('stream_out')
def stream(repo, proto):
--- a/mercurial/wireprotoserver.py Wed Feb 07 16:29:05 2018 -0800
+++ b/mercurial/wireprotoserver.py Wed Feb 07 20:27:36 2018 -0800
@@ -274,6 +274,9 @@
if isinstance(rsp, bytes):
req.respond(HTTP_OK, HGTYPE, body=rsp)
return []
+ elif isinstance(rsp, wireprototypes.bytesresponse):
+ req.respond(HTTP_OK, HGTYPE, body=rsp.data)
+ return []
elif isinstance(rsp, wireprototypes.streamreslegacy):
gen = rsp.gen
req.respond(HTTP_OK, HGTYPE)
@@ -435,6 +438,8 @@
if isinstance(rsp, bytes):
_sshv1respondbytes(self._fout, rsp)
+ elif isinstance(rsp, wireprototypes.bytesresponse):
+ _sshv1respondbytes(self._fout, rsp.data)
elif isinstance(rsp, wireprototypes.streamres):
_sshv1respondstream(self._fout, rsp)
elif isinstance(rsp, wireprototypes.streamreslegacy):
--- a/mercurial/wireprototypes.py Wed Feb 07 16:29:05 2018 -0800
+++ b/mercurial/wireprototypes.py Wed Feb 07 20:27:36 2018 -0800
@@ -5,6 +5,11 @@
from __future__ import absolute_import
+class bytesresponse(object):
+ """A wire protocol response consisting of raw bytes."""
+ def __init__(self, data):
+ self.data = data
+
class ooberror(object):
"""wireproto reply: failure of a batch of operation
--- a/tests/sshprotoext.py Wed Feb 07 16:29:05 2018 -0800
+++ b/tests/sshprotoext.py Wed Feb 07 20:27:36 2018 -0800
@@ -49,7 +49,7 @@
l = self._fin.readline()
assert l == b'between\n'
rsp = wireproto.dispatch(self._repo, self._proto, b'between')
- wireprotoserver._sshv1respondbytes(self._fout, rsp)
+ wireprotoserver._sshv1respondbytes(self._fout, rsp.data)
super(prehelloserver, self).serve_forever()
@@ -74,7 +74,7 @@
# Send the upgrade response.
self._fout.write(b'upgraded %s %s\n' % (token, name))
servercaps = wireproto.capabilities(self._repo, self._proto)
- rsp = b'capabilities: %s' % servercaps
+ rsp = b'capabilities: %s' % servercaps.data
self._fout.write(b'%d\n' % len(rsp))
self._fout.write(rsp)
self._fout.write(b'\n')
--- a/tests/test-wireproto.py Wed Feb 07 16:29:05 2018 -0800
+++ b/tests/test-wireproto.py Wed Feb 07 20:27:36 2018 -0800
@@ -1,8 +1,10 @@
from __future__ import absolute_import, print_function
from mercurial import (
+ error,
util,
wireproto,
+ wireprototypes,
)
stringio = util.stringio
@@ -42,7 +44,13 @@
return ['batch']
def _call(self, cmd, **args):
- return wireproto.dispatch(self.serverrepo, proto(args), cmd)
+ res = wireproto.dispatch(self.serverrepo, proto(args), cmd)
+ if isinstance(res, wireprototypes.bytesresponse):
+ return res.data
+ elif isinstance(res, bytes):
+ return res
+ else:
+ raise error.Abort('dummy client does not support response type')
def _callstream(self, cmd, **args):
return stringio(self._call(cmd, **args))