Skip to content

Commit 5893d4d

Browse files
authored
[Fix] Fix DatabricksConfig.copy when authenticated with OAuth (#723)
## Changes <!-- Summary of your changes that are easy to understand --> `DatabricksCliTokenSource().token()` itself can't be copied. So, Deep Copy can't be performed for Config. Added the wrapper function which can be copied. So, Deep copy can be performed. ## Tests <!-- How is this tested? Please see the checklist below and also describe any other relevant tests --> - [ ] `make test` run locally - [ ] `make fmt` applied - [ ] relevant integration tests applied
1 parent fb30ed9 commit 5893d4d

File tree

2 files changed

+43
-1
lines changed

2 files changed

+43
-1
lines changed

databricks/sdk/credentials_provider.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -607,7 +607,10 @@ def inner() -> Dict[str, str]:
607607
token = token_source.token()
608608
return {'Authorization': f'{token.token_type} {token.access_token}'}
609609

610-
return OAuthCredentialsProvider(inner, token_source.token)
610+
def token() -> Token:
611+
return token_source.token()
612+
613+
return OAuthCredentialsProvider(inner, token)
611614

612615

613616
class MetadataServiceTokenSource(Refreshable):

tests/test_config.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
import os
2+
import pathlib
23
import platform
4+
import random
5+
import string
6+
from datetime import datetime
37

48
import pytest
59

610
from databricks.sdk import useragent
711
from databricks.sdk.config import Config, with_product, with_user_agent_extra
12+
from databricks.sdk.credentials_provider import Token
813
from databricks.sdk.version import __version__
914

1015
from .conftest import noop_credentials, set_az_path
@@ -79,6 +84,40 @@ def test_config_copy_deep_copies_user_agent_other_info(config):
7984
useragent._reset_extra(original_extra)
8085

8186

87+
def test_config_deep_copy(monkeypatch, mocker, tmp_path):
88+
mocker.patch('databricks.sdk.credentials_provider.CliTokenSource.refresh',
89+
return_value=Token(access_token='token',
90+
token_type='Bearer',
91+
expiry=datetime(2023, 5, 22, 0, 0, 0)))
92+
93+
write_large_dummy_executable(tmp_path)
94+
monkeypatch.setenv('PATH', tmp_path.as_posix())
95+
96+
config = Config(host="https://abc123.azuredatabricks.net", auth_type="databricks-cli")
97+
config_copy = config.deep_copy()
98+
assert config_copy.host == config.host
99+
100+
101+
def write_large_dummy_executable(path: pathlib.Path):
102+
cli = path.joinpath('databricks')
103+
104+
# Generate a long random string to inflate the file size.
105+
random_string = ''.join(random.choice(string.ascii_letters) for i in range(1024 * 1024))
106+
cli.write_text("""#!/bin/sh
107+
cat <<EOF
108+
{
109+
"access_token": "...",
110+
"token_type": "Bearer",
111+
"expiry": "2023-05-22T00:00:00.000000+00:00"
112+
}
113+
EOF
114+
exit 0
115+
""" + random_string)
116+
cli.chmod(0o755)
117+
assert cli.stat().st_size >= (1024 * 1024)
118+
return cli
119+
120+
82121
def test_load_azure_tenant_id_404(requests_mock, monkeypatch):
83122
set_az_path(monkeypatch)
84123
mock = requests_mock.get('https://abc123.azuredatabricks.net/aad/auth', status_code=404)

0 commit comments

Comments
 (0)