Skip to content

Commit 88f646c

Browse files
Rename LargeList.dtype to LargeList.feature (#7106)
1 parent 3813ce8 commit 88f646c

File tree

4 files changed

+28
-28
lines changed

4 files changed

+28
-28
lines changed

src/datasets/features/features.py

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1175,11 +1175,11 @@ class LargeList:
11751175
It is backed by `pyarrow.LargeListType`, which is like `pyarrow.ListType` but with 64-bit rather than 32-bit offsets.
11761176
11771177
Args:
1178-
dtype ([`FeatureType`]):
1178+
feature ([`FeatureType`]):
11791179
Child feature data type of each item within the large list.
11801180
"""
11811181

1182-
dtype: Any
1182+
feature: Any
11831183
id: Optional[str] = None
11841184
# Automatically constructed
11851185
pa_type: ClassVar[Any] = None
@@ -1218,8 +1218,6 @@ def _check_non_null_non_empty_recursive(obj, schema: Optional[FeatureType] = Non
12181218
pass
12191219
elif isinstance(schema, (list, tuple)):
12201220
schema = schema[0]
1221-
elif isinstance(schema, LargeList):
1222-
schema = schema.dtype
12231221
else:
12241222
schema = schema.feature
12251223
return _check_non_null_non_empty_recursive(obj[0], schema)
@@ -1252,7 +1250,7 @@ def get_nested_type(schema: FeatureType) -> pa.DataType:
12521250
value_type = get_nested_type(schema[0])
12531251
return pa.list_(value_type)
12541252
elif isinstance(schema, LargeList):
1255-
value_type = get_nested_type(schema.dtype)
1253+
value_type = get_nested_type(schema.feature)
12561254
return pa.large_list(value_type)
12571255
elif isinstance(schema, Sequence):
12581256
value_type = get_nested_type(schema.feature)
@@ -1303,7 +1301,7 @@ def encode_nested_example(schema, obj, level=0):
13031301
return None
13041302
else:
13051303
if len(obj) > 0:
1306-
sub_schema = schema.dtype
1304+
sub_schema = schema.feature
13071305
for first_elmt in obj:
13081306
if _check_non_null_non_empty_recursive(first_elmt, sub_schema):
13091307
break
@@ -1384,7 +1382,7 @@ def decode_nested_example(schema, obj, token_per_repo_id: Optional[Dict[str, Uni
13841382
if obj is None:
13851383
return None
13861384
else:
1387-
sub_schema = schema.dtype
1385+
sub_schema = schema.feature
13881386
if len(obj) > 0:
13891387
for first_elmt in obj:
13901388
if _check_non_null_non_empty_recursive(first_elmt, sub_schema):
@@ -1463,8 +1461,8 @@ def generate_from_dict(obj: Any):
14631461
raise ValueError(f"Feature type '{_type}' not found. Available feature types: {list(_FEATURE_TYPES.keys())}")
14641462

14651463
if class_type == LargeList:
1466-
dtype = obj.pop("dtype")
1467-
return LargeList(generate_from_dict(dtype), **obj)
1464+
feature = obj.pop("feature")
1465+
return LargeList(feature=generate_from_dict(feature), **obj)
14681466
if class_type == Sequence:
14691467
feature = obj.pop("feature")
14701468
return Sequence(feature=generate_from_dict(feature), **obj)
@@ -1493,8 +1491,8 @@ def generate_from_arrow_type(pa_type: pa.DataType) -> FeatureType:
14931491
return [feature]
14941492
return Sequence(feature=feature)
14951493
elif isinstance(pa_type, pa.LargeListType):
1496-
dtype = generate_from_arrow_type(pa_type.value_type)
1497-
return LargeList(dtype)
1494+
feature = generate_from_arrow_type(pa_type.value_type)
1495+
return LargeList(feature=feature)
14981496
elif isinstance(pa_type, _ArrayXDExtensionType):
14991497
array_feature = [None, None, Array2D, Array3D, Array4D, Array5D][pa_type.ndims]
15001498
return array_feature(shape=pa_type.shape, dtype=pa_type.value_type)
@@ -1601,7 +1599,7 @@ def _visit(feature: FeatureType, func: Callable[[FeatureType], Optional[FeatureT
16011599
elif isinstance(feature, (list, tuple)):
16021600
out = func([_visit(feature[0], func)])
16031601
elif isinstance(feature, LargeList):
1604-
out = func(LargeList(_visit(feature.dtype, func)))
1602+
out = func(LargeList(_visit(feature.feature, func)))
16051603
elif isinstance(feature, Sequence):
16061604
out = func(Sequence(_visit(feature.feature, func), length=feature.length))
16071605
else:
@@ -1624,7 +1622,7 @@ def require_decoding(feature: FeatureType, ignore_decode_attribute: bool = False
16241622
elif isinstance(feature, (list, tuple)):
16251623
return require_decoding(feature[0])
16261624
elif isinstance(feature, LargeList):
1627-
return require_decoding(feature.dtype)
1625+
return require_decoding(feature.feature)
16281626
elif isinstance(feature, Sequence):
16291627
return require_decoding(feature.feature)
16301628
else:
@@ -1644,7 +1642,7 @@ def require_storage_cast(feature: FeatureType) -> bool:
16441642
elif isinstance(feature, (list, tuple)):
16451643
return require_storage_cast(feature[0])
16461644
elif isinstance(feature, LargeList):
1647-
return require_storage_cast(feature.dtype)
1645+
return require_storage_cast(feature.feature)
16481646
elif isinstance(feature, Sequence):
16491647
return require_storage_cast(feature.feature)
16501648
else:
@@ -1664,7 +1662,7 @@ def require_storage_embed(feature: FeatureType) -> bool:
16641662
elif isinstance(feature, (list, tuple)):
16651663
return require_storage_cast(feature[0])
16661664
elif isinstance(feature, LargeList):
1667-
return require_storage_cast(feature.dtype)
1665+
return require_storage_cast(feature.feature)
16681666
elif isinstance(feature, Sequence):
16691667
return require_storage_cast(feature.feature)
16701668
else:
@@ -1876,8 +1874,8 @@ def to_yaml_inner(obj: Union[dict, list]) -> dict:
18761874
if isinstance(obj, dict):
18771875
_type = obj.pop("_type", None)
18781876
if _type == "LargeList":
1879-
value_type = obj.pop("dtype")
1880-
return simplify({"large_list": to_yaml_inner(value_type), **obj})
1877+
_feature = obj.pop("feature")
1878+
return simplify({"large_list": to_yaml_inner(_feature), **obj})
18811879
elif _type == "Sequence":
18821880
_feature = obj.pop("feature")
18831881
return simplify({"sequence": to_yaml_inner(_feature), **obj})
@@ -1947,8 +1945,8 @@ def from_yaml_inner(obj: Union[dict, list]) -> Union[dict, list]:
19471945
return {}
19481946
_type = next(iter(obj))
19491947
if _type == "large_list":
1950-
_dtype = unsimplify(obj).pop(_type)
1951-
return {"dtype": from_yaml_inner(_dtype), **obj, "_type": "LargeList"}
1948+
_feature = unsimplify(obj).pop(_type)
1949+
return {"feature": from_yaml_inner(_feature), **obj, "_type": "LargeList"}
19521950
if _type == "sequence":
19531951
_feature = unsimplify(obj).pop(_type)
19541952
return {"feature": from_yaml_inner(_feature), **obj, "_type": "Sequence"}
@@ -2180,7 +2178,7 @@ def recursive_reorder(source, target, stack=""):
21802178
elif isinstance(source, LargeList):
21812179
if not isinstance(target, LargeList):
21822180
raise ValueError(f"Type mismatch: between {source} and {target}" + stack_position)
2183-
return LargeList(recursive_reorder(source.dtype, target.dtype, stack))
2181+
return LargeList(recursive_reorder(source.feature, target.feature, stack))
21842182
else:
21852183
return source
21862184

src/datasets/table.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2017,7 +2017,7 @@ def cast_array_to_feature(
20172017
array_offsets = _combine_list_array_offsets_with_mask(array)
20182018
return pa.ListArray.from_arrays(array_offsets, casted_array_values)
20192019
elif isinstance(feature, LargeList):
2020-
casted_array_values = _c(array.values, feature.dtype)
2020+
casted_array_values = _c(array.values, feature.feature)
20212021
if pa.types.is_large_list(array.type) and casted_array_values.type == array.values.type:
20222022
# Both array and feature have equal large_list type and values (within the list) type
20232023
return array
@@ -2075,7 +2075,9 @@ def cast_array_to_feature(
20752075
return pa.ListArray.from_arrays(array_offsets, _c(array.values, feature[0]), mask=array.is_null())
20762076
elif isinstance(feature, LargeList):
20772077
array_offsets = (np.arange(len(array) + 1) + array.offset) * array.type.list_size
2078-
return pa.LargeListArray.from_arrays(array_offsets, _c(array.values, feature.dtype), mask=array.is_null())
2078+
return pa.LargeListArray.from_arrays(
2079+
array_offsets, _c(array.values, feature.feature), mask=array.is_null()
2080+
)
20792081
elif isinstance(feature, Sequence):
20802082
if feature.length > -1:
20812083
if feature.length == array.type.list_size:
@@ -2155,7 +2157,7 @@ def embed_array_storage(array: pa.Array, feature: "FeatureType"):
21552157
# feature must be LargeList(subfeature)
21562158
# Merge offsets with the null bitmap to avoid the "Null bitmap with offsets slice not supported" ArrowNotImplementedError
21572159
array_offsets = _combine_list_array_offsets_with_mask(array)
2158-
return pa.LargeListArray.from_arrays(array_offsets, _e(array.values, feature.dtype))
2160+
return pa.LargeListArray.from_arrays(array_offsets, _e(array.values, feature.feature))
21592161
elif pa.types.is_fixed_size_list(array.type):
21602162
# feature must be Sequence(subfeature)
21612163
if isinstance(feature, Sequence) and feature.length > -1:

tests/features/test_features.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -726,7 +726,7 @@ def test_features_flatten_with_list_types(features_dict, expected_features_dict)
726726
{"col": [Value("int32")]},
727727
),
728728
(
729-
{"col": {"dtype": {"dtype": "int32", "_type": "Value"}, "_type": "LargeList"}},
729+
{"col": {"feature": {"dtype": "int32", "_type": "Value"}, "_type": "LargeList"}},
730730
{"col": LargeList(Value("int32"))},
731731
),
732732
(
@@ -738,7 +738,7 @@ def test_features_flatten_with_list_types(features_dict, expected_features_dict)
738738
{"col": [{"sub_col": Value("int32")}]},
739739
),
740740
(
741-
{"col": {"dtype": {"sub_col": {"dtype": "int32", "_type": "Value"}}, "_type": "LargeList"}},
741+
{"col": {"feature": {"sub_col": {"dtype": "int32", "_type": "Value"}}, "_type": "LargeList"}},
742742
{"col": LargeList({"sub_col": Value("int32")})},
743743
),
744744
(
@@ -760,7 +760,7 @@ def test_features_from_dict_with_list_types(deserialized_features_dict, expected
760760
[Value("int32")],
761761
),
762762
(
763-
{"dtype": {"dtype": "int32", "_type": "Value"}, "_type": "LargeList"},
763+
{"feature": {"dtype": "int32", "_type": "Value"}, "_type": "LargeList"},
764764
LargeList(Value("int32")),
765765
),
766766
(
@@ -772,7 +772,7 @@ def test_features_from_dict_with_list_types(deserialized_features_dict, expected
772772
[{"sub_col": Value("int32")}],
773773
),
774774
(
775-
{"dtype": {"sub_col": {"dtype": "int32", "_type": "Value"}}, "_type": "LargeList"},
775+
{"feature": {"sub_col": {"dtype": "int32", "_type": "Value"}}, "_type": "LargeList"},
776776
LargeList({"sub_col": Value("int32")}),
777777
),
778778
(

tests/test_info.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def test_dataset_info_from_dict_with_large_list():
170170
dataset_info_dict = {
171171
"citation": "",
172172
"description": "",
173-
"features": {"col_1": {"dtype": {"dtype": "int64", "_type": "Value"}, "_type": "LargeList"}},
173+
"features": {"col_1": {"feature": {"dtype": "int64", "_type": "Value"}, "_type": "LargeList"}},
174174
"homepage": "",
175175
"license": "",
176176
}

0 commit comments

Comments
 (0)