Merge pull request #196 from iopred/stateroles

Add Role tracking. Support disabling individual features in state. Closes #195.
This commit is contained in:
Bruce 2016-05-08 19:34:26 -04:00
commit 1d2d93a7b5

126
state.go
View file

@ -26,7 +26,13 @@ var ErrNilState = errors.New("State not instantiated, please use discordgo.New()
type State struct { type State struct {
sync.RWMutex sync.RWMutex
Ready Ready
MaxMessageCount int MaxMessageCount int
TrackChannels bool
TrackEmojis bool
TrackMembers bool
TrackRoles bool
TrackVoice bool
guildMap map[string]*Guild guildMap map[string]*Guild
channelMap map[string]*Channel channelMap map[string]*Channel
@ -39,6 +45,11 @@ func NewState() *State {
PrivateChannels: []*Channel{}, PrivateChannels: []*Channel{},
Guilds: []*Guild{}, Guilds: []*Guild{},
}, },
TrackChannels: true,
TrackEmojis: true,
TrackMembers: true,
TrackRoles: true,
TrackVoice: true,
guildMap: make(map[string]*Guild), guildMap: make(map[string]*Guild),
channelMap: make(map[string]*Channel), channelMap: make(map[string]*Channel),
} }
@ -227,6 +238,79 @@ func (s *State) Member(guildID, userID string) (*Member, error) {
return nil, errors.New("Member not found.") return nil, errors.New("Member not found.")
} }
// RoleAdd adds a role to the current world state, or
// updates it if it already exists.
func (s *State) RoleAdd(guildID string, role *Role) error {
if s == nil {
return ErrNilState
}
guild, err := s.Guild(guildID)
if err != nil {
return err
}
s.Lock()
defer s.Unlock()
for i, r := range guild.Roles {
if r.ID == role.ID {
guild.Roles[i] = role
return nil
}
}
guild.Roles = append(guild.Roles, role)
return nil
}
// RoleRemove removes a role from current world state by ID.
func (s *State) RoleRemove(guildID, roleID string) error {
if s == nil {
return ErrNilState
}
guild, err := s.Guild(guildID)
if err != nil {
return err
}
s.Lock()
defer s.Unlock()
for i, r := range guild.Roles {
if r.ID == roleID {
guild.Roles = append(guild.Roles[:i], guild.Roles[i+1:]...)
return nil
}
}
return errors.New("Role not found.")
}
// Role gets a role by ID from a guild.
func (s *State) Role(guildID, roleID string) (*Role, error) {
if s == nil {
return nil, ErrNilState
}
guild, err := s.Guild(guildID)
if err != nil {
return nil, err
}
s.RLock()
defer s.RUnlock()
for _, r := range guild.Roles {
if r.ID == roleID {
return r, nil
}
}
return nil, errors.New("Role not found.")
}
// ChannelAdd adds a guild to the current world state, or // ChannelAdd adds a guild to the current world state, or
// updates it if it already exists. // updates it if it already exists.
// Channels may exist either as PrivateChannels or inside // Channels may exist either as PrivateChannels or inside
@ -401,10 +485,6 @@ func (s *State) MessageAdd(message *Message) error {
return ErrNilState return ErrNilState
} }
if s.MaxMessageCount == 0 {
return nil
}
c, err := s.Channel(message.ChannelID) c, err := s.Channel(message.ChannelID)
if err != nil { if err != nil {
return err return err
@ -450,10 +530,6 @@ func (s *State) MessageRemove(message *Message) error {
return ErrNilState return ErrNilState
} }
if s.MaxMessageCount == 0 {
return nil
}
c, err := s.Channel(message.ChannelID) c, err := s.Channel(message.ChannelID)
if err != nil { if err != nil {
return err return err
@ -545,28 +621,62 @@ func (s *State) onInterface(se *Session, i interface{}) (err error) {
case *GuildDelete: case *GuildDelete:
err = s.GuildRemove(t.Guild) err = s.GuildRemove(t.Guild)
case *GuildMemberAdd: case *GuildMemberAdd:
if s.TrackMembers {
err = s.MemberAdd(t.Member) err = s.MemberAdd(t.Member)
}
case *GuildMemberUpdate: case *GuildMemberUpdate:
if s.TrackMembers {
err = s.MemberAdd(t.Member) err = s.MemberAdd(t.Member)
}
case *GuildMemberRemove: case *GuildMemberRemove:
if s.TrackMembers {
err = s.MemberRemove(t.Member) err = s.MemberRemove(t.Member)
}
case *GuildRoleCreate:
if s.TrackRoles {
err = s.RoleAdd(t.GuildID, t.Role)
}
case *GuildRoleUpdate:
if s.TrackRoles {
err = s.RoleAdd(t.GuildID, t.Role)
}
case *GuildRoleDelete:
if s.TrackRoles {
err = s.RoleRemove(t.GuildID, t.RoleID)
}
case *GuildEmojisUpdate: case *GuildEmojisUpdate:
if s.TrackEmojis {
err = s.EmojisAdd(t.GuildID, t.Emojis) err = s.EmojisAdd(t.GuildID, t.Emojis)
}
case *ChannelCreate: case *ChannelCreate:
if s.TrackChannels {
err = s.ChannelAdd(t.Channel) err = s.ChannelAdd(t.Channel)
}
case *ChannelUpdate: case *ChannelUpdate:
if s.TrackChannels {
err = s.ChannelAdd(t.Channel) err = s.ChannelAdd(t.Channel)
}
case *ChannelDelete: case *ChannelDelete:
if s.TrackChannels {
err = s.ChannelRemove(t.Channel) err = s.ChannelRemove(t.Channel)
}
case *MessageCreate: case *MessageCreate:
if s.MaxMessageCount != 0 {
err = s.MessageAdd(t.Message) err = s.MessageAdd(t.Message)
}
case *MessageUpdate: case *MessageUpdate:
if s.MaxMessageCount != 0 {
err = s.MessageAdd(t.Message) err = s.MessageAdd(t.Message)
}
case *MessageDelete: case *MessageDelete:
if s.MaxMessageCount != 0 {
err = s.MessageRemove(t.Message) err = s.MessageRemove(t.Message)
}
case *VoiceStateUpdate: case *VoiceStateUpdate:
if s.TrackVoice {
err = s.voiceStateUpdate(t) err = s.voiceStateUpdate(t)
} }
}
return return
} }