Fix incorrect logic for hashed IN / NOT IN with non-strict operators

ExecEvalHashedScalarArrayOp(), when using a strict equality function,
performs a short-circuit when looking up NULL values.  When the function
is non-strict, the code incorrectly looked up the hash table for a
zero-valued Datum, which could have resulted in an accidental true
return if the hash table contained zero valued Datum, or could result
in a crash for non-byval types.

Here we fix this by adding an extra step when we build the hash table to
check what the result of a NULL lookup would be.  This requires looping
over the array and checking what the non-hashed version of the code
would do.  We cache the results of that in the expression so that we can
reuse the result any time we're asked to search for a NULL value.

It's important to note that non-strict equality functions are free to
treat any NULL value as equal to any non-NULL value.  For example,
someone may wish to design a type that treats an empty string and NULL
as equal.

All built-in types have strict equality functions, so this could affect
custom / user-defined types.

Author: Chengpeng Yan <chengpeng_yan@outlook.com>
Author: David Rowley <dgrowleyml@gmail.com>
Reviewed-by: ChangAo Chen <cca5507@qq.com>
Discussion: https://postgr.es/m/A16187AE-2359-4265-9F5E-71D015EC2B2D@outlook.com
Backpatch-through: 14
This commit is contained in:
David Rowley 2026-04-24 14:04:31 +12:00
parent 07e96aeff9
commit a2a0060d5d
4 changed files with 365 additions and 63 deletions

View file

