wireprotoserver: access headers through parsed request
authorGregory Szorc <gregory.szorc@gmail.com>
Thu, 08 Mar 2018 16:38:01 -0800
changeset 36852 14f70c44af6c
parent 36851 31581528f242
child 36853 ed0456fde625
wireprotoserver: access headers through parsed request Now that we can access headers via the parsed request object, let's do that. Since the new object uses bytes, hyphens, and is case-insensitive, a bit of code around normalizing values has been removed. I think the new code is much more intuitive because it more closely matches what is going out over the wire. Differential Revision: https://phab.mercurial-scm.org/D2743
mercurial/wireprotoserver.py
--- a/mercurial/wireprotoserver.py	Mon Mar 12 22:47:33 2018 +0900
+++ b/mercurial/wireprotoserver.py	Thu Mar 08 16:38:01 2018 -0800
@@ -36,16 +36,15 @@
 SSHV1 = wireprototypes.SSHV1
 SSHV2 = wireprototypes.SSHV2
 
-def decodevaluefromheaders(wsgireq, headerprefix):
+def decodevaluefromheaders(req, headerprefix):
     """Decode a long value from multiple HTTP request headers.
 
     Returns the value as a bytes, not a str.
     """
     chunks = []
     i = 1
-    prefix = headerprefix.upper().replace(r'-', r'_')
     while True:
-        v = wsgireq.env.get(r'HTTP_%s_%d' % (prefix, i))
+        v = req.headers.get(b'%s-%d' % (headerprefix, i))
         if v is None:
             break
         chunks.append(pycompat.bytesurl(v))
@@ -54,8 +53,9 @@
     return ''.join(chunks)
 
 class httpv1protocolhandler(wireprototypes.baseprotocolhandler):
-    def __init__(self, wsgireq, ui, checkperm):
+    def __init__(self, wsgireq, req, ui, checkperm):
         self._wsgireq = wsgireq
+        self._req = req
         self._ui = ui
         self._checkperm = checkperm
 
@@ -80,24 +80,24 @@
 
     def _args(self):
         args = util.rapply(pycompat.bytesurl, self._wsgireq.form.copy())
-        postlen = int(self._wsgireq.env.get(r'HTTP_X_HGARGS_POST', 0))
+        postlen = int(self._req.headers.get(b'X-HgArgs-Post', 0))
         if postlen:
             args.update(urlreq.parseqs(
                 self._wsgireq.read(postlen), keep_blank_values=True))
             return args
 
-        argvalue = decodevaluefromheaders(self._wsgireq, r'X-HgArg')
+        argvalue = decodevaluefromheaders(self._req, b'X-HgArg')
         args.update(urlreq.parseqs(argvalue, keep_blank_values=True))
         return args
 
     def forwardpayload(self, fp):
-        if r'HTTP_CONTENT_LENGTH' in self._wsgireq.env:
-            length = int(self._wsgireq.env[r'HTTP_CONTENT_LENGTH'])
+        if b'Content-Length' in self._req.headers:
+            length = int(self._req.headers[b'Content-Length'])
         else:
             length = int(self._wsgireq.env[r'CONTENT_LENGTH'])
         # If httppostargs is used, we need to read Content-Length
         # minus the amount that was consumed by args.
-        length -= int(self._wsgireq.env.get(r'HTTP_X_HGARGS_POST', 0))
+        length -= int(self._req.headers.get(b'X-HgArgs-Post', 0))
         for s in util.filechunkiter(self._wsgireq, limit=length):
             fp.write(s)
 
@@ -193,11 +193,11 @@
     if req.dispatchpath:
         res = _handlehttperror(
             hgwebcommon.ErrorResponse(hgwebcommon.HTTP_NOT_FOUND), wsgireq,
-            cmd)
+            req, cmd)
 
         return True, res
 
-    proto = httpv1protocolhandler(wsgireq, repo.ui,
+    proto = httpv1protocolhandler(wsgireq, req, repo.ui,
                                   lambda perm: checkperm(rctx, wsgireq, perm))
 
     # The permissions checker should be the only thing that can raise an
@@ -205,20 +205,20 @@
     # exception here. So consider refactoring into a exception type that
     # is associated with the wire protocol.
     try:
-        res = _callhttp(repo, wsgireq, proto, cmd)
+        res = _callhttp(repo, wsgireq, req, proto, cmd)
     except hgwebcommon.ErrorResponse as e:
-        res = _handlehttperror(e, wsgireq, cmd)
+        res = _handlehttperror(e, wsgireq, req, cmd)
 
     return True, res
 
-def _httpresponsetype(ui, wsgireq, prefer_uncompressed):
+def _httpresponsetype(ui, req, prefer_uncompressed):
     """Determine the appropriate response type and compression settings.
 
     Returns a tuple of (mediatype, compengine, engineopts).
     """
     # Determine the response media type and compression engine based
     # on the request parameters.
-    protocaps = decodevaluefromheaders(wsgireq, r'X-HgProto').split(' ')
+    protocaps = decodevaluefromheaders(req, 'X-HgProto').split(' ')
 
     if '0.2' in protocaps:
         # All clients are expected to support uncompressed data.
@@ -251,7 +251,7 @@
     opts = {'level': ui.configint('server', 'zliblevel')}
     return HGTYPE, util.compengines['zlib'], opts
 
-def _callhttp(repo, wsgireq, proto, cmd):
+def _callhttp(repo, wsgireq, req, proto, cmd):
     def genversion2(gen, engine, engineopts):
         # application/mercurial-0.2 always sends a payload header
         # identifying the compression engine.
@@ -289,7 +289,7 @@
         # This code for compression should not be streamres specific. It
         # is here because we only compress streamres at the moment.
         mediatype, engine, engineopts = _httpresponsetype(
-            repo.ui, wsgireq, rsp.prefer_uncompressed)
+            repo.ui, req, rsp.prefer_uncompressed)
         gen = engine.compressstream(gen, engineopts)
 
         if mediatype == HGTYPE2:
@@ -314,7 +314,7 @@
         return []
     raise error.ProgrammingError('hgweb.protocol internal failure', rsp)
 
-def _handlehttperror(e, wsgireq, cmd):
+def _handlehttperror(e, wsgireq, req, cmd):
     """Called when an ErrorResponse is raised during HTTP request processing."""
 
     # Clients using Python's httplib are stateful: the HTTP client
@@ -327,8 +327,7 @@
 
     if (wsgireq.env[r'REQUEST_METHOD'] == r'POST' and
         # But not if Expect: 100-continue is being used.
-        (wsgireq.env.get('HTTP_EXPECT',
-                         '').lower() != '100-continue')):
+        (req.headers.get('Expect', '').lower() != '100-continue')):
         wsgireq.drain()
     else:
         wsgireq.headers.append((r'Connection', r'Close'))