Prevent numeric overflows in parallel numeric aggregates.
authorDean Rasheed <[email protected]>
Mon, 5 Jul 2021 09:16:42 +0000 (10:16 +0100)
committerDean Rasheed <[email protected]>
Mon, 5 Jul 2021 09:16:42 +0000 (10:16 +0100)
Formerly various numeric aggregate functions supported parallel
aggregation by having each worker convert partial aggregate values to
Numeric and use numeric_send() as part of serializing their state.
That's problematic, since the range of Numeric is smaller than that of
NumericVar, so it's possible for it to overflow (on either side of the
decimal point) in cases that would succeed in non-parallel mode.

Fix by serializing NumericVars instead, to avoid the overflow risk and
ensure that parallel and non-parallel modes work the same.

A side benefit is that this improves the efficiency of the
serialization/deserialization code, which can make a noticeable
difference to performance with large numbers of parallel workers.

No back-patch due to risk from changing the binary format of the
aggregate serialization states, as well as lack of prior field
complaints and low probability of such overflows in practice.

Patch by me. Thanks to David Rowley for review and performance
testing, and Ranier Vilela for an additional suggestion.

Discussion: https://postgr.es/m/CAEZATCUmeFWCrq2dNzZpRj5+6LfN85jYiDoqm+ucSXhb9U2TbA@mail.gmail.com

src/backend/utils/adt/numeric.c
src/test/regress/expected/numeric.out
src/test/regress/sql/numeric.sql

index eb78f0b9c2a14fb9e4196bd0ccca9fca9e99efbc..bc71326fc8af59bbeb2a4e774d9b173439958f16 100644 (file)
@@ -515,6 +515,9 @@ static void set_var_from_var(const NumericVar *value, NumericVar *dest);
 static char *get_str_from_var(const NumericVar *var);
 static char *get_str_from_var_sci(const NumericVar *var, int rscale);
 
+static void numericvar_serialize(StringInfo buf, const NumericVar *var);
+static void numericvar_deserialize(StringInfo buf, NumericVar *var);
+
 static Numeric duplicate_numeric(Numeric num);
 static Numeric make_result(const NumericVar *var);
 static Numeric make_result_opt_error(const NumericVar *var, bool *error);
@@ -4943,8 +4946,6 @@ numeric_avg_serialize(PG_FUNCTION_ARGS)
 {
    NumericAggState *state;
    StringInfoData buf;
-   Datum       temp;
-   bytea      *sumX;
    bytea      *result;
    NumericVar  tmp_var;
 
@@ -4954,19 +4955,7 @@ numeric_avg_serialize(PG_FUNCTION_ARGS)
 
    state = (NumericAggState *) PG_GETARG_POINTER(0);
 
-   /*
-    * This is a little wasteful since make_result converts the NumericVar
-    * into a Numeric and numeric_send converts it back again. Is it worth
-    * splitting the tasks in numeric_send into separate functions to stop
-    * this? Doing so would also remove the fmgr call overhead.
-    */
    init_var(&tmp_var);
-   accum_sum_final(&state->sumX, &tmp_var);
-
-   temp = DirectFunctionCall1(numeric_send,
-                              NumericGetDatum(make_result(&tmp_var)));
-   sumX = DatumGetByteaPP(temp);
-   free_var(&tmp_var);
 
    pq_begintypsend(&buf);
 
@@ -4974,7 +4963,8 @@ numeric_avg_serialize(PG_FUNCTION_ARGS)
    pq_sendint64(&buf, state->N);
 
    /* sumX */
-   pq_sendbytes(&buf, VARDATA_ANY(sumX), VARSIZE_ANY_EXHDR(sumX));
+   accum_sum_final(&state->sumX, &tmp_var);
+   numericvar_serialize(&buf, &tmp_var);
 
    /* maxScale */
    pq_sendint32(&buf, state->maxScale);
@@ -4993,6 +4983,8 @@ numeric_avg_serialize(PG_FUNCTION_ARGS)
 
    result = pq_endtypsend(&buf);
 
+   free_var(&tmp_var);
+
    PG_RETURN_BYTEA_P(result);
 }
 