@ -170,6 +170,14 @@ static Datum ExecJustAssignOuterVarVirt(ExprState *state, ExprContext *econtext,
static Datum ExecJustAssignScanVarVirt(ExprState *state, ExprContext *econtext, bool *isnull);
/* execution helper functions */
static pg_attribute_always_inline void ExecEvalArrayCompareInternal(FunctionCallInfo fcinfo,
ArrayType *arr,
int16 typlen,
bool typbyval,
char typalign,
bool useOr,
Datum *result,
bool *resultnull);
static pg_attribute_always_inline void ExecAggPlainTransByVal(AggState *aggstate,
AggStatePerTrans pertrans,
AggStatePerGroup pergroup,
@ -3363,12 +3371,6 @@ ExecEvalScalarArrayOp(ExprState *state, ExprEvalStep *op)
int nitems;
Datum result;
bool resultnull;
int16 typlen;
bool typbyval;
char typalign;
char *s;
bits8 *bitmap;
int bitmask;
/*
* If the array is NULL then we return NULL --- it's not very meaningful
@ -3417,13 +3419,42 @@ ExecEvalScalarArrayOp(ExprState *state, ExprEvalStep *op)
op->d.scalararrayop.element_type = ARR_ELEMTYPE(arr);
}
typlen = op->d.scalararrayop.typlen;
typbyval = op->d.scalararrayop.typbyval;
typalign = op->d.scalararrayop.typalign;
ExecEvalArrayCompareInternal(fcinfo,
arr,
op->d.scalararrayop.typlen,
op->d.scalararrayop.typbyval,
op->d.scalararrayop.typalign,
useOr,
&result,
&resultnull);
*op->resvalue = result;
*op->resnull = resultnull;
}
/*
* Shared helper for ExecEvalScalarArrayOp() and the NULL-LHS fallback for
* non-strict ExecEvalHashedScalarArrayOp().
*
* Callers must handle the strict LHS-is-NULL; return NULL fast path prior to
* calling this.
*/
static pg_attribute_always_inline void
ExecEvalArrayCompareInternal(FunctionCallInfo fcinfo, ArrayType *arr,
int16 typlen, bool typbyval, char typalign,
bool useOr, Datum *result, bool *resultnull)
{
int nitems;
char *s;
bits8 *bitmap;
int bitmask;
bool strictfunc = fcinfo->flinfo->fn_strict;
nitems = ArrayGetNItems(ARR_NDIM(arr), ARR_DIMS(arr));
/* Initialize result appropriately depending on useOr */
result = BoolGetDatum(!useOr);
resultnull = false;
*result = BoolGetDatum(!useOr);
*resultnull = false;
/* Loop over the array elements */
s = (char *) ARR_DATA_PTR(arr);
@ -3459,18 +3490,18 @@ ExecEvalScalarArrayOp(ExprState *state, ExprEvalStep *op)
else
{
fcinfo->isnull = false;
thisresult = op->d.scalararrayop.fn_addr(fcinfo);
thisresult = fcinfo->flinfo->fn_addr(fcinfo);
}
/* Combine results per OR or AND semantics */
if (fcinfo->isnull)
resultnull = true;
*resultnull = true;
else if (useOr)
{
if (DatumGetBool(thisresult))
{
result = BoolGetDatum(true);
resultnull = false;
*result = BoolGetDatum(true);
*resultnull = false;
break; /* needn't look at any more elements */
}
}
@ -3478,8 +3509,8 @@ ExecEvalScalarArrayOp(ExprState *state, ExprEvalStep *op)
{
if (!DatumGetBool(thisresult))
{
result = BoolGetDatum(false);
resultnull = false;
*result = BoolGetDatum(false);
*resultnull = false;
break; /* needn't look at any more elements */
}
}
@ -3495,9 +3526,6 @@ ExecEvalScalarArrayOp(ExprState *state, ExprEvalStep *op)
}
}
}
*op->resvalue = result;
*op->resnull = resultnull;
}
/*
@ -3576,7 +3604,7 @@ ExecEvalHashedScalarArrayOp(ExprState *state, ExprEvalStep *op, ExprContext *eco
* If the scalar is NULL, and the function is strict, return NULL; no
* point in executing the search.
*/
if (fcinfo->args[0].isnull && strictfunc)
if (scalar_isnull && strictfunc)
{
*op->resnull = true;
return;
@ -3674,8 +3702,51 @@ ExecEvalHashedScalarArrayOp(ExprState *state, ExprEvalStep *op, ExprContext *eco
* non-strict functions with a null lhs value if no match is found.
*/
op->d.hashedscalararrayop.has_nulls = has_nulls;
/*
* When we have a non-strict equality function, check and cache the
* result from looking up a NULL. Non-strict functions are free to
* treat a NULL as equal to any other value, e.g. a 0 or an empty
* string. Here we perform a linear search over the array and cache
* the outcome so that we can use that result any time we receive a
* NULL.
*/
if (!strictfunc)
{
bool null_lhs_result;
fcinfo->args[0].value = (Datum) 0;
fcinfo->args[0].isnull = true;
ExecEvalArrayCompareInternal(fcinfo, arr, typlen, typbyval,
typalign, true, &result,
&resultnull);
null_lhs_result = DatumGetBool(result);
/* invert non-NULL results for NOT IN */
if (!resultnull && !inclause)
null_lhs_result = !null_lhs_result;
op->d.hashedscalararrayop.null_lhs_isnull = resultnull;
op->d.hashedscalararrayop.null_lhs_result = null_lhs_result;
}
}
/*
* When looking up an SQL NULL value with non-strict functions, we defer
* to the value we cached when building the hash table.
*/
if (scalar_isnull)
{
Assert(!strictfunc);
*op->resnull = op->d.hashedscalararrayop.null_lhs_isnull;
*op->resvalue = BoolGetDatum(op->d.hashedscalararrayop.null_lhs_result);
return;
}
/* Check the hash to see if we have a match. */
hashfound = NULL != saophash_lookup(elements_tab->hashtab, scalar);

View file

@ -581,6 +581,10 @@ typedef struct ExprEvalStep
{
bool has_nulls;
bool inclause; /* true for IN and false for NOT IN */
bool null_lhs_result; /* for non-strict lookups, we
* cache what looking up NULL
* returns. */
bool null_lhs_isnull;
struct ScalarArrayOpExprHashTable *elements_tab;
FmgrInfo *finfo; /* function's lookup data */
FunctionCallInfo fcinfo_data; /* arguments etc */

View file

@ -382,42 +382,177 @@ default for type myint using hash as
operator 1 = (myint, myint),
function 1 myinthash(myint);
create table inttest (a myint);
insert into inttest values(1::myint),(null);
-- try an array with enough elements to cause hashing
select * from inttest where a in (1::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint,8::myint,9::myint, null);
a
---
1
(2 rows)
insert into inttest values (null), (0::myint), (1::myint);
-- Test EEOP_HASHED_SCALARARRAYOP against EEOP_SCALARARRAYOP. Ensure the
-- result of non-hashed vs hashed is the same.
select
a,
a in (1::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint,8::myint) as not_hashed,
a in (1::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint,8::myint,9::myint) as hashed
from inttest;
a | not_hashed | hashed
---+------------+--------
| |
0 | f | f
1 | t | t
(3 rows)
select * from inttest where a not in (1::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint,8::myint,9::myint, null);
a
---
(0 rows)
select
a,
a in (null::myint,1::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint) as not_hashed,
a in (null::myint,1::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint,8::myint) as hashed
from inttest;
a | not_hashed | hashed
---+------------+--------
| t | t
0 | |
1 | t | t
(3 rows)
select * from inttest where a not in (0::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint,8::myint,9::myint, null);
a
---
(0 rows)
select
a,
a not in (1::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint,8::myint) as not_hashed,
a not in (1::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint,8::myint,9::myint) as hashed
from inttest;
a | not_hashed | hashed
---+------------+--------
| |
0 | t | t
1 | f | f
(3 rows)
-- ensure the result matched with the non-hashed version. We simply remove
-- some array elements so that we don't reach the hashing threshold.
select * from inttest where a in (1::myint,2::myint,3::myint,4::myint,5::myint, null);
a
---
1
(2 rows)
select
a,
a not in (null::myint,1::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint) as not_hashed,
a not in (null::myint,1::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint,8::myint) as hashed
from inttest;
a | not_hashed | hashed
---+------------+--------
| f | f
0 | |
1 | f | f
(3 rows)
select * from inttest where a not in (1::myint,2::myint,3::myint,4::myint,5::myint, null);
a
---
(0 rows)
-- Now make the equal function return false when given two NULLs
create or replace function myinteq(myint, myint) returns bool as $$
begin
if $1 is null and $2 is null then
return false;
else
return $1::int = $2::int;
end if;
end;
$$ language plpgsql immutable;
-- And try the same again to ensure EEOP_HASHED_SCALARARRAYOP does the same
-- thing as EEOP_SCALARARRAYOP.
select
a,
a in (1::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint,8::myint) as not_hashed,
a in (1::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint,8::myint,9::myint) as hashed
from inttest;
a | not_hashed | hashed
---+------------+--------
| |
0 | f | f
1 | t | t
(3 rows)
select * from inttest where a not in (0::myint,2::myint,3::myint,4::myint,5::myint, null);
a
---
(0 rows)
select
a,
a in (null::myint,1::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint) as not_hashed,
a in (null::myint,1::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint,8::myint) as hashed
from inttest;
a | not_hashed | hashed
---+------------+--------
| |
0 | |
1 | t | t
(3 rows)
select
a,
a not in (1::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint,8::myint) as not_hashed,
a not in (1::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint,8::myint,9::myint) as hashed
from inttest;
a | not_hashed | hashed
---+------------+--------
| |
0 | t | t
1 | f | f
(3 rows)
select
a,
a not in (null::myint,1::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint) as not_hashed,
a not in (null::myint,1::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint,8::myint) as hashed
from inttest;
a | not_hashed | hashed
---+------------+--------
| |
0 | |
1 | f | f
(3 rows)
-- Try again with an equality function that treats NULLs as equal to 0.
create or replace function myinteq(myint, myint) returns bool as $$
begin
if $1 is null and $2 is null then
return false;
else
return coalesce($1::int,0) = coalesce($2::int, 0);
end if;
end;
$$ language plpgsql immutable;
select
a,
a in (1::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint,8::myint) as not_hashed,
a in (1::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint,8::myint,9::myint) as hashed,
a in (0::myint,1::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint) as not_hashed_zero,
a in (0::myint,1::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint,8::myint) as hashed_zero
from inttest;
a | not_hashed | hashed | not_hashed_zero | hashed_zero
---+------------+--------+-----------------+-------------
| f | f | t | t
0 | f | f | t | t
1 | t | t | t | t
(3 rows)
select
a,
a in (null::myint,1::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint) as not_hashed,
a in (null::myint,1::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint,8::myint) as hashed
from inttest;
a | not_hashed | hashed
---+------------+--------
| f | f
0 | t | t
1 | t | t
(3 rows)
select
a,
a not in (1::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint,8::myint) as not_hashed,
a not in (1::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint,8::myint,9::myint) as hashed,
a not in (0::myint,1::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint) as not_hashed_zero,
a not in (0::myint,1::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint,8::myint) as hashed_zero
from inttest;
a | not_hashed | hashed | not_hashed_zero | hashed_zero
---+------------+--------+-----------------+-------------
| t | t | f | f
0 | t | t | f | f
1 | f | f | f | f
(3 rows)
select
a,
a not in (null::myint,1::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint) as not_hashed,
a not in (null::myint,1::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint,8::myint) as hashed
from inttest;
a | not_hashed | hashed
---+------------+--------
| t | t
0 | f | f
1 | f | f
(3 rows)
rollback;

View file

@ -196,16 +196,108 @@ default for type myint using hash as
function 1 myinthash(myint);
create table inttest (a myint);
insert into inttest values(1::myint),(null);
insert into inttest values (null), (0::myint), (1::myint);
-- try an array with enough elements to cause hashing
select * from inttest where a in (1::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint,8::myint,9::myint, null);
select * from inttest where a not in (1::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint,8::myint,9::myint, null);
select * from inttest where a not in (0::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint,8::myint,9::myint, null);
-- ensure the result matched with the non-hashed version. We simply remove
-- some array elements so that we don't reach the hashing threshold.
select * from inttest where a in (1::myint,2::myint,3::myint,4::myint,5::myint, null);
select * from inttest where a not in (1::myint,2::myint,3::myint,4::myint,5::myint, null);
select * from inttest where a not in (0::myint,2::myint,3::myint,4::myint,5::myint, null);
-- Test EEOP_HASHED_SCALARARRAYOP against EEOP_SCALARARRAYOP. Ensure the
-- result of non-hashed vs hashed is the same.
select
a,
a in (1::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint,8::myint) as not_hashed,
a in (1::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint,8::myint,9::myint) as hashed
from inttest;
select
a,
a in (null::myint,1::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint) as not_hashed,
a in (null::myint,1::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint,8::myint) as hashed
from inttest;
select
a,
a not in (1::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint,8::myint) as not_hashed,
a not in (1::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint,8::myint,9::myint) as hashed
from inttest;
select
a,
a not in (null::myint,1::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint) as not_hashed,
a not in (null::myint,1::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint,8::myint) as hashed
from inttest;
-- Now make the equal function return false when given two NULLs
create or replace function myinteq(myint, myint) returns bool as $$
begin
if $1 is null and $2 is null then
return false;
else
return $1::int = $2::int;
end if;
end;
$$ language plpgsql immutable;
-- And try the same again to ensure EEOP_HASHED_SCALARARRAYOP does the same
-- thing as EEOP_SCALARARRAYOP.
select
a,
a in (1::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint,8::myint) as not_hashed,
a in (1::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint,8::myint,9::myint) as hashed
from inttest;
select
a,
a in (null::myint,1::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint) as not_hashed,
a in (null::myint,1::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint,8::myint) as hashed
from inttest;
select
a,
a not in (1::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint,8::myint) as not_hashed,
a not in (1::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint,8::myint,9::myint) as hashed
from inttest;
select
a,
a not in (null::myint,1::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint) as not_hashed,
a not in (null::myint,1::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint,8::myint) as hashed
from inttest;
-- Try again with an equality function that treats NULLs as equal to 0.
create or replace function myinteq(myint, myint) returns bool as $$
begin
if $1 is null and $2 is null then
return false;
else
return coalesce($1::int,0) = coalesce($2::int, 0);
end if;
end;
$$ language plpgsql immutable;
select
a,
a in (1::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint,8::myint) as not_hashed,
a in (1::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint,8::myint,9::myint) as hashed,
a in (0::myint,1::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint) as not_hashed_zero,
a in (0::myint,1::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint,8::myint) as hashed_zero
from inttest;
select
a,
a in (null::myint,1::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint) as not_hashed,
a in (null::myint,1::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint,8::myint) as hashed
from inttest;
select
a,
a not in (1::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint,8::myint) as not_hashed,
a not in (1::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint,8::myint,9::myint) as hashed,
a not in (0::myint,1::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint) as not_hashed_zero,
a not in (0::myint,1::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint,8::myint) as hashed_zero
from inttest;
select
a,
a not in (null::myint,1::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint) as not_hashed,
a not in (null::myint,1::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint,8::myint) as hashed
from inttest;
rollback;