- 修复 MaxPreview=0 仍被覆盖为默认值的 bug - 修复 API Endpoint 自动补全逻辑(避免 /v1/v1/chat/completions) - 为 AI 配置与匹配状态字段增加并发锁 - AI 增强未匹配行改为按索引跟踪,避免重复行误判 - 无时间列时 AI 匹配 B 表行数可配置并增加截断警告 - 导出时防御参差不齐行导致的数组越界 panic - Excel 读取时对单元格统一 TrimSpace - 删除未使用的 minInt 函数 - 修复 wails.json 开发服务器地址为 http://localhost:5173 - 重新生成 Wails 前端绑定 - 新增 ai_test.go / export_test.go 单元测试
249 lines
7.2 KiB
Go
249 lines
7.2 KiB
Go
package main
|
||
|
||
import (
|
||
"bytes"
|
||
"crypto/sha256"
|
||
"encoding/hex"
|
||
"encoding/json"
|
||
"fmt"
|
||
"io"
|
||
"net/http"
|
||
"path/filepath"
|
||
"strings"
|
||
"time"
|
||
)
|
||
|
||
// ---------- Deepseek API 类型 ----------
|
||
|
||
type deepseekMessage struct {
|
||
Role string `json:"role"`
|
||
Content string `json:"content"`
|
||
}
|
||
|
||
type deepseekRequest struct {
|
||
Model string `json:"model"`
|
||
Messages []deepseekMessage `json:"messages"`
|
||
Temperature float64 `json:"temperature"`
|
||
MaxTokens int `json:"max_tokens,omitempty"`
|
||
}
|
||
|
||
type deepseekResponse struct {
|
||
Choices []struct {
|
||
Message struct {
|
||
Content string `json:"content"`
|
||
} `json:"message"`
|
||
} `json:"choices"`
|
||
Error *struct {
|
||
Message string `json:"message"`
|
||
} `json:"error,omitempty"`
|
||
}
|
||
|
||
// maxPromptBChars 限制 B 表数据在 prompt 中的最大字符数
|
||
const maxPromptBChars = 80000
|
||
|
||
// ---------- AI API 调用 ----------
|
||
|
||
// buildRowCacheKey 为单行匹配构建缓存键
|
||
func (a *App) buildRowCacheKey(matchValue, timeStr string, config MatchConfig) string {
|
||
parts := fmt.Sprintf("%s|%s|%s|%.1f|%s",
|
||
matchValue, timeStr, config.RegexPattern, config.TimeWindow,
|
||
filepath.Base(config.FileBPath))
|
||
h := sha256.Sum256([]byte(parts))
|
||
return hex.EncodeToString(h[:])
|
||
}
|
||
|
||
// hashPrompt 对 prompt 消息计算 SHA256(用于缓存键)
|
||
func hashPrompt(messages []deepseekMessage) string {
|
||
h := sha256.New()
|
||
for _, m := range messages {
|
||
h.Write([]byte(m.Role))
|
||
h.Write([]byte{0})
|
||
h.Write([]byte(m.Content))
|
||
h.Write([]byte{0})
|
||
}
|
||
return hex.EncodeToString(h.Sum(nil))
|
||
}
|
||
|
||
// resolveAIEndpoint 根据用户输入补全为完整的 OpenAI 兼容 chat completions URL
|
||
func resolveAIEndpoint(apiEndpoint string) string {
|
||
endpoint := strings.TrimRight(apiEndpoint, "/")
|
||
switch {
|
||
case endpoint == "":
|
||
return "https://api.deepseek.com/v1/chat/completions"
|
||
case strings.HasSuffix(endpoint, "/chat/completions"):
|
||
return endpoint
|
||
case strings.HasSuffix(endpoint, "/v1"):
|
||
return endpoint + "/chat/completions"
|
||
default:
|
||
return endpoint + "/v1/chat/completions"
|
||
}
|
||
}
|
||
|
||
// callAIAPI 调用 OpenAI 兼容 API(Deepseek / OpenAI / 本地模型 等)
|
||
func (a *App) callAIAPI(messages []deepseekMessage) (string, error) {
|
||
a.aiMu.RLock()
|
||
apiKey := a.apiKey
|
||
apiEndpoint := a.apiEndpoint
|
||
apiModel := a.apiModel
|
||
a.aiMu.RUnlock()
|
||
|
||
if apiKey == "" {
|
||
return "", fmt.Errorf("请先设置 AI API 密钥")
|
||
}
|
||
|
||
endpoint := resolveAIEndpoint(apiEndpoint)
|
||
model := apiModel
|
||
if model == "" {
|
||
model = deepseekModel
|
||
}
|
||
|
||
hash := hashPrompt(messages)
|
||
|
||
// 先查缓存
|
||
if cached, ok := a.aiCache.get(hash); ok {
|
||
fmt.Printf("[CACHE] ✓ 命中 AI 缓存 (hash=%s)\n", hash[:12])
|
||
return cached, nil
|
||
}
|
||
fmt.Printf("[CACHE] ✗ 缓存未命中 (hash=%s),调用 %s...\n", hash[:12], endpoint)
|
||
|
||
reqBody := deepseekRequest{
|
||
Model: model,
|
||
Messages: messages,
|
||
Temperature: deepseekTemperature,
|
||
MaxTokens: deepseekMaxTokens,
|
||
}
|
||
|
||
bodyBytes, _ := json.Marshal(reqBody)
|
||
httpReq, err := http.NewRequest("POST", endpoint, bytes.NewReader(bodyBytes))
|
||
if err != nil {
|
||
return "", fmt.Errorf("创建请求失败: %v", err)
|
||
}
|
||
httpReq.Header.Set("Content-Type", "application/json")
|
||
httpReq.Header.Set("Authorization", "Bearer "+apiKey)
|
||
|
||
client := &http.Client{Timeout: 30 * time.Second}
|
||
resp, err := client.Do(httpReq)
|
||
if err != nil {
|
||
return "", fmt.Errorf("调用 AI API 失败: %v", err)
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
respBytes, err := io.ReadAll(resp.Body)
|
||
if err != nil {
|
||
return "", fmt.Errorf("读取 AI 响应失败: %v", err)
|
||
}
|
||
var dr deepseekResponse
|
||
if err := json.Unmarshal(respBytes, &dr); err != nil {
|
||
return "", fmt.Errorf("解析 AI 响应失败: %v", err)
|
||
}
|
||
|
||
if dr.Error != nil {
|
||
return "", fmt.Errorf("AI API 错误: %s", dr.Error.Message)
|
||
}
|
||
|
||
if len(dr.Choices) == 0 {
|
||
return "", fmt.Errorf("AI 未返回有效结果")
|
||
}
|
||
|
||
result := strings.TrimSpace(dr.Choices[0].Message.Content)
|
||
|
||
// 写入缓存并持久化
|
||
a.aiCache.put(hash, result)
|
||
a.aiCache.saveToFile()
|
||
|
||
return result, nil
|
||
}
|
||
|
||
// buildGenericAIPrompt 构建通用 AI 匹配提示词
|
||
func (a *App) buildGenericAIPrompt(unmatched, bRows [][]string, config MatchConfig, windowDuration time.Duration, hasTime bool) []deepseekMessage {
|
||
var sb strings.Builder
|
||
sb.WriteString("你是一个数据匹配专家。请根据以下 A 表记录,从 B 表数据中找出最匹配的记录。\n\n")
|
||
sb.WriteString("匹配规则:\n")
|
||
sb.WriteString("1. 根据文本相似度匹配(注意中文字段的核心含义,忽略字母数字前缀后缀)\n")
|
||
if hasTime {
|
||
sb.WriteString(fmt.Sprintf("2. 时间差应在 %.0f 小时内\n", windowDuration.Hours()))
|
||
}
|
||
sb.WriteString(fmt.Sprintf("3. 返回匹配到的 B 表记录的目标列值(第 %d 列)\n\n", config.ColBExtractIndex+1))
|
||
|
||
sb.WriteString("请严格按照以下 JSON 格式返回结果:\n")
|
||
sb.WriteString(`{"matches":[{"index":0,"value":"匹配到的目标列值"},{"index":1,"value":""}]}` + "\n")
|
||
sb.WriteString("如果某条无法匹配,value 设为空字符串。\n\n")
|
||
|
||
sb.WriteString(fmt.Sprintf("A 表记录(需要匹配,共 %d 条):\n", len(unmatched)))
|
||
for i, row := range unmatched {
|
||
matchVal := getCell(row, config.ColAMatchIndex)
|
||
sb.WriteString(fmt.Sprintf("- 索引 %d: 「%s」", i, matchVal))
|
||
if hasTime {
|
||
sb.WriteString(fmt.Sprintf(", 时间=%s", getCell(row, config.ColATimeIndex)))
|
||
}
|
||
sb.WriteString("\n")
|
||
}
|
||
|
||
sb.WriteString(fmt.Sprintf("\nB 表参考数据(共 %d 条):\n", len(bRows)))
|
||
truncated := false
|
||
for i, row := range bRows {
|
||
matchVal := getCell(row, config.ColBMatchIndex)
|
||
extractVal := getCell(row, config.ColBExtractIndex)
|
||
sb.WriteString(fmt.Sprintf(" 「%s」 → 目标列值: 「%s」", matchVal, extractVal))
|
||
if hasTime {
|
||
sb.WriteString(fmt.Sprintf(", 时间=%s", getCell(row, config.ColBTimeIndex)))
|
||
}
|
||
sb.WriteString("\n")
|
||
// 限制 B 表部分总字符数,防止 prompt 超出 token 限制
|
||
if sb.Len() > maxPromptBChars {
|
||
fmt.Printf("[AI-WARN] Prompt B 表数据超长 (%d 条,%d 字符),截断于第 %d 条\n", len(bRows), sb.Len(), i)
|
||
truncated = true
|
||
}
|
||
if truncated {
|
||
sb.WriteString(fmt.Sprintf(" ... 已截断,省略 %d 条\n", len(bRows)-i-1))
|
||
break
|
||
}
|
||
}
|
||
|
||
sb.WriteString("\n请返回 JSON 格式的匹配结果。")
|
||
|
||
return []deepseekMessage{
|
||
{Role: "system", Content: "你是一个数据匹配专家。请严格按照 JSON 格式返回结果,不要添加额外说明。"},
|
||
{Role: "user", Content: sb.String()},
|
||
}
|
||
}
|
||
|
||
// formatTimeDiff 格式化时间差为可读字符串
|
||
func formatTimeDiff(d time.Duration) string {
|
||
abs := d
|
||
if abs < 0 {
|
||
abs = -abs
|
||
}
|
||
hours := int(abs.Hours())
|
||
mins := int(abs.Minutes()) % 60
|
||
secs := int(abs.Seconds()) % 60
|
||
|
||
sign := ""
|
||
if d < 0 {
|
||
sign = "-"
|
||
}
|
||
if hours > 0 {
|
||
return fmt.Sprintf("%s%dh%dm%ds", sign, hours, mins, secs)
|
||
} else if mins > 0 {
|
||
return fmt.Sprintf("%s%dm%ds", sign, mins, secs)
|
||
}
|
||
return fmt.Sprintf("%s%ds", sign, secs)
|
||
}
|
||
|
||
// parseTimeDiffDuration 将 TimeDiff 字符串(如 "1h30m")解析为 time.Duration(用于排序)
|
||
func parseTimeDiffDuration(s string) time.Duration {
|
||
if s == "" {
|
||
return 0
|
||
}
|
||
sign := time.Duration(1)
|
||
if s[0] == '-' {
|
||
sign = -1
|
||
s = s[1:]
|
||
}
|
||
d, err := time.ParseDuration(s)
|
||
if err != nil {
|
||
return 0
|
||
}
|
||
return sign * d
|
||
}
|