diff --git a/dao/model.go b/dao/model.go index 75c6fb4..7a5f45a 100644 --- a/dao/model.go +++ b/dao/model.go @@ -19,7 +19,8 @@ type FunctionModel struct { Function string `gorm:"column:function"` //功能函数,唯一标识 UserID uint `gorm:"column:user_id"` //用户id ModelID uint `gorm:"column:model_id"` //模型id - Info string `gorm:"column:info"` //功能信息 + Info string `gorm:"column:info"` //功能信息,系统功能,对应模型系统参数 + //System string `gorm:"column:system"` //系统功能,对应模型系统参数 } // 创建模型 diff --git a/handler/im.go b/handler/im.go index 9d76a1c..eff1040 100644 --- a/handler/im.go +++ b/handler/im.go @@ -196,11 +196,14 @@ func doReceiveGenChatMessage(userId int, sessionID *uint, data *proto.WSMessageR } var model dao.Model //查看请求功能类型 - if data.Function == "gen-ai-chat" { - model = service.FindFuncModelByFunction(data.Function, userId)[0] - } else { + models, funcs := service.FindFuncModelByFunction(data.Function, userId) + if funcs == nil || len(funcs) == 0 { return errors.New("function not exist") } + if models == nil || len(models) == 0 { + return errors.New("model not exist") + } + model = models[0] //解析模型参数 var modelParam proto.ModelParam err = json.Unmarshal([]byte(model.Parameter), &modelParam) @@ -208,6 +211,7 @@ func doReceiveGenChatMessage(userId int, sessionID *uint, data *proto.WSMessageR return err } modelParam.Url = model.Url + modelParam.System = funcs[0].Info //系统功能 var userMsgID uint //将消息存入数据库 diff --git a/proto/im.go b/proto/im.go index 00040b8..0cabe7b 100644 --- a/proto/im.go +++ b/proto/im.go @@ -33,6 +33,7 @@ type ModelParam struct { APPID string `json:"appid"` //应用id APISecret string `json:"apiSecret"` //应用密钥 APIKey string `json:"apiKey"` //应用key + System string `json:"system"` //系统参数 } type SessionResponse struct { diff --git a/service/doubao.go b/service/doubao.go index e986e54..9a40975 100644 --- a/service/doubao.go +++ b/service/doubao.go @@ -24,7 +24,7 @@ func doubao(modelParam proto.ModelParam, question, channel string, SessionID, Us ) ctx := context.Background() messages := make([]*model.ChatCompletionMessage, 0) - err := GetDouBaoSessionHistoryMsg(SessionID, "你是豆包,是由字节跳动开发的 AI 人工智能助手", &messages, question) + err := GetDouBaoSessionHistoryMsg(SessionID, modelParam.System, &messages, question) if err != nil { log.Println("get doubao session history message error:", err) } diff --git a/service/modelService.go b/service/modelService.go index e5038e3..8d546b6 100644 --- a/service/modelService.go +++ b/service/modelService.go @@ -108,13 +108,13 @@ func UpdateFuncModelByID(id int, userID, modelID uint, name, info, function stri } // 根据功能查找对应功能 -func FindFuncModelByFunction(function string, userID int) []dao.Model { +func FindFuncModelByFunction(function string, userID int) ([]dao.Model, []dao.FunctionModel) { //先查找对应功能 funcModels := dao.FindFunctionModelByFunction(function) if len(funcModels) == 0 { - return nil + return nil, nil } modelID := funcModels[0].ModelID //再查找对应模型 - return dao.FindModelByID(int(modelID), userID) + return dao.FindModelByID(int(modelID), userID), funcModels } diff --git a/service/spark.go b/service/spark.go index 002b31f..ef43442 100644 --- a/service/spark.go +++ b/service/spark.go @@ -124,7 +124,7 @@ func Spark(modelParam proto.ModelParam, question, channel string, SessionID, Use //发送消息 go func() { - data := genSparkParams(question, modelParam.APPID, modelParam.Domain, SessionID, "") + data := genSparkParams(question, modelParam.APPID, modelParam.Domain, SessionID, modelParam.System) //将数据转换为json //dataByte, err := json.Marshal(data) if err != nil {