feat: memory
This commit is contained in:
parent
afbab98a6b
commit
7b0a35fbf6
6 changed files with 86 additions and 8 deletions
|
@ -138,21 +138,40 @@ func getDefaultResponse(s *discordgo.Session, question string) string {
|
|||
return result
|
||||
}
|
||||
|
||||
func getAIResponse(question string) string {
|
||||
result, err := ChatBot.Gemini.Models.GenerateContent(context.TODO(), configs.Config.Chatbot.Gemini.Model, genai.Text(question), ChatBot.config)
|
||||
func getAIResponse(userId, question string) string {
|
||||
contents, err := GetMemory(userId)
|
||||
if err != nil {
|
||||
ChatBot.Mode = ChatbotMuffin
|
||||
fmt.Println(err)
|
||||
log.Fatalln(err)
|
||||
return "AI에 문제가 생겼ㅇ어요."
|
||||
}
|
||||
return result.Text()
|
||||
|
||||
contents = append(contents, genai.NewContentFromText(question, genai.RoleUser))
|
||||
result, err := ChatBot.Gemini.Models.GenerateContent(context.TODO(), configs.Config.Chatbot.Gemini.Model, contents, ChatBot.config)
|
||||
if err != nil {
|
||||
ChatBot.Mode = ChatbotMuffin
|
||||
log.Fatalln(err)
|
||||
return "AI에 문제가 생겼ㅇ어요."
|
||||
}
|
||||
|
||||
resultText := result.Text()
|
||||
err = SaveMemory(&databases.InsertMemory{
|
||||
UserId: userId,
|
||||
Content: question,
|
||||
Answer: resultText,
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
|
||||
return resultText
|
||||
}
|
||||
|
||||
func (c *Chatbot) GetResponse(question string) string {
|
||||
func (c *Chatbot) GetResponse(userId, question string) string {
|
||||
switch c.Mode {
|
||||
case ChatbotMuffin:
|
||||
return getDefaultResponse(c.s, question)
|
||||
default:
|
||||
return getAIResponse(question)
|
||||
return getAIResponse(userId, question)
|
||||
}
|
||||
}
|
||||
|
|
41
chatbot/memory.go
Normal file
41
chatbot/memory.go
Normal file
|
@ -0,0 +1,41 @@
|
|||
package chatbot
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"git.wh64.net/muffin/goMuffin/databases"
|
||||
"go.mongodb.org/mongo-driver/v2/bson"
|
||||
"google.golang.org/genai"
|
||||
)
|
||||
|
||||
func SaveMemory(data *databases.InsertMemory) error {
|
||||
_, err := databases.Database.Memory.InsertOne(context.TODO(), *data)
|
||||
return err
|
||||
}
|
||||
|
||||
func GetMemory(userId string) ([]*genai.Content, error) {
|
||||
var data []databases.Memory
|
||||
|
||||
MAX_LENGTH := 50
|
||||
memory := []*genai.Content{}
|
||||
|
||||
cur, err := databases.Database.Memory.Find(context.TODO(), bson.D{{Key: "user_id", Value: userId}})
|
||||
if err != nil {
|
||||
return memory, err
|
||||
}
|
||||
|
||||
cur.All(context.TODO(), &data)
|
||||
|
||||
if len(data) > MAX_LENGTH {
|
||||
data = data[MAX_LENGTH:]
|
||||
}
|
||||
|
||||
for _, data := range data {
|
||||
memory = append(memory,
|
||||
genai.NewContentFromText(data.Content, genai.RoleUser),
|
||||
genai.NewContentFromText(data.Answer, genai.RoleModel),
|
||||
)
|
||||
}
|
||||
|
||||
return memory, nil
|
||||
}
|
|
@ -7,7 +7,7 @@ import (
|
|||
"git.wh64.net/muffin/goMuffin/utils"
|
||||
)
|
||||
|
||||
const MUFFIN_VERSION = "0.0.0-madeleine_canary.250601a-muffin-ai"
|
||||
const MUFFIN_VERSION = "0.0.0-madeleine_canary.250601b-muffin-ai"
|
||||
|
||||
var updatedString string = utils.RegexpDecimals.FindAllStringSubmatch(MUFFIN_VERSION, -1)[3][0]
|
||||
|
||||
|
|
16
databases/Memory.go
Normal file
16
databases/Memory.go
Normal file
|
@ -0,0 +1,16 @@
|
|||
package databases
|
||||
|
||||
import "go.mongodb.org/mongo-driver/v2/bson"
|
||||
|
||||
type InsertMemory struct {
|
||||
UserId string `bson:"user_id"`
|
||||
Content string `bson:"content"`
|
||||
Answer string `bson:"answer"`
|
||||
}
|
||||
|
||||
type Memory struct {
|
||||
Id bson.ObjectID `bson:"_id"`
|
||||
UserId string `bson:"user_id"`
|
||||
Content string `bson:"content"`
|
||||
Answer string `bson:"answer"`
|
||||
}
|
|
@ -12,6 +12,7 @@ type MuffinDatabase struct {
|
|||
Client *mongo.Client
|
||||
Learns *mongo.Collection
|
||||
Texts *mongo.Collection
|
||||
Memory *mongo.Collection
|
||||
}
|
||||
|
||||
var Database *MuffinDatabase
|
||||
|
@ -34,5 +35,6 @@ func Connect() (*MuffinDatabase, error) {
|
|||
Client: client,
|
||||
Learns: client.Database(configs.Config.Database.Name).Collection("learn"),
|
||||
Texts: client.Database(configs.Config.Database.Name).Collection("text"),
|
||||
Memory: client.Database(configs.Config.Database.Name).Collection("memory"),
|
||||
}, nil
|
||||
}
|
||||
|
|
|
@ -40,7 +40,7 @@ func MessageCreate(s *discordgo.Session, m *discordgo.MessageCreate) {
|
|||
if command == "" {
|
||||
s.ChannelTyping(m.ChannelID)
|
||||
|
||||
result := chatbot.ParseResult(chatbot.ChatBot.GetResponse(content), s, m)
|
||||
result := chatbot.ParseResult(chatbot.ChatBot.GetResponse(m.Author.ID, content), s, m)
|
||||
utils.NewMessageSender(&utils.MessageCreate{
|
||||
MessageCreate: m,
|
||||
Session: s,
|
||||
|
|
Loading…
Reference in a new issue