diff --git a/state.go b/state.go index e3c8908..00cd83e 100644 --- a/state.go +++ b/state.go @@ -26,7 +26,13 @@ var ErrNilState = errors.New("State not instantiated, please use discordgo.New() type State struct { sync.RWMutex Ready + MaxMessageCount int + TrackChannels bool + TrackEmojis bool + TrackMembers bool + TrackRoles bool + TrackVoice bool guildMap map[string]*Guild channelMap map[string]*Channel @@ -39,8 +45,13 @@ func NewState() *State { PrivateChannels: []*Channel{}, Guilds: []*Guild{}, }, - guildMap: make(map[string]*Guild), - channelMap: make(map[string]*Channel), + TrackChannels: true, + TrackEmojis: true, + TrackMembers: true, + TrackRoles: true, + TrackVoice: true, + guildMap: make(map[string]*Guild), + channelMap: make(map[string]*Channel), } } @@ -227,6 +238,79 @@ func (s *State) Member(guildID, userID string) (*Member, error) { return nil, errors.New("Member not found.") } +// RoleAdd adds a role to the current world state, or +// updates it if it already exists. +func (s *State) RoleAdd(guildID string, role *Role) error { + if s == nil { + return ErrNilState + } + + guild, err := s.Guild(guildID) + if err != nil { + return err + } + + s.Lock() + defer s.Unlock() + + for i, r := range guild.Roles { + if r.ID == role.ID { + guild.Roles[i] = role + return nil + } + } + + guild.Roles = append(guild.Roles, role) + return nil +} + +// RoleRemove removes a role from current world state by ID. +func (s *State) RoleRemove(guildID, roleID string) error { + if s == nil { + return ErrNilState + } + + guild, err := s.Guild(guildID) + if err != nil { + return err + } + + s.Lock() + defer s.Unlock() + + for i, r := range guild.Roles { + if r.ID == roleID { + guild.Roles = append(guild.Roles[:i], guild.Roles[i+1:]...) + return nil + } + } + + return errors.New("Role not found.") +} + +// Role gets a role by ID from a guild. +func (s *State) Role(guildID, roleID string) (*Role, error) { + if s == nil { + return nil, ErrNilState + } + + guild, err := s.Guild(guildID) + if err != nil { + return nil, err + } + + s.RLock() + defer s.RUnlock() + + for _, r := range guild.Roles { + if r.ID == roleID { + return r, nil + } + } + + return nil, errors.New("Role not found.") +} + // ChannelAdd adds a guild to the current world state, or // updates it if it already exists. // Channels may exist either as PrivateChannels or inside @@ -324,7 +408,7 @@ func (s *State) Channel(channelID string) (*Channel, error) { if s == nil { return nil, ErrNilState } - + s.RLock() defer s.RUnlock() @@ -401,10 +485,6 @@ func (s *State) MessageAdd(message *Message) error { return ErrNilState } - if s.MaxMessageCount == 0 { - return nil - } - c, err := s.Channel(message.ChannelID) if err != nil { return err @@ -435,10 +515,6 @@ func (s *State) MessageRemove(message *Message) error { return ErrNilState } - if s.MaxMessageCount == 0 { - return nil - } - c, err := s.Channel(message.ChannelID) if err != nil { return err @@ -530,27 +606,61 @@ func (s *State) onInterface(se *Session, i interface{}) (err error) { case *GuildDelete: err = s.GuildRemove(t.Guild) case *GuildMemberAdd: - err = s.MemberAdd(t.Member) + if s.TrackMembers { + err = s.MemberAdd(t.Member) + } case *GuildMemberUpdate: - err = s.MemberAdd(t.Member) + if s.TrackMembers { + err = s.MemberAdd(t.Member) + } case *GuildMemberRemove: - err = s.MemberRemove(t.Member) + if s.TrackMembers { + err = s.MemberRemove(t.Member) + } + case *GuildRoleCreate: + if s.TrackRoles { + err = s.RoleAdd(t.GuildID, t.Role) + } + case *GuildRoleUpdate: + if s.TrackRoles { + err = s.RoleAdd(t.GuildID, t.Role) + } + case *GuildRoleDelete: + if s.TrackRoles { + err = s.RoleRemove(t.GuildID, t.RoleID) + } case *GuildEmojisUpdate: - err = s.EmojisAdd(t.GuildID, t.Emojis) + if s.TrackEmojis { + err = s.EmojisAdd(t.GuildID, t.Emojis) + } case *ChannelCreate: - err = s.ChannelAdd(t.Channel) + if s.TrackChannels { + err = s.ChannelAdd(t.Channel) + } case *ChannelUpdate: - err = s.ChannelAdd(t.Channel) + if s.TrackChannels { + err = s.ChannelAdd(t.Channel) + } case *ChannelDelete: - err = s.ChannelRemove(t.Channel) + if s.TrackChannels { + err = s.ChannelRemove(t.Channel) + } case *MessageCreate: - err = s.MessageAdd(t.Message) + if s.MaxMessageCount != 0 { + err = s.MessageAdd(t.Message) + } case *MessageUpdate: - err = s.MessageAdd(t.Message) + if s.MaxMessageCount != 0 { + err = s.MessageAdd(t.Message) + } case *MessageDelete: - err = s.MessageRemove(t.Message) + if s.MaxMessageCount != 0 { + err = s.MessageRemove(t.Message) + } case *VoiceStateUpdate: - err = s.voiceStateUpdate(t) + if s.TrackVoice { + err = s.voiceStateUpdate(t) + } } return