Add fast lookups for guilds and channels in state.

This commit is contained in:
Chris Rhodes 2016-04-10 16:17:05 -07:00
parent 6101abeb27
commit fa63c00fa2
2 changed files with 60 additions and 86 deletions

123
state.go
View file

@ -24,6 +24,8 @@ func NewState() *State {
PrivateChannels: []*Channel{},
Guilds: []*Guild{},
},
guildMap: make(map[string]*Guild),
channelMap: make(map[string]*Channel),
}
}
@ -42,6 +44,8 @@ func (s *State) OnReady(r *Ready) error {
for _, c := range g.Channels {
c.GuildID = g.ID
}
s.guildMap[g.ID] = g
}
return nil
@ -54,30 +58,31 @@ func (s *State) GuildAdd(guild *Guild) error {
return ErrNilState
}
s.Lock()
defer s.Unlock()
// Otherwise, update the channels to point to the right guild
for _, c := range guild.Channels {
c.GuildID = guild.ID
}
// If the guild exists, replace it.
for i, g := range s.Guilds {
if g.ID == guild.ID {
// If this guild already exists with data, don't stomp on props
if g, err := s.Guild(guild.ID); err == nil {
s.Lock()
defer s.Unlock()
if !g.Unavailable {
guild.Members = g.Members
guild.Presences = g.Presences
guild.Channels = g.Channels
guild.VoiceStates = g.VoiceStates
}
s.Guilds[i] = guild
*g = *guild
return nil
}
}
s.Lock()
defer s.Unlock()
s.Guilds = append(s.Guilds, guild)
s.guildMap[guild.ID] = guild
return nil
}
@ -87,6 +92,11 @@ func (s *State) GuildRemove(guild *Guild) error {
return ErrNilState
}
_, err := s.Guild(guild.ID)
if err != nil {
return err
}
s.Lock()
defer s.Unlock()
@ -97,7 +107,9 @@ func (s *State) GuildRemove(guild *Guild) error {
}
}
return errors.New("Guild not found.")
delete(s.guildMap, guild.ID)
return nil
}
// Guild gets a guild by ID.
@ -112,11 +124,9 @@ func (s *State) Guild(guildID string) (*Guild, error) {
s.RLock()
defer s.RUnlock()
for _, g := range s.Guilds {
if g.ID == guildID {
if g, ok := s.guildMap[guildID]; ok {
return g, nil
}
}
return nil, errors.New("Guild not found.")
}
@ -205,20 +215,22 @@ func (s *State) ChannelAdd(channel *Channel) error {
return ErrNilState
}
if channel.IsPrivate {
if c, err := s.Channel(channel.ID); err == nil {
s.Lock()
defer s.Unlock()
// If the channel exists, replace it.
for i, c := range s.PrivateChannels {
if c.ID == channel.ID {
// Don't stomp on messages.
channel.Messages = c.Messages
s.PrivateChannels[i] = channel
channel.PermissionOverwrites = c.PermissionOverwrites
*c = *channel
return nil
}
}
s.Lock()
defer s.Unlock()
if channel.IsPrivate {
s.PrivateChannels = append(s.PrivateChannels, channel)
} else {
guild, err := s.Guild(channel.GuildID)
@ -226,22 +238,11 @@ func (s *State) ChannelAdd(channel *Channel) error {
return err
}
s.Lock()
defer s.Unlock()
// If the channel exists, replace it.
for i, c := range guild.Channels {
if c.ID == channel.ID {
// Don't stomp on messages.
channel.Messages = c.Messages
guild.Channels[i] = channel
return nil
}
}
guild.Channels = append(guild.Channels, channel)
}
s.channelMap[channel.ID] = channel
return nil
}
@ -251,6 +252,11 @@ func (s *State) ChannelRemove(channel *Channel) error {
return ErrNilState
}
_, err := s.Channel(channel.ID)
if err != nil {
return err
}
if channel.IsPrivate {
s.Lock()
defer s.Unlock()
@ -278,48 +284,21 @@ func (s *State) ChannelRemove(channel *Channel) error {
}
}
return errors.New("Channel not found.")
delete(s.channelMap, channel.ID)
return nil
}
// GuildChannel gets a channel by ID from a guild.
// This method is Deprecated, use Channel(channelID)
func (s *State) GuildChannel(guildID, channelID string) (*Channel, error) {
if s == nil {
return nil, ErrNilState
}
guild, err := s.Guild(guildID)
if err != nil {
return nil, err
}
s.RLock()
defer s.RUnlock()
for _, c := range guild.Channels {
if c.ID == channelID {
return c, nil
}
}
return nil, errors.New("Channel not found.")
return s.Channel(channelID)
}
// PrivateChannel gets a private channel by ID.
// This method is Deprecated, use Channel(channelID)
func (s *State) PrivateChannel(channelID string) (*Channel, error) {
if s == nil {
return nil, ErrNilState
}
s.RLock()
defer s.RUnlock()
for _, c := range s.PrivateChannels {
if c.ID == channelID {
return c, nil
}
}
return nil, errors.New("Channel not found.")
return s.Channel(channelID)
}
// Channel gets a channel by ID, it will look in all guilds an private channels.
@ -328,18 +307,10 @@ func (s *State) Channel(channelID string) (*Channel, error) {
return nil, ErrNilState
}
c, err := s.PrivateChannel(channelID)
if err == nil {
if c, ok := s.channelMap[channelID]; ok {
return c, nil
}
for _, g := range s.Guilds {
c, err := s.GuildChannel(g.ID, channelID)
if err == nil {
return c, nil
}
}
return nil, errors.New("Channel not found.")
}

View file

@ -368,6 +368,9 @@ type State struct {
sync.RWMutex
Ready
MaxMessageCount int
guildMap map[string]*Guild
channelMap map[string]*Channel
}
// Constants for the different bit offsets of text channel permissions