@@ -5006,15 +4998,16 @@ numeric_avg_deserialize(PG_FUNCTION_ARGS)
 {
    bytea      *sstate;
    NumericAggState *result;
-   Datum       temp;
-   NumericVar  tmp_var;
    StringInfoData buf;
+   NumericVar  tmp_var;
 
    if (!AggCheckCallContext(fcinfo, NULL))
        elog(ERROR, "aggregate function called in non-aggregate context");
 
    sstate = PG_GETARG_BYTEA_PP(0);
 
+   init_var(&tmp_var);
+
    /*
     * Copy the bytea into a StringInfo so that we can "receive" it using the
     * standard recv-function infrastructure.
@@ -5029,11 +5022,7 @@ numeric_avg_deserialize(PG_FUNCTION_ARGS)
    result->N = pq_getmsgint64(&buf);
 
    /* sumX */
-   temp = DirectFunctionCall3(numeric_recv,
-                              PointerGetDatum(&buf),
-                              ObjectIdGetDatum(InvalidOid),
-                              Int32GetDatum(-1));
-   init_var_from_num(DatumGetNumeric(temp), &tmp_var);
+   numericvar_deserialize(&buf, &tmp_var);
    accum_sum_add(&(result->sumX), &tmp_var);
 
    /* maxScale */
@@ -5054,6 +5043,8 @@ numeric_avg_deserialize(PG_FUNCTION_ARGS)
    pq_getmsgend(&buf);
    pfree(buf.data);
 
+   free_var(&tmp_var);
+
    PG_RETURN_POINTER(result);
 }
 
@@ -5067,11 +5058,8 @@ numeric_serialize(PG_FUNCTION_ARGS)
 {
    NumericAggState *state;
    StringInfoData buf;
-   Datum       temp;
-   bytea      *sumX;
-   NumericVar  tmp_var;
-   bytea      *sumX2;
    bytea      *result;
+   NumericVar  tmp_var;
 
    /* Ensure we disallow calling when not in aggregate context */
    if (!AggCheckCallContext(fcinfo, NULL))
@@ -5079,36 +5067,20 @@ numeric_serialize(PG_FUNCTION_ARGS)
 
    state = (NumericAggState *) PG_GETARG_POINTER(0);
 
-   /*
-    * This is a little wasteful since make_result converts the NumericVar
-    * into a Numeric and numeric_send converts it back again. Is it worth
-    * splitting the tasks in numeric_send into separate functions to stop
-    * this? Doing so would also remove the fmgr call overhead.
-    */
    init_var(&tmp_var);
 
-   accum_sum_final(&state->sumX, &tmp_var);
-   temp = DirectFunctionCall1(numeric_send,
-                              NumericGetDatum(make_result(&tmp_var)));
-   sumX = DatumGetByteaPP(temp);
-
-   accum_sum_final(&state->sumX2, &tmp_var);
-   temp = DirectFunctionCall1(numeric_send,
-                              NumericGetDatum(make_result(&tmp_var)));
-   sumX2 = DatumGetByteaPP(temp);
-
-   free_var(&tmp_var);
-
    pq_begintypsend(&buf);
 
    /* N */
    pq_sendint64(&buf, state->N);
 
    /* sumX */
-   pq_sendbytes(&buf, VARDATA_ANY(sumX), VARSIZE_ANY_EXHDR(sumX));
+   accum_sum_final(&state->sumX, &tmp_var);
+   numericvar_serialize(&buf, &tmp_var);
 
    /* sumX2 */
-   pq_sendbytes(&buf, VARDATA_ANY(sumX2), VARSIZE_ANY_EXHDR(sumX2));
+   accum_sum_final(&state->sumX2, &tmp_var);
+   numericvar_serialize(&buf, &tmp_var);
 
    /* maxScale */
    pq_sendint32(&buf, state->maxScale);
@@ -5127,6 +5099,8 @@ numeric_serialize(PG_FUNCTION_ARGS)
 
    result = pq_endtypsend(&buf);
 
+   free_var(&tmp_var);
+
    PG_RETURN_BYTEA_P(result);
 }
 
