diff --git a/databricks/sdk/core.py b/databricks/sdk/core.py index 450a78d6c..23843480f 100644 --- a/databricks/sdk/core.py +++ b/databricks/sdk/core.py @@ -177,16 +177,18 @@ def refreshed_headers() -> Dict[str, str]: return refreshed_headers -class AzureCliTokenSource(Refreshable): - """ Obtain the token granted by `az login` CLI command """ +class CliTokenSource(Refreshable): - def __init__(self, resource: str): + def __init__(self, cmd: List[str], token_type_field: str, access_token_field: str, expiry_field: str): super().__init__() - self.resource = resource + self._cmd = cmd + self._token_type_field = token_type_field + self._access_token_field = access_token_field + self._expiry_field = expiry_field @staticmethod def _parse_expiry(expiry: str) -> datetime: - for fmt in ("%Y-%m-%d %H:%M:%S.%f", "%Y-%m-%d %H:%M:%S"): + for fmt in ("%Y-%m-%d %H:%M:%S.%f", "%Y-%m-%d %H:%M:%S", "%Y-%m-%dT%H:%M:%S.%f%z"): try: return datetime.strptime(expiry, fmt) except ValueError as e: @@ -196,18 +198,28 @@ def _parse_expiry(expiry: str) -> datetime: def refresh(self) -> Token: try: - cmd = ["az", "account", "get-access-token", "--resource", self.resource, "--output", "json"] - out = subprocess.check_output(cmd, stderr=subprocess.STDOUT) + out = subprocess.check_output(self._cmd, stderr=subprocess.STDOUT) it = json.loads(out.decode()) - expires_on = self._parse_expiry(it["expiresOn"]) - return Token(access_token=it["accessToken"], - refresh_token=it.get('refreshToken', None), - token_type=it["tokenType"], + expires_on = self._parse_expiry(it[self._expiry_field]) + return Token(access_token=it[self._access_token_field], + token_type=it[self._token_type_field], expiry=expires_on) except ValueError as e: raise ValueError(f"cannot unmarshal CLI result: {e}") except subprocess.CalledProcessError as e: - raise IOError(f'cannot get access token: {e.output.decode()}') from e + message = e.output.decode().strip() + raise IOError(f'cannot get access token: {message}') from e + + +class AzureCliTokenSource(CliTokenSource): + """ Obtain the token granted by `az login` CLI command """ + + def __init__(self, resource: str): + cmd = ["az", "account", "get-access-token", "--resource", resource, "--output", "json"] + super().__init__(cmd=cmd, + token_type_field='tokenType', + access_token_field='accessToken', + expiry_field='expiresOn') @credentials_provider('azure-cli', ['is_azure']) @@ -231,6 +243,45 @@ def inner() -> Dict[str, str]: return inner +class BricksCliTokenSource(CliTokenSource): + """ Obtain the token granted by `bricks auth login` CLI command """ + + def __init__(self, cfg: 'Config'): + cli_path = cfg.bricks_cli_path + if not cli_path: + cli_path = 'bricks' + cmd = [cli_path, 'auth', 'token', '--host', cfg.host] + if cfg.is_account_client: + cmd += ['--account-id', cfg.account_id] + super().__init__(cmd=cmd, + token_type_field='token_type', + access_token_field='access_token', + expiry_field='expiry') + + +@credentials_provider('bricks-cli', ['host', 'is_aws']) +def bricks_cli(cfg: 'Config') -> Optional[HeaderFactory]: + token_source = BricksCliTokenSource(cfg) + try: + token_source.token() + except FileNotFoundError: + logger.debug(f'Most likely Bricks CLI is not installed.') + return None + except IOError as e: + if 'databricks OAuth is not' in str(e): + logger.debug(f'OAuth not configured or not available: {e}') + return None + raise e + + logger.info("Using Bricks CLI authentication") + + def inner() -> Dict[str, str]: + token = token_source.token() + return {'Authorization': f'{token.token_type} {token.access_token}'} + + return inner + + class DefaultCredentials: """ Select the first applicable credential provider from the chain """ @@ -243,7 +294,7 @@ def auth_type(self) -> str: def __call__(self, cfg: 'Config') -> HeaderFactory: auth_providers = [ pat_auth, basic_auth, oauth_service_principal, azure_service_principal, azure_cli, - external_browser + external_browser, bricks_cli ] for provider in auth_providers: auth_type = provider.auth_type() diff --git a/tests/test_auth.py b/tests/test_auth.py index 450b4def9..4484a5082 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -180,7 +180,7 @@ def test_config_azure_cli_host(monkeypatch): @raises( - "default auth: azure-cli: cannot get access token: This is just a failing script.\n. Config: azure_workspace_resource_id=/sub/rg/ws" + "default auth: azure-cli: cannot get access token: This is just a failing script. Config: azure_workspace_resource_id=/sub/rg/ws" ) def test_config_azure_cli_host_fail(monkeypatch): monkeypatch.setenv('FAIL', 'yes')