Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions parser_option.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,14 @@ func WithExpirationRequired() ParserOption {
}
}

// WithNotBeforeRequired returns the ParserOption to make nbf claim required.
// By default nbf claim is optional.
func WithNotBeforeRequired() ParserOption {
return func(p *Parser) {
p.validator.requireNbf = true
}
}

// WithAudience configures the validator to require any of the specified
// audiences in the `aud` claim. Validation will fail if the audience is not
// listed in the token or the `aud` claim is missing.
Expand Down
8 changes: 6 additions & 2 deletions validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ type Validator struct {
// requireExp specifies whether the exp claim is required
requireExp bool

// requireExp specifies whether the exp claim is required
requireNbf bool

// verifyIat specifies whether the iat (Issued At) claim will be verified.
// According to https://www.rfc-editor.org/rfc/rfc7519#section-4.1.6 this
// only specifies the age of the token, but no validation check is
Expand Down Expand Up @@ -111,8 +114,9 @@ func (v *Validator) Validate(claims Claims) error {
}

// We always need to check not-before, but usage of the claim itself is
// OPTIONAL.
if err = v.verifyNotBefore(claims, now, false); err != nil {
// OPTIONAL by default. requireNbf overrides this behavior and makes
// the nbf claim mandatory.
if err = v.verifyNotBefore(claims, now, v.requireNbf); err != nil {
errs = append(errs, err)
}

Expand Down
220 changes: 199 additions & 21 deletions validator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,49 +89,205 @@ func Test_Validator_Validate(t *testing.T) {
}

func Test_Validator_verifyExpiresAt(t *testing.T) {
times, err := test_Validator_CreateStaticTimes(t)
if err != nil {
t.Fatal(err)
}

type fields struct {
leeway time.Duration
timeFunc func() time.Time
leeway time.Duration
}

type args struct {
claims Claims
cmp time.Time
required bool
}

tests := []struct {
name string
fields fields
args args
wantErr error
}{
{
name: "good claim",
fields: fields{timeFunc: time.Now},
args: args{claims: RegisteredClaims{ExpiresAt: NewNumericDate(time.Now().Add(10 * time.Minute))}},
name: "required claim present and valid",
args: args{claims: RegisteredClaims{ExpiresAt: NewNumericDate(times.AfterNow)}, required: true},
wantErr: nil,
},
{
name: "claims with invalid type",
fields: fields{},
args: args{claims: MapClaims{"exp": "string"}},
name: "required claim present and expired",
args: args{claims: RegisteredClaims{ExpiresAt: NewNumericDate(times.BeforeNow)}, required: true},
wantErr: ErrTokenExpired,
},
{
name: "required claim present and expired with leeway",
fields: fields{leeway: time.Hour * 2},
args: args{claims: RegisteredClaims{ExpiresAt: NewNumericDate(times.BeforeNow)}, required: true},
wantErr: nil,
},
{
name: "required claim present and expired past leeway",
fields: fields{leeway: time.Minute * 1},
args: args{claims: RegisteredClaims{ExpiresAt: NewNumericDate(times.BeforeNow)}, required: true},
wantErr: ErrTokenExpired,
},
{
name: "required claim not provided",
args: args{claims: RegisteredClaims{}, required: true},
wantErr: ErrTokenRequiredClaimMissing,
},
{
name: "required claim present with invalid type",
args: args{claims: MapClaims{"exp": "string"}, required: true},
wantErr: ErrInvalidType,
},

{
name: "not required claim present and valid",
args: args{claims: RegisteredClaims{ExpiresAt: NewNumericDate(times.AfterNow)}, required: false},
wantErr: nil,
},
{
name: "not required claim present and expired",
args: args{claims: RegisteredClaims{ExpiresAt: NewNumericDate(times.BeforeNow)}, required: false},
wantErr: ErrTokenExpired,
},
{
name: "not required claim present and expired with leeway",
fields: fields{leeway: time.Hour * 2},
args: args{claims: RegisteredClaims{ExpiresAt: NewNumericDate(times.BeforeNow)}, required: false},
wantErr: nil,
},
{
name: "not required claim present and expired past leeway",
fields: fields{leeway: time.Minute * 1},
args: args{claims: RegisteredClaims{ExpiresAt: NewNumericDate(times.BeforeNow)}, required: false},
wantErr: ErrTokenExpired,
},
{
name: "not required claim not provided",
args: args{claims: RegisteredClaims{}, required: false},
wantErr: nil,
},
{
name: "not required claim present with invalid type",
args: args{claims: MapClaims{"exp": "string"}, required: false},
wantErr: ErrInvalidType,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
v := &Validator{
leeway: tt.fields.leeway,
timeFunc: tt.fields.timeFunc,
leeway: tt.fields.leeway,
}

err := v.verifyExpiresAt(tt.args.claims, tt.args.cmp, tt.args.required)
if (err != nil) && !errors.Is(err, tt.wantErr) {
err := v.verifyExpiresAt(tt.args.claims, times.Now, tt.args.required)
if (err != nil || tt.wantErr != nil) && !errors.Is(err, tt.wantErr) {
t.Errorf("validator.verifyExpiresAt() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

func Test_Validator_verifyNotBefore(t *testing.T) {
times, err := test_Validator_CreateStaticTimes(t)
if err != nil {
t.Fatal(err)
}

type fields struct {
leeway time.Duration
}
type args struct {
claims Claims
required bool
}
tests := []struct {
name string
fields fields
args args
wantErr error
}{
{
name: "required claim present and valid",
args: args{claims: RegisteredClaims{NotBefore: NewNumericDate(times.BeforeNow)}, required: true},
wantErr: nil,
},
{
name: "required claim present and in future",
args: args{claims: RegisteredClaims{NotBefore: NewNumericDate(times.AfterNow)}, required: true},
wantErr: ErrTokenNotValidYet,
},
{
name: "required claim present and in future with leeway",
fields: fields{leeway: time.Hour * 2},
args: args{claims: RegisteredClaims{NotBefore: NewNumericDate(times.AfterNow)}, required: true},
wantErr: nil,
},
{
name: "required claim present and in future past leeway",
fields: fields{leeway: time.Minute * 1},
args: args{claims: RegisteredClaims{NotBefore: NewNumericDate(times.AfterNow)}, required: true},
wantErr: ErrTokenNotValidYet,
},
{
name: "required claim not provided",
args: args{claims: RegisteredClaims{}, required: true},
wantErr: ErrTokenRequiredClaimMissing,
},
{
name: "required claim present with invalid type",
args: args{claims: MapClaims{"nbf": "string"}, required: true},
wantErr: ErrInvalidType,
},

{
name: "not required claim present and valid",
args: args{claims: RegisteredClaims{NotBefore: NewNumericDate(times.BeforeNow)}, required: false},
wantErr: nil,
},
{
name: "not required claim present and in future",
args: args{claims: RegisteredClaims{NotBefore: NewNumericDate(times.AfterNow)}, required: false},
wantErr: ErrTokenNotValidYet,
},
{
name: "not required claim present and in future with leeway",
fields: fields{leeway: time.Hour * 2},
args: args{claims: RegisteredClaims{NotBefore: NewNumericDate(times.AfterNow)}, required: false},
wantErr: nil,
},
{
name: "not required claim present and in future past leeway",
fields: fields{leeway: time.Minute * 1},
args: args{claims: RegisteredClaims{NotBefore: NewNumericDate(times.AfterNow)}, required: false},
wantErr: ErrTokenNotValidYet,
},
{
name: "not required claim not provided",
args: args{claims: RegisteredClaims{}, required: false},
wantErr: nil,
},
{
name: "not required claim present with invalid type",
args: args{claims: MapClaims{"nbf": "string"}, required: false},
wantErr: ErrInvalidType,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
v := &Validator{
leeway: tt.fields.leeway,
}

err := v.verifyNotBefore(tt.args.claims, times.Now, tt.args.required)
if (err != nil || tt.wantErr != nil) && !errors.Is(err, tt.wantErr) {
t.Errorf("validator.verifyNotBefore() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

func Test_Validator_verifyIssuer(t *testing.T) {
type fields struct {
expectedIss string
Expand Down Expand Up @@ -216,15 +372,15 @@ func Test_Validator_verifySubject(t *testing.T) {

func Test_Validator_verifyIssuedAt(t *testing.T) {
type fields struct {
leeway time.Duration
timeFunc func() time.Time
verifyIat bool
}

type args struct {
claims Claims
cmp time.Time
required bool
}

tests := []struct {
name string
fields fields
Expand All @@ -250,12 +406,8 @@ func Test_Validator_verifyIssuedAt(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
v := &Validator{
leeway: tt.fields.leeway,
timeFunc: tt.fields.timeFunc,
verifyIat: tt.fields.verifyIat,
}
if err := v.verifyIssuedAt(tt.args.claims, tt.args.cmp, tt.args.required); (err != nil) && !errors.Is(err, tt.wantErr) {
v := &Validator{}
if err := v.verifyIssuedAt(tt.args.claims, tt.args.cmp, tt.args.required); (err != nil || tt.wantErr != nil) && !errors.Is(err, tt.wantErr) {
t.Errorf("validator.verifyIssuedAt() error = %v, wantErr %v", err, tt.wantErr)
}
})
Expand Down Expand Up @@ -374,3 +526,29 @@ func Test_Validator_verifyAudience(t *testing.T) {
})
}
}

// testStaticTimes is a struct that contains 3 timestamps, that are intended to be used to validate functionality
// of registered claim validation.
type testStaticTimes struct {
BeforeNow time.Time
Now time.Time
AfterNow time.Time
}

// test_Validator_CreateStaticTimes returns a set of timestamps that can be used to validate functionality
// without requiring the use of "time.Now()", which can cause "flakey" tests if there is a delay in when the tests
// run vs when they were started.
func test_Validator_CreateStaticTimes(t *testing.T) (testStaticTimes, error) {
t.Helper()

staticNow, err := time.Parse(time.RFC3339, "2025-01-02T15:04:05Z")
if err != nil {
return testStaticTimes{}, err
}

return testStaticTimes{
BeforeNow: staticNow.Add(time.Hour * -1),
Now: staticNow,
AfterNow: staticNow.Add(time.Hour * 1),
}, nil
}