Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions exp/api/remote/remote_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -377,8 +377,9 @@ type writeStorage interface {
}

type handler struct {
store writeStorage
opts handlerOpts
store writeStorage
acceptedMessageTypes MessageTypes
opts handlerOpts
}

type handlerOpts struct {
Expand Down Expand Up @@ -455,15 +456,20 @@ func SnappyDecompressorMiddleware(logger *slog.Logger) func(http.Handler) http.H

// NewHandler returns HTTP handler that receives Remote Write 2.0
// protocol https://prometheus.io/docs/specs/remote_write_spec_2_0/.
func NewHandler(store writeStorage, opts ...HandlerOption) http.Handler {
func NewHandler(store writeStorage, acceptedMessageTypes MessageTypes, opts ...HandlerOption) http.Handler {
o := handlerOpts{
logger: slog.New(nopSlogHandler{}),
middlewares: []func(http.Handler) http.Handler{SnappyDecompressorMiddleware(slog.New(nopSlogHandler{}))},
}
for _, opt := range opts {
opt(&o)
}
h := &handler{opts: o, store: store}

h := &handler{
opts: o,
store: store,
acceptedMessageTypes: acceptedMessageTypes,
}

// Apply all middlewares in order
var handler http.Handler = h
Expand Down Expand Up @@ -524,6 +530,13 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}

if !h.acceptedMessageTypes.Contains(msgType) {
err := fmt.Errorf("%v protobuf message is not accepted by this server; only accepts %v", msgType, h.acceptedMessageTypes.String())
h.opts.logger.Error("Unaccepted message type", "msgType", msgType, "err", err)
http.Error(w, err.Error(), http.StatusUnsupportedMediaType)
return
}

writeResponse, storeErr := h.store.Store(r.Context(), msgType, r)

// Set required X-Prometheus-Remote-Write-Written-* response headers, in all cases, alongwith any user-defined headers.
Expand Down
4 changes: 2 additions & 2 deletions exp/api/remote/remote_api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ func TestRemoteAPI_Write_WithHandler(t *testing.T) {
t.Run("success", func(t *testing.T) {
tLogger := slog.Default()
mStore := &mockStorage{}
srv := httptest.NewServer(NewHandler(mStore, WithHandlerLogger(tLogger)))
srv := httptest.NewServer(NewHandler(mStore, MessageTypes{WriteV2MessageType}, WithHandlerLogger(tLogger)))
t.Cleanup(srv.Close)

client, err := NewAPI(srv.URL,
Expand Down Expand Up @@ -182,7 +182,7 @@ func TestRemoteAPI_Write_WithHandler(t *testing.T) {
mockErr: errors.New("storage error"),
mockCode: &mockCode,
}
srv := httptest.NewServer(NewHandler(mStore, WithHandlerLogger(tLogger)))
srv := httptest.NewServer(NewHandler(mStore, MessageTypes{WriteV2MessageType}, WithHandlerLogger(tLogger)))
t.Cleanup(srv.Close)

client, err := NewAPI(srv.URL,
Expand Down
13 changes: 9 additions & 4 deletions exp/api/remote/remote_headers.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"errors"
"fmt"
"net/http"
"slices"
"strconv"
"strings"
)
Expand Down Expand Up @@ -59,24 +60,28 @@ func (n WriteMessageType) Validate() error {
case WriteV1MessageType, WriteV2MessageType:
return nil
default:
return fmt.Errorf("unknown type for remote write protobuf message %v, supported: %v", n, messageTypes{WriteV1MessageType, WriteV2MessageType}.String())
return fmt.Errorf("unknown type for remote write protobuf message %v, supported: %v", n, MessageTypes{WriteV1MessageType, WriteV2MessageType}.String())
}
}

type messageTypes []WriteMessageType
type MessageTypes []WriteMessageType

func (m messageTypes) Strings() []string {
func (m MessageTypes) Strings() []string {
ret := make([]string, 0, len(m))
for _, typ := range m {
ret = append(ret, string(typ))
}
return ret
}

func (m messageTypes) String() string {
func (m MessageTypes) String() string {
return strings.Join(m.Strings(), ", ")
}

func (m MessageTypes) Contains(mType WriteMessageType) bool {
return slices.Contains(m, mType)
}

var contentTypeHeaders = map[WriteMessageType]string{
WriteV1MessageType: appProtoContentType, // Also application/x-protobuf;proto=prometheus.WriteRequest but simplified for compatibility with 1.x spec.
WriteV2MessageType: appProtoContentType + ";proto=io.prometheus.write.v2.Request",
Expand Down
Loading