diff --git a/proto/status.go b/proto/status.go index 2567129..8f7a148 100644 --- a/proto/status.go +++ b/proto/status.go @@ -117,3 +117,10 @@ const ( SparkRoleAssistant = "assistant" SparRoleSystem = "system" ) + +// 支持模型类型 +const ( + ModelTypeSpark = "spark" + ModelTypeDouBao = "doubao" + ModelTypeOllama = "ollama" +) diff --git a/service/doubao.go b/service/doubao.go index 68599b0..e986e54 100644 --- a/service/doubao.go +++ b/service/doubao.go @@ -81,7 +81,7 @@ func doubaoToGeneralMassageAndSendMsgQueue(data *model.ChatCompletionStreamRespo var wsMsg proto.WSMessage var apiMsg api.GenerateResponse - apiMsg.Model = "doubao" + apiMsg.Model = proto.ModelTypeDouBao choices := data.Choices[0] apiMsg.Response = choices.Delta.Content if choices.FinishReason == proto.FinishReasonStop { @@ -90,13 +90,13 @@ func doubaoToGeneralMassageAndSendMsgQueue(data *model.ChatCompletionStreamRespo apiMsg.Done = false } //消息队列部分 - aiMsg.Type = "doubao" + aiMsg.Type = proto.ModelTypeDouBao aiMsg.Msg = apiMsg //ws发送消息部分 wsMsg.Msg = aiMsg wsMsg.SessionID = SessionID wsMsg.ToID = UserID - wsMsg.Type = "doubao" + wsMsg.Type = proto.ModelTypeDouBao //发送消息 wsMsgStr, _ := json.Marshal(wsMsg) diff --git a/service/imService.go b/service/imService.go index 2bd513a..051dc38 100644 --- a/service/imService.go +++ b/service/imService.go @@ -99,11 +99,11 @@ func WSReceiveMessageService(userID, sessionID int, channel string, msg proto.WS var resErr error var resID uint //处理消息 - if model.Type == "ollama" { + if model.Type == proto.ModelTypeOllama { resErr, resID = ReceiveOllamaSession(userID, sessionID, channel, msg, modelParam, &model) - } else if model.Type == "spark" { + } else if model.Type == proto.ModelTypeSpark { resErr, resID = ReceiveSparkSession(userID, sessionID, channel, msg, modelParam, &model) - } else if model.Type == "doubao" { + } else if model.Type == proto.ModelTypeDouBao { doubao(modelParam, msg.Msg, channel, sessionID, userID, int(model.ID)) } else { return errors.New("model not exist"), 0 @@ -136,7 +136,7 @@ func ReceiveOllamaSession(userID, sessionID int, channel string, msg proto.WSMes var aiMsg proto.AIQueueMessage var data proto.WSMessage respFunc := func(resp api.GenerateResponse) error { - aiMsg.Type = "ollama" + aiMsg.Type = proto.ModelTypeOllama aiMsg.Msg = resp robotMsg += resp.Response if resp.Done { //该消息完成 diff --git a/service/spark.go b/service/spark.go index 33b606e..002b31f 100644 --- a/service/spark.go +++ b/service/spark.go @@ -182,7 +182,7 @@ func SparkToGeneralMassageAndSendMsgQueue(data *SparkResponse, channel string, S var aiMsg proto.AIQueueMessage var wsMsg proto.WSMessage var apiMsg api.GenerateResponse - apiMsg.Model = "spark" + apiMsg.Model = proto.ModelTypeSpark apiMsg.Response = data.Payload.Choices.Text[0].Content if data.Header.Status == 2 { apiMsg.Done = true @@ -190,13 +190,13 @@ func SparkToGeneralMassageAndSendMsgQueue(data *SparkResponse, channel string, S apiMsg.Done = false } //消息队列部分 - aiMsg.Type = "spark" + aiMsg.Type = proto.ModelTypeSpark aiMsg.Msg = apiMsg //ws发送消息部分 wsMsg.Msg = aiMsg wsMsg.SessionID = SessionID wsMsg.ToID = UserID - wsMsg.Type = "spark" + wsMsg.Type = proto.ModelTypeSpark //发送消息 wsMsgStr, _ := json.Marshal(wsMsg)