diff --git a/.travis.yml b/.travis.yml index 440159d..92beb2e 100644 --- a/.travis.yml +++ b/.travis.yml @@ -10,5 +10,5 @@ install: script: - diff <(gofmt -d .) <(echo -n) - go vet -x ./... - - golint -set_exit_status ./... + - golint ./... - go test -v -race ./... diff --git a/README.md b/README.md index ac9dbcd..eb9f14f 100644 --- a/README.md +++ b/README.md @@ -90,7 +90,6 @@ that information in a nice format. - [![GoDoc](https://godoc.org/github.com/bwmarrin/discordgo?status.svg)](https://godoc.org/github.com/bwmarrin/discordgo) - [![Go Walker](http://gowalker.org/api/v1/badge)](https://gowalker.org/github.com/bwmarrin/discordgo) -- [Unofficial Discord API Documentation](https://discordapi.readthedocs.org/en/latest/) - Hand crafted documentation coming eventually. diff --git a/discord.go b/discord.go index ef5bf28..04d4719 100644 --- a/discord.go +++ b/discord.go @@ -20,8 +20,8 @@ import ( "time" ) -// VERSION of Discordgo, follows Symantic Versioning. (http://semver.org/) -const VERSION = "0.16.0-dev" +// VERSION of DiscordGo, follows Semantic Versioning. (http://semver.org/) +const VERSION = "0.17.0-dev" // ErrMFA will be risen by New when the user has 2FA. var ErrMFA = errors.New("account has 2FA enabled") @@ -59,6 +59,7 @@ func New(args ...interface{}) (s *Session, err error) { MaxRestRetries: 3, Client: &http.Client{Timeout: (20 * time.Second)}, sequence: new(int64), + LastHeartbeatAck: time.Now().UTC(), } // If no arguments are passed return the empty Session interface. diff --git a/discord_test.go b/discord_test.go index afac0bc..ca4472a 100644 --- a/discord_test.go +++ b/discord_test.go @@ -11,9 +11,11 @@ import ( ////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////// VARS NEEDED FOR TESTING var ( - dg *Session // Stores global discordgo session + dg *Session // Stores a global discordgo user session + dgBot *Session // Stores a global discordgo bot session - envToken = os.Getenv("DG_TOKEN") // Token to use when authenticating + envToken = os.Getenv("DG_TOKEN") // Token to use when authenticating the user account + envBotToken = os.Getenv("DGB_TOKEN") // Token to use when authenticating the bot account envEmail = os.Getenv("DG_EMAIL") // Email to use when authenticating envPassword = os.Getenv("DG_PASSWORD") // Password to use when authenticating envGuild = os.Getenv("DG_GUILD") // Guild ID to use for tests @@ -23,6 +25,12 @@ var ( ) func init() { + if envBotToken != "" { + if d, err := New(envBotToken); err == nil { + dgBot = d + } + } + if envEmail == "" || envPassword == "" || envToken == "" { return } diff --git a/endpoints.go b/endpoints.go index 96bcf28..0ecdf0b 100644 --- a/endpoints.go +++ b/endpoints.go @@ -18,13 +18,14 @@ var ( EndpointSmActive = EndpointSm + "active.json" EndpointSmUpcoming = EndpointSm + "upcoming.json" - EndpointDiscord = "https://discordapp.com/" - EndpointAPI = EndpointDiscord + "api/" - EndpointGuilds = EndpointAPI + "guilds/" - EndpointChannels = EndpointAPI + "channels/" - EndpointUsers = EndpointAPI + "users/" - EndpointGateway = EndpointAPI + "gateway" - EndpointWebhooks = EndpointAPI + "webhooks/" + EndpointDiscord = "https://discordapp.com/" + EndpointAPI = EndpointDiscord + "api/" + EndpointGuilds = EndpointAPI + "guilds/" + EndpointChannels = EndpointAPI + "channels/" + EndpointUsers = EndpointAPI + "users/" + EndpointGateway = EndpointAPI + "gateway" + EndpointGatewayBot = EndpointGateway + "/bot" + EndpointWebhooks = EndpointAPI + "webhooks/" EndpointCDN = "https://cdn.discordapp.com/" EndpointCDNAttachments = EndpointCDN + "attachments/" @@ -54,16 +55,17 @@ var ( EndpointReport = EndpointAPI + "report" EndpointIntegrations = EndpointAPI + "integrations" - EndpointUser = func(uID string) string { return EndpointUsers + uID } - EndpointUserAvatar = func(uID, aID string) string { return EndpointCDNAvatars + uID + "/" + aID + ".png" } - EndpointUserSettings = func(uID string) string { return EndpointUsers + uID + "/settings" } - EndpointUserGuilds = func(uID string) string { return EndpointUsers + uID + "/guilds" } - EndpointUserGuild = func(uID, gID string) string { return EndpointUsers + uID + "/guilds/" + gID } - EndpointUserGuildSettings = func(uID, gID string) string { return EndpointUsers + uID + "/guilds/" + gID + "/settings" } - EndpointUserChannels = func(uID string) string { return EndpointUsers + uID + "/channels" } - EndpointUserDevices = func(uID string) string { return EndpointUsers + uID + "/devices" } - EndpointUserConnections = func(uID string) string { return EndpointUsers + uID + "/connections" } - EndpointUserNotes = func(uID string) string { return EndpointUsers + "@me/notes/" + uID } + EndpointUser = func(uID string) string { return EndpointUsers + uID } + EndpointUserAvatar = func(uID, aID string) string { return EndpointCDNAvatars + uID + "/" + aID + ".png" } + EndpointUserAvatarAnimated = func(uID, aID string) string { return EndpointCDNAvatars + uID + "/" + aID + ".gif" } + EndpointUserSettings = func(uID string) string { return EndpointUsers + uID + "/settings" } + EndpointUserGuilds = func(uID string) string { return EndpointUsers + uID + "/guilds" } + EndpointUserGuild = func(uID, gID string) string { return EndpointUsers + uID + "/guilds/" + gID } + EndpointUserGuildSettings = func(uID, gID string) string { return EndpointUsers + uID + "/guilds/" + gID + "/settings" } + EndpointUserChannels = func(uID string) string { return EndpointUsers + uID + "/channels" } + EndpointUserDevices = func(uID string) string { return EndpointUsers + uID + "/devices" } + EndpointUserConnections = func(uID string) string { return EndpointUsers + uID + "/connections" } + EndpointUserNotes = func(uID string) string { return EndpointUsers + "@me/notes/" + uID } EndpointGuild = func(gID string) string { return EndpointGuilds + gID } EndpointGuildInivtes = func(gID string) string { return EndpointGuilds + gID + "/invites" } diff --git a/message.go b/message.go index b743327..d46f3f3 100644 --- a/message.go +++ b/message.go @@ -12,6 +12,7 @@ package discordgo import ( "io" "regexp" + "strings" ) // A Message stores all data related to a specific Discord message. @@ -33,8 +34,9 @@ type Message struct { // File stores info about files you e.g. send in messages. type File struct { - Name string - Reader io.Reader + Name string + ContentType string + Reader io.Reader } // MessageSend stores all parameters you can send with ChannelMessageSendComplex. @@ -42,7 +44,10 @@ type MessageSend struct { Content string `json:"content,omitempty"` Embed *MessageEmbed `json:"embed,omitempty"` Tts bool `json:"tts"` - File *File `json:"file"` + Files []*File `json:"-"` + + // TODO: Remove this when compatibility is not required. + File *File `json:"-"` } // MessageEdit is used to chain parameters via ChannelMessageEditComplex, which @@ -167,13 +172,65 @@ type MessageReactions struct { // ContentWithMentionsReplaced will replace all @ mentions with the // username of the mention. -func (m *Message) ContentWithMentionsReplaced() string { - if m.Mentions == nil { - return m.Content - } - content := m.Content +func (m *Message) ContentWithMentionsReplaced() (content string) { + content = m.Content + for _, user := range m.Mentions { - content = regexp.MustCompile("<@!?("+regexp.QuoteMeta(user.ID)+")>").ReplaceAllString(content, "@"+user.Username) + content = strings.NewReplacer( + "<@"+user.ID+">", "@"+user.Username, + "<@!"+user.ID+">", "@"+user.Username, + ).Replace(content) } - return content + return +} + +var patternChannels = regexp.MustCompile("<#[^>]*>") + +// ContentWithMoreMentionsReplaced will replace all @ mentions with the +// username of the mention, but also role IDs and more. +func (m *Message) ContentWithMoreMentionsReplaced(s *Session) (content string, err error) { + content = m.Content + + if !s.StateEnabled { + content = m.ContentWithMentionsReplaced() + return + } + + channel, err := s.State.Channel(m.ChannelID) + if err != nil { + content = m.ContentWithMentionsReplaced() + return + } + + for _, user := range m.Mentions { + nick := user.Username + + member, err := s.State.Member(channel.GuildID, user.ID) + if err == nil && member.Nick != "" { + nick = member.Nick + } + + content = strings.NewReplacer( + "<@"+user.ID+">", "@"+user.Username, + "<@!"+user.ID+">", "@"+nick, + ).Replace(content) + } + for _, roleID := range m.MentionRoles { + role, err := s.State.Role(channel.GuildID, roleID) + if err != nil || !role.Mentionable { + continue + } + + content = strings.Replace(content, "<&"+role.ID+">", "@"+role.Name, -1) + } + + content = patternChannels.ReplaceAllStringFunc(content, func(mention string) string { + channel, err := s.State.Channel(mention[2 : len(mention)-1]) + if err != nil || channel.Type == "voice" { + return mention + } + + return "#" + channel.Name + }) + return } diff --git a/message_test.go b/message_test.go new file mode 100644 index 0000000..fd2f133 --- /dev/null +++ b/message_test.go @@ -0,0 +1,41 @@ +package discordgo + +import ( + "testing" +) + +func TestContentWithMoreMentionsReplaced(t *testing.T) { + s := &Session{StateEnabled: true, State: NewState()} + + user := &User{ + ID: "user", + Username: "User Name", + } + + s.StateEnabled = true + s.State.GuildAdd(&Guild{ID: "guild"}) + s.State.RoleAdd("guild", &Role{ + ID: "role", + Name: "Role Name", + Mentionable: true, + }) + s.State.MemberAdd(&Member{ + User: user, + Nick: "User Nick", + GuildID: "guild", + }) + s.State.ChannelAdd(&Channel{ + Name: "Channel Name", + GuildID: "guild", + ID: "channel", + }) + m := &Message{ + Content: "<&role> <@!user> <@user> <#channel>", + ChannelID: "channel", + MentionRoles: []string{"role"}, + Mentions: []*User{user}, + } + if result, _ := m.ContentWithMoreMentionsReplaced(s); result != "@Role Name @User Nick @User Name #Channel Name" { + t.Error(result) + } +} diff --git a/restapi.go b/restapi.go index 7c9fd81..bd944f2 100644 --- a/restapi.go +++ b/restapi.go @@ -23,14 +23,22 @@ import ( "log" "mime/multipart" "net/http" + "net/textproto" "net/url" "strconv" "strings" "time" ) -// ErrJSONUnmarshal is returned for JSON Unmarshall errors. -var ErrJSONUnmarshal = errors.New("json unmarshal") +// All error constants +var ( + ErrJSONUnmarshal = errors.New("json unmarshal") + ErrStatusOffline = errors.New("You can't set your Status to offline") + ErrVerificationLevelBounds = errors.New("VerificationLevel out of bounds, should be between 0 and 3") + ErrPruneDaysBounds = errors.New("the number of days should be more than or equal to 1") + ErrGuildNoIcon = errors.New("guild does not have an icon set") + ErrGuildNoSplash = errors.New("guild does not have a splash set") +) // Request is the same as RequestWithBucketID but the bucket id is the same as the urlStr func (s *Session) Request(method, urlStr string, data interface{}) (response []byte, err error) { @@ -302,8 +310,8 @@ func (s *Session) UserUpdate(email, password, username, avatar, newPassword stri // If left blank, avatar will be set to null/blank data := struct { - Email string `json:"email"` - Password string `json:"password"` + Email string `json:"email,omitempty"` + Password string `json:"password,omitempty"` Username string `json:"username,omitempty"` Avatar string `json:"avatar,omitempty"` NewPassword string `json:"new_password,omitempty"` @@ -334,7 +342,7 @@ func (s *Session) UserSettings() (st *Settings, err error) { // status : The new status (Actual valid status are 'online','idle','dnd','invisible') func (s *Session) UserUpdateStatus(status Status) (st *Settings, err error) { if status == StatusOffline { - err = errors.New("You can't set your Status to offline") + err = ErrStatusOffline return } @@ -595,7 +603,7 @@ func (s *Session) GuildEdit(guildID string, g GuildParams) (st *Guild, err error if g.VerificationLevel != nil { val := *g.VerificationLevel if val < 0 || val > 3 { - err = errors.New("VerificationLevel out of bounds, should be between 0 and 3") + err = ErrVerificationLevelBounds return } } @@ -756,7 +764,21 @@ func (s *Session) GuildMember(guildID, userID string) (st *Member, err error) { // userID : The ID of a User func (s *Session) GuildMemberDelete(guildID, userID string) (err error) { - _, err = s.RequestWithBucketID("DELETE", EndpointGuildMember(guildID, userID), nil, EndpointGuildMember(guildID, "")) + return s.GuildMemberDeleteWithReason(guildID, userID, "") +} + +// GuildMemberDelete removes the given user from the given guild. +// guildID : The ID of a Guild. +// userID : The ID of a User +// reason : The reason for the kick +func (s *Session) GuildMemberDeleteWithReason(guildID, userID, reason string) (err error) { + + uri := EndpointGuildMember(guildID, userID) + if reason != "" { + uri += "?reason=" + url.QueryEscape(reason) + } + + _, err = s.RequestWithBucketID("DELETE", uri, nil, EndpointGuildMember(guildID, "")) return } @@ -988,7 +1010,7 @@ func (s *Session) GuildPruneCount(guildID string, days uint32) (count uint32, er count = 0 if days <= 0 { - err = errors.New("the number of days should be more than or equal to 1") + err = ErrPruneDaysBounds return } @@ -1018,7 +1040,7 @@ func (s *Session) GuildPrune(guildID string, days uint32) (count uint32, err err count = 0 if days <= 0 { - err = errors.New("the number of days should be more than or equal to 1") + err = ErrPruneDaysBounds return } @@ -1120,7 +1142,7 @@ func (s *Session) GuildIcon(guildID string) (img image.Image, err error) { } if g.Icon == "" { - err = errors.New("guild does not have an icon set") + err = ErrGuildNoIcon return } @@ -1142,7 +1164,7 @@ func (s *Session) GuildSplash(guildID string) (img image.Image, err error) { } if g.Splash == "" { - err = errors.New("guild does not have a splash set") + err = ErrGuildNoSplash return } @@ -1309,6 +1331,8 @@ func (s *Session) ChannelMessageSend(channelID string, content string) (*Message }) } +var quoteEscaper = strings.NewReplacer("\\", "\\\\", `"`, "\\\"") + // ChannelMessageSendComplex sends a message to the given channel. // channelID : The ID of a Channel. // data : The message struct to send. @@ -1319,48 +1343,62 @@ func (s *Session) ChannelMessageSendComplex(channelID string, data *MessageSend) endpoint := EndpointChannelMessages(channelID) - var response []byte + // TODO: Remove this when compatibility is not required. + files := data.Files if data.File != nil { + if files == nil { + files = []*File{data.File} + } else { + err = fmt.Errorf("cannot specify both File and Files") + return + } + } + + var response []byte + if len(files) > 0 { body := &bytes.Buffer{} bodywriter := multipart.NewWriter(body) - // What's a better way of doing this? Reflect? Generator? I'm open to suggestions - - if data.Content != "" { - if err = bodywriter.WriteField("content", data.Content); err != nil { - return - } - } - - if data.Embed != nil { - var embed []byte - embed, err = json.Marshal(data.Embed) - if err != nil { - return - } - err = bodywriter.WriteField("embed", string(embed)) - if err != nil { - return - } - } - - if data.Tts { - if err = bodywriter.WriteField("tts", "true"); err != nil { - return - } - } - - var writer io.Writer - writer, err = bodywriter.CreateFormFile("file", data.File.Name) + var payload []byte + payload, err = json.Marshal(data) if err != nil { return } - _, err = io.Copy(writer, data.File.Reader) + var p io.Writer + + h := make(textproto.MIMEHeader) + h.Set("Content-Disposition", `form-data; name="payload_json"`) + h.Set("Content-Type", "application/json") + + p, err = bodywriter.CreatePart(h) if err != nil { return } + if _, err = p.Write(payload); err != nil { + return + } + + for i, file := range files { + h := make(textproto.MIMEHeader) + h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="file%d"; filename="%s"`, i, quoteEscaper.Replace(file.Name))) + contentType := file.ContentType + if contentType == "" { + contentType = "application/octet-stream" + } + h.Set("Content-Type", contentType) + + p, err = bodywriter.CreatePart(h) + if err != nil { + return + } + + if _, err = io.Copy(p, file.Reader); err != nil { + return + } + } + err = bodywriter.Close() if err != nil { return @@ -1678,6 +1716,28 @@ func (s *Session) Gateway() (gateway string, err error) { return } +// GatewayBot returns the websocket Gateway address and the recommended number of shards +func (s *Session) GatewayBot() (st *GatewayBotResponse, err error) { + + response, err := s.RequestWithBucketID("GET", EndpointGatewayBot, nil, EndpointGatewayBot) + if err != nil { + return + } + + err = unmarshal(response, &st) + if err != nil { + return + } + + // Ensure the gateway always has a trailing slash. + // MacOS will fail to connect if we add query params without a trailing slash on the base domain. + if !strings.HasSuffix(st.URL, "/") { + st.URL += "/" + } + + return +} + // Functions specific to Webhooks // WebhookCreate returns a new Webhook. diff --git a/restapi_test.go b/restapi_test.go index a5d326b..7aa4e60 100644 --- a/restapi_test.go +++ b/restapi_test.go @@ -166,6 +166,17 @@ func TestGateway(t *testing.T) { } } +func TestGatewayBot(t *testing.T) { + + if dgBot == nil { + t.Skip("Skipping, dgBot not set.") + } + _, err := dgBot.GatewayBot() + if err != nil { + t.Errorf("GatewayBot() returned error: %+v", err) + } +} + func TestVoiceICE(t *testing.T) { if dg == nil { diff --git a/state.go b/state.go index c2c5519..7400ef6 100644 --- a/state.go +++ b/state.go @@ -21,6 +21,10 @@ import ( // ErrNilState is returned when the state is nil. var ErrNilState = errors.New("state not instantiated, please use discordgo.New() or assign Session.State") +// ErrStateNotFound is returned when the state cache +// requested is not found +var ErrStateNotFound = errors.New("state cache not found") + // A State contains the current known state. // As discord sends this in a READY blob, it seems reasonable to simply // use that struct as the data store. @@ -146,7 +150,7 @@ func (s *State) Guild(guildID string) (*Guild, error) { return g, nil } - return nil, errors.New("guild not found") + return nil, ErrStateNotFound } // PresenceAdd adds a presence to the current world state, or @@ -227,7 +231,7 @@ func (s *State) PresenceRemove(guildID string, presence *Presence) error { } } - return errors.New("presence not found") + return ErrStateNotFound } // Presence gets a presence by ID from a guild. @@ -247,7 +251,7 @@ func (s *State) Presence(guildID, userID string) (*Presence, error) { } } - return nil, errors.New("presence not found") + return nil, ErrStateNotFound } // TODO: Consider moving Guild state update methods onto *Guild. @@ -299,7 +303,7 @@ func (s *State) MemberRemove(member *Member) error { } } - return errors.New("member not found") + return ErrStateNotFound } // Member gets a member by ID from a guild. @@ -322,7 +326,7 @@ func (s *State) Member(guildID, userID string) (*Member, error) { } } - return nil, errors.New("member not found") + return nil, ErrStateNotFound } // RoleAdd adds a role to the current world state, or @@ -372,7 +376,7 @@ func (s *State) RoleRemove(guildID, roleID string) error { } } - return errors.New("role not found") + return ErrStateNotFound } // Role gets a role by ID from a guild. @@ -395,7 +399,7 @@ func (s *State) Role(guildID, roleID string) (*Role, error) { } } - return nil, errors.New("role not found") + return nil, ErrStateNotFound } // ChannelAdd adds a channel to the current world state, or @@ -428,7 +432,7 @@ func (s *State) ChannelAdd(channel *Channel) error { } else { guild, ok := s.guildMap[channel.GuildID] if !ok { - return errors.New("guild for channel not found") + return ErrStateNotFound } guild.Channels = append(guild.Channels, channel) @@ -507,7 +511,7 @@ func (s *State) Channel(channelID string) (*Channel, error) { return c, nil } - return nil, errors.New("channel not found") + return nil, ErrStateNotFound } // Emoji returns an emoji for a guild and emoji id. @@ -530,7 +534,7 @@ func (s *State) Emoji(guildID, emojiID string) (*Emoji, error) { } } - return nil, errors.New("emoji not found") + return nil, ErrStateNotFound } // EmojiAdd adds an emoji to the current world state. @@ -647,7 +651,7 @@ func (s *State) messageRemoveByID(channelID, messageID string) error { } } - return errors.New("message not found") + return ErrStateNotFound } func (s *State) voiceStateUpdate(update *VoiceStateUpdate) error { @@ -701,7 +705,7 @@ func (s *State) Message(channelID, messageID string) (*Message, error) { } } - return nil, errors.New("message not found") + return nil, ErrStateNotFound } // OnReady takes a Ready event and updates all internal state. diff --git a/structs.go b/structs.go index 3a6ec05..9697fa5 100644 --- a/structs.go +++ b/structs.go @@ -78,6 +78,9 @@ type Session struct { // The http client used for REST requests Client *http.Client + // Stores the last HeartbeatAck that was recieved (in UTC) + LastHeartbeatAck time.Time + // Event handlers handlersMu sync.RWMutex handlers map[string][]*eventHandlerInstance @@ -304,7 +307,7 @@ type Game struct { // UnmarshalJSON unmarshals json to Game struct func (g *Game) UnmarshalJSON(bytes []byte) error { temp := &struct { - Name string `json:"name"` + Name json.Number `json:"name"` Type json.RawMessage `json:"type"` URL string `json:"url"` }{} @@ -312,8 +315,8 @@ func (g *Game) UnmarshalJSON(bytes []byte) error { if err != nil { return err } - g.Name = temp.Name g.URL = temp.URL + g.Name = temp.Name.String() if temp.Type != nil { err = json.Unmarshal(temp.Type, &g.Type) @@ -509,6 +512,12 @@ type MessageReaction struct { ChannelID string `json:"channel_id"` } +// GatewayBotResponse stores the data for the gateway/bot response +type GatewayBotResponse struct { + URL string `json:"url"` + Shards int `json:"shards"` +} + // Constants for the different bit offsets of text channel permissions const ( PermissionReadMessages = 1 << (iota + 10) @@ -549,6 +558,8 @@ const ( PermissionAdministrator PermissionManageChannels PermissionManageServer + PermissionAddReactions + PermissionViewAuditLogs PermissionAllText = PermissionReadMessages | PermissionSendMessages | @@ -568,7 +579,9 @@ const ( PermissionAllVoice | PermissionCreateInstantInvite | PermissionManageRoles | - PermissionManageChannels + PermissionManageChannels | + PermissionAddReactions | + PermissionViewAuditLogs PermissionAll = PermissionAllChannel | PermissionKickMembers | PermissionBanMembers | diff --git a/user.go b/user.go index cdc8a85..8abffb0 100644 --- a/user.go +++ b/user.go @@ -1,5 +1,7 @@ package discordgo +import "strings" + // A User stores all data for an individual Discord user. type User struct { ID string `json:"id"` @@ -22,3 +24,16 @@ func (u *User) String() string { func (u *User) Mention() string { return "<@" + u.ID + ">" } + +// AvatarURL returns a URL to the user's avatar. +// size: The size of the user's avatar as a power of two +func (u *User) AvatarURL(size string) string { + var URL string + if strings.HasPrefix(u.Avatar, "a_") { + URL = EndpointUserAvatarAnimated(u.ID, u.Avatar) + } else { + URL = EndpointUserAvatar(u.ID, u.Avatar) + } + + return URL + "?size=" + size +} diff --git a/voice.go b/voice.go index 8b566f4..5bbd0ad 100644 --- a/voice.go +++ b/voice.go @@ -814,7 +814,11 @@ func (v *VoiceConnection) opusReceiver(udpConn *net.UDPConn, close <-chan struct p.Opus, _ = secretbox.Open(nil, recvbuf[12:rlen], &nonce, &v.op4.SecretKey) if c != nil { - c <- &p + select { + case c <- &p: + case <-close: + return + } } } } diff --git a/wsapi.go b/wsapi.go index adab402..213ea72 100644 --- a/wsapi.go +++ b/wsapi.go @@ -25,6 +25,18 @@ import ( "github.com/gorilla/websocket" ) +// ErrWSAlreadyOpen is thrown when you attempt to open +// a websocket that already is open. +var ErrWSAlreadyOpen = errors.New("web socket already opened") + +// ErrWSNotFound is thrown when you attempt to use a websocket +// that doesn't exist +var ErrWSNotFound = errors.New("no websocket connection exists") + +// ErrWSShardBounds is thrown when you try to use a shard ID that is +// less than the total shard count +var ErrWSShardBounds = errors.New("ShardID must be less than ShardCount") + type resumePacket struct { Op int `json:"op"` Data struct { @@ -58,7 +70,7 @@ func (s *Session) Open() (err error) { } if s.wsConn != nil { - err = errors.New("web socket already opened") + err = ErrWSAlreadyOpen return } @@ -119,6 +131,7 @@ func (s *Session) Open() (err error) { // lock. s.listening = make(chan interface{}) go s.listen(s.wsConn, s.listening) + s.LastHeartbeatAck = time.Now().UTC() s.Unlock() @@ -187,10 +200,13 @@ type helloOp struct { Trace []string `json:"_trace"` } +// Number of heartbeat intervals to wait until forcing a connection restart. +const FailedHeartbeatAcks time.Duration = 5 * time.Millisecond + // heartbeat sends regular heartbeats to Discord so it knows the client // is still connected. If you do not send these heartbeats Discord will // disconnect the websocket connection after a few seconds. -func (s *Session) heartbeat(wsConn *websocket.Conn, listening <-chan interface{}, i time.Duration) { +func (s *Session) heartbeat(wsConn *websocket.Conn, listening <-chan interface{}, heartbeatIntervalMsec time.Duration) { s.log(LogInformational, "called") @@ -199,20 +215,26 @@ func (s *Session) heartbeat(wsConn *websocket.Conn, listening <-chan interface{} } var err error - ticker := time.NewTicker(i * time.Millisecond) + ticker := time.NewTicker(heartbeatIntervalMsec * time.Millisecond) defer ticker.Stop() for { + s.RLock() + last := s.LastHeartbeatAck + s.RUnlock() sequence := atomic.LoadInt64(s.sequence) s.log(LogInformational, "sending gateway websocket heartbeat seq %d", sequence) s.wsMutex.Lock() err = wsConn.WriteJSON(heartbeatOp{1, sequence}) s.wsMutex.Unlock() - if err != nil { - s.log(LogError, "error sending heartbeat to gateway %s, %s", s.gateway, err) - s.Lock() - s.DataReady = false - s.Unlock() + if err != nil || time.Now().UTC().Sub(last) > (heartbeatIntervalMsec*FailedHeartbeatAcks) { + if err != nil { + s.log(LogError, "error sending heartbeat to gateway %s, %s", s.gateway, err) + } else { + s.log(LogError, "haven't gotten a heartbeat ACK in %v, triggering a reconnection", time.Now().UTC().Sub(last)) + } + s.Close() + s.reconnect() return } s.Lock() @@ -250,7 +272,7 @@ func (s *Session) UpdateStreamingStatus(idle int, game string, url string) (err s.RLock() defer s.RUnlock() if s.wsConn == nil { - return errors.New("no websocket connection exists") + return ErrWSNotFound } var usd updateStatusData @@ -307,7 +329,7 @@ func (s *Session) RequestGuildMembers(guildID, query string, limit int) (err err s.RLock() defer s.RUnlock() if s.wsConn == nil { - return errors.New("no websocket connection exists") + return ErrWSNotFound } data := requestGuildMembersData{ @@ -386,7 +408,10 @@ func (s *Session) onEvent(messageType int, message []byte) { // Reconnect // Must immediately disconnect from gateway and reconnect to new gateway. if e.Operation == 7 { - // TODO + s.log(LogInformational, "Closing and reconnecting in response to Op7") + s.Close() + s.reconnect() + return } // Invalid Session @@ -414,6 +439,14 @@ func (s *Session) onEvent(messageType int, message []byte) { return } + if e.Operation == 11 { + s.Lock() + s.LastHeartbeatAck = time.Now().UTC() + s.Unlock() + s.log(LogInformational, "got heartbeat ACK") + return + } + // Do not try to Dispatch a non-Dispatch Message if e.Operation != 0 { // But we probably should be doing something with them. @@ -621,7 +654,7 @@ func (s *Session) identify() error { if s.ShardCount > 1 { if s.ShardID >= s.ShardCount { - return errors.New("ShardID must be less than ShardCount") + return ErrWSShardBounds } data.Shard = &[2]int{s.ShardID, s.ShardCount} @@ -676,6 +709,13 @@ func (s *Session) reconnect() { return } + // Certain race conditions can call reconnect() twice. If this happens, we + // just break out of the reconnect loop + if err == ErrWSAlreadyOpen { + s.log(LogInformational, "Websocket already exists, no need to reconnect") + return + } + s.log(LogError, "error reconnecting to gateway, %s", err) <-time.After(wait * time.Second)