Skip to content

Commit 52538e5

Browse files
authored
Merge pull request #263 from data-engineering-collective/schema_compat_str_dtype
2 parents f40dc10 + 7847b04 commit 52538e5

File tree

10 files changed

+318
-26
lines changed

10 files changed

+318
-26
lines changed

CHANGES.rst

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,12 @@
22
Changelog
33
=========
44

5-
Next release
6-
============
5+
Plateau 4.6.0 (2025-08-12)
6+
==========================
77

8+
* Schema normalization for pandas 3.x `str` dtype. String fields are considered
9+
compatible if they are using the same NA value but pandas storage backend is
10+
being ignored.
811
* Support for pyarrow 21.0.0
912
* Drop support for pyarrow 15.0.2, 16.1.0 and 17.0.0
1013

plateau/core/_compat.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,22 @@
11
import pandas as pd
2+
import pyarrow as pa
23
import simplejson
34
from packaging.version import parse as parse_version
5+
from pandas.errors import OptionError
46

57
PANDAS_3 = parse_version(pd.__version__).major >= 3
68

9+
ARROW_GE_20 = parse_version(pa.__version__).major >= 20
710

8-
def pandas_infer_string():
9-
return (
10-
pd.get_option("future.infer_string") or parse_version(pd.__version__).major >= 3
11-
)
11+
12+
def pandas_infer_string() -> bool:
13+
if parse_version(pd.__version__).major >= 3:
14+
# In pandas 3, infer_string is always True
15+
return True
16+
try:
17+
return pd.get_option("future.infer_string")
18+
except OptionError:
19+
return False
1220

1321

1422
def load_json(buf, **kwargs):

plateau/core/common_metadata.py

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,33 @@ def _schema_compat(self):
7878
if index_level_ix >= 0:
7979
schema = schema.remove(index_level_ix)
8080

81+
for cmd in pandas_metadata["columns"]:
82+
name = cmd.get("name")
83+
if name is None:
84+
continue
85+
86+
field_name = cmd["field_name"]
87+
field_idx = schema.get_field_index(field_name)
88+
if field_idx < 0:
89+
continue
90+
field = schema[field_idx]
91+
if (
92+
pa.types.is_string(field.type)
93+
and cmd["pandas_type"] == "unicode"
94+
and cmd["numpy_type"] == "object"
95+
):
96+
schema = schema.remove(field_idx)
97+
new_field = pa.field(
98+
field.name,
99+
pa.large_string(),
100+
field.nullable,
101+
field.metadata,
102+
)
103+
schema = schema.insert(field_idx, new_field)
104+
cmd["pandas_type"] = "object"
105+
cmd["numpy_type"] = "str"
106+
cmd["metadata"] = None
107+
81108
schema = schema.remove_metadata()
82109
md = {b"pandas": _dict_to_binary(pandas_metadata)}
83110
schema = schema.with_metadata(md)
@@ -319,8 +346,32 @@ def normalize_type(
319346
)
320347
return pa.list_(t_pa2), f"list[{t_pd2}]", "object", None
321348
elif pa.types.is_dictionary(t_pa):
322-
# downcast to dictionary content, `t_pd` is useless in that case
323-
return normalize_type(t_pa.value_type, t_np, t_np, None)
349+
return normalize_type(t_pa.value_type, t_pd, t_np, None)
350+
elif pa.types.is_string(t_pa) or pa.types.is_large_string(t_pa):
351+
# Pyarrow only supports reading back
352+
#
353+
# pyarrow + np.nan
354+
# pa.large_string(), "object", "str", None
355+
# or
356+
# python + pd.NA
357+
# pa.string(), "unicode", "string", None
358+
#
359+
# unintuitively, the numpy type identifier `t_np` corresponds
360+
# to the pandas dtypes `str` and `string`
361+
362+
# pandas also supports mixed types but those are rare and must be
363+
# constructed explicitly
364+
if t_pd == "categorical":
365+
# We loose the information of the nullable type since the t_np type
366+
# is set to the dtype of the codes but not the categories.
367+
return pa.large_string(), "object", "str", None
368+
elif t_np == "str":
369+
return pa.large_string(), "object", "str", None
370+
elif t_np == "string":
371+
return pa.string(), "unicode", "string", None
372+
else:
373+
# This should be the ordinary object dtype
374+
return t_pa, t_pd, t_np, metadata
324375
else:
325376
return t_pa, t_pd, t_np, metadata
326377

