|
16 | 16 | # under the License. |
17 | 17 | import ctypes |
18 | 18 | import datetime |
| 19 | +import itertools |
19 | 20 | import os |
20 | 21 | import re |
21 | 22 | import threading |
@@ -59,9 +60,7 @@ def ctx(): |
59 | 60 |
|
60 | 61 |
|
61 | 62 | @pytest.fixture |
62 | | -def df(): |
63 | | - ctx = SessionContext() |
64 | | - |
| 63 | +def df(ctx): |
65 | 64 | # create a RecordBatch and a new DataFrame from it |
66 | 65 | batch = pa.RecordBatch.from_arrays( |
67 | 66 | [pa.array([1, 2, 3]), pa.array([4, 5, 6]), pa.array([8, 5, 8])], |
@@ -1831,29 +1830,52 @@ def test_write_csv(ctx, df, tmp_path, path_to_str): |
1831 | 1830 | assert result == expected |
1832 | 1831 |
|
1833 | 1832 |
|
| 1833 | +sort_by_cases = [ |
| 1834 | + (None, [1, 2, 3], "unsorted"), |
| 1835 | + (column("c"), [2, 1, 3], "single_column_expr"), |
| 1836 | + (column("a").sort(ascending=False), [3, 2, 1], "single_sort_expr"), |
| 1837 | + ([column("c"), column("b")], [2, 1, 3], "list_col_expr"), |
| 1838 | + ( |
| 1839 | + [column("c").sort(ascending=False), column("b").sort(ascending=False)], |
| 1840 | + [3, 1, 2], |
| 1841 | + "list_sort_expr", |
| 1842 | + ), |
| 1843 | +] |
| 1844 | + |
| 1845 | +formats = ["csv", "json", "parquet", "table"] |
| 1846 | + |
| 1847 | + |
1834 | 1848 | @pytest.mark.parametrize( |
1835 | | - ("sort_by", "expected_a"), |
| 1849 | + ("format", "sort_by", "expected_a"), |
1836 | 1850 | [ |
1837 | | - pytest.param(None, [1, 2, 3], id="unsorted"), |
1838 | | - pytest.param(column("c"), [2, 1, 3], id="single_column_expr"), |
1839 | | - pytest.param( |
1840 | | - column("a").sort(ascending=False), [3, 2, 1], id="single_sort_expr" |
1841 | | - ), |
1842 | | - pytest.param([column("c"), column("b")], [2, 1, 3], id="list_col_expr"), |
1843 | | - pytest.param( |
1844 | | - [column("c").sort(ascending=False), column("b").sort(ascending=False)], |
1845 | | - [3, 1, 2], |
1846 | | - id="list_sort_expr", |
1847 | | - ), |
| 1851 | + pytest.param(format, sort_by, expected_a, id=f"{format}_{test_id}") |
| 1852 | + for format, (sort_by, expected_a, test_id) in itertools.product( |
| 1853 | + formats, sort_by_cases |
| 1854 | + ) |
1848 | 1855 | ], |
1849 | 1856 | ) |
1850 | | -def test_write_csv_with_options(ctx, df, tmp_path, sort_by, expected_a) -> None: |
| 1857 | +def test_write_files_with_options( |
| 1858 | + ctx, df, tmp_path, format, sort_by, expected_a |
| 1859 | +) -> None: |
1851 | 1860 | write_options = DataFrameWriteOptions(sort_by=sort_by) |
1852 | | - df.write_csv(tmp_path, with_header=True, write_options=write_options) |
1853 | 1861 |
|
1854 | | - ctx.register_csv("csv", tmp_path) |
1855 | | - result = ctx.table("csv").to_pydict()["a"] |
1856 | | - ctx.table("csv").show() |
| 1862 | + if format == "csv": |
| 1863 | + df.write_csv(tmp_path, with_header=True, write_options=write_options) |
| 1864 | + ctx.register_csv("test_table", tmp_path) |
| 1865 | + elif format == "json": |
| 1866 | + df.write_json(tmp_path, write_options=write_options) |
| 1867 | + ctx.register_json("test_table", tmp_path) |
| 1868 | + elif format == "parquet": |
| 1869 | + df.write_parquet(tmp_path, write_options=write_options) |
| 1870 | + ctx.register_parquet("test_table", tmp_path) |
| 1871 | + elif format == "table": |
| 1872 | + batch = pa.RecordBatch.from_arrays([[], [], []], schema=df.schema()) |
| 1873 | + ctx.register_record_batches("test_table", [[batch]]) |
| 1874 | + ctx.table("test_table").show() |
| 1875 | + df.write_table("test_table", write_options=write_options) |
| 1876 | + |
| 1877 | + result = ctx.table("test_table").to_pydict()["a"] |
| 1878 | + ctx.table("test_table").show() |
1857 | 1879 |
|
1858 | 1880 | assert result == expected_a |
1859 | 1881 |
|
|
0 commit comments