localrepo: acquire lock on stream_in
authorAdrian Buehlmann <adrian@cadifra.com>
Fri, 28 Jan 2011 13:34:07 +0100
changeset 13390 327719a44b6a
parent 13389 3efc99ac2ac4
child 13391 d00bbff8600e
localrepo: acquire lock on stream_in
mercurial/localrepo.py
--- a/mercurial/localrepo.py	Fri Jan 28 13:54:38 2011 +0100
+++ b/mercurial/localrepo.py	Fri Jan 28 13:34:07 2011 +0100
@@ -1909,59 +1909,63 @@
 
 
     def stream_in(self, remote, requirements):
-        fp = remote.stream_out()
-        l = fp.readline()
+        lock = self.lock()
         try:
-            resp = int(l)
-        except ValueError:
-            raise error.ResponseError(
-                _('Unexpected response from remote server:'), l)
-        if resp == 1:
-            raise util.Abort(_('operation forbidden by server'))
-        elif resp == 2:
-            raise util.Abort(_('locking the remote repository failed'))
-        elif resp != 0:
-            raise util.Abort(_('the server sent an unknown error code'))
-        self.ui.status(_('streaming all changes\n'))
-        l = fp.readline()
-        try:
-            total_files, total_bytes = map(int, l.split(' ', 1))
-        except (ValueError, TypeError):
-            raise error.ResponseError(
-                _('Unexpected response from remote server:'), l)
-        self.ui.status(_('%d files to transfer, %s of data\n') %
-                       (total_files, util.bytecount(total_bytes)))
-        start = time.time()
-        for i in xrange(total_files):
-            # XXX doesn't support '\n' or '\r' in filenames
+            fp = remote.stream_out()
             l = fp.readline()
             try:
-                name, size = l.split('\0', 1)
-                size = int(size)
+                resp = int(l)
+            except ValueError:
+                raise error.ResponseError(
+                    _('Unexpected response from remote server:'), l)
+            if resp == 1:
+                raise util.Abort(_('operation forbidden by server'))
+            elif resp == 2:
+                raise util.Abort(_('locking the remote repository failed'))
+            elif resp != 0:
+                raise util.Abort(_('the server sent an unknown error code'))
+            self.ui.status(_('streaming all changes\n'))
+            l = fp.readline()
+            try:
+                total_files, total_bytes = map(int, l.split(' ', 1))
             except (ValueError, TypeError):
                 raise error.ResponseError(
                     _('Unexpected response from remote server:'), l)
-            self.ui.debug('adding %s (%s)\n' % (name, util.bytecount(size)))
-            # for backwards compat, name was partially encoded
-            ofp = self.sopener(store.decodedir(name), 'w')
-            for chunk in util.filechunkiter(fp, limit=size):
-                ofp.write(chunk)
-            ofp.close()
-        elapsed = time.time() - start
-        if elapsed <= 0:
-            elapsed = 0.001
-        self.ui.status(_('transferred %s in %.1f seconds (%s/sec)\n') %
-                       (util.bytecount(total_bytes), elapsed,
-                        util.bytecount(total_bytes / elapsed)))
+            self.ui.status(_('%d files to transfer, %s of data\n') %
+                           (total_files, util.bytecount(total_bytes)))
+            start = time.time()
+            for i in xrange(total_files):
+                # XXX doesn't support '\n' or '\r' in filenames
+                l = fp.readline()
+                try:
+                    name, size = l.split('\0', 1)
+                    size = int(size)
+                except (ValueError, TypeError):
+                    raise error.ResponseError(
+                        _('Unexpected response from remote server:'), l)
+                self.ui.debug('adding %s (%s)\n' % (name, util.bytecount(size)))
+                # for backwards compat, name was partially encoded
+                ofp = self.sopener(store.decodedir(name), 'w')
+                for chunk in util.filechunkiter(fp, limit=size):
+                    ofp.write(chunk)
+                ofp.close()
+            elapsed = time.time() - start
+            if elapsed <= 0:
+                elapsed = 0.001
+            self.ui.status(_('transferred %s in %.1f seconds (%s/sec)\n') %
+                           (util.bytecount(total_bytes), elapsed,
+                            util.bytecount(total_bytes / elapsed)))
 
-        # new requirements = old non-format requirements + new format-related
-        # requirements from the streamed-in repository
-        requirements.update(set(self.requirements) - self.supportedformats)
-        self._applyrequirements(requirements)
-        self._writerequirements()
+            # new requirements = old non-format requirements + new format-related
+            # requirements from the streamed-in repository
+            requirements.update(set(self.requirements) - self.supportedformats)
+            self._applyrequirements(requirements)
+            self._writerequirements()
 
-        self.invalidate()
-        return len(self.heads()) + 1
+            self.invalidate()
+            return len(self.heads()) + 1
+        finally:
+            lock.release()
 
     def clone(self, remote, heads=[], stream=False):
         '''clone remote repository.