diff --git a/state.go b/state.go index 4ebfb1e..b545c53 100644 --- a/state.go +++ b/state.go @@ -42,6 +42,7 @@ type State struct { guildMap map[string]*Guild channelMap map[string]*Channel + memberMap map[string]map[string]*Member } // NewState creates an empty state. @@ -59,9 +60,18 @@ func NewState() *State { TrackPresences: true, guildMap: make(map[string]*Guild), channelMap: make(map[string]*Channel), + memberMap: make(map[string]map[string]*Member), } } +func (s *State) createMemberMap(guild *Guild) { + members := make(map[string]*Member) + for _, m := range guild.Members { + members[m.User.ID] = m + } + s.memberMap[guild.ID] = members +} + // GuildAdd adds a guild to the current world state, or // updates it if it already exists. func (s *State) GuildAdd(guild *Guild) error { @@ -77,6 +87,14 @@ func (s *State) GuildAdd(guild *Guild) error { s.channelMap[c.ID] = c } + // If this guild contains a new member slice, we must regenerate the member map so the pointers stay valid + if guild.Members != nil { + s.createMemberMap(guild) + } else if _, ok := s.memberMap[guild.ID]; !ok { + // Even if we have no new member slice, we still initialize the member map for this guild if it doesn't exist + s.memberMap[guild.ID] = make(map[string]*Member) + } + if g, ok := s.guildMap[guild.ID]; ok { // We are about to replace `g` in the state with `guild`, but first we need to // make sure we preserve any fields that the `guild` doesn't contain from `g`. @@ -271,14 +289,19 @@ func (s *State) MemberAdd(member *Member) error { s.Lock() defer s.Unlock() - for i, m := range guild.Members { - if m.User.ID == member.User.ID { - guild.Members[i] = member - return nil - } + members, ok := s.memberMap[member.GuildID] + if !ok { + return ErrStateNotFound + } + + m, ok := members[member.User.ID] + if !ok { + members[member.User.ID] = member + guild.Members = append(guild.Members, member) + } else { + *m = *member // Update the actual data, which will also update the member pointer in the slice } - guild.Members = append(guild.Members, member) return nil } @@ -296,6 +319,17 @@ func (s *State) MemberRemove(member *Member) error { s.Lock() defer s.Unlock() + members, ok := s.memberMap[member.GuildID] + if !ok { + return ErrStateNotFound + } + + _, ok = members[member.User.ID] + if !ok { + return ErrStateNotFound + } + delete(members, member.User.ID) + for i, m := range guild.Members { if m.User.ID == member.User.ID { guild.Members = append(guild.Members[:i], guild.Members[i+1:]...) @@ -312,18 +346,17 @@ func (s *State) Member(guildID, userID string) (*Member, error) { return nil, ErrNilState } - guild, err := s.Guild(guildID) - if err != nil { - return nil, err - } - s.RLock() defer s.RUnlock() - for _, m := range guild.Members { - if m.User.ID == userID { - return m, nil - } + members, ok := s.memberMap[guildID] + if !ok { + return nil, ErrStateNotFound + } + + m, ok := members[userID] + if ok { + return m, nil } return nil, ErrStateNotFound @@ -735,6 +768,7 @@ func (s *State) onReady(se *Session, r *Ready) (err error) { for _, g := range s.Guilds { s.guildMap[g.ID] = g + s.createMemberMap(g) for _, c := range g.Channels { s.channelMap[c.ID] = c