|
6 | 6 |
|
7 | 7 | import numpy as np |
8 | 8 | import pandas as pd |
| 9 | +import pandas.testing as pdt |
9 | 10 | import pytest |
10 | 11 |
|
11 | 12 | from plateau.api.dataset import read_dataset_as_ddf |
| 13 | +from plateau.core._compat import PANDAS_3 |
12 | 14 | from plateau.core.dataset import DatasetMetadata |
13 | 15 | from plateau.core.naming import DEFAULT_METADATA_VERSION |
14 | 16 | from plateau.core.testing import TIME_TO_FREEZE_ISO |
15 | | -from plateau.io.eager import read_dataset_as_metapartitions, store_dataframes_as_dataset |
| 17 | +from plateau.io.eager import ( |
| 18 | + read_dataset_as_metapartitions, |
| 19 | + read_table, |
| 20 | + store_dataframes_as_dataset, |
| 21 | +) |
16 | 22 | from plateau.io.iter import read_dataset_as_dataframes__iterator |
17 | 23 |
|
18 | 24 |
|
@@ -665,3 +671,198 @@ def test_update_of_dataset_with_non_default_table_name( |
665 | 671 | ) |
666 | 672 | df_expected = pd.concat([df_create, df_update]).reset_index(drop=True) |
667 | 673 | pd.testing.assert_frame_equal(df_read, df_expected) |
| 674 | + |
| 675 | + |
| 676 | +def _dtype_from_storage_nan_value(storage_backend, na_value): |
| 677 | + if PANDAS_3: |
| 678 | + dtype = pd.StringDtype(storage=storage_backend, na_value=na_value) |
| 679 | + else: |
| 680 | + if storage_backend == "pyarrow" and na_value is pd.NA: |
| 681 | + dtype = "string[pyarrow]" |
| 682 | + elif storage_backend == "pyarrow" and na_value is np.nan: |
| 683 | + dtype = "string[pyarrow_numpy]" |
| 684 | + elif storage_backend == "python" and na_value is np.nan: |
| 685 | + return None |
| 686 | + elif storage_backend == "python" and na_value is pd.NA: |
| 687 | + dtype = "string" |
| 688 | + else: |
| 689 | + raise ValueError(f"Unsupported storage backend: {storage_backend}") |
| 690 | + return dtype |
| 691 | + |
| 692 | + |
| 693 | +@pytest.mark.parametrize("storage_backend", ["pyarrow", "python"]) |
| 694 | +@pytest.mark.parametrize("na_value", [np.nan, pd.NA]) |
| 695 | +def test_update_after_empty_partition_string_dtypes( |
| 696 | + store_factory, bound_update_dataset, storage_backend, na_value, backend_identifier |
| 697 | +): |
| 698 | + import pandas as pd |
| 699 | + |
| 700 | + with pd.option_context("future.infer_string", True): |
| 701 | + other_nan_value = {np.nan, pd.NA} |
| 702 | + other_nan_value.remove(na_value) |
| 703 | + other_nan_value = other_nan_value.pop() |
| 704 | + dtype = _dtype_from_storage_nan_value(storage_backend, na_value) |
| 705 | + if dtype is None: |
| 706 | + pytest.skip() |
| 707 | + df = pd.DataFrame({"str": pd.Series(["a", "b", None], dtype=dtype)}) |
| 708 | + |
| 709 | + dataset_uuid = "dataset_uuid" |
| 710 | + bound_update_dataset( |
| 711 | + [df.iloc[0:0]], # empty partition |
| 712 | + store=store_factory, |
| 713 | + dataset_uuid=dataset_uuid, |
| 714 | + ) |
| 715 | + # Schema verification should not fail |
| 716 | + bound_update_dataset( |
| 717 | + [df], |
| 718 | + store=store_factory, |
| 719 | + dataset_uuid=dataset_uuid, |
| 720 | + ) |
| 721 | + if na_value is pd.NA: |
| 722 | + expected_dtype = _dtype_from_storage_nan_value("python", pd.NA) |
| 723 | + else: |
| 724 | + expected_dtype = _dtype_from_storage_nan_value("pyarrow", np.nan) |
| 725 | + # We have to cast to the expected dtype since pyarrow is only reading |
| 726 | + # the above two data types in. They are ignoring the written storage |
| 727 | + # backend and are defaulting to python for pd.NA and to pyarrow for |
| 728 | + # np.nan |
| 729 | + df["str"] = df["str"].astype(expected_dtype) |
| 730 | + |
| 731 | + pdt.assert_frame_equal(read_table(dataset_uuid, store_factory()), df) |
| 732 | + if backend_identifier == "dask.dataframe": |
| 733 | + # FIXME: dask.dataframe triggers the schema validation error but somehow |
| 734 | + # the exception is not properly forwarded and the test always fails |
| 735 | + return |
| 736 | + for storage in ["pyarrow", "python"]: |
| 737 | + df = pd.DataFrame( |
| 738 | + { |
| 739 | + "str": pd.Series( |
| 740 | + ["c", "d"], |
| 741 | + dtype=_dtype_from_storage_nan_value(storage, other_nan_value), |
| 742 | + ) |
| 743 | + } |
| 744 | + ) |
| 745 | + # Should be a ValueError but dask sometimes raises a different exception |
| 746 | + # type |
| 747 | + with pytest.raises(ValueError, match="Schemas.*are not compatible.*"): |
| 748 | + bound_update_dataset( |
| 749 | + [df], |
| 750 | + store=store_factory, |
| 751 | + dataset_uuid=dataset_uuid, |
| 752 | + ) |
| 753 | + |
| 754 | + |
| 755 | +@pytest.mark.parametrize("storage_backend", ["pyarrow", "python"]) |
| 756 | +@pytest.mark.parametrize("na_value", [np.nan, pd.NA]) |
| 757 | +def test_update_after_empty_partition_string_dtypes_categoricals( |
| 758 | + store_factory, bound_update_dataset, storage_backend, na_value |
| 759 | +): |
| 760 | + import pandas as pd |
| 761 | + |
| 762 | + with pd.option_context("future.infer_string", True): |
| 763 | + other_nan_value = {np.nan, pd.NA} |
| 764 | + other_nan_value.remove(na_value) |
| 765 | + other_nan_value = other_nan_value.pop() |
| 766 | + dtype = _dtype_from_storage_nan_value(storage_backend, na_value) |
| 767 | + if dtype is None: |
| 768 | + pytest.skip() |
| 769 | + df = pd.DataFrame( |
| 770 | + {"str": pd.Series(["a", "b", None], dtype=dtype).astype("category")} |
| 771 | + ) |
| 772 | + |
| 773 | + dataset_uuid = "dataset_uuid" |
| 774 | + bound_update_dataset( |
| 775 | + [df.iloc[0:0]], # empty partition |
| 776 | + store=store_factory, |
| 777 | + dataset_uuid=dataset_uuid, |
| 778 | + ) |
| 779 | + # Schema verification should not fail |
| 780 | + bound_update_dataset( |
| 781 | + [df], |
| 782 | + store=store_factory, |
| 783 | + dataset_uuid=dataset_uuid, |
| 784 | + ) |
| 785 | + expected_dtype = _dtype_from_storage_nan_value("pyarrow", np.nan) |
| 786 | + # We have to cast to the expected dtype since pyarrow is only reading |
| 787 | + # categoricals with the pyarrow_numpy data type. |
| 788 | + df["str"] = df["str"].astype(expected_dtype) |
| 789 | + |
| 790 | + pdt.assert_frame_equal(read_table(dataset_uuid, store_factory()), df) |
| 791 | + for storage in ["pyarrow", "python"]: |
| 792 | + df = pd.DataFrame( |
| 793 | + { |
| 794 | + "str": pd.Series( |
| 795 | + ["c", "d"], |
| 796 | + dtype=_dtype_from_storage_nan_value(storage, other_nan_value), |
| 797 | + ).astype("category") |
| 798 | + } |
| 799 | + ) |
| 800 | + bound_update_dataset( |
| 801 | + [df], |
| 802 | + store=store_factory, |
| 803 | + dataset_uuid=dataset_uuid, |
| 804 | + ) |
| 805 | + after_update = read_table(dataset_uuid, store_factory()) |
| 806 | + |
| 807 | + if not PANDAS_3: |
| 808 | + expected_dtype = "object" |
| 809 | + |
| 810 | + expected_after_update = pd.DataFrame( |
| 811 | + {"str": pd.Series(["a", "b", None, "c", "d", "c", "d"], dtype=expected_dtype)} |
| 812 | + ) |
| 813 | + pdt.assert_frame_equal(after_update, expected_after_update) |
| 814 | + |
| 815 | + # Storage of categorical dtypes will only happen with np.nan If we try the other na_value we'll get a validation error |
| 816 | + |
| 817 | + for storage in ["pyarrow", "python"]: |
| 818 | + df = pd.DataFrame( |
| 819 | + { |
| 820 | + "str": pd.Series( |
| 821 | + ["e", "f", None], |
| 822 | + dtype=_dtype_from_storage_nan_value(storage, pd.NA), |
| 823 | + ) |
| 824 | + } |
| 825 | + ) |
| 826 | + with pytest.raises(ValueError, match="Schemas.*are not compatible.*"): |
| 827 | + bound_update_dataset( |
| 828 | + [df], |
| 829 | + store=store_factory, |
| 830 | + dataset_uuid=dataset_uuid, |
| 831 | + ) |
| 832 | + |
| 833 | + # With np.nan works fine? |
| 834 | + skipped = False |
| 835 | + for storage in ["pyarrow", "python"]: |
| 836 | + dtype = _dtype_from_storage_nan_value(storage, np.nan) |
| 837 | + if dtype is None: |
| 838 | + skipped = True |
| 839 | + continue |
| 840 | + df = pd.DataFrame( |
| 841 | + { |
| 842 | + "str": pd.Series( |
| 843 | + ["e", "f", None], |
| 844 | + dtype=dtype, |
| 845 | + ) |
| 846 | + } |
| 847 | + ) |
| 848 | + bound_update_dataset( |
| 849 | + [df], |
| 850 | + store=store_factory, |
| 851 | + dataset_uuid=dataset_uuid, |
| 852 | + ) |
| 853 | + |
| 854 | + after_update_as_cats = read_table( |
| 855 | + dataset_uuid, store_factory(), categoricals=["str"] |
| 856 | + ) |
| 857 | + values = ["a", "b", None, "c", "d", "c", "d", "e", "f", None, "e", "f", None] |
| 858 | + if skipped: |
| 859 | + values = values[:-3] |
| 860 | + expected = pd.DataFrame( |
| 861 | + { |
| 862 | + "str": pd.Series( |
| 863 | + values, |
| 864 | + dtype=expected_dtype, |
| 865 | + ).astype("category") |
| 866 | + } |
| 867 | + ) |
| 868 | + pdt.assert_frame_equal(after_update_as_cats, expected) |
0 commit comments