diff --git a/wsapi.go b/wsapi.go index 85fe700..904f24a 100644 --- a/wsapi.go +++ b/wsapi.go @@ -15,6 +15,7 @@ import ( "compress/zlib" "encoding/json" "errors" + "fmt" "io" "net/http" "runtime" @@ -45,19 +46,114 @@ type resumePacket struct { } `json:"d"` } -// Open opens a websocket connection to Discord. -func (s *Session) Open() (err error) { - +// Open creates a websocket connection to Discord. +// See: https://discordapp.com/developers/docs/topics/gateway#connecting +func (s *Session) Open() error { s.log(LogInformational, "called") + var err error + + // Prevent Open or other major Session functions from + // being called while Open is still running. s.Lock() - defer func() { + defer s.Unlock() + + // If the websock is already open, bail out here. + if s.wsConn != nil { + return ErrWSAlreadyOpen + } + + // Get the gateway to use for the Websocket connection + if s.gateway == "" { + s.gateway, err = s.Gateway() if err != nil { - s.Unlock() + return err + } + + // Add the version and encoding to the URL + s.gateway = s.gateway + "?v=" + APIVersion + "&encoding=json" + } + + // Connect to the Gateway + s.log(LogInformational, "connecting to gateway %s", s.gateway) + header := http.Header{} + header.Add("accept-encoding", "zlib") + s.wsConn, _, err = websocket.DefaultDialer.Dial(s.gateway, header) + if err != nil { + s.log(LogWarning, "error connecting to gateway %s, %s", s.gateway, err) + s.gateway = "" // clear cached gateway + s.wsConn = nil // Just to be safe. + return err + } + + defer func() { + // because of this, all code below must set err to the error + // when exiting with an error :) Maybe someone has a better + // way :) + if err != nil { + s.wsConn.Close() + s.wsConn = nil } }() + // The first response from Discord should be an Op 10 (Hello) Packet. + // When processed by onEvent the heartbeat goroutine will be started. + mt, m, err := s.wsConn.ReadMessage() + if err != nil { + return err + } + e, err := s.onEvent(mt, m) + if err != nil { + return err + } + if e.Operation != 10 { + err = fmt.Errorf("Expecting Op 10, got Op %d instead.", e.Operation) + return err + } + s.log(LogInformational, "Op 10 Hello Packet received from Discord") + s.LastHeartbeatAck = time.Now().UTC() + var h helloOp + if err = json.Unmarshal(e.RawData, &h); err != nil { + err = fmt.Errorf("error unmarshalling helloOp, %s", err) + return err + } + + // Now we send either an Op 2 Identity if this is a brand new + // connection or Op 6 Resume if we are resuming an existing connection. + sequence := atomic.LoadInt64(s.sequence) + if s.sessionID == "" && sequence == 0 { + + // Send Op 2 Identity Packet + err = s.identify() + if err != nil { + err = fmt.Errorf("error sending identify packet to gateway, %s, %s", s.gateway, err) + return err + } + + } else { + + // Send Op 6 Resume Packet + p := resumePacket{} + p.Op = 6 + p.Data.Token = s.Token + p.Data.SessionID = s.sessionID + p.Data.Sequence = sequence + + s.log(LogInformational, "sending resume packet to gateway") + s.wsMutex.Lock() + err = s.wsConn.WriteJSON(p) + s.wsMutex.Unlock() + if err != nil { + err = fmt.Errorf("error sending gateway resume packet, %s, %s", s.gateway, err) + return err + } + + } + // A basic state is a hard requirement for Voice. + // We create it here so the below READY/RESUMED packet can populate + // the state :) + // XXX: Move to New() func? if s.State == nil { state := NewState() state.TrackChannels = false @@ -68,77 +164,42 @@ func (s *Session) Open() (err error) { s.State = state } - if s.wsConn != nil { - err = ErrWSAlreadyOpen - return + // Now Discord should send us a READY or RESUMED packet. + mt, m, err = s.wsConn.ReadMessage() + if err != nil { + return err } + e, err = s.onEvent(mt, m) + if err != nil { + return err + } + if e.Type != `READY` && e.Type != `RESUMED` { + // This is not fatal, but it does not follow their API documentation. + s.log(LogWarning, "Expected READY/RESUMED, instead got:\n%#v\n", e) + } + s.log(LogInformational, "First Packet:\n%#v\n", e) + s.log(LogInformational, "We are now connected to Discord, emitting connect event") + s.handleEvent(connectEventType, &Connect{}) + + // A VoiceConnections map is a hard requirement for Voice. + // XXX: can this be moved to when opening a voice connection? if s.VoiceConnections == nil { s.log(LogInformational, "creating new VoiceConnections map") s.VoiceConnections = make(map[string]*VoiceConnection) } - // Get the gateway to use for the Websocket connection - if s.gateway == "" { - s.gateway, err = s.Gateway() - if err != nil { - return - } - - // Add the version and encoding to the URL - s.gateway = s.gateway + "?v=" + APIVersion + "&encoding=json" - } - - header := http.Header{} - header.Add("accept-encoding", "zlib") - - s.log(LogInformational, "connecting to gateway %s", s.gateway) - s.wsConn, _, err = websocket.DefaultDialer.Dial(s.gateway, header) - if err != nil { - s.log(LogWarning, "error connecting to gateway %s, %s", s.gateway, err) - s.gateway = "" // clear cached gateway - // TODO: should we add a retry block here? - return - } - - sequence := atomic.LoadInt64(s.sequence) - if s.sessionID != "" && sequence > 0 { - - p := resumePacket{} - p.Op = 6 - p.Data.Token = s.Token - p.Data.SessionID = s.sessionID - p.Data.Sequence = sequence - - s.log(LogInformational, "sending resume packet to gateway") - err = s.wsConn.WriteJSON(p) - if err != nil { - s.log(LogWarning, "error sending gateway resume packet, %s, %s", s.gateway, err) - return - } - - } else { - - err = s.identify() - if err != nil { - s.log(LogWarning, "error sending gateway identify packet, %s, %s", s.gateway, err) - return - } - } - - // Create listening outside of listen, as it needs to happen inside the mutex - // lock. + // Create listening chan outside of listen, as it needs to happen inside the + // mutex lock and needs to exist before calling heartbeat and listen + // go rountines. s.listening = make(chan interface{}) + + // Start sending heartbeats and reading messages from Discord. + go s.heartbeat(s.wsConn, s.listening, h.HeartbeatInterval) go s.listen(s.wsConn, s.listening) - s.LastHeartbeatAck = time.Now().UTC() - - s.Unlock() - - s.log(LogInformational, "emit connect event") - s.handleEvent(connectEventType, &Connect{}) s.log(LogInformational, "exiting") - return + return nil } // listen polls the websocket connection for events, it will stop when the @@ -364,9 +425,7 @@ func (s *Session) RequestGuildMembers(guildID, query string, limit int) (err err // // If you use the AddHandler() function to register a handler for the // "OnEvent" event then all events will be passed to that handler. -// -// TODO: You may also register a custom event handler entirely using... -func (s *Session) onEvent(messageType int, message []byte) { +func (s *Session) onEvent(messageType int, message []byte) (*Event, error) { var err error var reader io.Reader @@ -378,7 +437,7 @@ func (s *Session) onEvent(messageType int, message []byte) { z, err2 := zlib.NewReader(reader) if err2 != nil { s.log(LogError, "error uncompressing websocket message, %s", err) - return + return nil, err2 } defer func() { @@ -396,7 +455,7 @@ func (s *Session) onEvent(messageType int, message []byte) { decoder := json.NewDecoder(reader) if err = decoder.Decode(&e); err != nil { s.log(LogError, "error decoding websocket message, %s", err) - return + return e, err } s.log(LogDebug, "Op: %d, Seq: %d, Type: %s, Data: %s\n\n", e.Operation, e.Sequence, e.Type, string(e.RawData)) @@ -410,10 +469,10 @@ func (s *Session) onEvent(messageType int, message []byte) { s.wsMutex.Unlock() if err != nil { s.log(LogError, "error sending heartbeat in response to Op1") - return + return e, err } - return + return e, nil } // Reconnect @@ -422,7 +481,7 @@ func (s *Session) onEvent(messageType int, message []byte) { s.log(LogInformational, "Closing and reconnecting in response to Op7") s.Close() s.reconnect() - return + return e, nil } // Invalid Session @@ -434,20 +493,15 @@ func (s *Session) onEvent(messageType int, message []byte) { err = s.identify() if err != nil { s.log(LogWarning, "error sending gateway identify packet, %s, %s", s.gateway, err) - return + return e, err } - return + return e, nil } if e.Operation == 10 { - var h helloOp - if err = json.Unmarshal(e.RawData, &h); err != nil { - s.log(LogError, "error unmarshalling helloOp, %s", err) - } else { - go s.heartbeat(s.wsConn, s.listening, h.HeartbeatInterval) - } - return + // Op10 is handled by Open() + return e, nil } if e.Operation == 11 { @@ -455,7 +509,7 @@ func (s *Session) onEvent(messageType int, message []byte) { s.LastHeartbeatAck = time.Now().UTC() s.Unlock() s.log(LogInformational, "got heartbeat ACK") - return + return e, nil } // Do not try to Dispatch a non-Dispatch Message @@ -463,7 +517,7 @@ func (s *Session) onEvent(messageType int, message []byte) { // But we probably should be doing something with them. // TEMP s.log(LogWarning, "unknown Op: %d, Seq: %d, Type: %s, Data: %s, message: %s", e.Operation, e.Sequence, e.Type, string(e.RawData), string(message)) - return + return e, nil } // Store the message sequence @@ -492,6 +546,8 @@ func (s *Session) onEvent(messageType int, message []byte) { // For legacy reasons, we send the raw event also, this could be useful for handling unknown events. s.handleEvent(eventEventType, e) + + return e, nil } // ------------------------------------------------------------------------------------------------