Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 18 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,7 @@ Works for both AWS and Azure. Not supported for GCP at the moment.

```python
from databricks.sdk.oauth import OAuthClient

oauth_client = OAuthClient(host='<workspace-url>',
client_id='<oauth client ID>',
redirect_url=f'http://host.domain/callback',
Expand All @@ -380,29 +381,31 @@ APP_NAME = 'flask-demo'
app = Flask(APP_NAME)
app.secret_key = secrets.token_urlsafe(32)


@app.route('/callback')
def callback():
from databricks.sdk.oauth import Consent
consent = Consent.from_dict(oauth_client, session['consent'])
session['creds'] = consent.exchange_callback_parameters(request.args).as_dict()
return redirect(url_for('index'))
from databricks.sdk.oauth import Consent
consent = Consent.from_dict(oauth_client, session['consent'])
session['creds'] = consent.exchange_callback_parameters(request.args).as_dict()
return redirect(url_for('index'))


@app.route('/')
def index():
if 'creds' not in session:
consent = oauth_client.initiate_consent()
session['consent'] = consent.as_dict()
return redirect(consent.auth_url)
if 'creds' not in session:
consent = oauth_client.initiate_consent()
session['consent'] = consent.as_dict()
return redirect(consent.auth_url)

from databricks.sdk import WorkspaceClient
from databricks.sdk.oauth import RefreshableCredentials
from databricks.sdk import WorkspaceClient
from databricks.sdk.oauth import SessionCredentials

credentials_provider = RefreshableCredentials.from_dict(oauth_client, session['creds'])
workspace_client = WorkspaceClient(host=oauth_client.host,
product=APP_NAME,
credentials_provider=credentials_provider)
credentials_provider = SessionCredentials.from_dict(oauth_client, session['creds'])
workspace_client = WorkspaceClient(host=oauth_client.host,
product=APP_NAME,
credentials_provider=credentials_provider)

return render_template_string('...', w=workspace_client)
return render_template_string('...', w=workspace_client)
```

### SSO for local scripts on development machines
Expand Down
20 changes: 10 additions & 10 deletions databricks/sdk/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def do_GET(self):
self.wfile.write(b'You can close this tab.')


class RefreshableCredentials(Refreshable):
class SessionCredentials(Refreshable):

def __init__(self, client: 'OAuthClient', token: Token):
self._client = client
Expand All @@ -187,8 +187,8 @@ def as_dict(self) -> dict:
return {'token': self._token.as_dict()}

@staticmethod
def from_dict(client: 'OAuthClient', raw: dict) -> 'RefreshableCredentials':
return RefreshableCredentials(client=client, token=Token.from_dict(raw['token']))
def from_dict(client: 'OAuthClient', raw: dict) -> 'SessionCredentials':
return SessionCredentials(client=client, token=Token.from_dict(raw['token']))

def auth_type(self):
"""Implementing CredentialsProvider protocol"""
Expand Down Expand Up @@ -237,7 +237,7 @@ def as_dict(self) -> dict:
def from_dict(client: 'OAuthClient', raw: dict) -> 'Consent':
return Consent(client, raw['state'], raw['verifier'])

def launch_external_browser(self) -> RefreshableCredentials:
def launch_external_browser(self) -> SessionCredentials:
redirect_url = urllib.parse.urlparse(self._client.redirect_url)
if redirect_url.hostname not in ('localhost', '127.0.0.1'):
raise ValueError(f'cannot listen on {redirect_url.hostname}')
Expand All @@ -254,14 +254,14 @@ def launch_external_browser(self) -> RefreshableCredentials:
query = feedback.pop()
return self.exchange_callback_parameters(query)

def exchange_callback_parameters(self, query: Dict[str, str]) -> RefreshableCredentials:
def exchange_callback_parameters(self, query: Dict[str, str]) -> SessionCredentials:
if 'error' in query:
raise ValueError('{error}: {error_description}'.format(**query))
if 'code' not in query or 'state' not in query:
raise ValueError('No code returned in callback')
return self.exchange(query['code'], query['state'])

def exchange(self, code: str, state: str) -> RefreshableCredentials:
def exchange(self, code: str, state: str) -> SessionCredentials:
if self._state != state:
raise ValueError('state mismatch')
params = {
Expand All @@ -279,7 +279,7 @@ def exchange(self, code: str, state: str) -> RefreshableCredentials:
params=params,
headers=headers,
use_params=True)
return RefreshableCredentials(self._client, token)
return SessionCredentials(self._client, token)
except ValueError as e:
if NO_ORIGIN_FOR_SPA_CLIENT_ERROR in str(e):
# Retry in cases of 'Single-Page Application' client-type with
Expand Down Expand Up @@ -420,7 +420,7 @@ def filename(self) -> str:
hash.update(chunk.encode('utf-8'))
return os.path.expanduser(os.path.join(self.__class__.BASE_PATH, hash.hexdigest() + ".json"))

def load(self) -> Optional[RefreshableCredentials]:
def load(self) -> Optional[SessionCredentials]:
"""
Load credentials from cache file. Return None if the cache file does not exist or is invalid.
"""
Expand All @@ -430,11 +430,11 @@ def load(self) -> Optional[RefreshableCredentials]:
try:
with open(self.filename, 'r') as f:
raw = json.load(f)
return RefreshableCredentials.from_dict(self.client, raw)
return SessionCredentials.from_dict(self.client, raw)
except Exception:
return None

def save(self, credentials: RefreshableCredentials) -> None:
def save(self, credentials: SessionCredentials) -> None:
"""
Save credentials to cache file.
"""
Expand Down
4 changes: 2 additions & 2 deletions examples/flask_app_with_oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@ def index():
return redirect(consent.auth_url)

from databricks.sdk import WorkspaceClient
from databricks.sdk.oauth import RefreshableCredentials
from databricks.sdk.oauth import SessionCredentials

credentials_provider = RefreshableCredentials.from_dict(oauth_client, session["creds"])
credentials_provider = SessionCredentials.from_dict(oauth_client, session["creds"])
workspace_client = WorkspaceClient(host=oauth_client.host,
product=APP_NAME,
credentials_provider=credentials_provider,
Expand Down