From e8bc4946877c5e7685e3064eecd809923f0839a2 Mon Sep 17 00:00:00 2001 From: Benjamin Date: Mon, 22 Sep 2025 09:49:09 -0500 Subject: [PATCH 01/21] add http token parsing --- rules/rules.go | 215 ++++++++++++++++---------------------------- rules/rules_test.go | 130 ++++++++++++++++++++++++++- 2 files changed, 204 insertions(+), 141 deletions(-) diff --git a/rules/rules.go b/rules/rules.go index ab64cc4..5baf368 100644 --- a/rules/rules.go +++ b/rules/rules.go @@ -1,9 +1,9 @@ package rules import ( + "errors" "fmt" "log/slog" - "strings" ) type Evaluator interface { @@ -12,16 +12,84 @@ type Evaluator interface { // Rule represents an allow rule with optional HTTP method restrictions type Rule struct { - Pattern string // wildcard pattern for matching - Methods map[string]bool // nil means all methods allowed - Raw string // rule string for logging + + // The path segments of the url + // nil means all paths allowed + // a path segment of `*` acts as a wild card. + Path []string + + // The labels of the host, i.e. ["google", "com"] + // nil means no hosts allowed + // subdomains automatically match + Host []string + + // The allowed http methods + // nil means all methods allowed + Methods map[string]struct{} + + // Raw rule string for logging + Raw string +} + +type httpToken string + +// Beyond the 9 methods defined in HTTP 1.1, there actually are many more seldom used extension methods by +// various systems. +// https://datatracker.ietf.org/doc/html/rfc7230#section-3.2.6 +func parseHTTPToken(token string) (httpToken, string, error) { + if token == "" { + return "", "", errors.New("expected http token, got empty string") + } + return doParseHTTPToken(token, nil) +} + +func doParseHTTPToken(token string, acc []byte) (httpToken, string, error) { + // BASE CASE: if the token passed in is empty, we're done parsing + if token == "" { + return httpToken(acc), "", nil + } + + // If the next byte in the string is not a valid http token character, we're done parsing. + if !isHTTPTokenChar(token[0]) { + return httpToken(acc), token, nil + } + + // The next character is valid, so the http token continues + acc = append(acc, token[0]) + return doParseHTTPToken(token[1:], acc) +} + +// The valid characters that can be in an http token (like the lexer/parser kind of token). +func isHTTPTokenChar(c byte) bool { + switch { + // Alpha numeric is fine. + case c >= 'A' && c <= 'Z': + return true + case c >= 'a' && c <= 'z': + return true + case c >= '0' && c <= '9': + return true + + // These special characters are also allowed unbelievably. + case c == '!' || c == '#' || c == '$' || c == '%' || c == '&' || + c == '\'' || c == '*' || c == '+' || c == '-' || c == '.' || + c == '^' || c == '_' || c == '`' || c == '|' || c == '~': + return true + + default: + return false + } +} + +func parseAllowRule(string) (Rule, error) { + return Rule{}, nil } // ParseAllowSpecs parses a slice of --allow specs into allow Rules. func ParseAllowSpecs(allowStrings []string) ([]Rule, error) { var out []Rule for _, s := range allowStrings { - r, err := newAllowRule(s) + r, err := parseAllowRule(s) if err != nil { return nil, fmt.Errorf("failed to parse allow '%s': %v", s, err) } @@ -71,142 +139,15 @@ func (re *Engine) Evaluate(method, url string) Result { // Matches checks if the rule matches the given method and URL using wildcard patterns func (re *Engine) matches(r Rule, method, url string) bool { - // Check method if specified - if r.Methods != nil && !r.Methods[strings.ToUpper(method)] { - return false - } - - // Check URL pattern using wildcard matching - // Try exact match first - if wildcardMatch(r.Pattern, url) { - return true - } - - // If pattern doesn't start with protocol, try matching against the URL without protocol - if !strings.HasPrefix(r.Pattern, "http://") && !strings.HasPrefix(r.Pattern, "https://") { - // Extract domain and path from URL - urlWithoutProtocol := url - if strings.HasPrefix(url, "https://") { - urlWithoutProtocol = url[8:] // Remove "https://" - } else if strings.HasPrefix(url, "http://") { - urlWithoutProtocol = url[7:] // Remove "http://" - } - - // Try matching against URL without protocol - if wildcardMatch(r.Pattern, urlWithoutProtocol) { - return true - } - - // Also try matching just the domain part - domainEnd := strings.Index(urlWithoutProtocol, "/") - if domainEnd > 0 { - domain := urlWithoutProtocol[:domainEnd] - if wildcardMatch(r.Pattern, domain) { - return true - } - } else { - // No path, just domain - if wildcardMatch(r.Pattern, urlWithoutProtocol) { - return true - } - } - } - - return false -} - -// wildcardMatch performs wildcard pattern matching -// Supports * (matches any sequence of characters) -func wildcardMatch(pattern, text string) bool { - pattern = strings.ToLower(pattern) - text = strings.ToLower(text) - - // Handle simple case - if pattern == "*" { + // If the rule doesn't have any method filters, don't restrict the allowed methods + if r.Methods == nil { return true } - // Split pattern by '*' and check each part exists in order - parts := strings.Split(pattern, "*") - - // If no wildcards, must be exact match - if len(parts) == 1 { - return pattern == text - } - - textPos := 0 - for i, part := range parts { - if part == "" { - continue // Skip empty parts from consecutive '*' - } - - if i == 0 { - // First part must be at the beginning - if !strings.HasPrefix(text, part) { - return false - } - textPos = len(part) - } else if i == len(parts)-1 { - // Last part must be at the end - if !strings.HasSuffix(text[textPos:], part) { - return false - } - } else { - // Middle parts must exist in order - idx := strings.Index(text[textPos:], part) - if idx == -1 { - return false - } - textPos += idx + len(part) - } + // If the rule has method filters and the provided method is not one of them, block the request. + if _, methodIsAllowed := r.Methods[method]; !methodIsAllowed { + return false } return true } - -// newAllowRule creates an allow Rule from a spec string used by --allow. -// Supported formats: -// -// "pattern" -> allow all methods to pattern -// "GET,HEAD pattern" -> allow only listed methods to pattern -func newAllowRule(spec string) (Rule, error) { - s := strings.TrimSpace(spec) - if s == "" { - return Rule{}, fmt.Errorf("invalid allow spec: empty") - } - - var methods map[string]bool - pattern := s - - // Detect optional leading methods list separated by commas and a space before pattern - // e.g., "GET,HEAD github.com" - if idx := strings.IndexFunc(s, func(r rune) bool { return r == ' ' || r == '\t' }); idx > 0 { - left := strings.TrimSpace(s[:idx]) - right := strings.TrimSpace(s[idx:]) - // methods part is valid if it only contains letters and commas - valid := left != "" && strings.IndexFunc(left, func(r rune) bool { - return r != ',' && (r < 'A' || r > 'Z') && (r < 'a' || r > 'z') - }) == -1 - if valid { - methods = make(map[string]bool) - for _, m := range strings.Split(left, ",") { - m = strings.TrimSpace(m) - if m == "" { - continue - } - methods[strings.ToUpper(m)] = true - } - pattern = right - } - } - - if pattern == "" { - return Rule{}, fmt.Errorf("invalid allow spec: missing pattern") - } - - return Rule{ - Pattern: pattern, - Methods: methods, - Raw: "allow " + spec, - }, nil -} diff --git a/rules/rules_test.go b/rules/rules_test.go index eb702fe..1bbf7a3 100644 --- a/rules/rules_test.go +++ b/rules/rules_test.go @@ -2,8 +2,130 @@ package rules import "testing" -// Stub test file - tests removed -func TestStub(t *testing.T) { - // This is a stub test - t.Skip("stub test file") +func TestParseHTTPToken(t *testing.T) { + tests := []struct { + name string + input string + expectedToken httpToken + expectedRemain string + expectError bool + }{ + { + name: "empty string", + input: "", + expectedToken: "", + expectedRemain: "", + expectError: true, + }, + { + name: "simple method GET", + input: "GET", + expectedToken: "GET", + expectedRemain: "", + expectError: false, + }, + { + name: "simple method POST", + input: "POST", + expectedToken: "POST", + expectedRemain: "", + expectError: false, + }, + { + name: "method with trailing space", + input: "GET ", + expectedToken: "GET", + expectedRemain: " ", + expectError: false, + }, + { + name: "method with trailing content", + input: "POST /api/users", + expectedToken: "POST", + expectedRemain: " /api/users", + expectError: false, + }, + { + name: "all valid special characters", + input: "!#$%&'*+-.^_`|~", + expectedToken: "!#$%&'*+-.^_`|~", + expectedRemain: "", + expectError: false, + }, + { + name: "alphanumeric token", + input: "ABC123xyz", + expectedToken: "ABC123xyz", + expectedRemain: "", + expectError: false, + }, + { + name: "token with invalid character", + input: "GET@test", + expectedToken: "GET", + expectedRemain: "@test", + expectError: false, + }, + { + name: "token starting with invalid character", + input: "@GET", + expectedToken: "", + expectedRemain: "@GET", + expectError: false, + }, + { + name: "single character token", + input: "A", + expectedToken: "A", + expectedRemain: "", + expectError: false, + }, + { + name: "token with underscore and dash", + input: "CUSTOM-METHOD_1", + expectedToken: "CUSTOM-METHOD_1", + expectedRemain: "", + expectError: false, + }, + { + name: "token stops at comma", + input: "GET,POST", + expectedToken: "GET", + expectedRemain: ",POST", + expectError: false, + }, + { + name: "token stops at semicolon", + input: "GET;charset=utf-8", + expectedToken: "GET", + expectedRemain: ";charset=utf-8", + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + token, remain, err := parseHTTPToken(tt.input) + + if tt.expectError { + if err == nil { + t.Errorf("expected error but got none") + } + return + } + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + if token != tt.expectedToken { + t.Errorf("expected token %q, got %q", tt.expectedToken, token) + } + + if remain != tt.expectedRemain { + t.Errorf("expected remaining %q, got %q", tt.expectedRemain, remain) + } + }) + } } From 18aa4a4f7278998358fe752d69f2d7119dafa419 Mon Sep 17 00:00:00 2001 From: Benjamin Date: Mon, 22 Sep 2025 09:54:53 -0500 Subject: [PATCH 02/21] add parsing hosts --- rules/rules.go | 84 +++++++++++++ rules/rules_test.go | 298 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 382 insertions(+) diff --git a/rules/rules.go b/rules/rules.go index 5baf368..765e0be 100644 --- a/rules/rules.go +++ b/rules/rules.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "log/slog" + "strings" ) type Evaluator interface { @@ -81,6 +82,89 @@ func isHTTPTokenChar(c byte) bool { } } +// Represents a valid host. +// https://datatracker.ietf.org/doc/html/rfc952 +// https://datatracker.ietf.org/doc/html/rfc1123#page-13 +type host []label + +func parseHost(input string) (host host, rest string, err error) { + rest = input + var label label + + if input == "" { + return nil, "", errors.New("expected host, got empty string") + } + + // There should be at least one label. + label, rest, err = parseLabel(rest) + if err != nil { + return nil, "", err + } + host = append(host, label) + + // A host is just a bunch of labels separated by `.` characters. + var found bool + for { + rest, found = strings.CutPrefix(rest, ".") + if !found { + break + } + + label, rest, err = parseLabel(rest) + if err != nil { + return nil, "", err + } + host = append(host, label) + } + + return host, rest, nil +} + +// Represents a valid label in a hostname. For example, wobble in `wib-ble.wobble.com`. +type label string + +func parseLabel(rest string) (label, string, error) { + if rest == "" { + return "", "", errors.New("expected label, got empty string") + } + + // First try to get a valid leading char. Leading char in a label cannot be a hyphen. + if !isValidLabelChar(rest[0]) || rest[0] == '-' { + return "", "", fmt.Errorf("could not pull label from front of string: %s", rest) + } + + // Go until the next character is not a valid char + var i int + for i = 1; i < len(rest) && isValidLabelChar(rest[i]); i += 1 { + } + + // Final char in a label cannot be a hyphen. + if rest[i-1] == '-' { + return "", "", fmt.Errorf("invalid label: %s", rest[:i]) + } + + return label(rest[:i]), rest[i:], nil +} + +func isValidLabelChar(c byte) bool { + switch { + // Alpha numeric is fine. + case c >= 'A' && c <= 'Z': + return true + case c >= 'a' && c <= 'z': + return true + case c >= '0' && c <= '9': + return true + + // Hyphens are good + case c == '-': + return true + + default: + return false + } +} + func parseAllowRule(string) (Rule, error) { return Rule{}, nil } diff --git a/rules/rules_test.go b/rules/rules_test.go index 1bbf7a3..6882b96 100644 --- a/rules/rules_test.go +++ b/rules/rules_test.go @@ -129,3 +129,301 @@ func TestParseHTTPToken(t *testing.T) { }) } } + +func TestParseHost(t *testing.T) { + tests := []struct { + name string + input string + expectedHost host + expectedRest string + expectError bool + }{ + { + name: "empty string", + input: "", + expectedHost: nil, + expectedRest: "", + expectError: true, + }, + { + name: "simple domain", + input: "google.com", + expectedHost: host{label("google"), label("com")}, + expectedRest: "", + expectError: false, + }, + { + name: "subdomain", + input: "api.google.com", + expectedHost: host{label("api"), label("google"), label("com")}, + expectedRest: "", + expectError: false, + }, + { + name: "single label", + input: "localhost", + expectedHost: host{label("localhost")}, + expectedRest: "", + expectError: false, + }, + { + name: "domain with trailing content", + input: "example.org/path", + expectedHost: host{label("example"), label("org")}, + expectedRest: "/path", + expectError: false, + }, + { + name: "domain with port", + input: "localhost:8080", + expectedHost: host{label("localhost")}, + expectedRest: ":8080", + expectError: false, + }, + { + name: "numeric labels", + input: "192.168.1.1", + expectedHost: host{label("192"), label("168"), label("1"), label("1")}, + expectedRest: "", + expectError: false, + }, + { + name: "hyphenated domain", + input: "my-site.example-domain.co.uk", + expectedHost: host{label("my-site"), label("example-domain"), label("co"), label("uk")}, + expectedRest: "", + expectError: false, + }, + { + name: "alphanumeric labels", + input: "a1b2c3.test123.com", + expectedHost: host{label("a1b2c3"), label("test123"), label("com")}, + expectedRest: "", + expectError: false, + }, + { + name: "starts with hyphen", + input: "-invalid.com", + expectedHost: nil, + expectedRest: "", + expectError: true, + }, + { + name: "ends with hyphen", + input: "invalid-.com", + expectedHost: nil, + expectedRest: "", + expectError: true, + }, + { + name: "label ends with hyphen", + input: "test.invalid-.com", + expectedHost: nil, + expectedRest: "", + expectError: true, + }, + { + name: "invalid character", + input: "test@example.com", + expectedHost: host{label("test")}, + expectedRest: "@example.com", + expectError: false, + }, + { + name: "empty label", + input: "test..com", + expectedHost: nil, + expectedRest: "", + expectError: true, + }, + { + name: "trailing dot", + input: "example.com.", + expectedHost: nil, + expectedRest: "", + expectError: true, + }, + { + name: "single character labels", + input: "a.b.c", + expectedHost: host{label("a"), label("b"), label("c")}, + expectedRest: "", + expectError: false, + }, + { + name: "mixed case", + input: "Example.COM", + expectedHost: host{label("Example"), label("COM")}, + expectedRest: "", + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + hostResult, rest, err := parseHost(tt.input) + + if tt.expectError { + if err == nil { + t.Errorf("expected error but got none") + } + return + } + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + if len(hostResult) != len(tt.expectedHost) { + t.Errorf("expected host length %d, got %d", len(tt.expectedHost), len(hostResult)) + return + } + + for i, expectedLabel := range tt.expectedHost { + if hostResult[i] != expectedLabel { + t.Errorf("expected label[%d] %q, got %q", i, expectedLabel, hostResult[i]) + } + } + + if rest != tt.expectedRest { + t.Errorf("expected remaining %q, got %q", tt.expectedRest, rest) + } + }) + } +} + +func TestParseLabel(t *testing.T) { + tests := []struct { + name string + input string + expectedLabel label + expectedRest string + expectError bool + }{ + { + name: "empty string", + input: "", + expectedLabel: "", + expectedRest: "", + expectError: true, + }, + { + name: "simple label", + input: "test", + expectedLabel: "test", + expectedRest: "", + expectError: false, + }, + { + name: "label with dot", + input: "test.com", + expectedLabel: "test", + expectedRest: ".com", + expectError: false, + }, + { + name: "label with hyphen", + input: "my-site", + expectedLabel: "my-site", + expectedRest: "", + expectError: false, + }, + { + name: "alphanumeric label", + input: "test123", + expectedLabel: "test123", + expectedRest: "", + expectError: false, + }, + { + name: "starts with hyphen", + input: "-invalid", + expectedLabel: "", + expectedRest: "", + expectError: true, + }, + { + name: "ends with hyphen", + input: "invalid-", + expectedLabel: "", + expectedRest: "", + expectError: true, + }, + { + name: "ends with hyphen followed by dot", + input: "invalid-.com", + expectedLabel: "", + expectedRest: "", + expectError: true, + }, + { + name: "single character", + input: "a", + expectedLabel: "a", + expectedRest: "", + expectError: false, + }, + { + name: "numeric label", + input: "123", + expectedLabel: "123", + expectedRest: "", + expectError: false, + }, + { + name: "mixed case", + input: "Test", + expectedLabel: "Test", + expectedRest: "", + expectError: false, + }, + { + name: "invalid character", + input: "test@invalid", + expectedLabel: "test", + expectedRest: "@invalid", + expectError: false, + }, + { + name: "starts with number", + input: "1test", + expectedLabel: "1test", + expectedRest: "", + expectError: false, + }, + { + name: "label with trailing slash", + input: "api/path", + expectedLabel: "api", + expectedRest: "/path", + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + labelResult, rest, err := parseLabel(tt.input) + + if tt.expectError { + if err == nil { + t.Errorf("expected error but got none") + } + return + } + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + if labelResult != tt.expectedLabel { + t.Errorf("expected label %q, got %q", tt.expectedLabel, labelResult) + } + + if rest != tt.expectedRest { + t.Errorf("expected remaining %q, got %q", tt.expectedRest, rest) + } + }) + } +} From 3fb4ef05324f1ef5258e0c084c423d5e22b7520a Mon Sep 17 00:00:00 2001 From: Benjamin Date: Mon, 22 Sep 2025 10:05:00 -0500 Subject: [PATCH 03/21] parse path --- rules/rules.go | 127 +++++++++++++++- rules/rules_test.go | 350 ++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 462 insertions(+), 15 deletions(-) diff --git a/rules/rules.go b/rules/rules.go index 765e0be..6bca613 100644 --- a/rules/rules.go +++ b/rules/rules.go @@ -85,9 +85,7 @@ func isHTTPTokenChar(c byte) bool { // Represents a valid host. // https://datatracker.ietf.org/doc/html/rfc952 // https://datatracker.ietf.org/doc/html/rfc1123#page-13 -type host []label - -func parseHost(input string) (host host, rest string, err error) { +func parseHost(input string) (host []label, rest string, err error) { rest = input var label label @@ -165,6 +163,129 @@ func isValidLabelChar(c byte) bool { } } +func parsePath(input string) ([]segment, string, error) { + if input == "" { + return nil, "", nil + } + + var segments []segment + rest := input + + // If the path doesn't start with '/', it's not a valid absolute path + // But we'll be flexible and parse relative paths too + for { + // Skip leading slash if present + if rest != "" && rest[0] == '/' { + rest = rest[1:] + } + + // If we've consumed all input, we're done + if rest == "" { + break + } + + // Parse the next segment + seg, remaining, err := parsePathSegment(rest) + if err != nil { + return nil, "", err + } + + // If we got an empty segment and there's still input, + // it means we hit an invalid character + if seg == "" && remaining != "" { + break + } + + segments = append(segments, seg) + rest = remaining + + // If there's no slash after the segment, we're done parsing the path + if rest == "" || rest[0] != '/' { + break + } + } + + return segments, rest, nil +} + +// Represents a valid url path segment. +type segment string + +func parsePathSegment(input string) (segment, string, error) { + if input == "" { + return "", "", nil + } + + var i int + for i = 0; i < len(input); i++ { + c := input[i] + + // Check for percent-encoded characters (%XX) + if c == '%' { + if i+2 >= len(input) || !isHexDigit(input[i+1]) || !isHexDigit(input[i+2]) { + break + } + i += 2 + continue + } + + // Check for valid pchar characters + if !isPChar(c) { + break + } + } + + return segment(input[:i]), input[i:], nil +} + +// isUnreserved returns true if the character is unreserved per RFC 3986 +// unreserved = ALPHA / DIGIT / "-" / "." / "_" / "~" +func isUnreserved(c byte) bool { + return (c >= 'A' && c <= 'Z') || + (c >= 'a' && c <= 'z') || + (c >= '0' && c <= '9') || + c == '-' || c == '.' || c == '_' || c == '~' +} + +// isSubDelim returns true if the character is a sub-delimiter per RFC 3986 +// sub-delims = "!" / "$" / "&" / "'" / "(" / ")" / "*" / "+" / "," / ";" / "=" +func isSubDelim(c byte) bool { + return c == '!' || c == '$' || c == '&' || c == '\'' || + c == '(' || c == ')' || c == '*' || c == '+' || + c == ',' || c == ';' || c == '=' +} + +// isPChar returns true if the character is valid in a path segment (excluding percent-encoded) +// pchar = unreserved / sub-delims / ":" / "@" +func isPChar(c byte) bool { + return isUnreserved(c) || isSubDelim(c) || c == ':' || c == '@' +} + +// isHexDigit returns true if the character is a hexadecimal digit +func isHexDigit(c byte) bool { + return (c >= '0' && c <= '9') || + (c >= 'A' && c <= 'F') || + (c >= 'a' && c <= 'f') +} + +// parseKey parses the predefined keys that the cli can handle. Also strips the `=` following the key. +func parseKey(rule string) (string, string, error) { + if rule == "" { + return "", "", errors.New("expected key") + } + + // These are the current keys we support. + keys := []string{"method", "domain", "path"} + + for _, key := range keys { + if rest, found := strings.CutPrefix(rule, key+"="); found { + return key, rest, nil + } + } + + return "", "", errors.New("expected key") +} + func parseAllowRule(string) (Rule, error) { return Rule{}, nil } diff --git a/rules/rules_test.go b/rules/rules_test.go index 6882b96..22de729 100644 --- a/rules/rules_test.go +++ b/rules/rules_test.go @@ -134,7 +134,7 @@ func TestParseHost(t *testing.T) { tests := []struct { name string input string - expectedHost host + expectedHost []label expectedRest string expectError bool }{ @@ -148,56 +148,56 @@ func TestParseHost(t *testing.T) { { name: "simple domain", input: "google.com", - expectedHost: host{label("google"), label("com")}, + expectedHost: []label{label("google"), label("com")}, expectedRest: "", expectError: false, }, { name: "subdomain", input: "api.google.com", - expectedHost: host{label("api"), label("google"), label("com")}, + expectedHost: []label{label("api"), label("google"), label("com")}, expectedRest: "", expectError: false, }, { name: "single label", input: "localhost", - expectedHost: host{label("localhost")}, + expectedHost: []label{label("localhost")}, expectedRest: "", expectError: false, }, { name: "domain with trailing content", input: "example.org/path", - expectedHost: host{label("example"), label("org")}, + expectedHost: []label{label("example"), label("org")}, expectedRest: "/path", expectError: false, }, { name: "domain with port", input: "localhost:8080", - expectedHost: host{label("localhost")}, + expectedHost: []label{label("localhost")}, expectedRest: ":8080", expectError: false, }, { name: "numeric labels", input: "192.168.1.1", - expectedHost: host{label("192"), label("168"), label("1"), label("1")}, + expectedHost: []label{label("192"), label("168"), label("1"), label("1")}, expectedRest: "", expectError: false, }, { name: "hyphenated domain", input: "my-site.example-domain.co.uk", - expectedHost: host{label("my-site"), label("example-domain"), label("co"), label("uk")}, + expectedHost: []label{label("my-site"), label("example-domain"), label("co"), label("uk")}, expectedRest: "", expectError: false, }, { name: "alphanumeric labels", input: "a1b2c3.test123.com", - expectedHost: host{label("a1b2c3"), label("test123"), label("com")}, + expectedHost: []label{label("a1b2c3"), label("test123"), label("com")}, expectedRest: "", expectError: false, }, @@ -225,7 +225,7 @@ func TestParseHost(t *testing.T) { { name: "invalid character", input: "test@example.com", - expectedHost: host{label("test")}, + expectedHost: []label{label("test")}, expectedRest: "@example.com", expectError: false, }, @@ -246,14 +246,14 @@ func TestParseHost(t *testing.T) { { name: "single character labels", input: "a.b.c", - expectedHost: host{label("a"), label("b"), label("c")}, + expectedHost: []label{label("a"), label("b"), label("c")}, expectedRest: "", expectError: false, }, { name: "mixed case", input: "Example.COM", - expectedHost: host{label("Example"), label("COM")}, + expectedHost: []label{label("Example"), label("COM")}, expectedRest: "", expectError: false, }, @@ -427,3 +427,329 @@ func TestParseLabel(t *testing.T) { }) } } + +func TestParsePathSegment(t *testing.T) { + tests := []struct { + name string + input string + expectedSegment segment + expectedRest string + expectError bool + }{ + { + name: "empty string", + input: "", + expectedSegment: "", + expectedRest: "", + expectError: false, + }, + { + name: "simple segment", + input: "api", + expectedSegment: "api", + expectedRest: "", + expectError: false, + }, + { + name: "segment with slash", + input: "api/users", + expectedSegment: "api", + expectedRest: "/users", + expectError: false, + }, + { + name: "segment with unreserved chars", + input: "my-file.txt_version~1", + expectedSegment: "my-file.txt_version~1", + expectedRest: "", + expectError: false, + }, + { + name: "segment with sub-delims", + input: "filter='test'&sort=name", + expectedSegment: "filter='test'&sort=name", + expectedRest: "", + expectError: false, + }, + { + name: "segment with colon and at", + input: "user:password@domain", + expectedSegment: "user:password@domain", + expectedRest: "", + expectError: false, + }, + { + name: "percent encoded segment", + input: "hello%20world", + expectedSegment: "hello%20world", + expectedRest: "", + expectError: false, + }, + { + name: "multiple percent encoded", + input: "%3Fkey%3Dvalue%26other%3D123", + expectedSegment: "%3Fkey%3Dvalue%26other%3D123", + expectedRest: "", + expectError: false, + }, + { + name: "invalid percent encoding incomplete", + input: "test%2", + expectedSegment: "test", + expectedRest: "%2", + expectError: false, + }, + { + name: "invalid percent encoding non-hex", + input: "test%ZZ", + expectedSegment: "test", + expectedRest: "%ZZ", + expectError: false, + }, + { + name: "segment stops at space", + input: "test hello", + expectedSegment: "test", + expectedRest: " hello", + expectError: false, + }, + { + name: "segment with question mark stops", + input: "path?query=value", + expectedSegment: "path", + expectedRest: "?query=value", + expectError: false, + }, + { + name: "segment with hash stops", + input: "path#fragment", + expectedSegment: "path", + expectedRest: "#fragment", + expectError: false, + }, + { + name: "numeric segment", + input: "123456", + expectedSegment: "123456", + expectedRest: "", + expectError: false, + }, + { + name: "mixed alphanumeric", + input: "abc123XYZ", + expectedSegment: "abc123XYZ", + expectedRest: "", + expectError: false, + }, + { + name: "all sub-delims", + input: "!$&'()*+,;=", + expectedSegment: "!$&'()*+,;=", + expectedRest: "", + expectError: false, + }, + { + name: "segment with brackets", + input: "test[bracket]", + expectedSegment: "test", + expectedRest: "[bracket]", + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + segment, rest, err := parsePathSegment(tt.input) + + if tt.expectError { + if err == nil { + t.Errorf("expected error but got none") + } + return + } + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + if segment != tt.expectedSegment { + t.Errorf("expected segment %q, got %q", tt.expectedSegment, segment) + } + + if rest != tt.expectedRest { + t.Errorf("expected rest %q, got %q", tt.expectedRest, rest) + } + }) + } +} + +func TestParsePath(t *testing.T) { + tests := []struct { + name string + input string + expectedSegments []segment + expectedRest string + expectError bool + }{ + { + name: "empty string", + input: "", + expectedSegments: nil, + expectedRest: "", + expectError: false, + }, + { + name: "single segment", + input: "/api", + expectedSegments: []segment{"api"}, + expectedRest: "", + expectError: false, + }, + { + name: "multiple segments", + input: "/api/v1/users", + expectedSegments: []segment{"api", "v1", "users"}, + expectedRest: "", + expectError: false, + }, + { + name: "relative path", + input: "api/users", + expectedSegments: []segment{"api", "users"}, + expectedRest: "", + expectError: false, + }, + { + name: "path with trailing slash", + input: "/api/users/", + expectedSegments: []segment{"api", "users"}, + expectedRest: "", + expectError: false, + }, + { + name: "path with query string", + input: "/api/users?limit=10", + expectedSegments: []segment{"api", "users"}, + expectedRest: "?limit=10", + expectError: false, + }, + { + name: "path with fragment", + input: "/docs/api#authentication", + expectedSegments: []segment{"docs", "api"}, + expectedRest: "#authentication", + expectError: false, + }, + { + name: "path with encoded segments", + input: "/api/hello%20world/test", + expectedSegments: []segment{"api", "hello%20world", "test"}, + expectedRest: "", + expectError: false, + }, + { + name: "path with special chars", + input: "/api/filter='test'&sort=name/results", + expectedSegments: []segment{"api", "filter='test'&sort=name", "results"}, + expectedRest: "", + expectError: false, + }, + { + name: "just slash", + input: "/", + expectedSegments: nil, + expectedRest: "", + expectError: false, + }, + { + name: "empty segments", + input: "/api//users", + expectedSegments: []segment{"api"}, + expectedRest: "/users", + expectError: false, + }, + { + name: "path with port-like segment", + input: "/host:8080/status", + expectedSegments: []segment{"host:8080", "status"}, + expectedRest: "", + expectError: false, + }, + { + name: "path stops at space", + input: "/api/test hello", + expectedSegments: []segment{"api", "test"}, + expectedRest: " hello", + expectError: false, + }, + { + name: "path with hyphens and underscores", + input: "/my-api/user_data/file-name.txt", + expectedSegments: []segment{"my-api", "user_data", "file-name.txt"}, + expectedRest: "", + expectError: false, + }, + { + name: "path with tildes", + input: "/api/~user/docs~backup", + expectedSegments: []segment{"api", "~user", "docs~backup"}, + expectedRest: "", + expectError: false, + }, + { + name: "numeric segments", + input: "/api/v2/users/12345", + expectedSegments: []segment{"api", "v2", "users", "12345"}, + expectedRest: "", + expectError: false, + }, + { + name: "single character segments", + input: "/a/b/c", + expectedSegments: []segment{"a", "b", "c"}, + expectedRest: "", + expectError: false, + }, + { + name: "path with at symbol", + input: "/user@domain.com/profile", + expectedSegments: []segment{"user@domain.com", "profile"}, + expectedRest: "", + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + segments, rest, err := parsePath(tt.input) + + if tt.expectError { + if err == nil { + t.Errorf("expected error but got none") + } + return + } + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + if len(segments) != len(tt.expectedSegments) { + t.Errorf("expected %d segments, got %d", len(tt.expectedSegments), len(segments)) + return + } + + for i, expectedSeg := range tt.expectedSegments { + if segments[i] != expectedSeg { + t.Errorf("expected segment[%d] %q, got %q", i, expectedSeg, segments[i]) + } + } + + if rest != tt.expectedRest { + t.Errorf("expected rest %q, got %q", tt.expectedRest, rest) + } + }) + } +} From ec944b69f7535df8e23b957456a6cfde2a66a233 Mon Sep 17 00:00:00 2001 From: Benjamin Date: Mon, 22 Sep 2025 10:41:08 -0500 Subject: [PATCH 04/21] update names to reflect we're pattern parsing --- rules/rules.go | 217 +++++++++++++++++++++++++++++++++-------- rules/rules_test.go | 229 +++++++++++++++++++++++++++++++++++++------- 2 files changed, 369 insertions(+), 77 deletions(-) diff --git a/rules/rules.go b/rules/rules.go index 6bca613..5be307d 100644 --- a/rules/rules.go +++ b/rules/rules.go @@ -17,47 +17,52 @@ type Rule struct { // The path segments of the url // nil means all paths allowed // a path segment of `*` acts as a wild card. - Path []string + PathPattern []segmentPattern // The labels of the host, i.e. ["google", "com"] - // nil means no hosts allowed - // subdomains automatically match - Host []string + // nil means all hosts allowed + // A label of `*` acts as a wild card. + HostPattern []labelPattern // The allowed http methods // nil means all methods allowed - Methods map[string]struct{} + MethodPatterns map[methodPattern]struct{} // Raw rule string for logging - Raw string + Raw string } -type httpToken string +type methodPattern string + +// An asterisk is treated as matching any method +func (t methodPattern) matches(input string) bool { + return t == "*" || string(t) == input +} // Beyond the 9 methods defined in HTTP 1.1, there actually are many more seldom used extension methods by // various systems. // https://datatracker.ietf.org/doc/html/rfc7230#section-3.2.6 -func parseHTTPToken(token string) (httpToken, string, error) { +func parseMethodPattern(token string) (methodPattern, string, error) { if token == "" { return "", "", errors.New("expected http token, got empty string") } - return doParseHTTPToken(token, nil) + return doParseMethodPattern(token, nil) } -func doParseHTTPToken(token string, acc []byte) (httpToken, string, error) { +func doParseMethodPattern(token string, acc []byte) (methodPattern, string, error) { // BASE CASE: if the token passed in is empty, we're done parsing if token == "" { - return httpToken(acc), "", nil + return methodPattern(acc), "", nil } // If the next byte in the string is not a valid http token character, we're done parsing. if !isHTTPTokenChar(token[0]) { - return httpToken(acc), token, nil + return methodPattern(acc), token, nil } // The next character is valid, so the http token continues acc = append(acc, token[0]) - return doParseHTTPToken(token[1:], acc) + return doParseMethodPattern(token[1:], acc) } // The valid characters that can be in an http token (like the lexer/parser kind of token). @@ -85,16 +90,16 @@ func isHTTPTokenChar(c byte) bool { // Represents a valid host. // https://datatracker.ietf.org/doc/html/rfc952 // https://datatracker.ietf.org/doc/html/rfc1123#page-13 -func parseHost(input string) (host []label, rest string, err error) { +func parseHostPattern(input string) (host []labelPattern, rest string, err error) { rest = input - var label label + var label labelPattern if input == "" { return nil, "", errors.New("expected host, got empty string") } // There should be at least one label. - label, rest, err = parseLabel(rest) + label, rest, err = parseLabelPattern(rest) if err != nil { return nil, "", err } @@ -108,7 +113,7 @@ func parseHost(input string) (host []label, rest string, err error) { break } - label, rest, err = parseLabel(rest) + label, rest, err = parseLabelPattern(rest) if err != nil { return nil, "", err } @@ -119,9 +124,9 @@ func parseHost(input string) (host []label, rest string, err error) { } // Represents a valid label in a hostname. For example, wobble in `wib-ble.wobble.com`. -type label string +type labelPattern string -func parseLabel(rest string) (label, string, error) { +func parseLabelPattern(rest string) (labelPattern, string, error) { if rest == "" { return "", "", errors.New("expected label, got empty string") } @@ -141,7 +146,7 @@ func parseLabel(rest string) (label, string, error) { return "", "", fmt.Errorf("invalid label: %s", rest[:i]) } - return label(rest[:i]), rest[i:], nil + return labelPattern(rest[:i]), rest[i:], nil } func isValidLabelChar(c byte) bool { @@ -163,12 +168,12 @@ func isValidLabelChar(c byte) bool { } } -func parsePath(input string) ([]segment, string, error) { +func parsePathPattern(input string) ([]segmentPattern, string, error) { if input == "" { return nil, "", nil } - var segments []segment + var segments []segmentPattern rest := input // If the path doesn't start with '/', it's not a valid absolute path @@ -185,12 +190,12 @@ func parsePath(input string) ([]segment, string, error) { } // Parse the next segment - seg, remaining, err := parsePathSegment(rest) + seg, remaining, err := parsePathSegmentPattern(rest) if err != nil { return nil, "", err } - // If we got an empty segment and there's still input, + // If we got an empty segment and there's still input, // it means we hit an invalid character if seg == "" && remaining != "" { break @@ -208,10 +213,10 @@ func parsePath(input string) ([]segment, string, error) { return segments, rest, nil } -// Represents a valid url path segment. -type segment string +// Represents a valid url path segmentPattern. +type segmentPattern string -func parsePathSegment(input string) (segment, string, error) { +func parsePathSegmentPattern(input string) (segmentPattern, string, error) { if input == "" { return "", "", nil } @@ -219,7 +224,7 @@ func parsePathSegment(input string) (segment, string, error) { var i int for i = 0; i < len(input); i++ { c := input[i] - + // Check for percent-encoded characters (%XX) if c == '%' { if i+2 >= len(input) || !isHexDigit(input[i+1]) || !isHexDigit(input[i+2]) { @@ -228,14 +233,14 @@ func parsePathSegment(input string) (segment, string, error) { i += 2 continue } - + // Check for valid pchar characters if !isPChar(c) { break } } - return segment(input[:i]), input[i:], nil + return segmentPattern(input[:i]), input[i:], nil } // isUnreserved returns true if the character is unreserved per RFC 3986 @@ -286,8 +291,72 @@ func parseKey(rule string) (string, string, error) { return "", "", errors.New("expected key") } -func parseAllowRule(string) (Rule, error) { - return Rule{}, nil +func parseAllowRule(ruleStr string) (Rule, error) { + rule := Rule{ + Raw: ruleStr, + } + + rest := ruleStr + + for rest != "" { + // Parse the key + key, valueRest, err := parseKey(rest) + if err != nil { + return Rule{}, fmt.Errorf("failed to parse key: %v", err) + } + + // Parse the value based on the key type + switch key { + case "method": + token, remaining, err := parseMethodPattern(valueRest) + if err != nil { + return Rule{}, fmt.Errorf("failed to parse method: %v", err) + } + + // Initialize Methods map if needed + if rule.MethodPatterns == nil { + rule.MethodPatterns = make(map[methodPattern]struct{}) + } + rule.MethodPatterns[token] = struct{}{} + rest = remaining + + case "domain": + hostLabels, remaining, err := parseHostPattern(valueRest) + if err != nil { + return Rule{}, fmt.Errorf("failed to parse domain: %v", err) + } + + // Convert labels to strings in reverse order (TLD first) + rule.HostPattern = make([]labelPattern, len(hostLabels)) + for i, label := range hostLabels { + rule.HostPattern[len(hostLabels)-1-i] = label + } + rest = remaining + + case "path": + segments, remaining, err := parsePathPattern(valueRest) + if err != nil { + return Rule{}, fmt.Errorf("failed to parse path: %v", err) + } + + // Convert segments to strings + rule.PathPattern = make([]segmentPattern, len(segments)) + for i, segment := range segments { + rule.PathPattern[i] = segment + } + rest = remaining + + default: + return Rule{}, fmt.Errorf("unknown key: %s", key) + } + + // Skip whitespace or comma separators + for rest != "" && (rest[0] == ' ' || rest[0] == '\t' || rest[0] == ',') { + rest = rest[1:] + } + } + + return rule, nil } // ParseAllowSpecs parses a slice of --allow specs into allow Rules. @@ -342,17 +411,85 @@ func (re *Engine) Evaluate(method, url string) Result { } } -// Matches checks if the rule matches the given method and URL using wildcard patterns -func (re *Engine) matches(r Rule, method, url string) bool { - // If the rule doesn't have any method filters, don't restrict the allowed methods - if r.Methods == nil { - return true +type protocol string + +func parseProtocol(input string) (protocol, string, error) { + if input == "" { + return "", "", errors.New("expected protocol, got empty string") } - // If the rule has method filters and the provided method is not one of them, block the request. - if _, methodIsAllowed := r.Methods[method]; !methodIsAllowed { - return false + // Look for "://" separator + if idx := strings.Index(input, "://"); idx > 0 { + protocolPart := input[:idx] + rest := input[idx+3:] + + // Validate protocol characters (scheme per RFC 3986) + // scheme = ALPHA *( ALPHA / DIGIT / "+" / "-" / "." ) + if len(protocolPart) == 0 { + return "", "", errors.New("empty protocol") + } + + // First character must be alpha + if !((protocolPart[0] >= 'A' && protocolPart[0] <= 'Z') || + (protocolPart[0] >= 'a' && protocolPart[0] <= 'z')) { + return "", "", errors.New("protocol must start with a letter") + } + + // Rest can be alphanumeric, +, -, or . + for i := 1; i < len(protocolPart); i++ { + c := protocolPart[i] + if !((c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || + (c >= '0' && c <= '9') || c == '+' || c == '-' || c == '.') { + return "", "", fmt.Errorf("invalid character in protocol: %c", c) + } + } + + return protocol(protocolPart), rest, nil } + // No protocol found + return "", input, nil +} + +type port uint16 + +func parsePort(input string) (port, string, error) { + if input == "" { + return 0, "", nil + } + + // Port must start with ':' + if input[0] != ':' { + return 0, input, nil + } + + // Find the end of the port number + i := 1 + for i < len(input) && input[i] >= '0' && input[i] <= '9' { + i++ + } + + // No digits found after ':' + if i == 1 { + return 0, "", errors.New("expected port number after ':'") + } + + portStr := input[1:i] + rest := input[i:] + + // Convert to uint16 (port range is 0-65535) + portNum := 0 + for _, digit := range portStr { + portNum = portNum*10 + int(digit-'0') + if portNum > 65535 { + return 0, "", errors.New("port number too large (max 65535)") + } + } + + return port(portNum), rest, nil +} + +// Matches checks if the rule matches the given method and URL using wildcard patterns +func (re *Engine) matches(r Rule, method, url string) bool { return true } diff --git a/rules/rules_test.go b/rules/rules_test.go index 22de729..3cbc59f 100644 --- a/rules/rules_test.go +++ b/rules/rules_test.go @@ -6,7 +6,7 @@ func TestParseHTTPToken(t *testing.T) { tests := []struct { name string input string - expectedToken httpToken + expectedToken methodPattern expectedRemain string expectError bool }{ @@ -105,7 +105,7 @@ func TestParseHTTPToken(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - token, remain, err := parseHTTPToken(tt.input) + token, remain, err := parseMethodPattern(tt.input) if tt.expectError { if err == nil { @@ -134,7 +134,7 @@ func TestParseHost(t *testing.T) { tests := []struct { name string input string - expectedHost []label + expectedHost []labelPattern expectedRest string expectError bool }{ @@ -148,56 +148,56 @@ func TestParseHost(t *testing.T) { { name: "simple domain", input: "google.com", - expectedHost: []label{label("google"), label("com")}, + expectedHost: []labelPattern{labelPattern("google"), labelPattern("com")}, expectedRest: "", expectError: false, }, { name: "subdomain", input: "api.google.com", - expectedHost: []label{label("api"), label("google"), label("com")}, + expectedHost: []labelPattern{labelPattern("api"), labelPattern("google"), labelPattern("com")}, expectedRest: "", expectError: false, }, { name: "single label", input: "localhost", - expectedHost: []label{label("localhost")}, + expectedHost: []labelPattern{labelPattern("localhost")}, expectedRest: "", expectError: false, }, { name: "domain with trailing content", input: "example.org/path", - expectedHost: []label{label("example"), label("org")}, + expectedHost: []labelPattern{labelPattern("example"), labelPattern("org")}, expectedRest: "/path", expectError: false, }, { name: "domain with port", input: "localhost:8080", - expectedHost: []label{label("localhost")}, + expectedHost: []labelPattern{labelPattern("localhost")}, expectedRest: ":8080", expectError: false, }, { name: "numeric labels", input: "192.168.1.1", - expectedHost: []label{label("192"), label("168"), label("1"), label("1")}, + expectedHost: []labelPattern{labelPattern("192"), labelPattern("168"), labelPattern("1"), labelPattern("1")}, expectedRest: "", expectError: false, }, { name: "hyphenated domain", input: "my-site.example-domain.co.uk", - expectedHost: []label{label("my-site"), label("example-domain"), label("co"), label("uk")}, + expectedHost: []labelPattern{labelPattern("my-site"), labelPattern("example-domain"), labelPattern("co"), labelPattern("uk")}, expectedRest: "", expectError: false, }, { name: "alphanumeric labels", input: "a1b2c3.test123.com", - expectedHost: []label{label("a1b2c3"), label("test123"), label("com")}, + expectedHost: []labelPattern{labelPattern("a1b2c3"), labelPattern("test123"), labelPattern("com")}, expectedRest: "", expectError: false, }, @@ -225,7 +225,7 @@ func TestParseHost(t *testing.T) { { name: "invalid character", input: "test@example.com", - expectedHost: []label{label("test")}, + expectedHost: []labelPattern{labelPattern("test")}, expectedRest: "@example.com", expectError: false, }, @@ -246,14 +246,14 @@ func TestParseHost(t *testing.T) { { name: "single character labels", input: "a.b.c", - expectedHost: []label{label("a"), label("b"), label("c")}, + expectedHost: []labelPattern{labelPattern("a"), labelPattern("b"), labelPattern("c")}, expectedRest: "", expectError: false, }, { name: "mixed case", input: "Example.COM", - expectedHost: []label{label("Example"), label("COM")}, + expectedHost: []labelPattern{labelPattern("Example"), labelPattern("COM")}, expectedRest: "", expectError: false, }, @@ -261,7 +261,7 @@ func TestParseHost(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - hostResult, rest, err := parseHost(tt.input) + hostResult, rest, err := parseHostPattern(tt.input) if tt.expectError { if err == nil { @@ -297,7 +297,7 @@ func TestParseLabel(t *testing.T) { tests := []struct { name string input string - expectedLabel label + expectedLabel labelPattern expectedRest string expectError bool }{ @@ -403,7 +403,7 @@ func TestParseLabel(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - labelResult, rest, err := parseLabel(tt.input) + labelResult, rest, err := parseLabelPattern(tt.input) if tt.expectError { if err == nil { @@ -432,7 +432,7 @@ func TestParsePathSegment(t *testing.T) { tests := []struct { name string input string - expectedSegment segment + expectedSegment segmentPattern expectedRest string expectError bool }{ @@ -559,7 +559,7 @@ func TestParsePathSegment(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - segment, rest, err := parsePathSegment(tt.input) + segment, rest, err := parsePathSegmentPattern(tt.input) if tt.expectError { if err == nil { @@ -588,7 +588,7 @@ func TestParsePath(t *testing.T) { tests := []struct { name string input string - expectedSegments []segment + expectedSegments []segmentPattern expectedRest string expectError bool }{ @@ -602,56 +602,56 @@ func TestParsePath(t *testing.T) { { name: "single segment", input: "/api", - expectedSegments: []segment{"api"}, + expectedSegments: []segmentPattern{"api"}, expectedRest: "", expectError: false, }, { name: "multiple segments", input: "/api/v1/users", - expectedSegments: []segment{"api", "v1", "users"}, + expectedSegments: []segmentPattern{"api", "v1", "users"}, expectedRest: "", expectError: false, }, { name: "relative path", input: "api/users", - expectedSegments: []segment{"api", "users"}, + expectedSegments: []segmentPattern{"api", "users"}, expectedRest: "", expectError: false, }, { name: "path with trailing slash", input: "/api/users/", - expectedSegments: []segment{"api", "users"}, + expectedSegments: []segmentPattern{"api", "users"}, expectedRest: "", expectError: false, }, { name: "path with query string", input: "/api/users?limit=10", - expectedSegments: []segment{"api", "users"}, + expectedSegments: []segmentPattern{"api", "users"}, expectedRest: "?limit=10", expectError: false, }, { name: "path with fragment", input: "/docs/api#authentication", - expectedSegments: []segment{"docs", "api"}, + expectedSegments: []segmentPattern{"docs", "api"}, expectedRest: "#authentication", expectError: false, }, { name: "path with encoded segments", input: "/api/hello%20world/test", - expectedSegments: []segment{"api", "hello%20world", "test"}, + expectedSegments: []segmentPattern{"api", "hello%20world", "test"}, expectedRest: "", expectError: false, }, { name: "path with special chars", input: "/api/filter='test'&sort=name/results", - expectedSegments: []segment{"api", "filter='test'&sort=name", "results"}, + expectedSegments: []segmentPattern{"api", "filter='test'&sort=name", "results"}, expectedRest: "", expectError: false, }, @@ -665,56 +665,56 @@ func TestParsePath(t *testing.T) { { name: "empty segments", input: "/api//users", - expectedSegments: []segment{"api"}, + expectedSegments: []segmentPattern{"api"}, expectedRest: "/users", expectError: false, }, { name: "path with port-like segment", input: "/host:8080/status", - expectedSegments: []segment{"host:8080", "status"}, + expectedSegments: []segmentPattern{"host:8080", "status"}, expectedRest: "", expectError: false, }, { name: "path stops at space", input: "/api/test hello", - expectedSegments: []segment{"api", "test"}, + expectedSegments: []segmentPattern{"api", "test"}, expectedRest: " hello", expectError: false, }, { name: "path with hyphens and underscores", input: "/my-api/user_data/file-name.txt", - expectedSegments: []segment{"my-api", "user_data", "file-name.txt"}, + expectedSegments: []segmentPattern{"my-api", "user_data", "file-name.txt"}, expectedRest: "", expectError: false, }, { name: "path with tildes", input: "/api/~user/docs~backup", - expectedSegments: []segment{"api", "~user", "docs~backup"}, + expectedSegments: []segmentPattern{"api", "~user", "docs~backup"}, expectedRest: "", expectError: false, }, { name: "numeric segments", input: "/api/v2/users/12345", - expectedSegments: []segment{"api", "v2", "users", "12345"}, + expectedSegments: []segmentPattern{"api", "v2", "users", "12345"}, expectedRest: "", expectError: false, }, { name: "single character segments", input: "/a/b/c", - expectedSegments: []segment{"a", "b", "c"}, + expectedSegments: []segmentPattern{"a", "b", "c"}, expectedRest: "", expectError: false, }, { name: "path with at symbol", input: "/user@domain.com/profile", - expectedSegments: []segment{"user@domain.com", "profile"}, + expectedSegments: []segmentPattern{"user@domain.com", "profile"}, expectedRest: "", expectError: false, }, @@ -722,7 +722,7 @@ func TestParsePath(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - segments, rest, err := parsePath(tt.input) + segments, rest, err := parsePathPattern(tt.input) if tt.expectError { if err == nil { @@ -753,3 +753,158 @@ func TestParsePath(t *testing.T) { }) } } + +func TestParseAllowRule(t *testing.T) { + tests := []struct { + name string + input string + expectedRule Rule + expectError bool + }{ + { + name: "empty string", + input: "", + expectedRule: Rule{ + Raw: "", + }, + expectError: false, + }, + { + name: "method only", + input: "method=GET", + expectedRule: Rule{ + Raw: "method=GET", + MethodPatterns: map[methodPattern]struct{}{methodPattern("GET"): {}}, + }, + expectError: false, + }, + { + name: "domain only", + input: "domain=google.com", + expectedRule: Rule{ + Raw: "domain=google.com", + HostPattern: []labelPattern{labelPattern("com"), labelPattern("google")}, + }, + expectError: false, + }, + { + name: "path only", + input: "path=/api/v1", + expectedRule: Rule{ + Raw: "path=/api/v1", + PathPattern: []segmentPattern{segmentPattern("api"), segmentPattern("v1")}, + }, + expectError: false, + }, + { + name: "method and domain", + input: "method=POST domain=api.example.com", + expectedRule: Rule{ + Raw: "method=POST domain=api.example.com", + MethodPatterns: map[methodPattern]struct{}{methodPattern("POST"): {}}, + HostPattern: []labelPattern{labelPattern("com"), labelPattern("example"), labelPattern("api")}, + }, + expectError: false, + }, + { + name: "all three keys", + input: "method=DELETE domain=test.com path=/resources/456", + expectedRule: Rule{ + Raw: "method=DELETE domain=test.com path=/resources/456", + MethodPatterns: map[methodPattern]struct{}{methodPattern("DELETE"): {}}, + HostPattern: []labelPattern{labelPattern("com"), labelPattern("test")}, + PathPattern: []segmentPattern{segmentPattern("resources"), segmentPattern("456")}, + }, + expectError: false, + }, + { + name: "invalid key", + input: "invalid=value", + expectedRule: Rule{}, + expectError: true, + }, + { + name: "missing value", + input: "method=", + expectedRule: Rule{}, + expectError: true, + }, + { + name: "invalid method", + input: "method=@invalid", + expectedRule: Rule{}, + expectError: true, + }, + { + name: "invalid domain", + input: "domain=-invalid.com", + expectedRule: Rule{}, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rule, err := parseAllowRule(tt.input) + + if tt.expectError { + if err == nil { + t.Errorf("expected error but got none") + } + return + } + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + // Check Raw field + if rule.Raw != tt.expectedRule.Raw { + t.Errorf("expected Raw %q, got %q", tt.expectedRule.Raw, rule.Raw) + } + + // Check MethodPatterns + if tt.expectedRule.MethodPatterns == nil { + if rule.MethodPatterns != nil { + t.Errorf("expected MethodPatterns to be nil, got %v", rule.MethodPatterns) + } + } else { + if rule.MethodPatterns == nil { + t.Errorf("expected MethodPatterns %v, got nil", tt.expectedRule.MethodPatterns) + } else { + if len(rule.MethodPatterns) != len(tt.expectedRule.MethodPatterns) { + t.Errorf("expected %d methods, got %d", len(tt.expectedRule.MethodPatterns), len(rule.MethodPatterns)) + } + for method := range tt.expectedRule.MethodPatterns { + if _, exists := rule.MethodPatterns[method]; !exists { + t.Errorf("expected method %q not found", method) + } + } + } + } + + // Check HostPattern + if len(rule.HostPattern) != len(tt.expectedRule.HostPattern) { + t.Errorf("expected HostPattern length %d, got %d", len(tt.expectedRule.HostPattern), len(rule.HostPattern)) + } else { + for i, expectedLabel := range tt.expectedRule.HostPattern { + if rule.HostPattern[i] != expectedLabel { + t.Errorf("expected HostPattern[%d] %q, got %q", i, expectedLabel, rule.HostPattern[i]) + } + } + } + + // Check PathPattern + if len(rule.PathPattern) != len(tt.expectedRule.PathPattern) { + t.Errorf("expected PathPattern length %d, got %d", len(tt.expectedRule.PathPattern), len(rule.PathPattern)) + } else { + for i, expectedSegment := range tt.expectedRule.PathPattern { + if rule.PathPattern[i] != expectedSegment { + t.Errorf("expected PathPattern[%d] %q, got %q", i, expectedSegment, rule.PathPattern[i]) + } + } + } + }) + } +} From 62edfb2f98169b1f10419a37d43835016db08089 Mon Sep 17 00:00:00 2001 From: Benjamin Date: Mon, 22 Sep 2025 11:00:25 -0500 Subject: [PATCH 05/21] wildcard tests --- rules/rules.go | 106 +++++++------------------------ rules/rules_test.go | 150 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 173 insertions(+), 83 deletions(-) diff --git a/rules/rules.go b/rules/rules.go index 5be307d..a3cebb0 100644 --- a/rules/rules.go +++ b/rules/rules.go @@ -34,11 +34,6 @@ type Rule struct { type methodPattern string -// An asterisk is treated as matching any method -func (t methodPattern) matches(input string) bool { - return t == "*" || string(t) == input -} - // Beyond the 9 methods defined in HTTP 1.1, there actually are many more seldom used extension methods by // various systems. // https://datatracker.ietf.org/doc/html/rfc7230#section-3.2.6 @@ -126,11 +121,21 @@ func parseHostPattern(input string) (host []labelPattern, rest string, err error // Represents a valid label in a hostname. For example, wobble in `wib-ble.wobble.com`. type labelPattern string +// An `asterisk` is treated as matching anything +func (lp labelPattern) matches(input string) bool { + return lp == "*" || string(lp) == input +} + func parseLabelPattern(rest string) (labelPattern, string, error) { if rest == "" { return "", "", errors.New("expected label, got empty string") } + // If the label is simply an asterisk, good to go. + if rest[0] == '*' { + return "*", rest[1:], nil + } + // First try to get a valid leading char. Leading char in a label cannot be a hyphen. if !isValidLabelChar(rest[0]) || rest[0] == '-' { return "", "", fmt.Errorf("could not pull label from front of string: %s", rest) @@ -216,11 +221,24 @@ func parsePathPattern(input string) ([]segmentPattern, string, error) { // Represents a valid url path segmentPattern. type segmentPattern string +// An `*` is treated as matching anything +func (sp segmentPattern) matches(input string) bool { + return sp == "*" || string(sp) == input +} + func parsePathSegmentPattern(input string) (segmentPattern, string, error) { if input == "" { return "", "", nil } + if len(input) > 0 && input[0] == '*' { + if len(input) > 1 && input[1] != '/' { + return "", "", fmt.Errorf("path segment wildcards must be for the entire segment, got: %s", input) + } + + return segmentPattern(input[0]), input[1:], nil + } + var i int for i = 0; i < len(input); i++ { c := input[i] @@ -411,84 +429,6 @@ func (re *Engine) Evaluate(method, url string) Result { } } -type protocol string - -func parseProtocol(input string) (protocol, string, error) { - if input == "" { - return "", "", errors.New("expected protocol, got empty string") - } - - // Look for "://" separator - if idx := strings.Index(input, "://"); idx > 0 { - protocolPart := input[:idx] - rest := input[idx+3:] - - // Validate protocol characters (scheme per RFC 3986) - // scheme = ALPHA *( ALPHA / DIGIT / "+" / "-" / "." ) - if len(protocolPart) == 0 { - return "", "", errors.New("empty protocol") - } - - // First character must be alpha - if !((protocolPart[0] >= 'A' && protocolPart[0] <= 'Z') || - (protocolPart[0] >= 'a' && protocolPart[0] <= 'z')) { - return "", "", errors.New("protocol must start with a letter") - } - - // Rest can be alphanumeric, +, -, or . - for i := 1; i < len(protocolPart); i++ { - c := protocolPart[i] - if !((c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || - (c >= '0' && c <= '9') || c == '+' || c == '-' || c == '.') { - return "", "", fmt.Errorf("invalid character in protocol: %c", c) - } - } - - return protocol(protocolPart), rest, nil - } - - // No protocol found - return "", input, nil -} - -type port uint16 - -func parsePort(input string) (port, string, error) { - if input == "" { - return 0, "", nil - } - - // Port must start with ':' - if input[0] != ':' { - return 0, input, nil - } - - // Find the end of the port number - i := 1 - for i < len(input) && input[i] >= '0' && input[i] <= '9' { - i++ - } - - // No digits found after ':' - if i == 1 { - return 0, "", errors.New("expected port number after ':'") - } - - portStr := input[1:i] - rest := input[i:] - - // Convert to uint16 (port range is 0-65535) - portNum := 0 - for _, digit := range portStr { - portNum = portNum*10 + int(digit-'0') - if portNum > 65535 { - return 0, "", errors.New("port number too large (max 65535)") - } - } - - return port(portNum), rest, nil -} - // Matches checks if the rule matches the given method and URL using wildcard patterns func (re *Engine) matches(r Rule, method, url string) bool { return true diff --git a/rules/rules_test.go b/rules/rules_test.go index 3cbc59f..8e2f067 100644 --- a/rules/rules_test.go +++ b/rules/rules_test.go @@ -257,6 +257,34 @@ func TestParseHost(t *testing.T) { expectedRest: "", expectError: false, }, + { + name: "wildcard subdomain", + input: "*.example.com", + expectedHost: []labelPattern{labelPattern("*"), labelPattern("example"), labelPattern("com")}, + expectedRest: "", + expectError: false, + }, + { + name: "wildcard domain", + input: "api.*", + expectedHost: []labelPattern{labelPattern("api"), labelPattern("*")}, + expectedRest: "", + expectError: false, + }, + { + name: "multiple wildcards", + input: "*.*.com", + expectedHost: []labelPattern{labelPattern("*"), labelPattern("*"), labelPattern("com")}, + expectedRest: "", + expectError: false, + }, + { + name: "wildcard with trailing content", + input: "*.example.com/path", + expectedHost: []labelPattern{labelPattern("*"), labelPattern("example"), labelPattern("com")}, + expectedRest: "/path", + expectError: false, + }, } for _, tt := range tests { @@ -399,6 +427,27 @@ func TestParseLabel(t *testing.T) { expectedRest: "/path", expectError: false, }, + { + name: "wildcard label", + input: "*", + expectedLabel: "*", + expectedRest: "", + expectError: false, + }, + { + name: "wildcard with dot", + input: "*.com", + expectedLabel: "*", + expectedRest: ".com", + expectError: false, + }, + { + name: "wildcard with trailing content", + input: "*/path", + expectedLabel: "*", + expectedRest: "/path", + expectError: false, + }, } for _, tt := range tests { @@ -555,6 +604,34 @@ func TestParsePathSegment(t *testing.T) { expectedRest: "[bracket]", expectError: false, }, + { + name: "wildcard segment", + input: "*", + expectedSegment: "*", + expectedRest: "", + expectError: false, + }, + { + name: "wildcard with slash", + input: "*/users", + expectedSegment: "*", + expectedRest: "/users", + expectError: false, + }, + { + name: "wildcard at end with slash", + input: "*", + expectedSegment: "*", + expectedRest: "", + expectError: false, + }, + { + name: "invalid partial wildcard", + input: "*abc", + expectedSegment: "", + expectedRest: "", + expectError: true, + }, } for _, tt := range tests { @@ -718,6 +795,41 @@ func TestParsePath(t *testing.T) { expectedRest: "", expectError: false, }, + { + name: "path with wildcard segment", + input: "/api/*/users", + expectedSegments: []segmentPattern{"api", "*", "users"}, + expectedRest: "", + expectError: false, + }, + { + name: "path with multiple wildcards", + input: "/*/v1/*/profile", + expectedSegments: []segmentPattern{"*", "v1", "*", "profile"}, + expectedRest: "", + expectError: false, + }, + { + name: "path ending with wildcard", + input: "/api/users/*", + expectedSegments: []segmentPattern{"api", "users", "*"}, + expectedRest: "", + expectError: false, + }, + { + name: "path starting with wildcard", + input: "/*/users", + expectedSegments: []segmentPattern{"*", "users"}, + expectedRest: "", + expectError: false, + }, + { + name: "path with wildcard and query", + input: "/api/*/users?limit=10", + expectedSegments: []segmentPattern{"api", "*", "users"}, + expectedRest: "?limit=10", + expectError: false, + }, } for _, tt := range tests { @@ -817,6 +929,44 @@ func TestParseAllowRule(t *testing.T) { }, expectError: false, }, + { + name: "wildcard domain", + input: "domain=*.example.com", + expectedRule: Rule{ + Raw: "domain=*.example.com", + HostPattern: []labelPattern{labelPattern("com"), labelPattern("example"), labelPattern("*")}, + }, + expectError: false, + }, + { + name: "wildcard path", + input: "path=/api/*/users", + expectedRule: Rule{ + Raw: "path=/api/*/users", + PathPattern: []segmentPattern{segmentPattern("api"), segmentPattern("*"), segmentPattern("users")}, + }, + expectError: false, + }, + { + name: "wildcard method", + input: "method=*", + expectedRule: Rule{ + Raw: "method=*", + MethodPatterns: map[methodPattern]struct{}{methodPattern("*"): {}}, + }, + expectError: false, + }, + { + name: "all wildcards", + input: "method=* domain=*.* path=/*/", + expectedRule: Rule{ + Raw: "method=* domain=*.* path=/*/", + MethodPatterns: map[methodPattern]struct{}{methodPattern("*"): {}}, + HostPattern: []labelPattern{labelPattern("*"), labelPattern("*")}, + PathPattern: []segmentPattern{segmentPattern("*")}, + }, + expectError: false, + }, { name: "invalid key", input: "invalid=value", From 3de487561719ef8714157193f05cb27697d8c8c8 Mon Sep 17 00:00:00 2001 From: Benjamin Date: Mon, 22 Sep 2025 12:05:03 -0500 Subject: [PATCH 06/21] implement top level matching --- rules/rules.go | 77 ++++++++++--- rules/rules_test.go | 270 +++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 332 insertions(+), 15 deletions(-) diff --git a/rules/rules.go b/rules/rules.go index a3cebb0..46be450 100644 --- a/rules/rules.go +++ b/rules/rules.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "log/slog" + neturl "net/url" "strings" ) @@ -17,11 +18,13 @@ type Rule struct { // The path segments of the url // nil means all paths allowed // a path segment of `*` acts as a wild card. + // sub paths automatically match PathPattern []segmentPattern // The labels of the host, i.e. ["google", "com"] // nil means all hosts allowed // A label of `*` acts as a wild card. + // subdomains automatically match HostPattern []labelPattern // The allowed http methods @@ -121,17 +124,12 @@ func parseHostPattern(input string) (host []labelPattern, rest string, err error // Represents a valid label in a hostname. For example, wobble in `wib-ble.wobble.com`. type labelPattern string -// An `asterisk` is treated as matching anything -func (lp labelPattern) matches(input string) bool { - return lp == "*" || string(lp) == input -} - func parseLabelPattern(rest string) (labelPattern, string, error) { if rest == "" { return "", "", errors.New("expected label, got empty string") } - // If the label is simply an asterisk, good to go. + // If the label is simply an asterisk, good to go. if rest[0] == '*' { return "*", rest[1:], nil } @@ -221,11 +219,6 @@ func parsePathPattern(input string) ([]segmentPattern, string, error) { // Represents a valid url path segmentPattern. type segmentPattern string -// An `*` is treated as matching anything -func (sp segmentPattern) matches(input string) bool { - return sp == "*" || string(sp) == input -} - func parsePathSegmentPattern(input string) (segmentPattern, string, error) { if input == "" { return "", "", nil @@ -359,9 +352,7 @@ func parseAllowRule(ruleStr string) (Rule, error) { // Convert segments to strings rule.PathPattern = make([]segmentPattern, len(segments)) - for i, segment := range segments { - rule.PathPattern[i] = segment - } + copy(rule.PathPattern, segments) rest = remaining default: @@ -431,5 +422,63 @@ func (re *Engine) Evaluate(method, url string) Result { // Matches checks if the rule matches the given method and URL using wildcard patterns func (re *Engine) matches(r Rule, method, url string) bool { + + // Check method patterns if they exist + if r.MethodPatterns != nil { + methodMatches := false + for mp := range r.MethodPatterns { + if string(mp) == method || string(mp) == "*" { + methodMatches = true + break + } + } + if !methodMatches { + return false + } + } + + parsedUrl, err := neturl.Parse(url) + if err != nil { + return false + } + + if r.HostPattern != nil { + // For a host pattern to match, every label has to match or be an `*`. + // Subdomains also match automatically, meaning if the pattern is "wobble.com" + // and the real is "wibble.wobble.com", it should match. We check this by comparing + // from the end since patterns are stored in reverse order (TLD first). + + labels := strings.Split(parsedUrl.Hostname(), ".") + + // If the host pattern is longer than the actual host, it's definitely not a match + if len(r.HostPattern) > len(labels) { + return false + } + + // Compare from the end of both arrays since pattern is stored in reverse order + for i, lp := range r.HostPattern { + labelIndex := len(labels) - 1 - i + if string(lp) != labels[labelIndex] && lp != "*" { + return false + } + } + } + + if r.PathPattern != nil { + segments := strings.Split(parsedUrl.Path, "/") + + // If the path pattern is longer than the actual path, definitely not a match + if len(r.PathPattern) > len(segments) { + return false + } + + // Each segment in the pattern must be either as asterisk or match the actual path segment + for i, sp := range r.PathPattern { + if string(sp) != segments[i] && sp != "*" { + return false + } + } + } + return true } diff --git a/rules/rules_test.go b/rules/rules_test.go index 8e2f067..ac1a302 100644 --- a/rules/rules_test.go +++ b/rules/rules_test.go @@ -1,6 +1,9 @@ package rules -import "testing" +import ( + "log/slog" + "testing" +) func TestParseHTTPToken(t *testing.T) { tests := []struct { @@ -1058,3 +1061,268 @@ func TestParseAllowRule(t *testing.T) { }) } } + +func TestEngineMatches(t *testing.T) { + logger := &slog.Logger{} + engine := NewRuleEngine(nil, logger) + + tests := []struct { + name string + rule Rule + method string + url string + expected bool + }{ + // Method pattern tests + { + name: "method matches exact", + rule: Rule{ + MethodPatterns: map[methodPattern]struct{}{methodPattern("GET"): {}}, + }, + method: "GET", + url: "https://example.com/api", + expected: true, + }, + { + name: "method does not match", + rule: Rule{ + MethodPatterns: map[methodPattern]struct{}{methodPattern("POST"): {}}, + }, + method: "GET", + url: "https://example.com/api", + expected: false, + }, + { + name: "method wildcard matches any", + rule: Rule{ + MethodPatterns: map[methodPattern]struct{}{methodPattern("*"): {}}, + }, + method: "PUT", + url: "https://example.com/api", + expected: true, + }, + { + name: "no method pattern allows all methods", + rule: Rule{ + HostPattern: []labelPattern{labelPattern("com"), labelPattern("example")}, + }, + method: "DELETE", + url: "https://example.com/api", + expected: true, + }, + + // Host pattern tests + { + name: "host matches exact", + rule: Rule{ + HostPattern: []labelPattern{labelPattern("com"), labelPattern("example")}, + }, + method: "GET", + url: "https://example.com/api", + expected: true, + }, + { + name: "host does not match", + rule: Rule{ + HostPattern: []labelPattern{labelPattern("org"), labelPattern("example")}, + }, + method: "GET", + url: "https://example.com/api", + expected: false, + }, + { + name: "subdomain matches", + rule: Rule{ + HostPattern: []labelPattern{labelPattern("com"), labelPattern("example")}, + }, + method: "GET", + url: "https://api.example.com/users", + expected: true, + }, + { + name: "host pattern too long", + rule: Rule{ + HostPattern: []labelPattern{labelPattern("com"), labelPattern("example"), labelPattern("api"), labelPattern("v1")}, + }, + method: "GET", + url: "https://api.example.com/users", + expected: false, + }, + { + name: "host wildcard matches", + rule: Rule{ + HostPattern: []labelPattern{labelPattern("com"), labelPattern("*")}, + }, + method: "GET", + url: "https://test.com/api", + expected: true, + }, + { + name: "multiple host wildcards", + rule: Rule{ + HostPattern: []labelPattern{labelPattern("*"), labelPattern("*")}, + }, + method: "GET", + url: "https://api.example.com/users", + expected: true, + }, + + // Path pattern tests + { + name: "path matches exact", + rule: Rule{ + PathPattern: []segmentPattern{segmentPattern(""), segmentPattern("api"), segmentPattern("users")}, + }, + method: "GET", + url: "https://example.com/api/users", + expected: true, + }, + { + name: "path does not match", + rule: Rule{ + PathPattern: []segmentPattern{segmentPattern(""), segmentPattern("api"), segmentPattern("posts")}, + }, + method: "GET", + url: "https://example.com/api/users", + expected: false, + }, + { + name: "subpath matches", + rule: Rule{ + PathPattern: []segmentPattern{segmentPattern(""), segmentPattern("api")}, + }, + method: "GET", + url: "https://example.com/api/users/123", + expected: true, + }, + { + name: "path pattern too long", + rule: Rule{ + PathPattern: []segmentPattern{segmentPattern(""), segmentPattern("api"), segmentPattern("v1"), segmentPattern("users"), segmentPattern("profile")}, + }, + method: "GET", + url: "https://example.com/api/v1/users", + expected: false, + }, + { + name: "path wildcard matches", + rule: Rule{ + PathPattern: []segmentPattern{segmentPattern(""), segmentPattern("api"), segmentPattern("*"), segmentPattern("profile")}, + }, + method: "GET", + url: "https://example.com/api/users/profile", + expected: true, + }, + { + name: "multiple path wildcards", + rule: Rule{ + PathPattern: []segmentPattern{segmentPattern(""), segmentPattern("*"), segmentPattern("*")}, + }, + method: "GET", + url: "https://example.com/api/users/123", + expected: true, + }, + + // Combined pattern tests + { + name: "all patterns match", + rule: Rule{ + MethodPatterns: map[methodPattern]struct{}{methodPattern("POST"): {}}, + HostPattern: []labelPattern{labelPattern("com"), labelPattern("api")}, + PathPattern: []segmentPattern{segmentPattern(""), segmentPattern("users")}, + }, + method: "POST", + url: "https://api.com/users", + expected: true, + }, + { + name: "method fails combined test", + rule: Rule{ + MethodPatterns: map[methodPattern]struct{}{methodPattern("POST"): {}}, + HostPattern: []labelPattern{labelPattern("com"), labelPattern("api")}, + PathPattern: []segmentPattern{segmentPattern(""), segmentPattern("users")}, + }, + method: "GET", + url: "https://api.com/users", + expected: false, + }, + { + name: "host fails combined test", + rule: Rule{ + MethodPatterns: map[methodPattern]struct{}{methodPattern("POST"): {}}, + HostPattern: []labelPattern{labelPattern("org"), labelPattern("api")}, + PathPattern: []segmentPattern{segmentPattern(""), segmentPattern("users")}, + }, + method: "POST", + url: "https://api.com/users", + expected: false, + }, + { + name: "path fails combined test", + rule: Rule{ + MethodPatterns: map[methodPattern]struct{}{methodPattern("POST"): {}}, + HostPattern: []labelPattern{labelPattern("com"), labelPattern("api")}, + PathPattern: []segmentPattern{segmentPattern(""), segmentPattern("posts")}, + }, + method: "POST", + url: "https://api.com/users", + expected: false, + }, + { + name: "all wildcards match", + rule: Rule{ + MethodPatterns: map[methodPattern]struct{}{methodPattern("*"): {}}, + HostPattern: []labelPattern{labelPattern("*"), labelPattern("*")}, + PathPattern: []segmentPattern{segmentPattern(""), segmentPattern("*"), segmentPattern("*")}, + }, + method: "PATCH", + url: "https://test.example.com/api/users/123", + expected: true, + }, + + // Edge cases + { + name: "empty rule matches everything", + rule: Rule{}, + method: "GET", + url: "https://example.com/api/users", + expected: true, + }, + { + name: "invalid URL", + rule: Rule{ + HostPattern: []labelPattern{labelPattern("com"), labelPattern("example")}, + }, + method: "GET", + url: "not-a-valid-url", + expected: false, + }, + { + name: "root path", + rule: Rule{ + PathPattern: []segmentPattern{segmentPattern("")}, + }, + method: "GET", + url: "https://example.com/", + expected: true, + }, + { + name: "localhost host", + rule: Rule{ + HostPattern: []labelPattern{labelPattern("localhost")}, + }, + method: "GET", + url: "http://localhost:8080/api", + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := engine.matches(tt.rule, tt.method, tt.url) + if result != tt.expected { + t.Errorf("expected %v, got %v", tt.expected, result) + } + }) + } +} From 72502e543591b88ccf5560a520a474e09bbddfbf Mon Sep 17 00:00:00 2001 From: Benjamin Date: Mon, 22 Sep 2025 12:21:19 -0500 Subject: [PATCH 07/21] don't reverse host pattern while parsing --- rules/rules.go | 20 ++++++++------------ rules/rules_test.go | 30 +++++++++++++++--------------- 2 files changed, 23 insertions(+), 27 deletions(-) diff --git a/rules/rules.go b/rules/rules.go index 46be450..55963bb 100644 --- a/rules/rules.go +++ b/rules/rules.go @@ -337,11 +337,8 @@ func parseAllowRule(ruleStr string) (Rule, error) { return Rule{}, fmt.Errorf("failed to parse domain: %v", err) } - // Convert labels to strings in reverse order (TLD first) - rule.HostPattern = make([]labelPattern, len(hostLabels)) - for i, label := range hostLabels { - rule.HostPattern[len(hostLabels)-1-i] = label - } + // Convert labels to strings + rule.HostPattern = append(rule.HostPattern, hostLabels...) rest = remaining case "path": @@ -351,8 +348,7 @@ func parseAllowRule(ruleStr string) (Rule, error) { } // Convert segments to strings - rule.PathPattern = make([]segmentPattern, len(segments)) - copy(rule.PathPattern, segments) + rule.PathPattern = append(rule.PathPattern, segments...) rest = remaining default: @@ -444,9 +440,9 @@ func (re *Engine) matches(r Rule, method, url string) bool { if r.HostPattern != nil { // For a host pattern to match, every label has to match or be an `*`. - // Subdomains also match automatically, meaning if the pattern is "wobble.com" - // and the real is "wibble.wobble.com", it should match. We check this by comparing - // from the end since patterns are stored in reverse order (TLD first). + // Subdomains also match automatically, meaning if the pattern is "example.com" + // and the real is "api.example.com", it should match. We check this by comparing + // from the end of the actual hostname with the pattern (which is in normal order). labels := strings.Split(parsedUrl.Hostname(), ".") @@ -455,9 +451,9 @@ func (re *Engine) matches(r Rule, method, url string) bool { return false } - // Compare from the end of both arrays since pattern is stored in reverse order + // Compare pattern with the end of labels (allowing subdomains) for i, lp := range r.HostPattern { - labelIndex := len(labels) - 1 - i + labelIndex := len(labels) - len(r.HostPattern) + i if string(lp) != labels[labelIndex] && lp != "*" { return false } diff --git a/rules/rules_test.go b/rules/rules_test.go index ac1a302..7680a57 100644 --- a/rules/rules_test.go +++ b/rules/rules_test.go @@ -898,7 +898,7 @@ func TestParseAllowRule(t *testing.T) { input: "domain=google.com", expectedRule: Rule{ Raw: "domain=google.com", - HostPattern: []labelPattern{labelPattern("com"), labelPattern("google")}, + HostPattern: []labelPattern{labelPattern("google"), labelPattern("com")}, }, expectError: false, }, @@ -917,7 +917,7 @@ func TestParseAllowRule(t *testing.T) { expectedRule: Rule{ Raw: "method=POST domain=api.example.com", MethodPatterns: map[methodPattern]struct{}{methodPattern("POST"): {}}, - HostPattern: []labelPattern{labelPattern("com"), labelPattern("example"), labelPattern("api")}, + HostPattern: []labelPattern{labelPattern("api"), labelPattern("example"), labelPattern("com")}, }, expectError: false, }, @@ -927,7 +927,7 @@ func TestParseAllowRule(t *testing.T) { expectedRule: Rule{ Raw: "method=DELETE domain=test.com path=/resources/456", MethodPatterns: map[methodPattern]struct{}{methodPattern("DELETE"): {}}, - HostPattern: []labelPattern{labelPattern("com"), labelPattern("test")}, + HostPattern: []labelPattern{labelPattern("test"), labelPattern("com")}, PathPattern: []segmentPattern{segmentPattern("resources"), segmentPattern("456")}, }, expectError: false, @@ -937,7 +937,7 @@ func TestParseAllowRule(t *testing.T) { input: "domain=*.example.com", expectedRule: Rule{ Raw: "domain=*.example.com", - HostPattern: []labelPattern{labelPattern("com"), labelPattern("example"), labelPattern("*")}, + HostPattern: []labelPattern{labelPattern("*"), labelPattern("example"), labelPattern("com")}, }, expectError: false, }, @@ -1104,7 +1104,7 @@ func TestEngineMatches(t *testing.T) { { name: "no method pattern allows all methods", rule: Rule{ - HostPattern: []labelPattern{labelPattern("com"), labelPattern("example")}, + HostPattern: []labelPattern{labelPattern("example"), labelPattern("com")}, }, method: "DELETE", url: "https://example.com/api", @@ -1115,7 +1115,7 @@ func TestEngineMatches(t *testing.T) { { name: "host matches exact", rule: Rule{ - HostPattern: []labelPattern{labelPattern("com"), labelPattern("example")}, + HostPattern: []labelPattern{labelPattern("example"), labelPattern("com")}, }, method: "GET", url: "https://example.com/api", @@ -1124,7 +1124,7 @@ func TestEngineMatches(t *testing.T) { { name: "host does not match", rule: Rule{ - HostPattern: []labelPattern{labelPattern("org"), labelPattern("example")}, + HostPattern: []labelPattern{labelPattern("example"), labelPattern("org")}, }, method: "GET", url: "https://example.com/api", @@ -1133,7 +1133,7 @@ func TestEngineMatches(t *testing.T) { { name: "subdomain matches", rule: Rule{ - HostPattern: []labelPattern{labelPattern("com"), labelPattern("example")}, + HostPattern: []labelPattern{labelPattern("example"), labelPattern("com")}, }, method: "GET", url: "https://api.example.com/users", @@ -1142,7 +1142,7 @@ func TestEngineMatches(t *testing.T) { { name: "host pattern too long", rule: Rule{ - HostPattern: []labelPattern{labelPattern("com"), labelPattern("example"), labelPattern("api"), labelPattern("v1")}, + HostPattern: []labelPattern{labelPattern("v1"), labelPattern("api"), labelPattern("example"), labelPattern("com")}, }, method: "GET", url: "https://api.example.com/users", @@ -1151,7 +1151,7 @@ func TestEngineMatches(t *testing.T) { { name: "host wildcard matches", rule: Rule{ - HostPattern: []labelPattern{labelPattern("com"), labelPattern("*")}, + HostPattern: []labelPattern{labelPattern("*"), labelPattern("com")}, }, method: "GET", url: "https://test.com/api", @@ -1228,7 +1228,7 @@ func TestEngineMatches(t *testing.T) { name: "all patterns match", rule: Rule{ MethodPatterns: map[methodPattern]struct{}{methodPattern("POST"): {}}, - HostPattern: []labelPattern{labelPattern("com"), labelPattern("api")}, + HostPattern: []labelPattern{labelPattern("api"), labelPattern("com")}, PathPattern: []segmentPattern{segmentPattern(""), segmentPattern("users")}, }, method: "POST", @@ -1239,7 +1239,7 @@ func TestEngineMatches(t *testing.T) { name: "method fails combined test", rule: Rule{ MethodPatterns: map[methodPattern]struct{}{methodPattern("POST"): {}}, - HostPattern: []labelPattern{labelPattern("com"), labelPattern("api")}, + HostPattern: []labelPattern{labelPattern("api"), labelPattern("com")}, PathPattern: []segmentPattern{segmentPattern(""), segmentPattern("users")}, }, method: "GET", @@ -1250,7 +1250,7 @@ func TestEngineMatches(t *testing.T) { name: "host fails combined test", rule: Rule{ MethodPatterns: map[methodPattern]struct{}{methodPattern("POST"): {}}, - HostPattern: []labelPattern{labelPattern("org"), labelPattern("api")}, + HostPattern: []labelPattern{labelPattern("api"), labelPattern("org")}, PathPattern: []segmentPattern{segmentPattern(""), segmentPattern("users")}, }, method: "POST", @@ -1261,7 +1261,7 @@ func TestEngineMatches(t *testing.T) { name: "path fails combined test", rule: Rule{ MethodPatterns: map[methodPattern]struct{}{methodPattern("POST"): {}}, - HostPattern: []labelPattern{labelPattern("com"), labelPattern("api")}, + HostPattern: []labelPattern{labelPattern("api"), labelPattern("com")}, PathPattern: []segmentPattern{segmentPattern(""), segmentPattern("posts")}, }, method: "POST", @@ -1291,7 +1291,7 @@ func TestEngineMatches(t *testing.T) { { name: "invalid URL", rule: Rule{ - HostPattern: []labelPattern{labelPattern("com"), labelPattern("example")}, + HostPattern: []labelPattern{labelPattern("example"), labelPattern("com")}, }, method: "GET", url: "not-a-valid-url", From 8e999bd57ac6f65c9db6d572cb31da6c0b66e871 Mon Sep 17 00:00:00 2001 From: Benjamin Date: Mon, 22 Sep 2025 12:29:23 -0500 Subject: [PATCH 08/21] use the logger --- jail/jail.go | 2 +- jail/linux_stub.go | 2 +- jail/macos_stub.go | 2 +- rules/rules.go | 7 +++++++ rules/rules_test.go | 6 +++--- 5 files changed, 13 insertions(+), 6 deletions(-) diff --git a/jail/jail.go b/jail/jail.go index b59bf2d..2c2f217 100644 --- a/jail/jail.go +++ b/jail/jail.go @@ -34,4 +34,4 @@ func DefaultOS(config Config) (Jailer, error) { default: return nil, fmt.Errorf("unsupported operating system: %s", runtime.GOOS) } -} \ No newline at end of file +} diff --git a/jail/linux_stub.go b/jail/linux_stub.go index 19d32dc..fe8835e 100644 --- a/jail/linux_stub.go +++ b/jail/linux_stub.go @@ -9,4 +9,4 @@ import ( // NewLinuxJail is not available on non-Linux platforms func NewLinuxJail(_ Config) (Jailer, error) { return nil, fmt.Errorf("linux jail not supported on this platform") -} \ No newline at end of file +} diff --git a/jail/macos_stub.go b/jail/macos_stub.go index 89f86a0..656cdc2 100644 --- a/jail/macos_stub.go +++ b/jail/macos_stub.go @@ -7,4 +7,4 @@ import "fmt" // NewMacOSJail is not available on non-macOS platforms func NewMacOSJail(_ Config) (Jailer, error) { return nil, fmt.Errorf("macOS jail not supported on this platform") -} \ No newline at end of file +} diff --git a/rules/rules.go b/rules/rules.go index 55963bb..6d36dfb 100644 --- a/rules/rules.go +++ b/rules/rules.go @@ -429,12 +429,14 @@ func (re *Engine) matches(r Rule, method, url string) bool { } } if !methodMatches { + re.logger.Info("rule does not match", "reason", "method pattern mismatch", "rule", r.Raw, "method", method, "url", url) return false } } parsedUrl, err := neturl.Parse(url) if err != nil { + re.logger.Info("rule does not match", "reason", "invalid URL", "rule", r.Raw, "method", method, "url", url, "error", err) return false } @@ -448,6 +450,7 @@ func (re *Engine) matches(r Rule, method, url string) bool { // If the host pattern is longer than the actual host, it's definitely not a match if len(r.HostPattern) > len(labels) { + re.logger.Info("rule does not match", "reason", "host pattern too long", "rule", r.Raw, "method", method, "url", url, "pattern_length", len(r.HostPattern), "hostname_labels", len(labels)) return false } @@ -455,6 +458,7 @@ func (re *Engine) matches(r Rule, method, url string) bool { for i, lp := range r.HostPattern { labelIndex := len(labels) - len(r.HostPattern) + i if string(lp) != labels[labelIndex] && lp != "*" { + re.logger.Info("rule does not match", "reason", "host pattern label mismatch", "rule", r.Raw, "method", method, "url", url, "expected", string(lp), "actual", labels[labelIndex]) return false } } @@ -465,16 +469,19 @@ func (re *Engine) matches(r Rule, method, url string) bool { // If the path pattern is longer than the actual path, definitely not a match if len(r.PathPattern) > len(segments) { + re.logger.Info("rule does not match", "reason", "path pattern too long", "rule", r.Raw, "method", method, "url", url, "pattern_length", len(r.PathPattern), "path_segments", len(segments)) return false } // Each segment in the pattern must be either as asterisk or match the actual path segment for i, sp := range r.PathPattern { if string(sp) != segments[i] && sp != "*" { + re.logger.Info("rule does not match", "reason", "path pattern segment mismatch", "rule", r.Raw, "method", method, "url", url, "expected", string(sp), "actual", segments[i]) return false } } } + re.logger.Info("rule matches", "reason", "all patterns matched", "rule", r.Raw, "method", method, "url", url) return true } diff --git a/rules/rules_test.go b/rules/rules_test.go index 7680a57..ad215d2 100644 --- a/rules/rules_test.go +++ b/rules/rules_test.go @@ -1063,7 +1063,7 @@ func TestParseAllowRule(t *testing.T) { } func TestEngineMatches(t *testing.T) { - logger := &slog.Logger{} + logger := slog.Default() engine := NewRuleEngine(nil, logger) tests := []struct { @@ -1282,8 +1282,8 @@ func TestEngineMatches(t *testing.T) { // Edge cases { - name: "empty rule matches everything", - rule: Rule{}, + name: "empty rule matches everything", + rule: Rule{}, method: "GET", url: "https://example.com/api/users", expected: true, From 9e92f3c7d5547010af0b7fd9b392e7905d6fa9aa Mon Sep 17 00:00:00 2001 From: Benjamin Date: Mon, 22 Sep 2025 12:36:50 -0500 Subject: [PATCH 09/21] update proxy server tests to match new syntax --- proxy/proxy_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index fe61391..8b14f02 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -35,7 +35,7 @@ func TestProxyServerBasicHTTP(t *testing.T) { })) // Create test rules (allow all for testing) - testRules, err := rules.ParseAllowSpecs([]string{"*"}) + testRules, err := rules.ParseAllowSpecs([]string{"method=*"}) if err != nil { t.Fatalf("Failed to parse test rules: %v", err) } @@ -116,7 +116,7 @@ func TestProxyServerBasicHTTPS(t *testing.T) { })) // Create test rules (allow all for testing) - testRules, err := rules.ParseAllowSpecs([]string{"*"}) + testRules, err := rules.ParseAllowSpecs([]string{"method=*"}) if err != nil { t.Fatalf("Failed to parse test rules: %v", err) } @@ -210,7 +210,7 @@ func TestProxyServerCONNECT(t *testing.T) { })) // Create test rules (allow all for testing) - testRules, err := rules.ParseAllowSpecs([]string{"*"}) + testRules, err := rules.ParseAllowSpecs([]string{"method=*"}) if err != nil { t.Fatalf("Failed to parse test rules: %v", err) } From e400285121a4a280b9fd058ee011f554aebaa736 Mon Sep 17 00:00:00 2001 From: Benjamin Date: Mon, 22 Sep 2025 13:08:49 -0500 Subject: [PATCH 10/21] remove trailing * support for hostname --- README.md | 30 +++--- rules/rules.go | 53 ++++++++--- rules/rules_test.go | 218 +++++++++++++++++++++++++++++++++++++++----- 3 files changed, 250 insertions(+), 51 deletions(-) diff --git a/README.md b/README.md index 418dc25..1910a5c 100644 --- a/README.md +++ b/README.md @@ -25,12 +25,12 @@ curl -fsSL https://raw.githubusercontent.com/coder/boundary/main/install.sh | ba ```bash # Allow only requests to github.com -boundary --allow "github.com" -- curl https://github.com +boundary --allow "domain=github.com" -- curl https://github.com # Allow full access to GitHub issues API, but only GET/HEAD elsewhere on GitHub boundary \ - --allow "github.com/api/issues/*" \ - --allow "GET,HEAD github.com" \ + --allow "domain=github.com path=/api/issues/*" \ + --allow "method=GET,HEAD domain=github.com" \ -- npm install # Default deny-all: everything is blocked unless explicitly allowed @@ -41,16 +41,20 @@ boundary -- curl https://example.com ### Format ```text ---allow "pattern" # All HTTP methods allowed ---allow "METHOD[,METHOD] pattern" # Specific methods only +--allow "key=value [key=value ...]" ``` +**Keys:** +- `method` - HTTP method(s), comma-separated (GET, POST, etc.) +- `domain` - Domain/hostname pattern +- `path` - URL path pattern + ### Examples ```bash -boundary --allow "github.com" -- git pull -boundary --allow "*.github.com" -- npm install # GitHub subdomains -boundary --allow "api.*" -- ./app # Any API domain -boundary --allow "GET,HEAD api.github.com" -- curl https://api.github.com +boundary --allow "domain=github.com" -- git pull +boundary --allow "domain=*.github.com" -- npm install # GitHub subdomains +boundary --allow "method=GET,HEAD domain=api.github.com" -- curl https://api.github.com +boundary --allow "method=POST domain=api.example.com path=/users" -- ./app ``` Wildcards: `*` matches any characters. All traffic is denied unless explicitly allowed. @@ -58,8 +62,8 @@ Wildcards: `*` matches any characters. All traffic is denied unless explicitly a ## Logging ```bash -boundary --log-level info --allow "*" -- npm install # Show all requests -boundary --log-level debug --allow "github.com" -- git pull # Debug info +boundary --log-level info --allow "method=*" -- npm install # Show all requests +boundary --log-level debug --allow "domain=github.com" -- git pull # Debug info ``` **Log Levels:** `error`, `warn` (default), `info`, `debug` @@ -70,10 +74,10 @@ When you can't or don't want to run with sudo privileges, use `--unprivileged`: ```bash # Run without network isolation (uses HTTP_PROXY/HTTPS_PROXY environment variables) -boundary --unprivileged --allow "github.com" -- npm install +boundary --unprivileged --allow "domain=github.com" -- npm install # Useful in containers or restricted environments -boundary --unprivileged --allow "*.npmjs.org" --allow "registry.npmjs.org" -- npm install +boundary --unprivileged --allow "domain=*.npmjs.org" --allow "domain=registry.npmjs.org" -- npm install ``` **Unprivileged Mode:** diff --git a/rules/rules.go b/rules/rules.go index 6d36dfb..cc5beec 100644 --- a/rules/rules.go +++ b/rules/rules.go @@ -118,6 +118,11 @@ func parseHostPattern(input string) (host []labelPattern, rest string, err error host = append(host, label) } + // Validate: host patterns cannot end with asterisk + if len(host) > 0 && host[len(host)-1] == "*" { + return nil, "", errors.New("host patterns cannot end with asterisk") + } + return host, rest, nil } @@ -319,17 +324,31 @@ func parseAllowRule(ruleStr string) (Rule, error) { // Parse the value based on the key type switch key { case "method": - token, remaining, err := parseMethodPattern(valueRest) - if err != nil { - return Rule{}, fmt.Errorf("failed to parse method: %v", err) - } + // Handle comma-separated methods + methodsRest := valueRest // Initialize Methods map if needed if rule.MethodPatterns == nil { rule.MethodPatterns = make(map[methodPattern]struct{}) } - rule.MethodPatterns[token] = struct{}{} - rest = remaining + + for { + token, remaining, err := parseMethodPattern(methodsRest) + if err != nil { + return Rule{}, fmt.Errorf("failed to parse method: %v", err) + } + + rule.MethodPatterns[token] = struct{}{} + + // Check if there's a comma for more methods + if remaining != "" && remaining[0] == ',' { + methodsRest = remaining[1:] // Skip the comma + continue + } + + rest = remaining + break + } case "domain": hostLabels, remaining, err := parseHostPattern(valueRest) @@ -429,14 +448,14 @@ func (re *Engine) matches(r Rule, method, url string) bool { } } if !methodMatches { - re.logger.Info("rule does not match", "reason", "method pattern mismatch", "rule", r.Raw, "method", method, "url", url) + re.logger.Debug("rule does not match", "reason", "method pattern mismatch", "rule", r.Raw, "method", method, "url", url) return false } } parsedUrl, err := neturl.Parse(url) if err != nil { - re.logger.Info("rule does not match", "reason", "invalid URL", "rule", r.Raw, "method", method, "url", url, "error", err) + re.logger.Debug("rule does not match", "reason", "invalid URL", "rule", r.Raw, "method", method, "url", url, "error", err) return false } @@ -450,15 +469,16 @@ func (re *Engine) matches(r Rule, method, url string) bool { // If the host pattern is longer than the actual host, it's definitely not a match if len(r.HostPattern) > len(labels) { - re.logger.Info("rule does not match", "reason", "host pattern too long", "rule", r.Raw, "method", method, "url", url, "pattern_length", len(r.HostPattern), "hostname_labels", len(labels)) + re.logger.Debug("rule does not match", "reason", "host pattern too long", "rule", r.Raw, "method", method, "url", url, "pattern_length", len(r.HostPattern), "hostname_labels", len(labels)) return false } - // Compare pattern with the end of labels (allowing subdomains) + // Since host patterns cannot end with asterisk, we only need to handle: + // "example.com" or "*.example.com" - match from the end (allowing subdomains) for i, lp := range r.HostPattern { labelIndex := len(labels) - len(r.HostPattern) + i if string(lp) != labels[labelIndex] && lp != "*" { - re.logger.Info("rule does not match", "reason", "host pattern label mismatch", "rule", r.Raw, "method", method, "url", url, "expected", string(lp), "actual", labels[labelIndex]) + re.logger.Debug("rule does not match", "reason", "host pattern label mismatch", "rule", r.Raw, "method", method, "url", url, "expected", string(lp), "actual", labels[labelIndex]) return false } } @@ -467,21 +487,26 @@ func (re *Engine) matches(r Rule, method, url string) bool { if r.PathPattern != nil { segments := strings.Split(parsedUrl.Path, "/") + // Skip the first empty segment if the path starts with "/" + if len(segments) > 0 && segments[0] == "" { + segments = segments[1:] + } + // If the path pattern is longer than the actual path, definitely not a match if len(r.PathPattern) > len(segments) { - re.logger.Info("rule does not match", "reason", "path pattern too long", "rule", r.Raw, "method", method, "url", url, "pattern_length", len(r.PathPattern), "path_segments", len(segments)) + re.logger.Debug("rule does not match", "reason", "path pattern too long", "rule", r.Raw, "method", method, "url", url, "pattern_length", len(r.PathPattern), "path_segments", len(segments)) return false } // Each segment in the pattern must be either as asterisk or match the actual path segment for i, sp := range r.PathPattern { if string(sp) != segments[i] && sp != "*" { - re.logger.Info("rule does not match", "reason", "path pattern segment mismatch", "rule", r.Raw, "method", method, "url", url, "expected", string(sp), "actual", segments[i]) + re.logger.Debug("rule does not match", "reason", "path pattern segment mismatch", "rule", r.Raw, "method", method, "url", url, "expected", string(sp), "actual", segments[i]) return false } } } - re.logger.Info("rule matches", "reason", "all patterns matched", "rule", r.Raw, "method", method, "url", url) + re.logger.Debug("rule matches", "reason", "all patterns matched", "rule", r.Raw, "method", method, "url", url) return true } diff --git a/rules/rules_test.go b/rules/rules_test.go index ad215d2..dc5bb70 100644 --- a/rules/rules_test.go +++ b/rules/rules_test.go @@ -1,6 +1,7 @@ package rules import ( + "fmt" "log/slog" "testing" ) @@ -268,11 +269,11 @@ func TestParseHost(t *testing.T) { expectError: false, }, { - name: "wildcard domain", + name: "wildcard domain - should error", input: "api.*", - expectedHost: []labelPattern{labelPattern("api"), labelPattern("*")}, + expectedHost: nil, expectedRest: "", - expectError: false, + expectError: true, }, { name: "multiple wildcards", @@ -288,6 +289,13 @@ func TestParseHost(t *testing.T) { expectedRest: "/path", expectError: false, }, + { + name: "host pattern ending with asterisk - rejected", + input: "api.*", + expectedHost: nil, + expectedRest: "", + expectError: true, + }, } for _, tt := range tests { @@ -960,15 +968,10 @@ func TestParseAllowRule(t *testing.T) { expectError: false, }, { - name: "all wildcards", - input: "method=* domain=*.* path=/*/", - expectedRule: Rule{ - Raw: "method=* domain=*.* path=/*/", - MethodPatterns: map[methodPattern]struct{}{methodPattern("*"): {}}, - HostPattern: []labelPattern{labelPattern("*"), labelPattern("*")}, - PathPattern: []segmentPattern{segmentPattern("*")}, - }, - expectError: false, + name: "all wildcards - domain ending with asterisk should error", + input: "method=* domain=*.* path=/*/", + expectedRule: Rule{}, + expectError: true, }, { name: "invalid key", @@ -1171,7 +1174,7 @@ func TestEngineMatches(t *testing.T) { { name: "path matches exact", rule: Rule{ - PathPattern: []segmentPattern{segmentPattern(""), segmentPattern("api"), segmentPattern("users")}, + PathPattern: []segmentPattern{segmentPattern("api"), segmentPattern("users")}, }, method: "GET", url: "https://example.com/api/users", @@ -1180,7 +1183,7 @@ func TestEngineMatches(t *testing.T) { { name: "path does not match", rule: Rule{ - PathPattern: []segmentPattern{segmentPattern(""), segmentPattern("api"), segmentPattern("posts")}, + PathPattern: []segmentPattern{segmentPattern("api"), segmentPattern("posts")}, }, method: "GET", url: "https://example.com/api/users", @@ -1189,7 +1192,7 @@ func TestEngineMatches(t *testing.T) { { name: "subpath matches", rule: Rule{ - PathPattern: []segmentPattern{segmentPattern(""), segmentPattern("api")}, + PathPattern: []segmentPattern{segmentPattern("api")}, }, method: "GET", url: "https://example.com/api/users/123", @@ -1198,7 +1201,7 @@ func TestEngineMatches(t *testing.T) { { name: "path pattern too long", rule: Rule{ - PathPattern: []segmentPattern{segmentPattern(""), segmentPattern("api"), segmentPattern("v1"), segmentPattern("users"), segmentPattern("profile")}, + PathPattern: []segmentPattern{segmentPattern("api"), segmentPattern("v1"), segmentPattern("users"), segmentPattern("profile")}, }, method: "GET", url: "https://example.com/api/v1/users", @@ -1207,7 +1210,7 @@ func TestEngineMatches(t *testing.T) { { name: "path wildcard matches", rule: Rule{ - PathPattern: []segmentPattern{segmentPattern(""), segmentPattern("api"), segmentPattern("*"), segmentPattern("profile")}, + PathPattern: []segmentPattern{segmentPattern("api"), segmentPattern("*"), segmentPattern("profile")}, }, method: "GET", url: "https://example.com/api/users/profile", @@ -1216,7 +1219,7 @@ func TestEngineMatches(t *testing.T) { { name: "multiple path wildcards", rule: Rule{ - PathPattern: []segmentPattern{segmentPattern(""), segmentPattern("*"), segmentPattern("*")}, + PathPattern: []segmentPattern{segmentPattern("*"), segmentPattern("*")}, }, method: "GET", url: "https://example.com/api/users/123", @@ -1229,7 +1232,7 @@ func TestEngineMatches(t *testing.T) { rule: Rule{ MethodPatterns: map[methodPattern]struct{}{methodPattern("POST"): {}}, HostPattern: []labelPattern{labelPattern("api"), labelPattern("com")}, - PathPattern: []segmentPattern{segmentPattern(""), segmentPattern("users")}, + PathPattern: []segmentPattern{segmentPattern("users")}, }, method: "POST", url: "https://api.com/users", @@ -1240,7 +1243,7 @@ func TestEngineMatches(t *testing.T) { rule: Rule{ MethodPatterns: map[methodPattern]struct{}{methodPattern("POST"): {}}, HostPattern: []labelPattern{labelPattern("api"), labelPattern("com")}, - PathPattern: []segmentPattern{segmentPattern(""), segmentPattern("users")}, + PathPattern: []segmentPattern{segmentPattern("users")}, }, method: "GET", url: "https://api.com/users", @@ -1251,7 +1254,7 @@ func TestEngineMatches(t *testing.T) { rule: Rule{ MethodPatterns: map[methodPattern]struct{}{methodPattern("POST"): {}}, HostPattern: []labelPattern{labelPattern("api"), labelPattern("org")}, - PathPattern: []segmentPattern{segmentPattern(""), segmentPattern("users")}, + PathPattern: []segmentPattern{segmentPattern("users")}, }, method: "POST", url: "https://api.com/users", @@ -1262,7 +1265,7 @@ func TestEngineMatches(t *testing.T) { rule: Rule{ MethodPatterns: map[methodPattern]struct{}{methodPattern("POST"): {}}, HostPattern: []labelPattern{labelPattern("api"), labelPattern("com")}, - PathPattern: []segmentPattern{segmentPattern(""), segmentPattern("posts")}, + PathPattern: []segmentPattern{segmentPattern("posts")}, }, method: "POST", url: "https://api.com/users", @@ -1273,7 +1276,7 @@ func TestEngineMatches(t *testing.T) { rule: Rule{ MethodPatterns: map[methodPattern]struct{}{methodPattern("*"): {}}, HostPattern: []labelPattern{labelPattern("*"), labelPattern("*")}, - PathPattern: []segmentPattern{segmentPattern(""), segmentPattern("*"), segmentPattern("*")}, + PathPattern: []segmentPattern{segmentPattern("*"), segmentPattern("*")}, }, method: "PATCH", url: "https://test.example.com/api/users/123", @@ -1300,7 +1303,7 @@ func TestEngineMatches(t *testing.T) { { name: "root path", rule: Rule{ - PathPattern: []segmentPattern{segmentPattern("")}, + PathPattern: []segmentPattern{}, }, method: "GET", url: "https://example.com/", @@ -1326,3 +1329,170 @@ func TestEngineMatches(t *testing.T) { }) } } + +func TestReadmeExamples(t *testing.T) { + logger := slog.Default() + + tests := []struct { + name string + allowRule string + testCases []struct { + method string + url string + expected bool + } + }{ + { + name: "domain only - github.com", + allowRule: "domain=github.com", + testCases: []struct { + method string + url string + expected bool + }{ + {"GET", "https://github.com", true}, + {"POST", "https://github.com/user/repo", true}, + {"GET", "https://api.github.com", true}, // subdomain match + {"GET", "https://example.com", false}, + }, + }, + { + name: "domain with path - github.com/api/issues/*", + allowRule: "domain=github.com path=/api/issues/*", + testCases: []struct { + method string + url string + expected bool + }{ + {"GET", "https://github.com/api/issues/123", true}, + {"POST", "https://github.com/api/issues/new", true}, + {"GET", "https://github.com/api/users", false}, // wrong path + {"GET", "https://example.com/api/issues/123", false}, // wrong domain + }, + }, + { + name: "method with domain - GET,HEAD github.com", + allowRule: "method=GET,HEAD domain=github.com", + testCases: []struct { + method string + url string + expected bool + }{ + {"GET", "https://github.com/user/repo", true}, + {"HEAD", "https://github.com/user/repo", true}, + {"POST", "https://github.com/user/repo", false}, // wrong method + {"GET", "https://example.com", false}, // wrong domain + }, + }, + { + name: "wildcard subdomain - *.github.com", + allowRule: "domain=*.github.com", + testCases: []struct { + method string + url string + expected bool + }{ + {"GET", "https://api.github.com", true}, + {"GET", "https://raw.github.com", true}, + {"GET", "https://github.com", false}, // no subdomain + {"GET", "https://example.com", false}, + }, + }, + { + name: "method with domain and specific host", + allowRule: "method=GET,HEAD domain=api.github.com", + testCases: []struct { + method string + url string + expected bool + }{ + {"GET", "https://api.github.com/users", true}, + {"HEAD", "https://api.github.com/repos", true}, + {"POST", "https://api.github.com/users", false}, // wrong method + {"GET", "https://github.com", false}, // wrong domain + }, + }, + { + name: "method with domain and path", + allowRule: "method=POST domain=api.example.com path=/users", + testCases: []struct { + method string + url string + expected bool + }{ + {"POST", "https://api.example.com/users", true}, + {"POST", "https://api.example.com/users/123", true}, // subpath match + {"GET", "https://api.example.com/users", false}, // wrong method + {"POST", "https://api.example.com/posts", false}, // wrong path + {"POST", "https://example.com/users", false}, // wrong domain + }, + }, + { + name: "method wildcard - all methods", + allowRule: "method=*", + testCases: []struct { + method string + url string + expected bool + }{ + {"GET", "https://example.com", true}, + {"POST", "https://example.com", true}, + {"DELETE", "https://example.com", true}, + {"PATCH", "https://example.com", true}, + {"OPTIONS", "https://example.com", true}, + }, + }, + { + name: "multiple wildcards - wildcard subdomains", + allowRule: "domain=*.npmjs.org", + testCases: []struct { + method string + url string + expected bool + }{ + {"GET", "https://registry.npmjs.org", true}, + {"GET", "https://api.npmjs.org", true}, + {"GET", "https://npmjs.org", false}, // no subdomain + {"GET", "https://example.com", false}, + }, + }, + { + name: "registry domain exact match", + allowRule: "domain=registry.npmjs.org", + testCases: []struct { + method string + url string + expected bool + }{ + {"GET", "https://registry.npmjs.org", true}, + {"GET", "https://registry.npmjs.org/package", true}, + {"GET", "https://api.npmjs.org", false}, // different subdomain + {"GET", "https://npmjs.org", false}, // missing subdomain + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Parse the allow rule + rule, err := parseAllowRule(tt.allowRule) + if err != nil { + t.Fatalf("Failed to parse allow rule %q: %v", tt.allowRule, err) + } + + // Create engine with the single rule + engine := NewRuleEngine([]Rule{rule}, logger) + + // Test each case + for i, tc := range tt.testCases { + t.Run(fmt.Sprintf("case_%d_%s_%s", i, tc.method, tc.url), func(t *testing.T) { + result := engine.matches(rule, tc.method, tc.url) + if result != tc.expected { + t.Errorf("Rule %q with method %q and URL %q: expected %v, got %v", + tt.allowRule, tc.method, tc.url, tc.expected, result) + } + }) + } + }) + } +} From b83d3cff3aaf0ab705cabd79131426186ff54463 Mon Sep 17 00:00:00 2001 From: Benjamin Date: Mon, 22 Sep 2025 15:44:34 -0500 Subject: [PATCH 11/21] update usage in error messages --- proxy/proxy.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/proxy/proxy.go b/proxy/proxy.go index e2aa537..00b116b 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -254,8 +254,8 @@ Request: %s %s Host: %s To allow this request, restart boundary with: - --allow "%s" # Allow all methods to this host - --allow "%s %s" # Allow only %s requests to this host + --allow "domain=%s" # Allow all methods to this host + --allow "method=%s domain=%s" # Allow only %s requests to this host For more help: https://github.com/coder/boundary `, @@ -639,7 +639,7 @@ func (p *Server) constructFullURL(req *http.Request, hostname string) string { // writeBlockedResponseStreaming writes a blocked response directly to the TLS connection func (p *Server) writeBlockedResponseStreaming(tlsConn *tls.Conn, req *http.Request) { - response := fmt.Sprintf("HTTP/1.1 403 Forbidden\r\nContent-Type: text/plain\r\nConnection: close\r\n\r\n🚫 Request Blocked by Boundary\n\nRequest: %s %s\nHost: %s\n\nTo allow this request, restart boundary with:\n --allow \"%s\"\n", + response := fmt.Sprintf("HTTP/1.1 403 Forbidden\r\nContent-Type: text/plain\r\nConnection: close\r\n\r\n🚫 Request Blocked by Boundary\n\nRequest: %s %s\nHost: %s\n\nTo allow this request, restart boundary with:\n --allow \"domain=%s\"\n", req.Method, req.URL.Path, req.Host, req.Host) _, _ = tlsConn.Write([]byte(response)) } From fe417b873ced0eef36a3ad22e4b25e4d6e1e9c30 Mon Sep 17 00:00:00 2001 From: Benjamin Date: Mon, 22 Sep 2025 15:45:51 -0500 Subject: [PATCH 12/21] update cli examples --- cli/cli.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cli/cli.go b/cli/cli.go index 9e8b993..4304649 100644 --- a/cli/cli.go +++ b/cli/cli.go @@ -37,10 +37,10 @@ func NewCommand() *serpent.Command { // may be called something different when used as a subcommand / there will be a leading binary (i.e. `coder boundary` vs. `boundary`). cmd.Long += `Examples: # Allow only requests to github.com - boundary --allow "github.com" -- curl https://github.com + boundary --allow "domain=github.com" -- curl https://github.com # Monitor all requests to specific domains (allow only those) - boundary --allow "github.com/api/issues/*" --allow "GET,HEAD github.com" -- npm install + boundary --allow "domain=github.com path=/api/issues/*" --allow "method=GET,HEAD domain=github.com" -- npm install # Block everything by default (implicit)` From f743e6aa8e8f65bd8e6edbcc33f172ffb0553f87 Mon Sep 17 00:00:00 2001 From: Benjamin Date: Mon, 22 Sep 2025 16:33:00 -0500 Subject: [PATCH 13/21] adding curl test to debug separate issue in ci --- dns/dns_test.go | 14 ++++++++++++++ 1 file changed, 14 insertions(+) create mode 100644 dns/dns_test.go diff --git a/dns/dns_test.go b/dns/dns_test.go new file mode 100644 index 0000000..9d44320 --- /dev/null +++ b/dns/dns_test.go @@ -0,0 +1,14 @@ +package dns + +import ( + "os/exec" + "testing" +) + +func TestDNSWithCurl(t *testing.T) { + out, err := exec.Command("curl", "--doh-url", "https://dns.google/dns-query", "http://coder.com", "-v").Output() + if err != nil { + t.Fatalf("error curling: %s", err) + } + t.Logf("output: %s", out) +} \ No newline at end of file From 0e868c9e5fcff09e0bf0015c50c0bea452001dc6 Mon Sep 17 00:00:00 2001 From: Benjamin Date: Mon, 22 Sep 2025 16:57:16 -0500 Subject: [PATCH 14/21] Revert "adding curl test to debug separate issue in ci" This reverts commit f743e6aa8e8f65bd8e6edbcc33f172ffb0553f87. --- dns/dns_test.go | 14 -------------- 1 file changed, 14 deletions(-) delete mode 100644 dns/dns_test.go diff --git a/dns/dns_test.go b/dns/dns_test.go deleted file mode 100644 index 9d44320..0000000 --- a/dns/dns_test.go +++ /dev/null @@ -1,14 +0,0 @@ -package dns - -import ( - "os/exec" - "testing" -) - -func TestDNSWithCurl(t *testing.T) { - out, err := exec.Command("curl", "--doh-url", "https://dns.google/dns-query", "http://coder.com", "-v").Output() - if err != nil { - t.Fatalf("error curling: %s", err) - } - t.Logf("output: %s", out) -} \ No newline at end of file From 26e2916a2779babfcaa0c98a07209ebf79e5c607 Mon Sep 17 00:00:00 2001 From: Benjamin Date: Wed, 24 Sep 2025 13:52:01 -0500 Subject: [PATCH 15/21] update e2e test to use key/value syntax --- e2e_tests/boundary_integration_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/e2e_tests/boundary_integration_test.go b/e2e_tests/boundary_integration_test.go index 84ff284..d966a0c 100644 --- a/e2e_tests/boundary_integration_test.go +++ b/e2e_tests/boundary_integration_test.go @@ -78,8 +78,8 @@ func TestBoundaryIntegration(t *testing.T) { // Start boundary process with sudo boundaryCmd := exec.CommandContext(ctx, "/tmp/boundary-test", - "--allow", "dev.coder.com", - "--allow", "jsonplaceholder.typicode.com", + "--allow", "domain=dev.coder.com", + "--allow", "domain=jsonplaceholder.typicode.com", "--log-level", "debug", "--", "bash", "-c", "sleep 10 && echo 'Test completed'") From ced3bc86d3d89868a92851274a487d26d86b0f04 Mon Sep 17 00:00:00 2001 From: Benjamin Date: Thu, 25 Sep 2025 09:24:24 -0500 Subject: [PATCH 16/21] remove superfluous interface from old impl --- boundary.go | 2 +- proxy/proxy.go | 4 ++-- rules/rules.go | 8 ++------ 3 files changed, 5 insertions(+), 9 deletions(-) diff --git a/boundary.go b/boundary.go index d3e98a6..de3fb55 100644 --- a/boundary.go +++ b/boundary.go @@ -15,7 +15,7 @@ import ( ) type Config struct { - RuleEngine rules.Evaluator + RuleEngine rules.Engine Auditor audit.Auditor TLSConfig *tls.Config Logger *slog.Logger diff --git a/proxy/proxy.go b/proxy/proxy.go index 00b116b..9655256 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -20,7 +20,7 @@ import ( // Server handles HTTP and HTTPS requests with rule-based filtering type Server struct { - ruleEngine rules.Evaluator + ruleEngine rules.Engine auditor audit.Auditor logger *slog.Logger tlsConfig *tls.Config @@ -33,7 +33,7 @@ type Server struct { // Config holds configuration for the proxy server type Config struct { HTTPPort int - RuleEngine rules.Evaluator + RuleEngine rules.Engine Auditor audit.Auditor Logger *slog.Logger TLSConfig *tls.Config diff --git a/rules/rules.go b/rules/rules.go index cc5beec..a486b4a 100644 --- a/rules/rules.go +++ b/rules/rules.go @@ -8,10 +8,6 @@ import ( "strings" ) -type Evaluator interface { - Evaluate(method, url string) Result -} - // Rule represents an allow rule with optional HTTP method restrictions type Rule struct { @@ -403,8 +399,8 @@ type Engine struct { } // NewRuleEngine creates a new rule engine -func NewRuleEngine(rules []Rule, logger *slog.Logger) *Engine { - return &Engine{ +func NewRuleEngine(rules []Rule, logger *slog.Logger) Engine { + return Engine{ rules: rules, logger: logger, } From 3a578be326399f8f1bb0836a18bf559109632de9 Mon Sep 17 00:00:00 2001 From: Benjamin Date: Thu, 25 Sep 2025 09:39:34 -0500 Subject: [PATCH 17/21] split engine and rules code into separate files to make easier to read through --- boundary.go | 4 +- cli/cli.go | 6 +- proxy/proxy.go | 6 +- proxy/proxy_test.go | 14 +- rulesengine/engine.go | 122 ++++++++++++ rulesengine/engine_test.go | 271 +++++++++++++++++++++++++++ {rules => rulesengine}/rules.go | 119 +----------- {rules => rulesengine}/rules_test.go | 267 +------------------------- 8 files changed, 410 insertions(+), 399 deletions(-) create mode 100644 rulesengine/engine.go create mode 100644 rulesengine/engine_test.go rename {rules => rulesengine}/rules.go (71%) rename {rules => rulesengine}/rules_test.go (82%) diff --git a/boundary.go b/boundary.go index de3fb55..821f83b 100644 --- a/boundary.go +++ b/boundary.go @@ -11,11 +11,11 @@ import ( "github.com/coder/boundary/audit" "github.com/coder/boundary/jail" "github.com/coder/boundary/proxy" - "github.com/coder/boundary/rules" + "github.com/coder/boundary/rulesengine" ) type Config struct { - RuleEngine rules.Engine + RuleEngine rulesengine.Engine Auditor audit.Auditor TLSConfig *tls.Config Logger *slog.Logger diff --git a/cli/cli.go b/cli/cli.go index f462002..376bb46 100644 --- a/cli/cli.go +++ b/cli/cli.go @@ -12,7 +12,7 @@ import ( "github.com/coder/boundary" "github.com/coder/boundary/audit" "github.com/coder/boundary/jail" - "github.com/coder/boundary/rules" + "github.com/coder/boundary/rulesengine" "github.com/coder/boundary/tls" "github.com/coder/boundary/util" "github.com/coder/serpent" @@ -101,14 +101,14 @@ func Run(ctx context.Context, config Config, args []string) error { } // Parse allow rules - allowRules, err := rules.ParseAllowSpecs(config.AllowStrings) + allowRules, err := rulesengine.ParseAllowSpecs(config.AllowStrings) if err != nil { logger.Error("Failed to parse allow rules", "error", err) return fmt.Errorf("failed to parse allow rules: %v", err) } // Create rule engine - ruleEngine := rules.NewRuleEngine(allowRules, logger) + ruleEngine := rulesengine.NewRuleEngine(allowRules, logger) // Create auditor auditor := audit.NewLogAuditor(logger) diff --git a/proxy/proxy.go b/proxy/proxy.go index 9655256..fa47aae 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -15,12 +15,12 @@ import ( "sync/atomic" "github.com/coder/boundary/audit" - "github.com/coder/boundary/rules" + "github.com/coder/boundary/rulesengine" ) // Server handles HTTP and HTTPS requests with rule-based filtering type Server struct { - ruleEngine rules.Engine + ruleEngine rulesengine.Engine auditor audit.Auditor logger *slog.Logger tlsConfig *tls.Config @@ -33,7 +33,7 @@ type Server struct { // Config holds configuration for the proxy server type Config struct { HTTPPort int - RuleEngine rules.Engine + RuleEngine rulesengine.Engine Auditor audit.Auditor Logger *slog.Logger TLSConfig *tls.Config diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index 8b14f02..2128198 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -17,7 +17,7 @@ import ( "github.com/stretchr/testify/require" "github.com/coder/boundary/audit" - "github.com/coder/boundary/rules" + "github.com/coder/boundary/rulesengine" ) // mockAuditor is a simple mock auditor for testing @@ -35,13 +35,13 @@ func TestProxyServerBasicHTTP(t *testing.T) { })) // Create test rules (allow all for testing) - testRules, err := rules.ParseAllowSpecs([]string{"method=*"}) + testRules, err := rulesengine.ParseAllowSpecs([]string{"method=*"}) if err != nil { t.Fatalf("Failed to parse test rules: %v", err) } // Create rule engine - ruleEngine := rules.NewRuleEngine(testRules, logger) + ruleEngine := rulesengine.NewRuleEngine(testRules, logger) // Create mock auditor auditor := &mockAuditor{} @@ -116,13 +116,13 @@ func TestProxyServerBasicHTTPS(t *testing.T) { })) // Create test rules (allow all for testing) - testRules, err := rules.ParseAllowSpecs([]string{"method=*"}) + testRules, err := rulesengine.ParseAllowSpecs([]string{"method=*"}) if err != nil { t.Fatalf("Failed to parse test rules: %v", err) } // Create rule engine - ruleEngine := rules.NewRuleEngine(testRules, logger) + ruleEngine := rulesengine.NewRuleEngine(testRules, logger) // Create mock auditor auditor := &mockAuditor{} @@ -210,13 +210,13 @@ func TestProxyServerCONNECT(t *testing.T) { })) // Create test rules (allow all for testing) - testRules, err := rules.ParseAllowSpecs([]string{"method=*"}) + testRules, err := rulesengine.ParseAllowSpecs([]string{"method=*"}) if err != nil { t.Fatalf("Failed to parse test rules: %v", err) } // Create rule engine - ruleEngine := rules.NewRuleEngine(testRules, logger) + ruleEngine := rulesengine.NewRuleEngine(testRules, logger) // Create mock auditor auditor := &mockAuditor{} diff --git a/rulesengine/engine.go b/rulesengine/engine.go new file mode 100644 index 0000000..a9c9c58 --- /dev/null +++ b/rulesengine/engine.go @@ -0,0 +1,122 @@ +package rulesengine + +import ( + "log/slog" + neturl "net/url" + "strings" +) + +// Engine evaluates HTTP requests against a set of rules. +type Engine struct { + rules []Rule + logger *slog.Logger +} + +// NewRuleEngine creates a new rule engine +func NewRuleEngine(rules []Rule, logger *slog.Logger) Engine { + return Engine{ + rules: rules, + logger: logger, + } +} + +// Result contains the result of rule evaluation +type Result struct { + Allowed bool + Rule string // The rule that matched (if any) +} + +// Evaluate evaluates a request and returns both result and matching rule +func (re *Engine) Evaluate(method, url string) Result { + // Check if any allow rule matches + for _, rule := range re.rules { + if re.matches(rule, method, url) { + return Result{ + Allowed: true, + Rule: rule.Raw, + } + } + } + + // Default deny if no allow rules match + return Result{ + Allowed: false, + Rule: "", + } +} + +// Matches checks if the rule matches the given method and URL using wildcard patterns +func (re *Engine) matches(r Rule, method, url string) bool { + + // Check method patterns if they exist + if r.MethodPatterns != nil { + methodMatches := false + for mp := range r.MethodPatterns { + if string(mp) == method || string(mp) == "*" { + methodMatches = true + break + } + } + if !methodMatches { + re.logger.Debug("rule does not match", "reason", "method pattern mismatch", "rule", r.Raw, "method", method, "url", url) + return false + } + } + + parsedUrl, err := neturl.Parse(url) + if err != nil { + re.logger.Debug("rule does not match", "reason", "invalid URL", "rule", r.Raw, "method", method, "url", url, "error", err) + return false + } + + if r.HostPattern != nil { + // For a host pattern to match, every label has to match or be an `*`. + // Subdomains also match automatically, meaning if the pattern is "example.com" + // and the real is "api.example.com", it should match. We check this by comparing + // from the end of the actual hostname with the pattern (which is in normal order). + + labels := strings.Split(parsedUrl.Hostname(), ".") + + // If the host pattern is longer than the actual host, it's definitely not a match + if len(r.HostPattern) > len(labels) { + re.logger.Debug("rule does not match", "reason", "host pattern too long", "rule", r.Raw, "method", method, "url", url, "pattern_length", len(r.HostPattern), "hostname_labels", len(labels)) + return false + } + + // Since host patterns cannot end with asterisk, we only need to handle: + // "example.com" or "*.example.com" - match from the end (allowing subdomains) + for i, lp := range r.HostPattern { + labelIndex := len(labels) - len(r.HostPattern) + i + if string(lp) != labels[labelIndex] && lp != "*" { + re.logger.Debug("rule does not match", "reason", "host pattern label mismatch", "rule", r.Raw, "method", method, "url", url, "expected", string(lp), "actual", labels[labelIndex]) + return false + } + } + } + + if r.PathPattern != nil { + segments := strings.Split(parsedUrl.Path, "/") + + // Skip the first empty segment if the path starts with "/" + if len(segments) > 0 && segments[0] == "" { + segments = segments[1:] + } + + // If the path pattern is longer than the actual path, definitely not a match + if len(r.PathPattern) > len(segments) { + re.logger.Debug("rule does not match", "reason", "path pattern too long", "rule", r.Raw, "method", method, "url", url, "pattern_length", len(r.PathPattern), "path_segments", len(segments)) + return false + } + + // Each segment in the pattern must be either as asterisk or match the actual path segment + for i, sp := range r.PathPattern { + if string(sp) != segments[i] && sp != "*" { + re.logger.Debug("rule does not match", "reason", "path pattern segment mismatch", "rule", r.Raw, "method", method, "url", url, "expected", string(sp), "actual", segments[i]) + return false + } + } + } + + re.logger.Debug("rule matches", "reason", "all patterns matched", "rule", r.Raw, "method", method, "url", url) + return true +} diff --git a/rulesengine/engine_test.go b/rulesengine/engine_test.go new file mode 100644 index 0000000..b4d3c2e --- /dev/null +++ b/rulesengine/engine_test.go @@ -0,0 +1,271 @@ +package rulesengine + +import ( + "log/slog" + "testing" +) + +func TestEngineMatches(t *testing.T) { + logger := slog.Default() + engine := NewRuleEngine(nil, logger) + + tests := []struct { + name string + rule Rule + method string + url string + expected bool + }{ + // Method pattern tests + { + name: "method matches exact", + rule: Rule{ + MethodPatterns: map[methodPattern]struct{}{methodPattern("GET"): {}}, + }, + method: "GET", + url: "https://example.com/api", + expected: true, + }, + { + name: "method does not match", + rule: Rule{ + MethodPatterns: map[methodPattern]struct{}{methodPattern("POST"): {}}, + }, + method: "GET", + url: "https://example.com/api", + expected: false, + }, + { + name: "method wildcard matches any", + rule: Rule{ + MethodPatterns: map[methodPattern]struct{}{methodPattern("*"): {}}, + }, + method: "PUT", + url: "https://example.com/api", + expected: true, + }, + { + name: "no method pattern allows all methods", + rule: Rule{ + HostPattern: []labelPattern{labelPattern("example"), labelPattern("com")}, + }, + method: "DELETE", + url: "https://example.com/api", + expected: true, + }, + + // Host pattern tests + { + name: "host matches exact", + rule: Rule{ + HostPattern: []labelPattern{labelPattern("example"), labelPattern("com")}, + }, + method: "GET", + url: "https://example.com/api", + expected: true, + }, + { + name: "host does not match", + rule: Rule{ + HostPattern: []labelPattern{labelPattern("example"), labelPattern("org")}, + }, + method: "GET", + url: "https://example.com/api", + expected: false, + }, + { + name: "subdomain matches", + rule: Rule{ + HostPattern: []labelPattern{labelPattern("example"), labelPattern("com")}, + }, + method: "GET", + url: "https://api.example.com/users", + expected: true, + }, + { + name: "host pattern too long", + rule: Rule{ + HostPattern: []labelPattern{labelPattern("v1"), labelPattern("api"), labelPattern("example"), labelPattern("com")}, + }, + method: "GET", + url: "https://api.example.com/users", + expected: false, + }, + { + name: "host wildcard matches", + rule: Rule{ + HostPattern: []labelPattern{labelPattern("*"), labelPattern("com")}, + }, + method: "GET", + url: "https://test.com/api", + expected: true, + }, + { + name: "multiple host wildcards", + rule: Rule{ + HostPattern: []labelPattern{labelPattern("*"), labelPattern("*")}, + }, + method: "GET", + url: "https://api.example.com/users", + expected: true, + }, + + // Path pattern tests + { + name: "path matches exact", + rule: Rule{ + PathPattern: []segmentPattern{segmentPattern("api"), segmentPattern("users")}, + }, + method: "GET", + url: "https://example.com/api/users", + expected: true, + }, + { + name: "path does not match", + rule: Rule{ + PathPattern: []segmentPattern{segmentPattern("api"), segmentPattern("posts")}, + }, + method: "GET", + url: "https://example.com/api/users", + expected: false, + }, + { + name: "subpath matches", + rule: Rule{ + PathPattern: []segmentPattern{segmentPattern("api")}, + }, + method: "GET", + url: "https://example.com/api/users/123", + expected: true, + }, + { + name: "path pattern too long", + rule: Rule{ + PathPattern: []segmentPattern{segmentPattern("api"), segmentPattern("v1"), segmentPattern("users"), segmentPattern("profile")}, + }, + method: "GET", + url: "https://example.com/api/v1/users", + expected: false, + }, + { + name: "path wildcard matches", + rule: Rule{ + PathPattern: []segmentPattern{segmentPattern("api"), segmentPattern("*"), segmentPattern("profile")}, + }, + method: "GET", + url: "https://example.com/api/users/profile", + expected: true, + }, + { + name: "multiple path wildcards", + rule: Rule{ + PathPattern: []segmentPattern{segmentPattern("*"), segmentPattern("*")}, + }, + method: "GET", + url: "https://example.com/api/users/123", + expected: true, + }, + + // Combined pattern tests + { + name: "all patterns match", + rule: Rule{ + MethodPatterns: map[methodPattern]struct{}{methodPattern("POST"): {}}, + HostPattern: []labelPattern{labelPattern("api"), labelPattern("com")}, + PathPattern: []segmentPattern{segmentPattern("users")}, + }, + method: "POST", + url: "https://api.com/users", + expected: true, + }, + { + name: "method fails combined test", + rule: Rule{ + MethodPatterns: map[methodPattern]struct{}{methodPattern("POST"): {}}, + HostPattern: []labelPattern{labelPattern("api"), labelPattern("com")}, + PathPattern: []segmentPattern{segmentPattern("users")}, + }, + method: "GET", + url: "https://api.com/users", + expected: false, + }, + { + name: "host fails combined test", + rule: Rule{ + MethodPatterns: map[methodPattern]struct{}{methodPattern("POST"): {}}, + HostPattern: []labelPattern{labelPattern("api"), labelPattern("org")}, + PathPattern: []segmentPattern{segmentPattern("users")}, + }, + method: "POST", + url: "https://api.com/users", + expected: false, + }, + { + name: "path fails combined test", + rule: Rule{ + MethodPatterns: map[methodPattern]struct{}{methodPattern("POST"): {}}, + HostPattern: []labelPattern{labelPattern("api"), labelPattern("com")}, + PathPattern: []segmentPattern{segmentPattern("posts")}, + }, + method: "POST", + url: "https://api.com/users", + expected: false, + }, + { + name: "all wildcards match", + rule: Rule{ + MethodPatterns: map[methodPattern]struct{}{methodPattern("*"): {}}, + HostPattern: []labelPattern{labelPattern("*"), labelPattern("*")}, + PathPattern: []segmentPattern{segmentPattern("*"), segmentPattern("*")}, + }, + method: "PATCH", + url: "https://test.example.com/api/users/123", + expected: true, + }, + + // Edge cases + { + name: "empty rule matches everything", + rule: Rule{}, + method: "GET", + url: "https://example.com/api/users", + expected: true, + }, + { + name: "invalid URL", + rule: Rule{ + HostPattern: []labelPattern{labelPattern("example"), labelPattern("com")}, + }, + method: "GET", + url: "not-a-valid-url", + expected: false, + }, + { + name: "root path", + rule: Rule{ + PathPattern: []segmentPattern{}, + }, + method: "GET", + url: "https://example.com/", + expected: true, + }, + { + name: "localhost host", + rule: Rule{ + HostPattern: []labelPattern{labelPattern("localhost")}, + }, + method: "GET", + url: "http://localhost:8080/api", + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := engine.matches(tt.rule, tt.method, tt.url) + if result != tt.expected { + t.Errorf("expected %v, got %v", tt.expected, result) + } + }) + } +} diff --git a/rules/rules.go b/rulesengine/rules.go similarity index 71% rename from rules/rules.go rename to rulesengine/rules.go index a486b4a..2ec0be4 100644 --- a/rules/rules.go +++ b/rulesengine/rules.go @@ -1,10 +1,8 @@ -package rules +package rulesengine import ( "errors" "fmt" - "log/slog" - neturl "net/url" "strings" ) @@ -391,118 +389,3 @@ func ParseAllowSpecs(allowStrings []string) ([]Rule, error) { } return out, nil } - -// Engine evaluates HTTP requests against a set of rules -type Engine struct { - rules []Rule - logger *slog.Logger -} - -// NewRuleEngine creates a new rule engine -func NewRuleEngine(rules []Rule, logger *slog.Logger) Engine { - return Engine{ - rules: rules, - logger: logger, - } -} - -// Result contains the result of rule evaluation -type Result struct { - Allowed bool - Rule string // The rule that matched (if any) -} - -// Evaluate evaluates a request and returns both result and matching rule -func (re *Engine) Evaluate(method, url string) Result { - // Check if any allow rule matches - for _, rule := range re.rules { - if re.matches(rule, method, url) { - return Result{ - Allowed: true, - Rule: rule.Raw, - } - } - } - - // Default deny if no allow rules match - return Result{ - Allowed: false, - Rule: "", - } -} - -// Matches checks if the rule matches the given method and URL using wildcard patterns -func (re *Engine) matches(r Rule, method, url string) bool { - - // Check method patterns if they exist - if r.MethodPatterns != nil { - methodMatches := false - for mp := range r.MethodPatterns { - if string(mp) == method || string(mp) == "*" { - methodMatches = true - break - } - } - if !methodMatches { - re.logger.Debug("rule does not match", "reason", "method pattern mismatch", "rule", r.Raw, "method", method, "url", url) - return false - } - } - - parsedUrl, err := neturl.Parse(url) - if err != nil { - re.logger.Debug("rule does not match", "reason", "invalid URL", "rule", r.Raw, "method", method, "url", url, "error", err) - return false - } - - if r.HostPattern != nil { - // For a host pattern to match, every label has to match or be an `*`. - // Subdomains also match automatically, meaning if the pattern is "example.com" - // and the real is "api.example.com", it should match. We check this by comparing - // from the end of the actual hostname with the pattern (which is in normal order). - - labels := strings.Split(parsedUrl.Hostname(), ".") - - // If the host pattern is longer than the actual host, it's definitely not a match - if len(r.HostPattern) > len(labels) { - re.logger.Debug("rule does not match", "reason", "host pattern too long", "rule", r.Raw, "method", method, "url", url, "pattern_length", len(r.HostPattern), "hostname_labels", len(labels)) - return false - } - - // Since host patterns cannot end with asterisk, we only need to handle: - // "example.com" or "*.example.com" - match from the end (allowing subdomains) - for i, lp := range r.HostPattern { - labelIndex := len(labels) - len(r.HostPattern) + i - if string(lp) != labels[labelIndex] && lp != "*" { - re.logger.Debug("rule does not match", "reason", "host pattern label mismatch", "rule", r.Raw, "method", method, "url", url, "expected", string(lp), "actual", labels[labelIndex]) - return false - } - } - } - - if r.PathPattern != nil { - segments := strings.Split(parsedUrl.Path, "/") - - // Skip the first empty segment if the path starts with "/" - if len(segments) > 0 && segments[0] == "" { - segments = segments[1:] - } - - // If the path pattern is longer than the actual path, definitely not a match - if len(r.PathPattern) > len(segments) { - re.logger.Debug("rule does not match", "reason", "path pattern too long", "rule", r.Raw, "method", method, "url", url, "pattern_length", len(r.PathPattern), "path_segments", len(segments)) - return false - } - - // Each segment in the pattern must be either as asterisk or match the actual path segment - for i, sp := range r.PathPattern { - if string(sp) != segments[i] && sp != "*" { - re.logger.Debug("rule does not match", "reason", "path pattern segment mismatch", "rule", r.Raw, "method", method, "url", url, "expected", string(sp), "actual", segments[i]) - return false - } - } - } - - re.logger.Debug("rule matches", "reason", "all patterns matched", "rule", r.Raw, "method", method, "url", url) - return true -} diff --git a/rules/rules_test.go b/rulesengine/rules_test.go similarity index 82% rename from rules/rules_test.go rename to rulesengine/rules_test.go index dc5bb70..20cddf7 100644 --- a/rules/rules_test.go +++ b/rulesengine/rules_test.go @@ -1,4 +1,4 @@ -package rules +package rulesengine import ( "fmt" @@ -1065,271 +1065,6 @@ func TestParseAllowRule(t *testing.T) { } } -func TestEngineMatches(t *testing.T) { - logger := slog.Default() - engine := NewRuleEngine(nil, logger) - - tests := []struct { - name string - rule Rule - method string - url string - expected bool - }{ - // Method pattern tests - { - name: "method matches exact", - rule: Rule{ - MethodPatterns: map[methodPattern]struct{}{methodPattern("GET"): {}}, - }, - method: "GET", - url: "https://example.com/api", - expected: true, - }, - { - name: "method does not match", - rule: Rule{ - MethodPatterns: map[methodPattern]struct{}{methodPattern("POST"): {}}, - }, - method: "GET", - url: "https://example.com/api", - expected: false, - }, - { - name: "method wildcard matches any", - rule: Rule{ - MethodPatterns: map[methodPattern]struct{}{methodPattern("*"): {}}, - }, - method: "PUT", - url: "https://example.com/api", - expected: true, - }, - { - name: "no method pattern allows all methods", - rule: Rule{ - HostPattern: []labelPattern{labelPattern("example"), labelPattern("com")}, - }, - method: "DELETE", - url: "https://example.com/api", - expected: true, - }, - - // Host pattern tests - { - name: "host matches exact", - rule: Rule{ - HostPattern: []labelPattern{labelPattern("example"), labelPattern("com")}, - }, - method: "GET", - url: "https://example.com/api", - expected: true, - }, - { - name: "host does not match", - rule: Rule{ - HostPattern: []labelPattern{labelPattern("example"), labelPattern("org")}, - }, - method: "GET", - url: "https://example.com/api", - expected: false, - }, - { - name: "subdomain matches", - rule: Rule{ - HostPattern: []labelPattern{labelPattern("example"), labelPattern("com")}, - }, - method: "GET", - url: "https://api.example.com/users", - expected: true, - }, - { - name: "host pattern too long", - rule: Rule{ - HostPattern: []labelPattern{labelPattern("v1"), labelPattern("api"), labelPattern("example"), labelPattern("com")}, - }, - method: "GET", - url: "https://api.example.com/users", - expected: false, - }, - { - name: "host wildcard matches", - rule: Rule{ - HostPattern: []labelPattern{labelPattern("*"), labelPattern("com")}, - }, - method: "GET", - url: "https://test.com/api", - expected: true, - }, - { - name: "multiple host wildcards", - rule: Rule{ - HostPattern: []labelPattern{labelPattern("*"), labelPattern("*")}, - }, - method: "GET", - url: "https://api.example.com/users", - expected: true, - }, - - // Path pattern tests - { - name: "path matches exact", - rule: Rule{ - PathPattern: []segmentPattern{segmentPattern("api"), segmentPattern("users")}, - }, - method: "GET", - url: "https://example.com/api/users", - expected: true, - }, - { - name: "path does not match", - rule: Rule{ - PathPattern: []segmentPattern{segmentPattern("api"), segmentPattern("posts")}, - }, - method: "GET", - url: "https://example.com/api/users", - expected: false, - }, - { - name: "subpath matches", - rule: Rule{ - PathPattern: []segmentPattern{segmentPattern("api")}, - }, - method: "GET", - url: "https://example.com/api/users/123", - expected: true, - }, - { - name: "path pattern too long", - rule: Rule{ - PathPattern: []segmentPattern{segmentPattern("api"), segmentPattern("v1"), segmentPattern("users"), segmentPattern("profile")}, - }, - method: "GET", - url: "https://example.com/api/v1/users", - expected: false, - }, - { - name: "path wildcard matches", - rule: Rule{ - PathPattern: []segmentPattern{segmentPattern("api"), segmentPattern("*"), segmentPattern("profile")}, - }, - method: "GET", - url: "https://example.com/api/users/profile", - expected: true, - }, - { - name: "multiple path wildcards", - rule: Rule{ - PathPattern: []segmentPattern{segmentPattern("*"), segmentPattern("*")}, - }, - method: "GET", - url: "https://example.com/api/users/123", - expected: true, - }, - - // Combined pattern tests - { - name: "all patterns match", - rule: Rule{ - MethodPatterns: map[methodPattern]struct{}{methodPattern("POST"): {}}, - HostPattern: []labelPattern{labelPattern("api"), labelPattern("com")}, - PathPattern: []segmentPattern{segmentPattern("users")}, - }, - method: "POST", - url: "https://api.com/users", - expected: true, - }, - { - name: "method fails combined test", - rule: Rule{ - MethodPatterns: map[methodPattern]struct{}{methodPattern("POST"): {}}, - HostPattern: []labelPattern{labelPattern("api"), labelPattern("com")}, - PathPattern: []segmentPattern{segmentPattern("users")}, - }, - method: "GET", - url: "https://api.com/users", - expected: false, - }, - { - name: "host fails combined test", - rule: Rule{ - MethodPatterns: map[methodPattern]struct{}{methodPattern("POST"): {}}, - HostPattern: []labelPattern{labelPattern("api"), labelPattern("org")}, - PathPattern: []segmentPattern{segmentPattern("users")}, - }, - method: "POST", - url: "https://api.com/users", - expected: false, - }, - { - name: "path fails combined test", - rule: Rule{ - MethodPatterns: map[methodPattern]struct{}{methodPattern("POST"): {}}, - HostPattern: []labelPattern{labelPattern("api"), labelPattern("com")}, - PathPattern: []segmentPattern{segmentPattern("posts")}, - }, - method: "POST", - url: "https://api.com/users", - expected: false, - }, - { - name: "all wildcards match", - rule: Rule{ - MethodPatterns: map[methodPattern]struct{}{methodPattern("*"): {}}, - HostPattern: []labelPattern{labelPattern("*"), labelPattern("*")}, - PathPattern: []segmentPattern{segmentPattern("*"), segmentPattern("*")}, - }, - method: "PATCH", - url: "https://test.example.com/api/users/123", - expected: true, - }, - - // Edge cases - { - name: "empty rule matches everything", - rule: Rule{}, - method: "GET", - url: "https://example.com/api/users", - expected: true, - }, - { - name: "invalid URL", - rule: Rule{ - HostPattern: []labelPattern{labelPattern("example"), labelPattern("com")}, - }, - method: "GET", - url: "not-a-valid-url", - expected: false, - }, - { - name: "root path", - rule: Rule{ - PathPattern: []segmentPattern{}, - }, - method: "GET", - url: "https://example.com/", - expected: true, - }, - { - name: "localhost host", - rule: Rule{ - HostPattern: []labelPattern{labelPattern("localhost")}, - }, - method: "GET", - url: "http://localhost:8080/api", - expected: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := engine.matches(tt.rule, tt.method, tt.url) - if result != tt.expected { - t.Errorf("expected %v, got %v", tt.expected, result) - } - }) - } -} - func TestReadmeExamples(t *testing.T) { logger := slog.Default() From 959e4bdf83b8800556495a91bcfa566c0d628f48 Mon Sep 17 00:00:00 2001 From: Benjamin Date: Thu, 25 Sep 2025 10:48:56 -0500 Subject: [PATCH 18/21] mutate rest in parse allow rule to make pattern clearer --- rulesengine/rules.go | 207 ++++++++++++++++++++++--------------------- 1 file changed, 107 insertions(+), 100 deletions(-) diff --git a/rulesengine/rules.go b/rulesengine/rules.go index 2ec0be4..5eac544 100644 --- a/rulesengine/rules.go +++ b/rulesengine/rules.go @@ -6,29 +6,125 @@ import ( "strings" ) -// Rule represents an allow rule with optional HTTP method restrictions +// Rule represents an allow rule passed to the cli with --allow or read from the config file. +// Rules have a specific grammar that we need to parse carefully. +// Example: --allow="method=GET,PATCH domain=wibble.wobble.com, path=/posts/*" type Rule struct { - // The path segments of the url - // nil means all paths allowed - // a path segment of `*` acts as a wild card. - // sub paths automatically match + // The path segments of the url. + // - nil means all paths allowed + // - a path segment of `*` acts as a wild card. + // - sub paths automatically match PathPattern []segmentPattern - // The labels of the host, i.e. ["google", "com"] - // nil means all hosts allowed - // A label of `*` acts as a wild card. - // subdomains automatically match + // The labels of the host, i.e. ["google", "com"]. + // - nil means all hosts allowed + // - A label of `*` acts as a wild card. + // - subdomains automatically match HostPattern []labelPattern - // The allowed http methods - // nil means all methods allowed + // The allowed http methods. + // - nil means all methods allowed MethodPatterns map[methodPattern]struct{} // Raw rule string for logging Raw string } +// ParseAllowSpecs parses a slice of --allow specs into allow Rules. +func ParseAllowSpecs(allowStrings []string) ([]Rule, error) { + var out []Rule + for _, s := range allowStrings { + r, err := parseAllowRule(s) + if err != nil { + return nil, fmt.Errorf("failed to parse allow '%s': %v", s, err) + } + out = append(out, r) + } + return out, nil +} + +// parseAllowRule takes an allow rule string and tries to parse it as a rule. +func parseAllowRule(ruleStr string) (Rule, error) { + rule := Rule{ + Raw: ruleStr, + } + + // Functions called by this function used a really common pattern: recursive descent parsing. + // All the helper functions for parsing an allow rule will be called like `thing, rest, err := parseThing(rest)`. + // What's going on here is that we try to parse some expected text from the front of the string. + // If we succeed, we get back the thing we parsed and the remaining text. If we fail, we get back a non nil error. + rest := ruleStr + var key string + var err error + + // Ann allow rule can have as many key=value pairs as needed, we go until there's no more text in the rule. + for rest != "" { + // Parse the key + key, rest, err = parseKey(rest) + if err != nil { + return Rule{}, fmt.Errorf("failed to parse key: %v", err) + } + + // Parse the value based on the key type + switch key { + case "method": + // Initialize Methods map if needed + if rule.MethodPatterns == nil { + rule.MethodPatterns = make(map[methodPattern]struct{}) + } + + var method methodPattern + for { + method, rest, err = parseMethodPattern(rest) + if err != nil { + return Rule{}, fmt.Errorf("failed to parse method: %v", err) + } + + rule.MethodPatterns[method] = struct{}{} + + // Check if there's a comma for more methods + if rest != "" && rest[0] == ',' { + rest = rest[1:] // Skip the comma + continue + } + + break + } + + case "domain": + var host []labelPattern + host, rest, err = parseHostPattern(rest) + if err != nil { + return Rule{}, fmt.Errorf("failed to parse domain: %v", err) + } + + // Convert labels to strings + rule.HostPattern = append(rule.HostPattern, host...) + + case "path": + var segments []segmentPattern + segments, rest, err = parsePathPattern(rest) + if err != nil { + return Rule{}, fmt.Errorf("failed to parse path: %v", err) + } + + // Convert segments to strings + rule.PathPattern = append(rule.PathPattern, segments...) + + default: + return Rule{}, fmt.Errorf("unknown key: %s", key) + } + + // Skip whitespace or comma separators + for rest != "" && (rest[0] == ' ' || rest[0] == '\t' || rest[0] == ',') { + rest = rest[1:] + } + } + + return rule, nil +} + type methodPattern string // Beyond the 9 methods defined in HTTP 1.1, there actually are many more seldom used extension methods by @@ -300,92 +396,3 @@ func parseKey(rule string) (string, string, error) { return "", "", errors.New("expected key") } - -func parseAllowRule(ruleStr string) (Rule, error) { - rule := Rule{ - Raw: ruleStr, - } - - rest := ruleStr - - for rest != "" { - // Parse the key - key, valueRest, err := parseKey(rest) - if err != nil { - return Rule{}, fmt.Errorf("failed to parse key: %v", err) - } - - // Parse the value based on the key type - switch key { - case "method": - // Handle comma-separated methods - methodsRest := valueRest - - // Initialize Methods map if needed - if rule.MethodPatterns == nil { - rule.MethodPatterns = make(map[methodPattern]struct{}) - } - - for { - token, remaining, err := parseMethodPattern(methodsRest) - if err != nil { - return Rule{}, fmt.Errorf("failed to parse method: %v", err) - } - - rule.MethodPatterns[token] = struct{}{} - - // Check if there's a comma for more methods - if remaining != "" && remaining[0] == ',' { - methodsRest = remaining[1:] // Skip the comma - continue - } - - rest = remaining - break - } - - case "domain": - hostLabels, remaining, err := parseHostPattern(valueRest) - if err != nil { - return Rule{}, fmt.Errorf("failed to parse domain: %v", err) - } - - // Convert labels to strings - rule.HostPattern = append(rule.HostPattern, hostLabels...) - rest = remaining - - case "path": - segments, remaining, err := parsePathPattern(valueRest) - if err != nil { - return Rule{}, fmt.Errorf("failed to parse path: %v", err) - } - - // Convert segments to strings - rule.PathPattern = append(rule.PathPattern, segments...) - rest = remaining - - default: - return Rule{}, fmt.Errorf("unknown key: %s", key) - } - - // Skip whitespace or comma separators - for rest != "" && (rest[0] == ' ' || rest[0] == '\t' || rest[0] == ',') { - rest = rest[1:] - } - } - - return rule, nil -} - -// ParseAllowSpecs parses a slice of --allow specs into allow Rules. -func ParseAllowSpecs(allowStrings []string) ([]Rule, error) { - var out []Rule - for _, s := range allowStrings { - r, err := parseAllowRule(s) - if err != nil { - return nil, fmt.Errorf("failed to parse allow '%s': %v", s, err) - } - out = append(out, r) - } - return out, nil -} From bd82e3d7cbe5fee90800ad7f3e43f890acbe42ec Mon Sep 17 00:00:00 2001 From: Benjamin Date: Thu, 25 Sep 2025 10:52:59 -0500 Subject: [PATCH 19/21] mutate rest throughout rules --- rulesengine/rules.go | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/rulesengine/rules.go b/rulesengine/rules.go index 5eac544..9b057b0 100644 --- a/rulesengine/rules.go +++ b/rulesengine/rules.go @@ -178,15 +178,17 @@ func isHTTPTokenChar(c byte) bool { // Represents a valid host. // https://datatracker.ietf.org/doc/html/rfc952 // https://datatracker.ietf.org/doc/html/rfc1123#page-13 -func parseHostPattern(input string) (host []labelPattern, rest string, err error) { - rest = input - var label labelPattern +func parseHostPattern(input string) ([]labelPattern, string, error) { + rest := input + var host []labelPattern + var err error if input == "" { return nil, "", errors.New("expected host, got empty string") } // There should be at least one label. + var label labelPattern label, rest, err = parseLabelPattern(rest) if err != nil { return nil, "", err @@ -271,8 +273,9 @@ func parsePathPattern(input string) ([]segmentPattern, string, error) { return nil, "", nil } - var segments []segmentPattern rest := input + var segments []segmentPattern + var err error // If the path doesn't start with '/', it's not a valid absolute path // But we'll be flexible and parse relative paths too @@ -288,19 +291,19 @@ func parsePathPattern(input string) ([]segmentPattern, string, error) { } // Parse the next segment - seg, remaining, err := parsePathSegmentPattern(rest) + var segment segmentPattern + segment, rest, err = parsePathSegmentPattern(rest) if err != nil { return nil, "", err } // If we got an empty segment and there's still input, // it means we hit an invalid character - if seg == "" && remaining != "" { + if segment == "" && rest != "" { break } - segments = append(segments, seg) - rest = remaining + segments = append(segments, segment) // If there's no slash after the segment, we're done parsing the path if rest == "" || rest[0] != '/' { From 812b5f8a5eb9b8a0953dfb6d39ecd521f0fbba91 Mon Sep 17 00:00:00 2001 From: Benjamin Date: Thu, 25 Sep 2025 10:58:18 -0500 Subject: [PATCH 20/21] don't allow comma separated rules --- rulesengine/rules.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rulesengine/rules.go b/rulesengine/rules.go index 9b057b0..22ca4fb 100644 --- a/rulesengine/rules.go +++ b/rulesengine/rules.go @@ -116,8 +116,8 @@ func parseAllowRule(ruleStr string) (Rule, error) { return Rule{}, fmt.Errorf("unknown key: %s", key) } - // Skip whitespace or comma separators - for rest != "" && (rest[0] == ' ' || rest[0] == '\t' || rest[0] == ',') { + // Skip whitespace separators (only support mac and linux so \r\n shouldn't be a thing) + for rest != "" && (rest[0] == ' ' || rest[0] == '\t' || rest[0] == '\n') { rest = rest[1:] } } From c5ecf9829315954ce4a573f2dcd4cb936c8937d3 Mon Sep 17 00:00:00 2001 From: Benjamin Date: Fri, 26 Sep 2025 10:55:18 -0500 Subject: [PATCH 21/21] remove custom parsed types in favor of stringly typed api --- rulesengine/engine_test.go | 68 ++++++++++++------------- rulesengine/rules.go | 52 ++++++++----------- rulesengine/rules_test.go | 102 ++++++++++++++++++------------------- 3 files changed, 107 insertions(+), 115 deletions(-) diff --git a/rulesengine/engine_test.go b/rulesengine/engine_test.go index b4d3c2e..2a5137f 100644 --- a/rulesengine/engine_test.go +++ b/rulesengine/engine_test.go @@ -20,7 +20,7 @@ func TestEngineMatches(t *testing.T) { { name: "method matches exact", rule: Rule{ - MethodPatterns: map[methodPattern]struct{}{methodPattern("GET"): {}}, + MethodPatterns: map[string]struct{}{"GET": {}}, }, method: "GET", url: "https://example.com/api", @@ -29,7 +29,7 @@ func TestEngineMatches(t *testing.T) { { name: "method does not match", rule: Rule{ - MethodPatterns: map[methodPattern]struct{}{methodPattern("POST"): {}}, + MethodPatterns: map[string]struct{}{"POST": {}}, }, method: "GET", url: "https://example.com/api", @@ -38,7 +38,7 @@ func TestEngineMatches(t *testing.T) { { name: "method wildcard matches any", rule: Rule{ - MethodPatterns: map[methodPattern]struct{}{methodPattern("*"): {}}, + MethodPatterns: map[string]struct{}{"*": {}}, }, method: "PUT", url: "https://example.com/api", @@ -47,7 +47,7 @@ func TestEngineMatches(t *testing.T) { { name: "no method pattern allows all methods", rule: Rule{ - HostPattern: []labelPattern{labelPattern("example"), labelPattern("com")}, + HostPattern: []string{"example", "com"}, }, method: "DELETE", url: "https://example.com/api", @@ -58,7 +58,7 @@ func TestEngineMatches(t *testing.T) { { name: "host matches exact", rule: Rule{ - HostPattern: []labelPattern{labelPattern("example"), labelPattern("com")}, + HostPattern: []string{"example", "com"}, }, method: "GET", url: "https://example.com/api", @@ -67,7 +67,7 @@ func TestEngineMatches(t *testing.T) { { name: "host does not match", rule: Rule{ - HostPattern: []labelPattern{labelPattern("example"), labelPattern("org")}, + HostPattern: []string{"example", "org"}, }, method: "GET", url: "https://example.com/api", @@ -76,7 +76,7 @@ func TestEngineMatches(t *testing.T) { { name: "subdomain matches", rule: Rule{ - HostPattern: []labelPattern{labelPattern("example"), labelPattern("com")}, + HostPattern: []string{"example", "com"}, }, method: "GET", url: "https://api.example.com/users", @@ -85,7 +85,7 @@ func TestEngineMatches(t *testing.T) { { name: "host pattern too long", rule: Rule{ - HostPattern: []labelPattern{labelPattern("v1"), labelPattern("api"), labelPattern("example"), labelPattern("com")}, + HostPattern: []string{"v1", "api", "example", "com"}, }, method: "GET", url: "https://api.example.com/users", @@ -94,7 +94,7 @@ func TestEngineMatches(t *testing.T) { { name: "host wildcard matches", rule: Rule{ - HostPattern: []labelPattern{labelPattern("*"), labelPattern("com")}, + HostPattern: []string{"*", "com"}, }, method: "GET", url: "https://test.com/api", @@ -103,7 +103,7 @@ func TestEngineMatches(t *testing.T) { { name: "multiple host wildcards", rule: Rule{ - HostPattern: []labelPattern{labelPattern("*"), labelPattern("*")}, + HostPattern: []string{"*", "*"}, }, method: "GET", url: "https://api.example.com/users", @@ -114,7 +114,7 @@ func TestEngineMatches(t *testing.T) { { name: "path matches exact", rule: Rule{ - PathPattern: []segmentPattern{segmentPattern("api"), segmentPattern("users")}, + PathPattern: []string{"api", "users"}, }, method: "GET", url: "https://example.com/api/users", @@ -123,7 +123,7 @@ func TestEngineMatches(t *testing.T) { { name: "path does not match", rule: Rule{ - PathPattern: []segmentPattern{segmentPattern("api"), segmentPattern("posts")}, + PathPattern: []string{"api", "posts"}, }, method: "GET", url: "https://example.com/api/users", @@ -132,7 +132,7 @@ func TestEngineMatches(t *testing.T) { { name: "subpath matches", rule: Rule{ - PathPattern: []segmentPattern{segmentPattern("api")}, + PathPattern: []string{"api"}, }, method: "GET", url: "https://example.com/api/users/123", @@ -141,7 +141,7 @@ func TestEngineMatches(t *testing.T) { { name: "path pattern too long", rule: Rule{ - PathPattern: []segmentPattern{segmentPattern("api"), segmentPattern("v1"), segmentPattern("users"), segmentPattern("profile")}, + PathPattern: []string{"api", "v1", "users", "profile"}, }, method: "GET", url: "https://example.com/api/v1/users", @@ -150,7 +150,7 @@ func TestEngineMatches(t *testing.T) { { name: "path wildcard matches", rule: Rule{ - PathPattern: []segmentPattern{segmentPattern("api"), segmentPattern("*"), segmentPattern("profile")}, + PathPattern: []string{"api", "*", "profile"}, }, method: "GET", url: "https://example.com/api/users/profile", @@ -159,7 +159,7 @@ func TestEngineMatches(t *testing.T) { { name: "multiple path wildcards", rule: Rule{ - PathPattern: []segmentPattern{segmentPattern("*"), segmentPattern("*")}, + PathPattern: []string{"*", "*"}, }, method: "GET", url: "https://example.com/api/users/123", @@ -170,9 +170,9 @@ func TestEngineMatches(t *testing.T) { { name: "all patterns match", rule: Rule{ - MethodPatterns: map[methodPattern]struct{}{methodPattern("POST"): {}}, - HostPattern: []labelPattern{labelPattern("api"), labelPattern("com")}, - PathPattern: []segmentPattern{segmentPattern("users")}, + MethodPatterns: map[string]struct{}{"POST": {}}, + HostPattern: []string{"api", "com"}, + PathPattern: []string{"users"}, }, method: "POST", url: "https://api.com/users", @@ -181,9 +181,9 @@ func TestEngineMatches(t *testing.T) { { name: "method fails combined test", rule: Rule{ - MethodPatterns: map[methodPattern]struct{}{methodPattern("POST"): {}}, - HostPattern: []labelPattern{labelPattern("api"), labelPattern("com")}, - PathPattern: []segmentPattern{segmentPattern("users")}, + MethodPatterns: map[string]struct{}{"POST": {}}, + HostPattern: []string{"api", "com"}, + PathPattern: []string{"users"}, }, method: "GET", url: "https://api.com/users", @@ -192,9 +192,9 @@ func TestEngineMatches(t *testing.T) { { name: "host fails combined test", rule: Rule{ - MethodPatterns: map[methodPattern]struct{}{methodPattern("POST"): {}}, - HostPattern: []labelPattern{labelPattern("api"), labelPattern("org")}, - PathPattern: []segmentPattern{segmentPattern("users")}, + MethodPatterns: map[string]struct{}{"POST": {}}, + HostPattern: []string{"api", "org"}, + PathPattern: []string{"users"}, }, method: "POST", url: "https://api.com/users", @@ -203,9 +203,9 @@ func TestEngineMatches(t *testing.T) { { name: "path fails combined test", rule: Rule{ - MethodPatterns: map[methodPattern]struct{}{methodPattern("POST"): {}}, - HostPattern: []labelPattern{labelPattern("api"), labelPattern("com")}, - PathPattern: []segmentPattern{segmentPattern("posts")}, + MethodPatterns: map[string]struct{}{"POST": {}}, + HostPattern: []string{"api", "com"}, + PathPattern: []string{"posts"}, }, method: "POST", url: "https://api.com/users", @@ -214,9 +214,9 @@ func TestEngineMatches(t *testing.T) { { name: "all wildcards match", rule: Rule{ - MethodPatterns: map[methodPattern]struct{}{methodPattern("*"): {}}, - HostPattern: []labelPattern{labelPattern("*"), labelPattern("*")}, - PathPattern: []segmentPattern{segmentPattern("*"), segmentPattern("*")}, + MethodPatterns: map[string]struct{}{"*": {}}, + HostPattern: []string{"*", "*"}, + PathPattern: []string{"*", "*"}, }, method: "PATCH", url: "https://test.example.com/api/users/123", @@ -234,7 +234,7 @@ func TestEngineMatches(t *testing.T) { { name: "invalid URL", rule: Rule{ - HostPattern: []labelPattern{labelPattern("example"), labelPattern("com")}, + HostPattern: []string{"example", "com"}, }, method: "GET", url: "not-a-valid-url", @@ -243,7 +243,7 @@ func TestEngineMatches(t *testing.T) { { name: "root path", rule: Rule{ - PathPattern: []segmentPattern{}, + PathPattern: []string{}, }, method: "GET", url: "https://example.com/", @@ -252,7 +252,7 @@ func TestEngineMatches(t *testing.T) { { name: "localhost host", rule: Rule{ - HostPattern: []labelPattern{labelPattern("localhost")}, + HostPattern: []string{"localhost"}, }, method: "GET", url: "http://localhost:8080/api", diff --git a/rulesengine/rules.go b/rulesengine/rules.go index 22ca4fb..6a038ec 100644 --- a/rulesengine/rules.go +++ b/rulesengine/rules.go @@ -15,17 +15,17 @@ type Rule struct { // - nil means all paths allowed // - a path segment of `*` acts as a wild card. // - sub paths automatically match - PathPattern []segmentPattern + PathPattern []string // The labels of the host, i.e. ["google", "com"]. // - nil means all hosts allowed // - A label of `*` acts as a wild card. // - subdomains automatically match - HostPattern []labelPattern + HostPattern []string // The allowed http methods. // - nil means all methods allowed - MethodPatterns map[methodPattern]struct{} + MethodPatterns map[string]struct{} // Raw rule string for logging Raw string @@ -71,10 +71,10 @@ func parseAllowRule(ruleStr string) (Rule, error) { case "method": // Initialize Methods map if needed if rule.MethodPatterns == nil { - rule.MethodPatterns = make(map[methodPattern]struct{}) + rule.MethodPatterns = make(map[string]struct{}) } - var method methodPattern + var method string for { method, rest, err = parseMethodPattern(rest) if err != nil { @@ -93,7 +93,7 @@ func parseAllowRule(ruleStr string) (Rule, error) { } case "domain": - var host []labelPattern + var host []string host, rest, err = parseHostPattern(rest) if err != nil { return Rule{}, fmt.Errorf("failed to parse domain: %v", err) @@ -103,7 +103,7 @@ func parseAllowRule(ruleStr string) (Rule, error) { rule.HostPattern = append(rule.HostPattern, host...) case "path": - var segments []segmentPattern + var segments []string segments, rest, err = parsePathPattern(rest) if err != nil { return Rule{}, fmt.Errorf("failed to parse path: %v", err) @@ -125,27 +125,25 @@ func parseAllowRule(ruleStr string) (Rule, error) { return rule, nil } -type methodPattern string - // Beyond the 9 methods defined in HTTP 1.1, there actually are many more seldom used extension methods by // various systems. // https://datatracker.ietf.org/doc/html/rfc7230#section-3.2.6 -func parseMethodPattern(token string) (methodPattern, string, error) { +func parseMethodPattern(token string) (string, string, error) { if token == "" { return "", "", errors.New("expected http token, got empty string") } return doParseMethodPattern(token, nil) } -func doParseMethodPattern(token string, acc []byte) (methodPattern, string, error) { +func doParseMethodPattern(token string, acc []byte) (string, string, error) { // BASE CASE: if the token passed in is empty, we're done parsing if token == "" { - return methodPattern(acc), "", nil + return string(acc), "", nil } // If the next byte in the string is not a valid http token character, we're done parsing. if !isHTTPTokenChar(token[0]) { - return methodPattern(acc), token, nil + return string(acc), token, nil } // The next character is valid, so the http token continues @@ -178,9 +176,9 @@ func isHTTPTokenChar(c byte) bool { // Represents a valid host. // https://datatracker.ietf.org/doc/html/rfc952 // https://datatracker.ietf.org/doc/html/rfc1123#page-13 -func parseHostPattern(input string) ([]labelPattern, string, error) { +func parseHostPattern(input string) ([]string, string, error) { rest := input - var host []labelPattern + var host []string var err error if input == "" { @@ -188,7 +186,7 @@ func parseHostPattern(input string) ([]labelPattern, string, error) { } // There should be at least one label. - var label labelPattern + var label string label, rest, err = parseLabelPattern(rest) if err != nil { return nil, "", err @@ -218,10 +216,7 @@ func parseHostPattern(input string) ([]labelPattern, string, error) { return host, rest, nil } -// Represents a valid label in a hostname. For example, wobble in `wib-ble.wobble.com`. -type labelPattern string - -func parseLabelPattern(rest string) (labelPattern, string, error) { +func parseLabelPattern(rest string) (string, string, error) { if rest == "" { return "", "", errors.New("expected label, got empty string") } @@ -246,7 +241,7 @@ func parseLabelPattern(rest string) (labelPattern, string, error) { return "", "", fmt.Errorf("invalid label: %s", rest[:i]) } - return labelPattern(rest[:i]), rest[i:], nil + return rest[:i], rest[i:], nil } func isValidLabelChar(c byte) bool { @@ -268,13 +263,13 @@ func isValidLabelChar(c byte) bool { } } -func parsePathPattern(input string) ([]segmentPattern, string, error) { +func parsePathPattern(input string) ([]string, string, error) { if input == "" { return nil, "", nil } rest := input - var segments []segmentPattern + var segments []string var err error // If the path doesn't start with '/', it's not a valid absolute path @@ -291,7 +286,7 @@ func parsePathPattern(input string) ([]segmentPattern, string, error) { } // Parse the next segment - var segment segmentPattern + var segment string segment, rest, err = parsePathSegmentPattern(rest) if err != nil { return nil, "", err @@ -314,10 +309,7 @@ func parsePathPattern(input string) ([]segmentPattern, string, error) { return segments, rest, nil } -// Represents a valid url path segmentPattern. -type segmentPattern string - -func parsePathSegmentPattern(input string) (segmentPattern, string, error) { +func parsePathSegmentPattern(input string) (string, string, error) { if input == "" { return "", "", nil } @@ -327,7 +319,7 @@ func parsePathSegmentPattern(input string) (segmentPattern, string, error) { return "", "", fmt.Errorf("path segment wildcards must be for the entire segment, got: %s", input) } - return segmentPattern(input[0]), input[1:], nil + return "*", input[1:], nil } var i int @@ -349,7 +341,7 @@ func parsePathSegmentPattern(input string) (segmentPattern, string, error) { } } - return segmentPattern(input[:i]), input[i:], nil + return input[:i], input[i:], nil } // isUnreserved returns true if the character is unreserved per RFC 3986 diff --git a/rulesengine/rules_test.go b/rulesengine/rules_test.go index 20cddf7..8ccf94e 100644 --- a/rulesengine/rules_test.go +++ b/rulesengine/rules_test.go @@ -10,7 +10,7 @@ func TestParseHTTPToken(t *testing.T) { tests := []struct { name string input string - expectedToken methodPattern + expectedToken string expectedRemain string expectError bool }{ @@ -138,7 +138,7 @@ func TestParseHost(t *testing.T) { tests := []struct { name string input string - expectedHost []labelPattern + expectedHost []string expectedRest string expectError bool }{ @@ -152,56 +152,56 @@ func TestParseHost(t *testing.T) { { name: "simple domain", input: "google.com", - expectedHost: []labelPattern{labelPattern("google"), labelPattern("com")}, + expectedHost: []string{"google", "com"}, expectedRest: "", expectError: false, }, { name: "subdomain", input: "api.google.com", - expectedHost: []labelPattern{labelPattern("api"), labelPattern("google"), labelPattern("com")}, + expectedHost: []string{"api", "google", "com"}, expectedRest: "", expectError: false, }, { name: "single label", input: "localhost", - expectedHost: []labelPattern{labelPattern("localhost")}, + expectedHost: []string{"localhost"}, expectedRest: "", expectError: false, }, { name: "domain with trailing content", input: "example.org/path", - expectedHost: []labelPattern{labelPattern("example"), labelPattern("org")}, + expectedHost: []string{"example", "org"}, expectedRest: "/path", expectError: false, }, { name: "domain with port", input: "localhost:8080", - expectedHost: []labelPattern{labelPattern("localhost")}, + expectedHost: []string{"localhost"}, expectedRest: ":8080", expectError: false, }, { name: "numeric labels", input: "192.168.1.1", - expectedHost: []labelPattern{labelPattern("192"), labelPattern("168"), labelPattern("1"), labelPattern("1")}, + expectedHost: []string{"192", "168", "1", "1"}, expectedRest: "", expectError: false, }, { name: "hyphenated domain", input: "my-site.example-domain.co.uk", - expectedHost: []labelPattern{labelPattern("my-site"), labelPattern("example-domain"), labelPattern("co"), labelPattern("uk")}, + expectedHost: []string{"my-site", "example-domain", "co", "uk"}, expectedRest: "", expectError: false, }, { name: "alphanumeric labels", input: "a1b2c3.test123.com", - expectedHost: []labelPattern{labelPattern("a1b2c3"), labelPattern("test123"), labelPattern("com")}, + expectedHost: []string{"a1b2c3", "test123", "com"}, expectedRest: "", expectError: false, }, @@ -229,7 +229,7 @@ func TestParseHost(t *testing.T) { { name: "invalid character", input: "test@example.com", - expectedHost: []labelPattern{labelPattern("test")}, + expectedHost: []string{"test"}, expectedRest: "@example.com", expectError: false, }, @@ -250,21 +250,21 @@ func TestParseHost(t *testing.T) { { name: "single character labels", input: "a.b.c", - expectedHost: []labelPattern{labelPattern("a"), labelPattern("b"), labelPattern("c")}, + expectedHost: []string{"a", "b", "c"}, expectedRest: "", expectError: false, }, { name: "mixed case", input: "Example.COM", - expectedHost: []labelPattern{labelPattern("Example"), labelPattern("COM")}, + expectedHost: []string{"Example", "COM"}, expectedRest: "", expectError: false, }, { name: "wildcard subdomain", input: "*.example.com", - expectedHost: []labelPattern{labelPattern("*"), labelPattern("example"), labelPattern("com")}, + expectedHost: []string{"*", "example", "com"}, expectedRest: "", expectError: false, }, @@ -278,14 +278,14 @@ func TestParseHost(t *testing.T) { { name: "multiple wildcards", input: "*.*.com", - expectedHost: []labelPattern{labelPattern("*"), labelPattern("*"), labelPattern("com")}, + expectedHost: []string{"*", "*", "com"}, expectedRest: "", expectError: false, }, { name: "wildcard with trailing content", input: "*.example.com/path", - expectedHost: []labelPattern{labelPattern("*"), labelPattern("example"), labelPattern("com")}, + expectedHost: []string{"*", "example", "com"}, expectedRest: "/path", expectError: false, }, @@ -336,7 +336,7 @@ func TestParseLabel(t *testing.T) { tests := []struct { name string input string - expectedLabel labelPattern + expectedLabel string expectedRest string expectError bool }{ @@ -492,7 +492,7 @@ func TestParsePathSegment(t *testing.T) { tests := []struct { name string input string - expectedSegment segmentPattern + expectedSegment string expectedRest string expectError bool }{ @@ -676,7 +676,7 @@ func TestParsePath(t *testing.T) { tests := []struct { name string input string - expectedSegments []segmentPattern + expectedSegments []string expectedRest string expectError bool }{ @@ -690,56 +690,56 @@ func TestParsePath(t *testing.T) { { name: "single segment", input: "/api", - expectedSegments: []segmentPattern{"api"}, + expectedSegments: []string{"api"}, expectedRest: "", expectError: false, }, { name: "multiple segments", input: "/api/v1/users", - expectedSegments: []segmentPattern{"api", "v1", "users"}, + expectedSegments: []string{"api", "v1", "users"}, expectedRest: "", expectError: false, }, { name: "relative path", input: "api/users", - expectedSegments: []segmentPattern{"api", "users"}, + expectedSegments: []string{"api", "users"}, expectedRest: "", expectError: false, }, { name: "path with trailing slash", input: "/api/users/", - expectedSegments: []segmentPattern{"api", "users"}, + expectedSegments: []string{"api", "users"}, expectedRest: "", expectError: false, }, { name: "path with query string", input: "/api/users?limit=10", - expectedSegments: []segmentPattern{"api", "users"}, + expectedSegments: []string{"api", "users"}, expectedRest: "?limit=10", expectError: false, }, { name: "path with fragment", input: "/docs/api#authentication", - expectedSegments: []segmentPattern{"docs", "api"}, + expectedSegments: []string{"docs", "api"}, expectedRest: "#authentication", expectError: false, }, { name: "path with encoded segments", input: "/api/hello%20world/test", - expectedSegments: []segmentPattern{"api", "hello%20world", "test"}, + expectedSegments: []string{"api", "hello%20world", "test"}, expectedRest: "", expectError: false, }, { name: "path with special chars", input: "/api/filter='test'&sort=name/results", - expectedSegments: []segmentPattern{"api", "filter='test'&sort=name", "results"}, + expectedSegments: []string{"api", "filter='test'&sort=name", "results"}, expectedRest: "", expectError: false, }, @@ -753,91 +753,91 @@ func TestParsePath(t *testing.T) { { name: "empty segments", input: "/api//users", - expectedSegments: []segmentPattern{"api"}, + expectedSegments: []string{"api"}, expectedRest: "/users", expectError: false, }, { name: "path with port-like segment", input: "/host:8080/status", - expectedSegments: []segmentPattern{"host:8080", "status"}, + expectedSegments: []string{"host:8080", "status"}, expectedRest: "", expectError: false, }, { name: "path stops at space", input: "/api/test hello", - expectedSegments: []segmentPattern{"api", "test"}, + expectedSegments: []string{"api", "test"}, expectedRest: " hello", expectError: false, }, { name: "path with hyphens and underscores", input: "/my-api/user_data/file-name.txt", - expectedSegments: []segmentPattern{"my-api", "user_data", "file-name.txt"}, + expectedSegments: []string{"my-api", "user_data", "file-name.txt"}, expectedRest: "", expectError: false, }, { name: "path with tildes", input: "/api/~user/docs~backup", - expectedSegments: []segmentPattern{"api", "~user", "docs~backup"}, + expectedSegments: []string{"api", "~user", "docs~backup"}, expectedRest: "", expectError: false, }, { name: "numeric segments", input: "/api/v2/users/12345", - expectedSegments: []segmentPattern{"api", "v2", "users", "12345"}, + expectedSegments: []string{"api", "v2", "users", "12345"}, expectedRest: "", expectError: false, }, { name: "single character segments", input: "/a/b/c", - expectedSegments: []segmentPattern{"a", "b", "c"}, + expectedSegments: []string{"a", "b", "c"}, expectedRest: "", expectError: false, }, { name: "path with at symbol", input: "/user@domain.com/profile", - expectedSegments: []segmentPattern{"user@domain.com", "profile"}, + expectedSegments: []string{"user@domain.com", "profile"}, expectedRest: "", expectError: false, }, { name: "path with wildcard segment", input: "/api/*/users", - expectedSegments: []segmentPattern{"api", "*", "users"}, + expectedSegments: []string{"api", "*", "users"}, expectedRest: "", expectError: false, }, { name: "path with multiple wildcards", input: "/*/v1/*/profile", - expectedSegments: []segmentPattern{"*", "v1", "*", "profile"}, + expectedSegments: []string{"*", "v1", "*", "profile"}, expectedRest: "", expectError: false, }, { name: "path ending with wildcard", input: "/api/users/*", - expectedSegments: []segmentPattern{"api", "users", "*"}, + expectedSegments: []string{"api", "users", "*"}, expectedRest: "", expectError: false, }, { name: "path starting with wildcard", input: "/*/users", - expectedSegments: []segmentPattern{"*", "users"}, + expectedSegments: []string{"*", "users"}, expectedRest: "", expectError: false, }, { name: "path with wildcard and query", input: "/api/*/users?limit=10", - expectedSegments: []segmentPattern{"api", "*", "users"}, + expectedSegments: []string{"api", "*", "users"}, expectedRest: "?limit=10", expectError: false, }, @@ -897,7 +897,7 @@ func TestParseAllowRule(t *testing.T) { input: "method=GET", expectedRule: Rule{ Raw: "method=GET", - MethodPatterns: map[methodPattern]struct{}{methodPattern("GET"): {}}, + MethodPatterns: map[string]struct{}{"GET": {}}, }, expectError: false, }, @@ -906,7 +906,7 @@ func TestParseAllowRule(t *testing.T) { input: "domain=google.com", expectedRule: Rule{ Raw: "domain=google.com", - HostPattern: []labelPattern{labelPattern("google"), labelPattern("com")}, + HostPattern: []string{"google", "com"}, }, expectError: false, }, @@ -915,7 +915,7 @@ func TestParseAllowRule(t *testing.T) { input: "path=/api/v1", expectedRule: Rule{ Raw: "path=/api/v1", - PathPattern: []segmentPattern{segmentPattern("api"), segmentPattern("v1")}, + PathPattern: []string{"api", "v1"}, }, expectError: false, }, @@ -924,8 +924,8 @@ func TestParseAllowRule(t *testing.T) { input: "method=POST domain=api.example.com", expectedRule: Rule{ Raw: "method=POST domain=api.example.com", - MethodPatterns: map[methodPattern]struct{}{methodPattern("POST"): {}}, - HostPattern: []labelPattern{labelPattern("api"), labelPattern("example"), labelPattern("com")}, + MethodPatterns: map[string]struct{}{"POST": {}}, + HostPattern: []string{"api", "example", "com"}, }, expectError: false, }, @@ -934,9 +934,9 @@ func TestParseAllowRule(t *testing.T) { input: "method=DELETE domain=test.com path=/resources/456", expectedRule: Rule{ Raw: "method=DELETE domain=test.com path=/resources/456", - MethodPatterns: map[methodPattern]struct{}{methodPattern("DELETE"): {}}, - HostPattern: []labelPattern{labelPattern("test"), labelPattern("com")}, - PathPattern: []segmentPattern{segmentPattern("resources"), segmentPattern("456")}, + MethodPatterns: map[string]struct{}{"DELETE": {}}, + HostPattern: []string{"test", "com"}, + PathPattern: []string{"resources", "456"}, }, expectError: false, }, @@ -945,7 +945,7 @@ func TestParseAllowRule(t *testing.T) { input: "domain=*.example.com", expectedRule: Rule{ Raw: "domain=*.example.com", - HostPattern: []labelPattern{labelPattern("*"), labelPattern("example"), labelPattern("com")}, + HostPattern: []string{"*", "example", "com"}, }, expectError: false, }, @@ -954,7 +954,7 @@ func TestParseAllowRule(t *testing.T) { input: "path=/api/*/users", expectedRule: Rule{ Raw: "path=/api/*/users", - PathPattern: []segmentPattern{segmentPattern("api"), segmentPattern("*"), segmentPattern("users")}, + PathPattern: []string{"api", "*", "users"}, }, expectError: false, }, @@ -963,7 +963,7 @@ func TestParseAllowRule(t *testing.T) { input: "method=*", expectedRule: Rule{ Raw: "method=*", - MethodPatterns: map[methodPattern]struct{}{methodPattern("*"): {}}, + MethodPatterns: map[string]struct{}{"*": {}}, }, expectError: false, },