rust: speed up zstd decompression by re-using the decompression context
authorArseniy Alekseyev <aalekseyev@janestreet.com>
Thu, 18 May 2023 17:18:54 +0100
changeset 50540 74d8a1b03960
parent 50539 32b4c2bbdb94
child 50541 d1cab48354bc
rust: speed up zstd decompression by re-using the decompression context Admittedly, zstd is already pretty fast, but this change makes it a bit faster yet: it saves ~5% of time it takes to read our large repo. The actual motivating use case is treemanifest: in treemanifest we end up reading *lots* of small directories, and many of them need decompression, and there the saving for [rhg files] is >10%. (which also seems unreasonable, we should probably keep things uncompressed more)
rust/hg-core/src/revlog/mod.rs
--- a/rust/hg-core/src/revlog/mod.rs	Tue May 16 10:44:25 2023 +0200
+++ b/rust/hg-core/src/revlog/mod.rs	Thu May 18 17:18:54 2023 +0100
@@ -23,6 +23,7 @@
 
 use flate2::read::ZlibDecoder;
 use sha1::{Digest, Sha1};
+use std::cell::RefCell;
 use zstd;
 
 use self::node::{NODE_BYTES_LENGTH, NULL_NODE};
@@ -413,6 +414,21 @@
     hash: Node,
 }
 
+thread_local! {
+  // seems fine to [unwrap] here: this can only fail due to memory allocation
+  // failing, and it's normal for that to cause panic.
+  static ZSTD_DECODER : RefCell<zstd::bulk::Decompressor<'static>> =
+      RefCell::new(zstd::bulk::Decompressor::new().ok().unwrap());
+}
+
+fn zstd_decompress_to_buffer(
+    bytes: &[u8],
+    buf: &mut Vec<u8>,
+) -> Result<usize, std::io::Error> {
+    ZSTD_DECODER
+        .with(|decoder| decoder.borrow_mut().decompress_to_buffer(bytes, buf))
+}
+
 impl<'revlog> RevlogEntry<'revlog> {
     pub fn revision(&self) -> Revision {
         self.rev
@@ -588,7 +604,7 @@
         } else {
             let cap = self.uncompressed_len.max(0) as usize;
             let mut buf = vec![0; cap];
-            let len = zstd::bulk::decompress_to_buffer(self.bytes, &mut buf)
+            let len = zstd_decompress_to_buffer(self.bytes, &mut buf)
                 .map_err(|e| corrupted(e.to_string()))?;
             if len != self.uncompressed_len as usize {
                 Err(corrupted("uncompressed length does not match"))