diff --git a/discord.go b/discord.go index 04d4719..6c098fd 100644 --- a/discord.go +++ b/discord.go @@ -50,7 +50,7 @@ func New(args ...interface{}) (s *Session, err error) { // Create an empty Session interface. s = &Session{ State: NewState(), - ratelimiter: NewRatelimiter(), + Ratelimiter: NewRatelimiter(), StateEnabled: true, Compress: true, ShouldReconnectOnError: true, diff --git a/ratelimit.go b/ratelimit.go index 223c0d0..c3d553c 100644 --- a/ratelimit.go +++ b/ratelimit.go @@ -41,8 +41,8 @@ func NewRatelimiter() *RateLimiter { } } -// getBucket retrieves or creates a bucket -func (r *RateLimiter) getBucket(key string) *Bucket { +// GetBucket retrieves or creates a bucket +func (r *RateLimiter) GetBucket(key string) *Bucket { r.Lock() defer r.Unlock() @@ -51,7 +51,7 @@ func (r *RateLimiter) getBucket(key string) *Bucket { } b := &Bucket{ - remaining: 1, + Remaining: 1, Key: key, global: r.global, } @@ -67,28 +67,36 @@ func (r *RateLimiter) getBucket(key string) *Bucket { r.buckets[key] = b return b } - -// LockBucket Locks until a request can be made -func (r *RateLimiter) LockBucket(bucketID string) *Bucket { - - b := r.getBucket(bucketID) - - b.Lock() - +func (r *RateLimiter) GetWaitTime(b *Bucket, minRemaining int) time.Duration { // If we ran out of calls and the reset time is still ahead of us // then we need to take it easy and relax a little - if b.remaining < 1 && b.reset.After(time.Now()) { - time.Sleep(b.reset.Sub(time.Now())) - + if b.Remaining < minRemaining && b.reset.After(time.Now()) { + return b.reset.Sub(time.Now()) } // Check for global ratelimits sleepTo := time.Unix(0, atomic.LoadInt64(r.global)) if now := time.Now(); now.Before(sleepTo) { - time.Sleep(sleepTo.Sub(now)) + return sleepTo.Sub(now) } - b.remaining-- + return 0 +} + +// LockBucket Locks until a request can be made +func (r *RateLimiter) LockBucket(bucketID string) *Bucket { + return r.LockBucketObject(r.GetBucket(bucketID)) +} + +// LockBucketObject Locks an already resolved bucket until a request can be made +func (r *RateLimiter) LockBucketObject(b *Bucket) *Bucket { + b.Lock() + + if wait := r.GetWaitTime(b, 1); wait > 0 { + time.Sleep(wait) + } + + b.Remaining-- return b } @@ -96,13 +104,14 @@ func (r *RateLimiter) LockBucket(bucketID string) *Bucket { type Bucket struct { sync.Mutex Key string - remaining int + Remaining int limit int reset time.Time global *int64 lastReset time.Time customRateLimit *customRateLimit + Userdata interface{} } // Release unlocks the bucket and reads the headers to update the buckets ratelimit info @@ -113,10 +122,10 @@ func (b *Bucket) Release(headers http.Header) error { // Check if the bucket uses a custom ratelimiter if rl := b.customRateLimit; rl != nil { if time.Now().Sub(b.lastReset) >= rl.reset { - b.remaining = rl.requests - 1 + b.Remaining = rl.requests - 1 b.lastReset = time.Now() } - if b.remaining < 1 { + if b.Remaining < 1 { b.reset = time.Now().Add(rl.reset) } return nil @@ -176,7 +185,7 @@ func (b *Bucket) Release(headers http.Header) error { if err != nil { return err } - b.remaining = int(parsedRemaining) + b.Remaining = int(parsedRemaining) } return nil diff --git a/restapi.go b/restapi.go index 1a682fa..9ba59c7 100644 --- a/restapi.go +++ b/restapi.go @@ -65,9 +65,11 @@ func (s *Session) request(method, urlStr, contentType string, b []byte, bucketID if bucketID == "" { bucketID = strings.SplitN(urlStr, "?", 2)[0] } + return s.RequestWithLockedBucket(method, urlStr, contentType, b, s.Ratelimiter.LockBucket(bucketID), sequence) +} - bucket := s.ratelimiter.LockBucket(bucketID) - +// RequestWithLockedBucket makes a request using a bucket that's already been locked +func (s *Session) RequestWithLockedBucket(method, urlStr, contentType string, b []byte, bucket *Bucket, sequence int) (response []byte, err error) { if s.Debug { log.Printf("API REQUEST %8s :: %s\n", method, urlStr) log.Printf("API REQUEST PAYLOAD :: [%s]\n", string(b)) @@ -139,7 +141,7 @@ func (s *Session) request(method, urlStr, contentType string, b []byte, bucketID if sequence < s.MaxRestRetries { s.log(LogInformational, "%s Failed (%s), Retrying...", urlStr, resp.Status) - response, err = s.request(method, urlStr, contentType, b, bucketID, sequence+1) + response, err = s.RequestWithLockedBucket(method, urlStr, contentType, b, s.Ratelimiter.LockBucketObject(bucket), sequence+1) } else { err = fmt.Errorf("Exceeded Max retries HTTP %s, %s", resp.Status, response) } @@ -158,7 +160,7 @@ func (s *Session) request(method, urlStr, contentType string, b []byte, bucketID // we can make the above smarter // this method can cause longer delays than required - response, err = s.request(method, urlStr, contentType, b, bucketID, sequence) + response, err = s.RequestWithLockedBucket(method, urlStr, contentType, b, s.Ratelimiter.LockBucketObject(bucket), sequence) default: // Error condition err = newRestError(req, resp, response) diff --git a/structs.go b/structs.go index 80928a8..c5dbdae 100644 --- a/structs.go +++ b/structs.go @@ -83,6 +83,9 @@ type Session struct { // Stores the last HeartbeatAck that was recieved (in UTC) LastHeartbeatAck time.Time + // used to deal with rate limits + Ratelimiter *RateLimiter + // Event handlers handlersMu sync.RWMutex handlers map[string][]*eventHandlerInstance @@ -94,9 +97,6 @@ type Session struct { // When nil, the session is not listening. listening chan interface{} - // used to deal with rate limits - ratelimiter *RateLimiter - // sequence tracks the current gateway api websocket sequence number sequence *int64