Skip to content
4 changes: 2 additions & 2 deletions airbyte_cdk/sources/declarative/auth/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,8 @@ def get_token_expiry_date(self) -> AirbyteDateTime:
def _has_access_token_been_initialized(self) -> bool:
return self._access_token is not None

def set_token_expiry_date(self, value: Union[str, int]) -> None:
self._token_expiry_date = self._parse_token_expiration_date(value)
def set_token_expiry_date(self, value: AirbyteDateTime) -> None:
self._token_expiry_date = value

def get_assertion_name(self) -> str:
return self.assertion_name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def build_refresh_request_headers(self) -> Mapping[str, Any] | None:
headers = self.get_refresh_request_headers()
return headers if headers else None

def refresh_access_token(self) -> Tuple[str, Union[str, int]]:
def refresh_access_token(self) -> Tuple[str, AirbyteDateTime]:
"""
Returns the refresh token and its expiration datetime

Expand All @@ -148,6 +148,14 @@ def refresh_access_token(self) -> Tuple[str, Union[str, int]]:
# PRIVATE METHODS
# ----------------

def _default_token_expiry_date(self) -> AirbyteDateTime:
"""
Returns the default token expiry date
"""
# 1 hour was chosen as a middle ground to avoid unnecessary frequent refreshes and token expiration
default_token_expiry_duration_hours = 1 # 1 hour
return ab_datetime_now() + timedelta(hours=default_token_expiry_duration_hours)

def _wrap_refresh_token_exception(
self, exception: requests.exceptions.RequestException
) -> bool:
Expand Down Expand Up @@ -257,14 +265,10 @@ def _ensure_access_token_in_response(self, response_data: Mapping[str, Any]) ->

def _parse_token_expiration_date(self, value: Union[str, int]) -> AirbyteDateTime:
"""
Return the expiration datetime of the refresh token
Parse a string or integer token expiration date into a datetime object

:return: expiration datetime
"""
if not value and not self.token_has_expired():
# No expiry token was provided but the previous one is not expired so it's fine
return self.get_token_expiry_date()

if self.token_expiry_is_time_of_expiration:
if not self.token_expiry_date_format:
raise ValueError(
Expand Down Expand Up @@ -308,17 +312,30 @@ def _extract_refresh_token(self, response_data: Mapping[str, Any]) -> Any:
"""
return self._find_and_get_value_from_response(response_data, self.get_refresh_token_name())

def _extract_token_expiry_date(self, response_data: Mapping[str, Any]) -> Any:
def _extract_token_expiry_date(self, response_data: Mapping[str, Any]) -> AirbyteDateTime:
"""
Extracts the token_expiry_date, like `expires_in` or `expires_at`, etc from the given response data.

If the token_expiry_date is not found, it will return an existing token expiry date if set, or a default token expiry date.

Args:
response_data (Mapping[str, Any]): The response data from which to extract the token_expiry_date.

Returns:
str: The extracted token_expiry_date.
The extracted token_expiry_date or None if not found.
"""
return self._find_and_get_value_from_response(response_data, self.get_expires_in_name())
expires_in = self._find_and_get_value_from_response(
response_data, self.get_expires_in_name()
)
if expires_in is not None:
return self._parse_token_expiration_date(expires_in)

# expires_in is None
existing_expiry_date = self.get_token_expiry_date()
if existing_expiry_date and not self.token_has_expired():
return existing_expiry_date

return self._default_token_expiry_date()

