Files
office-data-matcher/ai.go
sakuradairong 2b17760fbd Merge remote changes, split app.go, remove V1 dead code, fix AICache (#2)
- Merge remote improvements: generic AI API, row-level cache,
  CSV export, matchPrep, prompt truncation, O(1) cache index
- Split app.go (1645 -> 5 files: app.go, cache.go, ai.go,
  matcher.go, export.go)
- Remove V1 dead code: 6 methods, 4 helpers, ~300 lines
- Fix AICache 3 bugs: TOCTOU saveToFile, silent loadFromFile,
  full-sort put
- Extract 8 named constants (threshold, time window, batch size...)
- Frontend: isRunning guard, buildMatchConfig dedup, CSS variables
- Upgrade Go to 1.24.0
2026-06-05 14:46:55 +08:00

235 lines
6.8 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package main
import (
"bytes"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net/http"
"path/filepath"
"strings"
"time"
)
// ---------- Deepseek API 类型 ----------
type deepseekMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type deepseekRequest struct {
Model string `json:"model"`
Messages []deepseekMessage `json:"messages"`
Temperature float64 `json:"temperature"`
MaxTokens int `json:"max_tokens,omitempty"`
}
type deepseekResponse struct {
Choices []struct {
Message struct {
Content string `json:"content"`
} `json:"message"`
} `json:"choices"`
Error *struct {
Message string `json:"message"`
} `json:"error,omitempty"`
}
// maxPromptBChars 限制 B 表数据在 prompt 中的最大字符数
const maxPromptBChars = 80000
// ---------- AI API 调用 ----------
// buildRowCacheKey 为单行匹配构建缓存键
func (a *App) buildRowCacheKey(matchValue, timeStr string, config MatchConfig) string {
parts := fmt.Sprintf("%s|%s|%s|%.1f|%s",
matchValue, timeStr, config.RegexPattern, config.TimeWindow,
filepath.Base(config.FileBPath))
h := sha256.Sum256([]byte(parts))
return hex.EncodeToString(h[:])
}
// hashPrompt 对 prompt 消息计算 SHA256用于缓存键
func hashPrompt(messages []deepseekMessage) string {
h := sha256.New()
for _, m := range messages {
h.Write([]byte(m.Role))
h.Write([]byte{0})
h.Write([]byte(m.Content))
h.Write([]byte{0})
}
return hex.EncodeToString(h.Sum(nil))
}
// callAIAPI 调用 OpenAI 兼容 APIDeepseek / OpenAI / 本地模型 等)
func (a *App) callAIAPI(messages []deepseekMessage) (string, error) {
if a.apiKey == "" {
return "", fmt.Errorf("请先设置 AI API 密钥")
}
// 默认值
endpoint := strings.TrimRight(a.apiEndpoint, "/")
if endpoint == "" {
endpoint = "https://api.deepseek.com/v1/chat/completions"
} else if !strings.HasSuffix(endpoint, "/chat/completions") {
// 自动补齐 OpenAI 兼容路径(用户只需填 base URL
endpoint += "/v1/chat/completions"
}
model := a.apiModel
if model == "" {
model = deepseekModel
}
hash := hashPrompt(messages)
// 先查缓存
if cached, ok := a.aiCache.get(hash); ok {
fmt.Printf("[CACHE] ✓ 命中 AI 缓存 (hash=%s)\n", hash[:12])
return cached, nil
}
fmt.Printf("[CACHE] ✗ 缓存未命中 (hash=%s),调用 %s...\n", hash[:12], endpoint)
reqBody := deepseekRequest{
Model: model,
Messages: messages,
Temperature: deepseekTemperature,
MaxTokens: deepseekMaxTokens,
}
bodyBytes, _ := json.Marshal(reqBody)
httpReq, err := http.NewRequest("POST", endpoint, bytes.NewReader(bodyBytes))
if err != nil {
return "", fmt.Errorf("创建请求失败: %v", err)
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+a.apiKey)
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(httpReq)
if err != nil {
return "", fmt.Errorf("调用 AI API 失败: %v", err)
}
defer resp.Body.Close()
respBytes, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("读取 AI 响应失败: %v", err)
}
var dr deepseekResponse
if err := json.Unmarshal(respBytes, &dr); err != nil {
return "", fmt.Errorf("解析 AI 响应失败: %v", err)
}
if dr.Error != nil {
return "", fmt.Errorf("AI API 错误: %s", dr.Error.Message)
}
if len(dr.Choices) == 0 {
return "", fmt.Errorf("AI 未返回有效结果")
}
result := strings.TrimSpace(dr.Choices[0].Message.Content)
// 写入缓存并持久化
a.aiCache.put(hash, result)
a.aiCache.saveToFile()
return result, nil
}
// buildGenericAIPrompt 构建通用 AI 匹配提示词
func (a *App) buildGenericAIPrompt(unmatched, bRows [][]string, config MatchConfig, windowDuration time.Duration, hasTime bool) []deepseekMessage {
var sb strings.Builder
sb.WriteString("你是一个数据匹配专家。请根据以下 A 表记录,从 B 表数据中找出最匹配的记录。\n\n")
sb.WriteString("匹配规则:\n")
sb.WriteString("1. 根据文本相似度匹配(注意中文字段的核心含义,忽略字母数字前缀后缀)\n")
if hasTime {
sb.WriteString(fmt.Sprintf("2. 时间差应在 %.0f 小时内\n", windowDuration.Hours()))
}
sb.WriteString(fmt.Sprintf("3. 返回匹配到的 B 表记录的目标列值(第 %d 列)\n\n", config.ColBExtractIndex+1))
sb.WriteString("请严格按照以下 JSON 格式返回结果:\n")
sb.WriteString(`{"matches":[{"index":0,"value":"匹配到的目标列值"},{"index":1,"value":""}]}` + "\n")
sb.WriteString("如果某条无法匹配value 设为空字符串。\n\n")
sb.WriteString(fmt.Sprintf("A 表记录(需要匹配,共 %d 条):\n", len(unmatched)))
for i, row := range unmatched {
matchVal := getCell(row, config.ColAMatchIndex)
sb.WriteString(fmt.Sprintf("- 索引 %d: 「%s」", i, matchVal))
if hasTime {
sb.WriteString(fmt.Sprintf(", 时间=%s", getCell(row, config.ColATimeIndex)))
}
sb.WriteString("\n")
}
sb.WriteString(fmt.Sprintf("\nB 表参考数据(共 %d 条):\n", len(bRows)))
truncated := false
for i, row := range bRows {
matchVal := getCell(row, config.ColBMatchIndex)
extractVal := getCell(row, config.ColBExtractIndex)
sb.WriteString(fmt.Sprintf(" 「%s」 → 目标列值: 「%s」", matchVal, extractVal))
if hasTime {
sb.WriteString(fmt.Sprintf(", 时间=%s", getCell(row, config.ColBTimeIndex)))
}
sb.WriteString("\n")
// 限制 B 表部分总字符数,防止 prompt 超出 token 限制
if sb.Len() > maxPromptBChars {
fmt.Printf("[AI-WARN] Prompt B 表数据超长 (%d 条,%d 字符),截断于第 %d 条\n", len(bRows), sb.Len(), i)
truncated = true
}
if truncated {
sb.WriteString(fmt.Sprintf(" ... 已截断,省略 %d 条\n", len(bRows)-i-1))
break
}
}
sb.WriteString("\n请返回 JSON 格式的匹配结果。")
return []deepseekMessage{
{Role: "system", Content: "你是一个数据匹配专家。请严格按照 JSON 格式返回结果,不要添加额外说明。"},
{Role: "user", Content: sb.String()},
}
}
// formatTimeDiff 格式化时间差为可读字符串
func formatTimeDiff(d time.Duration) string {
abs := d
if abs < 0 {
abs = -abs
}
hours := int(abs.Hours())
mins := int(abs.Minutes()) % 60
secs := int(abs.Seconds()) % 60
sign := ""
if d < 0 {
sign = "-"
}
if hours > 0 {
return fmt.Sprintf("%s%dh%dm%ds", sign, hours, mins, secs)
} else if mins > 0 {
return fmt.Sprintf("%s%dm%ds", sign, mins, secs)
}
return fmt.Sprintf("%s%ds", sign, secs)
}
// parseTimeDiffDuration 将 TimeDiff 字符串(如 "1h30m")解析为 time.Duration用于排序
func parseTimeDiffDuration(s string) time.Duration {
if s == "" {
return 0
}
sign := time.Duration(1)
if s[0] == '-' {
sign = -1
s = s[1:]
}
d, err := time.ParseDuration(s)
if err != nil {
return 0
}
return sign * d
}