feat: memory

This commit is contained in:
Siwoo Jeon 2025-06-01 17:03:28 +09:00
parent afbab98a6b
commit 7b0a35fbf6
Signed by: migan
GPG key ID: 036E9A8C5E8E48DA
6 changed files with 86 additions and 8 deletions

View file

@ -138,21 +138,40 @@ func getDefaultResponse(s *discordgo.Session, question string) string {
return result return result
} }
func getAIResponse(question string) string { func getAIResponse(userId, question string) string {
result, err := ChatBot.Gemini.Models.GenerateContent(context.TODO(), configs.Config.Chatbot.Gemini.Model, genai.Text(question), ChatBot.config) contents, err := GetMemory(userId)
if err != nil { if err != nil {
ChatBot.Mode = ChatbotMuffin ChatBot.Mode = ChatbotMuffin
fmt.Println(err) log.Fatalln(err)
return "AI에 문제가 생겼ㅇ어요." return "AI에 문제가 생겼ㅇ어요."
} }
return result.Text()
contents = append(contents, genai.NewContentFromText(question, genai.RoleUser))
result, err := ChatBot.Gemini.Models.GenerateContent(context.TODO(), configs.Config.Chatbot.Gemini.Model, contents, ChatBot.config)
if err != nil {
ChatBot.Mode = ChatbotMuffin
log.Fatalln(err)
return "AI에 문제가 생겼ㅇ어요."
}
resultText := result.Text()
err = SaveMemory(&databases.InsertMemory{
UserId: userId,
Content: question,
Answer: resultText,
})
if err != nil {
log.Fatalln(err)
}
return resultText
} }
func (c *Chatbot) GetResponse(question string) string { func (c *Chatbot) GetResponse(userId, question string) string {
switch c.Mode { switch c.Mode {
case ChatbotMuffin: case ChatbotMuffin:
return getDefaultResponse(c.s, question) return getDefaultResponse(c.s, question)
default: default:
return getAIResponse(question) return getAIResponse(userId, question)
} }
} }

41
chatbot/memory.go Normal file
View file

@ -0,0 +1,41 @@
package chatbot
import (
"context"
"git.wh64.net/muffin/goMuffin/databases"
"go.mongodb.org/mongo-driver/v2/bson"
"google.golang.org/genai"
)
func SaveMemory(data *databases.InsertMemory) error {
_, err := databases.Database.Memory.InsertOne(context.TODO(), *data)
return err
}
func GetMemory(userId string) ([]*genai.Content, error) {
var data []databases.Memory
MAX_LENGTH := 50
memory := []*genai.Content{}
cur, err := databases.Database.Memory.Find(context.TODO(), bson.D{{Key: "user_id", Value: userId}})
if err != nil {
return memory, err
}
cur.All(context.TODO(), &data)
if len(data) > MAX_LENGTH {
data = data[MAX_LENGTH:]
}
for _, data := range data {
memory = append(memory,
genai.NewContentFromText(data.Content, genai.RoleUser),
genai.NewContentFromText(data.Answer, genai.RoleModel),
)
}
return memory, nil
}

View file

@ -7,7 +7,7 @@ import (
"git.wh64.net/muffin/goMuffin/utils" "git.wh64.net/muffin/goMuffin/utils"
) )
const MUFFIN_VERSION = "0.0.0-madeleine_canary.250601a-muffin-ai" const MUFFIN_VERSION = "0.0.0-madeleine_canary.250601b-muffin-ai"
var updatedString string = utils.RegexpDecimals.FindAllStringSubmatch(MUFFIN_VERSION, -1)[3][0] var updatedString string = utils.RegexpDecimals.FindAllStringSubmatch(MUFFIN_VERSION, -1)[3][0]

16
databases/Memory.go Normal file
View file

@ -0,0 +1,16 @@
package databases
import "go.mongodb.org/mongo-driver/v2/bson"
type InsertMemory struct {
UserId string `bson:"user_id"`
Content string `bson:"content"`
Answer string `bson:"answer"`
}
type Memory struct {
Id bson.ObjectID `bson:"_id"`
UserId string `bson:"user_id"`
Content string `bson:"content"`
Answer string `bson:"answer"`
}

View file

@ -12,6 +12,7 @@ type MuffinDatabase struct {
Client *mongo.Client Client *mongo.Client
Learns *mongo.Collection Learns *mongo.Collection
Texts *mongo.Collection Texts *mongo.Collection
Memory *mongo.Collection
} }
var Database *MuffinDatabase var Database *MuffinDatabase
@ -34,5 +35,6 @@ func Connect() (*MuffinDatabase, error) {
Client: client, Client: client,
Learns: client.Database(configs.Config.Database.Name).Collection("learn"), Learns: client.Database(configs.Config.Database.Name).Collection("learn"),
Texts: client.Database(configs.Config.Database.Name).Collection("text"), Texts: client.Database(configs.Config.Database.Name).Collection("text"),
Memory: client.Database(configs.Config.Database.Name).Collection("memory"),
}, nil }, nil
} }

View file

@ -40,7 +40,7 @@ func MessageCreate(s *discordgo.Session, m *discordgo.MessageCreate) {
if command == "" { if command == "" {
s.ChannelTyping(m.ChannelID) s.ChannelTyping(m.ChannelID)
result := chatbot.ParseResult(chatbot.ChatBot.GetResponse(content), s, m) result := chatbot.ParseResult(chatbot.ChatBot.GetResponse(m.Author.ID, content), s, m)
utils.NewMessageSender(&utils.MessageCreate{ utils.NewMessageSender(&utils.MessageCreate{
MessageCreate: m, MessageCreate: m,
Session: s, Session: s,