diff --git a/discord.go b/discord.go index 1404a83..776f143 100644 --- a/discord.go +++ b/discord.go @@ -42,6 +42,7 @@ func New(args ...interface{}) (s *Session, err error) { ShouldReconnectOnError: true, ShardID: 0, ShardCount: 1, + MaxRestRetries: 3, } // If no arguments are passed return the empty Session interface. diff --git a/restapi.go b/restapi.go index d4c612c..26070f2 100644 --- a/restapi.go +++ b/restapi.go @@ -45,11 +45,13 @@ func (s *Session) Request(method, urlStr string, data interface{}) (response []b } } - return s.request(method, urlStr, "application/json", body) + return s.request(method, urlStr, "application/json", body, 0) } // request makes a (GET/POST/...) Requests to Discord REST API. -func (s *Session) request(method, urlStr, contentType string, b []byte) (response []byte, err error) { +// Sequence is the sequence number, if it fails with a 502 it will +// retry with sequence+1 until it either succeeds or sequence >= session.MaxRestRetries +func (s *Session) request(method, urlStr, contentType string, b []byte, sequence int) (response []byte, err error) { // rate limit mutex for this url // TODO: review for performance improvements @@ -135,6 +137,16 @@ func (s *Session) request(method, urlStr, contentType string, b []byte) (respons // TODO check for 401 response, invalidate token if we get one. + case http.StatusBadGateway: + // Retry sending request if possible + if sequence < s.MaxRestRetries { + + s.log(LogInformational, "%s Failed (%s), Retrying...", urlStr, resp.Status) + response, err = s.request(method, urlStr, contentType, b, sequence+1) + } else { + err = fmt.Errorf("Exceeded Max retries HTTP %s, %s", resp.Status, response) + } + case 429: // TOO MANY REQUESTS - Rate limiting mu.Lock() // lock URL ratelimit mutex @@ -151,10 +163,10 @@ func (s *Session) request(method, urlStr, contentType string, b []byte) (respons time.Sleep(rl.RetryAfter * time.Millisecond) // we can make the above smarter - // this method can cause longer delays then required + // this method can cause longer delays than required mu.Unlock() // we have to unlock here - response, err = s.request(method, urlStr, contentType, b) + response, err = s.request(method, urlStr, contentType, b, sequence) default: // Error condition err = fmt.Errorf("HTTP %s, %s", resp.Status, response) @@ -718,7 +730,7 @@ func (s *Session) GuildMemberNickname(guildID, userID, nickname string) (err err // guildID : The ID of a Guild. func (s *Session) GuildChannels(guildID string) (st []*Channel, err error) { - body, err := s.request("GET", EndpointGuildChannels(guildID), "", nil) + body, err := s.request("GET", EndpointGuildChannels(guildID), "", nil, 0) if err != nil { return } @@ -1096,7 +1108,7 @@ func (s *Session) ChannelMessage(channelID, messageID string) (st *Message, err // messageID : the ID of a Message func (s *Session) ChannelMessageAck(channelID, messageID string) (err error) { - _, err = s.request("POST", EndpointChannelMessageAck(channelID, messageID), "", nil) + _, err = s.request("POST", EndpointChannelMessageAck(channelID, messageID), "", nil, 0) return } @@ -1264,7 +1276,7 @@ func (s *Session) ChannelFileSendWithMessage(channelID, content string, name str return } - response, err := s.request("POST", EndpointChannelMessages(channelID), bodywriter.FormDataContentType(), body.Bytes()) + response, err := s.request("POST", EndpointChannelMessages(channelID), bodywriter.FormDataContentType(), body.Bytes(), 0) if err != nil { return } diff --git a/structs.go b/structs.go index be554b7..c375fef 100644 --- a/structs.go +++ b/structs.go @@ -53,6 +53,9 @@ type Session struct { // Whether the Data Websocket is ready DataReady bool // NOTE: Maye be deprecated soon + // Max number of REST API retries + MaxRestRetries int + // Status stores the currect status of the websocket connection // this is being tested, may stay, may go away. status int32