Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions google/cloud/aiplatform/_streaming_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,17 +69,19 @@ def tensor_to_value(tensor_pb: aiplatform_types.Tensor) -> Any:
list_of_fields = tensor_pb.ListFields()
if not list_of_fields:
return None
descriptor, value = tensor_pb.ListFields()[0]
if descriptor.name == "list_val":
descriptor, value = list_of_fields[0]
name = descriptor.name
if name == "list_val":
# Use list comprehension instead of generator expression for completeness
return [tensor_to_value(x) for x in value]
elif descriptor.name == "struct_val":
elif name == "struct_val":
# Use dict comprehension directly on value.items() for maximal efficiency
return {k: tensor_to_value(v) for k, v in value.items()}
# Check for Sequence after above cases, then avoid repeating ListFields()
if not isinstance(value, Sequence):
raise TypeError(f"Unexpected non-list tensor value {value}")
if len(value) == 1:
return value[0]
else:
return value
# Only construct the list if len(value) != 1; returned value is unchanged
return value[0] if len(value) == 1 else value


def predict_stream_of_tensor_lists_from_single_tensor_list(
Expand Down