From d18a521f9cef75128721b7be0e0550c53e7a6d14 Mon Sep 17 00:00:00 2001 From: Elaina Date: Wed, 22 Apr 2026 12:21:52 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E8=AF=84=E5=AE=A1?= =?UTF-8?q?=E5=8F=91=E7=8E=B0=E7=9A=844=E4=B8=AA=E9=AB=98=E4=BC=98?= =?UTF-8?q?=E5=85=88=E7=BA=A7=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. sparse.py: 话题切换过滤从赋0分改为continue,真正排除旧话题候选 2. gatekeeper.py: reset() 清空IDF缓存,避免新会话状态污染 3. gatekeeper.py: 句级裁剪后重新估算token数 4. sparse.py: content_words提取纳入所有英文单词(含单字符如'pg')和2字中文词 --- src/gatekeeper.py | 9 +++++++-- src/sparse.py | 15 ++++++--------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/src/gatekeeper.py b/src/gatekeeper.py index 92bf873..49950c6 100644 --- a/src/gatekeeper.py +++ b/src/gatekeeper.py @@ -192,13 +192,16 @@ class ContextGatekeeper: if user_to_keep or kept_asst_sents: new_user = '。'.join(user_to_keep) + ('。' if user_to_keep and kept_asst_sents else '') new_asst = '。'.join(kept_asst_sents) + # 裁剪后重新估算 token 数,不用原始值 + new_tokens_user = Block._estimate_tokens(new_user) + new_tokens_asst = Block._estimate_tokens(new_asst) trimmed_block = Block( user_text=new_user or block.user_text, assistant_text=new_asst or block.assistant_text, turn_id=block.turn_id, anchors=block.anchors, - tokens_user=block.tokens_user, - tokens_assistant=block.tokens_assistant + tokens_user=new_tokens_user, + tokens_assistant=new_tokens_asst ) trimmed.append(trimmed_block) else: @@ -262,4 +265,6 @@ class ContextGatekeeper: self.blocks.clear() self.turn_counter = 0 self._active_topic = None + self.anchor_extractor._idf_cache.clear() + self.anchor_extractor._doc_count = 0 # constraints 保留 \ No newline at end of file diff --git a/src/sparse.py b/src/sparse.py index 484c8a3..a2efdc7 100644 --- a/src/sparse.py +++ b/src/sparse.py @@ -92,18 +92,18 @@ class SparseRetriever: q_anchors_lower = [a.lower() for a in query_anchors] # 内容词: 从 query 原文提取的 topic-discriminative 词汇 - # 只包括: 英文术语、代码标识符、版本号 - # 中文通用词(如"怎么"、"执行")不具有话题区分度,排除 + # 包括: 英文术语/标识符、版本号、2+字符中文词 + # 中文通用短词(如"怎么")不具有话题区分度,排除 content_words = set() - # 英文单词和代码标识符(长度>=2) + # 英文单词和代码标识符(所有长度 >= 2) for w in re.findall(r'[a-zA-Z_][a-zA-Z0-9_-]*', query_text): if len(w) >= 2: content_words.add(w.lower()) # 版本号 for v in re.findall(r'v?\d+(\.\d+)*', query_text): content_words.add(v.lower()) - # 完整中文术语(连续中文字符 >= 4,足够具体的术语) - for chunk in re.findall(r'[\u4e00-\u9fff]{4,}', query_text): + # 2字及以上中文词(覆盖"PostgreSQL"等专有名词) + for chunk in re.findall(r'[\u4e00-\u9fff]{2,}', query_text): content_words.add(chunk.lower()) for i, block in enumerate(blocks): @@ -111,16 +111,13 @@ class SparseRetriever: # 话题切换时: 过滤掉不包含任何内容词的块 # 这些块属于旧话题,不应参与当前查询的候选 - # 例如: 问 PostgreSQL 时,只有包含 'postgresql' 或 'explain' 等词的块才能通过 if topic_switched and content_words: block_text = (block.user_text + ' ' + block.assistant_text).lower() - # 检查 block 是否包含 query 的任意一个内容词 block_contains_content = any( cw in block_text for cw in content_words ) if not block_contains_content: - scored.append((block, 0.0)) - continue + continue # 直接跳过,不加入 scored 列表 s = self.score(block, query_anchors, recency, idf_cache) scored.append((block, s))