Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ go get nhooyr.io/websocket

For a production quality example that shows off the full API, see the [echo example on the godoc](https://godoc.org/nhooyr.io/websocket#example-package--Echo). On github, the example is at [example_echo_test.go](./example_echo_test.go).

Use the [errors.As](https://golang.org/pkg/errors/#As) function [new in Go 1.13](https://golang.org/doc/go1.13#error_wrapping) to check for [websocket.CloseError](https://godoc.org/nhooyr.io/websocket#CloseError). See the [CloseError godoc example](https://godoc.org/nhooyr.io/websocket#example-CloseError).
Use the [errors.As](https://golang.org/pkg/errors/#As) function [new in Go 1.13](https://golang.org/doc/go1.13#error_wrapping) to check for [websocket.CloseError](https://godoc.org/nhooyr.io/websocket#CloseError).
There is also [websocket.CloseStatus](https://godoc.org/nhooyr.io/websocket#CloseStatus) to quickly grab the close status code out of a [websocket.CloseError](https://godoc.org/nhooyr.io/websocket#CloseError).
See the [CloseError godoc example](https://godoc.org/nhooyr.io/websocket#example-CloseError).

### Server

Expand Down
68 changes: 6 additions & 62 deletions assert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,75 +2,19 @@ package websocket_test

import (
"context"
"fmt"
"math/rand"
"reflect"
"strings"
"time"

"github.com/google/go-cmp/cmp"

"nhooyr.io/websocket"
"nhooyr.io/websocket/internal/assert"
"nhooyr.io/websocket/wsjson"
)

func init() {
rand.Seed(time.Now().UnixNano())
}

// https://github.com/google/go-cmp/issues/40#issuecomment-328615283
func cmpDiff(exp, act interface{}) string {
return cmp.Diff(exp, act, deepAllowUnexported(exp, act))
}

func deepAllowUnexported(vs ...interface{}) cmp.Option {
m := make(map[reflect.Type]struct{})
for _, v := range vs {
structTypes(reflect.ValueOf(v), m)
}
var typs []interface{}
for t := range m {
typs = append(typs, reflect.New(t).Elem().Interface())
}
return cmp.AllowUnexported(typs...)
}

func structTypes(v reflect.Value, m map[reflect.Type]struct{}) {
if !v.IsValid() {
return
}
switch v.Kind() {
case reflect.Ptr:
if !v.IsNil() {
structTypes(v.Elem(), m)
}
case reflect.Interface:
if !v.IsNil() {
structTypes(v.Elem(), m)
}
case reflect.Slice, reflect.Array:
for i := 0; i < v.Len(); i++ {
structTypes(v.Index(i), m)
}
case reflect.Map:
for _, k := range v.MapKeys() {
structTypes(v.MapIndex(k), m)
}
case reflect.Struct:
m[v.Type()] = struct{}{}
for i := 0; i < v.NumField(); i++ {
structTypes(v.Field(i), m)
}
}
}

func assertEqualf(exp, act interface{}, f string, v ...interface{}) error {
if diff := cmpDiff(exp, act); diff != "" {
return fmt.Errorf(f+": %v", append(v, diff)...)
}
return nil
}

func assertJSONEcho(ctx context.Context, c *websocket.Conn, n int) error {
exp := randString(n)
err := wsjson.Write(ctx, c, exp)
Expand All @@ -84,7 +28,7 @@ func assertJSONEcho(ctx context.Context, c *websocket.Conn, n int) error {
return err
}

return assertEqualf(exp, act, "unexpected JSON")
return assert.Equalf(exp, act, "unexpected JSON")
}

func assertJSONRead(ctx context.Context, c *websocket.Conn, exp interface{}) error {
Expand All @@ -94,7 +38,7 @@ func assertJSONRead(ctx context.Context, c *websocket.Conn, exp interface{}) err
return err
}

return assertEqualf(exp, act, "unexpected JSON")
return assert.Equalf(exp, act, "unexpected JSON")
}

func randBytes(n int) []byte {
Expand Down Expand Up @@ -126,13 +70,13 @@ func assertEcho(ctx context.Context, c *websocket.Conn, typ websocket.MessageTyp
if err != nil {
return err
}
err = assertEqualf(typ, typ2, "unexpected data type")
err = assert.Equalf(typ, typ2, "unexpected data type")
if err != nil {
return err
}
return assertEqualf(p, p2, "unexpected payload")
return assert.Equalf(p, p2, "unexpected payload")
}

func assertSubprotocol(c *websocket.Conn, exp string) error {
return assertEqualf(exp, c.Subprotocol(), "unexpected subprotocol")
return assert.Equalf(exp, c.Subprotocol(), "unexpected subprotocol")
}
43 changes: 22 additions & 21 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
"go.uber.org/multierr"

"nhooyr.io/websocket"
"nhooyr.io/websocket/internal/assert"
"nhooyr.io/websocket/internal/wsecho"
"nhooyr.io/websocket/wsjson"
"nhooyr.io/websocket/wspb"
Expand Down Expand Up @@ -127,7 +128,7 @@ func TestHandshake(t *testing.T) {
if err != nil {
return fmt.Errorf("request is missing mycookie: %w", err)
}
err = assertEqualf("myvalue", cookie.Value, "unexpected cookie value")
err = assert.Equalf("myvalue", cookie.Value, "unexpected cookie value")
if err != nil {
return err
}
Expand Down Expand Up @@ -219,7 +220,7 @@ func TestConn(t *testing.T) {
}
for h, exp := range headers {
value := resp.Header.Get(h)
err := assertEqualf(exp, value, "unexpected value for header %v", h)
err := assert.Equalf(exp, value, "unexpected value for header %v", h)
if err != nil {
return err
}
Expand Down Expand Up @@ -276,11 +277,11 @@ func TestConn(t *testing.T) {
time.Sleep(1)
nc.SetWriteDeadline(time.Now().Add(time.Second * 15))

err := assertEqualf(websocket.Addr{}, nc.LocalAddr(), "net conn local address is not equal to websocket.Addr")
err := assert.Equalf(websocket.Addr{}, nc.LocalAddr(), "net conn local address is not equal to websocket.Addr")
if err != nil {
return err
}
err = assertEqualf(websocket.Addr{}, nc.RemoteAddr(), "net conn remote address is not equal to websocket.Addr")
err = assert.Equalf(websocket.Addr{}, nc.RemoteAddr(), "net conn remote address is not equal to websocket.Addr")
if err != nil {
return err
}
Expand Down Expand Up @@ -310,13 +311,13 @@ func TestConn(t *testing.T) {

// Ensure the close frame is converted to an EOF and multiple read's after all return EOF.
err2 := assertNetConnRead(nc, "hello")
err := assertEqualf(io.EOF, err2, "unexpected error")
err := assert.Equalf(io.EOF, err2, "unexpected error")
if err != nil {
return err
}

err2 = assertNetConnRead(nc, "hello")
return assertEqualf(io.EOF, err2, "unexpected error")
return assert.Equalf(io.EOF, err2, "unexpected error")
},
},
{
Expand Down Expand Up @@ -585,8 +586,8 @@ func TestConn(t *testing.T) {
return err
}
_, _, err = c.Read(ctx)
cerr := &websocket.CloseError{}
if !errors.As(err, cerr) || cerr.Code != websocket.StatusProtocolError {
var cerr websocket.CloseError
if !errors.As(err, &cerr) || cerr.Code != websocket.StatusProtocolError {
return fmt.Errorf("expected close error with StatusProtocolError: %+v", err)
}
return nil
Expand Down Expand Up @@ -772,15 +773,15 @@ func TestConn(t *testing.T) {
if err != nil {
return err
}
err = assertEqualf("hi", v, "unexpected JSON")
err = assert.Equalf("hi", v, "unexpected JSON")
if err != nil {
return err
}
_, b, err := c.Read(ctx)
if err != nil {
return err
}
return assertEqualf("hi", string(b), "unexpected JSON")
return assert.Equalf("hi", string(b), "unexpected JSON")
},
client: func(ctx context.Context, c *websocket.Conn) error {
err := wsjson.Write(ctx, c, "hi")
Expand Down Expand Up @@ -1079,11 +1080,11 @@ func TestAutobahn(t *testing.T) {
if err != nil {
return err
}
err = assertEqualf(typ, actTyp, "unexpected message type")
err = assert.Equalf(typ, actTyp, "unexpected message type")
if err != nil {
return err
}
return assertEqualf(p, p2, "unexpected message")
return assert.Equalf(p, p2, "unexpected message")
})
}
}
Expand Down Expand Up @@ -1859,7 +1860,7 @@ func assertCloseStatus(err error, code websocket.StatusCode) error {
if !errors.As(err, &cerr) {
return fmt.Errorf("no websocket close error in error chain: %+v", err)
}
return assertEqualf(code, cerr.Code, "unexpected status code")
return assert.Equalf(code, cerr.Code, "unexpected status code")
}

func assertProtobufRead(ctx context.Context, c *websocket.Conn, exp interface{}) error {
Expand All @@ -1871,7 +1872,7 @@ func assertProtobufRead(ctx context.Context, c *websocket.Conn, exp interface{})
return err
}

return assertEqualf(exp, act, "unexpected protobuf")
return assert.Equalf(exp, act, "unexpected protobuf")
}

func assertNetConnRead(r io.Reader, exp string) error {
Expand All @@ -1880,7 +1881,7 @@ func assertNetConnRead(r io.Reader, exp string) error {
if err != nil {
return err
}
return assertEqualf(exp, string(act), "unexpected net conn read")
return assert.Equalf(exp, string(act), "unexpected net conn read")
}

func assertErrorContains(err error, exp string) error {
Expand All @@ -1902,27 +1903,27 @@ func assertReadFrame(ctx context.Context, c *websocket.Conn, opcode websocket.Op
if err != nil {
return err
}
err = assertEqualf(opcode, actOpcode, "unexpected frame opcode with payload %q", actP)
err = assert.Equalf(opcode, actOpcode, "unexpected frame opcode with payload %q", actP)
if err != nil {
return err
}
return assertEqualf(p, actP, "unexpected frame %v payload", opcode)
return assert.Equalf(p, actP, "unexpected frame %v payload", opcode)
}

func assertReadCloseFrame(ctx context.Context, c *websocket.Conn, code websocket.StatusCode) error {
actOpcode, actP, err := c.ReadFrame(ctx)
if err != nil {
return err
}
err = assertEqualf(websocket.OpClose, actOpcode, "unexpected frame opcode with payload %q", actP)
err = assert.Equalf(websocket.OpClose, actOpcode, "unexpected frame opcode with payload %q", actP)
if err != nil {
return err
}
ce, err := websocket.ParseClosePayload(actP)
if err != nil {
return fmt.Errorf("failed to parse close frame payload: %w", err)
}
return assertEqualf(ce.Code, code, "unexpected frame close frame code with payload %q", actP)
return assert.Equalf(ce.Code, code, "unexpected frame close frame code with payload %q", actP)
}

func assertCloseHandshake(ctx context.Context, c *websocket.Conn, code websocket.StatusCode, reason string) error {
Expand Down Expand Up @@ -1960,11 +1961,11 @@ func assertReadMessage(ctx context.Context, c *websocket.Conn, typ websocket.Mes
if err != nil {
return err
}
err = assertEqualf(websocket.MessageText, actTyp, "unexpected frame opcode with payload %q", actP)
err = assert.Equalf(websocket.MessageText, actTyp, "unexpected frame opcode with payload %q", actP)
if err != nil {
return err
}
return assertEqualf(p, actP, "unexpected frame %v payload", actTyp)
return assert.Equalf(p, actP, "unexpected frame %v payload", actTyp)
}

func BenchmarkConn(b *testing.B) {
Expand Down
1 change: 1 addition & 0 deletions doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
// comparison with existing implementations.
//
// Use the errors.As function new in Go 1.13 to check for websocket.CloseError.
// Or use the CloseStatus function to grab the StatusCode out of a websocket.CloseError
// See the CloseError example.
//
// Wasm
Expand Down
4 changes: 1 addition & 3 deletions example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ package websocket_test

import (
"context"
"errors"
"log"
"net/http"
"time"
Expand Down Expand Up @@ -76,8 +75,7 @@ func ExampleCloseError() {
defer c.Close(websocket.StatusInternalError, "the sky is falling")

_, _, err = c.Reader(ctx)
var cerr websocket.CloseError
if !errors.As(err, &cerr) || cerr.Code != websocket.StatusNormalClosure {
if websocket.CloseStatus(err) != websocket.StatusNormalClosure {
log.Fatalf("expected to be disconnected with StatusNormalClosure but got: %+v", err)
return
}
Expand Down
12 changes: 12 additions & 0 deletions frame.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package websocket

import (
"encoding/binary"
"errors"
"fmt"
"io"
"math"
Expand Down Expand Up @@ -252,6 +253,17 @@ func (ce CloseError) Error() string {
return fmt.Sprintf("status = %v and reason = %q", ce.Code, ce.Reason)
}

// CloseStatus is a convenience wrapper around errors.As to grab
// the status code from a *CloseError. If the passed error is nil
// or not a *CloseError, the returned StatusCode will be -1.
func CloseStatus(err error) StatusCode {
var ce CloseError
if errors.As(err, &ce) {
return ce.Code
}
return -1
}

Copy link
Contributor Author

@nhooyr nhooyr Oct 5, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the main addition in this PR, the rest is just some refactoring to allow the use of assert.Equalf in the tests for this function.

func parseClosePayload(p []byte) (CloseError, error) {
if len(p) == 0 {
return CloseError{
Expand Down
42 changes: 42 additions & 0 deletions frame_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ import (
"time"

"github.com/google/go-cmp/cmp"

"nhooyr.io/websocket/internal/assert"
)

func init() {
Expand Down Expand Up @@ -376,3 +378,43 @@ func BenchmarkXOR(b *testing.B) {
})
}
}

func TestCloseStatus(t *testing.T) {
t.Parallel()

testCases := []struct {
name string
in error
exp StatusCode
}{
{
name: "nil",
in: nil,
exp: -1,
},
{
name: "io.EOF",
in: io.EOF,
exp: -1,
},
{
name: "StatusInternalError",
in: CloseError{
Code: StatusInternalError,
},
exp: StatusInternalError,
},
}

for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

err := assert.Equalf(tc.exp, CloseStatus(tc.in), "unexpected close status")
if err != nil {
t.Fatal(err)
}
})
}
}
Loading