Skip to content

检索增强生成(RAG)

RAG(Retrieval-Augmented Generation)是让大模型使用外部知识的技术。本节介绍RAG的原理和实现。

什么是RAG?

问题背景

大模型的局限性:
1. 知识截止:只训练到某个时间点
2. 不懂私有数据:不知道公司内部文档
3. 会编造:不知道的问题可能胡编答案

RAG的解决方案:
用户提问 → 检索相关知识 → 把知识+问题一起发给模型 → 生成答案

RAG流程

┌─────────────┐
│   用户提问   │
└──────┬──────┘

┌─────────────┐
│  向量检索   │ ←── 查询知识库
└──────┬──────┘

┌─────────────┐
│  组装提示词  │ ←── 问题 + 检索结果
└──────┬──────┘

┌─────────────┐
│   LLM生成   │
└──────┬──────┘

┌─────────────┐
│   返回答案   │
└─────────────┘

基本实现

使用LangChain

python
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_chroma import Chroma
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough

# 1. 准备向量库
embeddings = OpenAIEmbeddings()
vectorstore = Chroma.from_texts(
    [
        "LangChain是一个开发AI应用的框架",
        "RAG是检索增强生成的缩写",
        "向量数据库用于存储和检索向量"
    ],
    embeddings
)
retriever = vectorstore.as_retriever()

# 2. 定义提示词
prompt = ChatPromptTemplate.from_messages([
    ("system", "根据以下上下文回答问题。如果上下文中没有相关信息,请说不知道。\n\n上下文:{context}"),
    ("user", "{question}")
])

# 3. 构建RAG链
model = ChatOpenAI(model="gpt-4o")

def format_docs(docs):
    return "\n\n".join(doc.page_content for doc in docs)

rag_chain = (
    {"context": retriever | format_docs, "question": RunnablePassthrough()}
    | prompt
    | model
    | StrOutputParser()
)

# 4. 查询
answer = rag_chain.invoke("什么是LangChain?")
print(answer)

手动实现

python
def rag_query(question: str, vectorstore, model):
    """手动实现RAG"""
    # 1. 检索相关文档
    docs = vectorstore.similarity_search(question, k=3)
    context = "\n\n".join(doc.page_content for doc in docs)
    
    # 2. 构建提示词
    prompt = f"""根据以下上下文回答问题。如果上下文中没有相关信息,请说不知道。

上下文:
{context}

问题:{question}

答案:"""
    
    # 3. 调用模型
    response = model.invoke(prompt)
    return response.content

提示词设计

基础提示词

python
prompt = ChatPromptTemplate.from_messages([
    ("system", "根据上下文回答问题"),
    ("user", "上下文:{context}\n\n问题:{question}")
])

带引用的提示词

python
prompt = ChatPromptTemplate.from_messages([
    ("system", """根据上下文回答问题,并标注信息来源。

要求:
1. 只使用上下文中的信息
2. 在答案后列出引用的来源编号
3. 如果上下文没有相关信息,说"根据现有知识库无法回答"

上下文:
{context}"""),
    ("user", "{question}")
])

多轮对话提示词

python
prompt = ChatPromptTemplate.from_messages([
    ("system", "根据上下文回答问题"),
    ("placeholder", "{chat_history}"),
    ("user", "上下文:{context}\n\n问题:{question}")
])

from langchain_core.runnables import RunnablePassthrough
from langchain.chains import create_history_aware_retriever, create_retrieval_chain

# 创建带历史的RAG
chain = create_retrieval_chain(retriever, prompt | model)

检索器配置

基本配置

python
# 返回最相似的3个文档
retriever = vectorstore.as_retriever(
    search_kwargs={"k": 3}
)

相似度阈值

python
# 只返回相似度高于阈值的文档
retriever = vectorstore.as_retriever(
    search_type="similarity_score_threshold",
    search_kwargs={
        "k": 5,
        "score_threshold": 0.7  # 相似度阈值
    }
)

MMR检索

