Files
rime-context-filter/rime_context_filter.lua
sakuradairong c67747c231
Some checks failed
CI / luacheck (push) Has been cancelled
CI / Test (Lua 5.3) (push) Has been cancelled
CI / Test (LuaJIT) (push) Has been cancelled
CI / Test (Lua 5.1) (push) Has been cancelled
fix: Lua 5.1 兼容性 + CI 稳定性
- 修复 rime_context_filter.lua: load() 在 Lua 5.1 上用 loadstring
- 修复 test: load(ser) 在 Lua 5.1 上改用 loadstring
- 修复 CI: luarocks install 添加 continue-on-error
2026-06-10 16:36:40 +08:00

377 lines
11 KiB
Lua
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
-- rime_context_filter.lua
-- 纯学习的上下文调频引擎v5 — 跨平台 + 衰减 + 安全)
--
-- 学习:自动记录「上屏的前文 → 当前选的词」的共现关系
-- 匹配:根据当前上下文对候选词加权,越相关的越靠前
-- 持久化:数据以 Lua 表字面量格式存到
-- context_learned.data路径因平台而异
-- 由 Lua VM 原生加载,不需要逐行 regex 解析
--
-- 配置(可选,写入 rime_ice.custom.yaml patch: 下):
-- context_filter:
-- save_interval: 30 # 每 N 次提交存一次盘(默认 30
-- data_path: /自定义/path/context_learned.data # 自定义数据路径,优先级最高
-- decay_enabled: true # 是否启用衰减(默认 true
-- decay_rate: 0.95 # 每次保存时的衰减因子(默认 0.95
--
-- 激活:
-- "engine/filters/@after 6":
-- lua_filter@*rime_context_filter
----------------------------------------------------------------------
-- 跨平台路径解析
----------------------------------------------------------------------
local function get_data_path(env)
local config = env.engine.schema.config
local custom = config:get_string(env.name_space .. "/data_path")
if custom and #custom > 0 then return custom end
local sep = package.config:sub(1,1)
if sep == "\\" then
return (os.getenv("APPDATA") or "") .. sep .. "Rime" .. sep .. "context_learned.data"
end
-- macOS: ~/Library/Rime/(检测目录或通过 sw_vers 确认平台)
local home = os.getenv("HOME") or ""
if os.execute('test -d "' .. home .. '/Library/Rime" 2>/dev/null') == 0 then
return home .. "/Library/Rime/context_learned.data"
end
-- 首次部署时 ~/Library/Rime/ 可能尚未创建,通过 macOS 独有文件确认
if os.execute('test -f /usr/bin/sw_vers 2>/dev/null') == 0 then
return home .. "/Library/Rime/context_learned.data"
end
-- Linux: 优先 XDG_DATA_HOME
local xdg = os.getenv("XDG_DATA_HOME")
if xdg and #xdg > 0 then
return xdg .. "/rime/context_learned.data"
end
-- Linux fallback: $HOME/.local/share/rime/
return home .. "/.local/share/rime/context_learned.data"
end
----------------------------------------------------------------------
-- 持久化 — Lua 表字面量格式
----------------------------------------------------------------------
local function esc(s)
-- Lua 字符串字面量转义(中文不包含需要转义的字符,聊备一格)
return '"' .. s:gsub("\\", "\\\\"):gsub('"', '\\"'):gsub("\n", "\\n"):gsub("\r", "\\r") .. '"'
end
--- 将内存数据序列化为 Lua 源码
--- 格式:
--- return {
--- ["前文"]={["词"]=5,["词2"]=3},
--- ["前文2"]={["词"]=2},
--- }
local function serialize(data)
local buf = { "return {\n" }
for ctx, words in pairs(data) do
local first = true
buf[#buf + 1] = " [" .. esc(ctx) .. "]={"
for word, count in pairs(words) do
if count > 1 then -- 写入时即剪枝
if first then first = false else buf[#buf + 1] = "," end
buf[#buf + 1] = "[" .. esc(word) .. "]=" .. count
end
end
if first then
buf[#buf] = nil -- 该前文没有有效词条,跳过
else
buf[#buf + 1] = "},\n"
end
end
buf[#buf + 1] = "}\n"
return table.concat(buf)
end
--- 用 Lua VM 原生加载数据文件(无逐行 regex
--- 沙箱环境防止数据文件执行恶意代码
local function load(data_file)
local f = io.open(data_file, "r")
if not f then return {}, 0 end
local content = f:read("*a")
f:close()
if not content or #content == 0 then return {}, 0 end
local safe_env = {}
local loader, err
if _VERSION == "Lua 5.1" then
-- Lua 5.1: load() accepts function only; loadstring for string chunks
loader, err = loadstring(content, "@" .. data_file)
if loader then
io.stderr:write("[rime-context-filter] WARNING: " ..
"Sandbox unavailable on Lua 5.1. " ..
"Data file could access global environment.\n")
end
else
-- Lua 5.3+: load(chunk, name, mode, env) with sandbox
loader, err = load(content, "@" .. data_file, "t", safe_env)
if not loader then
-- Fallback for embedders without 4-arg load
loader, err = load(content, "@" .. data_file)
if loader then
io.stderr:write("[rime-context-filter] WARNING: " ..
"Sandbox unavailable, data file could access global environment.\n")
end
end
end
if not loader then return {}, 0 end
local ok, data = pcall(loader)
if not ok or type(data) ~= "table" then return {}, 0 end
-- 统计总条目数(仅用于 compaction 判断)
local entries = 0
for _, words in pairs(data) do
for _, count in pairs(words) do
if count > 1 then entries = entries + 1 end
end
end
return data, entries
end
--- 原子重写整个文件
local function save(data, data_file)
local tmp = data_file .. ".tmp"
local f = io.open(tmp, "w")
if not f then return false end
f:write(serialize(data))
f:close()
os.remove(data_file)
os.rename(tmp, data_file)
return true
end
----------------------------------------------------------------------
-- 跨平台目录确保
----------------------------------------------------------------------
local function ensure_dir(path)
local is_win = package.config:sub(1,1) == "\\"
if is_win then
local dir = path:match("^(.+)\\[^\\]+$")
if dir then os.execute('if not exist "' .. dir .. '" mkdir "' .. dir .. '"') end
else
local dir = path:match("^(.+)/([^/]+)$")
if dir then
os.execute("mkdir -p '" .. dir:gsub("\\", "/") .. "' 2>/dev/null")
end
end
end
--- 确保文件和目录存在
local function ensure_file(data_file)
local f = io.open(data_file, "a")
if f then f:close(); return end
ensure_dir(data_file)
f = io.open(data_file, "w")
if f then f:write("return {}\n"); f:close() end
end
----------------------------------------------------------------------
-- 衰减 / 遗忘机制
----------------------------------------------------------------------
local function decay_learned(learned, rate)
for ctx, words in pairs(learned) do
for word, count in pairs(words) do
local new_count = count * rate
if new_count < 1.1 then
words[word] = nil
else
words[word] = new_count
end
end
if next(words) == nil then
learned[ctx] = nil
end
end
end
----------------------------------------------------------------------
-- UTF-8 安全截取(按字符边界,非字节)
----------------------------------------------------------------------
--- 安全截取字符串末尾 N 个字符
local function utf8_last(s, n)
local len = #s
if len == 0 then return s end
local pos = len + 1
for i = 1, n do
if pos <= 1 then break end
pos = pos - 1
-- 跳过 utf-8 连续字节 (0x80-0xBF)
while pos > 1 and s:byte(pos) >= 0x80 and s:byte(pos) < 0xC0 do
pos = pos - 1
end
end
if pos < 1 then pos = 1 end
return s:sub(pos)
end
----------------------------------------------------------------------
-- 上下文评分(热路径)
----------------------------------------------------------------------
local function score_candidates(candidates, window, learned, scores)
-- 清空复用表
for k in pairs(scores) do scores[k] = nil end
local last = window[#window]
if not last or #last == 0 then return end
local function apply_weight(key, weight)
local e = learned[key]
if not e then return end
for _, c in ipairs(candidates) do
local w = c.text
if e[w] then
scores[w] = (scores[w] or 0) + e[w] * weight
end
end
end
apply_weight(last, 1.0)
if #last >= 6 then apply_weight(utf8_last(last, 2), 0.5) end -- ~2 CJK chars
if #last >= 3 then apply_weight(utf8_last(last, 1), 0.25) end -- ~1 CJK char
if #window >= 2 then
apply_weight(window[#window - 1] .. last, 0.4)
end
end
----------------------------------------------------------------------
-- 组件入口
----------------------------------------------------------------------
local function init(env)
env.name_space = env.name_space:gsub("^*", "")
local config = env.engine.schema.config
-- 跨平台数据路径
env.data_file = get_data_path(env)
local interval = config:get_int(env.name_space .. "/save_interval")
env.save_interval = (interval ~= nil) and interval or 30
-- 衰减配置
env.decay_enabled = config:get_bool(env.name_space .. "/decay_enabled")
if env.decay_enabled == nil then env.decay_enabled = true end
env.decay_rate = tonumber(config:get_string(env.name_space .. "/decay_rate")) or 0.95
-- 上下文窗口(最近 3 次上屏)
env.window = {}
-- 加载历史数据
ensure_file(env.data_file)
env.learned, env.entry_count = load(env.data_file)
-- 会话级新增缓冲
env.scores = {}
env.pending = {}
env.commit_count = 0
-- 监听提交
env.engine.context.commit_notifier:connect(function(ctx)
local text = ctx:get_commit_text()
if not text or #text == 0 then return end
local prev = env.window[#env.window]
if prev and #prev > 0 then
-- 更新内存
local e = env.learned[prev]
if e then
e[text] = (e[text] or 0) + 1
else
env.learned[prev] = { [text] = 1 }
end
-- 待刷缓冲
local pe = env.pending[prev]
if pe then
pe[text] = (pe[text] or 0) + 1
else
env.pending[prev] = { [text] = 1 }
end
end
-- 更新窗口
env.window[#env.window + 1] = text
if #env.window > 3 then
table.remove(env.window, 1)
end
-- 批量存盘全量重写Lua VM 编译加载比逐行 regex 快得多)
env.commit_count = env.commit_count + 1
if env.commit_count >= env.save_interval then
-- 将缓冲合并到 learned
for ctx, words in pairs(env.pending) do
local e = env.learned[ctx]
if not e then
env.learned[ctx] = words
else
for word, count in pairs(words) do
e[word] = (e[word] or 0) + count
end
end
end
env.pending = {}
env.commit_count = 0
env.entry_count = nil -- 下次 compact 时重新计算
-- 保存前执行衰减
if env.decay_enabled then
decay_learned(env.learned, env.decay_rate)
end
save(env.learned, env.data_file)
end
end)
end
local function filter(input, env)
local candidates = {}
for cand in input:iter() do
candidates[#candidates + 1] = cand
end
if #candidates == 0 then return end
score_candidates(candidates, env.window, env.learned, env.scores)
local scores = env.scores
-- 阈值 2.0(同一搭配选 2 次以上才生效)
local max_score = 0
for _, v in pairs(scores) do
if v > max_score then max_score = v end
end
if max_score < 2.0 then
for _, cand in ipairs(candidates) do yield(cand) end
return
end
-- 提权降序,同权保持原序
local order = {}
for i = 1, #candidates do order[i] = i end
table.sort(order, function(a, b)
local sa = scores[candidates[a].text] or 0
local sb = scores[candidates[b].text] or 0
if sa ~= sb then return sa > sb end
return a < b
end)
for _, idx in ipairs(order) do
yield(candidates[idx])
end
end
return {
init = init,
func = filter,
-- 以下仅用于测试
utf8_last = utf8_last,
serialize = serialize,
decay_learned = decay_learned,
score_candidates = score_candidates,
}