修改功能、模型获取逻辑,添加会话系统参数
This commit is contained in:
parent
7c4f3ef66f
commit
0089d2a1a3
|
|
@ -19,7 +19,8 @@ type FunctionModel struct {
|
|||
Function string `gorm:"column:function"` //功能函数,唯一标识
|
||||
UserID uint `gorm:"column:user_id"` //用户id
|
||||
ModelID uint `gorm:"column:model_id"` //模型id
|
||||
Info string `gorm:"column:info"` //功能信息
|
||||
Info string `gorm:"column:info"` //功能信息,系统功能,对应模型系统参数
|
||||
//System string `gorm:"column:system"` //系统功能,对应模型系统参数
|
||||
}
|
||||
|
||||
// 创建模型
|
||||
|
|
|
|||
|
|
@ -196,11 +196,14 @@ func doReceiveGenChatMessage(userId int, sessionID *uint, data *proto.WSMessageR
|
|||
}
|
||||
var model dao.Model
|
||||
//查看请求功能类型
|
||||
if data.Function == "gen-ai-chat" {
|
||||
model = service.FindFuncModelByFunction(data.Function, userId)[0]
|
||||
} else {
|
||||
models, funcs := service.FindFuncModelByFunction(data.Function, userId)
|
||||
if funcs == nil || len(funcs) == 0 {
|
||||
return errors.New("function not exist")
|
||||
}
|
||||
if models == nil || len(models) == 0 {
|
||||
return errors.New("model not exist")
|
||||
}
|
||||
model = models[0]
|
||||
//解析模型参数
|
||||
var modelParam proto.ModelParam
|
||||
err = json.Unmarshal([]byte(model.Parameter), &modelParam)
|
||||
|
|
@ -208,6 +211,7 @@ func doReceiveGenChatMessage(userId int, sessionID *uint, data *proto.WSMessageR
|
|||
return err
|
||||
}
|
||||
modelParam.Url = model.Url
|
||||
modelParam.System = funcs[0].Info //系统功能
|
||||
|
||||
var userMsgID uint
|
||||
//将消息存入数据库
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ type ModelParam struct {
|
|||
APPID string `json:"appid"` //应用id
|
||||
APISecret string `json:"apiSecret"` //应用密钥
|
||||
APIKey string `json:"apiKey"` //应用key
|
||||
System string `json:"system"` //系统参数
|
||||
}
|
||||
|
||||
type SessionResponse struct {
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ func doubao(modelParam proto.ModelParam, question, channel string, SessionID, Us
|
|||
)
|
||||
ctx := context.Background()
|
||||
messages := make([]*model.ChatCompletionMessage, 0)
|
||||
err := GetDouBaoSessionHistoryMsg(SessionID, "你是豆包,是由字节跳动开发的 AI 人工智能助手", &messages, question)
|
||||
err := GetDouBaoSessionHistoryMsg(SessionID, modelParam.System, &messages, question)
|
||||
if err != nil {
|
||||
log.Println("get doubao session history message error:", err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -108,13 +108,13 @@ func UpdateFuncModelByID(id int, userID, modelID uint, name, info, function stri
|
|||
}
|
||||
|
||||
// 根据功能查找对应功能
|
||||
func FindFuncModelByFunction(function string, userID int) []dao.Model {
|
||||
func FindFuncModelByFunction(function string, userID int) ([]dao.Model, []dao.FunctionModel) {
|
||||
//先查找对应功能
|
||||
funcModels := dao.FindFunctionModelByFunction(function)
|
||||
if len(funcModels) == 0 {
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
modelID := funcModels[0].ModelID
|
||||
//再查找对应模型
|
||||
return dao.FindModelByID(int(modelID), userID)
|
||||
return dao.FindModelByID(int(modelID), userID), funcModels
|
||||
}
|
||||
|
|
|
|||
|
|
@ -124,7 +124,7 @@ func Spark(modelParam proto.ModelParam, question, channel string, SessionID, Use
|
|||
|
||||
//发送消息
|
||||
go func() {
|
||||
data := genSparkParams(question, modelParam.APPID, modelParam.Domain, SessionID, "")
|
||||
data := genSparkParams(question, modelParam.APPID, modelParam.Domain, SessionID, modelParam.System)
|
||||
//将数据转换为json
|
||||
//dataByte, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
|
|
|
|||
Loading…
Reference in New Issue