diff --git a/dmfr/feed_fetch.go b/dmfr/feed_fetch.go index 80dd9965..9b705fe9 100644 --- a/dmfr/feed_fetch.go +++ b/dmfr/feed_fetch.go @@ -12,6 +12,7 @@ type FeedFetch struct { ResponseSize tl.Int ResponseCode tl.Int ResponseSHA1 tl.String + ResponseETag tl.String FeedVersionID tl.Int // optional field, don't use tl.FeedVersionEntity tl.Timestamps tl.DatabaseEntity diff --git a/tl/request/http.go b/tl/request/http.go index 5a1c0024..1fdc616d 100644 --- a/tl/request/http.go +++ b/tl/request/http.go @@ -15,14 +15,19 @@ import ( type Http struct{} func (r Http) Download(ctx context.Context, ustr string, secret tl.Secret, auth tl.FeedAuthorization) (io.ReadCloser, int, error) { + rc, code, _, err := r.ETagDownload(ctx, ustr, "", secret, auth) + return rc, code, err +} + +func (r Http) ETagDownload(ctx context.Context, ustr string, etag string, secret tl.Secret, auth tl.FeedAuthorization) (io.ReadCloser, int, string, error) { u, err := url.Parse(ustr) if err != nil { - return nil, 0, errors.New("could not parse url") + return nil, 0, "", errors.New("could not parse url") } if auth.Type == "query_param" { v, err := url.ParseQuery(u.RawQuery) if err != nil { - return nil, 0, errors.New("could not parse query string") + return nil, 0, "", errors.New("could not parse query string") } v.Set(auth.ParamName, secret.Key) u.RawQuery = v.Encode() @@ -31,7 +36,7 @@ func (r Http) Download(ctx context.Context, ustr string, secret tl.Secret, auth } else if auth.Type == "replace_url" { u, err = url.Parse(secret.ReplaceUrl) if err != nil { - return nil, 0, errors.New("could not parse replacement query string") + return nil, 0, "", errors.New("could not parse replacement query string") } } ustr = u.String() @@ -39,7 +44,7 @@ func (r Http) Download(ctx context.Context, ustr string, secret tl.Secret, auth // Prepare HTTP request req, err := http.NewRequestWithContext(ctx, "GET", ustr, nil) if err != nil { - return nil, 0, errors.New("invalid request") + return nil, 0, "", errors.New("invalid request") } // Set basic auth, if used @@ -56,11 +61,11 @@ func (r Http) Download(ctx context.Context, ustr string, secret tl.Secret, auth resp, err := client.Do(req) if err != nil { // return error directly - return nil, 0, err + return nil, 0, "", err } if resp.StatusCode >= 400 { resp.Body.Close() - return nil, resp.StatusCode, fmt.Errorf("response status code: %d", resp.StatusCode) + return nil, resp.StatusCode, "", fmt.Errorf("response status code: %d", resp.StatusCode) } - return resp.Body, resp.StatusCode, nil + return resp.Body, resp.StatusCode, resp.Header.Get("ETag"), nil } diff --git a/tl/request/request.go b/tl/request/request.go index 121322a2..df23d491 100644 --- a/tl/request/request.go +++ b/tl/request/request.go @@ -21,6 +21,10 @@ type Downloader interface { Download(context.Context, string, tl.Secret, tl.FeedAuthorization) (io.ReadCloser, int, error) } +type ETagDownloader interface { + ETagDownload(context.Context, string, string, tl.Secret, tl.FeedAuthorization) (io.ReadCloser, int, string, error) +} + type Uploader interface { Upload(context.Context, string, tl.Secret, io.Reader) error } @@ -29,11 +33,22 @@ type Presigner interface { CreateSignedUrl(context.Context, string, tl.Secret) (string, error) } +type FetchResponse struct { + Filename string + Data []byte + ResponseSize int + ResponseCode int + ResponseSHA1 string + ResponseETag string + FetchError error +} + type Request struct { URL string AllowFTP bool AllowLocal bool AllowS3 bool + CheckETag string MaxSize uint64 Secret tl.Secret Auth tl.FeedAuthorization @@ -126,15 +141,6 @@ func WithAuth(secret tl.Secret, auth tl.FeedAuthorization) func(req *Request) { } } -type FetchResponse struct { - Filename string - Data []byte - ResponseSize int - ResponseCode int - ResponseSHA1 string - FetchError error -} - // AuthenticatedRequestDownload is similar to AuthenticatedRequest but writes to a temporary file. // Fatal errors will be returned as the error; non-fatal errors as FetchResponse.FetchError func AuthenticatedRequestDownload(address string, opts ...RequestOption) (FetchResponse, error) { @@ -182,7 +188,7 @@ func authenticatedRequest(out io.Writer, address string, opts ...RequestOption) fr := FetchResponse{} req := NewRequest(address, opts...) var r io.ReadCloser - r, fr.ResponseCode, fr.FetchError = req.Request(ctx) + r, fr.ResponseCode, fr.ResponseETag, fr.FetchError = req.Request(ctx) if fr.FetchError != nil { return fr, nil } diff --git a/tl/request/request_test.go b/tl/request/request_test.go index 35977370..c7278db2 100644 --- a/tl/request/request_test.go +++ b/tl/request/request_test.go @@ -1,6 +1,8 @@ package request import ( + "crypto/md5" + "encoding/hex" "encoding/json" "net/http" "net/http/httptest" @@ -25,6 +27,10 @@ func TestAuthorizedRequest(t *testing.T) { http.Error(w, err.Error(), 400) return } + h := md5.New() // So it's distinct from checksha1 + h.Write(a) + etag := hex.EncodeToString(h.Sum(nil)) + w.Header().Add("ETag", etag) w.Header().Add("Status-Code", "200") w.Write(a) })) @@ -38,6 +44,7 @@ func TestAuthorizedRequest(t *testing.T) { checksize int checkcode int checksha1 string + checketag string expectError bool secret tl.Secret }{ @@ -50,6 +57,7 @@ func TestAuthorizedRequest(t *testing.T) { checksize: 29, checkcode: 200, checksha1: "66621b979e91314ea163d94be8e7486bdcfe07c9", + checketag: "f72666cd1f7a71508c6e81ac007c90bb", }, { name: "query_param", @@ -70,7 +78,8 @@ func TestAuthorizedRequest(t *testing.T) { checkvalue: "/anything/abcd/ok", checksize: 0, checkcode: 200, - checksha1: "", + checksha1: "22c2326192e08c5659ce99dcee454c86ed3fa09e", + checketag: "b820d771af0771269aae463658d4992a", secret: tl.Secret{Key: "abcd"}, }, { @@ -145,6 +154,9 @@ func TestAuthorizedRequest(t *testing.T) { if tc.checksha1 != "" { assert.Equal(t, tc.checksha1, fr.ResponseSHA1, "did not match expected sha1") } + if tc.checketag != "" { + assert.Equal(t, tc.checketag, fr.ResponseETag, "did not match expected etag") + } if tc.checkkey != "" { a, ok := result[tc.checkkey].(string) if !ok {