diff mercurial/wireprotoframing.py @ 37058:c5e9c3b47366

wireproto: support for receiving multiple requests Now that we have request IDs on each frame and a specification that allows multiple requests to be issued simultaneously, possibly interleaved, let's teach the server to deal with that. Instead of tracking the state for *the* active command request, we instead track the state of each receiving command by its request ID. The multiple states in our state machine for processing each command's state has been collapsed into a single state for "receiving commands." Tests have been added so our branch coverage covers all meaningful branches. However, we did lose some logical coverage. The implementation of this new feature opens up the door to a server having partial command requests when end of input is reached. We will probably want a mechanism to deal with partial requests. For now, I've tracked that as a known issue in the class docstring. I've also noted an abuse vector that becomes a little bit easier to exploit with this feature. Differential Revision: https://phab.mercurial-scm.org/D2870
author Gregory Szorc <gregory.szorc@gmail.com>
date Wed, 14 Mar 2018 16:53:30 -0700
parents 2ec1fb9de638
children 0a6c5cc09a88
line wrap: on
line diff
--- a/mercurial/wireprotoframing.py	Wed Mar 14 16:51:34 2018 -0700
+++ b/mercurial/wireprotoframing.py	Wed Mar 14 16:53:30 2018 -0700
@@ -327,6 +327,23 @@
 
     noop
        Indicates no additional action is required.
+
+    Known Issues
+    ------------
+
+    There are no limits to the number of partially received commands or their
+    size. A malicious client could stream command request data and exhaust the
+    server's memory.
+
+    Partially received commands are not acted upon when end of input is
+    reached. Should the server error if it receives a partial request?
+    Should the client send a message to abort a partially transmitted request
+    to facilitate graceful shutdown?
+
+    Active requests that haven't been responded to aren't tracked. This means
+    that if we receive a command and instruct its dispatch, another command
+    with its request ID can come in over the wire and there will be a race
+    between who responds to what.
     """
 
     def __init__(self, deferoutput=False):
@@ -342,14 +359,8 @@
         self._deferoutput = deferoutput
         self._state = 'idle'
         self._bufferedframegens = []
-        self._activerequestid = None
-        self._activecommand = None
-        self._activeargs = None
-        self._activedata = None
-        self._expectingargs = None
-        self._expectingdata = None
-        self._activeargname = None
-        self._activeargchunks = None
+        # request id -> dict of commands that are actively being received.
+        self._receivingcommands = {}
 
     def onframerecv(self, requestid, frametype, frameflags, payload):
         """Process a frame that has been received off the wire.
@@ -359,8 +370,7 @@
         """
         handlers = {
             'idle': self._onframeidle,
-            'command-receiving-args': self._onframereceivingargs,
-            'command-receiving-data': self._onframereceivingdata,
+            'command-receiving': self._onframecommandreceiving,
             'errored': self._onframeerrored,
         }
 
@@ -391,6 +401,8 @@
         No more frames will be received. All pending activity should be
         completed.
         """
+        # TODO should we do anything about in-flight commands?
+
         if not self._deferoutput or not self._bufferedframegens:
             return 'noop', {}
 
@@ -414,12 +426,20 @@
             'message': msg,
         }
 
-    def _makeruncommandresult(self):
+    def _makeruncommandresult(self, requestid):
+        entry = self._receivingcommands[requestid]
+        del self._receivingcommands[requestid]
+
+        if self._receivingcommands:
+            self._state = 'command-receiving'
+        else:
+            self._state = 'idle'
+
         return 'runcommand', {
-            'requestid': self._activerequestid,
-            'command': self._activecommand,
-            'args': self._activeargs,
-            'data': self._activedata.getvalue() if self._activedata else None,
+            'requestid': requestid,
+            'command': entry['command'],
+            'args': entry['args'],
+            'data': entry['data'].getvalue() if entry['data'] else None,
         }
 
     def _makewantframeresult(self):
@@ -435,34 +455,76 @@
             return self._makeerrorresult(
                 _('expected command frame; got %d') % frametype)
 
-        self._activerequestid = requestid
-        self._activecommand = payload
-        self._activeargs = {}
-        self._activedata = None
+        if requestid in self._receivingcommands:
+            self._state = 'errored'
+            return self._makeerrorresult(
+                _('request with ID %d already received') % requestid)
+
+        expectingargs = bool(frameflags & FLAG_COMMAND_NAME_HAVE_ARGS)
+        expectingdata = bool(frameflags & FLAG_COMMAND_NAME_HAVE_DATA)
+
+        self._receivingcommands[requestid] = {
+            'command': payload,
+            'args': {},
+            'data': None,
+            'expectingargs': expectingargs,
+            'expectingdata': expectingdata,
+        }
 
         if frameflags & FLAG_COMMAND_NAME_EOS:
