Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 27 additions & 65 deletions network/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,18 @@ import (
"github.com/rs/zerolog"
)

type Callback func(key, value interface{}) bool

type Pool interface {
ForEach(callback func(client *Client) error)
ForEach(Callback)
Pool() *sync.Map
ClientIDs() []string
Put(client *Client) error
Pop(ID string) *Client
// ClientIDs() []string
Put(key, value interface{})
Pop(key interface{}) interface{}
Size() int
Close() error
Shutdown()
Clear()
// Close() error
// Shutdown()
}

type PoolImpl struct {
Expand All @@ -25,51 +28,23 @@ type PoolImpl struct {

var _ Pool = &PoolImpl{}

func (p *PoolImpl) ForEach(callback func(client *Client) error) {
p.pool.Range(func(key, value interface{}) bool {
if c, ok := value.(*Client); ok {
err := callback(c)
if err != nil {
p.logger.Debug().Err(err).Msg("an error occurred running the callback")
}
return true
}

return false
})
func (p *PoolImpl) ForEach(cb Callback) {
p.pool.Range(cb)
}

func (p *PoolImpl) Pool() *sync.Map {
return &p.pool
}

func (p *PoolImpl) ClientIDs() []string {
var ids []string
p.pool.Range(func(key, _ interface{}) bool {
if id, ok := key.(string); ok {
ids = append(ids, id)
return true
}
return false
})
return ids
func (p *PoolImpl) Put(key, value interface{}) {
p.pool.Store(key, value)
p.logger.Debug().Msg("Item has been put on the pool")
}

func (p *PoolImpl) Put(client *Client) error {
p.pool.Store(client.ID, client)
p.logger.Debug().Msgf("Client %s has been put on the pool", client.ID)

return nil
}

func (p *PoolImpl) Pop(id string) *Client {
if client, ok := p.pool.Load(id); ok {
p.pool.Delete(id)
p.logger.Debug().Msgf("Client %s has been popped from the pool", id)
if c, ok := client.(*Client); ok {
return c
}
return nil
func (p *PoolImpl) Pop(key interface{}) interface{} {
if value, ok := p.pool.LoadAndDelete(key); ok {
p.logger.Debug().Msg("Item has been popped from the pool")
return value
}

return nil
Expand All @@ -85,25 +60,14 @@ func (p *PoolImpl) Size() int {
return size
}

func (p *PoolImpl) Close() error {
p.ForEach(func(client *Client) error {
client.Close()
return nil
})

return nil
func (p *PoolImpl) Clear() {
p.pool = sync.Map{}
}

func (p *PoolImpl) Shutdown() {
p.pool.Range(func(key, value interface{}) bool {
if cl, ok := value.(*Client); ok {
cl.Close()
}
p.pool.Delete(key)
return true
})

p.pool = sync.Map{}
func NewEmptyPool(logger zerolog.Logger) Pool {
return &PoolImpl{
logger: logger,
}
}

func NewPool(
Expand Down Expand Up @@ -131,15 +95,13 @@ func NewPool(
}

if client != nil {
if err := pool.Put(client); err != nil {
logger.Panic().Err(err).Msg("Failed to add client to pool")
}
pool.Put(client.ID, client)
}
}

// Verify that the pool is properly populated
logger.Info().Msgf("There are %d clients in the pool", len(pool.ClientIDs()))
if len(pool.ClientIDs()) != poolSize {
logger.Info().Msgf("There are %d clients in the pool", pool.Size())
if pool.Size() != poolSize {
logger.Error().Msg(
"The pool size is incorrect, either because " +
"the clients are cannot connect (no network connectivity) " +
Expand Down
106 changes: 54 additions & 52 deletions network/pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,13 @@ func TestNewPool(t *testing.T) {

logger := logging.NewLogger(cfg)
pool := NewPool(logger, 0, nil, nil)
defer pool.Close()
defer pool.Clear()
assert.NotNil(t, pool)
assert.NotNil(t, pool.Pool())
assert.Equal(t, 0, pool.Size())
}

func TestPool_Put(t *testing.T) {
postgres := embeddedpostgres.NewDatabase()
if err := postgres.Start(); err != nil {
t.Fatal(err)
}

defer func() {
if err := postgres.Stop(); err != nil {
t.Fatal(err)
}
}()

func TestNewEmptyPool(t *testing.T) {
cfg := logging.LoggerConfig{
Output: nil,
TimeFormat: zerolog.TimeFormatUnix,
Expand All @@ -45,19 +34,14 @@ func TestPool_Put(t *testing.T) {
}

logger := logging.NewLogger(cfg)

pool := NewPool(logger, 0, nil, nil)
defer pool.Close()
pool := NewEmptyPool(logger)
defer pool.Clear()
assert.NotNil(t, pool)
assert.NotNil(t, pool.Pool())
assert.Equal(t, 0, pool.Size())
assert.NoError(t, pool.Put(NewClient("tcp", "localhost:5432", DefaultBufferSize, logger)))
assert.Equal(t, 1, pool.Size())
assert.NoError(t, pool.Put(NewClient("tcp", "localhost:5432", DefaultBufferSize, logger)))
assert.Equal(t, 2, pool.Size())
}

func TestPool_Pop(t *testing.T) {
func TestPool_Put(t *testing.T) {
postgres := embeddedpostgres.NewDatabase()
if err := postgres.Start(); err != nil {
t.Fatal(err)
Expand All @@ -79,25 +63,19 @@ func TestPool_Pop(t *testing.T) {
logger := logging.NewLogger(cfg)

pool := NewPool(logger, 0, nil, nil)
defer pool.Close()
defer pool.Clear()
assert.NotNil(t, pool)
assert.NotNil(t, pool.Pool())
assert.Equal(t, 0, pool.Size())
client1 := NewClient("tcp", "localhost:5432", DefaultBufferSize, logger)
assert.NoError(t, pool.Put(client1))
pool.Put(client1.ID, client1)
assert.Equal(t, 1, pool.Size())
client2 := NewClient("tcp", "localhost:5432", DefaultBufferSize, logger)
assert.NoError(t, pool.Put(client2))
pool.Put(client2.ID, client2)
assert.Equal(t, 2, pool.Size())
client := pool.Pop(client1.ID)
assert.Equal(t, client1.ID, client.ID)
assert.Equal(t, 1, pool.Size())
client = pool.Pop(client2.ID)
assert.Equal(t, client2.ID, client.ID)
assert.Equal(t, 0, pool.Size())
}

func TestPool_Close(t *testing.T) {
func TestPool_Pop(t *testing.T) {
postgres := embeddedpostgres.NewDatabase()
if err := postgres.Start(); err != nil {
t.Fatal(err)
Expand All @@ -119,21 +97,31 @@ func TestPool_Close(t *testing.T) {
logger := logging.NewLogger(cfg)

pool := NewPool(logger, 0, nil, nil)
defer pool.Clear()
assert.NotNil(t, pool)
assert.NotNil(t, pool.Pool())
assert.Equal(t, 0, pool.Size())
client1 := NewClient("tcp", "localhost:5432", DefaultBufferSize, logger)
assert.NoError(t, pool.Put(client1))
pool.Put(client1.ID, client1)
assert.Equal(t, 1, pool.Size())
client2 := NewClient("tcp", "localhost:5432", DefaultBufferSize, logger)
assert.NoError(t, pool.Put(client2))
assert.Equal(t, 2, pool.Size())
err := pool.Close()
assert.Nil(t, err)
pool.Put(client2.ID, client2)
assert.Equal(t, 2, pool.Size())
if c1, ok := pool.Pop(client1.ID).(*Client); !ok {
assert.Equal(t, c1, client1)
} else {
assert.Equal(t, client1.ID, c1.ID)
assert.Equal(t, 1, pool.Size())
}
if c2, ok := pool.Pop(client2.ID).(*Client); !ok {
assert.Equal(t, c2, client2)
} else {
assert.Equal(t, client2.ID, c2.ID)
assert.Equal(t, 0, pool.Size())
}
}

func TestPool_Shutdown(t *testing.T) {
func TestPool_Clear(t *testing.T) {
postgres := embeddedpostgres.NewDatabase()
if err := postgres.Start(); err != nil {
t.Fatal(err)
Expand All @@ -155,17 +143,17 @@ func TestPool_Shutdown(t *testing.T) {
logger := logging.NewLogger(cfg)

pool := NewPool(logger, 0, nil, nil)
defer pool.Close()
defer pool.Clear()
assert.NotNil(t, pool)
assert.NotNil(t, pool.Pool())
assert.Equal(t, 0, pool.Size())
client1 := NewClient("tcp", "localhost:5432", DefaultBufferSize, logger)
assert.NoError(t, pool.Put(client1))
pool.Put(client1.ID, client1)
assert.Equal(t, 1, pool.Size())
client2 := NewClient("tcp", "localhost:5432", DefaultBufferSize, logger)
assert.NoError(t, pool.Put(client2))
pool.Put(client2.ID, client2)
assert.Equal(t, 2, pool.Size())
pool.Shutdown()
pool.Clear()
assert.Equal(t, 0, pool.Size())
}

Expand All @@ -191,23 +179,25 @@ func TestPool_ForEach(t *testing.T) {
logger := logging.NewLogger(cfg)

pool := NewPool(logger, 0, nil, nil)
defer pool.Close()
defer pool.Clear()
assert.NotNil(t, pool)
assert.NotNil(t, pool.Pool())
assert.Equal(t, 0, pool.Size())
client1 := NewClient("tcp", "localhost:5432", DefaultBufferSize, logger)
assert.NoError(t, pool.Put(client1))
pool.Put(client1.ID, client1)
assert.Equal(t, 1, pool.Size())
client2 := NewClient("tcp", "localhost:5432", DefaultBufferSize, logger)
assert.NoError(t, pool.Put(client2))
pool.Put(client2.ID, client2)
assert.Equal(t, 2, pool.Size())
pool.ForEach(func(client *Client) error {
assert.NotNil(t, client)
return nil
pool.ForEach(func(key, value interface{}) bool {
if c, ok := value.(*Client); ok {
assert.NotNil(t, c)
}
return true
})
}

func TestPool_ClientIDs(t *testing.T) {
func TestPool_GetClientIDs(t *testing.T) {
postgres := embeddedpostgres.NewDatabase()
if err := postgres.Start(); err != nil {
t.Fatal(err)
Expand All @@ -229,16 +219,28 @@ func TestPool_ClientIDs(t *testing.T) {
logger := logging.NewLogger(cfg)

pool := NewPool(logger, 0, nil, nil)
defer pool.Close()
defer pool.Clear()
assert.NotNil(t, pool)
assert.NotNil(t, pool.Pool())
assert.Equal(t, 0, pool.Size())
client1 := NewClient("tcp", "localhost:5432", DefaultBufferSize, logger)
assert.NoError(t, pool.Put(client1))
pool.Put(client1.ID, client1)
assert.Equal(t, 1, pool.Size())
client2 := NewClient("tcp", "localhost:5432", DefaultBufferSize, logger)
assert.NoError(t, pool.Put(client2))
pool.Put(client2.ID, client2)
assert.Equal(t, 2, pool.Size())
ids := pool.ClientIDs()

var ids []string
pool.ForEach(func(key, value interface{}) bool {
if id, ok := key.(string); ok {
ids = append(ids, id)
}
return true
})
assert.Equal(t, 2, len(ids))
assert.Contains(t, client1.ID, ids[0])
assert.Contains(t, client2.ID, ids[1])
client1.Close()
client2.Close()
pool.Clear()
}
Loading