Check for interrupts and stack overflow in TParserGet().
authorTom Lane <[email protected]>
Sat, 24 Jun 2023 21:18:08 +0000 (17:18 -0400)
committerTom Lane <[email protected]>
Sat, 24 Jun 2023 21:18:08 +0000 (17:18 -0400)
TParserGet() recurses for some token types, meaning it's possible
to drive it to stack overflow.  Since this is a minority behavior,
I chose to add the check_stack_depth() call to the two places that
recurse rather than doing it during every single call.

While at it, add CHECK_FOR_INTERRUPTS(), because this can run
unpleasantly long for long inputs.

Per bug #17995 from Zuming Jiang.  This is old, so back-patch
to all supported branches.

Discussion: https://postgr.es/m/17995-9f20ff3e6389db4c@postgresql.org

src/backend/tsearch/wparser_def.c

index 23e4e9d98a9efcae9ec9ade10208e53993910c6a..fb80fdd63f23706aebf5500ad876e8c73bbe5bf6 100644 (file)
@@ -18,6 +18,7 @@
 
 #include "catalog/pg_collation.h"
 #include "commands/defrem.h"
+#include "miscadmin.h"
 #include "tsearch/ts_locale.h"
 #include "tsearch/ts_public.h"
 #include "tsearch/ts_type.h"
@@ -631,6 +632,12 @@ p_ishost(TParser *prs)
 
    tmpprs->wanthost = true;
 
+   /*
+    * Check stack depth before recursing.  (Since TParserGet() doesn't
+    * normally recurse, we put the cost of checking here not there.)
+    */
+   check_stack_depth();
+
    if (TParserGet(tmpprs) && tmpprs->type == HOST)
    {
        prs->state->posbyte += tmpprs->lenbytetoken;
@@ -654,6 +661,12 @@ p_isURLPath(TParser *prs)
    tmpprs->state = newTParserPosition(tmpprs->state);
    tmpprs->state->state = TPS_InURLPathFirst;
 
+   /*
+    * Check stack depth before recursing.  (Since TParserGet() doesn't
+    * normally recurse, we put the cost of checking here not there.)
+    */
+   check_stack_depth();
+
    if (TParserGet(tmpprs) && tmpprs->type == URLPATH)
    {
        prs->state->posbyte += tmpprs->lenbytetoken;
@@ -1697,6 +1710,8 @@ TParserGet(TParser *prs)
 {
    const TParserStateActionItem *item = NULL;
 
+   CHECK_FOR_INTERRUPTS();
+
    Assert(prs->state);
 
    if (prs->state->posbyte >= prs->lenstr)