Skip to content

Commit 04b38a9

Browse files
committed
SNOW-2208073 Implement locking of prompting authenticators
1 parent 31b8ccf commit 04b38a9

21 files changed

+1549
-57
lines changed

assert_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"slices"
1212
"strings"
1313
"testing"
14+
"time"
1415
)
1516

1617
func assertNilE(t *testing.T, actual any, descriptions ...string) {
@@ -134,7 +135,7 @@ func errorOnNonEmpty(t *testing.T, errMsg string) {
134135
}
135136

136137
func formatErrorMessage(errMsg string) string {
137-
return fmt.Sprintf("%s. Thrown from %s", maskSecrets(errMsg), thrownFrom())
138+
return fmt.Sprintf("[%s] %s. Thrown from %s", time.Now().Format(time.RFC3339Nano), maskSecrets(errMsg), thrownFrom())
138139
}
139140

140141
func validateNil(actual any, descriptions ...string) string {

auth.go

Lines changed: 154 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -537,11 +537,7 @@ func createRequestBody(sc *snowflakeConn, sessionParameters map[string]interface
537537
}
538538
case AuthTypeOAuthAuthorizationCode:
539539
logger.WithContext(sc.ctx).Debug("OAuth authorization code")
540-
oauthClient, err := newOauthClient(sc.ctx, sc.cfg, sc)
541-
if err != nil {
542-
return nil, err
543-
}
544-
token, err := oauthClient.authenticateByOAuthAuthorizationCode()
540+
token, err := authenticateByAuthorizationCode(sc)
545541
if err != nil {
546542
return nil, err
547543
}
@@ -584,6 +580,62 @@ func createRequestBody(sc *snowflakeConn, sessionParameters map[string]interface
584580
return jsonBody, nil
585581
}
586582

583+
type oauthLockKey struct {
584+
tokenRequestUrl string
585+
user string
586+
flowType string
587+
}
588+
589+
func newOAuthAuthorizationCodeLockKey(tokenRequestUrl, user string) *oauthLockKey {
590+
return &oauthLockKey{
591+
tokenRequestUrl: tokenRequestUrl,
592+
user: user,
593+
flowType: "authorization_code",
594+
}
595+
}
596+
597+
func newRefreshTokenLockKey(tokenRequestUrl, user string) *oauthLockKey {
598+
return &oauthLockKey{
599+
tokenRequestUrl: tokenRequestUrl,
600+
user: user,
601+
flowType: "refresh_token",
602+
}
603+
}
604+
605+
func (o *oauthLockKey) lockId() string {
606+
return o.tokenRequestUrl + "|" + o.user + "|" + o.flowType
607+
}
608+
609+
func authenticateByAuthorizationCode(sc *snowflakeConn) (string, error) {
610+
oauthClient, err := newOauthClient(sc.ctx, sc.cfg, sc)
611+
if err != nil {
612+
return "", err
613+
}
614+
if !isEligibleForParallelLogin(sc.cfg, sc.cfg.ClientStoreTemporaryCredential) {
615+
return oauthClient.authenticateByOAuthAuthorizationCode()
616+
}
617+
618+
lockKey := newOAuthAuthorizationCodeLockKey(oauthClient.tokenURL(), sc.cfg.User)
619+
valueAwaiter := valueAwaitHolder.get(lockKey)
620+
defer valueAwaiter.resumeOne()
621+
token, err := awaitValue(valueAwaiter, func() (string, error) {
622+
return credentialsStorage.getCredential(newOAuthAccessTokenSpec(oauthClient.tokenURL(), sc.cfg.User)), nil
623+
}, func(s string, err error) bool {
624+
return s != ""
625+
}, func() string {
626+
return ""
627+
})
628+
if err != nil || token != "" {
629+
return token, err
630+
}
631+
token, err = oauthClient.authenticateByOAuthAuthorizationCode()
632+
if err != nil {
633+
return "", err
634+
}
635+
valueAwaiter.done()
636+
return token, err
637+
}
638+
587639
// Generate a JWT token in string given the configuration
588640
func prepareJWTToken(config *Config) (string, error) {
589641
if config.PrivateKey == nil {
@@ -619,20 +671,60 @@ func prepareJWTToken(config *Config) (string, error) {
619671
return tokenString, err
620672
}
621673

622-
// Authenticate with sc.cfg
674+
type tokenLockKey struct {
675+
snowflakeHost string
676+
user string
677+
tokenType string
678+
}
679+
680+
func newMfaTokenLockKey(snowflakeHost, user string) *tokenLockKey {
681+
return &tokenLockKey{
682+
snowflakeHost: snowflakeHost,
683+
user: user,
684+
tokenType: "MFA",
685+
}
686+
}
687+
688+
func newIDTokenLockKey(snowflakeHost, user string) *tokenLockKey {
689+
return &tokenLockKey{
690+
snowflakeHost: snowflakeHost,
691+
user: user,
692+
tokenType: "ID",
693+
}
694+
}
695+
696+
func (m *tokenLockKey) lockId() string {
697+
return m.snowflakeHost + "|" + m.user + "|" + m.tokenType
698+
}
699+
623700
func authenticateWithConfig(sc *snowflakeConn) error {
624701
var authData *authResponseMain
625702
var samlResponse []byte
626703
var proofKey []byte
627704
var err error
628-
//var consentCacheIdToken = true
705+
706+
mfaTokenLockKey := newMfaTokenLockKey(sc.cfg.Host, sc.cfg.User)
707+
idTokenLockKey := newIDTokenLockKey(sc.cfg.Host, sc.cfg.User)
629708

630709
if sc.cfg.Authenticator == AuthTypeExternalBrowser || sc.cfg.Authenticator == AuthTypeOAuthAuthorizationCode || sc.cfg.Authenticator == AuthTypeOAuthClientCredentials {
631710
if (runtime.GOOS == "windows" || runtime.GOOS == "darwin") && sc.cfg.ClientStoreTemporaryCredential == configBoolNotSet {
632711
sc.cfg.ClientStoreTemporaryCredential = ConfigBoolTrue
633712
}
634-
if sc.cfg.Authenticator == AuthTypeExternalBrowser && sc.cfg.ClientStoreTemporaryCredential == ConfigBoolTrue {
635-
sc.cfg.IDToken = credentialsStorage.getCredential(newIDTokenSpec(sc.cfg.Host, sc.cfg.User))
713+
if sc.cfg.Authenticator == AuthTypeExternalBrowser {
714+
if isEligibleForParallelLogin(sc.cfg, sc.cfg.ClientStoreTemporaryCredential) {
715+
valueAwaiter := valueAwaitHolder.get(idTokenLockKey)
716+
defer valueAwaiter.resumeOne()
717+
sc.cfg.IDToken, _ = awaitValue(valueAwaiter, func() (string, error) {
718+
credential := credentialsStorage.getCredential(newIDTokenSpec(sc.cfg.Host, sc.cfg.User))
719+
return credential, nil
720+
}, func(s string, err error) bool {
721+
return s != ""
722+
}, func() string {
723+
return ""
724+
})
725+
} else {
726+
sc.cfg.IDToken = credentialsStorage.getCredential(newIDTokenSpec(sc.cfg.Host, sc.cfg.User))
727+
}
636728
}
637729
// Disable console login by default
638730
if sc.cfg.DisableConsoleLogin == configBoolNotSet {
@@ -644,7 +736,18 @@ func authenticateWithConfig(sc *snowflakeConn) error {
644736
if (runtime.GOOS == "windows" || runtime.GOOS == "darwin") && sc.cfg.ClientRequestMfaToken == configBoolNotSet {
645737
sc.cfg.ClientRequestMfaToken = ConfigBoolTrue
646738
}
647-
if sc.cfg.ClientRequestMfaToken == ConfigBoolTrue {
739+
if isEligibleForParallelLogin(sc.cfg, sc.cfg.ClientRequestMfaToken) {
740+
valueAwaiter := valueAwaitHolder.get(mfaTokenLockKey)
741+
defer valueAwaiter.resumeOne()
742+
sc.cfg.MfaToken, _ = awaitValue(valueAwaiter, func() (string, error) {
743+
credential := credentialsStorage.getCredential(newMfaTokenSpec(sc.cfg.Host, sc.cfg.User))
744+
return credential, nil
745+
}, func(s string, err error) bool {
746+
return s != ""
747+
}, func() string {
748+
return ""
749+
})
750+
} else {
648751
sc.cfg.MfaToken = credentialsStorage.getCredential(newMfaTokenSpec(sc.cfg.Host, sc.cfg.User))
649752
}
650753
}
@@ -660,7 +763,6 @@ func authenticateWithConfig(sc *snowflakeConn) error {
660763
sc.cfg.Application,
661764
sc.cfg.Account,
662765
sc.cfg.User,
663-
sc.cfg.Password,
664766
sc.cfg.ExternalBrowserTimeout,
665767
sc.cfg.DisableConsoleLogin)
666768
if err != nil {
@@ -680,15 +782,7 @@ func authenticateWithConfig(sc *snowflakeConn) error {
680782
credentialsStorage.deleteCredential(newOAuthAccessTokenSpec(sc.cfg.OauthTokenRequestURL, sc.cfg.User))
681783

682784
if sc.cfg.Authenticator == AuthTypeOAuthAuthorizationCode {
683-
var oauthClient *oauthClient
684-
if oauthClient, err = newOauthClient(sc.ctx, sc.cfg, sc); err != nil {
685-
logger.Warnf("failed to create oauth client. %v", err)
686-
} else {
687-
if err = oauthClient.refreshToken(); err != nil {
688-
logger.Warnf("cannot refresh token. %v", err)
689-
credentialsStorage.deleteCredential(newOAuthRefreshTokenSpec(sc.cfg.OauthTokenRequestURL, sc.cfg.User))
690-
}
691-
}
785+
doRefreshTokenWithLock(sc)
692786
}
693787

694788
// if refreshing succeeds for authorization code, we will take a token from cache
@@ -700,7 +794,47 @@ func authenticateWithConfig(sc *snowflakeConn) error {
700794
return err
701795
}
702796
}
797+
if sc.cfg.Authenticator == AuthTypeUsernamePasswordMFA && isEligibleForParallelLogin(sc.cfg, sc.cfg.ClientRequestMfaToken) {
798+
valueAwaiter := valueAwaitHolder.get(mfaTokenLockKey)
799+
valueAwaiter.done()
800+
}
801+
if sc.cfg.Authenticator == AuthTypeExternalBrowser && isEligibleForParallelLogin(sc.cfg, sc.cfg.ClientStoreTemporaryCredential) {
802+
valueAwaiter := valueAwaitHolder.get(idTokenLockKey)
803+
valueAwaiter.done()
804+
}
703805
sc.populateSessionParameters(authData.Parameters)
704806
sc.ctx = context.WithValue(sc.ctx, SFSessionIDKey, authData.SessionID)
705807
return nil
706808
}
809+
810+
func doRefreshTokenWithLock(sc *snowflakeConn) {
811+
if oauthClient, err := newOauthClient(sc.ctx, sc.cfg, sc); err != nil {
812+
logger.Warnf("failed to create oauth client. %v", err)
813+
} else {
814+
lockKey := newRefreshTokenLockKey(oauthClient.tokenURL(), sc.cfg.User)
815+
if _, err = getValueWithLock(chooseLockerForAuth(sc.cfg), lockKey, func() (string, error) {
816+
if err = oauthClient.refreshToken(); err != nil {
817+
logger.Warnf("cannot refresh token. %v", err)
818+
credentialsStorage.deleteCredential(newOAuthRefreshTokenSpec(sc.cfg.OauthTokenRequestURL, sc.cfg.User))
819+
return "", err
820+
}
821+
return "", nil
822+
}); err != nil {
823+
logger.Warnf("failed to refresh token with lock. %v", err)
824+
}
825+
}
826+
}
827+
828+
func chooseLockerForAuth(cfg *Config) locker {
829+
if cfg.SingleAuthenticationPrompt == ConfigBoolFalse {
830+
return noopLocker
831+
}
832+
if cfg.User == "" {
833+
return noopLocker
834+
}
835+
return exclusiveLocker
836+
}
837+
838+
func isEligibleForParallelLogin(cfg *Config, cacheEnabled ConfigBool) bool {
839+
return cfg.SingleAuthenticationPrompt != ConfigBoolFalse && cfg.User != "" && cacheEnabled == ConfigBoolTrue
840+
}

auth_oauth_test.go

Lines changed: 53 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ func TestUnitOAuthAuthorizationCode(t *testing.T) {
5757
credentialsStorage.deleteCredential(accessTokenSpec)
5858
credentialsStorage.deleteCredential(refreshTokenSpec)
5959
wiremock.registerMappings(t, newWiremockMapping("oauth2/authorization_code/successful_flow.json"))
60-
authCodeProvider := &nonInteractiveAuthorizationCodeProvider{}
60+
authCodeProvider := &nonInteractiveAuthorizationCodeProvider{t: t}
6161
client.authorizationCodeProviderFactory = func() authorizationCodeProvider {
6262
return authCodeProvider
6363
}
@@ -71,7 +71,7 @@ func TestUnitOAuthAuthorizationCode(t *testing.T) {
7171
roundTripper.reset()
7272
credentialsStorage.setCredential(accessTokenSpec, "access-token-123")
7373
wiremock.registerMappings(t, newWiremockMapping("oauth2/authorization_code/successful_flow.json"))
74-
authCodeProvider := &nonInteractiveAuthorizationCodeProvider{}
74+
authCodeProvider := &nonInteractiveAuthorizationCodeProvider{t: t}
7575
for i := 0; i < 3; i++ {
7676
client, err := newOauthClient(context.WithValue(context.Background(), oauth2.HTTPClient, httpClient), cfg, &snowflakeConn{})
7777
assertNilF(t, err)
@@ -91,6 +91,7 @@ func TestUnitOAuthAuthorizationCode(t *testing.T) {
9191
wiremock.registerMappings(t, newWiremockMapping("oauth2/authorization_code/successful_flow.json"))
9292
authCodeProvider := &nonInteractiveAuthorizationCodeProvider{
9393
tamperWithState: true,
94+
t: t,
9495
}
9596
client.authorizationCodeProviderFactory = func() authorizationCodeProvider {
9697
return authCodeProvider
@@ -105,7 +106,7 @@ func TestUnitOAuthAuthorizationCode(t *testing.T) {
105106
credentialsStorage.deleteCredential(accessTokenSpec)
106107
credentialsStorage.deleteCredential(refreshTokenSpec)
107108
wiremock.registerMappings(t, newWiremockMapping("oauth2/authorization_code/error_from_idp.json"))
108-
authCodeProvider := &nonInteractiveAuthorizationCodeProvider{}
109+
authCodeProvider := &nonInteractiveAuthorizationCodeProvider{t: t}
109110
client.authorizationCodeProviderFactory = func() authorizationCodeProvider {
110111
return authCodeProvider
111112
}
@@ -130,7 +131,7 @@ func TestUnitOAuthAuthorizationCode(t *testing.T) {
130131
credentialsStorage.deleteCredential(accessTokenSpec)
131132
credentialsStorage.deleteCredential(refreshTokenSpec)
132133
wiremock.registerMappings(t, newWiremockMapping("oauth2/authorization_code/invalid_code.json"))
133-
authCodeProvider := &nonInteractiveAuthorizationCodeProvider{}
134+
authCodeProvider := &nonInteractiveAuthorizationCodeProvider{t: t}
134135
client.authorizationCodeProviderFactory = func() authorizationCodeProvider {
135136
return authCodeProvider
136137
}
@@ -148,7 +149,9 @@ func TestUnitOAuthAuthorizationCode(t *testing.T) {
148149
wiremock.registerMappings(t, newWiremockMapping("oauth2/authorization_code/successful_flow.json"))
149150
client.cfg.ExternalBrowserTimeout = 2 * time.Second
150151
authCodeProvider := &nonInteractiveAuthorizationCodeProvider{
151-
sleepTime: 3 * time.Second,
152+
sleepTime: 3 * time.Second,
153+
triggerError: "timed out",
154+
t: t,
152155
}
153156
client.authorizationCodeProviderFactory = func() authorizationCodeProvider {
154157
return authCodeProvider
@@ -289,6 +292,48 @@ func TestAuthorizationCodeFlow(t *testing.T) {
289292
runSmokeQuery(t, db)
290293
})
291294

295+
t.Run("successful flow with multiple threads", func(t *testing.T) {
296+
for _, singleAuthenticationPrompt := range []ConfigBool{ConfigBoolFalse, ConfigBoolTrue, configBoolNotSet} {
297+
t.Run("singleAuthenticationPrompt="+singleAuthenticationPrompt.String(), func(t *testing.T) {
298+
currentDefaultAuthorizationCodeProviderFactory := defaultAuthorizationCodeProviderFactory
299+
defer func() {
300+
defaultAuthorizationCodeProviderFactory = currentDefaultAuthorizationCodeProviderFactory
301+
}()
302+
defaultAuthorizationCodeProviderFactory = func() authorizationCodeProvider {
303+
return &nonInteractiveAuthorizationCodeProvider{
304+
t: t,
305+
mu: sync.Mutex{},
306+
sleepTime: 500 * time.Millisecond,
307+
}
308+
}
309+
roundTripper.reset()
310+
wiremock.registerMappings(t,
311+
newWiremockMapping("oauth2/authorization_code/successful_flow.json"),
312+
newWiremockMapping("oauth2/login_request.json"),
313+
newWiremockMapping("select1.json"),
314+
newWiremockMapping("close_session.json"))
315+
cfg := wiremock.connectionConfig()
316+
cfg.Role = "ANALYST"
317+
cfg.Authenticator = AuthTypeOAuthAuthorizationCode
318+
cfg.Transporter = roundTripper
319+
cfg.SingleAuthenticationPrompt = singleAuthenticationPrompt
320+
oauthAccessTokenSpec := newOAuthAccessTokenSpec(cfg.OauthTokenRequestURL, cfg.User)
321+
oauthRefreshTokenSpec := newOAuthRefreshTokenSpec(cfg.OauthTokenRequestURL, cfg.User)
322+
credentialsStorage.deleteCredential(oauthAccessTokenSpec)
323+
credentialsStorage.deleteCredential(oauthRefreshTokenSpec)
324+
connector := NewConnector(SnowflakeDriver{}, *cfg)
325+
db := sql.OpenDB(connector)
326+
initPoolWithSize(t, db, 20)
327+
println(roundTripper.postReqCount[cfg.OauthTokenRequestURL])
328+
if singleAuthenticationPrompt == ConfigBoolFalse {
329+
assertTrueE(t, roundTripper.postReqCount[cfg.OauthTokenRequestURL] > 1)
330+
} else {
331+
assertEqualE(t, roundTripper.postReqCount[cfg.OauthTokenRequestURL], 1)
332+
}
333+
})
334+
}
335+
})
336+
292337
t.Run("successful flow with single-use refresh token enabled", func(t *testing.T) {
293338
wiremock.registerMappings(t,
294339
newWiremockMapping("oauth2/authorization_code/successful_flow_with_single_use_refresh_token.json"),
@@ -683,7 +728,9 @@ type nonInteractiveAuthorizationCodeProvider struct {
683728
func (provider *nonInteractiveAuthorizationCodeProvider) run(authorizationURL string) error {
684729
if provider.sleepTime != 0 {
685730
time.Sleep(provider.sleepTime)
686-
return errors.New("ignore me")
731+
if provider.triggerError != "" {
732+
return errors.New(provider.triggerError)
733+
}
687734
}
688735
if provider.triggerError != "" {
689736
return errors.New(provider.triggerError)

0 commit comments

Comments
 (0)