package service import ( "StuAcaWorksAI/dao" "StuAcaWorksAI/proto" "StuAcaWorksAI/worker" "crypto/hmac" "crypto/sha256" "encoding/base64" "encoding/json" "fmt" "github.com/gorilla/websocket" "github.com/ollama/ollama/api" "io" "log" "net/http" "net/url" "strings" "time" ) /***************** Spark请求结构 ***********************/ type SparkRequestHeader struct { AppID string `json:"app_id"` UID string `json:"uid,omitempty"` } type ChatParameter struct { Domain string `json:"domain"` Temperature float64 `json:"temperature"` MaxTokens int `json:"max_tokens"` } type Parameter struct { Chat ChatParameter `json:"chat"` } type MessageContent struct { Role string `json:"role"` Content string `json:"content"` ContentType string `json:"content_type,omitempty"` } type SparkMessage struct { Text []MessageContent `json:"text"` } type SparkRequestPayload struct { Message SparkMessage `json:"message"` } type SparkRequest struct { Header SparkRequestHeader `json:"header"` Parameter Parameter `json:"parameter"` Payload SparkRequestPayload `json:"payload"` } /***************** Spark请求结构 ***********************/ /***************** Spark响应结构 ***********************/ type SparkResponseHeader struct { Code int `json:"code"` Message string `json:"message"` SID string `json:"sid"` Status int `json:"status"` //会话状态,取值为[0,1,2];0代表首次结果;1代表中间结果;2代表最后一个结果 } // ChoiceText 结构体对应 JSON 中 "choices.text[].text" 部分 type ChoiceText struct { Content string `json:"content"` Role string `json:"role"` Index int `json:"index"` } // Choices 结构体对应 JSON 中 "choices" 部分 type Choices struct { Status int `json:"status"` Seq int `json:"seq"` Text []ChoiceText `json:"text"` } // UsageText 结构体对应 JSON 中 "usage.text" 部分 type UsageText struct { QuestionTokens int `json:"question_tokens"` PromptTokens int `json:"prompt_tokens"` CompletionTokens int `json:"completion_tokens"` TotalTokens int `json:"total_tokens"` } // Usage 结构体对应 JSON 中 "usage" 部分 type Usage struct { Text UsageText `json:"text"` } // Payload 结构体对应 JSON 中 "payload" 部分 type SparkResponsePayload struct { Choices Choices `json:"choices"` Usage Usage `json:"usage"` } // Request 结构体对应整个 JSON 数据 type SparkResponse struct { Header SparkResponseHeader `json:"header"` Payload SparkResponsePayload `json:"payload"` } /***************** Spark响应结构 ***********************/ func Spark(modelParam proto.ModelParam, question, channel string, SessionID, UserID, modelID int) { log.Println("spark model param:", modelParam) log.Println("spark question:", question) //建立websocket连接 dialer := websocket.Dialer{ HandshakeTimeout: 5 * time.Second, } //建立连接 conn, resp, err := dialer.Dial(assembleAuthUrl(modelParam.Url, modelParam.APIKey, modelParam.APISecret), nil) if err != nil { panic(readResp(resp) + err.Error()) return } else if resp.StatusCode != 101 { panic(readResp(resp) + err.Error()) } //发送消息 go func() { data := genSparkParams(UserID, modelParam.APPID, modelParam.Domain, SessionID, modelParam.System) //将数据转换为json dataByte, err3 := json.Marshal(data) if err3 != nil { fmt.Println("Error parsing JSON:", err) return } log.Println("send message:", string(dataByte)) return err2 := conn.WriteJSON(data) if err != nil { fmt.Println("write message error:", err2) return } }() answer := "" //接收消息 for { _, msg, err2 := conn.ReadMessage() if err2 != nil { log.Println("read message error:", err2) break } var data SparkResponse err = json.Unmarshal(msg, &data) if err != nil { log.Println("Error parsing JSON:", err) return } if data.Header.Code != 0 { log.Println("Error response:", data.Header.Message, string(msg)) break } if data.Header.Status == 2 { //最后一个结果 answer += data.Payload.Choices.Text[0].Content } else { answer += data.Payload.Choices.Text[0].Content } go CreateAIStreamMsg(UserID, modelID, SessionID, data, modelParam.Model) SparkToGeneralMassageAndSendMsgQueue(&data, channel, SessionID, UserID) } conn.Close() //将消息存入数据库 err2, msgID := dao.CreateMessage(SessionID, modelID, UserID, proto.ModelToUserMsgType, proto.MsgHasRead, answer, 2) if err2 != nil { log.Println("spark create message error:", err2) } else { log.Println("spark create message success,message id:", msgID) } } func SparkV2(modelParam proto.ModelParam, imCtx *proto.IMParamContext) { log.Println("spark model param:", modelParam) log.Println("spark question:", imCtx.Question) //建立websocket连接 dialer := websocket.Dialer{ HandshakeTimeout: 5 * time.Second, } //建立连接 conn, resp, err := dialer.Dial(assembleAuthUrl(modelParam.Url, modelParam.APIKey, modelParam.APISecret), nil) if err != nil { panic(readResp(resp) + err.Error()) return } else if resp.StatusCode != 101 { panic(readResp(resp) + err.Error()) } //发送消息 go func() { data := genSparkParams(imCtx.UserID, modelParam.APPID, modelParam.Domain, imCtx.SessionID, modelParam.System) //将数据转换为json dataByte, err3 := json.Marshal(data) if err3 != nil { fmt.Println("Error parsing JSON:", err) return } log.Println("spark send message:", string(dataByte)) err2 := conn.WriteJSON(data) if err != nil { fmt.Println("write message error:", err2) return } }() answer := "" tokens := 0 //接收消息 for { _, msg, err2 := conn.ReadMessage() if err2 != nil { log.Println("read message error:", err2) break } var data SparkResponse err = json.Unmarshal(msg, &data) if err != nil { log.Println("Error parsing JSON:", err) return } if data.Header.Code != 0 { log.Println("Error response:", data.Header.Message, string(msg)) break } if data.Header.Status == 2 { //最后一个结果 answer += data.Payload.Choices.Text[0].Content tokens += data.Payload.Usage.Text.TotalTokens } else { answer += data.Payload.Choices.Text[0].Content } go CreateAIStreamMsg(imCtx.UserID, imCtx.ModelID, imCtx.SessionID, data, modelParam.Model) SparkToGeneralMassageAndSendMsgQueue(&data, imCtx.Channel, imCtx.SessionID, imCtx.UserID) } conn.Close() //将消息存入数据库 err2, msgID := dao.CreateMessage(imCtx.SessionID, imCtx.ModelID, imCtx.UserID, proto.ModelToUserMsgType, proto.MsgHasRead, answer, imCtx.FunctionID) if err2 != nil { log.Println("spark create message error:", err2) } else { log.Println("spark create message success,message id:", msgID) } //存入tokens消耗信息 err3, modelTokenID := dao.CreateModelToken(uint(imCtx.UserID), uint(imCtx.ModelID), uint(imCtx.SessionID), uint(tokens)) if err3 != nil { log.Println("spark create message error:", err3) } else { log.Println("spark create model tokens id:", modelTokenID) } } func SparkToGeneralMassageAndSendMsgQueue(data *SparkResponse, channel string, SessionID int, UserID int) { var aiMsg proto.AIQueueMessage var wsMsg proto.WSMessage var apiMsg api.GenerateResponse apiMsg.Model = proto.ModelTypeSpark apiMsg.Response = data.Payload.Choices.Text[0].Content if data.Header.Status == 2 { apiMsg.Done = true } else { apiMsg.Done = false } //消息队列部分 aiMsg.Type = proto.ModelTypeSpark aiMsg.Msg = apiMsg //ws发送消息部分 wsMsg.Msg = aiMsg wsMsg.SessionID = SessionID wsMsg.ToID = UserID wsMsg.Type = proto.ModelTypeSpark //发送消息 wsMsgStr, _ := json.Marshal(wsMsg) worker.Publish(channel, string(wsMsgStr), time.Second*60) } // 生成参数 func genSparkParams(userID int, appid string, domain string, sessionID int, systemPrompt string) SparkRequest { //messages := []MessageContent{ // {Role: "user", Content: question}, //} messages2 := make([]MessageContent, 0) //err := GetSparkSessionHistoryMsg(sessionID, systemPrompt, &messages2, question) err := GetSparkSessionHistoryMsgV2(sessionID, systemPrompt, &messages2, userID) if err != nil { log.Println("get spark session history message error:", err) } //messages2Str, _ := json.Marshal(messages2) //log.Println("spark messages2:", string(messages2Str)) data := SparkRequest{ Header: SparkRequestHeader{ AppID: appid, }, Parameter: Parameter{ Chat: ChatParameter{ Domain: domain, Temperature: 0.8, MaxTokens: 4096, }, }, Payload: SparkRequestPayload{ Message: SparkMessage{ Text: messages2, }, }, } return data } // 创建鉴权url apikey 即 hmac username func assembleAuthUrl(hostUrl string, apiKey, apiSecret string) string { ul, err := url.Parse(hostUrl) if err != nil { fmt.Println(err) } //签名时间 date := time.Now().UTC().Format(time.RFC1123) //date = "Tue, 28 May 2019 09:10:42 MST" //参与签名的字段 host ,date, request-line signString := []string{"host: " + ul.Host, "date: " + date, "GET " + ul.Path + " HTTP/1.1"} //拼接签名字符串 sgin := strings.Join(signString, "\n") // fmt.Println(sgin) //签名结果 sha := HmacWithShaTobase64("hmac-sha256", sgin, apiSecret) // fmt.Println(sha) //构建请求参数 此时不需要urlencoding authUrl := fmt.Sprintf("hmac username=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey, "hmac-sha256", "host date request-line", sha) //将请求参数使用base64编码 authorization := base64.StdEncoding.EncodeToString([]byte(authUrl)) v := url.Values{} v.Add("host", ul.Host) v.Add("date", date) v.Add("authorization", authorization) //将编码后的字符串url encode后添加到url后面 callurl := hostUrl + "?" + v.Encode() return callurl } func HmacWithShaTobase64(algorithm, data, key string) string { mac := hmac.New(sha256.New, []byte(key)) mac.Write([]byte(data)) encodeData := mac.Sum(nil) return base64.StdEncoding.EncodeToString(encodeData) } func readResp(resp *http.Response) string { if resp == nil { return "" } b, err := io.ReadAll(resp.Body) if err != nil { panic(err) } return fmt.Sprintf("code=%d,body=%s", resp.StatusCode, string(b)) } // 获取spark模型会话的历史消息 func GetSparkSessionHistoryMsg(sessionID int, systemPrompt string, messages *[]MessageContent, prompt string) error { //获取会话的历史消息 msgs := dao.FindMessageBySessionIDV2(sessionID) if len(msgs) > proto.SparkContextLength { //只取最后的 SparkContextLength 条消息 msgs = msgs[len(msgs)-proto.SparkContextLength-1:] } //添加系统消息 if systemPrompt != "" { var message MessageContent message.Role = "system" message.Content = systemPrompt *messages = append(*messages, message) } for _, msg := range msgs { var message MessageContent if msg.Type == proto.UserToModelMsgType { message.Role = proto.SparkRoleUser } else if msg.Type == proto.ModelToUserMsgType { message.Role = proto.SparkRoleAssistant } else { continue } message.Content = msg.Msg *messages = append(*messages, message) } //添加本次请求消息(上面获取到的已包含本次请求消息) //var message MessageContent //message.Role = "user" //message.Content = prompt //*messages = append(*messages, message) return nil } // 获取spark模型会话的历史消息 func GetSparkSessionHistoryMsgV2(sessionID int, systemPrompt string, messages *[]MessageContent, userID int) error { //获取会话的历史消息 msgs := dao.FindMessageBySessionIDV2(sessionID) if len(msgs) > proto.SparkContextLength { //只取最后的 SparkContextLength 条消息 msgs = msgs[len(msgs)-proto.SparkContextLength-1:] } //添加系统消息 if systemPrompt != "" { var message MessageContent message.Role = "system" message.Content = systemPrompt *messages = append(*messages, message) } for _, msg := range msgs { var message MessageContent if msg.Type == proto.UserToModelMsgType { message.Role = proto.SparkRoleUser if msg.Status == proto.UserToModelFileMsgType { err2 := decodeFileMessageToSparkMessage(msg.Msg, userID, messages) if err2 != nil { log.Println("spark decode file message error:", err2) } } else { message.Content = msg.Msg *messages = append(*messages, message) } } else if msg.Type == proto.ModelToUserMsgType { message.Role = proto.SparkRoleAssistant message.Content = msg.Msg *messages = append(*messages, message) } else { continue } } //添加本次请求消息(上面获取到的已包含本次请求消息) //var message MessageContent //message.Role = "user" //message.Content = prompt //*messages = append(*messages, message) return nil } func decodeFileMessageToSparkMessage(msg string, userID int, message *[]MessageContent) error { var fileMsgContent proto.UserFileMessage err := json.Unmarshal([]byte(msg), &fileMsgContent) if err != nil { log.Println("spark unmarshal file message error:", err) return err } log.Println("sparkv2,decode file message:", fileMsgContent) for _, file := range fileMsgContent.FileContent { //文件消息为文本类型,则获取文本内容,若为图片则添加图片链接 if file.FileType == proto.FileTypeText { fileContents, err2 := FindFileContentV2(file.FileContent.FileID, userID) if err2 != nil { log.Println("doubao find file content error:", err2, "\t,file id:", file.FileContent.FileID, "\t,file name:", file.FileContent.UserFileName) continue } fileContent := fileContents[0].FileContent //文件文本内容 var messageContent MessageContent messageContent.Role = "user" messageContent.Content = fileContent messageContent.ContentType = "text" *message = append(*message, messageContent) } else if file.FileType == proto.FileTypeImage { imgBase64, err2 := FindImageFileContent(file.FileContent.FileID, userID) if err2 != nil { continue } //图片文件内容 var messageContent MessageContent messageContent.Role = "user" messageContent.Content = imgBase64 messageContent.ContentType = "image" *message = append(*message, messageContent) } else { log.Println("file type is not support:", file.FileType, "file name:", file.FileContent.UserFileName, "file id:", file.FileContent.FileID) } } var messageContent MessageContent messageContent.Role = "user" messageContent.Content = fileMsgContent.Text messageContent.ContentType = "text" *message = append(*message, messageContent) return nil }