python
# 最大边际相关性,增加结果多样性
retriever = vectorstore.as_retriever(
    search_type="mmr",
    search_kwargs={
        "k": 5,
        "fetch_k": 20,
        "lambda_mult": 0.5  # 0=最大多样性,1=最大相关性
    }
)

返回引用来源

python
from langchain_core.runnables import RunnableParallel

def format_docs_with_source(docs):
    """格式化文档并保留来源信息"""
    formatted = []
    sources = []
    for i, doc in enumerate(docs):
        formatted.append(f"[{i+1}] {doc.page_content}")
        sources.append(doc.metadata.get("source", "未知"))
    return "\n\n".join(formatted), sources

# 构建带来源的链
rag_chain_with_source = RunnableParallel(
    answer=(
        {"context": retriever | format_docs_with_source, "question": RunnablePassthrough()}
        | prompt
        | model
        | StrOutputParser()
    ),
    sources=retriever
)

result = rag_chain_with_source.invoke("问题")
print("答案:", result["answer"])
print("来源:", [doc.metadata for doc in result["sources"]])

完整RAG应用示例

python
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_chroma import Chroma
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough, RunnableParallel
from langchain_community.document_loaders import PyPDFLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter

class RAGApplication:
    def __init__(self, docs_path: str = None, persist_directory: str = "./rag_db"):
        self.embeddings = OpenAIEmbeddings()
        self.model = ChatOpenAI(model="gpt-4o")
        self.persist_directory = persist_directory
        
        if docs_path:
            self.build_index(docs_path)
        else:
            self.load_index()
        
        self._build_chain()
    
    def build_index(self, docs_path: str):
        """构建向量索引"""
        # 加载文档
        loader = PyPDFLoader(docs_path)
        documents = loader.load()
        
        # 切分
        splitter = RecursiveCharacterTextSplitter(
            chunk_size=1000,
            chunk_overlap=200
        )
        chunks = splitter.split_documents(documents)
        
        # 创建向量库
        self.vectorstore = Chroma.from_documents(
            documents=chunks,
            embedding=self.embeddings,
            persist_directory=self.persist_directory
        )
        self.retriever = self.vectorstore.as_retriever(
            search_kwargs={"k": 4}
        )
    
    def load_index(self):
        """加载已有索引"""
        self.vectorstore = Chroma(
            persist_directory=self.persist_directory,
            embedding_function=self.embeddings
        )
        self.retriever = self.vectorstore.as_retriever()
    
    def _build_chain(self):
        """构建RAG链"""
        prompt = ChatPromptTemplate.from_messages([
            ("system", """你是知识库助手。根据上下文回答问题。

规则:
1. 只使用上下文中的信息回答
2. 如果上下文没有相关信息,说"知识库中没有相关信息"
3. 回答要准确、简洁

上下文:
{context}"""),
            ("user", "{question}")
])
        
        def format_docs(docs):
            return "\n\n---\n\n".join(
                f"[来源: {d.metadata.get('source', '未知')}]\n{d.page_content}"
                for d in docs
            )
        
        self.chain = (
            {
                "context": self.retriever | format_docs,
                "question": RunnablePassthrough()
            }
            | prompt
            | self.model
            | StrOutputParser()
        )
    
    def query(self, question: str) -> str:
        """查询"""
        return self.chain.invoke(question)
    
    def query_with_sources(self, question: str):
        """查询并返回来源"""
        docs = self.retriever.invoke(question)
        answer = self.chain.invoke(question)
        return {
            "answer": answer,
            "sources": [
                {"content": d.page_content[:200], "metadata": d.metadata}
                for d in docs
            ]
        }

# 使用
rag = RAGApplication("./knowledge.pdf")
result = rag.query_with_sources("什么是RAG?")
print("答案:", result["answer"])
print("\n来源:")
for s in result["sources"]:
    print(f"- {s['metadata'].get('source')}: {s['content'][:100]}...")

小结

概念说明
RAG检索增强生成
检索器从向量库获取相关文档
提示词组装上下文和问题
引用来源返回答案的来源信息

下一步

继续学习 RAG优化技巧,提升RAG应用的效果。