saw-go/service/spark.go

323 lines
8.6 KiB
Go
Raw Normal View History

2025-03-25 14:44:58 +08:00
package service
2025-03-23 16:03:17 +08:00
import (
"StuAcaWorksAI/dao"
"StuAcaWorksAI/proto"
2025-03-25 14:44:58 +08:00
"StuAcaWorksAI/worker"
2025-03-23 16:03:17 +08:00
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"github.com/gorilla/websocket"
"github.com/ollama/ollama/api"
"io"
"log"
"net/http"
"net/url"
"strings"
"time"
)
/***************** Spark请求结构 ***********************/
type SparkRequestHeader struct {
AppID string `json:"app_id"`
2025-03-24 15:42:32 +08:00
UID string `json:"uid,omitempty"`
2025-03-23 16:03:17 +08:00
}
type ChatParameter struct {
Domain string `json:"domain"`
Temperature float64 `json:"temperature"`
MaxTokens int `json:"max_tokens"`
}
type Parameter struct {
Chat ChatParameter `json:"chat"`
}
type MessageContent struct {
Role string `json:"role"`
Content string `json:"content"`
}
type SparkMessage struct {
Text []MessageContent `json:"text"`
}
type SparkRequestPayload struct {
Message SparkMessage `json:"message"`
}
type SparkRequest struct {
Header SparkRequestHeader `json:"header"`
Parameter Parameter `json:"parameter"`
Payload SparkRequestPayload `json:"payload"`
}
/***************** Spark请求结构 ***********************/
/***************** Spark响应结构 ***********************/
type SparkResponseHeader struct {
Code int `json:"code"`
Message string `json:"message"`
SID string `json:"sid"`
Status int `json:"status"` //会话状态,取值为[0,1,2]0代表首次结果1代表中间结果2代表最后一个结果
}
// ChoiceText 结构体对应 JSON 中 "choices.text[].text" 部分
type ChoiceText struct {
Content string `json:"content"`
Role string `json:"role"`
Index int `json:"index"`
}
// Choices 结构体对应 JSON 中 "choices" 部分
type Choices struct {
Status int `json:"status"`
Seq int `json:"seq"`
Text []ChoiceText `json:"text"`
}
// UsageText 结构体对应 JSON 中 "usage.text" 部分
type UsageText struct {
QuestionTokens int `json:"question_tokens"`
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
// Usage 结构体对应 JSON 中 "usage" 部分
type Usage struct {
Text UsageText `json:"text"`
}
// Payload 结构体对应 JSON 中 "payload" 部分
type SparkResponsePayload struct {
Choices Choices `json:"choices"`
Usage Usage `json:"usage"`
}
// Request 结构体对应整个 JSON 数据
type SparkResponse struct {
Header SparkResponseHeader `json:"header"`
Payload SparkResponsePayload `json:"payload"`
}
/***************** Spark响应结构 ***********************/
2025-03-25 15:51:05 +08:00
func Spark(modelParam proto.ModelParam, question, channel string, SessionID, UserID, modelID int) {
2025-03-24 15:22:43 +08:00
log.Println("spark model param:", modelParam)
log.Println("spark question:", question)
2025-03-23 16:03:17 +08:00
//建立websocket连接
dialer := websocket.Dialer{
HandshakeTimeout: 5 * time.Second,
}
//建立连接
conn, resp, err := dialer.Dial(assembleAuthUrl(modelParam.Url, modelParam.APIKey, modelParam.APISecret), nil)
2025-03-23 16:03:17 +08:00
if err != nil {
panic(readResp(resp) + err.Error())
return
} else if resp.StatusCode != 101 {
panic(readResp(resp) + err.Error())
}
//发送消息
go func() {
data := genSparkParams(question, modelParam.APPID, modelParam.Domain, SessionID, modelParam.System)
2025-03-23 16:03:17 +08:00
//将数据转换为json
//dataByte, err := json.Marshal(data)
if err != nil {
fmt.Println("Error parsing JSON:", err)
return
}
err2 := conn.WriteJSON(data)
if err != nil {
fmt.Println("write message error:", err2)
return
}
}()
answer := ""
//接收消息
for {
_, msg, err2 := conn.ReadMessage()
if err2 != nil {
log.Println("read message error:", err2)
break
}
var data SparkResponse
err = json.Unmarshal(msg, &data)
if err != nil {
log.Println("Error parsing JSON:", err)
return
}
if data.Header.Code != 0 {
2025-03-24 15:42:32 +08:00
log.Println("Error response:", data.Header.Message, string(msg))
2025-03-23 16:03:17 +08:00
break
}
if data.Header.Status == 2 {
//最后一个结果
answer += data.Payload.Choices.Text[0].Content
} else {
answer += data.Payload.Choices.Text[0].Content
}
go CreateAIStreamMsg(UserID, modelID, SessionID, data, modelParam.Model)
2025-03-23 16:03:17 +08:00
SparkToGeneralMassageAndSendMsgQueue(&data, channel, SessionID, UserID)
}
conn.Close()
//将消息存入数据库
err2, msgID := dao.CreateMessage(SessionID, modelID, UserID, proto.ModelToUserMsgType, proto.MsgHasRead, answer)
if err2 != nil {
log.Println("spark create message error:", err2)
} else {
log.Println("spark create message success,message id:", msgID)
2025-03-23 16:03:17 +08:00
}
}
func SparkToGeneralMassageAndSendMsgQueue(data *SparkResponse, channel string, SessionID int, UserID int) {
2025-03-23 16:03:17 +08:00
var aiMsg proto.AIQueueMessage
var wsMsg proto.WSMessage
var apiMsg api.GenerateResponse
2025-03-26 20:28:22 +08:00
apiMsg.Model = proto.ModelTypeSpark
2025-03-23 16:03:17 +08:00
apiMsg.Response = data.Payload.Choices.Text[0].Content
if data.Header.Status == 2 {
apiMsg.Done = true
} else {
apiMsg.Done = false
}
//消息队列部分
2025-03-26 20:28:22 +08:00
aiMsg.Type = proto.ModelTypeSpark
2025-03-23 16:03:17 +08:00
aiMsg.Msg = apiMsg
//ws发送消息部分
wsMsg.Msg = aiMsg
wsMsg.SessionID = SessionID
wsMsg.ToID = UserID
2025-03-26 20:28:22 +08:00
wsMsg.Type = proto.ModelTypeSpark
2025-03-23 16:03:17 +08:00
//发送消息
wsMsgStr, _ := json.Marshal(wsMsg)
2025-03-25 14:44:58 +08:00
worker.Publish(channel, string(wsMsgStr), time.Second*60)
2025-03-23 16:03:17 +08:00
}
// 生成参数
2025-03-26 14:07:27 +08:00
func genSparkParams(question, appid string, domain string, sessionID int, systemPrompt string) SparkRequest {
2025-03-26 14:26:09 +08:00
//messages := []MessageContent{
// {Role: "user", Content: question},
//}
2025-03-26 14:07:27 +08:00
messages2 := make([]MessageContent, 0)
err := GetSparkSessionHistoryMsg(sessionID, systemPrompt, &messages2, question)
if err != nil {
log.Println("get spark session history message error:", err)
}
2025-03-26 14:26:09 +08:00
//messages2Str, _ := json.Marshal(messages2)
//log.Println("spark messages2:", string(messages2Str))
2025-03-23 16:03:17 +08:00
data := SparkRequest{
Header: SparkRequestHeader{
AppID: appid,
},
Parameter: Parameter{
Chat: ChatParameter{
2025-03-24 19:57:28 +08:00
Domain: domain,
2025-03-23 16:03:17 +08:00
Temperature: 0.8,
MaxTokens: 2048,
},
},
Payload: SparkRequestPayload{
Message: SparkMessage{
2025-03-26 14:26:09 +08:00
Text: messages2,
2025-03-23 16:03:17 +08:00
},
},
}
return data
}
// 创建鉴权url apikey 即 hmac username
func assembleAuthUrl(hostUrl string, apiKey, apiSecret string) string {
ul, err := url.Parse(hostUrl)
if err != nil {
fmt.Println(err)
}
//签名时间
date := time.Now().UTC().Format(time.RFC1123)
//date = "Tue, 28 May 2019 09:10:42 MST"
//参与签名的字段 host ,date, request-line
signString := []string{"host: " + ul.Host, "date: " + date, "GET " + ul.Path + " HTTP/1.1"}
//拼接签名字符串
sgin := strings.Join(signString, "\n")
// fmt.Println(sgin)
//签名结果
sha := HmacWithShaTobase64("hmac-sha256", sgin, apiSecret)
// fmt.Println(sha)
//构建请求参数 此时不需要urlencoding
authUrl := fmt.Sprintf("hmac username=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey,
"hmac-sha256", "host date request-line", sha)
//将请求参数使用base64编码
authorization := base64.StdEncoding.EncodeToString([]byte(authUrl))
v := url.Values{}
v.Add("host", ul.Host)
v.Add("date", date)
v.Add("authorization", authorization)
//将编码后的字符串url encode后添加到url后面
callurl := hostUrl + "?" + v.Encode()
return callurl
}
func HmacWithShaTobase64(algorithm, data, key string) string {
mac := hmac.New(sha256.New, []byte(key))
mac.Write([]byte(data))
encodeData := mac.Sum(nil)
return base64.StdEncoding.EncodeToString(encodeData)
}
func readResp(resp *http.Response) string {
if resp == nil {
return ""
}
b, err := io.ReadAll(resp.Body)
if err != nil {
panic(err)
}
return fmt.Sprintf("code=%d,body=%s", resp.StatusCode, string(b))
}
2025-03-26 14:07:27 +08:00
// 获取spark模型会话的历史消息
func GetSparkSessionHistoryMsg(sessionID int, systemPrompt string, messages *[]MessageContent, prompt string) error {
//获取会话的历史消息
msgs := dao.FindMessageBySessionIDV2(sessionID)
if len(msgs) > proto.SparkContextLength {
//只取最后的 SparkContextLength 条消息
msgs = msgs[len(msgs)-proto.SparkContextLength-1:]
}
2025-03-26 14:07:27 +08:00
//添加系统消息
if systemPrompt != "" {
var message MessageContent
message.Role = "system"
message.Content = systemPrompt
*messages = append(*messages, message)
}
for _, msg := range msgs {
var message MessageContent
if msg.Type == proto.UserToModelMsgType {
2025-03-26 14:26:09 +08:00
message.Role = proto.SparkRoleUser
} else if msg.Type == proto.ModelToUserMsgType {
message.Role = proto.SparkRoleAssistant
2025-03-26 14:07:27 +08:00
} else {
2025-03-26 14:26:09 +08:00
continue
2025-03-26 14:07:27 +08:00
}
2025-03-26 14:26:09 +08:00
message.Content = msg.Msg
2025-03-26 14:07:27 +08:00
*messages = append(*messages, message)
}
2025-03-26 14:26:09 +08:00
//添加本次请求消息(上面获取到的已包含本次请求消息)
//var message MessageContent
//message.Role = "user"
//message.Content = prompt
//*messages = append(*messages, message)
2025-03-26 14:07:27 +08:00
return nil
}