Add fast lookups for guilds and channels in state.

This commit is contained in:
Chris Rhodes 2016-04-10 16:17:05 -07:00
parent 6101abeb27
commit fa63c00fa2
2 changed files with 60 additions and 86 deletions

123
state.go
View file

@ -24,6 +24,8 @@ func NewState() *State {
PrivateChannels: []*Channel{}, PrivateChannels: []*Channel{},
Guilds: []*Guild{}, Guilds: []*Guild{},
}, },
guildMap: make(map[string]*Guild),
channelMap: make(map[string]*Channel),
} }
} }
@ -42,6 +44,8 @@ func (s *State) OnReady(r *Ready) error {
for _, c := range g.Channels { for _, c := range g.Channels {
c.GuildID = g.ID c.GuildID = g.ID
} }
s.guildMap[g.ID] = g
} }
return nil return nil
@ -54,30 +58,31 @@ func (s *State) GuildAdd(guild *Guild) error {
return ErrNilState return ErrNilState
} }
s.Lock()
defer s.Unlock()
// Otherwise, update the channels to point to the right guild // Otherwise, update the channels to point to the right guild
for _, c := range guild.Channels { for _, c := range guild.Channels {
c.GuildID = guild.ID c.GuildID = guild.ID
} }
// If the guild exists, replace it. if g, err := s.Guild(guild.ID); err == nil {
for i, g := range s.Guilds { s.Lock()
if g.ID == guild.ID { defer s.Unlock()
// If this guild already exists with data, don't stomp on props
if !g.Unavailable { if !g.Unavailable {
guild.Members = g.Members guild.Members = g.Members
guild.Presences = g.Presences guild.Presences = g.Presences
guild.Channels = g.Channels guild.Channels = g.Channels
guild.VoiceStates = g.VoiceStates guild.VoiceStates = g.VoiceStates
} }
s.Guilds[i] = guild *g = *guild
return nil return nil
} }
}
s.Lock()
defer s.Unlock()
s.Guilds = append(s.Guilds, guild) s.Guilds = append(s.Guilds, guild)
s.guildMap[guild.ID] = guild
return nil return nil
} }
@ -87,6 +92,11 @@ func (s *State) GuildRemove(guild *Guild) error {
return ErrNilState return ErrNilState
} }
_, err := s.Guild(guild.ID)
if err != nil {
return err
}
s.Lock() s.Lock()
defer s.Unlock() defer s.Unlock()
@ -97,7 +107,9 @@ func (s *State) GuildRemove(guild *Guild) error {
} }
} }
return errors.New("Guild not found.") delete(s.guildMap, guild.ID)
return nil
} }
// Guild gets a guild by ID. // Guild gets a guild by ID.
@ -112,11 +124,9 @@ func (s *State) Guild(guildID string) (*Guild, error) {
s.RLock() s.RLock()
defer s.RUnlock() defer s.RUnlock()
for _, g := range s.Guilds { if g, ok := s.guildMap[guildID]; ok {
if g.ID == guildID {
return g, nil return g, nil
} }
}
return nil, errors.New("Guild not found.") return nil, errors.New("Guild not found.")
} }
@ -205,20 +215,22 @@ func (s *State) ChannelAdd(channel *Channel) error {
return ErrNilState return ErrNilState
} }
if channel.IsPrivate { if c, err := s.Channel(channel.ID); err == nil {
s.Lock() s.Lock()
defer s.Unlock() defer s.Unlock()
// If the channel exists, replace it.
for i, c := range s.PrivateChannels {
if c.ID == channel.ID {
// Don't stomp on messages.
channel.Messages = c.Messages channel.Messages = c.Messages
s.PrivateChannels[i] = channel channel.PermissionOverwrites = c.PermissionOverwrites
*c = *channel
return nil return nil
} }
}
s.Lock()
defer s.Unlock()
if channel.IsPrivate {
s.PrivateChannels = append(s.PrivateChannels, channel) s.PrivateChannels = append(s.PrivateChannels, channel)
} else { } else {
guild, err := s.Guild(channel.GuildID) guild, err := s.Guild(channel.GuildID)
@ -226,22 +238,11 @@ func (s *State) ChannelAdd(channel *Channel) error {
return err return err
} }
s.Lock()
defer s.Unlock()
// If the channel exists, replace it.
for i, c := range guild.Channels {
if c.ID == channel.ID {
// Don't stomp on messages.
channel.Messages = c.Messages
guild.Channels[i] = channel
return nil
}
}
guild.Channels = append(guild.Channels, channel) guild.Channels = append(guild.Channels, channel)
} }
s.channelMap[channel.ID] = channel
return nil return nil
} }
@ -251,6 +252,11 @@ func (s *State) ChannelRemove(channel *Channel) error {
return ErrNilState return ErrNilState
} }
_, err := s.Channel(channel.ID)
if err != nil {
return err
}
if channel.IsPrivate { if channel.IsPrivate {
s.Lock() s.Lock()
defer s.Unlock() defer s.Unlock()
@ -278,48 +284,21 @@ func (s *State) ChannelRemove(channel *Channel) error {
} }
} }
return errors.New("Channel not found.") delete(s.channelMap, channel.ID)
return nil
} }
// GuildChannel gets a channel by ID from a guild. // GuildChannel gets a channel by ID from a guild.
// This method is Deprecated, use Channel(channelID)
func (s *State) GuildChannel(guildID, channelID string) (*Channel, error) { func (s *State) GuildChannel(guildID, channelID string) (*Channel, error) {
if s == nil { return s.Channel(channelID)
return nil, ErrNilState
}
guild, err := s.Guild(guildID)
if err != nil {
return nil, err
}
s.RLock()
defer s.RUnlock()
for _, c := range guild.Channels {
if c.ID == channelID {
return c, nil
}
}
return nil, errors.New("Channel not found.")
} }
// PrivateChannel gets a private channel by ID. // PrivateChannel gets a private channel by ID.
// This method is Deprecated, use Channel(channelID)
func (s *State) PrivateChannel(channelID string) (*Channel, error) { func (s *State) PrivateChannel(channelID string) (*Channel, error) {
if s == nil { return s.Channel(channelID)
return nil, ErrNilState
}
s.RLock()
defer s.RUnlock()
for _, c := range s.PrivateChannels {
if c.ID == channelID {
return c, nil
}
}
return nil, errors.New("Channel not found.")
} }
// Channel gets a channel by ID, it will look in all guilds an private channels. // Channel gets a channel by ID, it will look in all guilds an private channels.
@ -328,18 +307,10 @@ func (s *State) Channel(channelID string) (*Channel, error) {
return nil, ErrNilState return nil, ErrNilState
} }
c, err := s.PrivateChannel(channelID) if c, ok := s.channelMap[channelID]; ok {
if err == nil {
return c, nil return c, nil
} }
for _, g := range s.Guilds {
c, err := s.GuildChannel(g.ID, channelID)
if err == nil {
return c, nil
}
}
return nil, errors.New("Channel not found.") return nil, errors.New("Channel not found.")
} }

View file

@ -368,6 +368,9 @@ type State struct {
sync.RWMutex sync.RWMutex
Ready Ready
MaxMessageCount int MaxMessageCount int
guildMap map[string]*Guild
channelMap map[string]*Channel
} }
// Constants for the different bit offsets of text channel permissions // Constants for the different bit offsets of text channel permissions