diff --git a/chatbot/chatbot.go b/chatbot/chatbot.go index e4e0474..34a3363 100644 --- a/chatbot/chatbot.go +++ b/chatbot/chatbot.go @@ -138,21 +138,40 @@ func getDefaultResponse(s *discordgo.Session, question string) string { return result } -func getAIResponse(question string) string { - result, err := ChatBot.Gemini.Models.GenerateContent(context.TODO(), configs.Config.Chatbot.Gemini.Model, genai.Text(question), ChatBot.config) +func getAIResponse(userId, question string) string { + contents, err := GetMemory(userId) if err != nil { ChatBot.Mode = ChatbotMuffin - fmt.Println(err) + log.Fatalln(err) return "AI에 문제가 생겼ㅇ어요." } - return result.Text() + + contents = append(contents, genai.NewContentFromText(question, genai.RoleUser)) + result, err := ChatBot.Gemini.Models.GenerateContent(context.TODO(), configs.Config.Chatbot.Gemini.Model, contents, ChatBot.config) + if err != nil { + ChatBot.Mode = ChatbotMuffin + log.Fatalln(err) + return "AI에 문제가 생겼ㅇ어요." + } + + resultText := result.Text() + err = SaveMemory(&databases.InsertMemory{ + UserId: userId, + Content: question, + Answer: resultText, + }) + if err != nil { + log.Fatalln(err) + } + + return resultText } -func (c *Chatbot) GetResponse(question string) string { +func (c *Chatbot) GetResponse(userId, question string) string { switch c.Mode { case ChatbotMuffin: return getDefaultResponse(c.s, question) default: - return getAIResponse(question) + return getAIResponse(userId, question) } } diff --git a/chatbot/memory.go b/chatbot/memory.go new file mode 100644 index 0000000..db6d706 --- /dev/null +++ b/chatbot/memory.go @@ -0,0 +1,41 @@ +package chatbot + +import ( + "context" + + "git.wh64.net/muffin/goMuffin/databases" + "go.mongodb.org/mongo-driver/v2/bson" + "google.golang.org/genai" +) + +func SaveMemory(data *databases.InsertMemory) error { + _, err := databases.Database.Memory.InsertOne(context.TODO(), *data) + return err +} + +func GetMemory(userId string) ([]*genai.Content, error) { + var data []databases.Memory + + MAX_LENGTH := 50 + memory := []*genai.Content{} + + cur, err := databases.Database.Memory.Find(context.TODO(), bson.D{{Key: "user_id", Value: userId}}) + if err != nil { + return memory, err + } + + cur.All(context.TODO(), &data) + + if len(data) > MAX_LENGTH { + data = data[MAX_LENGTH:] + } + + for _, data := range data { + memory = append(memory, + genai.NewContentFromText(data.Content, genai.RoleUser), + genai.NewContentFromText(data.Answer, genai.RoleModel), + ) + } + + return memory, nil +} diff --git a/configs/version.go b/configs/version.go index 987bb77..cda5e75 100644 --- a/configs/version.go +++ b/configs/version.go @@ -7,7 +7,7 @@ import ( "git.wh64.net/muffin/goMuffin/utils" ) -const MUFFIN_VERSION = "0.0.0-madeleine_canary.250601a-muffin-ai" +const MUFFIN_VERSION = "0.0.0-madeleine_canary.250601b-muffin-ai" var updatedString string = utils.RegexpDecimals.FindAllStringSubmatch(MUFFIN_VERSION, -1)[3][0] diff --git a/databases/Memory.go b/databases/Memory.go new file mode 100644 index 0000000..3c1509b --- /dev/null +++ b/databases/Memory.go @@ -0,0 +1,16 @@ +package databases + +import "go.mongodb.org/mongo-driver/v2/bson" + +type InsertMemory struct { + UserId string `bson:"user_id"` + Content string `bson:"content"` + Answer string `bson:"answer"` +} + +type Memory struct { + Id bson.ObjectID `bson:"_id"` + UserId string `bson:"user_id"` + Content string `bson:"content"` + Answer string `bson:"answer"` +} diff --git a/databases/database.go b/databases/database.go index 273a671..75a9cf8 100644 --- a/databases/database.go +++ b/databases/database.go @@ -12,6 +12,7 @@ type MuffinDatabase struct { Client *mongo.Client Learns *mongo.Collection Texts *mongo.Collection + Memory *mongo.Collection } var Database *MuffinDatabase @@ -34,5 +35,6 @@ func Connect() (*MuffinDatabase, error) { Client: client, Learns: client.Database(configs.Config.Database.Name).Collection("learn"), Texts: client.Database(configs.Config.Database.Name).Collection("text"), + Memory: client.Database(configs.Config.Database.Name).Collection("memory"), }, nil } diff --git a/handler/messageCreate.go b/handler/messageCreate.go index ca47517..9f7d40c 100644 --- a/handler/messageCreate.go +++ b/handler/messageCreate.go @@ -40,7 +40,7 @@ func MessageCreate(s *discordgo.Session, m *discordgo.MessageCreate) { if command == "" { s.ChannelTyping(m.ChannelID) - result := chatbot.ParseResult(chatbot.ChatBot.GetResponse(content), s, m) + result := chatbot.ParseResult(chatbot.ChatBot.GetResponse(m.Author.ID, content), s, m) utils.NewMessageSender(&utils.MessageCreate{ MessageCreate: m, Session: s,