diff --git a/message.go b/message.go index 59d0297..101a17c 100644 --- a/message.go +++ b/message.go @@ -70,6 +70,9 @@ type Embed struct { // ContentWithMentionsReplaced will replace all @ mentions with the // username of the mention. func (m *Message) ContentWithMentionsReplaced() string { + if m.Mentions == nil { + return m.Content + } content := m.Content for _, user := range m.Mentions { content = strings.Replace(content, fmt.Sprintf("<@%s>", user.ID), diff --git a/state.go b/state.go index 23731b6..a82cd26 100644 --- a/state.go +++ b/state.go @@ -31,6 +31,8 @@ func (s *State) OnReady(r *Ready) error { if s == nil { return nilError } + s.Lock() + defer s.Unlock() s.Ready = *r return nil @@ -42,16 +44,13 @@ func (s *State) GuildAdd(guild *Guild) error { if s == nil { return nilError } + s.Lock() + defer s.Unlock() - for _, g := range s.Guilds { + // If the guild exists, replace it. + for i, g := range s.Guilds { if g.ID == guild.ID { - // This could be a little faster ;) - for _, m := range guild.Members { - s.MemberAdd(m) - } - for _, c := range guild.Channels { - s.ChannelAdd(c) - } + s.Guilds[i] = guild return nil } } @@ -65,6 +64,8 @@ func (s *State) GuildRemove(guild *Guild) error { if s == nil { return nilError } + s.Lock() + defer s.Unlock() for i, g := range s.Guilds { if g.ID == guild.ID { @@ -84,6 +85,8 @@ func (s *State) Guild(guildID string) (*Guild, error) { if s == nil { return nil, nilError } + s.RLock() + defer s.RUnlock() for _, g := range s.Guilds { if g.ID == guildID { @@ -102,6 +105,8 @@ func (s *State) MemberAdd(member *Member) error { if s == nil { return nilError } + s.Lock() + defer s.Unlock() guild, err := s.Guild(member.GuildID) if err != nil { @@ -124,6 +129,8 @@ func (s *State) MemberRemove(member *Member) error { if s == nil { return nilError } + s.Lock() + defer s.Unlock() guild, err := s.Guild(member.GuildID) if err != nil { @@ -145,6 +152,8 @@ func (s *State) Member(guildID, userID string) (*Member, error) { if s == nil { return nil, nilError } + s.RLock() + defer s.RUnlock() guild, err := s.Guild(guildID) if err != nil { @@ -168,8 +177,11 @@ func (s *State) ChannelAdd(channel *Channel) error { if s == nil { return nilError } + s.Lock() + defer s.Unlock() if channel.IsPrivate { + // If the channel exists, replace it. for i, c := range s.PrivateChannels { if c.ID == channel.ID { s.PrivateChannels[i] = channel @@ -184,6 +196,7 @@ func (s *State) ChannelAdd(channel *Channel) error { return err } + // If the channel exists, replace it. for i, c := range guild.Channels { if c.ID == channel.ID { guild.Channels[i] = channel @@ -202,6 +215,8 @@ func (s *State) ChannelRemove(channel *Channel) error { if s == nil { return nilError } + s.Lock() + defer s.Unlock() if channel.IsPrivate { for i, c := range s.PrivateChannels { @@ -232,6 +247,8 @@ func (s *State) GuildChannel(guildID, channelID string) (*Channel, error) { if s == nil { return nil, nilError } + s.RLock() + defer s.RUnlock() guild, err := s.Guild(guildID) if err != nil { @@ -252,6 +269,8 @@ func (s *State) PrivateChannel(channelID string) (*Channel, error) { if s == nil { return nil, nilError } + s.RLock() + defer s.RUnlock() for _, c := range s.PrivateChannels { if c.ID == channelID { @@ -288,6 +307,8 @@ func (s *State) Emoji(guildID, emojiID string) (*Emoji, error) { if s == nil { return nil, nilError } + s.RLock() + defer s.RUnlock() guild, err := s.Guild(guildID) if err != nil { @@ -308,6 +329,8 @@ func (s *State) EmojiAdd(guildID string, emoji *Emoji) error { if s == nil { return nilError } + s.Lock() + defer s.Unlock() guild, err := s.Guild(guildID) if err != nil { @@ -334,3 +357,84 @@ func (s *State) EmojisAdd(guildID string, emojis []*Emoji) error { } return nil } + +// MessageAdd adds a message to the current world state, or updates it if it exists. +// If the channel cannot be found, the message is discarded. +// Messages are kept in state up to s.MaxMessageCount +func (s *State) MessageAdd(message *Message) error { + if s == nil { + return nilError + } + + c, err := s.Channel(message.ChannelID) + if err != nil { + return err + } + + s.Lock() + defer s.Unlock() + + // If the message exists, replace it. + for i, m := range c.Messages { + if m.ID == message.ID { + c.Messages[i] = message + return nil + } + } + + c.Messages = append(c.Messages, message) + + if len(c.Messages) > s.MaxMessageCount { + s.Unlock() + for len(c.Messages) > s.MaxMessageCount { + s.MessageRemove(c.Messages[0]) + } + s.Lock() + } + return nil +} + +// MessageRemove removes a message from the world state. +func (s *State) MessageRemove(message *Message) error { + if s == nil { + return nilError + } + c, err := s.Channel(message.ChannelID) + if err != nil { + return err + } + + s.Lock() + defer s.Unlock() + + for i, m := range c.Messages { + if m.ID == message.ID { + c.Messages = append(c.Messages[:i], c.Messages[i+1:]...) + return nil + } + } + + return errors.New("Message not found.") +} + +// Message gets a message by channel and message ID. +func (s *State) Message(channelID, messageID string) (*Message, error) { + if s == nil { + return nil, nilError + } + c, err := s.Channel(channelID) + if err != nil { + return nil, err + } + + s.RLock() + defer s.RUnlock() + + for _, m := range c.Messages { + if m.ID == messageID { + return m, nil + } + } + + return nil, errors.New("Message not found.") +} diff --git a/structs.go b/structs.go index ff63a24..091eda2 100644 --- a/structs.go +++ b/structs.go @@ -33,7 +33,7 @@ type Session struct { OnTypingStart func(*Session, *TypingStart) OnMessageCreate func(*Session, *Message) OnMessageUpdate func(*Session, *Message) - OnMessageDelete func(*Session, *MessageDelete) + OnMessageDelete func(*Session, *Message) OnMessageAck func(*Session, *MessageAck) OnUserUpdate func(*Session, *User) OnPresenceUpdate func(*Session, *PresenceUpdate) @@ -46,7 +46,7 @@ type Session struct { OnGuildDelete func(*Session, *Guild) OnGuildMemberAdd func(*Session, *Member) OnGuildMemberRemove func(*Session, *Member) - OnGuildMemberDelete func(*Session, *Member) // which is it? + OnGuildMemberDelete func(*Session, *Member) OnGuildMemberUpdate func(*Session, *Member) OnGuildRoleCreate func(*Session, *GuildRole) OnGuildRoleUpdate func(*Session, *GuildRole) @@ -77,8 +77,9 @@ type Session struct { Voice *Voice // Stores all details related to voice connections // Managed state object, updated with events. - State *State - StateEnabled bool + State *State + StateEnabled bool + StateMaxMessageCount int // Mutex/Bools for locks that prevent accidents. // TODO: Add channels. @@ -138,6 +139,7 @@ type Channel struct { IsPrivate bool `json:"is_private"` LastMessageID string `json:"last_message_id"` Recipient *User `json:"recipient"` + Messages []*Message `json:"-"` } // A PermissionOverwrite holds permission overwrite data for a Channel @@ -309,12 +311,6 @@ type MessageAck struct { ChannelID string `json:"channel_id"` } -// A MessageDelete stores data for the message delete websocket event. -type MessageDelete struct { - ID string `json:"id"` - ChannelID string `json:"channel_id"` -} // so much like MessageAck.. - // A GuildIntegrationsUpdate stores data for the guild integrations update // websocket event. type GuildIntegrationsUpdate struct { @@ -349,5 +345,7 @@ type GuildEmojisUpdate struct { // As discord sends this in a READY blob, it seems reasonable to simply // use that struct as the data store. type State struct { + sync.RWMutex Ready + MaxMessageCount int } diff --git a/wsapi.go b/wsapi.go index 4dceddf..4b63813 100644 --- a/wsapi.go +++ b/wsapi.go @@ -251,27 +251,53 @@ func (s *Session) event(messageType int, message []byte) (err error) { } */ case "MESSAGE_CREATE": - if s.OnMessageCreate != nil { - var st *Message - if err = unmarshalEvent(e, &st); err == nil { + if !s.StateEnabled && s.OnMessageCreate == nil { + break + } + var st *Message + if err = unmarshalEvent(e, &st); err == nil { + if s.StateEnabled { + fmt.Println(s.State.MessageAdd(st)) + } + if s.OnMessageCreate != nil { s.OnMessageCreate(s, st) } + } + if s.OnMessageCreate != nil { return } case "MESSAGE_UPDATE": - if s.OnMessageUpdate != nil { - var st *Message - if err = unmarshalEvent(e, &st); err == nil { + if !s.StateEnabled && s.OnMessageUpdate == nil { + break + } + var st *Message + if err = unmarshalEvent(e, &st); err == nil { + if s.StateEnabled { + s.State.MessageAdd(st) + } + if s.OnMessageUpdate != nil { s.OnMessageUpdate(s, st) } + } + return + if s.OnMessageUpdate != nil { return } case "MESSAGE_DELETE": - if s.OnMessageDelete != nil { - var st *MessageDelete - if err = unmarshalEvent(e, &st); err == nil { + if !s.StateEnabled && s.OnMessageDelete == nil { + break + } + var st *Message + if err = unmarshalEvent(e, &st); err == nil { + if s.StateEnabled { + s.State.MessageRemove(st) + } + if s.OnMessageDelete != nil { s.OnMessageDelete(s, st) } + } + return + if s.OnMessageDelete != nil { return } case "MESSAGE_ACK":