hgweb: use a multidict for holding query string parameters
authorGregory Szorc <gregory.szorc@gmail.com>
Sat, 10 Mar 2018 12:35:38 -0800
changeset 36862 ec0af9c59270
parent 36861 a88d68dc3ee8
child 36863 1a1972b1a1ff
hgweb: use a multidict for holding query string parameters My intention with refactoring the WSGI code was to make it easier to read. I initially wanted to vendor and use WebOb, because it seems to be a pretty reasonable abstraction layer for WSGI. However, it isn't using relative imports and I didn't want to deal with the hassle of patching it. But that doesn't mean we can't use good ideas from WebOb. WebOb has a "multidict" data structure for holding parsed query string and POST form data. It quacks like a dict but allows you to store multiple values for each key. It offers mechanisms to return just one value, all values, or return 1 value asserting that only 1 value is set. I quite like its API. This commit implements a read-only "multidict" in the spirit of WebOb's multidict. We replace the query string attributes of our parsed request with an instance of it. Differential Revision: https://phab.mercurial-scm.org/D2776
mercurial/hgweb/request.py
mercurial/wireprotoserver.py
--- a/mercurial/hgweb/request.py	Sat Mar 10 11:23:05 2018 -0800
+++ b/mercurial/hgweb/request.py	Sat Mar 10 12:35:38 2018 -0800
@@ -28,6 +28,90 @@
     util,
 )
 
+class multidict(object):
+    """A dict like object that can store multiple values for a key.
+
+    Used to store parsed request parameters.
+
+    This is inspired by WebOb's class of the same name.
+    """
+    def __init__(self):
+        # Stores (key, value) 2-tuples. This isn't the most efficient. But we
+        # don't rely on parameters that much, so it shouldn't be a perf issue.
+        # we can always add dict for fast lookups.
+        self._items = []
+
+    def __getitem__(self, key):
+        """Returns the last set value for a key."""
+        for k, v in reversed(self._items):
+            if k == key:
+                return v
+
+        raise KeyError(key)
+
+    def __setitem__(self, key, value):
+        """Replace a values for a key with a new value."""
+        try:
+            del self[key]
+        except KeyError:
+            pass
+
+        self._items.append((key, value))
+
+    def __delitem__(self, key):
+        """Delete all values for a key."""
+        oldlen = len(self._items)
+
+        self._items[:] = [(k, v) for k, v in self._items if k != key]
+
+        if oldlen == len(self._items):
+            raise KeyError(key)
+
+    def __contains__(self, key):
+        return any(k == key for k, v in self._items)
+
+    def __len__(self):
+        return len(self._items)
+
+    def get(self, key, default=None):
+        try:
+            return self.__getitem__(key)
+        except KeyError:
+            return default
+
+    def add(self, key, value):
+        """Add a new value for a key. Does not replace existing values."""
+        self._items.append((key, value))
+
+    def getall(self, key):
+        """Obtains all values for a key."""
+        return [v for k, v in self._items if k == key]
+
+    def getone(self, key):
+        """Obtain a single value for a key.
+
+        Raises KeyError if key not defined or it has multiple values set.
+        """
+        vals = self.getall(key)
+
+        if not vals:
+            raise KeyError(key)
+
+        if len(vals) > 1:
+            raise KeyError('multiple values for %r' % key)
+
+        return vals[0]
+
+    def asdictoflists(self):
+        d = {}
+        for k, v in self._items:
+            if k in d:
+                d[k].append(v)
+            else:
+                d[k] = [v]
+
+        return d
+
 @attr.s(frozen=True)
 class parsedrequest(object):
     """Represents a parsed WSGI request.
@@ -56,10 +140,8 @@
     havepathinfo = attr.ib()
     # Raw query string (part after "?" in URL).
     querystring = attr.ib()
-    # List of 2-tuples of query string arguments.
-    querystringlist = attr.ib()
-    # Dict of query string arguments. Values are lists with at least 1 item.
-    querystringdict = attr.ib()
+    # multidict of query string parameters.
+    qsparams = attr.ib()
     # wsgiref.headers.Headers instance. Operates like a dict with case
     # insensitive keys.
     headers = attr.ib()
@@ -157,14 +239,9 @@
 
     # We store as a list so we have ordering information. We also store as
     # a dict to facilitate fast lookup.
-    querystringlist = util.urlreq.parseqsl(querystring, keep_blank_values=True)
-
-    querystringdict = {}
-    for k, v in querystringlist:
-        if k in querystringdict:
-            querystringdict[k].append(v)
-        else:
-            querystringdict[k] = [v]
+    qsparams = multidict()
+    for k, v in util.urlreq.parseqsl(querystring, keep_blank_values=True):
+        qsparams.add(k, v)
 
     # HTTP_* keys contain HTTP request headers. The Headers structure should
     # perform case normalization for us. We just rewrite underscore to dash
@@ -197,8 +274,7 @@
                          dispatchparts=dispatchparts, dispatchpath=dispatchpath,
                          havepathinfo='PATH_INFO' in env,
                          querystring=querystring,
-                         querystringlist=querystringlist,
-                         querystringdict=querystringdict,
+                         qsparams=qsparams,
                          headers=headers,
                          bodyfh=bodyfh)
 
@@ -350,7 +426,7 @@
         self.run_once = wsgienv[r'wsgi.run_once']
         self.env = wsgienv
         self.req = parserequestfromenv(wsgienv, inp)
-        self.form = self.req.querystringdict
+        self.form = self.req.qsparams.asdictoflists()
         self.res = wsgiresponse(self.req, start_response)
         self._start_response = start_response
         self.server_write = None
--- a/mercurial/wireprotoserver.py	Sat Mar 10 11:23:05 2018 -0800
+++ b/mercurial/wireprotoserver.py	Sat Mar 10 12:35:38 2018 -0800
@@ -79,7 +79,7 @@
         return [data[k] for k in keys]
 
     def _args(self):
-        args = util.rapply(pycompat.bytesurl, self._wsgireq.form.copy())
+        args = self._req.qsparams.asdictoflists()
         postlen = int(self._req.headers.get(b'X-HgArgs-Post', 0))
         if postlen:
             args.update(urlreq.parseqs(
@@ -170,10 +170,10 @@
     # HTTP version 1 wire protocol requests are denoted by a "cmd" query
     # string parameter. If it isn't present, this isn't a wire protocol
     # request.
-    if 'cmd' not in req.querystringdict:
+    if 'cmd' not in req.qsparams:
         return False
 
-    cmd = req.querystringdict['cmd'][0]
+    cmd = req.qsparams['cmd']
 
     # The "cmd" request parameter is used by both the wire protocol and hgweb.
     # While not all wire protocol commands are available for all transports,