From 0f38b22ca1d6b4233c1bb2e8fde6b64b9fb0600a Mon Sep 17 00:00:00 2001 From: Chris Rhodes Date: Fri, 15 Jan 2016 22:01:56 -0800 Subject: [PATCH 1/3] Message state tracking. --- message.go | 3 ++ state.go | 120 +++++++++++++++++++++++++++++++++++++++++++++++++---- structs.go | 18 ++++---- wsapi.go | 44 ++++++++++++++++---- 4 files changed, 158 insertions(+), 27 deletions(-) diff --git a/message.go b/message.go index 59d0297..101a17c 100644 --- a/message.go +++ b/message.go @@ -70,6 +70,9 @@ type Embed struct { // ContentWithMentionsReplaced will replace all @ mentions with the // username of the mention. func (m *Message) ContentWithMentionsReplaced() string { + if m.Mentions == nil { + return m.Content + } content := m.Content for _, user := range m.Mentions { content = strings.Replace(content, fmt.Sprintf("<@%s>", user.ID), diff --git a/state.go b/state.go index 23731b6..a82cd26 100644 --- a/state.go +++ b/state.go @@ -31,6 +31,8 @@ func (s *State) OnReady(r *Ready) error { if s == nil { return nilError } + s.Lock() + defer s.Unlock() s.Ready = *r return nil @@ -42,16 +44,13 @@ func (s *State) GuildAdd(guild *Guild) error { if s == nil { return nilError } + s.Lock() + defer s.Unlock() - for _, g := range s.Guilds { + // If the guild exists, replace it. + for i, g := range s.Guilds { if g.ID == guild.ID { - // This could be a little faster ;) - for _, m := range guild.Members { - s.MemberAdd(m) - } - for _, c := range guild.Channels { - s.ChannelAdd(c) - } + s.Guilds[i] = guild return nil } } @@ -65,6 +64,8 @@ func (s *State) GuildRemove(guild *Guild) error { if s == nil { return nilError } + s.Lock() + defer s.Unlock() for i, g := range s.Guilds { if g.ID == guild.ID { @@ -84,6 +85,8 @@ func (s *State) Guild(guildID string) (*Guild, error) { if s == nil { return nil, nilError } + s.RLock() + defer s.RUnlock() for _, g := range s.Guilds { if g.ID == guildID { @@ -102,6 +105,8 @@ func (s *State) MemberAdd(member *Member) error { if s == nil { return nilError } + s.Lock() + defer s.Unlock() guild, err := s.Guild(member.GuildID) if err != nil { @@ -124,6 +129,8 @@ func (s *State) MemberRemove(member *Member) error { if s == nil { return nilError } + s.Lock() + defer s.Unlock() guild, err := s.Guild(member.GuildID) if err != nil { @@ -145,6 +152,8 @@ func (s *State) Member(guildID, userID string) (*Member, error) { if s == nil { return nil, nilError } + s.RLock() + defer s.RUnlock() guild, err := s.Guild(guildID) if err != nil { @@ -168,8 +177,11 @@ func (s *State) ChannelAdd(channel *Channel) error { if s == nil { return nilError } + s.Lock() + defer s.Unlock() if channel.IsPrivate { + // If the channel exists, replace it. for i, c := range s.PrivateChannels { if c.ID == channel.ID { s.PrivateChannels[i] = channel @@ -184,6 +196,7 @@ func (s *State) ChannelAdd(channel *Channel) error { return err } + // If the channel exists, replace it. for i, c := range guild.Channels { if c.ID == channel.ID { guild.Channels[i] = channel @@ -202,6 +215,8 @@ func (s *State) ChannelRemove(channel *Channel) error { if s == nil { return nilError } + s.Lock() + defer s.Unlock() if channel.IsPrivate { for i, c := range s.PrivateChannels { @@ -232,6 +247,8 @@ func (s *State) GuildChannel(guildID, channelID string) (*Channel, error) { if s == nil { return nil, nilError } + s.RLock() + defer s.RUnlock() guild, err := s.Guild(guildID) if err != nil { @@ -252,6 +269,8 @@ func (s *State) PrivateChannel(channelID string) (*Channel, error) { if s == nil { return nil, nilError } + s.RLock() + defer s.RUnlock() for _, c := range s.PrivateChannels { if c.ID == channelID { @@ -288,6 +307,8 @@ func (s *State) Emoji(guildID, emojiID string) (*Emoji, error) { if s == nil { return nil, nilError } + s.RLock() + defer s.RUnlock() guild, err := s.Guild(guildID) if err != nil { @@ -308,6 +329,8 @@ func (s *State) EmojiAdd(guildID string, emoji *Emoji) error { if s == nil { return nilError } + s.Lock() + defer s.Unlock() guild, err := s.Guild(guildID) if err != nil { @@ -334,3 +357,84 @@ func (s *State) EmojisAdd(guildID string, emojis []*Emoji) error { } return nil } + +// MessageAdd adds a message to the current world state, or updates it if it exists. +// If the channel cannot be found, the message is discarded. +// Messages are kept in state up to s.MaxMessageCount +func (s *State) MessageAdd(message *Message) error { + if s == nil { + return nilError + } + + c, err := s.Channel(message.ChannelID) + if err != nil { + return err + } + + s.Lock() + defer s.Unlock() + + // If the message exists, replace it. + for i, m := range c.Messages { + if m.ID == message.ID { + c.Messages[i] = message + return nil + } + } + + c.Messages = append(c.Messages, message) + + if len(c.Messages) > s.MaxMessageCount { + s.Unlock() + for len(c.Messages) > s.MaxMessageCount { + s.MessageRemove(c.Messages[0]) + } + s.Lock() + } + return nil +} + +// MessageRemove removes a message from the world state. +func (s *State) MessageRemove(message *Message) error { + if s == nil { + return nilError + } + c, err := s.Channel(message.ChannelID) + if err != nil { + return err + } + + s.Lock() + defer s.Unlock() + + for i, m := range c.Messages { + if m.ID == message.ID { + c.Messages = append(c.Messages[:i], c.Messages[i+1:]...) + return nil + } + } + + return errors.New("Message not found.") +} + +// Message gets a message by channel and message ID. +func (s *State) Message(channelID, messageID string) (*Message, error) { + if s == nil { + return nil, nilError + } + c, err := s.Channel(channelID) + if err != nil { + return nil, err + } + + s.RLock() + defer s.RUnlock() + + for _, m := range c.Messages { + if m.ID == messageID { + return m, nil + } + } + + return nil, errors.New("Message not found.") +} diff --git a/structs.go b/structs.go index ff63a24..091eda2 100644 --- a/structs.go +++ b/structs.go @@ -33,7 +33,7 @@ type Session struct { OnTypingStart func(*Session, *TypingStart) OnMessageCreate func(*Session, *Message) OnMessageUpdate func(*Session, *Message) - OnMessageDelete func(*Session, *MessageDelete) + OnMessageDelete func(*Session, *Message) OnMessageAck func(*Session, *MessageAck) OnUserUpdate func(*Session, *User) OnPresenceUpdate func(*Session, *PresenceUpdate) @@ -46,7 +46,7 @@ type Session struct { OnGuildDelete func(*Session, *Guild) OnGuildMemberAdd func(*Session, *Member) OnGuildMemberRemove func(*Session, *Member) - OnGuildMemberDelete func(*Session, *Member) // which is it? + OnGuildMemberDelete func(*Session, *Member) OnGuildMemberUpdate func(*Session, *Member) OnGuildRoleCreate func(*Session, *GuildRole) OnGuildRoleUpdate func(*Session, *GuildRole) @@ -77,8 +77,9 @@ type Session struct { Voice *Voice // Stores all details related to voice connections // Managed state object, updated with events. - State *State - StateEnabled bool + State *State + StateEnabled bool + StateMaxMessageCount int // Mutex/Bools for locks that prevent accidents. // TODO: Add channels. @@ -138,6 +139,7 @@ type Channel struct { IsPrivate bool `json:"is_private"` LastMessageID string `json:"last_message_id"` Recipient *User `json:"recipient"` + Messages []*Message `json:"-"` } // A PermissionOverwrite holds permission overwrite data for a Channel @@ -309,12 +311,6 @@ type MessageAck struct { ChannelID string `json:"channel_id"` } -// A MessageDelete stores data for the message delete websocket event. -type MessageDelete struct { - ID string `json:"id"` - ChannelID string `json:"channel_id"` -} // so much like MessageAck.. - // A GuildIntegrationsUpdate stores data for the guild integrations update // websocket event. type GuildIntegrationsUpdate struct { @@ -349,5 +345,7 @@ type GuildEmojisUpdate struct { // As discord sends this in a READY blob, it seems reasonable to simply // use that struct as the data store. type State struct { + sync.RWMutex Ready + MaxMessageCount int } diff --git a/wsapi.go b/wsapi.go index 4dceddf..4b63813 100644 --- a/wsapi.go +++ b/wsapi.go @@ -251,27 +251,53 @@ func (s *Session) event(messageType int, message []byte) (err error) { } */ case "MESSAGE_CREATE": - if s.OnMessageCreate != nil { - var st *Message - if err = unmarshalEvent(e, &st); err == nil { + if !s.StateEnabled && s.OnMessageCreate == nil { + break + } + var st *Message + if err = unmarshalEvent(e, &st); err == nil { + if s.StateEnabled { + fmt.Println(s.State.MessageAdd(st)) + } + if s.OnMessageCreate != nil { s.OnMessageCreate(s, st) } + } + if s.OnMessageCreate != nil { return } case "MESSAGE_UPDATE": - if s.OnMessageUpdate != nil { - var st *Message - if err = unmarshalEvent(e, &st); err == nil { + if !s.StateEnabled && s.OnMessageUpdate == nil { + break + } + var st *Message + if err = unmarshalEvent(e, &st); err == nil { + if s.StateEnabled { + s.State.MessageAdd(st) + } + if s.OnMessageUpdate != nil { s.OnMessageUpdate(s, st) } + } + return + if s.OnMessageUpdate != nil { return } case "MESSAGE_DELETE": - if s.OnMessageDelete != nil { - var st *MessageDelete - if err = unmarshalEvent(e, &st); err == nil { + if !s.StateEnabled && s.OnMessageDelete == nil { + break + } + var st *Message + if err = unmarshalEvent(e, &st); err == nil { + if s.StateEnabled { + s.State.MessageRemove(st) + } + if s.OnMessageDelete != nil { s.OnMessageDelete(s, st) } + } + return + if s.OnMessageDelete != nil { return } case "MESSAGE_ACK": From 6b73b588ba659228656463a96c4f152c89a32f2f Mon Sep 17 00:00:00 2001 From: Chris Rhodes Date: Sat, 16 Jan 2016 00:30:05 -0800 Subject: [PATCH 2/3] Message tracking. --- state.go | 52 ++++++++++++++++++++++++++++++++++++---------------- wsapi.go | 17 ++++++++++------- 2 files changed, 46 insertions(+), 23 deletions(-) diff --git a/state.go b/state.go index a82cd26..a2a209c 100644 --- a/state.go +++ b/state.go @@ -50,6 +50,11 @@ func (s *State) GuildAdd(guild *Guild) error { // If the guild exists, replace it. for i, g := range s.Guilds { if g.ID == guild.ID { + // Don't stomp on properties that don't come in updates. + guild.Members = g.Members + guild.Presences = g.Presences + guild.Channels = g.Channels + guild.VoiceStates = g.VoiceStates s.Guilds[i] = guild return nil } @@ -105,14 +110,15 @@ func (s *State) MemberAdd(member *Member) error { if s == nil { return nilError } - s.Lock() - defer s.Unlock() guild, err := s.Guild(member.GuildID) if err != nil { return err } + s.Lock() + defer s.Unlock() + for i, m := range guild.Members { if m.User.ID == member.User.ID { guild.Members[i] = member @@ -129,14 +135,15 @@ func (s *State) MemberRemove(member *Member) error { if s == nil { return nilError } - s.Lock() - defer s.Unlock() guild, err := s.Guild(member.GuildID) if err != nil { return err } + s.Lock() + defer s.Unlock() + for i, m := range guild.Members { if m.User.ID == member.User.ID { guild.Members = append(guild.Members[:i], guild.Members[i+1:]...) @@ -152,14 +159,15 @@ func (s *State) Member(guildID, userID string) (*Member, error) { if s == nil { return nil, nilError } - s.RLock() - defer s.RUnlock() guild, err := s.Guild(guildID) if err != nil { return nil, err } + s.RLock() + defer s.RUnlock() + for _, m := range guild.Members { if m.User.ID == userID { return m, nil @@ -177,13 +185,16 @@ func (s *State) ChannelAdd(channel *Channel) error { if s == nil { return nilError } - s.Lock() - defer s.Unlock() if channel.IsPrivate { + s.Lock() + defer s.Unlock() + // If the channel exists, replace it. for i, c := range s.PrivateChannels { if c.ID == channel.ID { + // Don't stomp on messages. + channel.Messages = c.Messages s.PrivateChannels[i] = channel return nil } @@ -196,9 +207,14 @@ func (s *State) ChannelAdd(channel *Channel) error { return err } + s.Lock() + defer s.Unlock() + // If the channel exists, replace it. for i, c := range guild.Channels { if c.ID == channel.ID { + // Don't stomp on messages. + channel.Messages = c.Messages guild.Channels[i] = channel return nil } @@ -215,8 +231,6 @@ func (s *State) ChannelRemove(channel *Channel) error { if s == nil { return nilError } - s.Lock() - defer s.Unlock() if channel.IsPrivate { for i, c := range s.PrivateChannels { @@ -231,6 +245,9 @@ func (s *State) ChannelRemove(channel *Channel) error { return err } + s.Lock() + defer s.Unlock() + for i, c := range guild.Channels { if c.ID == channel.ID { guild.Channels = append(guild.Channels[:i], guild.Channels[i+1:]...) @@ -247,14 +264,15 @@ func (s *State) GuildChannel(guildID, channelID string) (*Channel, error) { if s == nil { return nil, nilError } - s.RLock() - defer s.RUnlock() guild, err := s.Guild(guildID) if err != nil { return nil, err } + s.RLock() + defer s.RUnlock() + for _, c := range guild.Channels { if c.ID == channelID { return c, nil @@ -307,14 +325,15 @@ func (s *State) Emoji(guildID, emojiID string) (*Emoji, error) { if s == nil { return nil, nilError } - s.RLock() - defer s.RUnlock() guild, err := s.Guild(guildID) if err != nil { return nil, err } + s.RLock() + defer s.RUnlock() + for _, e := range guild.Emojis { if e.ID == emojiID { return e, nil @@ -329,14 +348,15 @@ func (s *State) EmojiAdd(guildID string, emoji *Emoji) error { if s == nil { return nilError } - s.Lock() - defer s.Unlock() guild, err := s.Guild(guildID) if err != nil { return err } + s.Lock() + defer s.Unlock() + for i, e := range guild.Emojis { if e.ID == emoji.ID { guild.Emojis[i] = emoji diff --git a/wsapi.go b/wsapi.go index 4b63813..3f8bafc 100644 --- a/wsapi.go +++ b/wsapi.go @@ -251,13 +251,14 @@ func (s *Session) event(messageType int, message []byte) (err error) { } */ case "MESSAGE_CREATE": - if !s.StateEnabled && s.OnMessageCreate == nil { + stateEnabled := s.StateEnabled && s.State.MaxMessageCount > 0 + if !stateEnabled && s.OnMessageCreate == nil { break } var st *Message if err = unmarshalEvent(e, &st); err == nil { - if s.StateEnabled { - fmt.Println(s.State.MessageAdd(st)) + if stateEnabled { + s.State.MessageAdd(st) } if s.OnMessageCreate != nil { s.OnMessageCreate(s, st) @@ -267,12 +268,13 @@ func (s *Session) event(messageType int, message []byte) (err error) { return } case "MESSAGE_UPDATE": - if !s.StateEnabled && s.OnMessageUpdate == nil { + stateEnabled := s.StateEnabled && s.State.MaxMessageCount > 0 + if !stateEnabled && s.OnMessageUpdate == nil { break } var st *Message if err = unmarshalEvent(e, &st); err == nil { - if s.StateEnabled { + if stateEnabled { s.State.MessageAdd(st) } if s.OnMessageUpdate != nil { @@ -284,12 +286,13 @@ func (s *Session) event(messageType int, message []byte) (err error) { return } case "MESSAGE_DELETE": - if !s.StateEnabled && s.OnMessageDelete == nil { + stateEnabled := s.StateEnabled && s.State.MaxMessageCount > 0 + if !stateEnabled && s.OnMessageDelete == nil { break } var st *Message if err = unmarshalEvent(e, &st); err == nil { - if s.StateEnabled { + if stateEnabled { s.State.MessageRemove(st) } if s.OnMessageDelete != nil { From ccb504b129f52cfcab1cf7be91c83d786dba68f2 Mon Sep 17 00:00:00 2001 From: Chris Rhodes Date: Sat, 16 Jan 2016 01:25:05 -0800 Subject: [PATCH 3/3] Fix missing lock --- state.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/state.go b/state.go index a2a209c..e91aede 100644 --- a/state.go +++ b/state.go @@ -233,6 +233,9 @@ func (s *State) ChannelRemove(channel *Channel) error { } if channel.IsPrivate { + s.Lock() + defer s.Unlock() + for i, c := range s.PrivateChannels { if c.ID == channel.ID { s.PrivateChannels = append(s.PrivateChannels[:i], s.PrivateChannels[i+1:]...)