Skip to content

Commit 405b147

Browse files
pieternnfx
authored andcommitted
Persist token acquired through external-browser auth type (#110)
## Changes This cache is local to the Python SDK and keyed off the workspace host, the OAuth client ID, and the list of scopes to authorize for. The cache path is `~/.config/databricks-sdk-py/oauth`. Files saved to this directory are masked 0600. It does not persist refreshes that happen during a session. ## Tests Reproduce by running `./examples/local_browser_oauth.py` multiple times. - [x] `make test` run locally - [x] `make fmt` applied - [x] relevant integration tests applied
1 parent b41a821 commit 405b147

File tree

3 files changed

+85
-6
lines changed

3 files changed

+85
-6
lines changed

databricks/sdk/core.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from .azure import ARM_DATABRICKS_RESOURCE_ID, ENVIRONMENTS, AzureEnvironment
2323
from .oauth import (ClientCredentials, OAuthClient, OidcEndpoints, Refreshable,
24-
Token, TokenSource)
24+
Token, TokenCache, TokenSource)
2525
from .version import __version__
2626

2727
__all__ = ['Config', 'DatabricksError']
@@ -140,10 +140,20 @@ def external_browser(cfg: 'Config') -> Optional[HeaderFactory]:
140140
client_id=client_id,
141141
redirect_url='http://localhost:8020',
142142
client_secret=cfg.client_secret)
143-
consent = oauth_client.initiate_consent()
144-
if not consent:
145-
return None
146-
credentials = consent.launch_external_browser()
143+
144+
# Load cached credentials from disk if they exist.
145+
# Note that these are local to the Python SDK and not reused by other SDKs.
146+
token_cache = TokenCache(oauth_client)
147+
credentials = token_cache.load()
148+
if credentials:
149+
# Force a refresh in case the loaded credentials are expired.
150+
credentials.token()
151+
else:
152+
consent = oauth_client.initiate_consent()
153+
if not consent:
154+
return None
155+
credentials = consent.launch_external_browser()
156+
token_cache.save(credentials)
147157
return credentials(cfg)
148158

149159

databricks/sdk/oauth.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import base64
22
import functools
33
import hashlib
4+
import json
45
import logging
6+
import os
57
import secrets
68
import threading
79
import urllib.parse
@@ -10,7 +12,7 @@
1012
from dataclasses import dataclass
1113
from datetime import datetime, timedelta
1214
from http.server import BaseHTTPRequestHandler, HTTPServer
13-
from typing import Any, Dict, List
15+
from typing import Any, Dict, List, Optional
1416

1517
import requests
1618
import requests.auth
@@ -402,3 +404,41 @@ def refresh(self) -> Token:
402404
params,
403405
use_params=self.use_params,
404406
use_header=self.use_header)
407+
408+
409+
class TokenCache():
410+
BASE_PATH = "~/.config/databricks-sdk-py/oauth"
411+
412+
def __init__(self, client: OAuthClient) -> None:
413+
self.client = client
414+
415+
@property
416+
def filename(self) -> str:
417+
# Include host, client_id, and scopes in the cache filename to make it unique.
418+
hash = hashlib.sha256()
419+
for chunk in [self.client.host, self.client.client_id, ",".join(self.client._scopes), ]:
420+
hash.update(chunk.encode('utf-8'))
421+
return os.path.expanduser(os.path.join(self.__class__.BASE_PATH, hash.hexdigest() + ".json"))
422+
423+
def load(self) -> Optional[RefreshableCredentials]:
424+
"""
425+
Load credentials from cache file. Return None if the cache file does not exist or is invalid.
426+
"""
427+
if not os.path.exists(self.filename):
428+
return None
429+
430+
try:
431+
with open(self.filename, 'r') as f:
432+
raw = json.load(f)
433+
return RefreshableCredentials.from_dict(self.client, raw)
434+
except Exception:
435+
return None
436+
437+
def save(self, credentials: RefreshableCredentials) -> None:
438+
"""
439+
Save credentials to cache file.
440+
"""
441+
os.makedirs(os.path.dirname(self.filename), exist_ok=True)
442+
with open(self.filename, 'w') as f:
443+
json.dump(credentials.as_dict(), f)
444+
os.chmod(self.filename, 0o600)

tests/test_oauth.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from databricks.sdk.core import Config, OidcEndpoints
2+
from databricks.sdk.oauth import OAuthClient, TokenCache
3+
4+
5+
def test_token_cache_unique_filename_by_host(mocker):
6+
mocker.patch.object(Config, "oidc_endpoints",
7+
OidcEndpoints("http://localhost:1234", "http://localhost:1234"))
8+
common_args = dict(client_id="abc", redirect_url="http://localhost:8020")
9+
c1 = OAuthClient(host="http://localhost:", **common_args)
10+
c2 = OAuthClient(host="https://bar.cloud.databricks.com", **common_args)
11+
assert TokenCache(c1).filename != TokenCache(c2).filename
12+
13+
14+
def test_token_cache_unique_filename_by_client_id(mocker):
15+
mocker.patch.object(Config, "oidc_endpoints",
16+
OidcEndpoints("http://localhost:1234", "http://localhost:1234"))
17+
common_args = dict(host="http://localhost:", redirect_url="http://localhost:8020")
18+
c1 = OAuthClient(client_id="abc", **common_args)
19+
c2 = OAuthClient(client_id="def", **common_args)
20+
assert TokenCache(c1).filename != TokenCache(c2).filename
21+
22+
23+
def test_token_cache_unique_filename_by_scopes(mocker):
24+
mocker.patch.object(Config, "oidc_endpoints",
25+
OidcEndpoints("http://localhost:1234", "http://localhost:1234"))
26+
common_args = dict(host="http://localhost:", client_id="abc", redirect_url="http://localhost:8020")
27+
c1 = OAuthClient(scopes=["foo"], **common_args)
28+
c2 = OAuthClient(scopes=["bar"], **common_args)
29+
assert TokenCache(c1).filename != TokenCache(c2).filename

0 commit comments

Comments
 (0)