@@ -5140,16 +5114,16 @@ numeric_deserialize(PG_FUNCTION_ARGS)
 {
    bytea      *sstate;
    NumericAggState *result;
-   Datum       temp;
-   NumericVar  sumX_var;
-   NumericVar  sumX2_var;
    StringInfoData buf;
+   NumericVar  tmp_var;
 
    if (!AggCheckCallContext(fcinfo, NULL))
        elog(ERROR, "aggregate function called in non-aggregate context");
 
    sstate = PG_GETARG_BYTEA_PP(0);
 
+   init_var(&tmp_var);
+
    /*
     * Copy the bytea into a StringInfo so that we can "receive" it using the
     * standard recv-function infrastructure.
@@ -5164,20 +5138,12 @@ numeric_deserialize(PG_FUNCTION_ARGS)
    result->N = pq_getmsgint64(&buf);
 
    /* sumX */
-   temp = DirectFunctionCall3(numeric_recv,
-                              PointerGetDatum(&buf),
-                              ObjectIdGetDatum(InvalidOid),
-                              Int32GetDatum(-1));
-   init_var_from_num(DatumGetNumeric(temp), &sumX_var);
-   accum_sum_add(&(result->sumX), &sumX_var);
+   numericvar_deserialize(&buf, &tmp_var);
+   accum_sum_add(&(result->sumX), &tmp_var);
 
    /* sumX2 */
-   temp = DirectFunctionCall3(numeric_recv,
-                              PointerGetDatum(&buf),
-                              ObjectIdGetDatum(InvalidOid),
-                              Int32GetDatum(-1));
-   init_var_from_num(DatumGetNumeric(temp), &sumX2_var);
-   accum_sum_add(&(result->sumX2), &sumX2_var);
+   numericvar_deserialize(&buf, &tmp_var);
+   accum_sum_add(&(result->sumX2), &tmp_var);
 
    /* maxScale */
    result->maxScale = pq_getmsgint(&buf, 4);
@@ -5197,6 +5163,8 @@ numeric_deserialize(PG_FUNCTION_ARGS)
    pq_getmsgend(&buf);
    pfree(buf.data);
 
+   free_var(&tmp_var);
+
    PG_RETURN_POINTER(result);
 }
 
@@ -5459,9 +5427,8 @@ numeric_poly_serialize(PG_FUNCTION_ARGS)
 {
    PolyNumAggState *state;
    StringInfoData buf;
-   bytea      *sumX;
-   bytea      *sumX2;
    bytea      *result;
+   NumericVar  tmp_var;
 
    /* Ensure we disallow calling when not in aggregate context */
    if (!AggCheckCallContext(fcinfo, NULL))
@@ -5477,32 +5444,8 @@ numeric_poly_serialize(PG_FUNCTION_ARGS)
     * day we might like to send these over to another server for further
     * processing and we want a standard format to work with.
     */
-   {
-       Datum       temp;
-       NumericVar  num;
-
-       init_var(&num);
-
-#ifdef HAVE_INT128
-       int128_to_numericvar(state->sumX, &num);
-#else
-       accum_sum_final(&state->sumX, &num);
-#endif
-       temp = DirectFunctionCall1(numeric_send,
-                                  NumericGetDatum(make_result(&num)));
-       sumX = DatumGetByteaPP(temp);
 
-#ifdef HAVE_INT128
-       int128_to_numericvar(state->sumX2, &num);
-#else
-       accum_sum_final(&state->sumX2, &num);
-#endif
-       temp = DirectFunctionCall1(numeric_send,
-                                  NumericGetDatum(make_result(&num)));
-       sumX2 = DatumGetByteaPP(temp);
-
-       free_var(&num);
-   }
+   init_var(&tmp_var);
 
    pq_begintypsend(&buf);
 
@@ -5510,13 +5453,25 @@ numeric_poly_serialize(PG_FUNCTION_ARGS)
    pq_sendint64(&buf, state->N);
 
    /* sumX */
-   pq_sendbytes(&buf, VARDATA_ANY(sumX), VARSIZE_ANY_EXHDR(sumX));
+#ifdef HAVE_INT128
+   int128_to_numericvar(state->sumX, &tmp_var);
+#else
+   accum_sum_final(&state->sumX, &tmp_var);
+#endif
+   numericvar_serialize(&buf, &tmp_var);
 
    /* sumX2 */
-   pq_sendbytes(&buf, VARDATA_ANY(sumX2), VARSIZE_ANY_EXHDR(sumX2));
+#ifdef HAVE_INT128
+   int128_to_numericvar(state->sumX2, &tmp_var);
+#else
+   accum_sum_final(&state->sumX2, &tmp_var);
+#endif
+   numericvar_serialize(&buf, &tmp_var);
 
    result = pq_endtypsend(&buf);
 
+   free_var(&tmp_var);
+
    PG_RETURN_BYTEA_P(result);
 }
 
