diff --git a/README.md b/README.md index 2050a704c..dc58c89c2 100644 --- a/README.md +++ b/README.md @@ -368,6 +368,7 @@ Works for both AWS and Azure. Not supported for GCP at the moment. ```python from databricks.sdk.oauth import OAuthClient + oauth_client = OAuthClient(host='', client_id='', redirect_url=f'http://host.domain/callback', @@ -380,29 +381,31 @@ APP_NAME = 'flask-demo' app = Flask(APP_NAME) app.secret_key = secrets.token_urlsafe(32) + @app.route('/callback') def callback(): - from databricks.sdk.oauth import Consent - consent = Consent.from_dict(oauth_client, session['consent']) - session['creds'] = consent.exchange_callback_parameters(request.args).as_dict() - return redirect(url_for('index')) + from databricks.sdk.oauth import Consent + consent = Consent.from_dict(oauth_client, session['consent']) + session['creds'] = consent.exchange_callback_parameters(request.args).as_dict() + return redirect(url_for('index')) + @app.route('/') def index(): - if 'creds' not in session: - consent = oauth_client.initiate_consent() - session['consent'] = consent.as_dict() - return redirect(consent.auth_url) + if 'creds' not in session: + consent = oauth_client.initiate_consent() + session['consent'] = consent.as_dict() + return redirect(consent.auth_url) - from databricks.sdk import WorkspaceClient - from databricks.sdk.oauth import RefreshableCredentials + from databricks.sdk import WorkspaceClient + from databricks.sdk.oauth import SessionCredentials - credentials_provider = RefreshableCredentials.from_dict(oauth_client, session['creds']) - workspace_client = WorkspaceClient(host=oauth_client.host, - product=APP_NAME, - credentials_provider=credentials_provider) + credentials_provider = SessionCredentials.from_dict(oauth_client, session['creds']) + workspace_client = WorkspaceClient(host=oauth_client.host, + product=APP_NAME, + credentials_provider=credentials_provider) - return render_template_string('...', w=workspace_client) + return render_template_string('...', w=workspace_client) ``` ### SSO for local scripts on development machines diff --git a/databricks/sdk/oauth.py b/databricks/sdk/oauth.py index ae75b484d..70af9545b 100644 --- a/databricks/sdk/oauth.py +++ b/databricks/sdk/oauth.py @@ -177,7 +177,7 @@ def do_GET(self): self.wfile.write(b'You can close this tab.') -class RefreshableCredentials(Refreshable): +class SessionCredentials(Refreshable): def __init__(self, client: 'OAuthClient', token: Token): self._client = client @@ -187,8 +187,8 @@ def as_dict(self) -> dict: return {'token': self._token.as_dict()} @staticmethod - def from_dict(client: 'OAuthClient', raw: dict) -> 'RefreshableCredentials': - return RefreshableCredentials(client=client, token=Token.from_dict(raw['token'])) + def from_dict(client: 'OAuthClient', raw: dict) -> 'SessionCredentials': + return SessionCredentials(client=client, token=Token.from_dict(raw['token'])) def auth_type(self): """Implementing CredentialsProvider protocol""" @@ -237,7 +237,7 @@ def as_dict(self) -> dict: def from_dict(client: 'OAuthClient', raw: dict) -> 'Consent': return Consent(client, raw['state'], raw['verifier']) - def launch_external_browser(self) -> RefreshableCredentials: + def launch_external_browser(self) -> SessionCredentials: redirect_url = urllib.parse.urlparse(self._client.redirect_url) if redirect_url.hostname not in ('localhost', '127.0.0.1'): raise ValueError(f'cannot listen on {redirect_url.hostname}') @@ -254,14 +254,14 @@ def launch_external_browser(self) -> RefreshableCredentials: query = feedback.pop() return self.exchange_callback_parameters(query) - def exchange_callback_parameters(self, query: Dict[str, str]) -> RefreshableCredentials: + def exchange_callback_parameters(self, query: Dict[str, str]) -> SessionCredentials: if 'error' in query: raise ValueError('{error}: {error_description}'.format(**query)) if 'code' not in query or 'state' not in query: raise ValueError('No code returned in callback') return self.exchange(query['code'], query['state']) - def exchange(self, code: str, state: str) -> RefreshableCredentials: + def exchange(self, code: str, state: str) -> SessionCredentials: if self._state != state: raise ValueError('state mismatch') params = { @@ -279,7 +279,7 @@ def exchange(self, code: str, state: str) -> RefreshableCredentials: params=params, headers=headers, use_params=True) - return RefreshableCredentials(self._client, token) + return SessionCredentials(self._client, token) except ValueError as e: if NO_ORIGIN_FOR_SPA_CLIENT_ERROR in str(e): # Retry in cases of 'Single-Page Application' client-type with @@ -420,7 +420,7 @@ def filename(self) -> str: hash.update(chunk.encode('utf-8')) return os.path.expanduser(os.path.join(self.__class__.BASE_PATH, hash.hexdigest() + ".json")) - def load(self) -> Optional[RefreshableCredentials]: + def load(self) -> Optional[SessionCredentials]: """ Load credentials from cache file. Return None if the cache file does not exist or is invalid. """ @@ -430,11 +430,11 @@ def load(self) -> Optional[RefreshableCredentials]: try: with open(self.filename, 'r') as f: raw = json.load(f) - return RefreshableCredentials.from_dict(self.client, raw) + return SessionCredentials.from_dict(self.client, raw) except Exception: return None - def save(self, credentials: RefreshableCredentials) -> None: + def save(self, credentials: SessionCredentials) -> None: """ Save credentials to cache file. """ diff --git a/examples/flask_app_with_oauth.py b/examples/flask_app_with_oauth.py index fd3847707..b03d2a406 100755 --- a/examples/flask_app_with_oauth.py +++ b/examples/flask_app_with_oauth.py @@ -78,9 +78,9 @@ def index(): return redirect(consent.auth_url) from databricks.sdk import WorkspaceClient - from databricks.sdk.oauth import RefreshableCredentials + from databricks.sdk.oauth import SessionCredentials - credentials_provider = RefreshableCredentials.from_dict(oauth_client, session["creds"]) + credentials_provider = SessionCredentials.from_dict(oauth_client, session["creds"]) workspace_client = WorkspaceClient(host=oauth_client.host, product=APP_NAME, credentials_provider=credentials_provider,