-
Notifications
You must be signed in to change notification settings - Fork 89
hot-reload (minimal test impl, opt-in) #3247
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 5 commits
811a3f7
c35783d
fb7b3c7
0dad688
a424de8
6fddf0c
b27c3e5
02bf17d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -23,6 +23,7 @@ import ( | |||||
"encoding/hex" | ||||||
"encoding/pem" | ||||||
"reflect" | ||||||
"sync/atomic" | ||||||
"time" | ||||||
|
||||||
"github.com/vdaas/vald/internal/errors" | ||||||
|
@@ -60,29 +61,33 @@ var ( | |||||
// credentials holds TLS settings for server and client | ||||||
// including certificate paths, CA bundle, and hot reload policies. | ||||||
type credentials struct { | ||||||
cfg *Config | ||||||
cert string | ||||||
key string | ||||||
ca string | ||||||
sn string | ||||||
insecure bool | ||||||
clientAuth tls.ClientAuthType | ||||||
cfg *Config | ||||||
cert string | ||||||
key string | ||||||
ca string | ||||||
sn string | ||||||
insecure bool | ||||||
clientAuth tls.ClientAuthType | ||||||
// hotReload toggles per-handshake reload using GetCertificate. | ||||||
hotReload bool | ||||||
// certPtr keeps the latest loaded certificate. | ||||||
certPtr atomic.Pointer[tls.Certificate] | ||||||
} | ||||||
|
||||||
// newCredential builds credentials from defaults and provided options. | ||||||
func newCredential(opts ...Option) (*credentials, error) { | ||||||
c := new(credentials) | ||||||
for _, opt := range append(defaultOptions(), opts...) { | ||||||
if err := opt(c); err != nil { | ||||||
return nil, errors.ErrOptionFailed(err, reflect.ValueOf(opt)) | ||||||
} | ||||||
} | ||||||
if c.cfg == nil { | ||||||
c.cfg = new(Config) | ||||||
} | ||||||
if c.sn != "" { | ||||||
c.cfg.ServerName = c.sn | ||||||
} | ||||||
c := new(credentials) | ||||||
for _, opt := range append(defaultOptions(), opts...) { | ||||||
if err := opt(c); err != nil { | ||||||
return nil, errors.ErrOptionFailed(err, reflect.ValueOf(opt)) | ||||||
} | ||||||
} | ||||||
if c.cfg == nil { | ||||||
c.cfg = new(Config) | ||||||
} | ||||||
if c.sn != "" { | ||||||
c.cfg.ServerName = c.sn | ||||||
} | ||||||
c.cfg.InsecureSkipVerify = c.insecure | ||||||
if c.cfg.MinVersion == 0 { | ||||||
c.cfg.MinVersion = tls.VersionTLS12 | ||||||
|
@@ -139,12 +144,48 @@ func NewServerConfig(opts ...Option) (*Config, error) { | |||||
c.sn = "vald-server" | ||||||
c.cfg.ServerName = c.sn | ||||||
} | ||||||
// load cert pair | ||||||
kp, err := loadKeyPair(c.sn, c.cert, c.key) | ||||||
if err != nil { | ||||||
return nil, err | ||||||
} | ||||||
c.cfg.Certificates = []tls.Certificate{kp} | ||||||
// Configure certificate strategy. | ||||||
if c.hotReload { | ||||||
// Preload once for NameToCertificate mapping and fallback. | ||||||
kp, err := loadKeyPair(c.sn, c.cert, c.key) | ||||||
if err != nil { | ||||||
return nil, err | ||||||
} | ||||||
c.cfg.Certificates = []tls.Certificate{kp} | ||||||
c.certPtr.Store(&kp) | ||||||
|
||||||
// Reload per-handshake. | ||||||
c.cfg.GetCertificate = func(chi *tls.ClientHelloInfo) (*tls.Certificate, error) { | ||||||
kp2, err := loadKeyPair(c.sn, c.cert, c.key) | ||||||
if err != nil { | ||||||
// fall back to last good certificate | ||||||
if cur := c.certPtr.Load(); cur != nil { | ||||||
return cur, nil | ||||||
} | ||||||
return nil, err | ||||||
} | ||||||
c.certPtr.Store(&kp2) | ||||||
return c.certPtr.Load(), nil | ||||||
} | ||||||
|
||||||
// Ensure NameToCertificate stays sensible by cloning config with latest cert. | ||||||
c.cfg.GetConfigForClient = func(chi *tls.ClientHelloInfo) (*tls.Config, error) { | ||||||
cfg := c.cfg.Clone() | ||||||
cfg.GetConfigForClient = nil | ||||||
if cur := c.certPtr.Load(); cur != nil { | ||||||
cfg.Certificates = []tls.Certificate{*cur} | ||||||
} | ||||||
return cfg, nil | ||||||
} | ||||||
} else { | ||||||
// load once statically | ||||||
kp, err := loadKeyPair(c.sn, c.cert, c.key) | ||||||
if err != nil { | ||||||
return nil, err | ||||||
} | ||||||
c.cfg.Certificates = []tls.Certificate{kp} | ||||||
c.certPtr.Store(&kp) | ||||||
} | ||||||
// if CA provided, configure mTLS | ||||||
if c.ca != "" { | ||||||
pool, err := NewX509CertPool(c.ca) | ||||||
|
@@ -186,19 +227,41 @@ func NewClientConfig(opts ...Option) (*Config, error) { | |||||
c.cfg.RootCAs = pool | ||||||
} | ||||||
} | ||||||
// load client cert if provided | ||||||
if c.cert != "" && c.key != "" { | ||||||
if c.sn == "" { | ||||||
c.sn = "vald-client" | ||||||
c.cfg.ServerName = c.sn | ||||||
} | ||||||
kp, err := loadKeyPair(c.sn, c.cert, c.key) | ||||||
if err != nil { | ||||||
return nil, err | ||||||
} | ||||||
c.cfg.Certificates = []tls.Certificate{kp} | ||||||
} | ||||||
return c.cfg, nil | ||||||
// load client cert if provided | ||||||
if c.cert != "" && c.key != "" { | ||||||
if c.sn == "" { | ||||||
c.sn = "vald-client" | ||||||
c.cfg.ServerName = c.sn | ||||||
} | ||||||
if c.hotReload { | ||||||
// Preload once for initial handshake and SNI mapping if needed. | ||||||
kp, err := loadKeyPair(c.sn, c.cert, c.key) | ||||||
if err != nil { | ||||||
return nil, err | ||||||
} | ||||||
c.cfg.Certificates = []tls.Certificate{kp} | ||||||
c.certPtr.Store(&kp) | ||||||
c.cfg.GetClientCertificate = func(cri *tls.CertificateRequestInfo) (*tls.Certificate, error) { | ||||||
kp2, err := loadKeyPair(c.sn, c.cert, c.key) | ||||||
if err != nil { | ||||||
if cur := c.certPtr.Load(); cur != nil { | ||||||
return cur, nil | ||||||
} | ||||||
return nil, err | ||||||
} | ||||||
c.certPtr.Store(&kp2) | ||||||
return c.certPtr.Load(), nil | ||||||
|
return c.certPtr.Load(), nil | |
return &kp2, nil |
Copilot uses AI. Check for mistakes.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,199 @@ | ||
package tls_test | ||
|
||
import ( | ||
"context" | ||
"path/filepath" | ||
"testing" | ||
"time" | ||
|
||
"github.com/vdaas/vald/apis/grpc/v1/payload" | ||
"github.com/vdaas/vald/apis/grpc/v1/vald" | ||
"github.com/vdaas/vald/internal/config" | ||
"github.com/vdaas/vald/internal/file" | ||
"github.com/vdaas/vald/internal/log" | ||
"github.com/vdaas/vald/internal/log/level" | ||
"github.com/vdaas/vald/internal/net" | ||
"github.com/vdaas/vald/internal/safety" | ||
"github.com/vdaas/vald/internal/servers/server" | ||
"github.com/vdaas/vald/internal/servers/starter" | ||
"github.com/vdaas/vald/internal/sync/errgroup" | ||
"github.com/vdaas/vald/internal/test" | ||
"github.com/vdaas/vald/internal/tls" | ||
"google.golang.org/grpc" | ||
"google.golang.org/grpc/credentials" | ||
) | ||
|
||
var ( | ||
activeCertPath string | ||
activeKeyPath string | ||
) | ||
|
||
func init() { | ||
log.Init(log.WithLevel(level.ERROR.String())) | ||
} | ||
|
||
func serverStarter(b *testing.B, hot bool) (ctx context.Context, stop context.CancelFunc, addr string) { | ||
b.Helper() | ||
ctx, stop = context.WithCancel(b.Context()) | ||
|
||
ln, err := net.Listen(net.TCP.String(), "127.0.0.1:0") | ||
if err != nil { | ||
b.Fatalf("listen: %v", err) | ||
} | ||
_, port, _ := net.SplitHostPort(ln.Addr().String()) | ||
_ = ln.Close() | ||
|
||
certPath := test.GetTestdataPath("tls/server.crt") | ||
keyPath := test.GetTestdataPath("tls/server.key") | ||
if hot { | ||
dir := b.TempDir() | ||
activeCertPath = filepath.Join(dir, "active.crt") | ||
activeKeyPath = filepath.Join(dir, "active.key") | ||
_, _ = file.CopyFile(ctx, certPath, activeCertPath) | ||
_, _ = file.CopyFile(ctx, keyPath, activeKeyPath) | ||
certPath, keyPath = activeCertPath, activeKeyPath | ||
} else { | ||
activeCertPath, activeKeyPath = "", "" | ||
} | ||
|
||
stls, err := tls.NewServerConfig( | ||
tls.WithCert(certPath), | ||
tls.WithKey(keyPath), | ||
tls.WithClientAuth("noclientcert"), | ||
tls.WithServerCertHotReload(hot), | ||
) | ||
if err != nil { | ||
b.Fatalf("server TLS config: %v", err) | ||
} | ||
|
||
srv, err := starter.New( | ||
starter.WithConfig((&config.Servers{ | ||
Servers: []*config.Server{{ | ||
Name: "bench-grpc", | ||
Mode: server.GRPC.String(), | ||
Host: "127.0.0.1", | ||
Port: port, | ||
GRPC: &config.GRPC{}, | ||
}}, | ||
}).Bind()), | ||
starter.WithGRPC(func(sc *config.Server) []server.Option { | ||
return []server.Option{ | ||
server.WithGRPCRegistFunc(func(gs *grpc.Server) { | ||
vald.RegisterIndexServer(gs, mockIndexInfoServer{}) | ||
}), | ||
server.WithGRPCOption(grpc.Creds(credentials.NewTLS(stls))), | ||
} | ||
}), | ||
) | ||
if err != nil { | ||
b.Error(err) | ||
} | ||
|
||
go func() { _ = srv.ListenAndServe(ctx) }() | ||
|
||
addr = net.JoinHostPort("127.0.0.1", port) | ||
deadline := time.Now().Add(3 * time.Second) | ||
for { | ||
dctx, cancel := context.WithTimeout(ctx, 200*time.Millisecond) | ||
c, err := net.DialContext(dctx, net.TCP.String(), addr) | ||
cancel() | ||
if err == nil { | ||
_ = c.Close() | ||
break | ||
} | ||
if time.Now().After(deadline) { | ||
break | ||
} | ||
time.Sleep(50 * time.Millisecond) | ||
} | ||
return ctx, stop, addr | ||
} | ||
|
||
func reloadTLSCerts(b *testing.B) (stop context.CancelFunc) { | ||
b.Helper() | ||
|
||
var ctx context.Context | ||
ctx, stop = context.WithCancel(b.Context()) | ||
eg, egctx := errgroup.New(ctx) | ||
eg.Go(safety.RecoverFunc(func() error { | ||
tick := time.NewTicker(200 * time.Millisecond) | ||
defer tick.Stop() | ||
srcA := test.GetTestdataPath("tls/server.crt") | ||
srcB := test.GetTestdataPath("tls/server2.crt") | ||
if !file.Exists(srcB) { | ||
srcB = srcA | ||
} | ||
useA := false | ||
for { | ||
select { | ||
case <-egctx.Done(): | ||
return nil | ||
case <-tick.C: | ||
if activeCertPath == "" { | ||
continue | ||
} | ||
if useA { | ||
_, _ = file.CopyFile(egctx, srcA, activeCertPath) | ||
} else { | ||
_, _ = file.CopyFile(egctx, srcB, activeCertPath) | ||
} | ||
useA = !useA | ||
} | ||
} | ||
})) | ||
|
||
return stop | ||
} | ||
|
||
type mockIndexInfoServer struct { | ||
vald.UnimplementedIndexServer | ||
} | ||
|
||
func (m mockIndexInfoServer) IndexInfo(context.Context, *payload.Empty) (*payload.Info_Index_Count, error) { | ||
return &payload.Info_Index_Count{Stored: 100}, nil | ||
} | ||
|
||
func runTLSHandshakePerOp(b *testing.B, hot bool) { | ||
ctx, stop, addr := serverStarter(b, hot) | ||
defer stop() | ||
|
||
ccfg, err := tls.NewClientConfig( | ||
tls.WithCa(test.GetTestdataPath("tls/ca.pem")), | ||
tls.WithServerName("vald.vdaas.org"), | ||
) | ||
if err != nil { | ||
b.Fatalf("client tls: %v", err) | ||
} | ||
|
||
var stopReload context.CancelFunc | ||
if hot { | ||
stopReload = reloadTLSCerts(b) | ||
defer stopReload() | ||
} | ||
|
||
b.ReportAllocs() | ||
b.ResetTimer() | ||
b.RunParallel(func(pb *testing.PB) { | ||
for pb.Next() { | ||
dctx, cancel := context.WithTimeout(ctx, 3*time.Second) | ||
conn, err := grpc.DialContext( | ||
dctx, addr, | ||
grpc.WithTransportCredentials(credentials.NewTLS(ccfg)), | ||
grpc.WithBlock(), | ||
grpc.WithReturnConnectionError(), | ||
) | ||
cancel() | ||
if err != nil { | ||
b.Fatalf("dial: %v", err) | ||
} | ||
_, err = vald.NewIndexClient(conn).IndexInfo(ctx, &payload.Empty{}) | ||
_ = conn.Close() | ||
if err != nil { | ||
b.Fatalf("IndexInfo: %v", err) | ||
} | ||
} | ||
}) | ||
} | ||
coderabbitai[bot] marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
||
func Benchmark_TLS_HandshakePerOp_Static(b *testing.B) { runTLSHandshakePerOp(b, false) } | ||
func Benchmark_TLS_HandshakePerOp_HotReload(b *testing.B) { runTLSHandshakePerOp(b, true) } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function should return the newly loaded certificate
&kp2
directly instead of callingc.certPtr.Load()
which could potentially return a different certificate if another goroutine updates it between the Store and Load operations.Copilot uses AI. Check for mistakes.