@@ -5530,17 +5485,16 @@ numeric_poly_deserialize(PG_FUNCTION_ARGS)
 {
    bytea      *sstate;
    PolyNumAggState *result;
-   Datum       sumX;
-   NumericVar  sumX_var;
-   Datum       sumX2;
-   NumericVar  sumX2_var;
    StringInfoData buf;
+   NumericVar  tmp_var;
 
    if (!AggCheckCallContext(fcinfo, NULL))
        elog(ERROR, "aggregate function called in non-aggregate context");
 
    sstate = PG_GETARG_BYTEA_PP(0);
 
+   init_var(&tmp_var);
+
    /*
     * Copy the bytea into a StringInfo so that we can "receive" it using the
     * standard recv-function infrastructure.
@@ -5555,34 +5509,26 @@ numeric_poly_deserialize(PG_FUNCTION_ARGS)
    result->N = pq_getmsgint64(&buf);
 
    /* sumX */
-   sumX = DirectFunctionCall3(numeric_recv,
-                              PointerGetDatum(&buf),
-                              ObjectIdGetDatum(InvalidOid),
-                              Int32GetDatum(-1));
-
-   /* sumX2 */
-   sumX2 = DirectFunctionCall3(numeric_recv,
-                               PointerGetDatum(&buf),
-                               ObjectIdGetDatum(InvalidOid),
-                               Int32GetDatum(-1));
-
-   init_var_from_num(DatumGetNumeric(sumX), &sumX_var);
+   numericvar_deserialize(&buf, &tmp_var);
 #ifdef HAVE_INT128
-   numericvar_to_int128(&sumX_var, &result->sumX);
+   numericvar_to_int128(&tmp_var, &result->sumX);
 #else
-   accum_sum_add(&result->sumX, &sumX_var);
+   accum_sum_add(&result->sumX, &tmp_var);
 #endif
 
-   init_var_from_num(DatumGetNumeric(sumX2), &sumX2_var);
+   /* sumX2 */
+   numericvar_deserialize(&buf, &tmp_var);
 #ifdef HAVE_INT128
-   numericvar_to_int128(&sumX2_var, &result->sumX2);
+   numericvar_to_int128(&tmp_var, &result->sumX2);
 #else
-   accum_sum_add(&result->sumX2, &sumX2_var);
+   accum_sum_add(&result->sumX2, &tmp_var);
 #endif
 
    pq_getmsgend(&buf);
    pfree(buf.data);
 
+   free_var(&tmp_var);
+
    PG_RETURN_POINTER(result);
 }
 
@@ -5681,8 +5627,8 @@ int8_avg_serialize(PG_FUNCTION_ARGS)
 {
    PolyNumAggState *state;
    StringInfoData buf;
-   bytea      *sumX;
    bytea      *result;
+   NumericVar  tmp_var;
 
    /* Ensure we disallow calling when not in aggregate context */
    if (!AggCheckCallContext(fcinfo, NULL))
@@ -5698,23 +5644,8 @@ int8_avg_serialize(PG_FUNCTION_ARGS)
     * like to send these over to another server for further processing and we
     * want a standard format to work with.
     */
-   {
-       Datum       temp;
-       NumericVar  num;
-
-       init_var(&num);
 
-#ifdef HAVE_INT128
-       int128_to_numericvar(state->sumX, &num);
-#else
-       accum_sum_final(&state->sumX, &num);
-#endif
-       temp = DirectFunctionCall1(numeric_send,
-                                  NumericGetDatum(make_result(&num)));
-       sumX = DatumGetByteaPP(temp);
-
-       free_var(&num);
-   }
+   init_var(&tmp_var);
 
    pq_begintypsend(&buf);
 
@@ -5722,10 +5653,17 @@ int8_avg_serialize(PG_FUNCTION_ARGS)
    pq_sendint64(&buf, state->N);
 
    /* sumX */
-   pq_sendbytes(&buf, VARDATA_ANY(sumX), VARSIZE_ANY_EXHDR(sumX));
+#ifdef HAVE_INT128
+   int128_to_numericvar(state->sumX, &tmp_var);
+#else
+   accum_sum_final(&state->sumX, &tmp_var);
+#endif
+   numericvar_serialize(&buf, &tmp_var);
 
    result = pq_endtypsend(&buf);
 
+   free_var(&tmp_var);
+
    PG_RETURN_BYTEA_P(result);
 }
 
