Skip to content

Commit d40e328

Browse files
authored
Merge pull request #272 from data-engineering-collective/not_in_predicate
Add `not in` predicate operation
2 parents 3938f4c + 01a7790 commit d40e328

File tree

6 files changed

+114
-15
lines changed

6 files changed

+114
-15
lines changed

CHANGES.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ Changelog
55
Plateau 4.6.2 (2025-08-XX)
66
==========================
77

8+
* Add support for `not in` predicate operation.
89
* Add further validation for predicates to raise errors if operators are misused with non-scalar values
910

1011

plateau/serialization/_generic.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -206,12 +206,13 @@ def check_predicates(predicates: PredicatesType) -> None:
206206
f"with null value and operator '{op}'. Only operators supporting null values "
207207
"are '==', '!=', 'in' and 'is distinct from'."
208208
)
209-
if op == "in" and pd.api.types.is_scalar(val):
210-
raise ValueError(
211-
f"Invalid predicates in clause {clause_idx} in conjunction {conjunction_idx} "
212-
f"with operator '{op}' must be used with a tuple or list, got {type(val)} instead."
213-
)
214-
if op != "in" and is_list_like(val):
209+
if op in ("in", "not in"):
210+
if pd.api.types.is_scalar(val):
211+
raise ValueError(
212+
f"Invalid predicates in clause {clause_idx} in conjunction {conjunction_idx} "
213+
f"with operator '{op}' must be used with a tuple or list, got {type(val)} instead."
214+
)
215+
elif is_list_like(val):
215216
raise ValueError(
216217
f"Invalid predicates in clause {clause_idx} in conjunction {conjunction_idx} "
217218
f"with operator '{op}' must be used with a scalar type, got {type(val)} instead."
@@ -515,7 +516,8 @@ def filter_array_like(
515516
np.logical_and(array_like < value, mask, out=out)
516517
elif op == ">":
517518
np.logical_and(array_like > value, mask, out=out)
518-
elif op == "in":
519+
elif op in ("in", "not in"):
520+
inclusive = op == "in"
519521
value = np.asarray(value)
520522
nullmask = pd.isnull(value)
521523
if value.dtype.kind in ("U", "S", "O"):
@@ -548,11 +550,19 @@ def filter_array_like(
548550
if any(nullmask):
549551
matching_idx |= pd.isnull(array_like)
550552

551-
np.logical_and(
552-
matching_idx,
553-
mask,
554-
out=out,
555-
)
553+
if inclusive:
554+
np.logical_and(
555+
matching_idx,
556+
mask,
557+
out=out,
558+
)
559+
else:
560+
np.logical_and(
561+
~matching_idx,
562+
mask,
563+
out=out,
564+
)
565+
556566
else:
557567
raise NotImplementedError("op not supported")
558568

plateau/serialization/_parquet.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,24 @@ def _predicate_accepts(predicate, row_meta, arrow_schema, parquet_reader):
582582
elif min_value <= x <= max_value:
583583
return True
584584
return False
585+
elif op == "not in":
586+
# The only way we could exclude a row group was if we knew that all
587+
# elements in the row group were listed in the values
588+
# The only situations we can tell for sure what the content of the row group is iff
589+
# min_value == max_value
590+
# or null_count == len
591+
592+
if min_value == max_value:
593+
for v in val:
594+
if pd.isnull(v):
595+
if parquet_statistics.null_count > 0:
596+
continue
597+
elif v == min_value:
598+
continue
599+
break
600+
else:
601+
return False
602+
return True
585603
else:
586604
raise NotImplementedError("op not supported")
587605

tests/io_components/test_read.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,14 @@ def test_dispatch_metapartitions(dataset, store_session):
3636
[[("mycol", "in", "scalar")]],
3737
"operator 'in' must be used with a tuple or list",
3838
),
39+
(
40+
[[("mycol", "not in", None)]],
41+
"Invalid predicates: Clause 0 in conjunction 0 with null value and operator 'not in'.",
42+
),
43+
(
44+
[[("mycol", "not in", "scalar")]],
45+
"operator 'not in' must be used with a tuple or list",
46+
),
3947
([[("mycol", "<", [17, 12])]], "operator '<' must be used with a scalar type"),
4048
],
4149
)

