Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 64 additions & 13 deletions databricks/sdk/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,16 +177,18 @@ def refreshed_headers() -> Dict[str, str]:
return refreshed_headers


class AzureCliTokenSource(Refreshable):
""" Obtain the token granted by `az login` CLI command """
class CliTokenSource(Refreshable):

def __init__(self, resource: str):
def __init__(self, cmd: List[str], token_type_field: str, access_token_field: str, expiry_field: str):
super().__init__()
self.resource = resource
self._cmd = cmd
self._token_type_field = token_type_field
self._access_token_field = access_token_field
self._expiry_field = expiry_field

@staticmethod
def _parse_expiry(expiry: str) -> datetime:
for fmt in ("%Y-%m-%d %H:%M:%S.%f", "%Y-%m-%d %H:%M:%S"):
for fmt in ("%Y-%m-%d %H:%M:%S.%f", "%Y-%m-%d %H:%M:%S", "%Y-%m-%dT%H:%M:%S.%f%z"):
try:
return datetime.strptime(expiry, fmt)
except ValueError as e:
Expand All @@ -196,18 +198,28 @@ def _parse_expiry(expiry: str) -> datetime:

def refresh(self) -> Token:
try:
cmd = ["az", "account", "get-access-token", "--resource", self.resource, "--output", "json"]
out = subprocess.check_output(cmd, stderr=subprocess.STDOUT)
out = subprocess.check_output(self._cmd, stderr=subprocess.STDOUT)
it = json.loads(out.decode())
expires_on = self._parse_expiry(it["expiresOn"])
return Token(access_token=it["accessToken"],
refresh_token=it.get('refreshToken', None),
token_type=it["tokenType"],
expires_on = self._parse_expiry(it[self._expiry_field])
return Token(access_token=it[self._access_token_field],
token_type=it[self._token_type_field],
expiry=expires_on)
except ValueError as e:
raise ValueError(f"cannot unmarshal CLI result: {e}")
except subprocess.CalledProcessError as e:
raise IOError(f'cannot get access token: {e.output.decode()}') from e
message = e.output.decode().strip()
raise IOError(f'cannot get access token: {message}') from e


class AzureCliTokenSource(CliTokenSource):
""" Obtain the token granted by `az login` CLI command """

def __init__(self, resource: str):
cmd = ["az", "account", "get-access-token", "--resource", resource, "--output", "json"]
super().__init__(cmd=cmd,
token_type_field='tokenType',
access_token_field='accessToken',
expiry_field='expiresOn')


@credentials_provider('azure-cli', ['is_azure'])
Expand All @@ -231,6 +243,45 @@ def inner() -> Dict[str, str]:
return inner


class BricksCliTokenSource(CliTokenSource):
""" Obtain the token granted by `bricks auth login` CLI command """

def __init__(self, cfg: 'Config'):
cli_path = cfg.bricks_cli_path
if not cli_path:
cli_path = 'bricks'
cmd = [cli_path, 'auth', 'token', '--host', cfg.host]
if cfg.is_account_client:
cmd += ['--account-id', cfg.account_id]
super().__init__(cmd=cmd,
token_type_field='token_type',
access_token_field='access_token',
expiry_field='expiry')


@credentials_provider('bricks-cli', ['host', 'is_aws'])
def bricks_cli(cfg: 'Config') -> Optional[HeaderFactory]:
token_source = BricksCliTokenSource(cfg)
try:
token_source.token()
except FileNotFoundError:
logger.debug(f'Most likely Bricks CLI is not installed.')
return None
except IOError as e:
if 'databricks OAuth is not' in str(e):
logger.debug(f'OAuth not configured or not available: {e}')
return None
raise e

logger.info("Using Bricks CLI authentication")

def inner() -> Dict[str, str]:
token = token_source.token()
return {'Authorization': f'{token.token_type} {token.access_token}'}

return inner


class DefaultCredentials:
""" Select the first applicable credential provider from the chain """

Expand All @@ -243,7 +294,7 @@ def auth_type(self) -> str:
def __call__(self, cfg: 'Config') -> HeaderFactory:
auth_providers = [
pat_auth, basic_auth, oauth_service_principal, azure_service_principal, azure_cli,
external_browser
external_browser, bricks_cli
]
for provider in auth_providers:
auth_type = provider.auth_type()
Expand Down
2 changes: 1 addition & 1 deletion tests/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def test_config_azure_cli_host(monkeypatch):


@raises(
"default auth: azure-cli: cannot get access token: This is just a failing script.\n. Config: azure_workspace_resource_id=/sub/rg/ws"
"default auth: azure-cli: cannot get access token: This is just a failing script. Config: azure_workspace_resource_id=/sub/rg/ws"
)
def test_config_azure_cli_host_fail(monkeypatch):
monkeypatch.setenv('FAIL', 'yes')
Expand Down