Skip to content

个人知识库助手

构建一个能够回答私有知识问题的AI助手,支持上传文档、智能问答、来源追溯。

项目目标

输入:上传PDF/Word/TXT文档
处理:自动切分、向量化、存储
输出:根据文档内容回答问题,显示引用来源

技术架构

┌──────────────┐
│   用户界面   │  Streamlit Web界面
└──────┬───────┘

┌──────────────┐
│   文档处理   │  加载→切分→向量化
└──────┬───────┘

┌──────────────┐
│   向量数据库  │  Chroma存储
└──────┬───────┘

┌──────────────┐
│    RAG链     │  检索→生成
└──────────────┘

完整代码

项目结构

knowledge-assistant/
├── app.py              # Streamlit界面
├── knowledge_base.py   # 知识库管理
├── rag_chain.py        # RAG链
├── requirements.txt    # 依赖
└── data/               # 文档存储

requirements.txt

txt
langchain==0.3.0
langchain-openai==0.2.0
langchain-chroma==0.1.0
langchain-community==0.3.0
streamlit==1.38.0
pypdf==4.0.0
python-dotenv==1.0.0

knowledge_base.py

python
"""知识库管理模块"""
import os
from typing import List, Optional
from langchain_community.document_loaders import PyPDFLoader, TextLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_openai import OpenAIEmbeddings
from langchain_chroma import Chroma
from langchain_core.documents import Document

class KnowledgeBase:
    """知识库管理类"""
    
    def __init__(self, persist_directory: str = "./chroma_db"):
        self.persist_directory = persist_directory
        self.embeddings = OpenAIEmbeddings()
        self.splitter = RecursiveCharacterTextSplitter(
            chunk_size=1000,
            chunk_overlap=200,
            separators=["\n\n", "\n", "。", "!", "?", " ", ""]
        )
        self.vectorstore: Optional[Chroma] = None
        self._load_or_create()
    
    def _load_or_create(self):
        """加载或创建向量库"""
        if os.path.exists(self.persist_directory):
            self.vectorstore = Chroma(
                persist_directory=self.persist_directory,
                embedding_function=self.embeddings
            )
        else:
            self.vectorstore = None
    
    def add_documents(self, file_path: str) -> int:
        """添加文档到知识库
        
        Args:
            file_path: 文档路径
            
        Returns:
            添加的文档块数量
        """
        # 根据文件类型加载
        if file_path.endswith('.pdf'):
            loader = PyPDFLoader(file_path)
        else:
            loader = TextLoader(file_path, encoding='utf-8')
        
        documents = loader.load()
        
        # 添加元数据
        for doc in documents:
            doc.metadata['source'] = os.path.basename(file_path)
        
        # 切分
        chunks = self.splitter.split_documents(documents)
        
        # 存入向量库
        if self.vectorstore is None:
            self.vectorstore = Chroma.from_documents(
                documents=chunks,
                embedding=self.embeddings,
                persist_directory=self.persist_directory
            )
        else:
            self.vectorstore.add_documents(chunks)
        
        return len(chunks)
    
    def search(self, query: str, k: int = 5) -> List[Document]:
        """搜索相关文档"""
        if self.vectorstore is None:
            return []
        return self.vectorstore.similarity_search(query, k=k)
    
    def search_with_scores(self, query: str, k: int = 5):
        """搜索并返回相似度分数"""
        if self.vectorstore is None:
            return []
        return self.vectorstore.similarity_search_with_score(query, k=k)
    
    def get_retriever(self):
        """获取检索器"""
        if self.vectorstore is None:
            return None
        return self.vectorstore.as_retriever(
            search_kwargs={"k": 5}
        )
    
    def get_document_count(self) -> int:
        """获取文档数量"""
        if self.vectorstore is None:
            return 0
        return self.vectorstore._collection.count()
    
    def clear(self):
        """清空知识库"""
        import shutil
        if os.path.exists(self.persist_directory):
            shutil.rmtree(self.persist_directory)
        self.vectorstore = None

rag_chain.py

python
"""RAG链模块"""
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough, RunnableParallel
from langchain_core.documents import Document

