修改功能、模型获取逻辑,添加会话系统参数

This commit is contained in:
junleea 2025-03-27 14:36:21 +08:00
parent 7c4f3ef66f
commit 0089d2a1a3
6 changed files with 15 additions and 9 deletions

View File

@ -19,7 +19,8 @@ type FunctionModel struct {
Function string `gorm:"column:function"` //功能函数,唯一标识
UserID uint `gorm:"column:user_id"` //用户id
ModelID uint `gorm:"column:model_id"` //模型id
Info string `gorm:"column:info"` //功能信息
Info string `gorm:"column:info"` //功能信息,系统功能,对应模型系统参数
//System string `gorm:"column:system"` //系统功能,对应模型系统参数
}
// 创建模型

View File

@ -196,11 +196,14 @@ func doReceiveGenChatMessage(userId int, sessionID *uint, data *proto.WSMessageR
}
var model dao.Model
//查看请求功能类型
if data.Function == "gen-ai-chat" {
model = service.FindFuncModelByFunction(data.Function, userId)[0]
} else {
models, funcs := service.FindFuncModelByFunction(data.Function, userId)
if funcs == nil || len(funcs) == 0 {
return errors.New("function not exist")
}
if models == nil || len(models) == 0 {
return errors.New("model not exist")
}
model = models[0]
//解析模型参数
var modelParam proto.ModelParam
err = json.Unmarshal([]byte(model.Parameter), &modelParam)
@ -208,6 +211,7 @@ func doReceiveGenChatMessage(userId int, sessionID *uint, data *proto.WSMessageR
return err
}
modelParam.Url = model.Url
modelParam.System = funcs[0].Info //系统功能
var userMsgID uint
//将消息存入数据库

View File

@ -33,6 +33,7 @@ type ModelParam struct {
APPID string `json:"appid"` //应用id
APISecret string `json:"apiSecret"` //应用密钥
APIKey string `json:"apiKey"` //应用key
System string `json:"system"` //系统参数
}
type SessionResponse struct {

View File

@ -24,7 +24,7 @@ func doubao(modelParam proto.ModelParam, question, channel string, SessionID, Us
)
ctx := context.Background()
messages := make([]*model.ChatCompletionMessage, 0)
err := GetDouBaoSessionHistoryMsg(SessionID, "你是豆包,是由字节跳动开发的 AI 人工智能助手", &messages, question)
err := GetDouBaoSessionHistoryMsg(SessionID, modelParam.System, &messages, question)
if err != nil {
log.Println("get doubao session history message error:", err)
}

View File

@ -108,13 +108,13 @@ func UpdateFuncModelByID(id int, userID, modelID uint, name, info, function stri
}
// 根据功能查找对应功能
func FindFuncModelByFunction(function string, userID int) []dao.Model {
func FindFuncModelByFunction(function string, userID int) ([]dao.Model, []dao.FunctionModel) {
//先查找对应功能
funcModels := dao.FindFunctionModelByFunction(function)
if len(funcModels) == 0 {
return nil
return nil, nil
}
modelID := funcModels[0].ModelID
//再查找对应模型
return dao.FindModelByID(int(modelID), userID)
return dao.FindModelByID(int(modelID), userID), funcModels
}

View File

@ -124,7 +124,7 @@ func Spark(modelParam proto.ModelParam, question, channel string, SessionID, Use
//发送消息
go func() {
data := genSparkParams(question, modelParam.APPID, modelParam.Domain, SessionID, "")
data := genSparkParams(question, modelParam.APPID, modelParam.Domain, SessionID, modelParam.System)
//将数据转换为json
//dataByte, err := json.Marshal(data)
if err != nil {