-            return self._makeruncommandresult()
-
-        self._expectingargs = bool(frameflags & FLAG_COMMAND_NAME_HAVE_ARGS)
-        self._expectingdata = bool(frameflags & FLAG_COMMAND_NAME_HAVE_DATA)
+            return self._makeruncommandresult(requestid)
 
-        if self._expectingargs:
-            self._state = 'command-receiving-args'
-            return self._makewantframeresult()
-        elif self._expectingdata:
-            self._activedata = util.bytesio()
-            self._state = 'command-receiving-data'
+        if expectingargs or expectingdata:
+            self._state = 'command-receiving'
             return self._makewantframeresult()
         else:
             self._state = 'errored'
             return self._makeerrorresult(_('missing frame flags on '
                                            'command frame'))
 
-    def _onframereceivingargs(self, requestid, frametype, frameflags, payload):
-        if frametype != FRAME_TYPE_COMMAND_ARGUMENT:
+    def _onframecommandreceiving(self, requestid, frametype, frameflags,
+                                 payload):
+        # It could be a new command request. Process it as such.
+        if frametype == FRAME_TYPE_COMMAND_NAME:
+            return self._onframeidle(requestid, frametype, frameflags, payload)
+
+        # All other frames should be related to a command that is currently
+        # receiving.
+        if requestid not in self._receivingcommands:
             self._state = 'errored'
-            return self._makeerrorresult(_('expected command argument '
-                                           'frame; got %d') % frametype)
+            return self._makeerrorresult(
+                _('received frame for request that is not receiving: %d') %
+                  requestid)
+
+        entry = self._receivingcommands[requestid]
+
+        if frametype == FRAME_TYPE_COMMAND_ARGUMENT:
+            if not entry['expectingargs']:
+                self._state = 'errored'
+                return self._makeerrorresult(_(
+                    'received command argument frame for request that is not '
+                    'expecting arguments: %d') % requestid)
+
+            return self._handlecommandargsframe(requestid, entry, frametype,
+                                                frameflags, payload)
+
+        elif frametype == FRAME_TYPE_COMMAND_DATA:
+            if not entry['expectingdata']:
+                self._state = 'errored'
+                return self._makeerrorresult(_(
+                    'received command data frame for request that is not '
+                    'expecting data: %d') % requestid)
+
+            if entry['data'] is None:
+                entry['data'] = util.bytesio()
+
+            return self._handlecommanddataframe(requestid, entry, frametype,
+                                                frameflags, payload)
+
+    def _handlecommandargsframe(self, requestid, entry, frametype, frameflags,
+                                payload):
+        # The frame and state of command should have already been validated.
+        assert frametype == FRAME_TYPE_COMMAND_ARGUMENT
 
         offset = 0
         namesize, valuesize = ARGUMENT_FRAME_HEADER.unpack_from(payload)
@@ -483,10 +545,6 @@
         # and wait for the next frame.
         if frameflags & FLAG_COMMAND_ARGUMENT_CONTINUATION:
             raise error.ProgrammingError('not yet implemented')
-            self._activeargname = argname
-            self._activeargchunks = [argvalue]
-            self._state = 'command-arg-continuation'
-            return self._makewantframeresult()
 
         # Common case: the argument value is completely contained in this
         # frame.
@@ -496,36 +554,30 @@
             return self._makeerrorresult(_('malformed argument frame: '
                                            'partial argument value'))
 
-        self._activeargs[argname] = argvalue
+        entry['args'][argname] = argvalue
 
         if frameflags & FLAG_COMMAND_ARGUMENT_EOA:
-            if self._expectingdata:
-                self._state = 'command-receiving-data'
-                self._activedata = util.bytesio()
+            if entry['expectingdata']:
                 # TODO signal request to run a command once we don't
                 # buffer data frames.
                 return self._makewantframeresult()
             else:
-                self._state = 'waiting'
-                return self._makeruncommandresult()
+                return self._makeruncommandresult(requestid)
         else:
             return self._makewantframeresult()
 
-    def _onframereceivingdata(self, requestid, frametype, frameflags, payload):
-        if frametype != FRAME_TYPE_COMMAND_DATA:
-            self._state = 'errored'
-            return self._makeerrorresult(_('expected command data frame; '
-                                           'got %d') % frametype)
+    def _handlecommanddataframe(self, requestid, entry, frametype, frameflags,
+                                payload):
+        assert frametype == FRAME_TYPE_COMMAND_DATA
 
         # TODO support streaming data instead of buffering it.
-        self._activedata.write(payload)
+        entry['data'].write(payload)
 
         if frameflags & FLAG_COMMAND_DATA_CONTINUATION:
             return self._makewantframeresult()
         elif frameflags & FLAG_COMMAND_DATA_EOS:
-            self._activedata.seek(0)
-            self._state = 'idle'
-            return self._makeruncommandresult()
+            entry['data'].seek(0)
+            return self._makeruncommandresult(requestid)
         else:
             self._state = 'errored'
             return self._makeerrorresult(_('command data frame without '