@@ -537,11 +537,7 @@ func createRequestBody(sc *snowflakeConn, sessionParameters map[string]interface
537
537
}
538
538
case AuthTypeOAuthAuthorizationCode :
539
539
logger .WithContext (sc .ctx ).Debug ("OAuth authorization code" )
540
- oauthClient , err := newOauthClient (sc .ctx , sc .cfg , sc )
541
- if err != nil {
542
- return nil , err
543
- }
544
- token , err := oauthClient .authenticateByOAuthAuthorizationCode ()
540
+ token , err := authenticateByAuthorizationCode (sc )
545
541
if err != nil {
546
542
return nil , err
547
543
}
@@ -584,6 +580,62 @@ func createRequestBody(sc *snowflakeConn, sessionParameters map[string]interface
584
580
return jsonBody , nil
585
581
}
586
582
583
+ type oauthLockKey struct {
584
+ tokenRequestUrl string
585
+ user string
586
+ flowType string
587
+ }
588
+
589
+ func newOAuthAuthorizationCodeLockKey (tokenRequestUrl , user string ) * oauthLockKey {
590
+ return & oauthLockKey {
591
+ tokenRequestUrl : tokenRequestUrl ,
592
+ user : user ,
593
+ flowType : "authorization_code" ,
594
+ }
595
+ }
596
+
597
+ func newRefreshTokenLockKey (tokenRequestUrl , user string ) * oauthLockKey {
598
+ return & oauthLockKey {
599
+ tokenRequestUrl : tokenRequestUrl ,
600
+ user : user ,
601
+ flowType : "refresh_token" ,
602
+ }
603
+ }
604
+
605
+ func (o * oauthLockKey ) lockId () string {
606
+ return o .tokenRequestUrl + "|" + o .user + "|" + o .flowType
607
+ }
608
+
609
+ func authenticateByAuthorizationCode (sc * snowflakeConn ) (string , error ) {
610
+ oauthClient , err := newOauthClient (sc .ctx , sc .cfg , sc )
611
+ if err != nil {
612
+ return "" , err
613
+ }
614
+ if ! isEligibleForParallelLogin (sc .cfg , sc .cfg .ClientStoreTemporaryCredential ) {
615
+ return oauthClient .authenticateByOAuthAuthorizationCode ()
616
+ }
617
+
618
+ lockKey := newOAuthAuthorizationCodeLockKey (oauthClient .tokenURL (), sc .cfg .User )
619
+ valueAwaiter := valueAwaitHolder .get (lockKey )
620
+ defer valueAwaiter .resumeOne ()
621
+ token , err := awaitValue (valueAwaiter , func () (string , error ) {
622
+ return credentialsStorage .getCredential (newOAuthAccessTokenSpec (oauthClient .tokenURL (), sc .cfg .User )), nil
623
+ }, func (s string , err error ) bool {
624
+ return s != ""
625
+ }, func () string {
626
+ return ""
627
+ })
628
+ if err != nil || token != "" {
629
+ return token , err
630
+ }
631
+ token , err = oauthClient .authenticateByOAuthAuthorizationCode ()
632
+ if err != nil {
633
+ return "" , err
634
+ }
635
+ valueAwaiter .done ()
636
+ return token , err
637
+ }
638
+
587
639
// Generate a JWT token in string given the configuration
588
640
func prepareJWTToken (config * Config ) (string , error ) {
589
641
if config .PrivateKey == nil {
@@ -619,20 +671,60 @@ func prepareJWTToken(config *Config) (string, error) {
619
671
return tokenString , err
620
672
}
621
673
622
- // Authenticate with sc.cfg
674
+ type tokenLockKey struct {
675
+ snowflakeHost string
676
+ user string
677
+ tokenType string
678
+ }
679
+
680
+ func newMfaTokenLockKey (snowflakeHost , user string ) * tokenLockKey {
681
+ return & tokenLockKey {
682
+ snowflakeHost : snowflakeHost ,
683
+ user : user ,
684
+ tokenType : "MFA" ,
685
+ }
686
+ }
687
+
688
+ func newIDTokenLockKey (snowflakeHost , user string ) * tokenLockKey {
689
+ return & tokenLockKey {
690
+ snowflakeHost : snowflakeHost ,
691
+ user : user ,
692
+ tokenType : "ID" ,
693
+ }
694
+ }
695
+
696
+ func (m * tokenLockKey ) lockId () string {
697
+ return m .snowflakeHost + "|" + m .user + "|" + m .tokenType
698
+ }
699
+
623
700
func authenticateWithConfig (sc * snowflakeConn ) error {
624
701
var authData * authResponseMain
625
702
var samlResponse []byte
626
703
var proofKey []byte
627
704
var err error
628
- //var consentCacheIdToken = true
705
+
706
+ mfaTokenLockKey := newMfaTokenLockKey (sc .cfg .Host , sc .cfg .User )
707
+ idTokenLockKey := newIDTokenLockKey (sc .cfg .Host , sc .cfg .User )
629
708
630
709
if sc .cfg .Authenticator == AuthTypeExternalBrowser || sc .cfg .Authenticator == AuthTypeOAuthAuthorizationCode || sc .cfg .Authenticator == AuthTypeOAuthClientCredentials {
631
710
if (runtime .GOOS == "windows" || runtime .GOOS == "darwin" ) && sc .cfg .ClientStoreTemporaryCredential == configBoolNotSet {
632
711
sc .cfg .ClientStoreTemporaryCredential = ConfigBoolTrue
633
712
}
634
- if sc .cfg .Authenticator == AuthTypeExternalBrowser && sc .cfg .ClientStoreTemporaryCredential == ConfigBoolTrue {
635
- sc .cfg .IDToken = credentialsStorage .getCredential (newIDTokenSpec (sc .cfg .Host , sc .cfg .User ))
713
+ if sc .cfg .Authenticator == AuthTypeExternalBrowser {
714
+ if isEligibleForParallelLogin (sc .cfg , sc .cfg .ClientStoreTemporaryCredential ) {
715
+ valueAwaiter := valueAwaitHolder .get (idTokenLockKey )
716
+ defer valueAwaiter .resumeOne ()
717
+ sc .cfg .IDToken , _ = awaitValue (valueAwaiter , func () (string , error ) {
718
+ credential := credentialsStorage .getCredential (newIDTokenSpec (sc .cfg .Host , sc .cfg .User ))
719
+ return credential , nil
720
+ }, func (s string , err error ) bool {
721
+ return s != ""
722
+ }, func () string {
723
+ return ""
724
+ })
725
+ } else {
726
+ sc .cfg .IDToken = credentialsStorage .getCredential (newIDTokenSpec (sc .cfg .Host , sc .cfg .User ))
727
+ }
636
728
}
637
729
// Disable console login by default
638
730
if sc .cfg .DisableConsoleLogin == configBoolNotSet {
@@ -644,7 +736,18 @@ func authenticateWithConfig(sc *snowflakeConn) error {
644
736
if (runtime .GOOS == "windows" || runtime .GOOS == "darwin" ) && sc .cfg .ClientRequestMfaToken == configBoolNotSet {
645
737
sc .cfg .ClientRequestMfaToken = ConfigBoolTrue
646
738
}
647
- if sc .cfg .ClientRequestMfaToken == ConfigBoolTrue {
739
+ if isEligibleForParallelLogin (sc .cfg , sc .cfg .ClientRequestMfaToken ) {
740
+ valueAwaiter := valueAwaitHolder .get (mfaTokenLockKey )
741
+ defer valueAwaiter .resumeOne ()
742
+ sc .cfg .MfaToken , _ = awaitValue (valueAwaiter , func () (string , error ) {
743
+ credential := credentialsStorage .getCredential (newMfaTokenSpec (sc .cfg .Host , sc .cfg .User ))
744
+ return credential , nil
745
+ }, func (s string , err error ) bool {
746
+ return s != ""
747
+ }, func () string {
748
+ return ""
749
+ })
750
+ } else {
648
751
sc .cfg .MfaToken = credentialsStorage .getCredential (newMfaTokenSpec (sc .cfg .Host , sc .cfg .User ))
649
752
}
650
753
}
@@ -660,7 +763,6 @@ func authenticateWithConfig(sc *snowflakeConn) error {
660
763
sc .cfg .Application ,
661
764
sc .cfg .Account ,
662
765
sc .cfg .User ,
663
- sc .cfg .Password ,
664
766
sc .cfg .ExternalBrowserTimeout ,
665
767
sc .cfg .DisableConsoleLogin )
666
768
if err != nil {
@@ -680,15 +782,7 @@ func authenticateWithConfig(sc *snowflakeConn) error {
680
782
credentialsStorage .deleteCredential (newOAuthAccessTokenSpec (sc .cfg .OauthTokenRequestURL , sc .cfg .User ))
681
783
682
784
if sc .cfg .Authenticator == AuthTypeOAuthAuthorizationCode {
683
- var oauthClient * oauthClient
684
- if oauthClient , err = newOauthClient (sc .ctx , sc .cfg , sc ); err != nil {
685
- logger .Warnf ("failed to create oauth client. %v" , err )
686
- } else {
687
- if err = oauthClient .refreshToken (); err != nil {
688
- logger .Warnf ("cannot refresh token. %v" , err )
689
- credentialsStorage .deleteCredential (newOAuthRefreshTokenSpec (sc .cfg .OauthTokenRequestURL , sc .cfg .User ))
690
- }
691
- }
785
+ doRefreshTokenWithLock (sc )
692
786
}
693
787
694
788
// if refreshing succeeds for authorization code, we will take a token from cache
@@ -700,7 +794,47 @@ func authenticateWithConfig(sc *snowflakeConn) error {
700
794
return err
701
795
}
702
796
}
797
+ if sc .cfg .Authenticator == AuthTypeUsernamePasswordMFA && isEligibleForParallelLogin (sc .cfg , sc .cfg .ClientRequestMfaToken ) {
798
+ valueAwaiter := valueAwaitHolder .get (mfaTokenLockKey )
799
+ valueAwaiter .done ()
800
+ }
801
+ if sc .cfg .Authenticator == AuthTypeExternalBrowser && isEligibleForParallelLogin (sc .cfg , sc .cfg .ClientStoreTemporaryCredential ) {
802
+ valueAwaiter := valueAwaitHolder .get (idTokenLockKey )
803
+ valueAwaiter .done ()
804
+ }
703
805
sc .populateSessionParameters (authData .Parameters )
704
806
sc .ctx = context .WithValue (sc .ctx , SFSessionIDKey , authData .SessionID )
705
807
return nil
706
808
}
809
+
810
+ func doRefreshTokenWithLock (sc * snowflakeConn ) {
811
+ if oauthClient , err := newOauthClient (sc .ctx , sc .cfg , sc ); err != nil {
812
+ logger .Warnf ("failed to create oauth client. %v" , err )
813
+ } else {
814
+ lockKey := newRefreshTokenLockKey (oauthClient .tokenURL (), sc .cfg .User )
815
+ if _ , err = getValueWithLock (chooseLockerForAuth (sc .cfg ), lockKey , func () (string , error ) {
816
+ if err = oauthClient .refreshToken (); err != nil {
817
+ logger .Warnf ("cannot refresh token. %v" , err )
818
+ credentialsStorage .deleteCredential (newOAuthRefreshTokenSpec (sc .cfg .OauthTokenRequestURL , sc .cfg .User ))
819
+ return "" , err
820
+ }
821
+ return "" , nil
822
+ }); err != nil {
823
+ logger .Warnf ("failed to refresh token with lock. %v" , err )
824
+ }
825
+ }
826
+ }
827
+
828
+ func chooseLockerForAuth (cfg * Config ) locker {
829
+ if cfg .SingleAuthenticationPrompt == ConfigBoolFalse {
830
+ return noopLocker
831
+ }
832
+ if cfg .User == "" {
833
+ return noopLocker
834
+ }
835
+ return exclusiveLocker
836
+ }
837
+
838
+ func isEligibleForParallelLogin (cfg * Config , cacheEnabled ConfigBool ) bool {
839
+ return cfg .SingleAuthenticationPrompt != ConfigBoolFalse && cfg .User != "" && cacheEnabled == ConfigBoolTrue
840
+ }
0 commit comments