Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions llmfoundry/command_utils/data_prep/convert_delta_to_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,6 +712,14 @@ def fetch_DT(
message=
f'The data preparation cluster you provided is terminated. Please retry with a cluster that is healthy and alive. {e}',
) from e
if isinstance(
e,
spark_errors.SparkConnectGrpcException,
) and 'is not usable' in str(e):
raise FaultyDataPrepCluster(
message=
f'The data preparation cluster you provided is not usable. Please retry with a cluster that is healthy and alive. {e}',
) from e
if isinstance(e, grpc.RpcError) and e.code(
) == grpc.StatusCode.INTERNAL and 'Job aborted due to stage failure' in e.details(
):
Expand Down
153 changes: 62 additions & 91 deletions tests/a_scripts/data_prep/test_convert_delta_to_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,7 @@ def test_format_tablename(self):
@patch(
'llmfoundry.command_utils.data_prep.convert_delta_to_json.validate_and_get_cluster_info',
)
def test_fetch_DT_grpc_error_handling(
def test_fetch_DT_catches_grpc_errors(
self,
mock_validate_cluster_info: MagicMock,
mock_fetch: MagicMock,
Expand All @@ -543,99 +543,70 @@ def test_fetch_DT_grpc_error_handling(
# Mock the validate_and_get_cluster_info to return test values
mock_validate_cluster_info.return_value = ('dbconnect', None, None)

# Create a grpc.RpcError with StatusCode.INTERNAL and specific details
grpc_error = grpc.RpcError()
grpc_error.code = lambda: grpc.StatusCode.INTERNAL
grpc_error.details = lambda: 'Job aborted due to stage failure: Task failed due to an error.'

# Configure the fetch function to raise the grpc.RpcError
mock_fetch.side_effect = grpc_error

# Test inputs
delta_table_name = 'test_table'
json_output_folder = '/tmp/to/jsonl'
http_path = None
cluster_id = None
use_serverless = False
DATABRICKS_HOST = 'https://test-host'
DATABRICKS_TOKEN = 'test-token'

# Act & Assert
with self.assertRaises(FaultyDataPrepCluster) as context:
fetch_DT(
delta_table_name=delta_table_name,
json_output_folder=json_output_folder,
http_path=http_path,
cluster_id=cluster_id,
use_serverless=use_serverless,
DATABRICKS_HOST=DATABRICKS_HOST,
DATABRICKS_TOKEN=DATABRICKS_TOKEN,
)

# Verify that the FaultyDataPrepCluster contains the expected message
self.assertIn(
'Faulty data prep cluster, please try swapping data prep cluster: ',
str(context.exception),
)
self.assertIn(
'Job aborted due to stage failure',
str(context.exception),
)

# Verify that fetch was called
mock_fetch.assert_called_once()

@patch('llmfoundry.command_utils.data_prep.convert_delta_to_json.fetch')
@patch(
'llmfoundry.command_utils.data_prep.convert_delta_to_json.validate_and_get_cluster_info',
)
def test_fetch_DT_catches_cluster_failed_to_start(
self,
mock_validate_cluster_info: MagicMock,
mock_fetch: MagicMock,
):
# Arrange
# Mock the validate_and_get_cluster_info to return test values
mock_validate_cluster_info.return_value = ('dbconnect', None, None)

# Create a SparkConnectGrpcException indicating that the cluster failed to start

grpc_error = SparkConnectGrpcException(
message='Cannot start cluster etc...',
)

# Configure the fetch function to raise the SparkConnectGrpcException
mock_fetch.side_effect = grpc_error

# Test inputs
delta_table_name = 'test_table'
json_output_folder = '/tmp/to/jsonl'
http_path = None
cluster_id = None
use_serverless = False
DATABRICKS_HOST = 'https://test-host'
DATABRICKS_TOKEN = 'test-token'

# Act & Assert
with self.assertRaises(FaultyDataPrepCluster) as context:
fetch_DT(
delta_table_name=delta_table_name,
json_output_folder=json_output_folder,
http_path=http_path,
cluster_id=cluster_id,
use_serverless=use_serverless,
DATABRICKS_HOST=DATABRICKS_HOST,
DATABRICKS_TOKEN=DATABRICKS_TOKEN,
)
grpc_lib_error = grpc.RpcError()
grpc_lib_error.code = lambda: grpc.StatusCode.INTERNAL
grpc_lib_error.details = lambda: 'Job aborted due to stage failure: Task failed due to an error.'

error_contexts = [
(
SparkConnectGrpcException('Cannot start cluster etc...'),
FaultyDataPrepCluster,
[
'The data preparation cluster you provided is terminated. Please retry with a cluster that is healthy and alive.',
],
),
(
SparkConnectGrpcException('cluster ... is not usable'),
FaultyDataPrepCluster,
[
'The data preparation cluster you provided is not usable. Please retry with a cluster that is healthy and alive.',
],
),
(
grpc_lib_error,
FaultyDataPrepCluster,
[
'Faulty data prep cluster, please try swapping data prep cluster: ',
'Job aborted due to stage failure',
],
),
]

for (
err_to_throw,
err_to_catch,
texts_to_check_in_error,
) in error_contexts:
# Configure the fetch function to raise the SparkConnectGrpcException
mock_fetch.side_effect = err_to_throw

# Test inputs
delta_table_name = 'test_table'
json_output_folder = '/tmp/to/jsonl'
http_path = None
cluster_id = None
use_serverless = False
DATABRICKS_HOST = 'https://test-host'
DATABRICKS_TOKEN = 'test-token'

# Act & Assert
with self.assertRaises(err_to_catch) as context:
fetch_DT(
delta_table_name=delta_table_name,
json_output_folder=json_output_folder,
http_path=http_path,
cluster_id=cluster_id,
use_serverless=use_serverless,
DATABRICKS_HOST=DATABRICKS_HOST,
DATABRICKS_TOKEN=DATABRICKS_TOKEN,
)

# Verify that the FaultyDataPrepCluster contains the expected message
self.assertIn(
'The data preparation cluster you provided is terminated. Please retry with a cluster that is healthy and alive.',
str(context.exception),
)
# Verify that the FaultyDataPrepCluster contains the expected message
for text in texts_to_check_in_error:
self.assertIn(text, str(context.exception))

# Verify that fetch was called
mock_fetch.assert_called_once()
mock_fetch.assert_called()

@patch(
'llmfoundry.command_utils.data_prep.convert_delta_to_json.get_total_rows',
Expand Down
Loading