From 6b73b588ba659228656463a96c4f152c89a32f2f Mon Sep 17 00:00:00 2001 From: Chris Rhodes Date: Sat, 16 Jan 2016 00:30:05 -0800 Subject: [PATCH] Message tracking. --- state.go | 52 ++++++++++++++++++++++++++++++++++++---------------- wsapi.go | 17 ++++++++++------- 2 files changed, 46 insertions(+), 23 deletions(-) diff --git a/state.go b/state.go index a82cd26..a2a209c 100644 --- a/state.go +++ b/state.go @@ -50,6 +50,11 @@ func (s *State) GuildAdd(guild *Guild) error { // If the guild exists, replace it. for i, g := range s.Guilds { if g.ID == guild.ID { + // Don't stomp on properties that don't come in updates. + guild.Members = g.Members + guild.Presences = g.Presences + guild.Channels = g.Channels + guild.VoiceStates = g.VoiceStates s.Guilds[i] = guild return nil } @@ -105,14 +110,15 @@ func (s *State) MemberAdd(member *Member) error { if s == nil { return nilError } - s.Lock() - defer s.Unlock() guild, err := s.Guild(member.GuildID) if err != nil { return err } + s.Lock() + defer s.Unlock() + for i, m := range guild.Members { if m.User.ID == member.User.ID { guild.Members[i] = member @@ -129,14 +135,15 @@ func (s *State) MemberRemove(member *Member) error { if s == nil { return nilError } - s.Lock() - defer s.Unlock() guild, err := s.Guild(member.GuildID) if err != nil { return err } + s.Lock() + defer s.Unlock() + for i, m := range guild.Members { if m.User.ID == member.User.ID { guild.Members = append(guild.Members[:i], guild.Members[i+1:]...) @@ -152,14 +159,15 @@ func (s *State) Member(guildID, userID string) (*Member, error) { if s == nil { return nil, nilError } - s.RLock() - defer s.RUnlock() guild, err := s.Guild(guildID) if err != nil { return nil, err } + s.RLock() + defer s.RUnlock() + for _, m := range guild.Members { if m.User.ID == userID { return m, nil @@ -177,13 +185,16 @@ func (s *State) ChannelAdd(channel *Channel) error { if s == nil { return nilError } - s.Lock() - defer s.Unlock() if channel.IsPrivate { + s.Lock() + defer s.Unlock() + // If the channel exists, replace it. for i, c := range s.PrivateChannels { if c.ID == channel.ID { + // Don't stomp on messages. + channel.Messages = c.Messages s.PrivateChannels[i] = channel return nil } @@ -196,9 +207,14 @@ func (s *State) ChannelAdd(channel *Channel) error { return err } + s.Lock() + defer s.Unlock() + // If the channel exists, replace it. for i, c := range guild.Channels { if c.ID == channel.ID { + // Don't stomp on messages. + channel.Messages = c.Messages guild.Channels[i] = channel return nil } @@ -215,8 +231,6 @@ func (s *State) ChannelRemove(channel *Channel) error { if s == nil { return nilError } - s.Lock() - defer s.Unlock() if channel.IsPrivate { for i, c := range s.PrivateChannels { @@ -231,6 +245,9 @@ func (s *State) ChannelRemove(channel *Channel) error { return err } + s.Lock() + defer s.Unlock() + for i, c := range guild.Channels { if c.ID == channel.ID { guild.Channels = append(guild.Channels[:i], guild.Channels[i+1:]...) @@ -247,14 +264,15 @@ func (s *State) GuildChannel(guildID, channelID string) (*Channel, error) { if s == nil { return nil, nilError } - s.RLock() - defer s.RUnlock() guild, err := s.Guild(guildID) if err != nil { return nil, err } + s.RLock() + defer s.RUnlock() + for _, c := range guild.Channels { if c.ID == channelID { return c, nil @@ -307,14 +325,15 @@ func (s *State) Emoji(guildID, emojiID string) (*Emoji, error) { if s == nil { return nil, nilError } - s.RLock() - defer s.RUnlock() guild, err := s.Guild(guildID) if err != nil { return nil, err } + s.RLock() + defer s.RUnlock() + for _, e := range guild.Emojis { if e.ID == emojiID { return e, nil @@ -329,14 +348,15 @@ func (s *State) EmojiAdd(guildID string, emoji *Emoji) error { if s == nil { return nilError } - s.Lock() - defer s.Unlock() guild, err := s.Guild(guildID) if err != nil { return err } + s.Lock() + defer s.Unlock() + for i, e := range guild.Emojis { if e.ID == emoji.ID { guild.Emojis[i] = emoji diff --git a/wsapi.go b/wsapi.go index 4b63813..3f8bafc 100644 --- a/wsapi.go +++ b/wsapi.go @@ -251,13 +251,14 @@ func (s *Session) event(messageType int, message []byte) (err error) { } */ case "MESSAGE_CREATE": - if !s.StateEnabled && s.OnMessageCreate == nil { + stateEnabled := s.StateEnabled && s.State.MaxMessageCount > 0 + if !stateEnabled && s.OnMessageCreate == nil { break } var st *Message if err = unmarshalEvent(e, &st); err == nil { - if s.StateEnabled { - fmt.Println(s.State.MessageAdd(st)) + if stateEnabled { + s.State.MessageAdd(st) } if s.OnMessageCreate != nil { s.OnMessageCreate(s, st) @@ -267,12 +268,13 @@ func (s *Session) event(messageType int, message []byte) (err error) { return } case "MESSAGE_UPDATE": - if !s.StateEnabled && s.OnMessageUpdate == nil { + stateEnabled := s.StateEnabled && s.State.MaxMessageCount > 0 + if !stateEnabled && s.OnMessageUpdate == nil { break } var st *Message if err = unmarshalEvent(e, &st); err == nil { - if s.StateEnabled { + if stateEnabled { s.State.MessageAdd(st) } if s.OnMessageUpdate != nil { @@ -284,12 +286,13 @@ func (s *Session) event(messageType int, message []byte) (err error) { return } case "MESSAGE_DELETE": - if !s.StateEnabled && s.OnMessageDelete == nil { + stateEnabled := s.StateEnabled && s.State.MaxMessageCount > 0 + if !stateEnabled && s.OnMessageDelete == nil { break } var st *Message if err = unmarshalEvent(e, &st); err == nil { - if s.StateEnabled { + if stateEnabled { s.State.MessageRemove(st) } if s.OnMessageDelete != nil {