diff --git a/discord.go b/discord.go index f683bff..dd436c8 100644 --- a/discord.go +++ b/discord.go @@ -144,6 +144,11 @@ func (s *Session) AddHandler(handler interface{}) { eventType := handlerType.In(1) + // Support handlers of type interface{}, this is a special handler, which is triggered on every event. + if eventType.Kind() == reflect.Interface { + eventType = nil + } + handlers := s.Handlers[eventType] if handlers == nil { handlers = []interface{}{} @@ -153,39 +158,39 @@ func (s *Session) AddHandler(handler interface{}) { s.Handlers[eventType] = handlers } -func (s *Session) Handle(event interface{}) (handled bool) { +func (s *Session) handle(event interface{}) { s.RLock() defer s.RUnlock() - eventType := reflect.TypeOf(event) - - handlers, ok := s.Handlers[eventType] - if !ok { - return + if handlers, ok := s.Handlers[reflect.TypeOf(event)]; ok { + for _, handler := range handlers { + reflect.ValueOf(handler).Call([]reflect.Value{reflect.ValueOf(s), reflect.ValueOf(event)}) + } } - for _, handler := range handlers { - reflect.ValueOf(handler).Call([]reflect.Value{reflect.ValueOf(s), reflect.ValueOf(event)}) - handled = true + if handlers, ok := s.Handlers[nil]; ok { + for _, handler := range handlers { + reflect.ValueOf(handler).Call([]reflect.Value{reflect.ValueOf(s), reflect.ValueOf(event)}) + } } - - return } // initialize adds all internal handlers and state tracking handlers. func (s *Session) initialize() { s.Lock() - defer s.Unlock() - s.Handlers = map[interface{}][]interface{}{} + s.Unlock() + + s.AddHandler(s.onEvent) s.AddHandler(s.onReady) s.AddHandler(s.onVoiceServerUpdate) s.AddHandler(s.onVoiceStateUpdate) + s.AddHandler(s.State.onInterface) +} - s.AddHandler(s.State.onReady) - s.AddHandler(s.State.onMessageCreate) - s.AddHandler(s.State.onMessageUpdate) - s.AddHandler(s.State.onMessageDelete) +// onEvent handles events that are unhandled or errored while unmarshalling +func (s *Session) onEvent(se *Session, e *Event) { + printEvent(e) } // onReady handles the ready event. diff --git a/events.go b/events.go new file mode 100644 index 0000000..d81ecef --- /dev/null +++ b/events.go @@ -0,0 +1,100 @@ +package discordgo + +// Connect is an empty struct for an event. +type Connect struct{} + +// Disconnect is an empty struct for an event. +type Disconnect struct{} + +// MessageCreate is a wrapper struct for an event. +type MessageCreate struct { + *Message +} + +// MessageUpdate is a wrapper struct for an event. +type MessageUpdate struct { + *Message +} + +// MessageDelete is a wrapper struct for an event. +type MessageDelete struct { + *Message +} + +// ChannelCreate is a wrapper struct for an event. +type ChannelCreate struct { + *Channel +} + +// ChannelUpdate is a wrapper struct for an event. +type ChannelUpdate struct { + *Channel +} + +// ChannelDelete is a wrapper struct for an event. +type ChannelDelete struct { + *Channel +} + +// GuildCreate is a wrapper struct for an event. +type GuildCreate struct { + *Guild +} + +// GuildUpdate is a wrapper struct for an event. +type GuildUpdate struct { + *Guild +} + +// GuildDelete is a wrapper struct for an event. +type GuildDelete struct { + *Guild +} + +// GuildBanAdd is a wrapper struct for an event. +type GuildBanAdd struct { + *GuildBan +} + +// GuildBanRemove is a wrapper struct for an event. +type GuildBanRemove struct { + *GuildBan +} + +// GuildMemberAdd is a wrapper struct for an event. +type GuildMemberAdd struct { + *Member +} + +// GuildMemberUpdate is a wrapper struct for an event. +type GuildMemberUpdate struct { + *Member +} + +// GuildMemberRemove is a wrapper struct for an event. +type GuildMemberRemove struct { + *Member +} + +// GuildRoleCreate is a wrapper struct for an event. +type GuildRoleCreate struct { + *GuildRole +} + +// GuildRoleUpdate is a wrapper struct for an event. +type GuildRoleUpdate struct { + *GuildRole +} + +// VoiceStateUpdate is a wrapper struct for an event. +type VoiceStateUpdate struct { + *VoiceState +} + +// UserUpdate is a wrapper struct for an event. +type UserUpdate struct { + *UserUpdate +} + +// UserSettingsUpdate is a map for an event. +type UserSettingsUpdate map[string]interface{} diff --git a/examples/api_basic/api_basic.go b/examples/api_basic/api_basic.go index aea4e56..733cece 100644 --- a/examples/api_basic/api_basic.go +++ b/examples/api_basic/api_basic.go @@ -23,9 +23,10 @@ func main() { // Create a new Discord Session interface and set a handler for the // OnMessageCreate event that happens for every new message on any channel - dg := discordgo.Session{ - OnMessageCreate: messageCreate, - } + dg := discordgo.Session{} + + // Register messageCreate as a callback for the messageCreate events. + dg.AddHandler(messageCreate) // Login to the Discord server and store the authentication token err = dg.Login(os.Args[1], os.Args[2]) @@ -46,9 +47,9 @@ func main() { return } -// This function will be called (due to above assignment) every time a new +// This function will be called (due to AddHandler above) every time a new // message is created on any channel that the autenticated user has access to. -func messageCreate(s *discordgo.Session, m *discordgo.Message) { +func messageCreate(s *discordgo.Session, m *discordgo.MessageCreate) { // Print message to stdout. fmt.Printf("%20s %20s %20s > %s\n", m.ChannelID, time.Now().Format(time.Stamp), m.Author.Username, m.Content) diff --git a/examples/new_basic/new_basic.go b/examples/new_basic/new_basic.go index 43492ec..6d524ac 100644 --- a/examples/new_basic/new_basic.go +++ b/examples/new_basic/new_basic.go @@ -28,8 +28,8 @@ func main() { return } - // Register messageCreate as a callback for the OnMessageCreate event. - dg.OnMessageCreate = messageCreate + // Register messageCreate as a callback for the messageCreate events. + dg.AddHandler(messageCreate) // Open the websocket and begin listening. dg.Open() @@ -40,9 +40,9 @@ func main() { return } -// This function will be called (due to above assignment) every time a new +// This function will be called (due to AddHandler above) every time a new // message is created on any channel that the autenticated user has access to. -func messageCreate(s *discordgo.Session, m *discordgo.Message) { +func messageCreate(s *discordgo.Session, m *discordgo.MessageCreate) { // Print message to stdout. fmt.Printf("%20s %20s %20s > %s\n", m.ChannelID, time.Now().Format(time.Stamp), m.Author.Username, m.Content) diff --git a/state.go b/state.go index 0fe8a95..8074853 100644 --- a/state.go +++ b/state.go @@ -35,6 +35,7 @@ func (s *State) OnReady(r *Ready) error { if s == nil { return ErrNilState } + s.Lock() defer s.Unlock() @@ -48,6 +49,7 @@ func (s *State) GuildAdd(guild *Guild) error { if s == nil { return ErrNilState } + s.Lock() defer s.Unlock() @@ -73,6 +75,7 @@ func (s *State) GuildRemove(guild *Guild) error { if s == nil { return ErrNilState } + s.Lock() defer s.Unlock() @@ -94,6 +97,7 @@ func (s *State) Guild(guildID string) (*Guild, error) { if s == nil { return nil, ErrNilState } + s.RLock() defer s.RUnlock() @@ -294,6 +298,7 @@ func (s *State) PrivateChannel(channelID string) (*Channel, error) { if s == nil { return nil, ErrNilState } + s.RLock() defer s.RUnlock() @@ -429,6 +434,7 @@ func (s *State) MessageRemove(message *Message) error { if s == nil { return ErrNilState } + c, err := s.Channel(message.ChannelID) if err != nil { return err @@ -452,6 +458,7 @@ func (s *State) Message(channelID, messageID string) (*Message, error) { if s == nil { return nil, ErrNilState } + c, err := s.Channel(channelID) if err != nil { return nil, err @@ -469,30 +476,40 @@ func (s *State) Message(channelID, messageID string) (*Message, error) { return nil, errors.New("Message not found.") } -// onReady handles the ready event. -func (s *State) onReady(se *Session, r *Ready) { - if se.StateEnabled { - s.OnReady(r) +// onInterface handles all events related to states. +func (s *State) onInterface(se *Session, i interface{}) { + if s == nil || !se.StateEnabled { + return } -} -// onMessageCreate handles the messageCreate event. -func (s *State) onMessageCreate(se *Session, m *MessageCreate) { - if se.StateEnabled { - s.MessageAdd(m.Message) - } -} - -// onMessageUpdate handles the messageUpdate event. -func (s *State) onMessageUpdate(se *Session, m *MessageUpdate) { - if se.StateEnabled { - s.MessageAdd(m.Message) - } -} - -// onMessageDelete handles the messageDelete event. -func (s *State) onMessageDelete(se *Session, m *MessageDelete) { - if se.StateEnabled { - s.MessageRemove(m.Message) + switch t := i.(type) { + case *Ready: + s.OnReady(t) + case *GuildCreate: + s.GuildAdd(t.Guild) + case *GuildUpdate: + s.GuildAdd(t.Guild) + case *GuildDelete: + s.GuildRemove(t.Guild) + case *GuildMemberAdd: + s.MemberAdd(t.Member) + case *GuildMemberUpdate: + s.MemberAdd(t.Member) + case *GuildMemberRemove: + s.MemberRemove(t.Member) + case *GuildEmojisUpdate: + s.EmojisAdd(t.GuildID, t.Emojis) + case *ChannelCreate: + s.ChannelAdd(t.Channel) + case *ChannelUpdate: + s.ChannelAdd(t.Channel) + case *ChannelDelete: + s.ChannelRemove(t.Channel) + case *MessageCreate: + s.MessageAdd(t.Message) + case *MessageUpdate: + s.MessageAdd(t.Message) + case *MessageDelete: + s.MessageRemove(t.Message) } } diff --git a/wsapi.go b/wsapi.go index 281fe98..6584072 100644 --- a/wsapi.go +++ b/wsapi.go @@ -18,6 +18,7 @@ import ( "fmt" "io" "net/http" + "reflect" "runtime" "time" @@ -86,7 +87,7 @@ func (s *Session) Open() (err error) { s.Unlock() - s.Handle(&Connect{}) + s.handle(&Connect{}) return } @@ -110,7 +111,7 @@ func (s *Session) Close() (err error) { s.Unlock() - s.Handle(&Disconnect{}) + s.handle(&Disconnect{}) return } @@ -243,6 +244,38 @@ func (s *Session) UpdateStatus(idle int, game string) (err error) { return } +// eventToInterface is a mapping of Discord WSAPI events to their +// DiscordGo event container. +var eventToInterface = map[string]interface{}{ + "CHANNEL_CREATE": ChannelCreate{}, + "CHANNEL_UPDATE": ChannelUpdate{}, + "CHANNEL_DELETE": ChannelDelete{}, + "GUILD_CREATE": GuildCreate{}, + "GUILD_UPDATE": GuildUpdate{}, + "GUILD_DELETE": GuildDelete{}, + "GUILD_BAN_ADD": GuildBanAdd{}, + "GUILD_BAN_REMOVE": GuildBanRemove{}, + "GUILD_MEMBER_ADD": GuildMemberAdd{}, + "GUILD_MEMBER_UPDATE": GuildMemberUpdate{}, + "GUILD_MEMBER_REMOVE": GuildMemberRemove{}, + "GUILD_ROLE_CREATE": GuildRoleCreate{}, + "GUILD_ROLE_UPDATE": GuildRoleUpdate{}, + "GUILD_ROLE_DELETE": GuildRoleDelete{}, + "GUILD_INTEGRATIONS_UPDATE": GuildIntegrationsUpdate{}, + "GUILD_EMOJIS_UPDATE": GuildEmojisUpdate{}, + "MESSAGE_ACK": MessageAck{}, + "MESSAGE_CREATE": MessageCreate{}, + "MESSAGE_UPDATE": MessageUpdate{}, + "MESSAGE_DELETE": MessageDelete{}, + "PRESENCE_UPDATE": PresenceUpdate{}, + "READY": Ready{}, + "USER_UPDATE": UserUpdate{}, + "USER_SETTINGS_UPDATE": UserSettingsUpdate{}, + "TYPING_START": TypingStart{}, + "VOICE_SERVER_UPDATE": VoiceServerUpdate{}, + "VOICE_STATE_UPDATE": VoiceStateUpdate{}, +} + // Front line handler for all Websocket Events. Determines the // event type and passes the message along to the next handler. @@ -289,312 +322,23 @@ func (s *Session) event(messageType int, message []byte) { printEvent(e) } - var i interface{} + i := eventToInterface[e.Type] + if i != nil { + // Create a new instance of the event type. + i = reflect.New(reflect.TypeOf(i)).Interface() - // TODO(iopred): Figure out a clean way to do this with a map, simply - // creating a map[string]interface{} will not work, as that will reuse - // the same instance for each event. - switch e.Type { - case "READY": - i = &Ready{} - case "MESSAGE_CREATE": - i = &MessageCreate{} - case "MESSAGE_UPDATE": - i = &MessageUpdate{} - case "MESSAGE_DELETE": - i = &MessageDelete{} - case "PRESENCE_UPDATE": - i = &PresenceUpdate{} - case "TYPING_START": - i = &TypingStart{} - case "VOICE_SERVER_UPDATE": - i = &VoiceServerUpdate{} - case "VOICE_STATE_UPDATE": - i = &VoiceStateUpdate{} - case "USER_UPDATE": - i = &UserUpdate{} - case "MESSAGE_ACK": - i = &MessageAck{} - case "GUILD_ROLE_CREATE": - i = &GuildRoleCreate{} - case "GUILD_ROLE_UPDATE": - i = &GuildRoleUpdate{} - case "GUILD_ROLE_DELETE": - i = &GuildRoleDelete{} - case "GUILD_INTEGRATIONS_UPDATE": - i = &GuildIntegrationsUpdate{} - case "GUILD_BAN_ADD": - i = &GuildBanAdd{} - case "GUILD_BAN_REMOVE": - i = &GuildBanRemove{} - } - - // case "CHANNEL_CREATE": - // if !s.StateEnabled && s.OnChannelCreate == nil { - // break - // } - // var st *Channel - // if err = unmarshalEvent(e, &st); err == nil { - // if s.StateEnabled { - // err := s.State.ChannelAdd(st) - // if err != nil { - // fmt.Println("error :", err) - // } - // } - // if s.OnChannelCreate != nil { - // s.OnChannelCreate(s, st) - // } - // } - // if s.OnChannelCreate != nil { - // return - // } - // case "CHANNEL_UPDATE": - // if !s.StateEnabled && s.OnChannelUpdate == nil { - // break - // } - // var st *Channel - // if err = unmarshalEvent(e, &st); err == nil { - // if s.StateEnabled { - // err := s.State.ChannelAdd(st) - // if err != nil { - // fmt.Println("error :", err) - // } - // } - // if s.OnChannelUpdate != nil { - // s.OnChannelUpdate(s, st) - // } - // } - // if s.OnChannelUpdate != nil { - // return - // } - // case "CHANNEL_DELETE": - // if !s.StateEnabled && s.OnChannelDelete == nil { - // break - // } - // var st *Channel - // if err = unmarshalEvent(e, &st); err == nil { - // if s.StateEnabled { - // err := s.State.ChannelRemove(st) - // if err != nil { - // fmt.Println("error :", err) - // } - // } - // if s.OnChannelDelete != nil { - // s.OnChannelDelete(s, st) - // } - // } - // if s.OnChannelDelete != nil { - // return - // } - // case "GUILD_CREATE": - // if !s.StateEnabled && s.OnGuildCreate == nil { - // break - // } - // var st *Guild - // if err = unmarshalEvent(e, &st); err == nil { - // if s.StateEnabled { - // err := s.State.GuildAdd(st) - // if err != nil { - // fmt.Println("error :", err) - // } - // } - // if s.OnGuildCreate != nil { - // s.OnGuildCreate(s, st) - // } - // } - // if s.OnGuildCreate != nil { - // return - // } - // case "GUILD_UPDATE": - // if !s.StateEnabled && s.OnGuildUpdate == nil { - // break - // } - // var st *Guild - // if err = unmarshalEvent(e, &st); err == nil { - // if s.StateEnabled { - // err := s.State.GuildAdd(st) - // if err != nil { - // fmt.Println("error :", err) - // } - // } - // if s.OnGuildCreate != nil { - // s.OnGuildUpdate(s, st) - // } - // } - // if s.OnGuildUpdate != nil { - // return - // } - // case "GUILD_DELETE": - // if !s.StateEnabled && s.OnGuildDelete == nil { - // break - // } - // var st *Guild - // if err = unmarshalEvent(e, &st); err == nil { - // if s.StateEnabled { - // err := s.State.GuildRemove(st) - // if err != nil { - // fmt.Println("error :", err) - // } - // } - // if s.OnGuildDelete != nil { - // s.OnGuildDelete(s, st) - // } - // } - // if s.OnGuildDelete != nil { - // return - // } - // case "GUILD_MEMBER_ADD": - // if !s.StateEnabled && s.OnGuildMemberAdd == nil { - // break - // } - // var st *Member - // if err = unmarshalEvent(e, &st); err == nil { - // if s.StateEnabled { - // err := s.State.MemberAdd(st) - // if err != nil { - // fmt.Println("error :", err) - // } - // } - // if s.OnGuildMemberAdd != nil { - // s.OnGuildMemberAdd(s, st) - // } - // } - // if s.OnGuildMemberAdd != nil { - // return - // } - // case "GUILD_MEMBER_REMOVE": - // if !s.StateEnabled && s.OnGuildMemberRemove == nil { - // break - // } - // var st *Member - // if err = unmarshalEvent(e, &st); err == nil { - // if s.StateEnabled { - // err := s.State.MemberRemove(st) - // if err != nil { - // fmt.Println("error :", err) - // } - // } - // if s.OnGuildMemberRemove != nil { - // s.OnGuildMemberRemove(s, st) - // } - // } - // if s.OnGuildMemberRemove != nil { - // return - // } - // case "GUILD_MEMBER_UPDATE": - // if !s.StateEnabled && s.OnGuildMemberUpdate == nil { - // break - // } - // var st *Member - // if err = unmarshalEvent(e, &st); err == nil { - // if s.StateEnabled { - // err := s.State.MemberAdd(st) - // if err != nil { - // fmt.Println("error :", err) - // } - // } - // if s.OnGuildMemberUpdate != nil { - // s.OnGuildMemberUpdate(s, st) - // } - // } - // if s.OnGuildMemberUpdate != nil { - // return - // } - // case "GUILD_ROLE_CREATE": - // if s.OnGuildRoleCreate != nil { - // var st *GuildRole - // if err = unmarshalEvent(e, &st); err == nil { - // s.OnGuildRoleCreate(s, st) - // } - // return - // } - // case "GUILD_ROLE_UPDATE": - // if s.OnGuildRoleUpdate != nil { - // var st *GuildRole - // if err = unmarshalEvent(e, &st); err == nil { - // s.OnGuildRoleUpdate(s, st) - // } - // return - // } - // case "GUILD_ROLE_DELETE": - // if s.OnGuildRoleDelete != nil { - // var st *GuildRoleDelete - // if err = unmarshalEvent(e, &st); err == nil { - // s.OnGuildRoleDelete(s, st) - // } - // return - // } - // case "GUILD_INTEGRATIONS_UPDATE": - // if s.OnGuildIntegrationsUpdate != nil { - // var st *GuildIntegrationsUpdate - // if err = unmarshalEvent(e, &st); err == nil { - // s.OnGuildIntegrationsUpdate(s, st) - // } - // return - // } - // case "GUILD_BAN_ADD": - // if s.OnGuildBanAdd != nil { - // var st *GuildBan - // if err = unmarshalEvent(e, &st); err == nil { - // s.OnGuildBanAdd(s, st) - // } - // return - // } - // case "GUILD_BAN_REMOVE": - // if s.OnGuildBanRemove != nil { - // var st *GuildBan - // if err = unmarshalEvent(e, &st); err == nil { - // s.OnGuildBanRemove(s, st) - // } - // return - // } - // case "GUILD_EMOJIS_UPDATE": - // if !s.StateEnabled && s.OnGuildEmojisUpdate == nil { - // break - // } - // var st *GuildEmojisUpdate - // if err = unmarshalEvent(e, &st); err == nil { - // if s.StateEnabled { - // err := s.State.EmojisAdd(st.GuildID, st.Emojis) - // if err != nil { - // fmt.Println("error :", err) - // } - // } - // if s.OnGuildEmojisUpdate != nil { - // s.OnGuildEmojisUpdate(s, st) - // } - // } - // if s.OnGuildEmojisUpdate != nil { - // return - // } - // case "USER_SETTINGS_UPDATE": - // if s.OnUserSettingsUpdate != nil { - // var st map[string]interface{} - // if err = unmarshalEvent(e, &st); err == nil { - // s.OnUserSettingsUpdate(s, st) - // } - // return - // } - // default: - // fmt.Println("Unknown Event.") - // printEvent(e) - // } - - // Attempt to unmarshal our event. - // If there is an error (eg. we don't know how to handle it) we should handle the event itself. - if err = unmarshal(e.RawData, i); err != nil { - fmt.Println("Unable to unmarshal event data.") - printEvent(e) + // Attempt to unmarshal our event. + // If there is an error we should handle the event itself. + if err = unmarshal(e.RawData, i); err != nil { + fmt.Println("Unable to unmarshal event data.") + i = e + } + } else { + fmt.Println("Unknown event.") i = e } - if !s.Handle(i) { - if i != e { - // If there was not a handler for the struct, handle the event, as long as it wasn't the - // event we were trying to handle. - s.Handle(e) - } - } + s.handle(i) return }