Skip to content

Commit eea840a

Browse files
paulbauriegelfrascuchonpre-commit-ci[bot]
authored
Add keycloak SSO (#5711)
# Add keycloak SSO Based on discussion in #5691 Points that need some feedback: - A lot of configurations are set via env variables now. Not sure if that's ideal, error messages if something is not set correctly can be rather cryptic with social auth lib - I added the Keycloak logo to the Oauth button id the provider is keycloak, generally the same could also be done for the HF logo not having a separate button - Is the documentation to set-up a keycloak server sufficient? **Type of change** - Improvement (change adding some improvement to an existing functionality) - Documentation update **How Has This Been Tested** Local build & Keycloak installation as described in the documentation. **Checklist** <!-- Please go over the list and make sure you've taken everything into account --> - I added relevant documentation - I followed the style guidelines of this project - I did a self-review of my code - I made corresponding changes to the documentation - I confirm My changes generate no new warnings - I have added tests that prove my fix is effective or that my feature works - TODO I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/) --------- Co-authored-by: Paco Aranda <[email protected]> Co-authored-by: Francisco Aranda <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Paco Aranda <[email protected]>
1 parent 9d69501 commit eea840a

File tree

11 files changed

+637
-9
lines changed

11 files changed

+637
-9
lines changed
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
<template>
2+
<!--https://github.com/keycloak/keycloak-misc/blob/main/logo/icon.svg-->
3+
<svg
4+
width="256"
5+
height="256"
6+
viewBox="0 0 44.216 39.861"
7+
fill="none"
8+
xmlns="http://www.w3.org/2000/svg"
9+
>
10+
<path
11+
d="m88.61 138.456 5.716-9.865 23.018-.004 5.686 9.965.007 19.932-5.691 9.957-23.012.008-5.782-9.965z"
12+
style="
13+
display: inline;
14+
fill: #4d4d4d;
15+
fill-opacity: 1;
16+
stroke-width: 0.264583;
17+
"
18+
transform="translate(-82.815 -128.588)"
19+
/>
20+
<path
21+
d="M88.552 158.481h10.375l-5.699-10.041 4.634-9.982-9.252-.002-5.795 10.065"
22+
style="
23+
fill: #ededed;
24+
fill-opacity: 1;
25+
fill-rule: nonzero;
26+
stroke: none;
27+
stroke-width: 0.330729;
28+
"
29+
transform="translate(-82.815 -128.588)"
30+
/>
31+
<path
32+
d="M102.073 158.481h7.582l6.706-9.773-6.589-10.156h-8.921l-5.373 9.814z"
33+
style="
34+
fill: #e0e0e0;
35+
fill-opacity: 1;
36+
fill-rule: nonzero;
37+
stroke: none;
38+
stroke-width: 0.330729;
39+
"
40+
transform="translate(-82.815 -128.588)"
41+
/>
42+
<path
43+
d="m82.815 148.52 5.738 9.964h10.374l-5.636-9.93z"
44+
style="
45+
fill: #acacac;
46+
fill-opacity: 1;
47+
fill-rule: nonzero;
48+
stroke: none;
49+
stroke-width: 0.330729;
50+
"
51+
transform="translate(-82.815 -128.588)"
52+
/>
53+
<path
54+
d="m95.589 148.522 6.484 9.963h7.582l6.601-9.959z"
55+
style="
56+
fill: #9e9e9e;
57+
fill-opacity: 1;
58+
fill-rule: nonzero;
59+
stroke: none;
60+
stroke-width: 0.330729;
61+
"
62+
transform="translate(-82.815 -128.588)"
63+
/>
64+
<path
65+
d="m98.157 148.529-1.958.569-1.877-.572 7.667-13.288 1.918 3.316"
66+
style="
67+
fill: #00b8e3;
68+
fill-opacity: 1;
69+
fill-rule: nonzero;
70+
stroke: none;
71+
stroke-width: 0.330729;
72+
"
73+
transform="translate(-82.815 -128.588)"
74+
/>
75+
<path
76+
d="m103.9 158.482-1.909 3.332-5.093-5.487-2.58-7.797v-.004h3.838"
77+
style="
78+
fill: #33c6e9;
79+
fill-opacity: 1;
80+
fill-rule: nonzero;
81+
stroke: none;
82+
stroke-width: 0.330729;
83+
"
84+
transform="translate(-82.815 -128.588)"
85+
/>
86+
<path
87+
d="M94.322 148.526h-.003v.003l-1.918 3.322-1.925-3.307 1.952-3.386 5.728-9.92h3.834"
88+
style="
89+
fill: #008aaa;
90+
fill-opacity: 1;
91+
fill-rule: nonzero;
92+
stroke: none;
93+
stroke-width: 0.330729;
94+
"
95+
transform="translate(-82.815 -128.588)"
96+
/>
97+
<path
98+
d="M115.42 158.481h11.611l-.007-19.93h-11.605z"
99+
style="
100+
fill: #d4d4d4;
101+
fill-opacity: 1;
102+
fill-rule: nonzero;
103+
stroke: none;
104+
stroke-width: 0.330729;
105+
"
106+
transform="translate(-82.815 -128.588)"
107+
/>
108+
<path
109+
d="M115.42 148.554v9.93h11.59v-9.93z"
110+
style="
111+
fill: #919191;
112+
fill-opacity: 1;
113+
fill-rule: nonzero;
114+
stroke: none;
115+
stroke-width: 0.330729;
116+
"
117+
transform="translate(-82.815 -128.588)"
118+
/>
119+
<path
120+
d="M101.992 161.817h-3.836l-5.755-9.966 1.918-3.321z"
121+
style="
122+
fill: #00b8e3;
123+
fill-opacity: 1;
124+
fill-rule: nonzero;
125+
stroke: none;
126+
stroke-width: 0.330729;
127+
"
128+
transform="translate(-82.815 -128.588)"
129+
/>
130+
<path
131+
d="m117.333 148.526-7.669 13.289c-.705-1.036-1.913-3.331-1.913-3.331l5.753-9.959z"
132+
style="
133+
fill: #008aaa;
134+
fill-opacity: 1;
135+
fill-rule: nonzero;
136+
stroke: none;
137+
stroke-width: 0.330729;
138+
"
139+
transform="translate(-82.815 -128.588)"
140+
/>
141+
<path
142+
d="m113.495 161.815-3.831-.001 7.67-13.288 1.917-3.317 1.921 3.34m-3.839-.023h-3.828l-5.755-9.973 1.905-3.314 4.658 5.922z"
143+
style="
144+
fill: #00b8e3;
145+
fill-opacity: 1;
146+
fill-rule: nonzero;
147+
stroke: none;
148+
stroke-width: 0.330729;
149+
"
150+
transform="translate(-82.815 -128.588)"
151+
/>
152+
<path
153+
d="M119.25 145.205v.003l-1.917 3.318-7.677-13.286 3.841.002z"
154+
style="
155+
fill: #33c6e9;
156+
fill-opacity: 1;
157+
fill-rule: nonzero;
158+
stroke: none;
159+
stroke-width: 0.330729;
160+
"
161+
transform="translate(-82.815 -128.588)"
162+
/>
163+
</svg>
164+
</template>

argilla-frontend/components/features/login/components/OAuthLoginButton.vue

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,18 @@
11
<template>
22
<BaseButton class="sign-in-button" @click="$emit('click')">
3+
<KeycloakLogo v-if="provider === 'keycloak'" />
34
{{ signinText }}
45
</BaseButton>
56
</template>
7+
68
<script>
9+
import KeycloakLogo from "./KeycloakLogo.vue";
10+
711
export default {
812
name: "OAuthLoginButton",
13+
components: {
14+
KeycloakLogo,
15+
},
916
props: {
1017
provider: {
1118
type: String,

argilla-frontend/translation/de.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ export default {
120120
button: {
121121
ignore_and_continue: "Ignorieren und fortfahren",
122122
login: "Anmelden",
123-
signin_with_provider: "Anmeldung bei {provider} starten",
123+
signin_with_provider: "Mit {provider} anmelden",
124124
"hf-login": "Mit Hugging Face anmelden",
125125
sign_in_with_username: "Mit Benutzername anmelden",
126126
cancel: "Abbrechen",

argilla-server/src/argilla_server/api/handlers/v1/oauth2.py

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from argilla_server.contexts import accounts
2121
from argilla_server.database import get_async_db
2222
from argilla_server.errors.future import NotFoundError
23-
from argilla_server.models import User
23+
from argilla_server.models import Workspace, WorkspaceUser
2424
from argilla_server.security.authentication.oauth2 import OAuth2ClientProvider
2525
from argilla_server.security.authentication.userinfo import UserInfo
2626
from argilla_server.security.settings import settings
@@ -61,14 +61,51 @@ async def get_access_token(
6161
if not userinfo.username:
6262
raise RuntimeError("OAuth error: Missing username")
6363

64-
user = await User.get_by(db, username=userinfo.username)
65-
if user is None:
66-
user = await accounts.create_user_with_random_password(
64+
default_available_workspaces = [workspace.name for workspace in settings.oauth.allowed_workspaces]
65+
available_workspaces = userinfo.available_workspaces or default_available_workspaces
66+
67+
oauth_user = await accounts.get_user_by_username(db, username=userinfo.username)
68+
69+
if oauth_user is None:
70+
for workspace_name in available_workspaces:
71+
if await Workspace.get_by(db, name=workspace_name) is None:
72+
await Workspace.create(db, name=workspace_name, autocommit=False)
73+
74+
oauth_user = await accounts.create_user_with_random_password(
6775
db,
6876
username=userinfo.username,
6977
first_name=userinfo.first_name,
7078
role=userinfo.role,
71-
workspaces=[workspace.name for workspace in settings.oauth.allowed_workspaces],
79+
workspaces=available_workspaces,
7280
)
81+
elif provider.sync_user:
82+
oauth_role = oauth_user.role
83+
oauth_workspaces = oauth_user.workspaces or []
84+
85+
# Sync user role
86+
if oauth_role != userinfo.role:
87+
await accounts.update_user(db, user=oauth_user, user_attrs={"role": userinfo.role})
88+
# Sync removed workspaces
89+
for workspace in oauth_workspaces:
90+
if workspace.name not in available_workspaces:
91+
ws_user = await WorkspaceUser.get_by(db, workspace_id=workspace.id, user_id=oauth_user.id)
92+
await ws_user.delete(db, autocommit=False)
93+
# Sync added workspaces
94+
for workspace_name in available_workspaces:
95+
if workspace_name in [ws.name for ws in oauth_workspaces]:
96+
continue
97+
98+
workspace = await Workspace.get_by(db, name=workspace_name)
99+
if not workspace:
100+
workspace = await Workspace.create(db, name=workspace_name, autocommit=False)
101+
102+
if not await WorkspaceUser.get_by(db, workspace_id=workspace.id, user_id=oauth_user.id):
103+
await WorkspaceUser.create(
104+
db,
105+
workspace_id=workspace.id,
106+
user_id=oauth_user.id,
107+
autocommit=False,
108+
)
109+
await db.commit()
73110

74-
return Token(access_token=accounts.generate_user_token(user))
111+
return Token(access_token=accounts.generate_user_token(oauth_user))

argilla-server/src/argilla_server/security/authentication/oauth2/_backends.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,15 @@
1313
# limitations under the License.
1414

1515
import os
16-
from typing import Type, Dict, Any
16+
from typing import Type, Dict, Any, Optional, List
1717

1818
from social_core.backends.oauth import BaseOAuth2
1919
from social_core.backends.open_id_connect import OpenIdConnectAuth
2020
from social_core.backends.utils import load_backends
2121
from social_core.strategy import BaseStrategy
2222

2323
from argilla_server.errors.future import NotFoundError
24+
from argilla_server.models import UserRole
2425

2526

2627
class Strategy(BaseStrategy):
@@ -48,6 +49,61 @@ class HuggingfaceOpenId(OpenIdConnectAuth):
4849
DEFAULT_SCOPE = ["openid", "profile"]
4950

5051

52+
class KeycloakOpenId(OpenIdConnectAuth):
53+
"""Huggingface OpenID Connect authentication backend."""
54+
55+
name = "keycloak"
56+
57+
def oidc_endpoint(self) -> str:
58+
value = super().oidc_endpoint()
59+
60+
if value is None:
61+
from social_core.utils import setting_name
62+
63+
name = setting_name("OIDC_ENDPOINT")
64+
raise ValueError(
65+
"oidc_endpoint needs to be set in the Keycloak configuration. "
66+
f"Please set the {name} environment variable."
67+
)
68+
69+
return value
70+
71+
def get_user_details(self, response: Dict[str, Any]) -> Dict[str, Any]:
72+
user = super().get_user_details(response)
73+
74+
if role := self._extract_role(response):
75+
user["role"] = role
76+
77+
if available_workspaces := self._extract_available_workspaces(response):
78+
user["available_workspaces"] = available_workspaces
79+
80+
return user
81+
82+
def _extract_role(self, response: Dict[str, Any]) -> Optional[str]:
83+
roles = self._read_realm_roles(response)
84+
role_to_value = {UserRole.owner: 3, UserRole.admin: 2, UserRole.annotator: 1}
85+
role_list = [role.split(":")[1] for role in roles if role.startswith("argilla_role:")]
86+
if role_list:
87+
max_role = max(role_list, key=lambda s: role_to_value.get(s, 0))
88+
return max_role
89+
90+
def _extract_available_workspaces(self, response: Dict[str, Any]) -> List[str]:
91+
roles = self._read_realm_roles(response)
92+
93+
workspaces = []
94+
for role in roles:
95+
if role.startswith("argilla_workspace:"):
96+
workspace = role.split(":")[1]
97+
workspaces.append(workspace)
98+
99+
return workspaces
100+
101+
@classmethod
102+
def _read_realm_roles(cls, response) -> List[str]:
103+
realm_access = response.get("realm_access") or {}
104+
return realm_access.get("roles") or []
105+
106+
51107
_SUPPORTED_BACKENDS = {}
52108

53109

@@ -56,6 +112,7 @@ def load_supported_backends(extra_backends: list = None) -> Dict[str, Type[BaseO
56112

57113
backends = [
58114
"argilla_server.security.authentication.oauth2._backends.HuggingfaceOpenId",
115+
"argilla_server.security.authentication.oauth2._backends.KeycloakOpenId",
59116
"social_core.backends.github.GithubOAuth2",
60117
"social_core.backends.google.GoogleOAuth2",
61118
]

argilla-server/src/argilla_server/security/authentication/oauth2/provider.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def __init__(
4848
client_secret: str = None,
4949
scope: Optional[List[str]] = None,
5050
redirect_uri: str = None,
51+
sync_user: bool = False,
5152
) -> None:
5253
self.name = backend_class.name
5354
self._backend = backend_class(strategy=self.backend_strategy)
@@ -74,6 +75,7 @@ def __init__(
7475
self.scope = self.scope.split(" ")
7576

7677
self.redirect_uri = redirect_uri or f"/oauth/{self.name}/callback"
78+
self.sync_user = sync_user
7779

7880
@classmethod
7981
def from_dict(cls, provider: dict, backend_class: Type[BaseOAuth2]) -> "OAuth2ClientProvider":

argilla-server/src/argilla_server/security/authentication/userinfo.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,19 @@ def username(self) -> str:
3636
def first_name(self) -> str:
3737
return self.get("first_name") or self.username
3838

39+
@property
40+
def last_name(self) -> Optional[str]:
41+
return self.get("last_name") or None
42+
3943
@property
4044
def role(self) -> UserRole:
4145
role = self.get("role") or self._parse_role_from_environment()
4246
return UserRole(role)
4347

48+
@property
49+
def available_workspaces(self) -> Optional[list]:
50+
return self.get("available_workspaces")
51+
4452
def _parse_role_from_environment(self) -> Optional[UserRole]:
4553
"""This is a temporal solution, and it will be replaced by a proper Sign up process"""
4654
if self["username"] == os.getenv("USERNAME"):

0 commit comments

Comments
 (0)