class RAGChain:
    """RAG问答链"""
    
    def __init__(self, knowledge_base):
        self.kb = knowledge_base
        self.model = ChatOpenAI(model="gpt-4o", temperature=0)
        self._build_chain()
    
    def _build_chain(self):
        """构建RAG链"""
        # 提示词模板
        self.prompt = ChatPromptTemplate.from_messages([
            ("system", """你是知识库助手。根据上下文回答问题。

规则:
1. 只使用上下文中的信息回答
2. 如果上下文没有相关信息,说"根据知识库,我无法回答这个问题"
3. 回答要准确、简洁、有条理
4. 在回答末尾标注引用的来源编号

上下文:
{context}"""),
            ("user", "{question}")
        ])
        
        self.parser = StrOutputParser()
    
    def _format_docs(self, docs: list) -> str:
        """格式化文档"""
        formatted = []
        for i, doc in enumerate(docs):
            source = doc.metadata.get('source', '未知')
            formatted.append(f"[{i+1}] 来源:{source}\n{doc.page_content}")
        return "\n\n---\n\n".join(formatted)
    
    def query(self, question: str):
        """查询并返回答案"""
        # 检索
        docs = self.kb.search(question, k=5)
        
        if not docs:
            return {
                "answer": "知识库中没有相关内容,请先上传文档。",
                "sources": []
            }
        
        # 格式化上下文
        context = self._format_docs(docs)
        
        # 生成答案
        chain = self.prompt | self.model | self.parser
        answer = chain.invoke({
            "context": context,
            "question": question
        })
        
        return {
            "answer": answer,
            "sources": [
                {
                    "content": doc.page_content[:200] + "...",
                    "source": doc.metadata.get('source', '未知')
                }
                for doc in docs[:3]  # 只返回前3个来源
            ]
        }
    
    def query_stream(self, question: str):
        """流式查询"""
        docs = self.kb.search(question, k=5)
        
        if not docs:
            yield "知识库中没有相关内容。"
            return
        
        context = self._format_docs(docs)
        
        chain = self.prompt | self.model | self.parser
        
        for chunk in chain.stream({
            "context": context,
            "question": question
        }):
            yield chunk

app.py

python
"""Streamlit界面"""
import os
import tempfile
import streamlit as st
from knowledge_base import KnowledgeBase
from rag_chain import RAGChain

# 页面配置
st.set_page_config(
    page_title="知识库助手",
    page_icon="📚",
    layout="wide"
)

# 初始化
@st.cache_resource
def init_knowledge_base():
    return KnowledgeBase("./data/chroma_db")

@st.cache_resource
def init_rag_chain(kb):
    return RAGChain(kb)

kb = init_knowledge_base()
rag = init_rag_chain(kb)

# 侧边栏
with st.sidebar:
    st.title("📚 知识库管理")
    
    # 文档上传
    uploaded_file = st.file_uploader(
        "上传文档",
        type=['pdf', 'txt', 'md'],
        help="支持PDF、TXT、MD格式"
    )
    
    if uploaded_file:
        if st.button("添加到知识库"):
            with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(uploaded_file.name)[1]) as tmp:
                tmp.write(uploaded_file.getvalue())
                tmp_path = tmp.name
            
            with st.spinner("处理中..."):
                count = kb.add_documents(tmp_path)
            
            os.unlink(tmp_path)
            st.success(f"已添加 {count} 个文档块")
            st.rerun()
    
    # 知识库信息
    st.divider()
    doc_count = kb.get_document_count()
    st.metric("知识库文档块数", doc_count)
    
    # 清空按钮
    if st.button("清空知识库", type="secondary"):
        kb.clear()
        st.success("已清空")
        st.rerun()

# 主界面
st.title("📚 个人知识库助手")
st.markdown("上传文档,让AI帮你理解和检索知识")

# 检查知识库
if kb.get_document_count() == 0:
    st.info("👈 请先在侧边栏上传文档")
    st.stop()

# 问答界面
question = st.chat_input("输入你的问题...")

if "messages" not in st.session_state:
    st.session_state.messages = []

# 显示历史消息
for msg in st.session_state.messages:
    with st.chat_message(msg["role"]):
        st.markdown(msg["content"])
        
        # 显示来源
        if msg["role"] == "assistant" and "sources" in msg:
            with st.expander("查看来源"):
                for i, source in enumerate(msg["sources"]):
                    st.caption(f"**来源 {i+1}:** {source['source']}")
                    st.text(source['content'])

# 处理问题
if question:
    # 显示用户问题
    with st.chat_message("user"):
        st.markdown(question)
    st.session_state.messages.append({"role": "user", "content": question})
    
    # 获取答案
    with st.chat_message("assistant"):
        with st.spinner("思考中..."):
            result = rag.query(question)
        
        # 流式输出答案
        st.markdown(result["answer"])
        
        # 显示来源
        if result["sources"]:
            with st.expander("查看来源"):
                for i, source in enumerate(result["sources"]):
                    st.caption(f"**来源 {i+1}:** {source['source']}")
                    st.text(source['content"])
        
        # 保存到历史
        st.session_state.messages.append({
            "role": "assistant",
            "content": result["answer"],
            "sources": result["sources"]
        })

运行项目

bash
# 安装依赖
pip install -r requirements.txt

# 配置环境变量
export OPENAI_API_KEY=your-key

# 运行
streamlit run app.py

功能演示

1. 上传文档

![上传文档] 支持拖拽上传PDF、TXT、MD文件

2. 智能问答

用户: 这份文档主要讲了什么?

助手: 根据文档内容,主要讲述了以下要点:
1. ...
2. ...
[来源: document.pdf 第1章]

3. 来源追溯

每个回答都显示引用的原文位置,方便验证。

扩展方向

  1. 支持更多文档格式 - Word、Excel、网页
  2. 添加对话记忆 - 支持多轮对话
  3. 优化检索效果 - 使用重排序、混合检索
  4. 添加API接口 - FastAPI后端服务

小结

本项目实现了:

  • 文档上传和处理
  • 向量化存储和检索
  • RAG问答链
  • Streamlit Web界面

下一步

继续学习 智能客服机器人,了解Agent和工具调用。