Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions internal/tls/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ var defaultNextProtos = []string{
var defaultOptions = func() []Option {
return []Option{
WithInsecureSkipVerify(false),
// Hot reload is opt-in to avoid unexpected overhead by default.
WithServerCertHotReload(false),
WithTLSConfig(&tls.Config{
MinVersion: tls.VersionTLS12,
NextProtos: defaultNextProtos,
Expand Down Expand Up @@ -145,6 +147,13 @@ func WithInsecureSkipVerify(insecure bool) Option {
}
}

func WithServerCertHotReload(enabled bool) Option {
return func(c *credentials) error {
c.hotReload = enabled
return nil
}
}

// WithClientAuth sets server-side client auth policy
func WithClientAuth(auth string) Option {
at := parseClientAuthType(auth)
Expand Down
139 changes: 101 additions & 38 deletions internal/tls/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"encoding/hex"
"encoding/pem"
"reflect"
"sync/atomic"
"time"

"github.com/vdaas/vald/internal/errors"
Expand Down Expand Up @@ -60,29 +61,33 @@ var (
// credentials holds TLS settings for server and client
// including certificate paths, CA bundle, and hot reload policies.
type credentials struct {
cfg *Config
cert string
key string
ca string
sn string
insecure bool
clientAuth tls.ClientAuthType
cfg *Config
cert string
key string
ca string
sn string
insecure bool
clientAuth tls.ClientAuthType
// hotReload toggles per-handshake reload using GetCertificate.
hotReload bool
// certPtr keeps the latest loaded certificate.
certPtr atomic.Pointer[tls.Certificate]
}

// newCredential builds credentials from defaults and provided options.
func newCredential(opts ...Option) (*credentials, error) {
c := new(credentials)
for _, opt := range append(defaultOptions(), opts...) {
if err := opt(c); err != nil {
return nil, errors.ErrOptionFailed(err, reflect.ValueOf(opt))
}
}
if c.cfg == nil {
c.cfg = new(Config)
}
if c.sn != "" {
c.cfg.ServerName = c.sn
}
c := new(credentials)
for _, opt := range append(defaultOptions(), opts...) {
if err := opt(c); err != nil {
return nil, errors.ErrOptionFailed(err, reflect.ValueOf(opt))
}
}
if c.cfg == nil {
c.cfg = new(Config)
}
if c.sn != "" {
c.cfg.ServerName = c.sn
}
c.cfg.InsecureSkipVerify = c.insecure
if c.cfg.MinVersion == 0 {
c.cfg.MinVersion = tls.VersionTLS12
Expand Down Expand Up @@ -139,12 +144,48 @@ func NewServerConfig(opts ...Option) (*Config, error) {
c.sn = "vald-server"
c.cfg.ServerName = c.sn
}
// load cert pair
kp, err := loadKeyPair(c.sn, c.cert, c.key)
if err != nil {
return nil, err
}
c.cfg.Certificates = []tls.Certificate{kp}
// Configure certificate strategy.
if c.hotReload {
// Preload once for NameToCertificate mapping and fallback.
kp, err := loadKeyPair(c.sn, c.cert, c.key)
if err != nil {
return nil, err
}
c.cfg.Certificates = []tls.Certificate{kp}
c.certPtr.Store(&kp)

// Reload per-handshake.
c.cfg.GetCertificate = func(chi *tls.ClientHelloInfo) (*tls.Certificate, error) {
kp2, err := loadKeyPair(c.sn, c.cert, c.key)
if err != nil {
// fall back to last good certificate
if cur := c.certPtr.Load(); cur != nil {
return cur, nil
}
return nil, err
}
c.certPtr.Store(&kp2)
return c.certPtr.Load(), nil
Copy link
Preview

Copilot AI Sep 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function should return the newly loaded certificate &kp2 directly instead of calling c.certPtr.Load() which could potentially return a different certificate if another goroutine updates it between the Store and Load operations.

Suggested change
return c.certPtr.Load(), nil
return &kp2, nil

Copilot uses AI. Check for mistakes.

}

// Ensure NameToCertificate stays sensible by cloning config with latest cert.
c.cfg.GetConfigForClient = func(chi *tls.ClientHelloInfo) (*tls.Config, error) {
cfg := c.cfg.Clone()
cfg.GetConfigForClient = nil
if cur := c.certPtr.Load(); cur != nil {
cfg.Certificates = []tls.Certificate{*cur}
}
return cfg, nil
}
} else {
// load once statically
kp, err := loadKeyPair(c.sn, c.cert, c.key)
if err != nil {
return nil, err
}
c.cfg.Certificates = []tls.Certificate{kp}
c.certPtr.Store(&kp)
}
// if CA provided, configure mTLS
if c.ca != "" {
pool, err := NewX509CertPool(c.ca)
Expand Down Expand Up @@ -186,19 +227,41 @@ func NewClientConfig(opts ...Option) (*Config, error) {
c.cfg.RootCAs = pool
}
}
// load client cert if provided
if c.cert != "" && c.key != "" {
if c.sn == "" {
c.sn = "vald-client"
c.cfg.ServerName = c.sn
}
kp, err := loadKeyPair(c.sn, c.cert, c.key)
if err != nil {
return nil, err
}
c.cfg.Certificates = []tls.Certificate{kp}
}
return c.cfg, nil
// load client cert if provided
if c.cert != "" && c.key != "" {
if c.sn == "" {
c.sn = "vald-client"
c.cfg.ServerName = c.sn
}
if c.hotReload {
// Preload once for initial handshake and SNI mapping if needed.
kp, err := loadKeyPair(c.sn, c.cert, c.key)
if err != nil {
return nil, err
}
c.cfg.Certificates = []tls.Certificate{kp}
c.certPtr.Store(&kp)
c.cfg.GetClientCertificate = func(cri *tls.CertificateRequestInfo) (*tls.Certificate, error) {
kp2, err := loadKeyPair(c.sn, c.cert, c.key)
if err != nil {
if cur := c.certPtr.Load(); cur != nil {
return cur, nil
}
return nil, err
}
c.certPtr.Store(&kp2)
return c.certPtr.Load(), nil
Copy link
Preview

Copilot AI Sep 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function should return the newly loaded certificate &kp2 directly instead of calling c.certPtr.Load() which could potentially return a different certificate if another goroutine updates it between the Store and Load operations.

Suggested change
return c.certPtr.Load(), nil
return &kp2, nil

Copilot uses AI. Check for mistakes.

}
} else {
kp, err := loadKeyPair(c.sn, c.cert, c.key)
if err != nil {
return nil, err
}
c.cfg.Certificates = []tls.Certificate{kp}
c.certPtr.Store(&kp)
}
}
return c.cfg, nil
}

// NewX509CertPool loads one or more PEM files into a CertPool
Expand Down
199 changes: 199 additions & 0 deletions internal/tls/tls_bench_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
package tls_test

import (
"context"
"path/filepath"
"testing"
"time"

"github.com/vdaas/vald/apis/grpc/v1/payload"
"github.com/vdaas/vald/apis/grpc/v1/vald"
"github.com/vdaas/vald/internal/config"
"github.com/vdaas/vald/internal/file"
"github.com/vdaas/vald/internal/log"
"github.com/vdaas/vald/internal/log/level"
"github.com/vdaas/vald/internal/net"
"github.com/vdaas/vald/internal/safety"
"github.com/vdaas/vald/internal/servers/server"
"github.com/vdaas/vald/internal/servers/starter"
"github.com/vdaas/vald/internal/sync/errgroup"
"github.com/vdaas/vald/internal/test"
"github.com/vdaas/vald/internal/tls"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
)

var (
activeCertPath string
activeKeyPath string
)

func init() {
log.Init(log.WithLevel(level.ERROR.String()))
}

func serverStarter(b *testing.B, hot bool) (ctx context.Context, stop context.CancelFunc, addr string) {
b.Helper()
ctx, stop = context.WithCancel(b.Context())

ln, err := net.Listen(net.TCP.String(), "127.0.0.1:0")
if err != nil {
b.Fatalf("listen: %v", err)
}
_, port, _ := net.SplitHostPort(ln.Addr().String())
_ = ln.Close()

certPath := test.GetTestdataPath("tls/server.crt")
keyPath := test.GetTestdataPath("tls/server.key")
if hot {
dir := b.TempDir()
activeCertPath = filepath.Join(dir, "active.crt")
activeKeyPath = filepath.Join(dir, "active.key")
_, _ = file.CopyFile(ctx, certPath, activeCertPath)
_, _ = file.CopyFile(ctx, keyPath, activeKeyPath)
certPath, keyPath = activeCertPath, activeKeyPath
} else {
activeCertPath, activeKeyPath = "", ""
}

stls, err := tls.NewServerConfig(
tls.WithCert(certPath),
tls.WithKey(keyPath),
tls.WithClientAuth("noclientcert"),
tls.WithServerCertHotReload(hot),
)
if err != nil {
b.Fatalf("server TLS config: %v", err)
}

srv, err := starter.New(
starter.WithConfig((&config.Servers{
Servers: []*config.Server{{
Name: "bench-grpc",
Mode: server.GRPC.String(),
Host: "127.0.0.1",
Port: port,
GRPC: &config.GRPC{},
}},
}).Bind()),
starter.WithGRPC(func(sc *config.Server) []server.Option {
return []server.Option{
server.WithGRPCRegistFunc(func(gs *grpc.Server) {
vald.RegisterIndexServer(gs, mockIndexInfoServer{})
}),
server.WithGRPCOption(grpc.Creds(credentials.NewTLS(stls))),
}
}),
)
if err != nil {
b.Error(err)
}

go func() { _ = srv.ListenAndServe(ctx) }()

addr = net.JoinHostPort("127.0.0.1", port)
deadline := time.Now().Add(3 * time.Second)
for {
dctx, cancel := context.WithTimeout(ctx, 200*time.Millisecond)
c, err := net.DialContext(dctx, net.TCP.String(), addr)
cancel()
if err == nil {
_ = c.Close()
break
}
if time.Now().After(deadline) {
break
}
time.Sleep(50 * time.Millisecond)
}
return ctx, stop, addr
}

func reloadTLSCerts(b *testing.B) (stop context.CancelFunc) {
b.Helper()

var ctx context.Context
ctx, stop = context.WithCancel(b.Context())
eg, egctx := errgroup.New(ctx)
eg.Go(safety.RecoverFunc(func() error {
tick := time.NewTicker(200 * time.Millisecond)
defer tick.Stop()
srcA := test.GetTestdataPath("tls/server.crt")
srcB := test.GetTestdataPath("tls/server2.crt")
if !file.Exists(srcB) {
srcB = srcA
}
useA := false
for {
select {
case <-egctx.Done():
return nil
case <-tick.C:
if activeCertPath == "" {
continue
}
if useA {
_, _ = file.CopyFile(egctx, srcA, activeCertPath)
} else {
_, _ = file.CopyFile(egctx, srcB, activeCertPath)
}
useA = !useA
}
}
}))

return stop
}

type mockIndexInfoServer struct {
vald.UnimplementedIndexServer
}

func (m mockIndexInfoServer) IndexInfo(context.Context, *payload.Empty) (*payload.Info_Index_Count, error) {
return &payload.Info_Index_Count{Stored: 100}, nil
}

func runTLSHandshakePerOp(b *testing.B, hot bool) {
ctx, stop, addr := serverStarter(b, hot)
defer stop()

ccfg, err := tls.NewClientConfig(
tls.WithCa(test.GetTestdataPath("tls/ca.pem")),
tls.WithServerName("vald.vdaas.org"),
)
if err != nil {
b.Fatalf("client tls: %v", err)
}

var stopReload context.CancelFunc
if hot {
stopReload = reloadTLSCerts(b)
defer stopReload()
}

b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
dctx, cancel := context.WithTimeout(ctx, 3*time.Second)
conn, err := grpc.DialContext(
dctx, addr,
grpc.WithTransportCredentials(credentials.NewTLS(ccfg)),
grpc.WithBlock(),
grpc.WithReturnConnectionError(),
)
cancel()
if err != nil {
b.Fatalf("dial: %v", err)
}
_, err = vald.NewIndexClient(conn).IndexInfo(ctx, &payload.Empty{})
_ = conn.Close()
if err != nil {
b.Fatalf("IndexInfo: %v", err)
}
}
})
}

func Benchmark_TLS_HandshakePerOp_Static(b *testing.B) { runTLSHandshakePerOp(b, false) }
func Benchmark_TLS_HandshakePerOp_HotReload(b *testing.B) { runTLSHandshakePerOp(b, true) }
Loading