contrib/python-zstandard/zstd/compress/hist.c
changeset 42070 675775c33ab6
parent 40122 73fef626dae3
--- a/contrib/python-zstandard/zstd/compress/hist.c	Thu Apr 04 15:24:03 2019 -0700
+++ b/contrib/python-zstandard/zstd/compress/hist.c	Thu Apr 04 17:34:43 2019 -0700
@@ -73,6 +73,7 @@
     return largestCount;
 }
 
+typedef enum { trustInput, checkMaxSymbolValue } HIST_checkInput_e;
 
 /* HIST_count_parallel_wksp() :
  * store histogram into 4 intermediate tables, recombined at the end.
@@ -85,8 +86,8 @@
 static size_t HIST_count_parallel_wksp(
                                 unsigned* count, unsigned* maxSymbolValuePtr,
                                 const void* source, size_t sourceSize,
-                                unsigned checkMax,
-                                unsigned* const workSpace)
+                                HIST_checkInput_e check,
+                                U32* const workSpace)
 {
     const BYTE* ip = (const BYTE*)source;
     const BYTE* const iend = ip+sourceSize;
@@ -137,7 +138,7 @@
     /* finish last symbols */
     while (ip<iend) Counting1[*ip++]++;
 
-    if (checkMax) {   /* verify stats will fit into destination table */
+    if (check) {   /* verify stats will fit into destination table */
         U32 s; for (s=255; s>maxSymbolValue; s--) {
             Counting1[s] += Counting2[s] + Counting3[s] + Counting4[s];
             if (Counting1[s]) return ERROR(maxSymbolValue_tooSmall);
@@ -157,14 +158,18 @@
 
 /* HIST_countFast_wksp() :
  * Same as HIST_countFast(), but using an externally provided scratch buffer.
- * `workSpace` size must be table of >= HIST_WKSP_SIZE_U32 unsigned */
+ * `workSpace` is a writable buffer which must be 4-bytes aligned,
+ * `workSpaceSize` must be >= HIST_WKSP_SIZE
+ */
 size_t HIST_countFast_wksp(unsigned* count, unsigned* maxSymbolValuePtr,
                           const void* source, size_t sourceSize,
-                          unsigned* workSpace)
+                          void* workSpace, size_t workSpaceSize)
 {
     if (sourceSize < 1500) /* heuristic threshold */
         return HIST_count_simple(count, maxSymbolValuePtr, source, sourceSize);
-    return HIST_count_parallel_wksp(count, maxSymbolValuePtr, source, sourceSize, 0, workSpace);
+    if ((size_t)workSpace & 3) return ERROR(GENERIC);  /* must be aligned on 4-bytes boundaries */
+    if (workSpaceSize < HIST_WKSP_SIZE) return ERROR(workSpace_tooSmall);
+    return HIST_count_parallel_wksp(count, maxSymbolValuePtr, source, sourceSize, trustInput, (U32*)workSpace);
 }
 
 /* fast variant (unsafe : won't check if src contains values beyond count[] limit) */
@@ -172,24 +177,27 @@
                      const void* source, size_t sourceSize)
 {
     unsigned tmpCounters[HIST_WKSP_SIZE_U32];
-    return HIST_countFast_wksp(count, maxSymbolValuePtr, source, sourceSize, tmpCounters);
+    return HIST_countFast_wksp(count, maxSymbolValuePtr, source, sourceSize, tmpCounters, sizeof(tmpCounters));
 }
 
 /* HIST_count_wksp() :
  * Same as HIST_count(), but using an externally provided scratch buffer.
  * `workSpace` size must be table of >= HIST_WKSP_SIZE_U32 unsigned */
 size_t HIST_count_wksp(unsigned* count, unsigned* maxSymbolValuePtr,
-                 const void* source, size_t sourceSize, unsigned* workSpace)
+                       const void* source, size_t sourceSize,
+                       void* workSpace, size_t workSpaceSize)
 {
+    if ((size_t)workSpace & 3) return ERROR(GENERIC);  /* must be aligned on 4-bytes boundaries */
+    if (workSpaceSize < HIST_WKSP_SIZE) return ERROR(workSpace_tooSmall);
     if (*maxSymbolValuePtr < 255)
-        return HIST_count_parallel_wksp(count, maxSymbolValuePtr, source, sourceSize, 1, workSpace);
+        return HIST_count_parallel_wksp(count, maxSymbolValuePtr, source, sourceSize, checkMaxSymbolValue, (U32*)workSpace);
     *maxSymbolValuePtr = 255;
-    return HIST_countFast_wksp(count, maxSymbolValuePtr, source, sourceSize, workSpace);
+    return HIST_countFast_wksp(count, maxSymbolValuePtr, source, sourceSize, workSpace, workSpaceSize);
 }
 
 size_t HIST_count(unsigned* count, unsigned* maxSymbolValuePtr,
                  const void* src, size_t srcSize)
 {
     unsigned tmpCounters[HIST_WKSP_SIZE_U32];
-    return HIST_count_wksp(count, maxSymbolValuePtr, src, srcSize, tmpCounters);
+    return HIST_count_wksp(count, maxSymbolValuePtr, src, srcSize, tmpCounters, sizeof(tmpCounters));
 }