@@ -5739,14 +5677,15 @@ int8_avg_deserialize(PG_FUNCTION_ARGS)
    bytea      *sstate;
    PolyNumAggState *result;
    StringInfoData buf;
-   Datum       temp;
-   NumericVar  num;
+   NumericVar  tmp_var;
 
    if (!AggCheckCallContext(fcinfo, NULL))
        elog(ERROR, "aggregate function called in non-aggregate context");
 
    sstate = PG_GETARG_BYTEA_PP(0);
 
+   init_var(&tmp_var);
+
    /*
     * Copy the bytea into a StringInfo so that we can "receive" it using the
     * standard recv-function infrastructure.
@@ -5761,20 +5700,18 @@ int8_avg_deserialize(PG_FUNCTION_ARGS)
    result->N = pq_getmsgint64(&buf);
 
    /* sumX */
-   temp = DirectFunctionCall3(numeric_recv,
-                              PointerGetDatum(&buf),
-                              ObjectIdGetDatum(InvalidOid),
-                              Int32GetDatum(-1));
-   init_var_from_num(DatumGetNumeric(temp), &num);
+   numericvar_deserialize(&buf, &tmp_var);
 #ifdef HAVE_INT128
-   numericvar_to_int128(&num, &result->sumX);
+   numericvar_to_int128(&tmp_var, &result->sumX);
 #else
-   accum_sum_add(&result->sumX, &num);
+   accum_sum_add(&result->sumX, &tmp_var);
 #endif
 
    pq_getmsgend(&buf);
    pfree(buf.data);
 
+   free_var(&tmp_var);
+
    PG_RETURN_POINTER(result);
 }
 
@@ -7286,6 +7223,48 @@ get_str_from_var_sci(const NumericVar *var, int rscale)
 }
 
 