tests/serialization/test_filter.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,14 +108,14 @@ def test_filter_array_like_categoricals(op, expected, cat_type):
108108
pytest.param([True], True, marks=pytest.mark.xfail(reason="see gh-193")),
109109
],
110110
)
111-
@pytest.mark.parametrize("op", ["==", "!=", "<", "<=", ">", ">=", "in"])
111+
@pytest.mark.parametrize("op", ["==", "!=", "<", "<=", ">", ">=", "in", "not in"])
112112
def test_raise_on_type(value, filter_value, op):
113113
array_like = pd.Series([value])
114114
with pytest.raises(TypeError, match="Unexpected type for predicate:"):
115115
filter_array_like(array_like, op, filter_value, strict_date_types=True)
116116

117117

118-
@pytest.mark.parametrize("op", ["==", "!=", ">=", "<=", ">", "<", "in"])
118+
@pytest.mark.parametrize("op", ["==", "!=", ">=", "<=", ">", "<", "in", "not in"])
119119
@pytest.mark.parametrize(
120120
"data,value",
121121
[
@@ -170,7 +170,7 @@ def test_filter_df_from_predicates(op, data, value):
170170
if isinstance(df["A"].dtype, pd.CategoricalDtype):
171171
df["A"] = df["A"].astype(df["A"].cat.as_ordered().dtype)
172172

173-
if op == "in":
173+
if op in ["in", "not in"]:
174174
value = [value]
175175

176176
predicates = [[("A", op, value)]]
@@ -181,6 +181,8 @@ def test_filter_df_from_predicates(op, data, value):
181181
value = pd.Series(value, dtype=df["A"].dtype).iloc[0]
182182
if op == "in":
183183
expected = df[df["A"].isin([value])]
184+
elif op == "not in":
185+
expected = df[~df["A"].isin([value])]
184186
else:
185187
expected = eval(f"df[df['A'] {op} value]")
186188
pdt.assert_frame_equal(actual, expected, check_categorical=False)

tests/serialization/test_parquet.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,66 @@ def test_predicate_accept_in(store, predicate_value, expected):
413413
)
414414

415415

416+
@pytest.mark.parametrize(
417+
["predicate_value", "expected"],
418+
[
419+
([0, 4, 1], True),
420+
([-2, 44], True),
421+
([-3, 0], True),
422+
([-1, 10**4], True),
423+
([2, 3], True),
424+
([-1, 20], True),
425+
([-30, -5, 50, 10], True),
426+
([-30, -5, 50, np.nan], True),
427+
([], True),
428+
],
429+
)
430+
def test_predicate_accept_notin(store, predicate_value, expected):
431+
df = pd.DataFrame({"A": [0, 4, 13, 29]}) # min = 0, max = 29
432+
predicate = ("A", "not in", predicate_value)
433+
serialiser = ParquetSerializer(chunk_size=None)
434+
key = serialiser.store(store, "prefix", df)
435+
436+
parquet_file = ParquetFile(store.open(key))
437+
row_meta = parquet_file.metadata.row_group(0)
438+
arrow_schema = parquet_file.schema.to_arrow_schema()
439+
parquet_reader = parquet_file.reader
440+
assert (
441+
_predicate_accepts(
442+
predicate,
443+
row_meta=row_meta,
444+
arrow_schema=arrow_schema,
445+
parquet_reader=parquet_reader,
446+
)
447+
== expected
448+
)
449+
450+
451+
@pytest.mark.parametrize(
452+
["predicate_value", "test_data"],
453+
[
454+
([0], [0, 0]),
455+
([0, np.nan], [0, 0, np.nan]),
456+
],
457+
)
458+
def test_predicate_accept_notin_excludes(store, predicate_value, test_data):
459+
df = pd.DataFrame({"A": test_data}) # min = 0, max = 29
460+
predicate = ("A", "not in", predicate_value)
461+
serialiser = ParquetSerializer(chunk_size=None)
462+
key = serialiser.store(store, "prefix", df)
463+
464+
parquet_file = ParquetFile(store.open(key))
465+
row_meta = parquet_file.metadata.row_group(0)
466+
arrow_schema = parquet_file.schema.to_arrow_schema()
467+
parquet_reader = parquet_file.reader
468+
assert not _predicate_accepts(
469+
predicate,
470+
row_meta=row_meta,
471+
arrow_schema=arrow_schema,
472+
parquet_reader=parquet_reader,
473+
)
474+
475+
416476
def test_read_categorical(store):
417477
df = pd.DataFrame({"col": ["a"]}).astype({"col": "category"})
418478

0 commit comments

Comments
 (0)