检索增强生成(RAG)
RAG(Retrieval-Augmented Generation)是让大模型使用外部知识的技术。本节介绍RAG的原理和实现。
什么是RAG?
问题背景
大模型的局限性:
1. 知识截止:只训练到某个时间点
2. 不懂私有数据:不知道公司内部文档
3. 会编造:不知道的问题可能胡编答案
RAG的解决方案:
用户提问 → 检索相关知识 → 把知识+问题一起发给模型 → 生成答案RAG流程
┌─────────────┐
│ 用户提问 │
└──────┬──────┘
↓
┌─────────────┐
│ 向量检索 │ ←── 查询知识库
└──────┬──────┘
↓
┌─────────────┐
│ 组装提示词 │ ←── 问题 + 检索结果
└──────┬──────┘
↓
┌─────────────┐
│ LLM生成 │
└──────┬──────┘
↓
┌─────────────┐
│ 返回答案 │
└─────────────┘基本实现
使用LangChain
python
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_chroma import Chroma
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
# 1. 准备向量库
embeddings = OpenAIEmbeddings()
vectorstore = Chroma.from_texts(
[
"LangChain是一个开发AI应用的框架",
"RAG是检索增强生成的缩写",
"向量数据库用于存储和检索向量"
],
embeddings
)
retriever = vectorstore.as_retriever()
# 2. 定义提示词
prompt = ChatPromptTemplate.from_messages([
("system", "根据以下上下文回答问题。如果上下文中没有相关信息,请说不知道。\n\n上下文:{context}"),
("user", "{question}")
])
# 3. 构建RAG链
model = ChatOpenAI(model="gpt-4o")
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
rag_chain = (
{"context": retriever | format_docs, "question": RunnablePassthrough()}
| prompt
| model
| StrOutputParser()
)
# 4. 查询
answer = rag_chain.invoke("什么是LangChain?")
print(answer)手动实现
python
def rag_query(question: str, vectorstore, model):
"""手动实现RAG"""
# 1. 检索相关文档
docs = vectorstore.similarity_search(question, k=3)
context = "\n\n".join(doc.page_content for doc in docs)
# 2. 构建提示词
prompt = f"""根据以下上下文回答问题。如果上下文中没有相关信息,请说不知道。
上下文:
{context}
问题:{question}
答案:"""
# 3. 调用模型
response = model.invoke(prompt)
return response.content提示词设计
基础提示词
python
prompt = ChatPromptTemplate.from_messages([
("system", "根据上下文回答问题"),
("user", "上下文:{context}\n\n问题:{question}")
])带引用的提示词
python
prompt = ChatPromptTemplate.from_messages([
("system", """根据上下文回答问题,并标注信息来源。
要求:
1. 只使用上下文中的信息
2. 在答案后列出引用的来源编号
3. 如果上下文没有相关信息,说"根据现有知识库无法回答"
上下文:
{context}"""),
("user", "{question}")
])多轮对话提示词
python
prompt = ChatPromptTemplate.from_messages([
("system", "根据上下文回答问题"),
("placeholder", "{chat_history}"),
("user", "上下文:{context}\n\n问题:{question}")
])
from langchain_core.runnables import RunnablePassthrough
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
# 创建带历史的RAG
chain = create_retrieval_chain(retriever, prompt | model)检索器配置
基本配置
python
# 返回最相似的3个文档
retriever = vectorstore.as_retriever(
search_kwargs={"k": 3}
)相似度阈值
python
# 只返回相似度高于阈值的文档
retriever = vectorstore.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={
"k": 5,
"score_threshold": 0.7 # 相似度阈值
}
)MMR检索
python
# 最大边际相关性,增加结果多样性
retriever = vectorstore.as_retriever(
search_type="mmr",
search_kwargs={
"k": 5,
"fetch_k": 20,
"lambda_mult": 0.5 # 0=最大多样性,1=最大相关性
}
)返回引用来源
python
from langchain_core.runnables import RunnableParallel
def format_docs_with_source(docs):
"""格式化文档并保留来源信息"""
formatted = []
sources = []
for i, doc in enumerate(docs):
formatted.append(f"[{i+1}] {doc.page_content}")
sources.append(doc.metadata.get("source", "未知"))
return "\n\n".join(formatted), sources
# 构建带来源的链
rag_chain_with_source = RunnableParallel(
answer=(
{"context": retriever | format_docs_with_source, "question": RunnablePassthrough()}
| prompt
| model
| StrOutputParser()
),
sources=retriever
)
result = rag_chain_with_source.invoke("问题")
print("答案:", result["answer"])
print("来源:", [doc.metadata for doc in result["sources"]])完整RAG应用示例
python
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_chroma import Chroma
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough, RunnableParallel
from langchain_community.document_loaders import PyPDFLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
class RAGApplication:
def __init__(self, docs_path: str = None, persist_directory: str = "./rag_db"):
self.embeddings = OpenAIEmbeddings()
self.model = ChatOpenAI(model="gpt-4o")
self.persist_directory = persist_directory
if docs_path:
self.build_index(docs_path)
else:
self.load_index()
self._build_chain()
def build_index(self, docs_path: str):
"""构建向量索引"""
# 加载文档
loader = PyPDFLoader(docs_path)
documents = loader.load()
# 切分
splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200
)
chunks = splitter.split_documents(documents)
# 创建向量库
self.vectorstore = Chroma.from_documents(
documents=chunks,
embedding=self.embeddings,
persist_directory=self.persist_directory
)
self.retriever = self.vectorstore.as_retriever(
search_kwargs={"k": 4}
)
def load_index(self):
"""加载已有索引"""
self.vectorstore = Chroma(
persist_directory=self.persist_directory,
embedding_function=self.embeddings
)
self.retriever = self.vectorstore.as_retriever()
def _build_chain(self):
"""构建RAG链"""
prompt = ChatPromptTemplate.from_messages([
("system", """你是知识库助手。根据上下文回答问题。
规则:
1. 只使用上下文中的信息回答
2. 如果上下文没有相关信息,说"知识库中没有相关信息"
3. 回答要准确、简洁
上下文:
{context}"""),
("user", "{question}")
])
def format_docs(docs):
return "\n\n---\n\n".join(
f"[来源: {d.metadata.get('source', '未知')}]\n{d.page_content}"
for d in docs
)
self.chain = (
{
"context": self.retriever | format_docs,
"question": RunnablePassthrough()
}
| prompt
| self.model
| StrOutputParser()
)
def query(self, question: str) -> str:
"""查询"""
return self.chain.invoke(question)
def query_with_sources(self, question: str):
"""查询并返回来源"""
docs = self.retriever.invoke(question)
answer = self.chain.invoke(question)
return {
"answer": answer,
"sources": [
{"content": d.page_content[:200], "metadata": d.metadata}
for d in docs
]
}
# 使用
rag = RAGApplication("./knowledge.pdf")
result = rag.query_with_sources("什么是RAG?")
print("答案:", result["answer"])
print("\n来源:")
for s in result["sources"]:
print(f"- {s['metadata'].get('source')}: {s['content'][:100]}...")小结
| 概念 | 说明 |
|---|---|
| RAG | 检索增强生成 |
| 检索器 | 从向量库获取相关文档 |
| 提示词 | 组装上下文和问题 |
| 引用来源 | 返回答案的来源信息 |
下一步
继续学习 RAG优化技巧,提升RAG应用的效果。