def _find_and_get_value_from_response(
self,
Expand All @@ -344,7 +361,7 @@ def _find_and_get_value_from_response(
"""
if current_depth > max_depth:
# this is needed to avoid an inf loop, possible with a very deep nesting observed.
message = f"The maximum level of recursion is reached. Couldn't find the speficied `{key_name}` in the response."
message = f"The maximum level of recursion is reached. Couldn't find the specified `{key_name}` in the response."
raise ResponseKeysMaxRecurtionReached(
internal_message=message, message=message, failure_type=FailureType.config_error
)
Expand Down Expand Up @@ -441,7 +458,7 @@ def get_token_expiry_date(self) -> AirbyteDateTime:
"""Expiration date of the access token"""

@abstractmethod
def set_token_expiry_date(self, value: Union[str, int]) -> None:
def set_token_expiry_date(self, value: AirbyteDateTime) -> None:
"""Setter for access token expiration date"""

@abstractmethod
Expand Down
31 changes: 4 additions & 27 deletions airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,8 @@ def get_grant_type(self) -> str:
def get_token_expiry_date(self) -> AirbyteDateTime:
return self._token_expiry_date

def set_token_expiry_date(self, value: Union[str, int]) -> None:
self._token_expiry_date = self._parse_token_expiration_date(value)
def set_token_expiry_date(self, value: AirbyteDateTime) -> None:
self._token_expiry_date = value

@property
def token_expiry_is_time_of_expiration(self) -> bool:
Expand Down Expand Up @@ -316,26 +316,6 @@ def token_has_expired(self) -> bool:
"""Returns True if the token is expired"""
return ab_datetime_now() > self.get_token_expiry_date()

@staticmethod
def get_new_token_expiry_date(
access_token_expires_in: str,
token_expiry_date_format: str | None = None,
) -> AirbyteDateTime:
"""
Calculate the new token expiry date based on the provided expiration duration or format.

Args:
access_token_expires_in (str): The duration (in seconds) until the access token expires, or the expiry date in a specific format.
token_expiry_date_format (str | None, optional): The format of the expiry date if provided. Defaults to None.

Returns:
AirbyteDateTime: The calculated expiry date of the access token.
"""
if token_expiry_date_format:
return ab_datetime_parse(access_token_expires_in)
else:
return ab_datetime_now() + timedelta(seconds=int(access_token_expires_in))

def get_access_token(self) -> str:
"""Retrieve new access and refresh token if the access token has expired.
The new refresh token is persisted with the set_refresh_token function
Expand All @@ -346,16 +326,13 @@ def get_access_token(self) -> str:
new_access_token, access_token_expires_in, new_refresh_token = (
self.refresh_access_token()
)
new_token_expiry_date: AirbyteDateTime = self.get_new_token_expiry_date(
access_token_expires_in, self._token_expiry_date_format
)
self.access_token = new_access_token
self.set_refresh_token(new_refresh_token)
self.set_token_expiry_date(new_token_expiry_date)
self.set_token_expiry_date(access_token_expires_in)
self._emit_control_message()
return self.access_token

def refresh_access_token(self) -> Tuple[str, str, str]: # type: ignore[override]
def refresh_access_token(self) -> Tuple[str, AirbyteDateTime, str]: # type: ignore[override]
"""
Refreshes the access token by making a handled request and extracting the necessary token information.

Expand Down
73 changes: 66 additions & 7 deletions unit_tests/sources/declarative/auth/test_oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ def test_error_on_refresh_token_grant_without_refresh_token(self):
grant_type="refresh_token",
)

@freezegun.freeze_time("2022-01-01")
def test_refresh_access_token(self, mocker):
oauth = DeclarativeOauth2Authenticator(
token_refresh_endpoint="{{ config['refresh_endpoint'] }}",
Expand All @@ -225,13 +226,15 @@ def test_refresh_access_token(self, mocker):
resp, "json", return_value={"access_token": "access_token", "expires_in": 1000}
)
mocker.patch.object(requests, "request", side_effect=mock_request, autospec=True)
token = oauth.refresh_access_token()
access_token, token_expiry_date = oauth.refresh_access_token()

assert ("access_token", 1000) == token
assert access_token == "access_token"
assert token_expiry_date == ab_datetime_now() + timedelta(seconds=1000)

filtered = filter_secrets("access_token")
assert filtered == "****"

@freezegun.freeze_time("2022-01-01")
def test_refresh_access_token_when_headers_provided(self, mocker):
expected_headers = {
"Authorization": "Bearer some_access_token",
Expand All @@ -256,9 +259,10 @@ def test_refresh_access_token_when_headers_provided(self, mocker):
mocked_request = mocker.patch.object(
requests, "request", side_effect=mock_request, autospec=True
)
token = oauth.refresh_access_token()
access_token, token_expiry_date = oauth.refresh_access_token()

assert ("access_token", 1000) == token
assert access_token == "access_token"
assert token_expiry_date == ab_datetime_now() + timedelta(seconds=1000)

assert mocked_request.call_args.kwargs["headers"] == expected_headers

Expand Down Expand Up @@ -314,6 +318,7 @@ def test_initialize_declarative_oauth_with_token_expiry_date_as_timestamp(
assert isinstance(oauth._token_expiry_date, AirbyteDateTime)
assert oauth.get_token_expiry_date() == ab_datetime_parse(expected_date)

@freezegun.freeze_time("2022-01-01")
def test_given_no_access_token_but_expiry_in_the_future_when_refresh_token_then_fetch_access_token(
self,
) -> None:
Expand All @@ -335,12 +340,65 @@ def test_given_no_access_token_but_expiry_in_the_future_when_refresh_token_then_
url="https://refresh_endpoint.com/",
body="grant_type=client&client_id=some_client_id&client_secret=some_client_secret&refresh_token=some_refresh_token",
),
HttpResponse(body=json.dumps({"access_token": "new_access_token"})),
HttpResponse(
body=json.dumps({"access_token": "new_access_token", "expires_in": 1000})
),
)
oauth.get_access_token()

assert oauth.access_token == "new_access_token"
assert oauth._token_expiry_date == expiry_date
assert oauth._token_expiry_date == ab_datetime_now() + timedelta(seconds=1000)

@freezegun.freeze_time("2022-01-01")
@pytest.mark.parametrize(
"initial_expiry_date_delta, expected_new_expiry_date_delta, expected_access_token",
[
(timedelta(days=1), timedelta(days=1), "some_access_token"),
(timedelta(days=-1), timedelta(hours=1), "new_access_token"),
(None, timedelta(hours=1), "new_access_token"),
],
ids=[
"initial_expiry_date_in_future",
"initial_expiry_date_in_past",
"no_initial_expiry_date",
],
)
def test_no_expiry_date_provided_by_auth_server(
self,
initial_expiry_date_delta,
expected_new_expiry_date_delta,
expected_access_token,
) -> None:
initial_expiry_date = (
ab_datetime_now().add(initial_expiry_date_delta).isoformat()
if initial_expiry_date_delta
else None
)
expected_new_expiry_date = ab_datetime_now().add(expected_new_expiry_date_delta)
oauth = DeclarativeOauth2Authenticator(
token_refresh_endpoint="https://refresh_endpoint.com/",
client_id="some_client_id",
client_secret="some_client_secret",
token_expiry_date=initial_expiry_date,
access_token_value="some_access_token",
refresh_token="some_refresh_token",
config={},
parameters={},
grant_type="client",
)

with HttpMocker() as http_mocker:
http_mocker.post(
HttpRequest(
url="https://refresh_endpoint.com/",
body="grant_type=client&client_id=some_client_id&client_secret=some_client_secret&refresh_token=some_refresh_token",
),
HttpResponse(body=json.dumps({"access_token": "new_access_token"})),
)
oauth.get_access_token()

assert oauth.access_token == expected_access_token
assert oauth._token_expiry_date == expected_new_expiry_date

@pytest.mark.parametrize(
"expires_in_response, token_expiry_date_format",
Expand Down Expand Up @@ -443,6 +501,7 @@ def test_set_token_expiry_date_no_format(self, mocker, expires_in_response, next
assert "access_token" == token
assert oauth.get_token_expiry_date() == ab_datetime_parse(next_day)

@freezegun.freeze_time("2022-01-01")
def test_profile_assertion(self, mocker):
with HttpMocker() as http_mocker:
jwt = JwtAuthenticator(
Expand Down Expand Up @@ -477,7 +536,7 @@ def test_profile_assertion(self, mocker):

token = oauth.refresh_access_token()

assert ("access_token", 1000) == token
assert ("access_token", ab_datetime_now().add(timedelta(seconds=1000))) == token

filtered = filter_secrets("access_token")
assert filtered == "****"
Expand Down
Loading
Loading