From 9ba6d5b7c12d846894cd727984228b7a129e1daa Mon Sep 17 00:00:00 2001 From: Chris Rhodes Date: Sat, 2 Jan 2016 12:01:03 -0800 Subject: [PATCH] Make state tracking optional. --- discord.go | 3 +- state.go | 88 +++++++++++++++++++++++++++++++++++++++++++----------- structs.go | 3 +- wsapi.go | 40 ++++++++++++++++++------- 4 files changed, 104 insertions(+), 30 deletions(-) diff --git a/discord.go b/discord.go index 799e725..50d3686 100644 --- a/discord.go +++ b/discord.go @@ -62,7 +62,8 @@ func New(args ...interface{}) (s *Session, err error) { // Create an empty Session interface. s = &Session{ - State: NewState(), + State: NewState(), + StateEnabled: true, } // If no arguments are passed return the empty Session interface. diff --git a/state.go b/state.go index 5645aa3..69d3746 100644 --- a/state.go +++ b/state.go @@ -2,6 +2,8 @@ package discordgo import "errors" +var nilError error = errors.New("State not instantiated, please use discordgo.New() or assign session.State.") + // NewState creates an empty state. func NewState() *State { return &State{ @@ -13,13 +15,22 @@ func NewState() *State { } // OnReady takes a Ready event and updates all internal state. -func (s *State) OnReady(r *Ready) { +func (s *State) OnReady(r *Ready) error { + if s == nil { + return nilError + } + s.Ready = *r + return nil } // AddGuild adds a guild to the current world state, or // updates it if it already exists. -func (s *State) AddGuild(guild *Guild) { +func (s *State) AddGuild(guild *Guild) error { + if s == nil { + return nilError + } + for _, g := range s.Guilds { if g.ID == guild.ID { // This could be a little faster ;) @@ -29,14 +40,19 @@ func (s *State) AddGuild(guild *Guild) { for _, c := range guild.Channels { s.AddChannel(&c) } - return + return nil } } s.Guilds = append(s.Guilds, *guild) + return nil } // RemoveGuild removes a guild from current world state. func (s *State) RemoveGuild(guild *Guild) error { + if s == nil { + return nilError + } + for i, g := range s.Guilds { if g.ID == guild.ID { s.Guilds = append(s.Guilds[:i], s.Guilds[i+1:]...) @@ -46,11 +62,15 @@ func (s *State) RemoveGuild(guild *Guild) error { return errors.New("Guild not found.") } -// GetGuildByID gets a guild by ID. +// Guild gets a guild by ID. // Useful for querying if @me is in a guild: -// _, err := discordgo.Session.State.GetGuildById(guildID) +// _, err := discordgo.Session.State.Guild(guildID) // isInGuild := err == nil -func (s *State) GetGuildByID(guildID string) (*Guild, error) { +func (s *State) Guild(guildID string) (*Guild, error) { + if s == nil { + return nil, nilError + } + for _, g := range s.Guilds { if g.ID == guildID { return &g, nil @@ -64,7 +84,11 @@ func (s *State) GetGuildByID(guildID string) (*Guild, error) { // AddMember adds a member to the current world state, or // updates it if it already exists. func (s *State) AddMember(member *Member) error { - guild, err := s.GetGuildByID(member.GuildID) + if s == nil { + return nilError + } + + guild, err := s.Guild(member.GuildID) if err != nil { return err } @@ -82,7 +106,11 @@ func (s *State) AddMember(member *Member) error { // RemoveMember removes a member from current world state. func (s *State) RemoveMember(member *Member) error { - guild, err := s.GetGuildByID(member.GuildID) + if s == nil { + return nilError + } + + guild, err := s.Guild(member.GuildID) if err != nil { return err } @@ -93,12 +121,17 @@ func (s *State) RemoveMember(member *Member) error { return nil } } + return errors.New("Member not found.") } -// GetMemberByID gets a member by ID from a guild. -func (s *State) GetMemberByID(guildID string, userID string) (*Member, error) { - guild, err := s.GetGuildByID(guildID) +// Member gets a member by ID from a guild. +func (s *State) Member(guildID string, userID string) (*Member, error) { + if s == nil { + return nil, nilError + } + + guild, err := s.Guild(guildID) if err != nil { return nil, err } @@ -108,6 +141,7 @@ func (s *State) GetMemberByID(guildID string, userID string) (*Member, error) { return &m, nil } } + return nil, errors.New("Member not found.") } @@ -116,6 +150,10 @@ func (s *State) GetMemberByID(guildID string, userID string) (*Member, error) { // Channels may exist either as PrivateChannels or inside // a guild. func (s *State) AddChannel(channel *Channel) error { + if s == nil { + return nilError + } + if channel.IsPrivate { for i, c := range s.PrivateChannels { if c.ID == channel.ID { @@ -126,7 +164,7 @@ func (s *State) AddChannel(channel *Channel) error { s.PrivateChannels = append(s.PrivateChannels, *channel) } else { - guild, err := s.GetGuildByID(channel.GuildID) + guild, err := s.Guild(channel.GuildID) if err != nil { return err } @@ -145,6 +183,10 @@ func (s *State) AddChannel(channel *Channel) error { // RemoveChannel removes a channel from current world state. func (s *State) RemoveChannel(channel *Channel) error { + if s == nil { + return nilError + } + if channel.IsPrivate { for i, c := range s.PrivateChannels { if c.ID == channel.ID { @@ -153,7 +195,7 @@ func (s *State) RemoveChannel(channel *Channel) error { } } } else { - guild, err := s.GetGuildByID(channel.GuildID) + guild, err := s.Guild(channel.GuildID) if err != nil { return err } @@ -169,9 +211,13 @@ func (s *State) RemoveChannel(channel *Channel) error { return errors.New("Channel not found.") } -// GetGuildChannelById gets a channel by ID from a guild. -func (s *State) GetGuildChannelByID(guildID string, channelID string) (*Channel, error) { - guild, err := s.GetGuildByID(guildID) +// GuildChannel gets a channel by ID from a guild. +func (s *State) GuildChannel(guildID string, channelID string) (*Channel, error) { + if s == nil { + return nil, nilError + } + + guild, err := s.Guild(guildID) if err != nil { return nil, err } @@ -181,15 +227,21 @@ func (s *State) GetGuildChannelByID(guildID string, channelID string) (*Channel, return &c, nil } } + return nil, errors.New("Channel not found.") } -// GetPrivateChannelByID gets a private channel by ID. -func (s *State) GetPrivateChannelByID(channelID string) (*Channel, error) { +// PrivateChannel gets a private channel by ID. +func (s *State) PrivateChannel(channelID string) (*Channel, error) { + if s == nil { + return nil, nilError + } + for _, c := range s.PrivateChannels { if c.ID == channelID { return &c, nil } } + return nil, errors.New("Channel not found.") } diff --git a/structs.go b/structs.go index b716624..c1b3e14 100644 --- a/structs.go +++ b/structs.go @@ -86,7 +86,8 @@ type Session struct { UDPConn *net.UDPConn // Managed state object, updated with events. - State *State + State *State + StateEnabled bool } // A Message stores all data related to a specific Discord message. diff --git a/wsapi.go b/wsapi.go index 7b40eda..cbe3332 100644 --- a/wsapi.go +++ b/wsapi.go @@ -153,7 +153,9 @@ func (s *Session) event(messageType int, message []byte) (err error) { case "READY": var st Ready if err = unmarshalEvent(e, &st); err == nil { - s.State.OnReady(&st) + if s.StateEnabled { + s.State.OnReady(&st) + } if s.OnReady != nil { s.OnReady(s, st) } @@ -238,7 +240,9 @@ func (s *Session) event(messageType int, message []byte) (err error) { case "CHANNEL_CREATE": var st Channel if err = unmarshalEvent(e, &st); err == nil { - s.State.AddChannel(&st) + if s.StateEnabled { + s.State.AddChannel(&st) + } if s.OnChannelCreate != nil { s.OnChannelCreate(s, st) } @@ -247,7 +251,9 @@ func (s *Session) event(messageType int, message []byte) (err error) { case "CHANNEL_UPDATE": var st Channel if err = unmarshalEvent(e, &st); err == nil { - s.State.AddChannel(&st) + if s.StateEnabled { + s.State.AddChannel(&st) + } if s.OnChannelUpdate != nil { s.OnChannelUpdate(s, st) } @@ -256,7 +262,9 @@ func (s *Session) event(messageType int, message []byte) (err error) { case "CHANNEL_DELETE": var st Channel if err = unmarshalEvent(e, &st); err == nil { - s.State.RemoveChannel(&st) + if s.StateEnabled { + s.State.RemoveChannel(&st) + } if s.OnChannelDelete != nil { s.OnChannelDelete(s, st) } @@ -265,7 +273,9 @@ func (s *Session) event(messageType int, message []byte) (err error) { case "GUILD_CREATE": var st Guild if err = unmarshalEvent(e, &st); err == nil { - s.State.AddGuild(&st) + if s.StateEnabled { + s.State.AddGuild(&st) + } if s.OnGuildCreate != nil { s.OnGuildCreate(s, st) } @@ -274,7 +284,9 @@ func (s *Session) event(messageType int, message []byte) (err error) { case "GUILD_UPDATE": var st Guild if err = unmarshalEvent(e, &st); err == nil { - s.State.AddGuild(&st) + if s.StateEnabled { + s.State.AddGuild(&st) + } if s.OnGuildCreate != nil { s.OnGuildUpdate(s, st) } @@ -283,7 +295,9 @@ func (s *Session) event(messageType int, message []byte) (err error) { case "GUILD_DELETE": var st Guild if err = unmarshalEvent(e, &st); err == nil { - s.State.RemoveGuild(&st) + if s.StateEnabled { + s.State.RemoveGuild(&st) + } if s.OnGuildDelete != nil { s.OnGuildDelete(s, st) } @@ -292,7 +306,9 @@ func (s *Session) event(messageType int, message []byte) (err error) { case "GUILD_MEMBER_ADD": var st Member if err = unmarshalEvent(e, &st); err == nil { - s.State.AddMember(&st) + if s.StateEnabled { + s.State.AddMember(&st) + } if s.OnGuildMemberAdd != nil { s.OnGuildMemberAdd(s, st) } @@ -301,7 +317,9 @@ func (s *Session) event(messageType int, message []byte) (err error) { case "GUILD_MEMBER_REMOVE": var st Member if err = unmarshalEvent(e, &st); err == nil { - s.State.RemoveMember(&st) + if s.StateEnabled { + s.State.RemoveMember(&st) + } if s.OnGuildMemberRemove != nil { s.OnGuildMemberRemove(s, st) } @@ -310,7 +328,9 @@ func (s *Session) event(messageType int, message []byte) (err error) { case "GUILD_MEMBER_UPDATE": var st Member if err = unmarshalEvent(e, &st); err == nil { - s.State.AddMember(&st) + if s.StateEnabled { + s.State.AddMember(&st) + } if s.OnGuildMemberUpdate != nil { s.OnGuildMemberUpdate(s, st) }