+/*
+ * numericvar_serialize - serialize NumericVar to binary format
+ *
+ * At variable level, no checks are performed on the weight or dscale, allowing
+ * us to pass around intermediate values with higher precision than supported
+ * by the numeric type.  Note: this is incompatible with numeric_send/recv(),
+ * which use 16-bit integers for these fields.
+ */
+static void
+numericvar_serialize(StringInfo buf, const NumericVar *var)
+{
+   int         i;
+
+   pq_sendint32(buf, var->ndigits);
+   pq_sendint32(buf, var->weight);
+   pq_sendint32(buf, var->sign);
+   pq_sendint32(buf, var->dscale);
+   for (i = 0; i < var->ndigits; i++)
+       pq_sendint16(buf, var->digits[i]);
+}
+
+/*
+ * numericvar_deserialize - deserialize binary format to NumericVar
+ */
+static void
+numericvar_deserialize(StringInfo buf, NumericVar *var)
+{
+   int         len,
+               i;
+
+   len = pq_getmsgint(buf, sizeof(int32));
+
+   alloc_var(var, len);        /* sets var->ndigits */
+
+   var->weight = pq_getmsgint(buf, sizeof(int32));
+   var->sign = pq_getmsgint(buf, sizeof(int32));
+   var->dscale = pq_getmsgint(buf, sizeof(int32));
+   for (i = 0; i < len; i++)
+       var->digits[i] = pq_getmsgint(buf, sizeof(int16));
+}
+
+
 /*
  * duplicate_numeric() - copy a packed-format Numeric
  *
index 30a5642b95849ef06e3ce9982965698697303322..4ad485130bda0965b17d29f656af069665edd378 100644 (file)
@@ -2966,6 +2966,56 @@ SELECT SUM((-9999)::numeric) FROM generate_series(1, 100000);
  -999900000
 (1 row)
 
+--
+-- Tests for VARIANCE()
+--
+CREATE TABLE num_variance (a numeric);
+INSERT INTO num_variance VALUES (0);
+INSERT INTO num_variance VALUES (3e-500);
+INSERT INTO num_variance VALUES (-3e-500);
+INSERT INTO num_variance VALUES (4e-500 - 1e-16383);
+INSERT INTO num_variance VALUES (-4e-500 + 1e-16383);
+-- variance is just under 12.5e-1000 and so should round down to 12e-1000
+SELECT trim_scale(variance(a) * 1e1000) FROM num_variance;
+ trim_scale 
+------------
+         12
+(1 row)
+
+-- check that parallel execution produces the same result
+BEGIN;
+ALTER TABLE num_variance SET (parallel_workers = 4);
+SET LOCAL parallel_setup_cost = 0;
+SET LOCAL max_parallel_workers_per_gather = 4;
+SELECT trim_scale(variance(a) * 1e1000) FROM num_variance;
+ trim_scale 
+------------
+         12
+(1 row)
+
+ROLLBACK;
+-- case where sum of squares would overflow but variance does not
+DELETE FROM num_variance;
+INSERT INTO num_variance SELECT 9e131071 + x FROM generate_series(1, 5) x;
+SELECT variance(a) FROM num_variance;
+      variance      
+--------------------
+ 2.5000000000000000
+(1 row)
+
+-- check that parallel execution produces the same result
+BEGIN;
+ALTER TABLE num_variance SET (parallel_workers = 4);
+SET LOCAL parallel_setup_cost = 0;
+SET LOCAL max_parallel_workers_per_gather = 4;
+SELECT variance(a) FROM num_variance;
+      variance      
+--------------------
+ 2.5000000000000000
+(1 row)
+
+ROLLBACK;
+DROP TABLE num_variance;
 --
 -- Tests for GCD()
 --
index db812c813a39b505682de2bf1aa31c8a58f47854..3784c5253d7ca30a6b2b8dc01965dbc81f84fe9d 100644 (file)
@@ -1277,6 +1277,42 @@ select trim_scale(1e100);
 SELECT SUM(9999::numeric) FROM generate_series(1, 100000);
 SELECT SUM((-9999)::numeric) FROM generate_series(1, 100000);
 
+--
+-- Tests for VARIANCE()
+--
+CREATE TABLE num_variance (a numeric);
+INSERT INTO num_variance VALUES (0);
+INSERT INTO num_variance VALUES (3e-500);
+INSERT INTO num_variance VALUES (-3e-500);
+INSERT INTO num_variance VALUES (4e-500 - 1e-16383);
+INSERT INTO num_variance VALUES (-4e-500 + 1e-16383);
+
+-- variance is just under 12.5e-1000 and so should round down to 12e-1000
+SELECT trim_scale(variance(a) * 1e1000) FROM num_variance;
+
+-- check that parallel execution produces the same result
+BEGIN;
+ALTER TABLE num_variance SET (parallel_workers = 4);
+SET LOCAL parallel_setup_cost = 0;
+SET LOCAL max_parallel_workers_per_gather = 4;
+SELECT trim_scale(variance(a) * 1e1000) FROM num_variance;
+ROLLBACK;
+
+-- case where sum of squares would overflow but variance does not
+DELETE FROM num_variance;
+INSERT INTO num_variance SELECT 9e131071 + x FROM generate_series(1, 5) x;
+SELECT variance(a) FROM num_variance;
+
+-- check that parallel execution produces the same result
+BEGIN;
+ALTER TABLE num_variance SET (parallel_workers = 4);
+SET LOCAL parallel_setup_cost = 0;
+SET LOCAL max_parallel_workers_per_gather = 4;
+SELECT variance(a) FROM num_variance;
+ROLLBACK;
+
+DROP TABLE num_variance;
+
 --
 -- Tests for GCD()
 --