package worker import ( "StuAcaWorksAI/dao" "StuAcaWorksAI/proto" "crypto/hmac" "crypto/sha256" "encoding/base64" "encoding/json" "fmt" "github.com/gorilla/websocket" "github.com/ollama/ollama/api" "io" "log" "net/http" "net/url" "strings" "time" ) /***************** Spark请求结构 ***********************/ type SparkRequestHeader struct { AppID string `json:"app_id"` UID string `json:"uid"` } type ChatParameter struct { Domain string `json:"domain"` Temperature float64 `json:"temperature"` MaxTokens int `json:"max_tokens"` } type Parameter struct { Chat ChatParameter `json:"chat"` } type MessageContent struct { Role string `json:"role"` Content string `json:"content"` } type SparkMessage struct { Text []MessageContent `json:"text"` } type SparkRequestPayload struct { Message SparkMessage `json:"message"` } type SparkRequest struct { Header SparkRequestHeader `json:"header"` Parameter Parameter `json:"parameter"` Payload SparkRequestPayload `json:"payload"` } /***************** Spark请求结构 ***********************/ /***************** Spark响应结构 ***********************/ type SparkResponseHeader struct { Code int `json:"code"` Message string `json:"message"` SID string `json:"sid"` Status int `json:"status"` //会话状态,取值为[0,1,2];0代表首次结果;1代表中间结果;2代表最后一个结果 } // ChoiceText 结构体对应 JSON 中 "choices.text[].text" 部分 type ChoiceText struct { Content string `json:"content"` Role string `json:"role"` Index int `json:"index"` } // Choices 结构体对应 JSON 中 "choices" 部分 type Choices struct { Status int `json:"status"` Seq int `json:"seq"` Text []ChoiceText `json:"text"` } // UsageText 结构体对应 JSON 中 "usage.text" 部分 type UsageText struct { QuestionTokens int `json:"question_tokens"` PromptTokens int `json:"prompt_tokens"` CompletionTokens int `json:"completion_tokens"` TotalTokens int `json:"total_tokens"` } // Usage 结构体对应 JSON 中 "usage" 部分 type Usage struct { Text UsageText `json:"text"` } // Payload 结构体对应 JSON 中 "payload" 部分 type SparkResponsePayload struct { Choices Choices `json:"choices"` Usage Usage `json:"usage"` } // Request 结构体对应整个 JSON 数据 type SparkResponse struct { Header SparkResponseHeader `json:"header"` Payload SparkResponsePayload `json:"payload"` } /***************** Spark响应结构 ***********************/ func Spark(modelParam proto.ModelParam, question, channel string, SessionID, UserID int) { //建立websocket连接 dialer := websocket.Dialer{ HandshakeTimeout: 5 * time.Second, } //建立连接 conn, resp, err := dialer.Dial(assembleAuthUrl(modelParam.Url, modelParam.APIKey, modelParam.APISecret), nil) if err != nil { panic(readResp(resp) + err.Error()) return } else if resp.StatusCode != 101 { panic(readResp(resp) + err.Error()) } //发送消息 go func() { data := genSparkParams(modelParam.APPID, question) //将数据转换为json //dataByte, err := json.Marshal(data) if err != nil { fmt.Println("Error parsing JSON:", err) return } err2 := conn.WriteJSON(data) if err != nil { fmt.Println("write message error:", err2) return } }() answer := "" //接收消息 for { _, msg, err2 := conn.ReadMessage() if err2 != nil { log.Println("read message error:", err2) break } var data SparkResponse err = json.Unmarshal(msg, &data) if err != nil { log.Println("Error parsing JSON:", err) return } if data.Header.Code != 0 { log.Println("Error response:", data.Header.Message) break } if data.Header.Status == 2 { //最后一个结果 answer += data.Payload.Choices.Text[0].Content } else { answer += data.Payload.Choices.Text[0].Content } SparkToGeneralMassageAndSendMsgQueue(&data, channel, SessionID, UserID) } conn.Close() //将消息存入数据库 err, _ = dao.CreateMessage(SessionID, 3, UserID, 2, 1, answer) if err != nil { log.Println("spark create message error:", err) } } func SparkToGeneralMassageAndSendMsgQueue(data *SparkResponse, channel string, SessionID int, UserID int) { var aiMsg proto.AIQueueMessage var wsMsg proto.WSMessage var apiMsg api.GenerateResponse apiMsg.Model = "spark" apiMsg.Response = data.Payload.Choices.Text[0].Content if data.Header.Status == 2 { apiMsg.Done = true } else { apiMsg.Done = false } //消息队列部分 aiMsg.Type = "spark" aiMsg.Msg = apiMsg //ws发送消息部分 wsMsg.Msg = aiMsg wsMsg.SessionID = SessionID wsMsg.ToID = UserID wsMsg.Type = "spark" //发送消息 wsMsgStr, _ := json.Marshal(wsMsg) Publish(channel, string(wsMsgStr), time.Second*60) } // 生成参数 func genSparkParams(question, appid string) SparkRequest { messages := []MessageContent{ {Role: "user", Content: question}, } data := SparkRequest{ Header: SparkRequestHeader{ AppID: appid, }, Parameter: Parameter{ Chat: ChatParameter{ Domain: "4.0Ultra", Temperature: 0.8, MaxTokens: 2048, }, }, Payload: SparkRequestPayload{ Message: SparkMessage{ Text: messages, }, }, } return data } // 创建鉴权url apikey 即 hmac username func assembleAuthUrl(hostUrl string, apiKey, apiSecret string) string { ul, err := url.Parse(hostUrl) if err != nil { fmt.Println(err) } //签名时间 date := time.Now().UTC().Format(time.RFC1123) //date = "Tue, 28 May 2019 09:10:42 MST" //参与签名的字段 host ,date, request-line signString := []string{"host: " + ul.Host, "date: " + date, "GET " + ul.Path + " HTTP/1.1"} //拼接签名字符串 sgin := strings.Join(signString, "\n") // fmt.Println(sgin) //签名结果 sha := HmacWithShaTobase64("hmac-sha256", sgin, apiSecret) // fmt.Println(sha) //构建请求参数 此时不需要urlencoding authUrl := fmt.Sprintf("hmac username=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey, "hmac-sha256", "host date request-line", sha) //将请求参数使用base64编码 authorization := base64.StdEncoding.EncodeToString([]byte(authUrl)) v := url.Values{} v.Add("host", ul.Host) v.Add("date", date) v.Add("authorization", authorization) //将编码后的字符串url encode后添加到url后面 callurl := hostUrl + "?" + v.Encode() return callurl } func HmacWithShaTobase64(algorithm, data, key string) string { mac := hmac.New(sha256.New, []byte(key)) mac.Write([]byte(data)) encodeData := mac.Sum(nil) return base64.StdEncoding.EncodeToString(encodeData) } func readResp(resp *http.Response) string { if resp == nil { return "" } b, err := io.ReadAll(resp.Body) if err != nil { panic(err) } return fmt.Sprintf("code=%d,body=%s", resp.StatusCode, string(b)) }