diff --git a/cmd/testoidcauth/main.go b/cmd/testoidcauth/main.go index 060f464355..4bed494c4a 100644 --- a/cmd/testoidcauth/main.go +++ b/cmd/testoidcauth/main.go @@ -25,15 +25,14 @@ import ( var uriAdmin = os.Getenv("MONGODB_URI") var uriSingle = os.Getenv("MONGODB_URI_SINGLE") - -// var uriMulti = os.Getenv("MONGODB_URI_MULTI") +var uriMulti = os.Getenv("MONGODB_URI_MULTI") var oidcTokenDir = os.Getenv("OIDC_TOKEN_DIR") -//var oidcDomain = os.Getenv("OIDC_DOMAIN") +var oidcDomain = os.Getenv("OIDC_DOMAIN") -//func explicitUser(user string) string { -// return fmt.Sprintf("%s@%s", user, oidcDomain) -//} +func explicitUser(user string) string { + return fmt.Sprintf("%s@%s", user, oidcDomain) +} func tokenFile(user string) string { return path.Join(oidcTokenDir, user) @@ -50,6 +49,13 @@ func connectWithMachineCB(uri string, cb options.OIDCCallback) (*mongo.Client, e return mongo.Connect(context.Background(), opts) } +func connectWithHumanCB(uri string, cb options.OIDCCallback) (*mongo.Client, error) { + opts := options.Client().ApplyURI(uri) + + opts.Auth.OIDCHumanCallback = cb + return mongo.Connect(context.Background(), opts) +} + func connectWithMachineCBAndProperties(uri string, cb options.OIDCCallback, props map[string]string) (*mongo.Client, error) { opts := options.Client().ApplyURI(uri) @@ -88,6 +94,22 @@ func main() { aux("machine_4_1_reauthenticationSucceeds", machine41ReauthenticationSucceeds) aux("machine_4_2_readCommandsFailIfReauthenticationFails", machine42ReadCommandsFailIfReauthenticationFails) aux("machine_4_3_writeCommandsFailIfReauthenticationFails", machine43WriteCommandsFailIfReauthenticationFails) + aux("human_1_1_singlePrincipalImplictUsername", human11singlePrincipalImplictUsername) + aux("human_1_2_singlePrincipalExplicitUsername", human12singlePrincipalExplicitUsername) + aux("human_1_3_mulitplePrincipalUser1", human13mulitplePrincipalUser1) + aux("human_1_4_mulitplePrincipalUser2", human14mulitplePrincipalUser2) + aux("human_1_5_multiplPrincipalNoUser", human15mulitplePrincipalNoUser) + aux("human_1_6_allowedHostsBlocked", human16allowedHostsBlocked) + aux("human_1_7_allowedHostsInConnectionStringIgnored", human17AllowedHostsInConnectionStringIgnored) + aux("human_2_1_validCallbackInputs", human21validCallbackInputs) + aux("human_2_2_CallbackReturnsMissingData", human22CallbackReturnsMissingData) + aux("human_2_3_RefreshTokenIsPassedToCallback", human23RefreshTokenIsPassedToCallback) + aux("human_3_1_usesSpeculativeAuth", human31usesSpeculativeAuth) + aux("human_3_2_doesNotUseSpecualtiveAuth", human32doesNotUseSpecualtiveAuth) + aux("human_4_1_reauthenticationSucceeds", human41ReauthenticationSucceeds) + aux("human_4_2_reauthenticationSucceedsNoRefresh", human42ReauthenticationSucceedsNoRefreshToken) + aux("human_4_3_reauthenticationSucceedsAfterRefreshFails", human43ReauthenticationSucceedsAfterRefreshFails) + aux("human_4_4_reauthenticationFails", human44ReauthenticationFails) case "azure": aux("machine_5_1_azureWithNoUsername", machine51azureWithNoUsername) aux("machine_5_2_azureWithNoUsername", machine52azureWithBadUsername) @@ -403,11 +425,10 @@ func machine33UnexpectedErrorCodeDoesNotClearTheCache() error { countMutex := sync.Mutex{} adminClient, err := connectAdminClinet() - defer adminClient.Disconnect(context.Background()) - if err != nil { return fmt.Errorf("machine_3_3: failed connecting admin client: %v", err) } + defer adminClient.Disconnect(context.Background()) client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { countMutex.Lock() @@ -478,11 +499,10 @@ func machine41ReauthenticationSucceeds() error { countMutex := sync.Mutex{} adminClient, err := connectAdminClinet() - defer adminClient.Disconnect(context.Background()) - if err != nil { return fmt.Errorf("machine_4_1: failed connecting admin client: %v", err) } + defer adminClient.Disconnect(context.Background()) client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { countMutex.Lock() @@ -544,11 +564,10 @@ func machine42ReadCommandsFailIfReauthenticationFails() error { countMutex := sync.Mutex{} adminClient, err := connectAdminClinet() - defer adminClient.Disconnect(context.Background()) - if err != nil { return fmt.Errorf("machine_4_2: failed connecting admin client: %v", err) } + defer adminClient.Disconnect(context.Background()) client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { countMutex.Lock() @@ -625,11 +644,10 @@ func machine43WriteCommandsFailIfReauthenticationFails() error { countMutex := sync.Mutex{} adminClient, err := connectAdminClinet() - defer adminClient.Disconnect(context.Background()) - if err != nil { return fmt.Errorf("machine_4_3: failed connecting admin client: %v", err) } + defer adminClient.Disconnect(context.Background()) client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { countMutex.Lock() @@ -698,6 +716,799 @@ func machine43WriteCommandsFailIfReauthenticationFails() error { return callbackFailed } +func human11singlePrincipalImplictUsername() error { + callbackCount := 0 + var callbackFailed error + countMutex := sync.Mutex{} + + client, err := connectWithHumanCB(uriSingle, func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { + countMutex.Lock() + defer countMutex.Unlock() + callbackCount++ + t := time.Now().Add(time.Hour) + tokenFile := tokenFile("test_user1") + accessToken, err := os.ReadFile(tokenFile) + if err != nil { + callbackFailed = fmt.Errorf("human_1_1: failed reading token file: %v", err) + } + return &options.OIDCCredential{ + AccessToken: string(accessToken), + ExpiresAt: &t, + RefreshToken: nil, + }, nil + }) + + defer client.Disconnect(context.Background()) + + if err != nil { + return fmt.Errorf("human_1_1: failed connecting client: %v", err) + } + + coll := client.Database("test").Collection("test") + + _, err = coll.Find(context.Background(), bson.D{}) + if err != nil { + return fmt.Errorf("human_1_1: failed executing Find: %v", err) + } + countMutex.Lock() + defer countMutex.Unlock() + if callbackCount != 1 { + return fmt.Errorf("human_1_1: expected callback count to be 1, got %d", callbackCount) + } + return callbackFailed +} + +func human12singlePrincipalExplicitUsername() error { + callbackCount := 0 + var callbackFailed error + countMutex := sync.Mutex{} + + opts := options.Client().ApplyURI(uriSingle) + opts.Auth.OIDCHumanCallback = func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { + countMutex.Lock() + defer countMutex.Unlock() + callbackCount++ + t := time.Now().Add(time.Hour) + tokenFile := tokenFile("test_user1") + accessToken, err := os.ReadFile(tokenFile) + if err != nil { + callbackFailed = fmt.Errorf("human_1_2: failed reading token file: %v", err) + } + return &options.OIDCCredential{ + AccessToken: string(accessToken), + ExpiresAt: &t, + RefreshToken: nil, + }, nil + } + opts.Auth.Username = explicitUser("test_user1") + client, err := mongo.Connect(context.Background(), opts) + if err != nil { + return fmt.Errorf("human_1_2: failed connecting client: %v", err) + } + defer client.Disconnect(context.Background()) + + coll := client.Database("test").Collection("test") + + _, err = coll.Find(context.Background(), bson.D{}) + if err != nil { + return fmt.Errorf("human_1_2: failed executing Find: %v", err) + } + countMutex.Lock() + defer countMutex.Unlock() + if callbackCount != 1 { + return fmt.Errorf("human_1_2: expected callback count to be 1, got %d", callbackCount) + } + return callbackFailed +} + +func human13mulitplePrincipalUser1() error { + callbackCount := 0 + var callbackFailed error + countMutex := sync.Mutex{} + + opts := options.Client().ApplyURI(uriMulti) + opts.Auth.OIDCHumanCallback = func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { + countMutex.Lock() + defer countMutex.Unlock() + callbackCount++ + t := time.Now().Add(time.Hour) + tokenFile := tokenFile("test_user1") + accessToken, err := os.ReadFile(tokenFile) + if err != nil { + callbackFailed = fmt.Errorf("human_1_3: failed reading token file: %v", err) + } + return &options.OIDCCredential{ + AccessToken: string(accessToken), + ExpiresAt: &t, + RefreshToken: nil, + }, nil + } + opts.Auth.Username = explicitUser("test_user1") + client, err := mongo.Connect(context.Background(), opts) + if err != nil { + return fmt.Errorf("human_1_3: failed connecting client: %v", err) + } + defer client.Disconnect(context.Background()) + + coll := client.Database("test").Collection("test") + + _, err = coll.Find(context.Background(), bson.D{}) + if err != nil { + return fmt.Errorf("human_1_3: failed executing Find: %v", err) + } + countMutex.Lock() + defer countMutex.Unlock() + if callbackCount != 1 { + return fmt.Errorf("human_1_3: expected callback count to be 1, got %d", callbackCount) + } + return callbackFailed +} + +func human14mulitplePrincipalUser2() error { + callbackCount := 0 + var callbackFailed error + countMutex := sync.Mutex{} + + opts := options.Client().ApplyURI(uriMulti) + opts.Auth.OIDCHumanCallback = func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { + countMutex.Lock() + defer countMutex.Unlock() + callbackCount++ + t := time.Now().Add(time.Hour) + tokenFile := tokenFile("test_user2") + accessToken, err := os.ReadFile(tokenFile) + if err != nil { + callbackFailed = fmt.Errorf("human_1_4: failed reading token file: %v", err) + } + return &options.OIDCCredential{ + AccessToken: string(accessToken), + ExpiresAt: &t, + RefreshToken: nil, + }, nil + } + opts.Auth.Username = explicitUser("test_user2") + client, err := mongo.Connect(context.Background(), opts) + if err != nil { + return fmt.Errorf("human_1_4: failed connecting client: %v", err) + } + defer client.Disconnect(context.Background()) + + coll := client.Database("test").Collection("test") + + _, err = coll.Find(context.Background(), bson.D{}) + if err != nil { + return fmt.Errorf("human_1_4: failed executing Find: %v", err) + } + countMutex.Lock() + defer countMutex.Unlock() + if callbackCount != 1 { + return fmt.Errorf("human_1_4: expected callback count to be 1, got %d", callbackCount) + } + return callbackFailed +} + +func human15mulitplePrincipalNoUser() error { + callbackCount := 0 + var callbackFailed error + countMutex := sync.Mutex{} + + client, err := connectWithHumanCB(uriMulti, func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { + countMutex.Lock() + defer countMutex.Unlock() + callbackCount++ + t := time.Now().Add(time.Hour) + tokenFile := tokenFile("test_user1") + accessToken, err := os.ReadFile(tokenFile) + if err != nil { + callbackFailed = fmt.Errorf("human_1_5: failed reading token file: %v", err) + } + return &options.OIDCCredential{ + AccessToken: string(accessToken), + ExpiresAt: &t, + RefreshToken: nil, + }, nil + }) + if err != nil { + return fmt.Errorf("human_1_5: failed connecting client: %v", err) + } + defer client.Disconnect(context.Background()) + + coll := client.Database("test").Collection("test") + + _, err = coll.Find(context.Background(), bson.D{}) + if err == nil { + return fmt.Errorf("human_1_5: Find succeeded when it should fail") + } + countMutex.Lock() + defer countMutex.Unlock() + if callbackCount != 0 { + return fmt.Errorf("human_1_5: expected callback count to be 0, got %d", callbackCount) + } + return callbackFailed +} + +func human16allowedHostsBlocked() error { + var callbackFailed error + { + opts := options.Client().ApplyURI(uriSingle) + opts.Auth.OIDCHumanCallback = func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { + t := time.Now().Add(time.Hour) + tokenFile := tokenFile("test_user1") + accessToken, err := os.ReadFile(tokenFile) + if err != nil { + callbackFailed = fmt.Errorf("human_1_6: failed reading token file: %v", err) + } + return &options.OIDCCredential{ + AccessToken: string(accessToken), + ExpiresAt: &t, + RefreshToken: nil, + }, nil + } + opts.Auth.AuthMechanismProperties = map[string]string{"ALLOWED_HOSTS": ""} + client, err := mongo.Connect(context.Background(), opts) + if err != nil { + return fmt.Errorf("human_1_4: failed connecting client: %v", err) + } + defer client.Disconnect(context.Background()) + + coll := client.Database("test").Collection("test") + + _, err = coll.Find(context.Background(), bson.D{}) + if err == nil { + return fmt.Errorf("machine_1_6: Find succeeded when it should fail with empty 'ALLOWED_HOSTS'") + } + } + { + opts := options.Client().ApplyURI("mongodb://localhost/?authMechanism=MONGODB-OIDC&ignored=example.com") + opts.Auth.OIDCHumanCallback = func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { + t := time.Now().Add(time.Hour) + tokenFile := tokenFile("test_user1") + accessToken, err := os.ReadFile(tokenFile) + if err != nil { + callbackFailed = fmt.Errorf("human_1_6: failed reading token file: %v", err) + } + return &options.OIDCCredential{ + AccessToken: string(accessToken), + ExpiresAt: &t, + RefreshToken: nil, + }, nil + } + opts.Auth.AuthMechanismProperties = map[string]string{"ALLOWED_HOSTS": "example.com"} + client, err := mongo.Connect(context.Background(), opts) + if err != nil { + return fmt.Errorf("human_1_4: failed connecting client: %v", err) + } + defer client.Disconnect(context.Background()) + + coll := client.Database("test").Collection("test") + + _, err = coll.Find(context.Background(), bson.D{}) + if err == nil { + return fmt.Errorf("machine_1_6: Find succeeded when it should fail with 'ALLOWED_HOSTS' 'example.com'") + } + } + return callbackFailed +} + +func human17AllowedHostsInConnectionStringIgnored() error { + uri := "mongodb+srv://example.com/?authMechanism=MONGODB-OIDC&authMechanismProperties=ALLOWED_HOSTS:%5B%22example.com%22%5D" + opts := options.Client().ApplyURI(uri) + err := opts.Validate() + if err == nil { + return fmt.Errorf("human_1_7: succeeded in applying URI which should produce an error") + } + return nil +} + +func human21validCallbackInputs() error { + callbackCount := 0 + var callbackFailed error + countMutex := sync.Mutex{} + + client, err := connectWithHumanCB(uriSingle, func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { + countMutex.Lock() + defer countMutex.Unlock() + callbackCount++ + t := time.Now().Add(time.Hour) + if args.Version != 1 { + callbackFailed = fmt.Errorf("human_2_1: expected version to be 1, got %d", args.Version) + } + if args.IDPInfo == nil { + callbackFailed = fmt.Errorf("human_2_1: expected IDPInfo to be non-nil, previous error: (%v)", callbackFailed) + } + tokenFile := tokenFile("test_user1") + accessToken, err := os.ReadFile(tokenFile) + if err != nil { + callbackFailed = fmt.Errorf("human_2_1: failed reading token file: %v, previous error: (%v)", err, callbackFailed) + } + return &options.OIDCCredential{ + AccessToken: string(accessToken), + ExpiresAt: &t, + RefreshToken: nil, + }, nil + }) + + defer client.Disconnect(context.Background()) + + if err != nil { + return fmt.Errorf("human_2_1: failed connecting client: %v", err) + } + + coll := client.Database("test").Collection("test") + + _, err = coll.Find(context.Background(), bson.D{}) + if err != nil { + return fmt.Errorf("human_2_1: failed executing Find: %v", err) + } + countMutex.Lock() + defer countMutex.Unlock() + if callbackCount != 1 { + return fmt.Errorf("human_2_1: expected callback count to be 1, got %d", callbackCount) + } + return callbackFailed +} + +func human22CallbackReturnsMissingData() error { + callbackCount := 0 + countMutex := sync.Mutex{} + + client, err := connectWithHumanCB(uriSingle, func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { + countMutex.Lock() + defer countMutex.Unlock() + callbackCount++ + return &options.OIDCCredential{}, nil + }) + + defer client.Disconnect(context.Background()) + + if err != nil { + return fmt.Errorf("human_2_2: failed connecting client: %v", err) + } + + coll := client.Database("test").Collection("test") + + _, err = coll.Find(context.Background(), bson.D{}) + if err == nil { + return fmt.Errorf("human_2_2: Find succeeded when it should fail") + } + countMutex.Lock() + defer countMutex.Unlock() + if callbackCount != 1 { + return fmt.Errorf("human_2_2: expected callback count to be 1, got %d", callbackCount) + } + return nil +} + +func human23RefreshTokenIsPassedToCallback() error { + callbackCount := 0 + var callbackFailed error + countMutex := sync.Mutex{} + + adminClient, err := connectAdminClinet() + if err != nil { + return fmt.Errorf("human_2_3: failed connecting admin client: %v", err) + } + defer adminClient.Disconnect(context.Background()) + + client, err := connectWithHumanCB(uriSingle, func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { + countMutex.Lock() + defer countMutex.Unlock() + callbackCount++ + if callbackCount == 1 && args.RefreshToken != nil { + callbackFailed = fmt.Errorf("human_2_3: expected refresh token to be nil first time, got %v, previous error: (%v)", args.RefreshToken, callbackFailed) + } + if callbackCount == 2 && args.RefreshToken == nil { + callbackFailed = fmt.Errorf("human_2_3: expected refresh token to be non-nil second time, got %v, previous error: (%v)", args.RefreshToken, callbackFailed) + } + t := time.Now().Add(time.Hour) + tokenFile := tokenFile("test_user1") + accessToken, err := os.ReadFile(tokenFile) + if err != nil { + callbackFailed = fmt.Errorf("human_2_3: failed reading token file: %v", err) + } + rt := "this is fake" + return &options.OIDCCredential{ + AccessToken: string(accessToken), + ExpiresAt: &t, + RefreshToken: &rt, + }, nil + }) + + defer client.Disconnect(context.Background()) + + if err != nil { + return fmt.Errorf("human_2_3: failed connecting client: %v", err) + } + + res := adminClient.Database("admin").RunCommand(context.Background(), bson.D{ + {Key: "configureFailPoint", Value: "failCommand"}, + {Key: "mode", Value: bson.D{ + {Key: "times", Value: 1}, + }}, + {Key: "data", Value: bson.D{ + {Key: "failCommands", Value: bson.A{ + "find", + }}, + {Key: "errorCode", Value: 391}, + }}, + }) + + if res.Err() != nil { + return fmt.Errorf("human_2_3: failed to set failpoint") + } + + coll := client.Database("test").Collection("test") + + _, err = coll.Find(context.Background(), bson.D{}) + if err != nil { + return fmt.Errorf("human_2_3: failed executing Find: %v", err) + } + countMutex.Lock() + defer countMutex.Unlock() + if callbackCount != 2 { + return fmt.Errorf("human_2_3: expected callback count to be 2, got %d", callbackCount) + } + return callbackFailed +} + +func human31usesSpeculativeAuth() error { + adminClient, err := connectAdminClinet() + if err != nil { + return fmt.Errorf("human_3_1: failed connecting admin client: %v", err) + } + defer adminClient.Disconnect(context.Background()) + + client, err := connectWithHumanCB(uriSingle, func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { + // the callback should not even be called due to spec auth. + return &options.OIDCCredential{}, nil + }) + + if err != nil { + return fmt.Errorf("human_3_1: failed connecting client: %v", err) + } + defer client.Disconnect(context.Background()) + + // We deviate from the Prose test since the failPoint on find with no error code does not seem to + // work. Rather we put an access token in the cache to force speculative auth. + tokenFile := tokenFile("test_user1") + accessToken, err := os.ReadFile(tokenFile) + if err != nil { + return fmt.Errorf("human_3_1: failed reading token file: %v", err) + } + clientElem := reflect.ValueOf(client).Elem() + authenticatorField := clientElem.FieldByName("authenticator") + authenticatorField = reflect.NewAt( + authenticatorField.Type(), + unsafe.Pointer(authenticatorField.UnsafeAddr())).Elem() + // This is the only usage of the x packages in the test, showing the the public interface is + // correct. + authenticatorField.Interface().(*auth.OIDCAuthenticator).SetAccessToken(string(accessToken)) + + res := adminClient.Database("admin").RunCommand(context.Background(), bson.D{ + {Key: "configureFailPoint", Value: "failCommand"}, + {Key: "mode", Value: bson.D{ + {Key: "times", Value: 1}, + }}, + {Key: "data", Value: bson.D{ + {Key: "failCommands", Value: bson.A{ + "saslStart", + }}, + {Key: "errorCode", Value: 18}, + }}, + }) + + if res.Err() != nil { + return fmt.Errorf("human_3_1: failed to set failpoint") + } + + coll := client.Database("test").Collection("test") + _, err = coll.Find(context.Background(), bson.D{}) + if err != nil { + return fmt.Errorf("human_3_1: failed executing Find: %v", err) + } + + return nil +} + +func human32doesNotUseSpecualtiveAuth() error { + var callbackFailed error + + adminClient, err := connectAdminClinet() + if err != nil { + return fmt.Errorf("human_3_2: failed connecting admin client: %v", err) + } + defer adminClient.Disconnect(context.Background()) + + client, err := connectWithHumanCB(uriSingle, func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { + t := time.Now().Add(time.Hour) + tokenFile := tokenFile("test_user1") + accessToken, err := os.ReadFile(tokenFile) + if err != nil { + callbackFailed = fmt.Errorf("human_3_2: failed reading token file: %v", err) + } + return &options.OIDCCredential{ + AccessToken: string(accessToken), + ExpiresAt: &t, + RefreshToken: nil, + }, nil + }) + + defer client.Disconnect(context.Background()) + + if err != nil { + return fmt.Errorf("human_3_2: failed connecting client: %v", err) + } + + res := adminClient.Database("admin").RunCommand(context.Background(), bson.D{ + {Key: "configureFailPoint", Value: "failCommand"}, + {Key: "mode", Value: bson.D{ + {Key: "times", Value: 1}, + }}, + {Key: "data", Value: bson.D{ + {Key: "failCommands", Value: bson.A{ + "saslStart", + }}, + {Key: "errorCode", Value: 18}, + }}, + }) + + if res.Err() != nil { + return fmt.Errorf("human_3_2: failed to set failpoint") + } + + coll := client.Database("test").Collection("test") + + _, err = coll.Find(context.Background(), bson.D{}) + if err == nil { + return fmt.Errorf("human_3_2: Find succeeded when it should fail") + } + return callbackFailed +} + +func human41ReauthenticationSucceeds() error { + return nil +} + +func human42ReauthenticationSucceedsNoRefreshToken() error { + callbackCount := 0 + var callbackFailed error + countMutex := sync.Mutex{} + + adminClient, err := connectAdminClinet() + if err != nil { + return fmt.Errorf("human_4_2: failed connecting admin client: %v", err) + } + defer adminClient.Disconnect(context.Background()) + + client, err := connectWithHumanCB(uriSingle, func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { + countMutex.Lock() + defer countMutex.Unlock() + callbackCount++ + t := time.Now().Add(time.Hour) + tokenFile := tokenFile("test_user1") + accessToken, err := os.ReadFile(tokenFile) + if err != nil { + callbackFailed = fmt.Errorf("human_4_2: failed reading token file: %v", err) + } + return &options.OIDCCredential{ + AccessToken: string(accessToken), + ExpiresAt: &t, + RefreshToken: nil, + }, nil + }) + + defer client.Disconnect(context.Background()) + + if err != nil { + return fmt.Errorf("human_4_2: failed connecting client: %v", err) + } + + coll := client.Database("test").Collection("test") + + _, err = coll.Find(context.Background(), bson.D{}) + if err != nil { + return fmt.Errorf("human_4_2: failed executing Find: %v", err) + } + + countMutex.Lock() + if callbackCount != 1 { + return fmt.Errorf("human_4_2: expected callback count to be 1, got %d", callbackCount) + } + countMutex.Unlock() + + res := adminClient.Database("admin").RunCommand(context.Background(), bson.D{ + {Key: "configureFailPoint", Value: "failCommand"}, + {Key: "mode", Value: bson.D{ + {Key: "times", Value: 1}, + }}, + {Key: "data", Value: bson.D{ + {Key: "failCommands", Value: bson.A{ + "find", + }}, + {Key: "errorCode", Value: 391}, + }}, + }) + + if res.Err() != nil { + return fmt.Errorf("human_4_2: failed to set failpoint") + } + + _, err = coll.Find(context.Background(), bson.D{}) + if err != nil { + return fmt.Errorf("human_4_2: failed executing Find: %v", err) + } + + countMutex.Lock() + if callbackCount != 2 { + return fmt.Errorf("human_4_2: expected callback count to be 2, got %d", callbackCount) + } + countMutex.Unlock() + return callbackFailed +} + +func human43ReauthenticationSucceedsAfterRefreshFails() error { + callbackCount := 0 + var callbackFailed error + countMutex := sync.Mutex{} + + adminClient, err := connectAdminClinet() + if err != nil { + return fmt.Errorf("human_4_3: failed connecting admin client: %v", err) + } + defer adminClient.Disconnect(context.Background()) + + client, err := connectWithHumanCB(uriSingle, func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { + countMutex.Lock() + defer countMutex.Unlock() + callbackCount++ + t := time.Now().Add(time.Hour) + tokenFile := tokenFile("test_user1") + accessToken, err := os.ReadFile(tokenFile) + if err != nil { + callbackFailed = fmt.Errorf("human_4_3: failed reading token file: %v", err) + } + refreshToken := "bad token" + return &options.OIDCCredential{ + AccessToken: string(accessToken), + ExpiresAt: &t, + RefreshToken: &refreshToken, + }, nil + }) + + defer client.Disconnect(context.Background()) + + if err != nil { + return fmt.Errorf("human_4_3: failed connecting client: %v", err) + } + + coll := client.Database("test").Collection("test") + + _, err = coll.Find(context.Background(), bson.D{}) + if err != nil { + return fmt.Errorf("human_4_3: failed executing Find: %v", err) + } + + countMutex.Lock() + if callbackCount != 1 { + return fmt.Errorf("human_4_3: expected callback count to be 1, got %d", callbackCount) + } + countMutex.Unlock() + + res := adminClient.Database("admin").RunCommand(context.Background(), bson.D{ + {Key: "configureFailPoint", Value: "failCommand"}, + {Key: "mode", Value: bson.D{ + {Key: "times", Value: 1}, + }}, + {Key: "data", Value: bson.D{ + {Key: "failCommands", Value: bson.A{ + "find", + }}, + {Key: "errorCode", Value: 391}, + }}, + }) + + if res.Err() != nil { + return fmt.Errorf("human_4_3: failed to set failpoint") + } + + _, err = coll.Find(context.Background(), bson.D{}) + if err != nil { + return fmt.Errorf("human_4_3: failed executing Find: %v", err) + } + + countMutex.Lock() + if callbackCount != 2 { + return fmt.Errorf("human_4_3: expected callback count to be 2, got %d", callbackCount) + } + countMutex.Unlock() + return callbackFailed +} + +func human44ReauthenticationFails() error { + callbackCount := 0 + var callbackFailed error + countMutex := sync.Mutex{} + + adminClient, err := connectAdminClinet() + if err != nil { + return fmt.Errorf("human_4_4: failed connecting admin client: %v", err) + } + defer adminClient.Disconnect(context.Background()) + + client, err := connectWithHumanCB(uriSingle, func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { + countMutex.Lock() + defer countMutex.Unlock() + callbackCount++ + badToken := "bad token" + t := time.Now().Add(time.Hour) + if callbackCount == 1 { + tokenFile := tokenFile("test_user1") + accessToken, err := os.ReadFile(tokenFile) + if err != nil { + callbackFailed = fmt.Errorf("human_4_4: failed reading token file: %v", err) + } + return &options.OIDCCredential{ + AccessToken: string(accessToken), + ExpiresAt: &t, + RefreshToken: &badToken, + }, nil + } + return &options.OIDCCredential{ + AccessToken: badToken, + ExpiresAt: &t, + RefreshToken: &badToken, + }, fmt.Errorf("failed to refresh token") + }) + + defer client.Disconnect(context.Background()) + + if err != nil { + return fmt.Errorf("human_4_4: failed connecting client: %v", err) + } + + coll := client.Database("test").Collection("test") + + _, err = coll.Find(context.Background(), bson.D{}) + if err != nil { + return fmt.Errorf("human_4_4: failed executing Find: %v", err) + } + + countMutex.Lock() + if callbackCount != 1 { + return fmt.Errorf("human_4_4: expected callback count to be 1, got %d", callbackCount) + } + countMutex.Unlock() + + res := adminClient.Database("admin").RunCommand(context.Background(), bson.D{ + {Key: "configureFailPoint", Value: "failCommand"}, + {Key: "mode", Value: bson.D{ + {Key: "times", Value: 1}, + }}, + {Key: "data", Value: bson.D{ + {Key: "failCommands", Value: bson.A{ + "find", + }}, + {Key: "errorCode", Value: 391}, + }}, + }) + + if res.Err() != nil { + return fmt.Errorf("human_4_4: failed to set failpoint") + } + + _, err = coll.Find(context.Background(), bson.D{}) + if err == nil { + return fmt.Errorf("human_4_4: Find succeeded when it should fail") + } + + countMutex.Lock() + if callbackCount != 3 { + return fmt.Errorf("human_4_4: expected callback count to be 3, got %d", callbackCount) + } + countMutex.Unlock() + return callbackFailed +} + func machine51azureWithNoUsername() error { opts := options.Client().ApplyURI(uriSingle) if opts == nil || opts.Auth == nil { diff --git a/x/mongo/driver/auth/internal/gssapi/gss.go b/x/mongo/driver/auth/internal/gssapi/gss.go index abfa4db47c..496057882d 100644 --- a/x/mongo/driver/auth/internal/gssapi/gss.go +++ b/x/mongo/driver/auth/internal/gssapi/gss.go @@ -19,6 +19,7 @@ package gssapi */ import "C" import ( + "context" "fmt" "runtime" "strings" @@ -91,12 +92,12 @@ func (sc *SaslClient) Start() (string, []byte, error) { return mechName, nil, sc.getError("unable to initialize client") } - payload, err := sc.Next(nil) + payload, err := sc.Next(nil, nil) return mechName, payload, err } -func (sc *SaslClient) Next(challenge []byte) ([]byte, error) { +func (sc *SaslClient) Next(_ context.Context, challenge []byte) ([]byte, error) { var buf unsafe.Pointer var bufLen C.size_t diff --git a/x/mongo/driver/auth/internal/gssapi/sspi.go b/x/mongo/driver/auth/internal/gssapi/sspi.go index 6e7d3ed8ad..d73da025bb 100644 --- a/x/mongo/driver/auth/internal/gssapi/sspi.go +++ b/x/mongo/driver/auth/internal/gssapi/sspi.go @@ -12,6 +12,7 @@ package gssapi // #include "sspi_wrapper.h" import "C" import ( + "context" "fmt" "net" "strconv" @@ -120,7 +121,7 @@ func (sc *SaslClient) Start() (string, []byte, error) { return mechName, payload, err } -func (sc *SaslClient) Next(challenge []byte) ([]byte, error) { +func (sc *SaslClient) Next(_ context.Context, challenge []byte) ([]byte, error) { var outBuf C.PVOID var outBufLen C.ULONG diff --git a/x/mongo/driver/auth/mongodbaws.go b/x/mongo/driver/auth/mongodbaws.go index 2245bdb6fe..c5cebaa27f 100644 --- a/x/mongo/driver/auth/mongodbaws.go +++ b/x/mongo/driver/auth/mongodbaws.go @@ -82,7 +82,7 @@ func (a *awsSaslAdapter) Start() (string, []byte, error) { return MongoDBAWS, step, nil } -func (a *awsSaslAdapter) Next(challenge []byte) ([]byte, error) { +func (a *awsSaslAdapter) Next(_ context.Context, challenge []byte) ([]byte, error) { step, err := a.conversation.Step(challenge) if err != nil { return nil, err diff --git a/x/mongo/driver/auth/oidc.go b/x/mongo/driver/auth/oidc.go index fe0584eb14..454a1f635d 100644 --- a/x/mongo/driver/auth/oidc.go +++ b/x/mongo/driver/auth/oidc.go @@ -13,10 +13,12 @@ import ( "io" "net/http" "net/url" + "regexp" "strings" "sync" "time" + "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" ) @@ -24,14 +26,10 @@ import ( // MongoDBOIDC is the string constant for the MONGODB-OIDC authentication mechanism. const MongoDBOIDC = "MONGODB-OIDC" -// TODO GODRIVER-2728: Automatic token acquisition for Azure Identity Provider // const tokenResourceProp = "TOKEN_RESOURCE" const environmentProp = "ENVIRONMENT" - const resourceProp = "TOKEN_RESOURCE" - -// GODRIVER-3249 OIDC: Handle all possible OIDC configuration errors -//const allowedHostsProp = "ALLOWED_HOSTS" +const allowedHostsProp = "ALLOWED_HOSTS" const azureEnvironmentValue = "azure" const gcpEnvironmentValue = "gcp" @@ -44,18 +42,18 @@ const invalidateSleepTimeout = 100 * time.Millisecond // ambiguous for the v1.x Go Driver because it could mean either "no timeout provided" or "CSOT not // enabled". Always use a maximum timeout duration of 1 minute, allowing us to ignore the ambiguity. // Contexts with a shorter timeout are unaffected. -const machineCallbackTimeout = 60 * time.Second - -//GODRIVER-3246 OIDC: Implement Human Callback Mechanism -//var defaultAllowedHosts = []string{ -// "*.mongodb.net", -// "*.mongodb-qa.net", -// "*.mongodb-dev.net", -// "*.mongodbgov.net", -// "localhost", -// "127.0.0.1", -// "::1", -//} +const machineCallbackTimeout = time.Minute +const humanCallbackTimeout = 5 * time.Minute + +var defaultAllowedHosts = []*regexp.Regexp{ + regexp.MustCompile(`^.*[.]mongodb[.]net(:\d+)?$`), + regexp.MustCompile(`^.*[.]mongodb-qa[.]net(:\d+)?$`), + regexp.MustCompile(`^.*[.]mongodb-dev[.]net(:\d+)?$`), + regexp.MustCompile(`^.*[.]mongodbgov[.]net(:\d+)?$`), + regexp.MustCompile(`^localhost(:\d+)?$`), + regexp.MustCompile(`^127[.]0[.]0[.]1(:\d+)?$`), + regexp.MustCompile(`^::1(:\d+)?$`), +} // OIDCCallback is a function that takes a context and OIDCArgs and returns an OIDCCredential. type OIDCCallback = driver.OIDCCallback @@ -72,6 +70,7 @@ type IDPInfo = driver.IDPInfo var _ driver.Authenticator = (*OIDCAuthenticator)(nil) var _ SpeculativeAuthenticator = (*OIDCAuthenticator)(nil) var _ SaslClient = (*oidcOneStep)(nil) +var _ SaslClient = (*oidcTwoStep)(nil) // OIDCAuthenticator is synchronized and handles caching of the access token, refreshToken, // and IDPInfo. It also provides a mechanism to refresh the access token, but this functionality @@ -83,6 +82,7 @@ type OIDCAuthenticator struct { OIDCMachineCallback OIDCCallback OIDCHumanCallback OIDCCallback + allowedHosts *[]*regexp.Regexp userName string httpClient *http.Client accessToken string @@ -127,7 +127,59 @@ func newOIDCAuthenticator(cred *Cred, httpClient *http.Client) (Authenticator, e OIDCMachineCallback: cred.OIDCMachineCallback, OIDCHumanCallback: cred.OIDCHumanCallback, } - return oa, nil + err := oa.setAllowedHosts() + return oa, err +} + +func createPatternsForGlobs(hosts []string) ([]*regexp.Regexp, error) { + var err error + ret := make([]*regexp.Regexp, len(hosts)) + for i := range hosts { + hosts[i] = strings.ReplaceAll(hosts[i], ".", "[.]") + hosts[i] = strings.ReplaceAll(hosts[i], "*", ".*") + hosts[i] = "^" + hosts[i] + "(:\\d+)?$" + ret[i], err = regexp.Compile(hosts[i]) + if err != nil { + return nil, err + } + } + return ret, nil +} + +func (oa *OIDCAuthenticator) setAllowedHosts() error { + if oa.AuthMechanismProperties == nil { + oa.allowedHosts = &defaultAllowedHosts + return nil + } + allowedHosts, ok := oa.AuthMechanismProperties[allowedHostsProp] + if !ok { + oa.allowedHosts = &defaultAllowedHosts + return nil + } + globs := strings.Split(allowedHosts, ",") + ret, err := createPatternsForGlobs(globs) + if err != nil { + return err + } + oa.allowedHosts = &ret + return nil +} + +func (oa *OIDCAuthenticator) validateConnectionAddressWithAllowedHosts(conn driver.Connection) error { + if oa.allowedHosts == nil { + // should be unreachable, but this is a safety check. + return newAuthError(fmt.Sprintf("%q missing", allowedHostsProp), nil) + } + allowedHosts := *oa.allowedHosts + if len(allowedHosts) == 0 { + return newAuthError(fmt.Sprintf("empty %q specified", allowedHostsProp), nil) + } + for _, pattern := range allowedHosts { + if pattern.MatchString(string(conn.Address())) { + return nil + } + } + return newAuthError(fmt.Sprintf("address %q not allowed by %q: %v", conn.Address(), allowedHostsProp, allowedHosts), nil) } type oidcOneStep struct { @@ -135,26 +187,30 @@ type oidcOneStep struct { accessToken string } +type oidcTwoStep struct { + conn driver.Connection + oa *OIDCAuthenticator +} + func jwtStepRequest(accessToken string) []byte { return bsoncore.NewDocumentBuilder(). AppendString("jwt", accessToken). Build() } -// TODO GODRIVER-3246: Implement OIDC human flow -//func principalStepRequest(principal string) []byte { -// doc := bsoncore.NewDocumentBuilder() -// if principal != "" { -// doc.AppendString("n", principal) -// } -// return doc.Build() -//} +func principalStepRequest(principal string) []byte { + doc := bsoncore.NewDocumentBuilder() + if principal != "" { + doc.AppendString("n", principal) + } + return doc.Build() +} func (oos *oidcOneStep) Start() (string, []byte, error) { return MongoDBOIDC, jwtStepRequest(oos.accessToken), nil } -func (oos *oidcOneStep) Next([]byte) ([]byte, error) { +func (oos *oidcOneStep) Next(context.Context, []byte) ([]byte, error) { return nil, newAuthError("unexpected step in OIDC authentication", nil) } @@ -162,6 +218,36 @@ func (*oidcOneStep) Completed() bool { return true } +func (ots *oidcTwoStep) Start() (string, []byte, error) { + return MongoDBOIDC, principalStepRequest(ots.oa.userName), nil +} + +func (ots *oidcTwoStep) Next(ctx context.Context, msg []byte) ([]byte, error) { + var idpInfo IDPInfo + err := bson.Unmarshal(msg, &idpInfo) + if err != nil { + return nil, fmt.Errorf("error unmarshaling BSON document: %w", err) + } + + accessToken, err := ots.oa.getAccessToken(ctx, + ots.conn, + &OIDCArgs{ + Version: apiVersion, + // idpInfo is nil for machine callbacks in the current spec. + IDPInfo: &idpInfo, + // there is no way there could be a refresh token when there is no IDPInfo. + RefreshToken: nil, + }, + // two-step callbacks are always human callbacks. + ots.oa.OIDCHumanCallback) + + return jwtStepRequest(accessToken), err +} + +func (*oidcTwoStep) Completed() bool { + return true +} + func (oa *OIDCAuthenticator) providerCallback() (OIDCCallback, error) { env, ok := oa.AuthMechanismProperties[environmentProp] if !ok { @@ -273,43 +359,40 @@ func (oa *OIDCAuthenticator) getAccessToken( return oa.accessToken, nil } + // Attempt to refresh the access token if a refresh token is available. + if args.RefreshToken != nil { + cred, err := callback(ctx, args) + if err == nil && cred != nil { + oa.accessToken = cred.AccessToken + oa.tokenGenID++ + conn.SetOIDCTokenGenID(oa.tokenGenID) + oa.refreshToken = cred.RefreshToken + return cred.AccessToken, nil + } + oa.refreshToken = nil + args.RefreshToken = nil + } + // If we get here this means there either was no refresh token or the refresh token failed. cred, err := callback(ctx, args) if err != nil { return "", err } + // This line should never occur, if go conventions are followed, but it is a safety check such + // that we do not throw nil pointer errors to our users if they abuse the API. + if cred == nil { + return "", newAuthError("OIDC callback returned nil credential with no specified error", nil) + } oa.accessToken = cred.AccessToken oa.tokenGenID++ conn.SetOIDCTokenGenID(oa.tokenGenID) - if cred.RefreshToken != nil { - oa.refreshToken = cred.RefreshToken - } + oa.refreshToken = cred.RefreshToken + // always set the IdPInfo, in most cases, this should just be recopying the same pointer, or nil + // in the machine flow. + oa.idpInfo = args.IDPInfo return cred.AccessToken, nil } -// TODO GODRIVER-3246: Implement OIDC human flow -// This should only be called with the Mutex held. -//func (oa *OIDCAuthenticator) getAccessTokenWithRefresh( -// ctx context.Context, -// callback OIDCCallback, -// refreshToken string, -//) (string, error) { -// -// cred, err := callback(ctx, &OIDCArgs{ -// Version: apiVersion, -// IDPInfo: oa.idpInfo, -// RefreshToken: &refreshToken, -// }) -// if err != nil { -// return "", err -// } -// -// oa.accessToken = cred.AccessToken -// oa.tokenGenID++ -// oa.cfg.Connection.SetOIDCTokenGenID(oa.tokenGenID) -// return cred.AccessToken, nil -//} - // invalidateAccessToken invalidates the access token, if the force flag is set to true (which is // only on a Reauth call) or if the tokenGenID of the connection is greater than or equal to the // tokenGenID of the OIDCAuthenticator. It should never actually be greater than, but only equal, @@ -346,6 +429,8 @@ func (oa *OIDCAuthenticator) Auth(ctx context.Context, cfg *Config) error { oa.mu.Lock() cachedAccessToken := oa.accessToken + cachedRefreshToken := oa.refreshToken + cachedIDPInfo := oa.idpInfo oa.mu.Unlock() if cachedAccessToken != "" { @@ -364,7 +449,7 @@ func (oa *OIDCAuthenticator) Auth(ctx context.Context, cfg *Config) error { } if oa.OIDCHumanCallback != nil { - return oa.doAuthHuman(ctx, cfg, oa.OIDCHumanCallback) + return oa.doAuthHuman(ctx, cfg, oa.OIDCHumanCallback, cachedIDPInfo, cachedRefreshToken) } // Handle user provided or automatic provider machine callback. @@ -384,9 +469,41 @@ func (oa *OIDCAuthenticator) Auth(ctx context.Context, cfg *Config) error { return newAuthError("no OIDC callback provided", nil) } -func (oa *OIDCAuthenticator) doAuthHuman(_ context.Context, _ *Config, _ OIDCCallback) error { - // TODO GODRIVER-3246: Implement OIDC human flow - return newAuthError("OIDC", fmt.Errorf("human flow not implemented yet, %v", oa.idpInfo)) +func (oa *OIDCAuthenticator) doAuthHuman(ctx context.Context, cfg *Config, humanCallback OIDCCallback, idpInfo *IDPInfo, refreshToken *string) error { + // Ensure that the connection address is allowed by the allowed hosts. + err := oa.validateConnectionAddressWithAllowedHosts(cfg.Connection) + if err != nil { + return err + } + subCtx, cancel := context.WithTimeout(ctx, humanCallbackTimeout) + defer cancel() + // If the idpInfo exists, we can just do one step + if idpInfo != nil { + accessToken, err := oa.getAccessToken(subCtx, + cfg.Connection, + &OIDCArgs{ + Version: apiVersion, + // idpInfo is nil for machine callbacks in the current spec. + IDPInfo: idpInfo, + RefreshToken: refreshToken, + }, + humanCallback) + if err != nil { + return err + } + return ConductSaslConversation( + subCtx, + cfg, + "$external", + &oidcOneStep{accessToken: accessToken}, + ) + } + // otherwise, we need the two step where we ask the server for the IdPInfo first. + ots := &oidcTwoStep{ + conn: cfg.Connection, + oa: oa, + } + return ConductSaslConversation(subCtx, cfg, "$external", ots) } func (oa *OIDCAuthenticator) doAuthMachine(ctx context.Context, cfg *Config, machineCallback OIDCCallback) error { @@ -412,7 +529,7 @@ func (oa *OIDCAuthenticator) doAuthMachine(ctx context.Context, cfg *Config, mac ) } -// CreateSpeculativeConversation creates a speculative conversation for SCRAM authentication. +// CreateSpeculativeConversation creates a speculative conversation for OIDC authentication. func (oa *OIDCAuthenticator) CreateSpeculativeConversation() (SpeculativeConversation, error) { oa.mu.Lock() defer oa.mu.Unlock() diff --git a/x/mongo/driver/auth/oidc_test.go b/x/mongo/driver/auth/oidc_test.go new file mode 100644 index 0000000000..dcb941aff1 --- /dev/null +++ b/x/mongo/driver/auth/oidc_test.go @@ -0,0 +1,44 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package auth + +import ( + "regexp" + "testing" + + "go.mongodb.org/mongo-driver/internal/assert" +) + +func TestCreatePatternsForGlobs(t *testing.T) { + t.Run("transform allowedHosts patterns", func(t *testing.T) { + + hosts := []string{ + "*.mongodb.net", + "*.mongodb-qa.net", + "*.mongodb-dev.net", + "*.mongodbgov.net", + "localhost", + "127.0.0.1", + "::1", + } + + check, err := createPatternsForGlobs(hosts) + assert.NoError(t, err) + assert.Equal(t, + []*regexp.Regexp{ + regexp.MustCompile(`^.*[.]mongodb[.]net(:\d+)?$`), + regexp.MustCompile(`^.*[.]mongodb-qa[.]net(:\d+)?$`), + regexp.MustCompile(`^.*[.]mongodb-dev[.]net(:\d+)?$`), + regexp.MustCompile(`^.*[.]mongodbgov[.]net(:\d+)?$`), + regexp.MustCompile(`^localhost(:\d+)?$`), + regexp.MustCompile(`^127[.]0[.]0[.]1(:\d+)?$`), + regexp.MustCompile(`^::1(:\d+)?$`), + }, + check, + ) + }) +} diff --git a/x/mongo/driver/auth/plain.go b/x/mongo/driver/auth/plain.go index 3e4c5b4eb3..9fce7ec383 100644 --- a/x/mongo/driver/auth/plain.go +++ b/x/mongo/driver/auth/plain.go @@ -54,7 +54,7 @@ func (c *plainSaslClient) Start() (string, []byte, error) { return PLAIN, b, nil } -func (c *plainSaslClient) Next([]byte) ([]byte, error) { +func (c *plainSaslClient) Next(context.Context, []byte) ([]byte, error) { return nil, newAuthError("unexpected server challenge", nil) } diff --git a/x/mongo/driver/auth/sasl.go b/x/mongo/driver/auth/sasl.go index 75f0c411bf..1ef67f02b0 100644 --- a/x/mongo/driver/auth/sasl.go +++ b/x/mongo/driver/auth/sasl.go @@ -19,7 +19,7 @@ import ( // SaslClient is the client piece of a sasl conversation. type SaslClient interface { Start() (string, []byte, error) - Next(challenge []byte) ([]byte, error) + Next(ctx context.Context, challenge []byte) ([]byte, error) Completed() bool } @@ -118,7 +118,7 @@ func (sc *saslConversation) Finish(ctx context.Context, cfg *Config, firstRespon return nil } - payload, err = sc.client.Next(saslResp.Payload) + payload, err = sc.client.Next(ctx, saslResp.Payload) if err != nil { return newError(err, sc.mechanism) } diff --git a/x/mongo/driver/auth/scram.go b/x/mongo/driver/auth/scram.go index 291492e6ff..8c04ce32cc 100644 --- a/x/mongo/driver/auth/scram.go +++ b/x/mongo/driver/auth/scram.go @@ -119,7 +119,7 @@ func (a *scramSaslAdapter) Start() (string, []byte, error) { return a.mechanism, []byte(step), nil } -func (a *scramSaslAdapter) Next(challenge []byte) ([]byte, error) { +func (a *scramSaslAdapter) Next(_ context.Context, challenge []byte) ([]byte, error) { step, err := a.conversation.Step(string(challenge)) if err != nil { return nil, err