feat: memory
This commit is contained in:
parent
afbab98a6b
commit
7b0a35fbf6
6 changed files with 86 additions and 8 deletions
|
@ -138,21 +138,40 @@ func getDefaultResponse(s *discordgo.Session, question string) string {
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
func getAIResponse(question string) string {
|
func getAIResponse(userId, question string) string {
|
||||||
result, err := ChatBot.Gemini.Models.GenerateContent(context.TODO(), configs.Config.Chatbot.Gemini.Model, genai.Text(question), ChatBot.config)
|
contents, err := GetMemory(userId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ChatBot.Mode = ChatbotMuffin
|
ChatBot.Mode = ChatbotMuffin
|
||||||
fmt.Println(err)
|
log.Fatalln(err)
|
||||||
return "AI에 문제가 생겼ㅇ어요."
|
return "AI에 문제가 생겼ㅇ어요."
|
||||||
}
|
}
|
||||||
return result.Text()
|
|
||||||
|
contents = append(contents, genai.NewContentFromText(question, genai.RoleUser))
|
||||||
|
result, err := ChatBot.Gemini.Models.GenerateContent(context.TODO(), configs.Config.Chatbot.Gemini.Model, contents, ChatBot.config)
|
||||||
|
if err != nil {
|
||||||
|
ChatBot.Mode = ChatbotMuffin
|
||||||
|
log.Fatalln(err)
|
||||||
|
return "AI에 문제가 생겼ㅇ어요."
|
||||||
|
}
|
||||||
|
|
||||||
|
resultText := result.Text()
|
||||||
|
err = SaveMemory(&databases.InsertMemory{
|
||||||
|
UserId: userId,
|
||||||
|
Content: question,
|
||||||
|
Answer: resultText,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalln(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return resultText
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Chatbot) GetResponse(question string) string {
|
func (c *Chatbot) GetResponse(userId, question string) string {
|
||||||
switch c.Mode {
|
switch c.Mode {
|
||||||
case ChatbotMuffin:
|
case ChatbotMuffin:
|
||||||
return getDefaultResponse(c.s, question)
|
return getDefaultResponse(c.s, question)
|
||||||
default:
|
default:
|
||||||
return getAIResponse(question)
|
return getAIResponse(userId, question)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
41
chatbot/memory.go
Normal file
41
chatbot/memory.go
Normal file
|
@ -0,0 +1,41 @@
|
||||||
|
package chatbot
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"git.wh64.net/muffin/goMuffin/databases"
|
||||||
|
"go.mongodb.org/mongo-driver/v2/bson"
|
||||||
|
"google.golang.org/genai"
|
||||||
|
)
|
||||||
|
|
||||||
|
func SaveMemory(data *databases.InsertMemory) error {
|
||||||
|
_, err := databases.Database.Memory.InsertOne(context.TODO(), *data)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetMemory(userId string) ([]*genai.Content, error) {
|
||||||
|
var data []databases.Memory
|
||||||
|
|
||||||
|
MAX_LENGTH := 50
|
||||||
|
memory := []*genai.Content{}
|
||||||
|
|
||||||
|
cur, err := databases.Database.Memory.Find(context.TODO(), bson.D{{Key: "user_id", Value: userId}})
|
||||||
|
if err != nil {
|
||||||
|
return memory, err
|
||||||
|
}
|
||||||
|
|
||||||
|
cur.All(context.TODO(), &data)
|
||||||
|
|
||||||
|
if len(data) > MAX_LENGTH {
|
||||||
|
data = data[MAX_LENGTH:]
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, data := range data {
|
||||||
|
memory = append(memory,
|
||||||
|
genai.NewContentFromText(data.Content, genai.RoleUser),
|
||||||
|
genai.NewContentFromText(data.Answer, genai.RoleModel),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return memory, nil
|
||||||
|
}
|
|
@ -7,7 +7,7 @@ import (
|
||||||
"git.wh64.net/muffin/goMuffin/utils"
|
"git.wh64.net/muffin/goMuffin/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
const MUFFIN_VERSION = "0.0.0-madeleine_canary.250601a-muffin-ai"
|
const MUFFIN_VERSION = "0.0.0-madeleine_canary.250601b-muffin-ai"
|
||||||
|
|
||||||
var updatedString string = utils.RegexpDecimals.FindAllStringSubmatch(MUFFIN_VERSION, -1)[3][0]
|
var updatedString string = utils.RegexpDecimals.FindAllStringSubmatch(MUFFIN_VERSION, -1)[3][0]
|
||||||
|
|
||||||
|
|
16
databases/Memory.go
Normal file
16
databases/Memory.go
Normal file
|
@ -0,0 +1,16 @@
|
||||||
|
package databases
|
||||||
|
|
||||||
|
import "go.mongodb.org/mongo-driver/v2/bson"
|
||||||
|
|
||||||
|
type InsertMemory struct {
|
||||||
|
UserId string `bson:"user_id"`
|
||||||
|
Content string `bson:"content"`
|
||||||
|
Answer string `bson:"answer"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Memory struct {
|
||||||
|
Id bson.ObjectID `bson:"_id"`
|
||||||
|
UserId string `bson:"user_id"`
|
||||||
|
Content string `bson:"content"`
|
||||||
|
Answer string `bson:"answer"`
|
||||||
|
}
|
|
@ -12,6 +12,7 @@ type MuffinDatabase struct {
|
||||||
Client *mongo.Client
|
Client *mongo.Client
|
||||||
Learns *mongo.Collection
|
Learns *mongo.Collection
|
||||||
Texts *mongo.Collection
|
Texts *mongo.Collection
|
||||||
|
Memory *mongo.Collection
|
||||||
}
|
}
|
||||||
|
|
||||||
var Database *MuffinDatabase
|
var Database *MuffinDatabase
|
||||||
|
@ -34,5 +35,6 @@ func Connect() (*MuffinDatabase, error) {
|
||||||
Client: client,
|
Client: client,
|
||||||
Learns: client.Database(configs.Config.Database.Name).Collection("learn"),
|
Learns: client.Database(configs.Config.Database.Name).Collection("learn"),
|
||||||
Texts: client.Database(configs.Config.Database.Name).Collection("text"),
|
Texts: client.Database(configs.Config.Database.Name).Collection("text"),
|
||||||
|
Memory: client.Database(configs.Config.Database.Name).Collection("memory"),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -40,7 +40,7 @@ func MessageCreate(s *discordgo.Session, m *discordgo.MessageCreate) {
|
||||||
if command == "" {
|
if command == "" {
|
||||||
s.ChannelTyping(m.ChannelID)
|
s.ChannelTyping(m.ChannelID)
|
||||||
|
|
||||||
result := chatbot.ParseResult(chatbot.ChatBot.GetResponse(content), s, m)
|
result := chatbot.ParseResult(chatbot.ChatBot.GetResponse(m.Author.ID, content), s, m)
|
||||||
utils.NewMessageSender(&utils.MessageCreate{
|
utils.NewMessageSender(&utils.MessageCreate{
|
||||||
MessageCreate: m,
|
MessageCreate: m,
|
||||||
Session: s,
|
Session: s,
|
||||||
|
|
Loading…
Reference in a new issue