Allow length=-1 for NUL-terminated input to pg_strncoll(), etc.
authorJeff Davis <[email protected]>
Tue, 24 Sep 2024 22:15:03 +0000 (15:15 -0700)
committerJeff Davis <[email protected]>
Tue, 24 Sep 2024 22:15:18 +0000 (15:15 -0700)
Like ICU, allow a length of -1 to be specified for NUL-terminated
arguments to pg_strncoll(), pg_strnxfrm(), and pg_strnxfrm_prefix().

Simplifies the code and comments.

Discussion: https://postgr.es/m/2d758e07dff26bcc7cbe2aec57431329bfe3679a[email protected]

src/backend/utils/adt/pg_locale.c
src/include/utils/pg_locale.h

index 8a7dde21398e78a0f20b72e820295f8ef892256c..f2a28d5ef5a7f5f484891f7ebfbfaa60911be4f0 100644 (file)
@@ -1803,30 +1803,37 @@ get_collation_actual_version(char collprovider, const char *collcollate)
 }
 
 /*
- * pg_strncoll_libc_win32_utf8
+ * strncoll_libc_win32_utf8
  *
  * Win32 does not have UTF-8. Convert UTF8 arguments to wide characters and
  * invoke wcscoll_l().
+ *
+ * An input string length of -1 means that it's NUL-terminated.
  */
 #ifdef WIN32
 static int
