Make state tracking optional.

This commit is contained in:
Chris Rhodes 2016-01-02 12:01:03 -08:00
parent b2ef55ae9c
commit 9ba6d5b7c1
4 changed files with 104 additions and 30 deletions

View file

@ -63,6 +63,7 @@ func New(args ...interface{}) (s *Session, err error) {
// Create an empty Session interface. // Create an empty Session interface.
s = &Session{ s = &Session{
State: NewState(), State: NewState(),
StateEnabled: true,
} }
// If no arguments are passed return the empty Session interface. // If no arguments are passed return the empty Session interface.

View file

@ -2,6 +2,8 @@ package discordgo
import "errors" import "errors"
var nilError error = errors.New("State not instantiated, please use discordgo.New() or assign session.State.")
// NewState creates an empty state. // NewState creates an empty state.
func NewState() *State { func NewState() *State {
return &State{ return &State{
@ -13,13 +15,22 @@ func NewState() *State {
} }
// OnReady takes a Ready event and updates all internal state. // OnReady takes a Ready event and updates all internal state.
func (s *State) OnReady(r *Ready) { func (s *State) OnReady(r *Ready) error {
if s == nil {
return nilError
}
s.Ready = *r s.Ready = *r
return nil
} }
// AddGuild adds a guild to the current world state, or // AddGuild adds a guild to the current world state, or
// updates it if it already exists. // updates it if it already exists.
func (s *State) AddGuild(guild *Guild) { func (s *State) AddGuild(guild *Guild) error {
if s == nil {
return nilError
}
for _, g := range s.Guilds { for _, g := range s.Guilds {
if g.ID == guild.ID { if g.ID == guild.ID {
// This could be a little faster ;) // This could be a little faster ;)
@ -29,14 +40,19 @@ func (s *State) AddGuild(guild *Guild) {
for _, c := range guild.Channels { for _, c := range guild.Channels {
s.AddChannel(&c) s.AddChannel(&c)
} }
return return nil
} }
} }
s.Guilds = append(s.Guilds, *guild) s.Guilds = append(s.Guilds, *guild)
return nil
} }
// RemoveGuild removes a guild from current world state. // RemoveGuild removes a guild from current world state.
func (s *State) RemoveGuild(guild *Guild) error { func (s *State) RemoveGuild(guild *Guild) error {
if s == nil {
return nilError
}
for i, g := range s.Guilds { for i, g := range s.Guilds {
if g.ID == guild.ID { if g.ID == guild.ID {
s.Guilds = append(s.Guilds[:i], s.Guilds[i+1:]...) s.Guilds = append(s.Guilds[:i], s.Guilds[i+1:]...)
@ -46,11 +62,15 @@ func (s *State) RemoveGuild(guild *Guild) error {
return errors.New("Guild not found.") return errors.New("Guild not found.")
} }
// GetGuildByID gets a guild by ID. // Guild gets a guild by ID.
// Useful for querying if @me is in a guild: // Useful for querying if @me is in a guild:
// _, err := discordgo.Session.State.GetGuildById(guildID) // _, err := discordgo.Session.State.Guild(guildID)
// isInGuild := err == nil // isInGuild := err == nil
func (s *State) GetGuildByID(guildID string) (*Guild, error) { func (s *State) Guild(guildID string) (*Guild, error) {
if s == nil {
return nil, nilError
}
for _, g := range s.Guilds { for _, g := range s.Guilds {
if g.ID == guildID { if g.ID == guildID {
return &g, nil return &g, nil
@ -64,7 +84,11 @@ func (s *State) GetGuildByID(guildID string) (*Guild, error) {
// AddMember adds a member to the current world state, or // AddMember adds a member to the current world state, or
// updates it if it already exists. // updates it if it already exists.
func (s *State) AddMember(member *Member) error { func (s *State) AddMember(member *Member) error {
guild, err := s.GetGuildByID(member.GuildID) if s == nil {
return nilError
}
guild, err := s.Guild(member.GuildID)
if err != nil { if err != nil {
return err return err
} }
@ -82,7 +106,11 @@ func (s *State) AddMember(member *Member) error {
// RemoveMember removes a member from current world state. // RemoveMember removes a member from current world state.
func (s *State) RemoveMember(member *Member) error { func (s *State) RemoveMember(member *Member) error {
guild, err := s.GetGuildByID(member.GuildID) if s == nil {
return nilError
}
guild, err := s.Guild(member.GuildID)
if err != nil { if err != nil {
return err return err
} }
@ -93,12 +121,17 @@ func (s *State) RemoveMember(member *Member) error {
return nil return nil
} }
} }
return errors.New("Member not found.") return errors.New("Member not found.")
} }
// GetMemberByID gets a member by ID from a guild. // Member gets a member by ID from a guild.
func (s *State) GetMemberByID(guildID string, userID string) (*Member, error) { func (s *State) Member(guildID string, userID string) (*Member, error) {
guild, err := s.GetGuildByID(guildID) if s == nil {
return nil, nilError
}
guild, err := s.Guild(guildID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -108,6 +141,7 @@ func (s *State) GetMemberByID(guildID string, userID string) (*Member, error) {
return &m, nil return &m, nil
} }
} }
return nil, errors.New("Member not found.") return nil, errors.New("Member not found.")
} }
@ -116,6 +150,10 @@ func (s *State) GetMemberByID(guildID string, userID string) (*Member, error) {
// Channels may exist either as PrivateChannels or inside // Channels may exist either as PrivateChannels or inside
// a guild. // a guild.
func (s *State) AddChannel(channel *Channel) error { func (s *State) AddChannel(channel *Channel) error {
if s == nil {
return nilError
}
if channel.IsPrivate { if channel.IsPrivate {
for i, c := range s.PrivateChannels { for i, c := range s.PrivateChannels {
if c.ID == channel.ID { if c.ID == channel.ID {
@ -126,7 +164,7 @@ func (s *State) AddChannel(channel *Channel) error {
s.PrivateChannels = append(s.PrivateChannels, *channel) s.PrivateChannels = append(s.PrivateChannels, *channel)
} else { } else {
guild, err := s.GetGuildByID(channel.GuildID) guild, err := s.Guild(channel.GuildID)
if err != nil { if err != nil {
return err return err
} }
@ -145,6 +183,10 @@ func (s *State) AddChannel(channel *Channel) error {
// RemoveChannel removes a channel from current world state. // RemoveChannel removes a channel from current world state.
func (s *State) RemoveChannel(channel *Channel) error { func (s *State) RemoveChannel(channel *Channel) error {
if s == nil {
return nilError
}
if channel.IsPrivate { if channel.IsPrivate {
for i, c := range s.PrivateChannels { for i, c := range s.PrivateChannels {
if c.ID == channel.ID { if c.ID == channel.ID {
@ -153,7 +195,7 @@ func (s *State) RemoveChannel(channel *Channel) error {
} }
} }
} else { } else {
guild, err := s.GetGuildByID(channel.GuildID) guild, err := s.Guild(channel.GuildID)
if err != nil { if err != nil {
return err return err
} }
@ -169,9 +211,13 @@ func (s *State) RemoveChannel(channel *Channel) error {
return errors.New("Channel not found.") return errors.New("Channel not found.")
} }
// GetGuildChannelById gets a channel by ID from a guild. // GuildChannel gets a channel by ID from a guild.
func (s *State) GetGuildChannelByID(guildID string, channelID string) (*Channel, error) { func (s *State) GuildChannel(guildID string, channelID string) (*Channel, error) {
guild, err := s.GetGuildByID(guildID) if s == nil {
return nil, nilError
}
guild, err := s.Guild(guildID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -181,15 +227,21 @@ func (s *State) GetGuildChannelByID(guildID string, channelID string) (*Channel,
return &c, nil return &c, nil
} }
} }
return nil, errors.New("Channel not found.") return nil, errors.New("Channel not found.")
} }
// GetPrivateChannelByID gets a private channel by ID. // PrivateChannel gets a private channel by ID.
func (s *State) GetPrivateChannelByID(channelID string) (*Channel, error) { func (s *State) PrivateChannel(channelID string) (*Channel, error) {
if s == nil {
return nil, nilError
}
for _, c := range s.PrivateChannels { for _, c := range s.PrivateChannels {
if c.ID == channelID { if c.ID == channelID {
return &c, nil return &c, nil
} }
} }
return nil, errors.New("Channel not found.") return nil, errors.New("Channel not found.")
} }

View file

@ -87,6 +87,7 @@ type Session struct {
// Managed state object, updated with events. // Managed state object, updated with events.
State *State State *State
StateEnabled bool
} }
// A Message stores all data related to a specific Discord message. // A Message stores all data related to a specific Discord message.

View file

@ -153,7 +153,9 @@ func (s *Session) event(messageType int, message []byte) (err error) {
case "READY": case "READY":
var st Ready var st Ready
if err = unmarshalEvent(e, &st); err == nil { if err = unmarshalEvent(e, &st); err == nil {
if s.StateEnabled {
s.State.OnReady(&st) s.State.OnReady(&st)
}
if s.OnReady != nil { if s.OnReady != nil {
s.OnReady(s, st) s.OnReady(s, st)
} }
@ -238,7 +240,9 @@ func (s *Session) event(messageType int, message []byte) (err error) {
case "CHANNEL_CREATE": case "CHANNEL_CREATE":
var st Channel var st Channel
if err = unmarshalEvent(e, &st); err == nil { if err = unmarshalEvent(e, &st); err == nil {
if s.StateEnabled {
s.State.AddChannel(&st) s.State.AddChannel(&st)
}
if s.OnChannelCreate != nil { if s.OnChannelCreate != nil {
s.OnChannelCreate(s, st) s.OnChannelCreate(s, st)
} }
@ -247,7 +251,9 @@ func (s *Session) event(messageType int, message []byte) (err error) {
case "CHANNEL_UPDATE": case "CHANNEL_UPDATE":
var st Channel var st Channel
if err = unmarshalEvent(e, &st); err == nil { if err = unmarshalEvent(e, &st); err == nil {
if s.StateEnabled {
s.State.AddChannel(&st) s.State.AddChannel(&st)
}
if s.OnChannelUpdate != nil { if s.OnChannelUpdate != nil {
s.OnChannelUpdate(s, st) s.OnChannelUpdate(s, st)
} }
@ -256,7 +262,9 @@ func (s *Session) event(messageType int, message []byte) (err error) {
case "CHANNEL_DELETE": case "CHANNEL_DELETE":
var st Channel var st Channel
if err = unmarshalEvent(e, &st); err == nil { if err = unmarshalEvent(e, &st); err == nil {
if s.StateEnabled {
s.State.RemoveChannel(&st) s.State.RemoveChannel(&st)
}
if s.OnChannelDelete != nil { if s.OnChannelDelete != nil {
s.OnChannelDelete(s, st) s.OnChannelDelete(s, st)
} }
@ -265,7 +273,9 @@ func (s *Session) event(messageType int, message []byte) (err error) {
case "GUILD_CREATE": case "GUILD_CREATE":
var st Guild var st Guild
if err = unmarshalEvent(e, &st); err == nil { if err = unmarshalEvent(e, &st); err == nil {
if s.StateEnabled {
s.State.AddGuild(&st) s.State.AddGuild(&st)
}
if s.OnGuildCreate != nil { if s.OnGuildCreate != nil {
s.OnGuildCreate(s, st) s.OnGuildCreate(s, st)
} }
@ -274,7 +284,9 @@ func (s *Session) event(messageType int, message []byte) (err error) {
case "GUILD_UPDATE": case "GUILD_UPDATE":
var st Guild var st Guild
if err = unmarshalEvent(e, &st); err == nil { if err = unmarshalEvent(e, &st); err == nil {
if s.StateEnabled {
s.State.AddGuild(&st) s.State.AddGuild(&st)
}
if s.OnGuildCreate != nil { if s.OnGuildCreate != nil {
s.OnGuildUpdate(s, st) s.OnGuildUpdate(s, st)
} }
@ -283,7 +295,9 @@ func (s *Session) event(messageType int, message []byte) (err error) {
case "GUILD_DELETE": case "GUILD_DELETE":
var st Guild var st Guild
if err = unmarshalEvent(e, &st); err == nil { if err = unmarshalEvent(e, &st); err == nil {
if s.StateEnabled {
s.State.RemoveGuild(&st) s.State.RemoveGuild(&st)
}
if s.OnGuildDelete != nil { if s.OnGuildDelete != nil {
s.OnGuildDelete(s, st) s.OnGuildDelete(s, st)
} }
@ -292,7 +306,9 @@ func (s *Session) event(messageType int, message []byte) (err error) {
case "GUILD_MEMBER_ADD": case "GUILD_MEMBER_ADD":
var st Member var st Member
if err = unmarshalEvent(e, &st); err == nil { if err = unmarshalEvent(e, &st); err == nil {
if s.StateEnabled {
s.State.AddMember(&st) s.State.AddMember(&st)
}
if s.OnGuildMemberAdd != nil { if s.OnGuildMemberAdd != nil {
s.OnGuildMemberAdd(s, st) s.OnGuildMemberAdd(s, st)
} }
@ -301,7 +317,9 @@ func (s *Session) event(messageType int, message []byte) (err error) {
case "GUILD_MEMBER_REMOVE": case "GUILD_MEMBER_REMOVE":
var st Member var st Member
if err = unmarshalEvent(e, &st); err == nil { if err = unmarshalEvent(e, &st); err == nil {
if s.StateEnabled {
s.State.RemoveMember(&st) s.State.RemoveMember(&st)
}
if s.OnGuildMemberRemove != nil { if s.OnGuildMemberRemove != nil {
s.OnGuildMemberRemove(s, st) s.OnGuildMemberRemove(s, st)
} }
@ -310,7 +328,9 @@ func (s *Session) event(messageType int, message []byte) (err error) {
case "GUILD_MEMBER_UPDATE": case "GUILD_MEMBER_UPDATE":
var st Member var st Member
if err = unmarshalEvent(e, &st); err == nil { if err = unmarshalEvent(e, &st); err == nil {
if s.StateEnabled {
s.State.AddMember(&st) s.State.AddMember(&st)
}
if s.OnGuildMemberUpdate != nil { if s.OnGuildMemberUpdate != nil {
s.OnGuildMemberUpdate(s, st) s.OnGuildMemberUpdate(s, st)
} }