From ea57a99ff753bfdd928542b17a79527893f2136d Mon Sep 17 00:00:00 2001 From: Serge Smertin Date: Tue, 28 Mar 2023 20:27:23 +0200 Subject: [PATCH] Improve long-running operations --- .codegen/service.py.tmpl | 14 +++-- README.md | 4 ++ databricks/sdk/service/_internal.py | 22 ++++--- databricks/sdk/service/clusters.py | 20 +++++-- databricks/sdk/service/commands.py | 32 ++++++++-- databricks/sdk/service/deployment.py | 13 +++- databricks/sdk/service/endpoints.py | 17 ++++-- databricks/sdk/service/jobs.py | 31 ++++++---- databricks/sdk/service/pipelines.py | 30 ++++++---- databricks/sdk/service/sql.py | 30 ++++++++-- examples/last_job_runs.py | 39 ++++++++++++ examples/starting_job_and_waiting.py | 89 ++++++++++++++++++++++++++++ tests/integration/test_jobs.py | 66 ++++++++++++++++++++- 13 files changed, 352 insertions(+), 55 deletions(-) create mode 100755 examples/last_job_runs.py create mode 100755 examples/starting_job_and_waiting.py diff --git a/.codegen/service.py.tmpl b/.codegen/service.py.tmpl index 96a1b1bae..dfba212d7 100644 --- a/.codegen/service.py.tmpl +++ b/.codegen/service.py.tmpl @@ -3,7 +3,7 @@ from dataclasses import dataclass from datetime import timedelta from enum import Enum -from typing import Dict, List, Any, Iterator, Type +from typing import Dict, List, Any, Iterator, Type, Callable import time import random import logging @@ -52,7 +52,7 @@ class {{.PascalName}}{{if eq "List" .PascalName}}Request{{end}}:{{if .Descriptio {{- define "as_request_type" -}} {{- if not .Entity }}None # ERROR: No Type {{- else if .Entity.ArrayValue }}[{{if .Entity.ArrayValue.IsObject}}v.as_dict(){{else}}v{{end}} for v in self.{{.SnakeName}}] - {{- else if .Entity.IsObject }}self.{{.SnakeName}}.as_dict() + {{- else if or .Entity.IsObject .Entity.IsExternal }}self.{{.SnakeName}}.as_dict() {{- else if .Entity.Enum }}self.{{.SnakeName}}.value {{- else}}self.{{.SnakeName}}{{- end -}} {{- end -}} @@ -89,7 +89,8 @@ class {{.Name}}API:{{if .Description}} def __init__(self, api_client): self._api = api_client {{range .Waits}} - def {{.SnakeName}}(self{{range .Binding}}, {{.PollField.SnakeName}}: {{template "type-nq" .PollField.Entity}}{{end}}, timeout=timedelta(minutes={{.Timeout}})) -> {{.Poll.Response.PascalName}}: + def {{.SnakeName}}(self{{range .Binding}}, {{.PollField.SnakeName}}: {{template "type-nq" .PollField.Entity}}{{end}}, + timeout=timedelta(minutes={{.Timeout}}), callback: Callable[[{{.Poll.Response.PascalName}}], None] = None) -> {{.Poll.Response.PascalName}}: deadline = time.time() + timeout.total_seconds() target_states = ({{range .Success}}{{.Entity.PascalName}}.{{.Content}}, {{end}}){{if .Failure}} failure_states = ({{range .Failure}}{{.Entity.PascalName}}.{{.Content}}, {{end}}){{end}} @@ -109,6 +110,8 @@ class {{.Name}}API:{{if .Description}} {{- end}} if status in target_states: return poll + if callback: + callback(poll) {{if .Failure -}} if status in failure_states: msg = f'failed to reach {{range $i, $e := .Success}}{{if $i}} or {{end}}{{$e.Content}}{{end}}, got {status}: {status_message}' @@ -166,8 +169,9 @@ class {{.Name}}API:{{if .Description}} {{define "method-call-retried" -}} {{if .Response}}op_response = {{end}}{{template "method-do" .}} - return Wait(self.{{.Wait.SnakeName}}, {{range $i, $b := .Wait.Binding}}{{if $i}}, {{end}} - {{.PollField.SnakeName}}={{if .IsResponseBind}}op_response['{{.Bind.Name}}']{{else}}request.{{.Bind.SnakeName}}{{end}} + return Wait(self.{{.Wait.SnakeName}} + {{if .Response}}, response = {{.Response.PascalName}}.from_dict(op_response){{end}} + {{range .Wait.Binding}}, {{.PollField.SnakeName}}={{if .IsResponseBind}}op_response['{{.Bind.Name}}']{{else}}request.{{.Bind.SnakeName}}{{end}} {{- end}}) {{- end}} diff --git a/README.md b/README.md index 57d40fc8b..4f2012e35 100644 --- a/README.md +++ b/README.md @@ -220,6 +220,8 @@ info = w.clusters.create_and_wait(cluster_name='Created cluster', logging.info(f'Created: {info}') ``` +Please look at the `examples/starting_job_and_waiting.py` for a more advanced usage. + ## Paginated responses On the platform side the Databricks APIs have different wait to deal with pagination: @@ -240,6 +242,8 @@ for repo in w.repos.list(): logging.info(f'Found repo: {repo.path}') ``` +Please look at the `examples/last_job_runs.py` for a more advanced usage. + ## Single-Sign-On (SSO) with OAuth ### Authorization Code flow with PKCE diff --git a/databricks/sdk/service/_internal.py b/databricks/sdk/service/_internal.py index ed1f660c7..5cc8d52c9 100644 --- a/databricks/sdk/service/_internal.py +++ b/databricks/sdk/service/_internal.py @@ -26,12 +26,20 @@ def _enum(d: Dict[str, any], field: str, cls: Type) -> any: class Wait(Generic[ReturnType]): - def __init__(self, waiter: Callable, **kwargs) -> None: + def __init__(self, waiter: Callable, response: any = None, **kwargs) -> None: + self.response = response + self._waiter = waiter - self.arguments = kwargs + self._bind = kwargs + + def __getattr__(self, key) -> any: + return self._bind[key] + + def bind(self) -> dict: + return self._bind - def result(self, timeout: datetime.timedelta = None) -> ReturnType: - kwargs = self.arguments.copy() - if timeout: - kwargs['timeout'] = timeout - return self._waiter(**kwargs) + def result(self, + timeout: datetime.timedelta = None, + callback: Callable[[ReturnType], None] = None) -> ReturnType: + kwargs = self._bind.copy() + return self._waiter(callback=callback, timeout=timeout, **kwargs) diff --git a/databricks/sdk/service/clusters.py b/databricks/sdk/service/clusters.py index d9da10074..5710cc937 100755 --- a/databricks/sdk/service/clusters.py +++ b/databricks/sdk/service/clusters.py @@ -6,7 +6,7 @@ from dataclasses import dataclass from datetime import timedelta from enum import Enum -from typing import Dict, Iterator, List +from typing import Callable, Dict, Iterator, List from ..errors import OperationFailed from ._internal import Wait, _enum, _from_dict, _repeated @@ -1628,7 +1628,10 @@ class ClustersAPI: def __init__(self, api_client): self._api = api_client - def wait_get_cluster_running(self, cluster_id: str, timeout=timedelta(minutes=20)) -> ClusterInfo: + def wait_get_cluster_running(self, + cluster_id: str, + timeout=timedelta(minutes=20), + callback: Callable[[ClusterInfo], None] = None) -> ClusterInfo: deadline = time.time() + timeout.total_seconds() target_states = (State.RUNNING, ) failure_states = (State.ERROR, State.TERMINATED, ) @@ -1640,6 +1643,8 @@ def wait_get_cluster_running(self, cluster_id: str, timeout=timedelta(minutes=20 status_message = poll.state_message if status in target_states: return poll + if callback: + callback(poll) if status in failure_states: msg = f'failed to reach RUNNING, got {status}: {status_message}' raise OperationFailed(msg) @@ -1653,7 +1658,10 @@ def wait_get_cluster_running(self, cluster_id: str, timeout=timedelta(minutes=20 attempt += 1 raise TimeoutError(f'timed out after {timeout}: {status_message}') - def wait_get_cluster_terminated(self, cluster_id: str, timeout=timedelta(minutes=20)) -> ClusterInfo: + def wait_get_cluster_terminated(self, + cluster_id: str, + timeout=timedelta(minutes=20), + callback: Callable[[ClusterInfo], None] = None) -> ClusterInfo: deadline = time.time() + timeout.total_seconds() target_states = (State.TERMINATED, ) failure_states = (State.ERROR, ) @@ -1665,6 +1673,8 @@ def wait_get_cluster_terminated(self, cluster_id: str, timeout=timedelta(minutes status_message = poll.state_message if status in target_states: return poll + if callback: + callback(poll) if status in failure_states: msg = f'failed to reach TERMINATED, got {status}: {status_message}' raise OperationFailed(msg) @@ -1755,7 +1765,9 @@ def create(self, workload_type=workload_type) body = request.as_dict() op_response = self._api.do('POST', '/api/2.0/clusters/create', body=body) - return Wait(self.wait_get_cluster_running, cluster_id=op_response['cluster_id']) + return Wait(self.wait_get_cluster_running, + response=CreateClusterResponse.from_dict(op_response), + cluster_id=op_response['cluster_id']) def create_and_wait(self, spark_version: str, diff --git a/databricks/sdk/service/commands.py b/databricks/sdk/service/commands.py index 12eb719cb..b81b045e5 100755 --- a/databricks/sdk/service/commands.py +++ b/databricks/sdk/service/commands.py @@ -6,7 +6,7 @@ from dataclasses import dataclass from datetime import timedelta from enum import Enum -from typing import Any, Dict, List +from typing import Any, Callable, Dict, List from ..errors import OperationFailed from ._internal import Wait, _enum, _from_dict @@ -239,8 +239,12 @@ def __init__(self, api_client): self._api = api_client def wait_command_status_command_execution_cancelled( - self, cluster_id: str, command_id: str, context_id: str, - timeout=timedelta(minutes=20)) -> CommandStatusResponse: + self, + cluster_id: str, + command_id: str, + context_id: str, + timeout=timedelta(minutes=20), + callback: Callable[[CommandStatusResponse], None] = None) -> CommandStatusResponse: deadline = time.time() + timeout.total_seconds() target_states = (CommandStatus.Cancelled, ) failure_states = (CommandStatus.Error, ) @@ -254,6 +258,8 @@ def wait_command_status_command_execution_cancelled( status_message = poll.results.cause if status in target_states: return poll + if callback: + callback(poll) if status in failure_states: msg = f'failed to reach Cancelled, got {status}: {status_message}' raise OperationFailed(msg) @@ -268,8 +274,12 @@ def wait_command_status_command_execution_cancelled( raise TimeoutError(f'timed out after {timeout}: {status_message}') def wait_command_status_command_execution_finished_or_error( - self, cluster_id: str, command_id: str, context_id: str, - timeout=timedelta(minutes=20)) -> CommandStatusResponse: + self, + cluster_id: str, + command_id: str, + context_id: str, + timeout=timedelta(minutes=20), + callback: Callable[[CommandStatusResponse], None] = None) -> CommandStatusResponse: deadline = time.time() + timeout.total_seconds() target_states = (CommandStatus.Finished, CommandStatus.Error, ) failure_states = (CommandStatus.Cancelled, CommandStatus.Cancelling, ) @@ -281,6 +291,8 @@ def wait_command_status_command_execution_finished_or_error( status_message = f'current status: {status}' if status in target_states: return poll + if callback: + callback(poll) if status in failure_states: msg = f'failed to reach Finished or Error, got {status}: {status_message}' raise OperationFailed(msg) @@ -295,7 +307,11 @@ def wait_command_status_command_execution_finished_or_error( raise TimeoutError(f'timed out after {timeout}: {status_message}') def wait_context_status_command_execution_running( - self, cluster_id: str, context_id: str, timeout=timedelta(minutes=20)) -> ContextStatusResponse: + self, + cluster_id: str, + context_id: str, + timeout=timedelta(minutes=20), + callback: Callable[[ContextStatusResponse], None] = None) -> ContextStatusResponse: deadline = time.time() + timeout.total_seconds() target_states = (ContextStatus.Running, ) failure_states = (ContextStatus.Error, ) @@ -307,6 +323,8 @@ def wait_context_status_command_execution_running( status_message = f'current status: {status}' if status in target_states: return poll + if callback: + callback(poll) if status in failure_states: msg = f'failed to reach Running, got {status}: {status_message}' raise OperationFailed(msg) @@ -402,6 +420,7 @@ def create(self, body = request.as_dict() op_response = self._api.do('POST', '/api/1.2/contexts/create', body=body) return Wait(self.wait_context_status_command_execution_running, + response=Created.from_dict(op_response), cluster_id=request.cluster_id, context_id=op_response['id']) @@ -443,6 +462,7 @@ def execute(self, body = request.as_dict() op_response = self._api.do('POST', '/api/1.2/commands/execute', body=body) return Wait(self.wait_command_status_command_execution_finished_or_error, + response=Created.from_dict(op_response), cluster_id=request.cluster_id, command_id=op_response['id'], context_id=request.context_id) diff --git a/databricks/sdk/service/deployment.py b/databricks/sdk/service/deployment.py index f550ca8b5..1ea183eb5 100755 --- a/databricks/sdk/service/deployment.py +++ b/databricks/sdk/service/deployment.py @@ -6,7 +6,7 @@ from dataclasses import dataclass from datetime import timedelta from enum import Enum -from typing import Dict, Iterator, List +from typing import Callable, Dict, Iterator, List from ..errors import OperationFailed from ._internal import Wait, _enum, _from_dict, _repeated @@ -1561,7 +1561,10 @@ class WorkspacesAPI: def __init__(self, api_client): self._api = api_client - def wait_get_workspace_running(self, workspace_id: int, timeout=timedelta(minutes=20)) -> Workspace: + def wait_get_workspace_running(self, + workspace_id: int, + timeout=timedelta(minutes=20), + callback: Callable[[Workspace], None] = None) -> Workspace: deadline = time.time() + timeout.total_seconds() target_states = (WorkspaceStatus.RUNNING, ) failure_states = (WorkspaceStatus.BANNED, WorkspaceStatus.FAILED, ) @@ -1573,6 +1576,8 @@ def wait_get_workspace_running(self, workspace_id: int, timeout=timedelta(minute status_message = poll.workspace_status_message if status in target_states: return poll + if callback: + callback(poll) if status in failure_states: msg = f'failed to reach RUNNING, got {status}: {status_message}' raise OperationFailed(msg) @@ -1630,7 +1635,9 @@ def create(self, workspace_name=workspace_name) body = request.as_dict() op_response = self._api.do('POST', f'/api/2.0/accounts/{self._api.account_id}/workspaces', body=body) - return Wait(self.wait_get_workspace_running, workspace_id=op_response['workspace_id']) + return Wait(self.wait_get_workspace_running, + response=Workspace.from_dict(op_response), + workspace_id=op_response['workspace_id']) def create_and_wait( self, diff --git a/databricks/sdk/service/endpoints.py b/databricks/sdk/service/endpoints.py index 6499a7a98..71a748974 100755 --- a/databricks/sdk/service/endpoints.py +++ b/databricks/sdk/service/endpoints.py @@ -6,7 +6,7 @@ from dataclasses import dataclass from datetime import timedelta from enum import Enum -from typing import Any, Dict, List +from typing import Any, Callable, Dict, List from ..errors import OperationFailed from ._internal import Wait, _enum, _from_dict, _repeated @@ -488,7 +488,10 @@ def __init__(self, api_client): self._api = api_client def wait_get_serving_endpoint_not_updating( - self, name: str, timeout=timedelta(minutes=20)) -> ServingEndpointDetailed: + self, + name: str, + timeout=timedelta(minutes=20), + callback: Callable[[ServingEndpointDetailed], None] = None) -> ServingEndpointDetailed: deadline = time.time() + timeout.total_seconds() target_states = (EndpointStateConfigUpdate.NOT_UPDATING, ) failure_states = (EndpointStateConfigUpdate.UPDATE_FAILED, ) @@ -500,6 +503,8 @@ def wait_get_serving_endpoint_not_updating( status_message = f'current status: {status}' if status in target_states: return poll + if callback: + callback(poll) if status in failure_states: msg = f'failed to reach NOT_UPDATING, got {status}: {status_message}' raise OperationFailed(msg) @@ -534,7 +539,9 @@ def create(self, name: str, config: EndpointCoreConfigInput, **kwargs) -> Wait[S request = CreateServingEndpoint(config=config, name=name) body = request.as_dict() op_response = self._api.do('POST', '/api/2.0/serving-endpoints', body=body) - return Wait(self.wait_get_serving_endpoint_not_updating, name=op_response['name']) + return Wait(self.wait_get_serving_endpoint_not_updating, + response=ServingEndpointDetailed.from_dict(op_response), + name=op_response['name']) def create_and_wait( self, name: str, config: EndpointCoreConfigInput, @@ -618,7 +625,9 @@ def update_config(self, traffic_config=traffic_config) body = request.as_dict() op_response = self._api.do('PUT', f'/api/2.0/serving-endpoints/{request.name}/config', body=body) - return Wait(self.wait_get_serving_endpoint_not_updating, name=op_response['name']) + return Wait(self.wait_get_serving_endpoint_not_updating, + response=ServingEndpointDetailed.from_dict(op_response), + name=op_response['name']) def update_config_and_wait( self, diff --git a/databricks/sdk/service/jobs.py b/databricks/sdk/service/jobs.py index 7cd0b2ca7..b85bbde9c 100755 --- a/databricks/sdk/service/jobs.py +++ b/databricks/sdk/service/jobs.py @@ -6,7 +6,7 @@ from dataclasses import dataclass from datetime import timedelta from enum import Enum -from typing import Dict, Iterator, List +from typing import Callable, Dict, Iterator, List from ..errors import OperationFailed from ._internal import Wait, _enum, _from_dict, _repeated @@ -183,7 +183,7 @@ def as_dict(self) -> dict: body = {} if self.existing_cluster_id: body['existing_cluster_id'] = self.existing_cluster_id if self.libraries: body['libraries'] = [v for v in self.libraries] - if self.new_cluster: body['new_cluster'] = self.new_cluster + if self.new_cluster: body['new_cluster'] = self.new_cluster.as_dict() return body @classmethod @@ -555,7 +555,7 @@ class JobCluster: def as_dict(self) -> dict: body = {} if self.job_cluster_key: body['job_cluster_key'] = self.job_cluster_key - if self.new_cluster: body['new_cluster'] = self.new_cluster + if self.new_cluster: body['new_cluster'] = self.new_cluster.as_dict() return body @classmethod @@ -679,7 +679,7 @@ def as_dict(self) -> dict: if self.libraries: body['libraries'] = [v for v in self.libraries] if self.max_retries: body['max_retries'] = self.max_retries if self.min_retry_interval_millis: body['min_retry_interval_millis'] = self.min_retry_interval_millis - if self.new_cluster: body['new_cluster'] = self.new_cluster + if self.new_cluster: body['new_cluster'] = self.new_cluster.as_dict() if self.notebook_task: body['notebook_task'] = self.notebook_task.as_dict() if self.pipeline_task: body['pipeline_task'] = self.pipeline_task.as_dict() if self.python_wheel_task: body['python_wheel_task'] = self.python_wheel_task.as_dict() @@ -1336,7 +1336,7 @@ def as_dict(self) -> dict: if self.depends_on: body['depends_on'] = [v.as_dict() for v in self.depends_on] if self.existing_cluster_id: body['existing_cluster_id'] = self.existing_cluster_id if self.libraries: body['libraries'] = [v for v in self.libraries] - if self.new_cluster: body['new_cluster'] = self.new_cluster + if self.new_cluster: body['new_cluster'] = self.new_cluster.as_dict() if self.notebook_task: body['notebook_task'] = self.notebook_task.as_dict() if self.pipeline_task: body['pipeline_task'] = self.pipeline_task.as_dict() if self.python_wheel_task: body['python_wheel_task'] = self.python_wheel_task.as_dict() @@ -1403,7 +1403,7 @@ def as_dict(self) -> dict: if self.existing_cluster_id: body['existing_cluster_id'] = self.existing_cluster_id if self.git_source: body['git_source'] = self.git_source.as_dict() if self.libraries: body['libraries'] = [v for v in self.libraries] - if self.new_cluster: body['new_cluster'] = self.new_cluster + if self.new_cluster: body['new_cluster'] = self.new_cluster.as_dict() if self.notebook_task: body['notebook_task'] = self.notebook_task.as_dict() if self.pipeline_task: body['pipeline_task'] = self.pipeline_task.as_dict() if self.python_wheel_task: body['python_wheel_task'] = self.python_wheel_task.as_dict() @@ -1975,7 +1975,10 @@ class JobsAPI: def __init__(self, api_client): self._api = api_client - def wait_get_run_job_terminated_or_skipped(self, run_id: int, timeout=timedelta(minutes=20)) -> Run: + def wait_get_run_job_terminated_or_skipped(self, + run_id: int, + timeout=timedelta(minutes=20), + callback: Callable[[Run], None] = None) -> Run: deadline = time.time() + timeout.total_seconds() target_states = (RunLifeCycleState.TERMINATED, RunLifeCycleState.SKIPPED, ) failure_states = (RunLifeCycleState.INTERNAL_ERROR, ) @@ -1989,6 +1992,8 @@ def wait_get_run_job_terminated_or_skipped(self, run_id: int, timeout=timedelta( status_message = poll.state.state_message if status in target_states: return poll + if callback: + callback(poll) if status in failure_states: msg = f'failed to reach TERMINATED or SKIPPED, got {status}: {status_message}' raise OperationFailed(msg) @@ -2277,7 +2282,9 @@ def repair_run(self, sql_params=sql_params) body = request.as_dict() op_response = self._api.do('POST', '/api/2.1/jobs/runs/repair', body=body) - return Wait(self.wait_get_run_job_terminated_or_skipped, run_id=request.run_id) + return Wait(self.wait_get_run_job_terminated_or_skipped, + response=RepairRunResponse.from_dict(op_response), + run_id=request.run_id) def repair_run_and_wait(self, run_id: int, @@ -2348,7 +2355,9 @@ def run_now(self, sql_params=sql_params) body = request.as_dict() op_response = self._api.do('POST', '/api/2.1/jobs/run-now', body=body) - return Wait(self.wait_get_run_job_terminated_or_skipped, run_id=op_response['run_id']) + return Wait(self.wait_get_run_job_terminated_or_skipped, + response=RunNowResponse.from_dict(op_response), + run_id=op_response['run_id']) def run_now_and_wait(self, job_id: int, @@ -2400,7 +2409,9 @@ def submit(self, webhook_notifications=webhook_notifications) body = request.as_dict() op_response = self._api.do('POST', '/api/2.1/jobs/runs/submit', body=body) - return Wait(self.wait_get_run_job_terminated_or_skipped, run_id=op_response['run_id']) + return Wait(self.wait_get_run_job_terminated_or_skipped, + response=SubmitRunResponse.from_dict(op_response), + run_id=op_response['run_id']) def submit_and_wait( self, diff --git a/databricks/sdk/service/pipelines.py b/databricks/sdk/service/pipelines.py index 5e7736625..c3b076b5f 100755 --- a/databricks/sdk/service/pipelines.py +++ b/databricks/sdk/service/pipelines.py @@ -6,7 +6,7 @@ from dataclasses import dataclass from datetime import timedelta from enum import Enum -from typing import Any, Dict, Iterator, List +from typing import Any, Callable, Dict, Iterator, List from ..errors import OperationFailed from ._internal import Wait, _enum, _from_dict, _repeated @@ -521,14 +521,14 @@ def as_dict(self) -> dict: body = {} if self.apply_policy_default_values: body['apply_policy_default_values'] = self.apply_policy_default_values - if self.autoscale: body['autoscale'] = self.autoscale - if self.aws_attributes: body['aws_attributes'] = self.aws_attributes - if self.azure_attributes: body['azure_attributes'] = self.azure_attributes - if self.cluster_log_conf: body['cluster_log_conf'] = self.cluster_log_conf + if self.autoscale: body['autoscale'] = self.autoscale.as_dict() + if self.aws_attributes: body['aws_attributes'] = self.aws_attributes.as_dict() + if self.azure_attributes: body['azure_attributes'] = self.azure_attributes.as_dict() + if self.cluster_log_conf: body['cluster_log_conf'] = self.cluster_log_conf.as_dict() if self.custom_tags: body['custom_tags'] = self.custom_tags if self.driver_instance_pool_id: body['driver_instance_pool_id'] = self.driver_instance_pool_id if self.driver_node_type_id: body['driver_node_type_id'] = self.driver_node_type_id - if self.gcp_attributes: body['gcp_attributes'] = self.gcp_attributes + if self.gcp_attributes: body['gcp_attributes'] = self.gcp_attributes.as_dict() if self.instance_pool_id: body['instance_pool_id'] = self.instance_pool_id if self.label: body['label'] = self.label if self.node_type_id: body['node_type_id'] = self.node_type_id @@ -608,7 +608,7 @@ class PipelineLibrary: def as_dict(self) -> dict: body = {} if self.jar: body['jar'] = self.jar - if self.maven: body['maven'] = self.maven + if self.maven: body['maven'] = self.maven.as_dict() if self.notebook: body['notebook'] = self.notebook.as_dict() if self.whl: body['whl'] = self.whl return body @@ -983,7 +983,10 @@ class PipelinesAPI: def __init__(self, api_client): self._api = api_client - def wait_get_pipeline_idle(self, pipeline_id: str, timeout=timedelta(minutes=20)) -> GetPipelineResponse: + def wait_get_pipeline_idle(self, + pipeline_id: str, + timeout=timedelta(minutes=20), + callback: Callable[[GetPipelineResponse], None] = None) -> GetPipelineResponse: deadline = time.time() + timeout.total_seconds() target_states = (PipelineState.IDLE, ) failure_states = (PipelineState.FAILED, ) @@ -995,6 +998,8 @@ def wait_get_pipeline_idle(self, pipeline_id: str, timeout=timedelta(minutes=20) status_message = poll.cause if status in target_states: return poll + if callback: + callback(poll) if status in failure_states: msg = f'failed to reach IDLE, got {status}: {status_message}' raise OperationFailed(msg) @@ -1008,8 +1013,11 @@ def wait_get_pipeline_idle(self, pipeline_id: str, timeout=timedelta(minutes=20) attempt += 1 raise TimeoutError(f'timed out after {timeout}: {status_message}') - def wait_get_pipeline_running(self, pipeline_id: str, - timeout=timedelta(minutes=20)) -> GetPipelineResponse: + def wait_get_pipeline_running( + self, + pipeline_id: str, + timeout=timedelta(minutes=20), + callback: Callable[[GetPipelineResponse], None] = None) -> GetPipelineResponse: deadline = time.time() + timeout.total_seconds() target_states = (PipelineState.RUNNING, ) failure_states = (PipelineState.FAILED, ) @@ -1021,6 +1029,8 @@ def wait_get_pipeline_running(self, pipeline_id: str, status_message = poll.cause if status in target_states: return poll + if callback: + callback(poll) if status in failure_states: msg = f'failed to reach RUNNING, got {status}: {status_message}' raise OperationFailed(msg) diff --git a/databricks/sdk/service/sql.py b/databricks/sdk/service/sql.py index 8207d7fc9..9703f3a0b 100755 --- a/databricks/sdk/service/sql.py +++ b/databricks/sdk/service/sql.py @@ -6,7 +6,7 @@ from dataclasses import dataclass from datetime import timedelta from enum import Enum -from typing import Any, Dict, Iterator, List +from typing import Any, Callable, Dict, Iterator, List from ..errors import OperationFailed from ._internal import Wait, _enum, _from_dict, _repeated @@ -3067,7 +3067,11 @@ class WarehousesAPI: def __init__(self, api_client): self._api = api_client - def wait_get_warehouse_deleted(self, id: str, timeout=timedelta(minutes=20)) -> GetWarehouseResponse: + def wait_get_warehouse_deleted(self, + id: str, + timeout=timedelta(minutes=20), + callback: Callable[[GetWarehouseResponse], + None] = None) -> GetWarehouseResponse: deadline = time.time() + timeout.total_seconds() target_states = (State.DELETED, ) status_message = 'polling...' @@ -3080,6 +3084,8 @@ def wait_get_warehouse_deleted(self, id: str, timeout=timedelta(minutes=20)) -> status_message = poll.health.summary if status in target_states: return poll + if callback: + callback(poll) prefix = f"id={id}" sleep = attempt if sleep > 10: @@ -3090,7 +3096,11 @@ def wait_get_warehouse_deleted(self, id: str, timeout=timedelta(minutes=20)) -> attempt += 1 raise TimeoutError(f'timed out after {timeout}: {status_message}') - def wait_get_warehouse_running(self, id: str, timeout=timedelta(minutes=20)) -> GetWarehouseResponse: + def wait_get_warehouse_running(self, + id: str, + timeout=timedelta(minutes=20), + callback: Callable[[GetWarehouseResponse], + None] = None) -> GetWarehouseResponse: deadline = time.time() + timeout.total_seconds() target_states = (State.RUNNING, ) failure_states = (State.STOPPED, State.DELETED, ) @@ -3104,6 +3114,8 @@ def wait_get_warehouse_running(self, id: str, timeout=timedelta(minutes=20)) -> status_message = poll.health.summary if status in target_states: return poll + if callback: + callback(poll) if status in failure_states: msg = f'failed to reach RUNNING, got {status}: {status_message}' raise OperationFailed(msg) @@ -3117,7 +3129,11 @@ def wait_get_warehouse_running(self, id: str, timeout=timedelta(minutes=20)) -> attempt += 1 raise TimeoutError(f'timed out after {timeout}: {status_message}') - def wait_get_warehouse_stopped(self, id: str, timeout=timedelta(minutes=20)) -> GetWarehouseResponse: + def wait_get_warehouse_stopped(self, + id: str, + timeout=timedelta(minutes=20), + callback: Callable[[GetWarehouseResponse], + None] = None) -> GetWarehouseResponse: deadline = time.time() + timeout.total_seconds() target_states = (State.STOPPED, ) status_message = 'polling...' @@ -3130,6 +3146,8 @@ def wait_get_warehouse_stopped(self, id: str, timeout=timedelta(minutes=20)) -> status_message = poll.health.summary if status in target_states: return poll + if callback: + callback(poll) prefix = f"id={id}" sleep = attempt if sleep > 10: @@ -3176,7 +3194,9 @@ def create(self, warehouse_type=warehouse_type) body = request.as_dict() op_response = self._api.do('POST', '/api/2.0/sql/warehouses', body=body) - return Wait(self.wait_get_warehouse_running, id=op_response['id']) + return Wait(self.wait_get_warehouse_running, + response=CreateWarehouseResponse.from_dict(op_response), + id=op_response['id']) def create_and_wait( self, diff --git a/examples/last_job_runs.py b/examples/last_job_runs.py new file mode 100755 index 000000000..2b96ebfa7 --- /dev/null +++ b/examples/last_job_runs.py @@ -0,0 +1,39 @@ +#!env python3 +import logging +import sys +from collections import defaultdict +from datetime import datetime, timezone +from databricks.sdk import WorkspaceClient + + +if __name__ == '__main__': + logging.basicConfig(stream=sys.stdout, level=logging.INFO, + format='%(asctime)s [%(name)s][%(levelname)s] %(message)s') + + latest_state = {} + all_jobs = {} + durations = defaultdict(list) + + w = WorkspaceClient() + for job in w.jobs.list(): + all_jobs[job.job_id] = job + for run in w.jobs.list_runs(job_id=job.job_id, expand_tasks=False): + durations[job.job_id].append(run.run_duration) + if job.job_id not in latest_state: + latest_state[job.job_id] = run + continue + if run.end_time < latest_state[job.job_id].end_time: + continue + latest_state[job.job_id] = run + + summary = [] + for job_id, run in latest_state.items(): + summary.append({ + 'job_name': all_jobs[job_id].settings.name, + 'last_status': run.state.result_state, + 'last_finished': datetime.fromtimestamp(run.end_time/1000, timezone.utc), + 'average_duration': sum(durations[job_id]) / len(durations[job_id]) + }) + + for line in sorted(summary, key=lambda s: s['last_finished'], reverse=True): + logging.info(f'Latest: {line}') diff --git a/examples/starting_job_and_waiting.py b/examples/starting_job_and_waiting.py new file mode 100755 index 000000000..1d4d17504 --- /dev/null +++ b/examples/starting_job_and_waiting.py @@ -0,0 +1,89 @@ +#!env python3 +""" Detailed demonstration of long-running operations + +This example goes over the advanced usage of long-running operations like: + + - w.clusters.create + - w.clusters.delete + - w.clusters.edit + - w.clusters.resize + - w.clusters.restart + - w.clusters.start + - w.command_execution.cancel + - w.command_execution.create + - w.command_execution.execute + - a.workspaces.create + - a.workspaces.update + - w.serving_endpoints.create + - w.serving_endpoints.update_config + - w.jobs.cancel_run + - w.jobs.repair_run + - w.jobs.run_now + - w.jobs.submit + - w.pipelines.reset + - w.pipelines.stop + - w.warehouses.create + - w.warehouses.delete + - w.warehouses.edit + - w.warehouses.start + - w.warehouses.stop + +In this example, you'll learn how block main thread until operation reaches a terminal state or times out. +You'll also learn how to add a custom callback for intermediate state updates. + +You can change `logging.INFO` to `logging.DEBUG` to see HTTP traffic performed by SDK under the hood. +""" + +import datetime +import logging +import sys +import time + +from databricks.sdk import WorkspaceClient +import databricks.sdk.service.jobs as j + +if __name__ == '__main__': + logging.basicConfig(stream=sys.stdout, level=logging.INFO, + format='%(asctime)s [%(name)s][%(levelname)s] %(message)s') + + w = WorkspaceClient() + + # create a dummy file on DBFS that just sleeps for 10 seconds + py_on_dbfs = f'/home/{w.current_user.me().user_name}/sample.py' + with w.dbfs.open(py_on_dbfs, write=True, overwrite=True) as f: + f.write(b'import time; time.sleep(10); print("Hello, World!")') + + # trigger one-time-run job and get waiter object + waiter = w.jobs.submit(run_name=f'py-sdk-run-{time.time()}', tasks=[ + j.RunSubmitTaskSettings( + task_key='hello_world', + new_cluster=j.BaseClusterInfo( + spark_version=w.clusters.select_spark_version(long_term_support=True), + node_type_id=w.clusters.select_node_type(local_disk=True), + num_workers=1 + ), + spark_python_task=j.SparkPythonTask( + python_file=f'dbfs:{py_on_dbfs}' + ), + ) + ]) + + logging.info(f'starting to poll: {waiter.run_id}') + + # callback, that receives a polled entity between state updates + def print_status(run: j.Run): + statuses = [f'{t.task_key}: {t.state.life_cycle_state}' for t in run.tasks] + logging.info(f'workflow intermediate status: {", ".join(statuses)}') + + # If you want to perform polling in a separate thread, process, or service, + # you can use w.jobs.wait_get_run_job_terminated_or_skipped( + # run_id=waiter.run_id, + # timeout=datetime.timedelta(minutes=15), + # callback=print_status) to achieve the same results. + # + # Waiter interface allows for `w.jobs.submit(..).result()` simplicity in + # the scenarios, where you need to block the calling thread for the job to finish. + run = waiter.result(timeout=datetime.timedelta(minutes=15), + callback=print_status) + + logging.info(f'job finished: {run.run_page_url}') \ No newline at end of file diff --git a/tests/integration/test_jobs.py b/tests/integration/test_jobs.py index f168bd6b3..c9504d444 100644 --- a/tests/integration/test_jobs.py +++ b/tests/integration/test_jobs.py @@ -1,3 +1,4 @@ +import datetime import logging @@ -6,4 +7,67 @@ def test_jobs(w): for job in w.jobs.list(): logging.info(f'Looking at {job.settings.name}') found += 1 - assert found > 0 \ No newline at end of file + assert found > 0 + + +def test_submitting_jobs(w, random, env_or_skip): + import databricks.sdk.service.jobs as j + + py_on_dbfs = f'/home/{w.current_user.me().user_name}/sample.py' + with w.dbfs.open(py_on_dbfs, write=True, overwrite=True) as f: + f.write(b'import time; time.sleep(10); print("Hello, World!")') + + waiter = w.jobs.submit(run_name=f'py-sdk-{random(8)}', + tasks=[ + j.RunSubmitTaskSettings( + task_key='pi', + new_cluster=j.BaseClusterInfo( + spark_version=w.clusters.select_spark_version(long_term_support=True), + # node_type_id=w.clusters.select_node_type(local_disk=True), + instance_pool_id=env_or_skip('TEST_INSTANCE_POOL_ID'), + num_workers=1), + spark_python_task=j.SparkPythonTask(python_file=f'dbfs:{py_on_dbfs}'), + ) + ]) + + logging.info(f'starting to poll: {waiter.run_id}') + + def print_status(run: j.Run): + statuses = [f'{t.task_key}: {t.state.life_cycle_state}' for t in run.tasks] + logging.info(f'workflow intermediate status: {", ".join(statuses)}') + + run = waiter.result(timeout=datetime.timedelta(minutes=15), callback=print_status) + + logging.info(f'job finished: {run.run_page_url}') + + +def test_last_job_runs(w): + from collections import defaultdict + from datetime import datetime, timezone + + latest_state = {} + all_jobs = {} + durations = defaultdict(list) + + for job in w.jobs.list(): + all_jobs[job.job_id] = job + for run in w.jobs.list_runs(job_id=job.job_id, expand_tasks=False): + durations[job.job_id].append(run.run_duration) + if job.job_id not in latest_state: + latest_state[job.job_id] = run + continue + if run.end_time < latest_state[job.job_id].end_time: + continue + latest_state[job.job_id] = run + + summary = [] + for job_id, run in latest_state.items(): + summary.append({ + 'job_name': all_jobs[job_id].settings.name, + 'last_status': run.state.result_state, + 'last_finished': datetime.fromtimestamp(run.end_time / 1000, timezone.utc), + 'average_duration': sum(durations[job_id]) / len(durations[job_id]) + }) + + for line in sorted(summary, key=lambda s: s['last_finished'], reverse=True): + logging.info(f'Latest: {line}')