diff --git a/discord.go b/discord.go index e77ec94..6b3f6eb 100644 --- a/discord.go +++ b/discord.go @@ -111,56 +111,8 @@ func New(args ...interface{}) (s *Session, err error) { } } - // TODO: Add code here to fetch authenticated user info like settings, - // avatar, User ID, etc. If fails, return error. - - // Open websocket connection - err = s.Open() - if err != nil { - fmt.Println(err) - return - } - - // Do websocket handshake. - err = s.Handshake() - if err != nil { - fmt.Println(err) - return - } - - // Listen for events. - go s.Listen() + // The Session is now able to have RestAPI methods called on it. + // It is recommended that you now call Open() so that events will trigger. return } - -// Close closes a Discord session -// TODO: Add support for Voice WS/UDP connections -func (s *Session) Close() { - - s.DataReady = false - - if s.heartbeatChan != nil { - select { - case <-s.heartbeatChan: - break - default: - close(s.heartbeatChan) - } - s.heartbeatChan = nil - } - - if s.listenChan != nil { - select { - case <-s.listenChan: - break - default: - close(s.listenChan) - } - s.listenChan = nil - } - - if s.wsConn != nil { - s.wsConn.Close() - } -} diff --git a/discord_test.go b/discord_test.go index a9ddc50..deed4f4 100644 --- a/discord_test.go +++ b/discord_test.go @@ -21,6 +21,16 @@ var ( envAdmin string = os.Getenv("DG_ADMIN") // User ID of admin user to use for tests ) +func init() { + if envEmail == "" || envPassword == "" || envToken == "" { + return + } + + if d, err := New(envEmail, envPassword, envToken); err == nil { + dg = d + } +} + ////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////// HELPER FUNCTIONS USED FOR TESTING @@ -76,12 +86,16 @@ func TestNew(t *testing.T) { // TestInvalidToken tests the New() function with an invalid token func TestInvalidToken(t *testing.T) { - - _, err := New("asjkldhflkjasdh") - if err == nil { - t.Errorf("New(InvalidToken) returned nil error.") + d, err := New("asjkldhflkjasdh") + if err != nil { + t.Fatalf("New(InvalidToken) returned error: %+v", err) } + // New with just a token does not do any communication, so attempt an api call. + _, err = d.UserSettings() + if err == nil { + t.Errorf("New(InvalidToken), d.UserSettings returned nil error.") + } } // TestInvalidUserPass tests the New() function with an invalid Email and Pass @@ -108,8 +122,7 @@ func TestInvalidPass(t *testing.T) { } // TestNewUserPass tests the New() function with a username and password. -// This should return a valid Session{}, a valid Session.Token, and open -// a websocket connection to Discord. +// This should return a valid Session{}, a valid Session.Token. func TestNewUserPass(t *testing.T) { if envEmail == "" || envPassword == "" { @@ -129,22 +142,6 @@ func TestNewUserPass(t *testing.T) { if d.Token == "" { t.Fatal("New(user,pass), d.Token is empty, should be a valid Token.") } - - if !waitBoolEqual(10*time.Second, &d.DataReady, true) { - t.Fatal("New(user,pass), d.DataReady is false after 10 seconds. Should be true.") - } - - t.Log("Successfully connected to Discord via New(user,pass).") - dg = d - if envToken == "" { - envToken = dg.Token - } -} - -func TestClose(t *testing.T) { - if dg != nil { - dg.Close() - } } // TestNewToken tests the New() function with a Token. This should return @@ -167,12 +164,46 @@ func TestNewToken(t *testing.T) { if d.Token == "" { t.Fatal("New(envToken), d.Token is empty, should be a valid Token.") } +} - if !waitBoolEqual(10*time.Second, &d.DataReady, true) { - t.Fatal("New(envToken), d.DataReady is false after 10 seconds. Should be true.") +// TestNewUserPassToken tests the New() function with a username, password and token. +// This should return the same as the TestNewUserPass function. +func TestNewUserPassToken(t *testing.T) { + + if envEmail == "" || envPassword == "" || envToken == "" { + t.Skip("Skipping New(username,password,token), DG_EMAIL, DG_PASSWORD or DG_TOKEN not set") + return } - t.Log("Successfully connected to Discord via New(token).") - dg = d + d, err := New(envEmail, envPassword, envToken) + if err != nil { + t.Fatalf("New(user,pass,token) returned error: %+v", err) + } + if d == nil { + t.Fatal("New(user,pass,token), d is nil, should be Session{}") + } + + if d.Token == "" { + t.Fatal("New(user,pass,token), d.Token is empty, should be a valid Token.") + } +} + +func TestOpenClose(t *testing.T) { + if envToken == "" { + t.Skip("Skipping TestClose, DG_TOKEN not set") + } + + d, err := New(envToken) + if err != nil { + t.Fatalf("TestClose, New(envToken) returned error: %+v", err) + } + + if err = d.Open(); err != nil { + t.Fatalf("TestClose, d.Open failed: %+v", err) + } + + if err = d.Close(); err != nil { + t.Fatalf("TestClose, d.Close failed: %+v", err) + } } diff --git a/examples/api_basic/api_basic.go b/examples/api_basic/api_basic.go index 890fa9f..aea4e56 100644 --- a/examples/api_basic/api_basic.go +++ b/examples/api_basic/api_basic.go @@ -40,15 +40,6 @@ func main() { fmt.Println(err) } - // Do websocket handshake. - err = dg.Handshake() - if err != nil { - fmt.Println(err) - } - - // Listen for events. - go dg.Listen() - // Simple way to keep program running until any key press. var input string fmt.Scanln(&input) diff --git a/examples/new_basic/new_basic.go b/examples/new_basic/new_basic.go index 88fca12..43492ec 100644 --- a/examples/new_basic/new_basic.go +++ b/examples/new_basic/new_basic.go @@ -31,6 +31,9 @@ func main() { // Register messageCreate as a callback for the OnMessageCreate event. dg.OnMessageCreate = messageCreate + // Open the websocket and begin listening. + dg.Open() + // Simple way to keep program running until any key press. var input string fmt.Scanln(&input) diff --git a/restapi_test.go b/restapi_test.go index b24b61a..d0f3e80 100644 --- a/restapi_test.go +++ b/restapi_test.go @@ -10,8 +10,8 @@ import ( // TestLogout tests the Logout() function. This should not return an error. func TestLogout(t *testing.T) { - if dg == nil || dg.Token == "" { - t.Skip("Cannot test logout, dg.Token not set.") + if dg == nil { + t.Skip("Cannot TestLogout, dg not set.") } err := dg.Logout() @@ -21,8 +21,8 @@ func TestLogout(t *testing.T) { } func TestUserAvatar(t *testing.T) { - if !isConnected() { - t.Skip("Skipped, Not connected to Discord.") + if dg == nil { + t.Skip("Cannot TestUserAvatar, dg not set.") } a, err := dg.UserAvatar("@me") @@ -39,14 +39,8 @@ func TestUserAvatar(t *testing.T) { } func TestUserUpdate(t *testing.T) { - - if envEmail == "" || envPassword == "" { - t.Skip("Skipping, DG_USERNAME or DG_PASSWORD not set") - return - } - - if !isConnected() { - t.Skip("Skipped, Not connected to Discord.") + if dg == nil { + t.Skip("Cannot test logout, dg not set.") } u, err := dg.User("@me") @@ -73,9 +67,8 @@ func TestUserUpdate(t *testing.T) { //func (s *Session) UserChannelCreate(recipientID string) (st *Channel, err error) { func TestUserChannelCreate(t *testing.T) { - - if !isConnected() { - t.Skip("Skipped, Not connected to Discord.") + if dg == nil { + t.Skip("Cannot TestUserChannelCreate, dg not set.") } if envAdmin == "" { @@ -91,9 +84,8 @@ func TestUserChannelCreate(t *testing.T) { } func TestUserChannels(t *testing.T) { - - if !isConnected() { - t.Skip("Skipped, Not connected to Discord.") + if dg == nil { + t.Skip("Cannot TestUserChannels, dg not set.") } _, err := dg.UserChannels() @@ -103,9 +95,8 @@ func TestUserChannels(t *testing.T) { } func TestUserGuilds(t *testing.T) { - - if !isConnected() { - t.Skip("Skipped, Not connected to Discord.") + if dg == nil { + t.Skip("Cannot TestUserGuilds, dg not set.") } _, err := dg.UserGuilds() @@ -115,9 +106,8 @@ func TestUserGuilds(t *testing.T) { } func TestUserSettings(t *testing.T) { - - if !isConnected() { - t.Skip("Skipped, Not connected to Discord.") + if dg == nil { + t.Skip("Cannot TestUserSettings, dg not set.") } _, err := dg.UserSettings() diff --git a/structs.go b/structs.go index 091eda2..68866ca 100644 --- a/structs.go +++ b/structs.go @@ -23,10 +23,19 @@ import ( // token : The authentication token returned from Discord // Debug : If set to ture debug logging will be displayed. type Session struct { + sync.RWMutex + // General configurable settings. Token string // Authentication token for this session Debug bool // Debug for printing JSON request/responses + // Settable Callback functions for Internal Events + // OnConnect is called when the websocket connection opens. + OnConnect func(*Session) + // OnDisconnect is called when the websocket connection closes. + // This is a good handler to add reconnection logic to. + OnDisconnect func(*Session) + // Settable Callback functions for Websocket Events OnEvent func(*Session, *Event) OnReady func(*Session, *Ready) @@ -81,14 +90,8 @@ type Session struct { StateEnabled bool StateMaxMessageCount int - // Mutex/Bools for locks that prevent accidents. - // TODO: Add channels. - - heartbeatLock sync.Mutex - heartbeatChan chan struct{} - - listenLock sync.Mutex - listenChan chan struct{} + // When nil, the session is not listening. + listening chan interface{} } // A VoiceRegion stores data for a specific voice region server. diff --git a/wsapi.go b/wsapi.go index 3f8bafc..342efdd 100644 --- a/wsapi.go +++ b/wsapi.go @@ -11,6 +11,7 @@ package discordgo import ( + "errors" "fmt" "runtime" "time" @@ -18,21 +19,6 @@ import ( "github.com/gorilla/websocket" ) -// Open opens a websocket connection to Discord. -func (s *Session) Open() (err error) { - - // Get the gateway to use for the Websocket connection - g, err := s.Gateway() - if err != nil { - return - } - - // TODO: See if there's a use for the http response. - // conn, response, err := websocket.DefaultDialer.Dial(session.Gateway, nil) - s.wsConn, _, err = websocket.DefaultDialer.Dial(g, nil) - return -} - type handshakeProperties struct { OS string `json:"$os"` Browser string `json:"$browser"` @@ -52,12 +38,98 @@ type handshakeOp struct { Data handshakeData `json:"d"` } -// Handshake sends the client data to Discord during websocket initial connection. -func (s *Session) Handshake() (err error) { - // maybe this is SendOrigin? not sure the right name here +// Open opens a websocket connection to Discord. +func (s *Session) Open() (err error) { + s.Lock() + defer func() { + if err != nil { + s.Unlock() + } + }() + + if s.wsConn != nil { + err = errors.New("Web socket already opened.") + return + } + + // Get the gateway to use for the Websocket connection + g, err := s.Gateway() + if err != nil { + return + } + + // TODO: See if there's a use for the http response. + // conn, response, err := websocket.DefaultDialer.Dial(session.Gateway, nil) + s.wsConn, _, err = websocket.DefaultDialer.Dial(g, nil) + if err != nil { + return + } + + err = s.wsConn.WriteJSON(handshakeOp{2, handshakeData{3, s.Token, handshakeProperties{runtime.GOOS, "Discordgo v" + VERSION, "", "", ""}}}) + if err != nil { + return + } + + // Create listening outside of listen, as it needs to happen inside the mutex + // lock. + s.listening = make(chan interface{}) + go s.listen(s.wsConn, s.listening) + + s.Unlock() + + if s.OnConnect != nil { + s.OnConnect(s) + } + + return +} + +// Close closes a websocket and stops all listening/heartbeat goroutines. +// TODO: Add support for Voice WS/UDP connections +func (s *Session) Close() (err error) { + s.Lock() + + s.DataReady = false + + if s.listening != nil { + close(s.listening) + s.listening = nil + } + + if s.wsConn != nil { + err = s.wsConn.Close() + s.wsConn = nil + } + + s.Unlock() + + if s.OnDisconnect != nil { + s.OnDisconnect(s) + } + + return +} + +// listen polls the websocket connection for events, it will stop when +// the listening channel is closed, or an error occurs. +func (s *Session) listen(wsConn *websocket.Conn, listening <-chan interface{}) { + for { + messageType, message, err := wsConn.ReadMessage() + if err != nil { + // There has been an error reading, Close() the websocket so that + // OnDisconnect is fired. + s.Close() + return + } + + select { + case <-listening: + return + default: + go s.event(messageType, message) + } + } - data := handshakeOp{2, handshakeData{3, s.Token, handshakeProperties{runtime.GOOS, "Discordgo v" + VERSION, "", "", ""}}} - err = s.wsConn.WriteJSON(data) return } @@ -79,6 +151,11 @@ type updateStatusOp struct { // If idle>0 then set status to idle. If game>0 then set game. // if otherwise, set status to active, and no game. func (s *Session) UpdateStatus(idle int, game string) (err error) { + s.RLock() + defer s.RUnlock() + if s.wsConn == nil { + return errors.New("No websocket connection exists.") + } var usd updateStatusData if idle > 0 { @@ -88,75 +165,7 @@ func (s *Session) UpdateStatus(idle int, game string) (err error) { usd.Game = &updateStatusGame{game} } - data := updateStatusOp{3, usd} - err = s.wsConn.WriteJSON(data) - - return -} - -// Listen starts listening to the websocket connection for events. -func (s *Session) Listen() (err error) { - // TODO: need a channel or something to communicate - // to this so I can tell it to stop listening - - if s.wsConn == nil { - fmt.Println("No websocket connection exists.") - return // TODO need to return an error. - } - - // Make sure Listen is not already running - s.listenLock.Lock() - if s.listenChan != nil { - s.listenLock.Unlock() - return - } - s.listenChan = make(chan struct{}) - s.listenLock.Unlock() - - // this is ugly. - defer func() { - if s.listenChan == nil { - return - } - select { - case <-s.listenChan: - break - default: - close(s.listenChan) - } - s.listenChan = nil - }() - - // this is ugly. - defer func() { - if s.heartbeatChan == nil { - return - } - select { - case <-s.heartbeatChan: - break - default: - close(s.heartbeatChan) - } - s.listenChan = nil - }() - - for { - messageType, message, err := s.wsConn.ReadMessage() - if err != nil { - fmt.Println("Websocket Listen Error", err) - // TODO Log error - break - } - go s.event(messageType, message) - - // If our chan gets closed, exit out of this loop. - // TODO: Can we make this smarter, using select - // and some other trickery? http://www.goinggo.net/2013/10/my-channel-select-bug.html - if s.listenChan == nil { - return nil - } - } + err = s.wsConn.WriteJSON(updateStatusOp{3, usd}) return } @@ -192,17 +201,16 @@ func (s *Session) event(messageType int, message []byte) (err error) { } switch e.Type { - case "READY": var st *Ready if err = unmarshalEvent(e, &st); err == nil { + go s.heartbeat(s.wsConn, s.listening, st.HeartbeatInterval) if s.StateEnabled { s.State.OnReady(st) } if s.OnReady != nil { s.OnReady(s, st) } - go s.Heartbeat(st.HeartbeatInterval) } if s.OnReady != nil { return @@ -541,58 +549,45 @@ func (s *Session) event(messageType int, message []byte) (err error) { return } -// Heartbeat sends regular heartbeats to Discord so it knows the client +type heartbeatOp struct { + Op int `json:"op"` + Data int `json:"d"` +} + +func (s *Session) sendHeartbeat(wsConn *websocket.Conn) error { + return wsConn.WriteJSON(heartbeatOp{1, int(time.Now().Unix())}) +} + +// heartbeat sends regular heartbeats to Discord so it knows the client // is still connected. If you do not send these heartbeats Discord will // disconnect the websocket connection after a few seconds. -func (s *Session) Heartbeat(i time.Duration) { - - if s.wsConn == nil { - fmt.Println("No websocket connection exists.") - return // TODO need to return/log an error. - } - - // Make sure Heartbeat is not already running - s.heartbeatLock.Lock() - if s.heartbeatChan != nil { - s.heartbeatLock.Unlock() +func (s *Session) heartbeat(wsConn *websocket.Conn, listening <-chan interface{}, i time.Duration) { + if listening == nil || wsConn == nil { return } - s.heartbeatChan = make(chan struct{}) - s.heartbeatLock.Unlock() - // this is ugly. - defer func() { - if s.heartbeatChan == nil { - return - } - select { - case <-s.heartbeatChan: - break - default: - close(s.heartbeatChan) - } - s.listenChan = nil - }() + s.Lock() + s.DataReady = true + s.Unlock() + + // Send first heartbeat immediately because lag could put the + // first heartbeat outside the required heartbeat interval window. + err := s.sendHeartbeat(wsConn) + if err != nil { + fmt.Println("Error sending initial heartbeat:", err) + return + } - // send first heartbeat immediately because lag could put the - // first heartbeat outside the required heartbeat interval window ticker := time.NewTicker(i * time.Millisecond) for { - - err := s.wsConn.WriteJSON(map[string]int{ - "op": 1, - "d": int(time.Now().Unix()), - }) - if err != nil { - fmt.Println("error sending data heartbeat:", err) - s.DataReady = false - return // TODO log error? - } - s.DataReady = true - select { case <-ticker.C: - case <-s.heartbeatChan: + err := s.sendHeartbeat(wsConn) + if err != nil { + fmt.Println("Error sending heartbeat:", err) + return + } + case <-listening: return } }