diff --git a/restapi.go b/restapi.go index 589185c..6286b65 100644 --- a/restapi.go +++ b/restapi.go @@ -18,7 +18,9 @@ import ( "image" _ "image/jpeg" // For JPEG decoding _ "image/png" // For PNG decoding + "io" "io/ioutil" + "mime/multipart" "net/http" "net/url" "strconv" @@ -28,12 +30,11 @@ import ( // ErrJSONUnmarshal is returned for JSON Unmarshall errors. var ErrJSONUnmarshal = errors.New("json unmarshal") -// Request makes a (GET/POST/...) Requests to Discord REST API. +// Request makes a (GET/POST/...) Requests to Discord REST API with JSON data. // All the other Discord REST Calls in this file use this function. func (s *Session) Request(method, urlStr string, data interface{}) (response []byte, err error) { if s.Debug { - fmt.Printf("API REQUEST %8s :: %s\n", method, urlStr) fmt.Println("API REQUEST PAYLOAD :: [" + fmt.Sprintf("%+v", data) + "]") } @@ -45,7 +46,17 @@ func (s *Session) Request(method, urlStr string, data interface{}) (response []b } } - req, err := http.NewRequest(method, urlStr, bytes.NewBuffer(body)) + return s.request(method, urlStr, "application/json", body) +} + +// request makes a (GET/POST/...) Requests to Discord REST API. +func (s *Session) request(method, urlStr, contentType string, b []byte) (response []byte, err error) { + + if s.Debug { + fmt.Printf("API REQUEST %8s :: %s\n", method, urlStr) + } + + req, err := http.NewRequest(method, urlStr, bytes.NewBuffer(b)) if err != nil { return } @@ -56,7 +67,7 @@ func (s *Session) Request(method, urlStr string, data interface{}) (response []b req.Header.Set("authorization", s.Token) } - req.Header.Set("Content-Type", "application/json") + req.Header.Set("Content-Type", contentType) // TODO: Make a configurable static variable. req.Header.Set("User-Agent", fmt.Sprintf("DiscordBot (https://github.com/bwmarrin/discordgo, v%s)", VERSION)) @@ -112,7 +123,7 @@ func (s *Session) Request(method, urlStr string, data interface{}) (response []b return } time.Sleep(rl.RetryAfter) - response, err = s.Request(method, urlStr, data) + response, err = s.request(method, urlStr, contentType, b) default: // Error condition err = fmt.Errorf("HTTP %s, %s", resp.Status, response) @@ -796,6 +807,36 @@ func (s *Session) ChannelMessageDelete(channelID, messageID string) (err error) return } +// ChannelFileSend sends a file to the given channel. +// channelID : The ID of a Channel. +// io.Reader : A reader for the file contents. +func (s *Session) ChannelFileSend(channelID, name string, r io.Reader) (st *Message, err error) { + + body := &bytes.Buffer{} + bodywriter := multipart.NewWriter(body) + + writer, err := bodywriter.CreateFormFile("file", name) + if err != nil { + return nil, err + } + + io.Copy(writer, r) + + bodywriter.Close() + + se := body.String() + fmt.Println(se) + + // Send the message to the given channel + response, err := s.request("POST", CHANNEL_MESSAGES(channelID), bodywriter.FormDataContentType(), body.Bytes()) + if err != nil { + return + } + + err = unmarshal(response, &st) + return +} + // ChannelInvites returns an array of Invite structures for the given channel // channelID : The ID of a Channel func (s *Session) ChannelInvites(channelID string) (st []*Invite, err error) {