saw-go/worker/spark.go

277 lines
6.8 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package worker
import (
"StuAcaWorksAI/dao"
"StuAcaWorksAI/proto"
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"github.com/gorilla/websocket"
"github.com/ollama/ollama/api"
"io"
"log"
"net/http"
"net/url"
"strings"
"time"
)
/***************** Spark请求结构 ***********************/
type SparkRequestHeader struct {
AppID string `json:"app_id"`
UID string `json:"uid"`
}
type ChatParameter struct {
Domain string `json:"domain"`
Temperature float64 `json:"temperature"`
MaxTokens int `json:"max_tokens"`
}
type Parameter struct {
Chat ChatParameter `json:"chat"`
}
type MessageContent struct {
Role string `json:"role"`
Content string `json:"content"`
}
type SparkMessage struct {
Text []MessageContent `json:"text"`
}
type SparkRequestPayload struct {
Message SparkMessage `json:"message"`
}
type SparkRequest struct {
Header SparkRequestHeader `json:"header"`
Parameter Parameter `json:"parameter"`
Payload SparkRequestPayload `json:"payload"`
}
/***************** Spark请求结构 ***********************/
/***************** Spark响应结构 ***********************/
type SparkResponseHeader struct {
Code int `json:"code"`
Message string `json:"message"`
SID string `json:"sid"`
Status int `json:"status"` //会话状态,取值为[0,1,2]0代表首次结果1代表中间结果2代表最后一个结果
}
// ChoiceText 结构体对应 JSON 中 "choices.text[].text" 部分
type ChoiceText struct {
Content string `json:"content"`
Role string `json:"role"`
Index int `json:"index"`
}
// Choices 结构体对应 JSON 中 "choices" 部分
type Choices struct {
Status int `json:"status"`
Seq int `json:"seq"`
Text []ChoiceText `json:"text"`
}
// UsageText 结构体对应 JSON 中 "usage.text" 部分
type UsageText struct {
QuestionTokens int `json:"question_tokens"`
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
// Usage 结构体对应 JSON 中 "usage" 部分
type Usage struct {
Text UsageText `json:"text"`
}
// Payload 结构体对应 JSON 中 "payload" 部分
type SparkResponsePayload struct {
Choices Choices `json:"choices"`
Usage Usage `json:"usage"`
}
// Request 结构体对应整个 JSON 数据
type SparkResponse struct {
Header SparkResponseHeader `json:"header"`
Payload SparkResponsePayload `json:"payload"`
}
/***************** Spark响应结构 ***********************/
func Spark(modelParam proto.ModelParam, question, channel string, SessionID, UserID int) {
log.Println("spark model param:", modelParam)
log.Println("spark question:", question)
//建立websocket连接
dialer := websocket.Dialer{
HandshakeTimeout: 5 * time.Second,
}
//建立连接
conn, resp, err := dialer.Dial(assembleAuthUrl(modelParam.Url, modelParam.APIKey, modelParam.APISecret), nil)
if err != nil {
panic(readResp(resp) + err.Error())
return
} else if resp.StatusCode != 101 {
panic(readResp(resp) + err.Error())
}
//发送消息
go func() {
data := genSparkParams(modelParam.APPID, question)
//将数据转换为json
//dataByte, err := json.Marshal(data)
if err != nil {
fmt.Println("Error parsing JSON:", err)
return
}
err2 := conn.WriteJSON(data)
if err != nil {
fmt.Println("write message error:", err2)
return
}
}()
answer := ""
//接收消息
for {
_, msg, err2 := conn.ReadMessage()
if err2 != nil {
log.Println("read message error:", err2)
break
}
var data SparkResponse
err = json.Unmarshal(msg, &data)
if err != nil {
log.Println("Error parsing JSON:", err)
return
}
if data.Header.Code != 0 {
log.Println("Error response:", data.Header.Message, msg)
break
}
if data.Header.Status == 2 {
//最后一个结果
answer += data.Payload.Choices.Text[0].Content
} else {
answer += data.Payload.Choices.Text[0].Content
}
SparkToGeneralMassageAndSendMsgQueue(&data, channel, SessionID, UserID)
}
conn.Close()
//将消息存入数据库
err, _ = dao.CreateMessage(SessionID, 3, UserID, 2, 1, answer)
if err != nil {
log.Println("spark create message error:", err)
}
}
func SparkToGeneralMassageAndSendMsgQueue(data *SparkResponse, channel string, SessionID int, UserID int) {
var aiMsg proto.AIQueueMessage
var wsMsg proto.WSMessage
var apiMsg api.GenerateResponse
apiMsg.Model = "spark"
apiMsg.Response = data.Payload.Choices.Text[0].Content
if data.Header.Status == 2 {
apiMsg.Done = true
} else {
apiMsg.Done = false
}
//消息队列部分
aiMsg.Type = "spark"
aiMsg.Msg = apiMsg
//ws发送消息部分
wsMsg.Msg = aiMsg
wsMsg.SessionID = SessionID
wsMsg.ToID = UserID
wsMsg.Type = "spark"
//发送消息
wsMsgStr, _ := json.Marshal(wsMsg)
Publish(channel, string(wsMsgStr), time.Second*60)
}
// 生成参数
func genSparkParams(question, appid string) SparkRequest {
messages := []MessageContent{
{Role: "user", Content: question},
}
data := SparkRequest{
Header: SparkRequestHeader{
AppID: appid,
},
Parameter: Parameter{
Chat: ChatParameter{
Domain: "4.0Ultra",
Temperature: 0.8,
MaxTokens: 2048,
},
},
Payload: SparkRequestPayload{
Message: SparkMessage{
Text: messages,
},
},
}
return data
}
// 创建鉴权url apikey 即 hmac username
func assembleAuthUrl(hostUrl string, apiKey, apiSecret string) string {
ul, err := url.Parse(hostUrl)
if err != nil {
fmt.Println(err)
}
//签名时间
date := time.Now().UTC().Format(time.RFC1123)
//date = "Tue, 28 May 2019 09:10:42 MST"
//参与签名的字段 host ,date, request-line
signString := []string{"host: " + ul.Host, "date: " + date, "GET " + ul.Path + " HTTP/1.1"}
//拼接签名字符串
sgin := strings.Join(signString, "\n")
// fmt.Println(sgin)
//签名结果
sha := HmacWithShaTobase64("hmac-sha256", sgin, apiSecret)
// fmt.Println(sha)
//构建请求参数 此时不需要urlencoding
authUrl := fmt.Sprintf("hmac username=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey,
"hmac-sha256", "host date request-line", sha)
//将请求参数使用base64编码
authorization := base64.StdEncoding.EncodeToString([]byte(authUrl))
v := url.Values{}
v.Add("host", ul.Host)
v.Add("date", date)
v.Add("authorization", authorization)
//将编码后的字符串url encode后添加到url后面
callurl := hostUrl + "?" + v.Encode()
return callurl
}
func HmacWithShaTobase64(algorithm, data, key string) string {
mac := hmac.New(sha256.New, []byte(key))
mac.Write([]byte(data))
encodeData := mac.Sum(nil)
return base64.StdEncoding.EncodeToString(encodeData)
}
func readResp(resp *http.Response) string {
if resp == nil {
return ""
}
b, err := io.ReadAll(resp.Body)
if err != nil {
panic(err)
}
return fmt.Sprintf("code=%d,body=%s", resp.StatusCode, string(b))
}