plateau/core/index.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import plateau.core._time
1313
from plateau.core import naming
14+
from plateau.core._compat import ARROW_GE_20
1415
from plateau.core._mixins import CopyMixin
1516
from plateau.core.common_metadata import normalize_type
1617
from plateau.core.docs import default_docs
@@ -138,7 +139,7 @@ def observed_values(
138139
) -> np.ndarray:
139140
"""Return an array of all observed values."""
140141
keys = np.array(list(self.index_dct.keys()))
141-
labeled_array = pa.array(keys, type=self.dtype)
142+
labeled_array = _safe_paarray(keys, self.dtype)
142143

143144
_coerce = {"coerce_temporal_nanoseconds": coerce_temporal_nanoseconds}
144145
return np.array(
@@ -918,10 +919,17 @@ def _index_dct_to_table(index_dct: IndexDictType, column: str, dtype: pa.DataTyp
918919
# the np.array dtype will be double which arrow cannot convert to the target type, so use an empty list instead
919920
labeled_array = pa.array([], type=dtype)
920921
else:
921-
labeled_array = pa.array(keys, type=dtype)
922+
labeled_array = _safe_paarray(keys, dtype)
922923

923924
partition_array = pa.array(list(index_dct.values()), type=pa.list_(pa.string()))
924925

925926
return pa.Table.from_arrays(
926927
[labeled_array, partition_array], names=[column, _PARTITION_COLUMN_NAME]
927928
)
929+
930+
931+
def _safe_paarray(arr: np.ndarray, dtype: pa.DataType) -> pa.Array:
932+
if dtype is not None and pa.types.is_large_string(dtype) and not ARROW_GE_20:
933+
return pa.array(iter(arr), type=dtype)
934+
else:
935+
return pa.array(arr, type=dtype)

plateau/io/testing/read.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
import pytest
3737
from minimalkv import get_store_from_url
3838

39-
from plateau.core._compat import pandas_infer_string
39+
from plateau.core._compat import PANDAS_3, pandas_infer_string
4040
from plateau.io.eager import store_dataframes_as_dataset
4141
from plateau.io.iter import store_dataframes_as_dataset__iter
4242
from plateau.io_components.metapartition import SINGLE_TABLE, MetaPartition
@@ -648,8 +648,14 @@ def test_binary_column_metadata(store_factory, bound_load_dataframes):
648648
assert set(df.columns.map(type)) == {str}
649649

650650

651-
def test_extensiondtype_roundtrip(store_factory, bound_load_dataframes):
652-
df = pd.DataFrame({"str": pd.Series(["a", "b"], dtype="string")})
651+
def test_string_type_roundtrip(store_factory, bound_load_dataframes):
652+
# Note: we're not actually roundtripping the string type since the loading
653+
# type depends on the pandas version. Keeping the loading type aligned with
654+
# what is typically initialized by pandas by default is likely the best
655+
# option
656+
df = pd.DataFrame(
657+
{"str": pd.Series(["a", "b"], dtype="str" if PANDAS_3 else "string")}
658+
)
653659

654660
store_dataframes_as_dataset(
655661
dfs=[df], store=store_factory, dataset_uuid="dataset_uuid"

plateau/io/testing/update.py

Lines changed: 202 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,19 @@
66

77
import numpy as np
88
import pandas as pd
9+
import pandas.testing as pdt
910
import pytest
1011

1112
from plateau.api.dataset import read_dataset_as_ddf
13+
from plateau.core._compat import PANDAS_3
1214
from plateau.core.dataset import DatasetMetadata
1315
from plateau.core.naming import DEFAULT_METADATA_VERSION
1416
from plateau.core.testing import TIME_TO_FREEZE_ISO
15-
from plateau.io.eager import read_dataset_as_metapartitions, store_dataframes_as_dataset
17+
from plateau.io.eager import (
18+
read_dataset_as_metapartitions,
19+
read_table,
20+
store_dataframes_as_dataset,
21+
)
1622
from plateau.io.iter import read_dataset_as_dataframes__iterator
1723

1824

@@ -665,3 +671,198 @@ def test_update_of_dataset_with_non_default_table_name(
665671
)
666672
df_expected = pd.concat([df_create, df_update]).reset_index(drop=True)
667673
pd.testing.assert_frame_equal(df_read, df_expected)
674+
675+
676+
def _dtype_from_storage_nan_value(storage_backend, na_value):
677+
if PANDAS_3:
678+
dtype = pd.StringDtype(storage=storage_backend, na_value=na_value)
679+
else:
680+
if storage_backend == "pyarrow" and na_value is pd.NA:
681+
dtype = "string[pyarrow]"
682+
elif storage_backend == "pyarrow" and na_value is np.nan:
683+
dtype = "string[pyarrow_numpy]"
684+
elif storage_backend == "python" and na_value is np.nan:
685+
return None
686+
elif storage_backend == "python" and na_value is pd.NA:
687+
dtype = "string"
688+
else:
689+
raise ValueError(f"Unsupported storage backend: {storage_backend}")
690+
return dtype
691+
692+
693+
@pytest.mark.parametrize("storage_backend", ["pyarrow", "python"])
694+
@pytest.mark.parametrize("na_value", [np.nan, pd.NA])
695+
def test_update_after_empty_partition_string_dtypes(
696+
store_factory, bound_update_dataset, storage_backend, na_value, backend_identifier
697+
):
698+
import pandas as pd
699+
700+
with pd.option_context("future.infer_string", True):
701+
other_nan_value = {np.nan, pd.NA}
702+
other_nan_value.remove(na_value)
703+
other_nan_value = other_nan_value.pop()
704+
dtype = _dtype_from_storage_nan_value(storage_backend, na_value)
705+
if dtype is None:
706+
pytest.skip()
707+
df = pd.DataFrame({"str": pd.Series(["a", "b", None], dtype=dtype)})
708+
709+
dataset_uuid = "dataset_uuid"
710+
bound_update_dataset(
711+
[df.iloc[0:0]], # empty partition
712+
store=store_factory,
713+
dataset_uuid=dataset_uuid,
714+
)
715+
# Schema verification should not fail
716+
bound_update_dataset(
717+
[df],
718+
store=store_factory,
719+
dataset_uuid=dataset_uuid,
720+
)
721+
if na_value is pd.NA:
722+
expected_dtype = _dtype_from_storage_nan_value("python", pd.NA)
723+
else:
724+
expected_dtype = _dtype_from_storage_nan_value("pyarrow", np.nan)
725+
# We have to cast to the expected dtype since pyarrow is only reading
726+
# the above two data types in. They are ignoring the written storage
727+
# backend and are defaulting to python for pd.NA and to pyarrow for
728+
# np.nan
729+
df["str"] = df["str"].astype(expected_dtype)
730+
731+
pdt.assert_frame_equal(read_table(dataset_uuid, store_factory()), df)
732+
if backend_identifier == "dask.dataframe":
733+
# FIXME: dask.dataframe triggers the schema validation error but somehow
734+
# the exception is not properly forwarded and the test always fails
735+
return
736+
for storage in ["pyarrow", "python"]:
737+
df = pd.DataFrame(
738+
{
739+
"str": pd.Series(
740+
["c", "d"],
741+
dtype=_dtype_from_storage_nan_value(storage, other_nan_value),
742+
)
743+
}
744+
)
745+
# Should be a ValueError but dask sometimes raises a different exception
746+
# type
747+
with pytest.raises(ValueError, match="Schemas.*are not compatible.*"):
748+
bound_update_dataset(
749+
[df],
750+
store=store_factory,
751+
dataset_uuid=dataset_uuid,
752+
)
753+
754+
755+
@pytest.mark.parametrize("storage_backend", ["pyarrow", "python"])
756+
@pytest.mark.parametrize("na_value", [np.nan, pd.NA])
757+
def test_update_after_empty_partition_string_dtypes_categoricals(
758+
store_factory, bound_update_dataset, storage_backend, na_value
759+
):
760+
import pandas as pd
761+
762+
with pd.option_context("future.infer_string", True):
763+
other_nan_value = {np.nan, pd.NA}
764+
other_nan_value.remove(na_value)
765+
other_nan_value = other_nan_value.pop()
766+
dtype = _dtype_from_storage_nan_value(storage_backend, na_value)
767+
if dtype is None:
768+
pytest.skip()
769+
df = pd.DataFrame(
770+
{"str": pd.Series(["a", "b", None], dtype=dtype).astype("category")}
771+
)
772+
773+
dataset_uuid = "dataset_uuid"
774+
bound_update_dataset(
775+
[df.iloc[0:0]], # empty partition
776+
store=store_factory,
777+
dataset_uuid=dataset_uuid,
778+
)
779+
# Schema verification should not fail
780+
bound_update_dataset(
781+
[df],
782+
store=store_factory,
783+
dataset_uuid=dataset_uuid,
784+
)
785+
expected_dtype = _dtype_from_storage_nan_value("pyarrow", np.nan)
786+
# We have to cast to the expected dtype since pyarrow is only reading
787+
# categoricals with the pyarrow_numpy data type.
788+
df["str"] = df["str"].astype(expected_dtype)
789+
790+
pdt.assert_frame_equal(read_table(dataset_uuid, store_factory()), df)
791+
for storage in ["pyarrow", "python"]:
792+
df = pd.DataFrame(
793+
{
794+
"str": pd.Series(
795+
["c", "d"],
796+
dtype=_dtype_from_storage_nan_value(storage, other_nan_value),
797+
).astype("category")
798+
}
799+
)
800+
bound_update_dataset(
801+
[df],
802+
store=store_factory,
803+
dataset_uuid=dataset_uuid,
804+
)
805+
after_update = read_table(dataset_uuid, store_factory())
806+
807+
if not PANDAS_3:
808+
expected_dtype = "object"
809+
810+
expected_after_update = pd.DataFrame(
811+
{"str": pd.Series(["a", "b", None, "c", "d", "c", "d"], dtype=expected_dtype)}
812+
)
813+
pdt.assert_frame_equal(after_update, expected_after_update)
814+
815+
# Storage of categorical dtypes will only happen with np.nan If we try the other na_value we'll get a validation error
816+
817+
for storage in ["pyarrow", "python"]:
818+
df = pd.DataFrame(
819+
{
820+
"str": pd.Series(
821+
["e", "f", None],
822+
dtype=_dtype_from_storage_nan_value(storage, pd.NA),
823+
)
824+
}
825+
)
826+
with pytest.raises(ValueError, match="Schemas.*are not compatible.*"):
827+
bound_update_dataset(
828+
[df],
829+
store=store_factory,
830+
dataset_uuid=dataset_uuid,
831+
)
832+
833+
# With np.nan works fine?
834+
skipped = False
835+
for storage in ["pyarrow", "python"]:
836+
dtype = _dtype_from_storage_nan_value(storage, np.nan)
837+
if dtype is None:
838+
skipped = True
839+
continue
840+
df = pd.DataFrame(
841+
{
842+
"str": pd.Series(
843+
["e", "f", None],
844+
dtype=dtype,
845+
)
846+
}
847+
)
848+
bound_update_dataset(
849+
[df],
850+
store=store_factory,
851+
dataset_uuid=dataset_uuid,
852+
)
853+
854+
after_update_as_cats = read_table(
855+
dataset_uuid, store_factory(), categoricals=["str"]
856+
)
857+
values = ["a", "b", None, "c", "d", "c", "d", "e", "f", None, "e", "f", None]
858+
if skipped:
859+
values = values[:-3]
860+
expected = pd.DataFrame(
861+
{
862+
"str": pd.Series(
863+
values,
864+
dtype=expected_dtype,
865+
).astype("category")
866+
}
867+
)
868+
pdt.assert_frame_equal(after_update_as_cats, expected)

0 commit comments

Comments
 (0)