diff --git a/mongo/integration/sdam_prose_test.go b/mongo/integration/sdam_prose_test.go index 615c77569b..3107dcb97d 100644 --- a/mongo/integration/sdam_prose_test.go +++ b/mongo/integration/sdam_prose_test.go @@ -11,6 +11,8 @@ import ( "net" "os" "runtime" + "sync" + "sync/atomic" "testing" "time" @@ -232,4 +234,45 @@ func TestServerHeartbeatStartedEvent(t *testing.T) { } assert.Equal(t, expectedEvents, actualEvents) }) + + mt := mtest.New(t) + + mt.Run("polling must await frequency", func(mt *mtest.T) { + var heartbeatStartedCount atomic.Int64 + + servers := map[string]bool{} + serversMu := sync.RWMutex{} // Guard the servers set + + serverMonitor := &event.ServerMonitor{ + ServerHeartbeatStarted: func(*event.ServerHeartbeatStartedEvent) { + heartbeatStartedCount.Add(1) + }, + TopologyDescriptionChanged: func(evt *event.TopologyDescriptionChangedEvent) { + serversMu.Lock() + defer serversMu.Unlock() + + for _, srv := range evt.NewDescription.Servers { + servers[srv.Addr.String()] = true + } + }, + } + + // Create a client with heartbeatFrequency=100ms, + // serverMonitoringMode=poll. Use SDAM to record the number of times the + // a heartbeat is started and the number of servers discovered. + mt.ResetClient(options.Client(). + SetServerMonitor(serverMonitor). + SetServerMonitoringMode(options.ServerMonitoringModePoll)) + + // Per specifications, minHeartbeatFrequencyMS=500ms. So, within the first + // 500ms the heartbeatStartedCount should be LEQ to the number of discovered + // servers. + time.Sleep(500 * time.Millisecond) + + serversMu.Lock() + serverCount := int64(len(servers)) + serversMu.Unlock() + + assert.LessOrEqual(mt, heartbeatStartedCount.Load(), serverCount) + }) } diff --git a/x/mongo/driver/topology/server.go b/x/mongo/driver/topology/server.go index 99f8dd618b..a29eea4a6d 100644 --- a/x/mongo/driver/topology/server.go +++ b/x/mongo/driver/topology/server.go @@ -666,7 +666,7 @@ func (s *Server) update() { s.monitorOnce.Do(s.rttMonitor.connect) } - if isStreamable(s) || connectionIsStreaming || transitionedFromNetworkError { + if isStreamingEnabled(s) && (isStreamable(s) || connectionIsStreaming) || transitionedFromNetworkError { continue }