-pg_strncoll_libc_win32_utf8(const char *arg1, size_t len1, const char *arg2,
-                           size_t len2, pg_locale_t locale)
+strncoll_libc_win32_utf8(const char *arg1, ssize_t len1, const char *arg2,
+                        ssize_t len2, pg_locale_t locale)
 {
    char        sbuf[TEXTBUFLEN];
    char       *buf = sbuf;
    char       *a1p,
               *a2p;
-   int         a1len = len1 * 2 + 2;
-   int         a2len = len2 * 2 + 2;
+   int         a1len;
+   int         a2len;
    int         r;
    int         result;
 
    Assert(locale->provider == COLLPROVIDER_LIBC);
    Assert(GetDatabaseEncoding() == PG_UTF8);
-#ifndef WIN32
-   Assert(false);
-#endif
+
+   if (len1 == -1)
+       len1 = strlen(arg1);
+   if (len2 == -1)
+       len2 = strlen(arg2);
+
+   a1len = len1 * 2 + 2;
+   a2len = len2 * 2 + 2;
 
    if (a1len + a2len > TEXTBUFLEN)
        buf = palloc(a1len + a2len);
@@ -1875,50 +1882,22 @@ pg_strncoll_libc_win32_utf8(const char *arg1, size_t len1, const char *arg2,
 #endif                         /* WIN32 */
 
 /*
- * pg_strcoll_libc
+ * strncoll_libc
  *
- * Call strcoll_l() or wcscoll_l() as appropriate for the given locale,
- * platform, and database encoding. If the locale is NULL, use the database
- * collation.
+ * NUL-terminate arguments, if necessary, and pass to strcoll_l().
  *
- * Arguments must be encoded in the database encoding and nul-terminated.
+ * An input string length of -1 means that it's already NUL-terminated.
  */
 static int
-pg_strcoll_libc(const char *arg1, const char *arg2, pg_locale_t locale)
-{
-   int         result;
-
-   Assert(locale->provider == COLLPROVIDER_LIBC);
-#ifdef WIN32
-   if (GetDatabaseEncoding() == PG_UTF8)
-   {
-       size_t      len1 = strlen(arg1);
-       size_t      len2 = strlen(arg2);
-
-       result = pg_strncoll_libc_win32_utf8(arg1, len1, arg2, len2, locale);
-   }
-   else
-#endif                         /* WIN32 */
-       result = strcoll_l(arg1, arg2, locale->info.lt);
-
-   return result;
-}
-
-/*
- * pg_strncoll_libc
- *
- * Nul-terminate the arguments and call pg_strcoll_libc().
- */
-static int
-pg_strncoll_libc(const char *arg1, size_t len1, const char *arg2, size_t len2,
-                pg_locale_t locale)
+strncoll_libc(const char *arg1, ssize_t len1, const char *arg2, ssize_t len2,
+             pg_locale_t locale)
 {
    char        sbuf[TEXTBUFLEN];
    char       *buf = sbuf;
-   size_t      bufsize1 = len1 + 1;
-   size_t      bufsize2 = len2 + 1;
-   char       *arg1n;
-   char       *arg2n;
+   size_t      bufsize1 = (len1 == -1) ? 0 : len1 + 1;
+   size_t      bufsize2 = (len2 == -1) ? 0 : len2 + 1;
+   const char *arg1n;
+   const char *arg2n;
    int         result;
 
    Assert(locale->provider == COLLPROVIDER_LIBC);
@@ -1926,22 +1905,40 @@ pg_strncoll_libc(const char *arg1, size_t len1, const char *arg2, size_t len2,
 #ifdef WIN32
    /* check for this case before doing the work for nul-termination */
    if (GetDatabaseEncoding() == PG_UTF8)
-       return pg_strncoll_libc_win32_utf8(arg1, len1, arg2, len2, locale);
+       return strncoll_libc_win32_utf8(arg1, len1, arg2, len2, locale);
 #endif                         /* WIN32 */
 
    if (bufsize1 + bufsize2 > TEXTBUFLEN)
        buf = palloc(bufsize1 + bufsize2);
 
-   arg1n = buf;
-   arg2n = buf + bufsize1;
+   /* nul-terminate arguments if necessary */
+   if (len1 == -1)
+   {
+       arg1n = arg1;
+   }
+   else
+   {
+       char       *buf1 = buf;
+
+       memcpy(buf1, arg1, len1);
+       buf1[len1] = '\0';
+       arg1n = buf1;
+   }
+
+   if (len2 == -1)
+   {
+       arg2n = arg2;
+   }
+   else
+   {
+       char       *buf2 = buf + bufsize1;
 
-   /* nul-terminate arguments */
-   memcpy(arg1n, arg1, len1);
-   arg1n[len1] = '\0';
-   memcpy(arg2n, arg2, len2);
-   arg2n[len2] = '\0';
+       memcpy(buf2, arg2, len2);
+       buf2[len2] = '\0';
+       arg2n = buf2;
+   }
 
-   result = pg_strcoll_libc(arg1n, arg2n, locale);
+   result = strcoll_l(arg1n, arg2n, locale->info.lt);
 
    if (buf != sbuf)
        pfree(buf);
@@ -1952,7 +1949,7 @@ pg_strncoll_libc(const char *arg1, size_t len1, const char *arg2, size_t len2,
 #ifdef USE_ICU
 
 /*
- * pg_strncoll_icu_no_utf8
+ * strncoll_icu_no_utf8
  *
  * Convert the arguments from the database encoding to UChar strings, then
  * call ucol_strcoll(). An argument length of -1 means that the string is
@@ -1962,8 +1959,8 @@ pg_strncoll_libc(const char *arg1, size_t len1, const char *arg2, size_t len2,
  * caller should call that instead.
  */
 static int
-pg_strncoll_icu_no_utf8(const char *arg1, int32_t len1,
-                       const char *arg2, int32_t len2, pg_locale_t locale)
+strncoll_icu_no_utf8(const char *arg1, ssize_t len1,
+                    const char *arg2, ssize_t len2, pg_locale_t locale)
 {
    char        sbuf[TEXTBUFLEN];
    char       *buf = sbuf;
@@ -2008,17 +2005,15 @@ pg_strncoll_icu_no_utf8(const char *arg1, int32_t len1,
 }
 
 /*
- * pg_strncoll_icu
+ * strncoll_icu
  *
  * Call ucol_strcollUTF8() or ucol_strcoll() as appropriate for the given
  * database encoding. An argument length of -1 means the string is
  * NUL-terminated.
- *
- * Arguments must be encoded in the database encoding.
  */
 static int
-pg_strncoll_icu(const char *arg1, int32_t len1, const char *arg2, int32_t len2,
-               pg_locale_t locale)
+strncoll_icu(const char *arg1, ssize_t len1, const char *arg2, ssize_t len2,
+            pg_locale_t locale)
 {
    int         result;
 
@@ -2041,7 +2036,7 @@ pg_strncoll_icu(const char *arg1, int32_t len1, const char *arg2, int32_t len2,
    else
 #endif
    {
-       result = pg_strncoll_icu_no_utf8(arg1, len1, arg2, len2, locale);
+       result = strncoll_icu_no_utf8(arg1, len1, arg2, len2, locale);
    }
 
    return result;
@@ -2052,15 +2047,7 @@ pg_strncoll_icu(const char *arg1, int32_t len1, const char *arg2, int32_t len2,
 /*
  * pg_strcoll
  *
- * Call ucol_strcollUTF8(), ucol_strcoll(), strcoll_l() or wcscoll_l() as
- * appropriate for the given locale, platform, and database encoding. If the
- * locale is not specified, use the database collation.
- *
- * Arguments must be encoded in the database encoding and nul-terminated.
- *
- * The caller is responsible for breaking ties if the collation is
- * deterministic; this maintains consistency with pg_strxfrm(), which cannot
- * easily account for deterministic collations.
+ * Like pg_strncoll for NUL-terminated input strings.
  */
 int
 pg_strcoll(const char *arg1, const char *arg2, pg_locale_t locale)
@@ -2068,10 +2055,10 @@ pg_strcoll(const char *arg1, const char *arg2, pg_locale_t locale)
    int         result;
 
    if (locale->provider == COLLPROVIDER_LIBC)
-       result = pg_strcoll_libc(arg1, arg2, locale);
+       result = strncoll_libc(arg1, -1, arg2, -1, locale);
 #ifdef USE_ICU
    else if (locale->provider == COLLPROVIDER_ICU)
-       result = pg_strncoll_icu(arg1, -1, arg2, -1, locale);
+       result = strncoll_icu(arg1, -1, arg2, -1, locale);
 #endif
    else
        /* shouldn't happen */
@@ -2087,27 +2074,24 @@ pg_strcoll(const char *arg1, const char *arg2, pg_locale_t locale)
  * appropriate for the given locale, platform, and database encoding. If the
  * locale is not specified, use the database collation.
  *
- * Arguments must be encoded in the database encoding.
- *
- * This function may need to nul-terminate the arguments for libc functions;
- * so if the caller already has nul-terminated strings, it should call
- * pg_strcoll() instead.
+ * The input strings must be encoded in the database encoding. If an input
+ * string is NUL-terminated, its length may be specified as -1.
  *
  * The caller is responsible for breaking ties if the collation is
  * deterministic; this maintains consistency with pg_strnxfrm(), which cannot
  * easily account for deterministic collations.
  */
 int
-pg_strncoll(const char *arg1, size_t len1, const char *arg2, size_t len2,
+pg_strncoll(const char *arg1, ssize_t len1, const char *arg2, ssize_t len2,
            pg_locale_t locale)
 {
    int         result;
 
    if (locale->provider == COLLPROVIDER_LIBC)
-       result = pg_strncoll_libc(arg1, len1, arg2, len2, locale);
+       result = strncoll_libc(arg1, len1, arg2, len2, locale);
 #ifdef USE_ICU
    else if (locale->provider == COLLPROVIDER_ICU)
-       result = pg_strncoll_icu(arg1, len1, arg2, len2, locale);
+       result = strncoll_icu(arg1, len1, arg2, len2, locale);
 #endif
    else
        /* shouldn't happen */
@@ -2116,18 +2100,16 @@ pg_strncoll(const char *arg1, size_t len1, const char *arg2, size_t len2,
    return result;
 }
 
-
-static size_t
-pg_strxfrm_libc(char *dest, const char *src, size_t destsize,
-               pg_locale_t locale)
-{
-   Assert(locale->provider == COLLPROVIDER_LIBC);
-   return strxfrm_l(dest, src, destsize, locale->info.lt);
-}
-
+/*
+ * strnxfrm_libc
+ *
+ * NUL-terminate src, if necessary, and pass to strxfrm_l().
+ *
+ * A source length of -1 means that it's already NUL-terminated.
+ */
 static size_t
-pg_strnxfrm_libc(char *dest, const char *src, size_t srclen, size_t destsize,
-                pg_locale_t locale)
+strnxfrm_libc(char *dest, size_t destsize, const char *src, ssize_t srclen,
+             pg_locale_t locale)
 {
    char        sbuf[TEXTBUFLEN];
    char       *buf = sbuf;
@@ -2136,14 +2118,17 @@ pg_strnxfrm_libc(char *dest, const char *src, size_t srclen, size_t destsize,
 
    Assert(locale->provider == COLLPROVIDER_LIBC);
 
+   if (srclen == -1)
+       return strxfrm_l(dest, src, destsize, locale->info.lt);
+
    if (bufsize > TEXTBUFLEN)
        buf = palloc(bufsize);
 
-   /* nul-terminate arguments */
+   /* nul-terminate argument */
    memcpy(buf, src, srclen);
    buf[srclen] = '\0';
 
-   result = pg_strxfrm_libc(dest, buf, destsize, locale);
+   result = strxfrm_l(dest, buf, destsize, locale->info.lt);
 
    if (buf != sbuf)
        pfree(buf);
@@ -2158,8 +2143,8 @@ pg_strnxfrm_libc(char *dest, const char *src, size_t srclen, size_t destsize,
 
 /* 'srclen' of -1 means the strings are NUL-terminated */
 static size_t
-pg_strnxfrm_icu(char *dest, const char *src, int32_t srclen, int32_t destsize,
-               pg_locale_t locale)
+strnxfrm_icu(char *dest, size_t destsize, const char *src, ssize_t srclen,
+            pg_locale_t locale)
 {
    char        sbuf[TEXTBUFLEN];
    char       *buf = sbuf;
@@ -2205,8 +2190,9 @@ pg_strnxfrm_icu(char *dest, const char *src, int32_t srclen, int32_t destsize,
 
 /* 'srclen' of -1 means the strings are NUL-terminated */
 static size_t
-pg_strnxfrm_prefix_icu_no_utf8(char *dest, const char *src, int32_t srclen,
-                              int32_t destsize, pg_locale_t locale)
+strnxfrm_prefix_icu_no_utf8(char *dest, size_t destsize,
+                           const char *src, ssize_t srclen,
+                           pg_locale_t locale)
 {
    char        sbuf[TEXTBUFLEN];
    char       *buf = sbuf;
@@ -2253,8 +2239,9 @@ pg_strnxfrm_prefix_icu_no_utf8(char *dest, const char *src, int32_t srclen,
 
 /* 'srclen' of -1 means the strings are NUL-terminated */
 static size_t
-pg_strnxfrm_prefix_icu(char *dest, const char *src, int32_t srclen,
-                      int32_t destsize, pg_locale_t locale)
+strnxfrm_prefix_icu(char *dest, size_t destsize,
+                   const char *src, ssize_t srclen,
+                   pg_locale_t locale)
 {
    size_t      result;
 
@@ -2281,8 +2268,8 @@ pg_strnxfrm_prefix_icu(char *dest, const char *src, int32_t srclen,
                            u_errorName(status))));
    }
    else
-       result = pg_strnxfrm_prefix_icu_no_utf8(dest, src, srclen, destsize,
-                                               locale);
+       result = strnxfrm_prefix_icu_no_utf8(dest, destsize, src, srclen,
+                                            locale);
 
    return result;
 }
@@ -2324,20 +2311,7 @@ pg_strxfrm_enabled(pg_locale_t locale)
 /*
  * pg_strxfrm
  *
- * Transforms 'src' to a nul-terminated string stored in 'dest' such that
- * ordinary strcmp() on transformed strings is equivalent to pg_strcoll() on
- * untransformed strings.
- *
- * The provided 'src' must be nul-terminated. If 'destsize' is zero, 'dest'
- * may be NULL.
- *
- * Not all providers support pg_strxfrm() safely. The caller should check
- * pg_strxfrm_enabled() first, otherwise this function may return wrong
- * results or an error.
- *
- * Returns the number of bytes needed (or more) to store the transformed
- * string, excluding the terminating nul byte. If the value returned is
- * 'destsize' or greater, the resulting contents of 'dest' are undefined.
+ * Like pg_strnxfrm for a NUL-terminated input string.
  */
 size_t
 pg_strxfrm(char *dest, const char *src, size_t destsize, pg_locale_t locale)
@@ -2345,10 +2319,10 @@ pg_strxfrm(char *dest, const char *src, size_t destsize, pg_locale_t locale)
    size_t      result = 0;     /* keep compiler quiet */
 
    if (locale->provider == COLLPROVIDER_LIBC)
-       result = pg_strxfrm_libc(dest, src, destsize, locale);
+       result = strnxfrm_libc(dest, destsize, src, -1, locale);
 #ifdef USE_ICU
    else if (locale->provider == COLLPROVIDER_ICU)
-       result = pg_strnxfrm_icu(dest, src, -1, destsize, locale);
+       result = strnxfrm_icu(dest, destsize, src, -1, locale);
 #endif
    else
        /* shouldn't happen */
@@ -2364,8 +2338,9 @@ pg_strxfrm(char *dest, const char *src, size_t destsize, pg_locale_t locale)
  * ordinary strcmp() on transformed strings is equivalent to pg_strcoll() on
  * untransformed strings.
  *
- * 'src' does not need to be nul-terminated. If 'destsize' is zero, 'dest' may
- * be NULL.
+ * The input string must be encoded in the database encoding. If the input
+ * string is NUL-terminated, its length may be specified as -1. If 'destsize'
+ * is zero, 'dest' may be NULL.
  *
  * Not all providers support pg_strnxfrm() safely. The caller should check
  * pg_strxfrm_enabled() first, otherwise this function may return wrong
@@ -2374,22 +2349,18 @@ pg_strxfrm(char *dest, const char *src, size_t destsize, pg_locale_t locale)
  * Returns the number of bytes needed (or more) to store the transformed
  * string, excluding the terminating nul byte. If the value returned is
  * 'destsize' or greater, the resulting contents of 'dest' are undefined.
- *
- * This function may need to nul-terminate the argument for libc functions;
- * so if the caller already has a nul-terminated string, it should call
- * pg_strxfrm() instead.
  */
 size_t
-pg_strnxfrm(char *dest, size_t destsize, const char *src, size_t srclen,
+pg_strnxfrm(char *dest, size_t destsize, const char *src, ssize_t srclen,
            pg_locale_t locale)
 {
    size_t      result = 0;     /* keep compiler quiet */
 
    if (locale->provider == COLLPROVIDER_LIBC)
-       result = pg_strnxfrm_libc(dest, src, srclen, destsize, locale);
+       result = strnxfrm_libc(dest, destsize, src, srclen, locale);
 #ifdef USE_ICU
    else if (locale->provider == COLLPROVIDER_ICU)
-       result = pg_strnxfrm_icu(dest, src, srclen, destsize, locale);
+       result = strnxfrm_icu(dest, destsize, src, srclen, locale);
 #endif
    else
        /* shouldn't happen */
@@ -2419,44 +2390,24 @@ pg_strxfrm_prefix_enabled(pg_locale_t locale)
 /*
  * pg_strxfrm_prefix
  *
- * Transforms 'src' to a byte sequence stored in 'dest' such that ordinary
- * memcmp() on the byte sequence is equivalent to pg_strcoll() on
- * untransformed strings. The result is not nul-terminated.
- *
- * The provided 'src' must be nul-terminated.
- *
- * Not all providers support pg_strxfrm_prefix() safely. The caller should
- * check pg_strxfrm_prefix_enabled() first, otherwise this function may return
- * wrong results or an error.
- *
- * If destsize is not large enough to hold the resulting byte sequence, stores
- * only the first destsize bytes in 'dest'. Returns the number of bytes
- * actually copied to 'dest'.
+ * Like pg_strnxfrm_prefix for a NUL-terminated input string.
  */
 size_t
 pg_strxfrm_prefix(char *dest, const char *src, size_t destsize,
                  pg_locale_t locale)
 {
-   size_t      result = 0;     /* keep compiler quiet */
-
-#ifdef USE_ICU
-   if (locale->provider == COLLPROVIDER_ICU)
-       result = pg_strnxfrm_prefix_icu(dest, src, -1, destsize, locale);
-   else
-#endif
-       PGLOCALE_SUPPORT_ERROR(locale->provider);
-
-   return result;
+   return pg_strnxfrm_prefix(dest, destsize, src, -1, locale);
 }
 
 /*
  * pg_strnxfrm_prefix
  *
  * Transforms 'src' to a byte sequence stored in 'dest' such that ordinary
- * memcmp() on the byte sequence is equivalent to pg_strcoll() on
+ * memcmp() on the byte sequence is equivalent to pg_strncoll() on
  * untransformed strings. The result is not nul-terminated.
  *
- * The provided 'src' must be nul-terminated.
+ * The input string must be encoded in the database encoding. If the input
+ * string is NUL-terminated, its length may be specified as -1.
  *
  * Not all providers support pg_strnxfrm_prefix() safely. The caller should
  * check pg_strxfrm_prefix_enabled() first, otherwise this function may return
@@ -2465,20 +2416,16 @@ pg_strxfrm_prefix(char *dest, const char *src, size_t destsize,
  * If destsize is not large enough to hold the resulting byte sequence, stores
  * only the first destsize bytes in 'dest'. Returns the number of bytes
  * actually copied to 'dest'.
- *
- * This function may need to nul-terminate the argument for libc functions;
- * so if the caller already has a nul-terminated string, it should call
- * pg_strxfrm_prefix() instead.
  */
 size_t
 pg_strnxfrm_prefix(char *dest, size_t destsize, const char *src,
-                  size_t srclen, pg_locale_t locale)
+                  ssize_t srclen, pg_locale_t locale)
 {
    size_t      result = 0;     /* keep compiler quiet */
 
 #ifdef USE_ICU
    if (locale->provider == COLLPROVIDER_ICU)
-       result = pg_strnxfrm_prefix_icu(dest, src, -1, destsize, locale);
+       result = strnxfrm_prefix_icu(dest, destsize, src, -1, locale);
    else
 #endif
        PGLOCALE_SUPPORT_ERROR(locale->provider);
@@ -2661,6 +2608,8 @@ init_icu_converter(void)
 
 /*
  * Find length, in UChars, of given string if converted to UChar string.
+ *
+ * A length of -1 indicates that the input string is NUL-terminated.
  */
 static size_t
 uchar_length(UConverter *converter, const char *str, int32_t len)
@@ -2678,6 +2627,8 @@ uchar_length(UConverter *converter, const char *str, int32_t len)
 /*
  * Convert the given source string into a UChar string, stored in dest, and
  * return the length (in UChars).
+ *
+ * A srclen of -1 indicates that the input string is NUL-terminated.
  */
 static int32_t
 uchar_convert(UConverter *converter, UChar *dest, int32_t destlen,
index c2d95411e0a53eaf75852ec939ac8df61cc2f159..3b443df801434ddec79069d35449e1e553e55e60 100644 (file)
@@ -109,18 +109,18 @@ extern pg_locale_t pg_newlocale_from_collation(Oid collid);
 
 extern char *get_collation_actual_version(char collprovider, const char *collcollate);
 extern int pg_strcoll(const char *arg1, const char *arg2, pg_locale_t locale);
-extern int pg_strncoll(const char *arg1, size_t len1,
-                       const char *arg2, size_t len2, pg_locale_t locale);
+extern int pg_strncoll(const char *arg1, ssize_t len1,
+                       const char *arg2, ssize_t len2, pg_locale_t locale);
 extern bool pg_strxfrm_enabled(pg_locale_t locale);
 extern size_t pg_strxfrm(char *dest, const char *src, size_t destsize,
                         pg_locale_t locale);
 extern size_t pg_strnxfrm(char *dest, size_t destsize, const char *src,
-                         size_t srclen, pg_locale_t locale);
+                         ssize_t srclen, pg_locale_t locale);
 extern bool pg_strxfrm_prefix_enabled(pg_locale_t locale);
 extern size_t pg_strxfrm_prefix(char *dest, const char *src, size_t destsize,
                                pg_locale_t locale);
 extern size_t pg_strnxfrm_prefix(char *dest, size_t destsize, const char *src,
-                                size_t srclen, pg_locale_t locale);
+                                ssize_t srclen, pg_locale_t locale);
 
 extern int builtin_locale_encoding(const char *locale);
 extern const char *builtin_validate_locale(int encoding, const char *locale);