diff --git a/restapi.go b/restapi.go index b7d323e..fe4d2fe 100644 --- a/restapi.go +++ b/restapi.go @@ -223,6 +223,7 @@ func (s *Session) RequestWithLockedBucket(method, urlStr, contentType string, b for _, opt := range options { opt(cfg) } + req = cfg.Request if s.Debug { for k, v := range req.Header { diff --git a/restapi_test.go b/restapi_test.go index 2e23a58..56da575 100644 --- a/restapi_test.go +++ b/restapi_test.go @@ -1,7 +1,9 @@ package discordgo import ( + "context" "errors" + "net/http" "testing" ) @@ -235,3 +237,41 @@ func Test_unmarshal(t *testing.T) { t.Errorf("Unexpected error type: %T", err) } } + +func TestWithContext(t *testing.T) { + // Set up a test context. + type key struct{} + ctx := context.WithValue(context.Background(), key{}, "value") + + // Set up a test client. + session, err := New("") + if err != nil { + t.Fatal(err) + } + + testErr := errors.New("test") + + // Intercept the request to assert the context. + session.Client.Transport = roundTripperFunc(func(r *http.Request) (*http.Response, error) { + val, _ := r.Context().Value(key{}).(string) + if val != "value" { + t.Errorf("missing value in context (got %q, wanted %q)", val, "value") + } + return nil, testErr + }) + + // Run any client method using WithContext. + _, err = session.User("", WithContext(ctx)) + + // Verify that the assertion code was actually run. + if !errors.Is(err, testErr) { + t.Errorf("unexpected error %v returned from client", err) + } +} + +// roundTripperFunc implements http.RoundTripper. +type roundTripperFunc func(*http.Request) (*http.Response, error) + +func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +}