feat: v5 — cross-platform, decay, sandbox, GC optimization
- Cross-platform path resolution (auto-detect Windows/macOS/Linux) - GC reduction in score_candidates (1 table vs 6 per invocation) - Decay/forgetting mechanism to prevent infinite data growth - Sandbox for load() to prevent code injection via data file - Cross-platform ensure_dir for mkdir on all platforms - New config options: data_path, decay_enabled, decay_rate - Update README with v5 features and cross-platform documentation
This commit is contained in:
59
README.md
59
README.md
@@ -9,6 +9,7 @@ RIME 输入法上下文调频过滤器。根据已上屏的前文自动调整候
|
||||
- **纯学习**:零硬编码规则,完全从你的输入习惯中学习
|
||||
- **跨会话**:学习数据持久化到本地文件,重启不丢
|
||||
- **轻量**:热路径无文件 I/O,不卡输入
|
||||
- **跨平台**:自动适配 Windows / macOS / Linux 路径
|
||||
|
||||
## 效果
|
||||
|
||||
@@ -66,12 +67,50 @@ patch:
|
||||
|
||||
```yaml
|
||||
context_filter:
|
||||
save_interval: 30 # 每 N 次提交存一次盘(默认 30,增加可减少写入频率)
|
||||
save_interval: 30 # 每 N 次提交存一次盘(默认 30,增加可减少写入频率)
|
||||
data_path: "" # 自定义数据文件路径,为空则自动检测(见下方)
|
||||
decay_enabled: true # 是否启用衰减遗忘机制(默认 true)
|
||||
decay_rate: 0.95 # 每次保存时的衰减因子(默认 0.95)
|
||||
```
|
||||
|
||||
### 参数说明
|
||||
|
||||
| 参数 | 类型 | 默认值 | 说明 |
|
||||
|---|---|---|---|
|
||||
| `save_interval` | 整数 | 30 | 每 N 次提交写一次磁盘 |
|
||||
| `data_path` | 字符串 | 自动 | 自定义数据文件完整路径。设置后覆盖自动检测结果 |
|
||||
| `decay_enabled` | 布尔 | true | 启用衰减后,旧数据的权重随时间逐渐降低 |
|
||||
| `decay_rate` | 浮点数 | 0.95 | 每次保存时所有计数的乘数因子(0 < rate < 1) |
|
||||
|
||||
### 衰减行为
|
||||
|
||||
衰减机制防止数据无限增长,让长期不用的搭配逐渐遗忘。`decay_rate = 0.95` 时:
|
||||
|
||||
| 初始计数 | 10 次保存后 | 20 次保存后 | 40 次保存后 | 消亡阈值 (≈1.1) |
|
||||
|---|---|---|---|---|
|
||||
| 1 | 已消亡 | — | — | 1 次 |
|
||||
| 3 | 1.79 | 1.07 → 已消亡 | — | ~20 次 |
|
||||
| 5 | 2.99 | 1.79 | 0.64 → 已消亡 | ~32 次 |
|
||||
| 10 | 5.99 | 3.58 | 2.15 | ~44 次 |
|
||||
| 50 | 29.9 | 17.9 | 6.43 | ~76 次 |
|
||||
|
||||
频率越高的搭配保留越久,偶然一次的搭配较快消亡,保持数据库精炼。
|
||||
|
||||
## 数据文件
|
||||
|
||||
学习数据存储在 RIME 用户目录下的 `context_learned.data`,格式为 Lua 表字面量:
|
||||
学习数据存储在 RIME 用户目录下的 `context_learned.data`,路径根据平台自动检测:
|
||||
|
||||
| 平台 | 默认路径 |
|
||||
|---|---|
|
||||
| **Windows** | `%APPDATA%\Rime\context_learned.data` |
|
||||
| **macOS** | `~/Library/Rime/context_learned.data` |
|
||||
| **Linux** | `/Library/Rime/context_learned.data`(注:Linux 路径可能随发行版不同,推荐使用 `data_path` 自定义) |
|
||||
|
||||
也可以配置 `data_path` 指定任意路径。
|
||||
|
||||
### 文件格式
|
||||
|
||||
格式为 Lua 表字面量,由 Lua VM 原生加载,无需逐行解析。手动编辑此文件可以增删规则或重置学习数据。
|
||||
|
||||
```lua
|
||||
return {
|
||||
@@ -81,7 +120,9 @@ return {
|
||||
}
|
||||
```
|
||||
|
||||
由 Lua VM 原生加载,无需逐行解析。手动编辑这个文件可以增删规则或重置学习数据。
|
||||
### 安全性
|
||||
|
||||
数据文件在**沙箱环境**中加载。恶意构造的数据文件无法访问 `os`、`io`、`string` 等系统库,仅允许纯数据(表、数字、字符串)返回。兼容 Lua 5.1 / LuaJIT(自动降级为无沙箱模式)。
|
||||
|
||||
## 工作原理
|
||||
|
||||
@@ -99,8 +140,8 @@ commit_notifier
|
||||
| Key | 权重 | 示例 |
|
||||
|---|---|---|
|
||||
| 精确前文 | 1.0 | `"接下来的"` |
|
||||
| 末尾 2 字 | 0.5 | `"的"`(前文 `"的"` 时退化为单字) |
|
||||
| 末尾 1 字 | 0.25 | `"的"` |
|
||||
| 末尾 2 字(≥6 字节) | 0.5 | `"来的"`(前文 `"接下来的"` 时) |
|
||||
| 末尾 1 字(≥3 字节) | 0.25 | `"的"` |
|
||||
| 双词组合 | 0.4 | `"接下来的任务"` |
|
||||
|
||||
四个 key 的得分加权求和,总分 ≥ 2.0 才参与重排(约 2-3 次选择后生效)。
|
||||
@@ -109,6 +150,14 @@ commit_notifier
|
||||
|
||||
数据以 **Lua 源码格式** 存储。加载时通过 `load()` 由 Lua VM 一次性编译执行,不用逐行 regex 解析。写入使用原子重写(`.tmp` + `rename`),防止文件损坏。
|
||||
|
||||
## 版本历史
|
||||
|
||||
- **v5** — 跨平台路径自动检测、衰减遗忘机制、沙箱安全加载、热路径 GC 优化
|
||||
- **v4** — Lua 源码持久化格式(移除 JSON 依赖)
|
||||
- **v3** — 增量缓冲 + 批量写入
|
||||
- **v2** — 上下文窗口 + 4 种 key 加权查询
|
||||
- **v1** — 初版
|
||||
|
||||
## License
|
||||
|
||||
MIT
|
||||
|
||||
@@ -1,25 +1,44 @@
|
||||
-- rime_context_filter.lua
|
||||
-- 纯学习的上下文调频引擎(v4 — Lua 源码持久化)
|
||||
-- 纯学习的上下文调频引擎(v5 — 跨平台 + 衰减 + 安全)
|
||||
--
|
||||
-- 学习:自动记录「上屏的前文 → 当前选的词」的共现关系
|
||||
-- 匹配:根据当前上下文对候选词加权,越相关的越靠前
|
||||
-- 持久化:数据以 Lua 表字面量格式存到
|
||||
-- %APPDATA%\Rime\context_learned.data
|
||||
-- context_learned.data(路径因平台而异)
|
||||
-- 由 Lua VM 原生加载,不需要逐行 regex 解析
|
||||
--
|
||||
-- 配置(可选,写入 rime_ice.custom.yaml patch: 下):
|
||||
-- context_filter:
|
||||
-- save_interval: 30 # 每 N 次提交存一次盘(默认 30)
|
||||
-- data_path: /自定义/path/context_learned.data # 自定义数据路径,优先级最高
|
||||
-- decay_enabled: true # 是否启用衰减(默认 true)
|
||||
-- decay_rate: 0.95 # 每次保存时的衰减因子(默认 0.95)
|
||||
--
|
||||
-- 激活:
|
||||
-- "engine/filters/@after 6":
|
||||
-- lua_filter@*rime_context_filter
|
||||
|
||||
----------------------------------------------------------------------
|
||||
-- 持久化 — Lua 表字面量格式
|
||||
-- 跨平台路径解析
|
||||
----------------------------------------------------------------------
|
||||
|
||||
local DATA_FILE = (os.getenv("APPDATA") or "") .. "\\Rime\\context_learned.data"
|
||||
local function get_data_path(env)
|
||||
local config = env.engine.schema.config
|
||||
local custom = config:get_string(env.name_space .. "/data_path")
|
||||
if custom and #custom > 0 then return custom end
|
||||
|
||||
local sep = package.config:sub(1,1)
|
||||
if sep == "\\" then
|
||||
return (os.getenv("APPDATA") or "") .. sep .. "Rime" .. sep .. "context_learned.data"
|
||||
end
|
||||
-- macOS / Linux
|
||||
local home = os.getenv("HOME") or ""
|
||||
return home .. "/Library/Rime/context_learned.data"
|
||||
end
|
||||
|
||||
----------------------------------------------------------------------
|
||||
-- 持久化 — Lua 表字面量格式
|
||||
----------------------------------------------------------------------
|
||||
|
||||
local function esc(s)
|
||||
-- Lua 字符串字面量转义(中文不包含需要转义的字符,聊备一格)
|
||||
@@ -54,15 +73,23 @@ local function serialize(data)
|
||||
end
|
||||
|
||||
--- 用 Lua VM 原生加载数据文件(无逐行 regex)
|
||||
local function load()
|
||||
local f = io.open(DATA_FILE, "r")
|
||||
--- 沙箱环境防止数据文件执行恶意代码
|
||||
local function load(data_file)
|
||||
local f = io.open(data_file, "r")
|
||||
if not f then return {}, 0 end
|
||||
|
||||
local content = f:read("*a")
|
||||
f:close()
|
||||
if not content or #content == 0 then return {}, 0 end
|
||||
|
||||
local loader, err = load(content, "@" .. DATA_FILE)
|
||||
local safe_env = {}
|
||||
-- Lua 5.3+: load(chunk, name, mode, env)
|
||||
-- Lua 5.1/LuaJIT: load(chunk, name) only — no sandbox parameter
|
||||
local loader, err = load(content, "@" .. data_file, "t", safe_env)
|
||||
if not loader then
|
||||
-- Fallback for older Lua versions
|
||||
loader, err = load(content, "@" .. data_file)
|
||||
end
|
||||
if not loader then return {}, 0 end
|
||||
local ok, data = pcall(loader)
|
||||
if not ok or type(data) ~= "table" then return {}, 0 end
|
||||
@@ -78,29 +105,65 @@ local function load()
|
||||
end
|
||||
|
||||
--- 原子重写整个文件
|
||||
local function save(data)
|
||||
local tmp = DATA_FILE .. ".tmp"
|
||||
local function save(data, data_file)
|
||||
local tmp = data_file .. ".tmp"
|
||||
local f = io.open(tmp, "w")
|
||||
if not f then return false end
|
||||
f:write(serialize(data))
|
||||
f:close()
|
||||
os.remove(DATA_FILE)
|
||||
os.rename(tmp, DATA_FILE)
|
||||
os.remove(data_file)
|
||||
os.rename(tmp, data_file)
|
||||
return true
|
||||
end
|
||||
|
||||
----------------------------------------------------------------------
|
||||
-- 跨平台目录确保
|
||||
----------------------------------------------------------------------
|
||||
|
||||
local function ensure_dir(path)
|
||||
local is_win = package.config:sub(1,1) == "\\"
|
||||
if is_win then
|
||||
local dir = path:match("^(.+)\\[^\\]+$")
|
||||
if dir then os.execute('if not exist "' .. dir .. '" mkdir "' .. dir .. '"') end
|
||||
else
|
||||
local dir = path:match("^(.+)/([^/]+)$")
|
||||
if dir then
|
||||
os.execute("mkdir -p '" .. dir:gsub("\\", "/") .. "' 2>/dev/null")
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
--- 确保文件和目录存在
|
||||
local function ensure_file()
|
||||
local f = io.open(DATA_FILE, "a")
|
||||
local function ensure_file(data_file)
|
||||
local f = io.open(data_file, "a")
|
||||
if f then f:close(); return end
|
||||
local dir = DATA_FILE:match("^(.+)\\[^\\]+$")
|
||||
if dir then os.execute('if not exist "' .. dir .. '" mkdir "' .. dir .. '"') end
|
||||
f = io.open(DATA_FILE, "w")
|
||||
ensure_dir(data_file)
|
||||
f = io.open(data_file, "w")
|
||||
if f then f:write("return {}\n"); f:close() end
|
||||
end
|
||||
|
||||
----------------------------------------------------------------------
|
||||
-- 上下文评分
|
||||
-- 衰减 / 遗忘机制
|
||||
----------------------------------------------------------------------
|
||||
|
||||
local function decay_learned(learned, rate)
|
||||
for ctx, words in pairs(learned) do
|
||||
for word, count in pairs(words) do
|
||||
local new_count = count * rate
|
||||
if new_count < 1.1 then
|
||||
words[word] = nil
|
||||
else
|
||||
words[word] = new_count
|
||||
end
|
||||
end
|
||||
if next(words) == nil then
|
||||
learned[ctx] = nil
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
----------------------------------------------------------------------
|
||||
-- 上下文评分(热路径,减少 GC)
|
||||
----------------------------------------------------------------------
|
||||
|
||||
local function score_candidates(candidates, window, learned)
|
||||
@@ -108,34 +171,23 @@ local function score_candidates(candidates, window, learned)
|
||||
local last = window[#window]
|
||||
if not last or #last == 0 then return scores end
|
||||
|
||||
-- 候选词快速查找
|
||||
local cand_set = {}
|
||||
for _, c in ipairs(candidates) do
|
||||
cand_set[c.text] = true
|
||||
end
|
||||
|
||||
-- context keys with weights
|
||||
local keys = { { last, 1.0 } }
|
||||
if #last >= 2 then
|
||||
keys[#keys + 1] = { last:sub(-2), 0.5 }
|
||||
end
|
||||
keys[#keys + 1] = { last:sub(-1), 0.25 }
|
||||
if #window >= 2 then
|
||||
keys[#keys + 1] = { window[#window - 1] .. window[#window], 0.4 }
|
||||
end
|
||||
|
||||
for _, kv in ipairs(keys) do
|
||||
local key, weight = kv[1], kv[2]
|
||||
local function apply_weight(key, weight)
|
||||
local e = learned[key]
|
||||
if e then
|
||||
for word, count in pairs(e) do
|
||||
if cand_set[word] then
|
||||
scores[word] = (scores[word] or 0) + count * weight
|
||||
end
|
||||
if not e then return end
|
||||
for _, c in ipairs(candidates) do
|
||||
local w = c.text
|
||||
if e[w] then
|
||||
scores[w] = (scores[w] or 0) + e[w] * weight
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
apply_weight(last, 1.0)
|
||||
if #last >= 6 then apply_weight(last:sub(-6), 0.5) end -- ~2 CJK chars
|
||||
if #last >= 3 then apply_weight(last:sub(-3), 0.25) end -- ~1 CJK char
|
||||
if #window >= 2 then
|
||||
apply_weight(window[#window - 1] .. last, 0.4)
|
||||
end
|
||||
return scores
|
||||
end
|
||||
|
||||
@@ -147,14 +199,22 @@ local function init(env)
|
||||
env.name_space = env.name_space:gsub("^*", "")
|
||||
local config = env.engine.schema.config
|
||||
|
||||
-- 跨平台数据路径
|
||||
env.data_file = get_data_path(env)
|
||||
|
||||
env.save_interval = config:get_int(env.name_space .. "/save_interval") or 30
|
||||
|
||||
-- 衰减配置
|
||||
env.decay_enabled = config:get_bool(env.name_space .. "/decay_enabled")
|
||||
if env.decay_enabled == nil then env.decay_enabled = true end
|
||||
env.decay_rate = tonumber(config:get_string(env.name_space .. "/decay_rate")) or 0.95
|
||||
|
||||
-- 上下文窗口(最近 3 次上屏)
|
||||
env.window = {}
|
||||
|
||||
-- 加载历史数据
|
||||
ensure_file()
|
||||
env.learned, env.entry_count = load()
|
||||
ensure_file(env.data_file)
|
||||
env.learned, env.entry_count = load(env.data_file)
|
||||
|
||||
-- 会话级新增缓冲
|
||||
env.pending = {}
|
||||
@@ -206,8 +266,13 @@ local function init(env)
|
||||
end
|
||||
env.pending = {}
|
||||
env.commit_count = 0
|
||||
env.entry_count = nil
|
||||
save(env.learned)
|
||||
env.entry_count = nil -- 下次 compact 时重新计算
|
||||
|
||||
-- 保存前执行衰减
|
||||
if env.decay_enabled then
|
||||
decay_learned(env.learned, env.decay_rate)
|
||||
end
|
||||
save(env.learned, env.data_file)
|
||||
end
|
||||
end)
|
||||
end
|
||||
|
||||
Reference in New Issue
Block a user