LangChain Ensemble Retriever¶
The LangChain EnsembleRetriever
takes a list of retrievers as input and ensemble the results of their get_relevant_documents()
methods and rerank the results based on the Reciprocal Rank Fusion (RRF) algorithm. With TruLens, we have the ability to evaluate the context of each component retriever along with the ensemble retriever, compare performance, and track context relevance across all retrievers. This example walks through that process.
Setup¶
# !pip install trulens trulens-apps-langchain trulens-providers-openai openai langchain langchain_community langchain_openai rank_bm25 faiss_cpu
from getpass import getpass
import os
if not os.getenv("OPENAI_API_KEY"):
os.environ["OPENAI_API_KEY"] = getpass("Enter your OpenAI API key: ")
os.environ["TRULENS_OTEL_TRACING"] = "1"
# Imports main tools:
# Imports from LangChain to build app
from langchain.retrievers import BM25Retriever
from langchain.retrievers import EnsembleRetriever
from langchain_community.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings
from trulens.apps.langchain import TruChain
from trulens.core import Feedback
from trulens.core import TruSession
session = TruSession()
session.reset_database()
doc_list = [
"Python is a popular programming language.",
"JavaScript is mainly used for web development.",
"C++ is known for its performance in system programming.",
"The snake is a reptile found in many parts of the world.", # Lexical distractor
"Web pages are often made interactive with JS.", # Paraphrase
"Many developers love coding in Python due to its simplicity.", # Paraphrase
"A 500 error code indicates an internal server error.",
"Internal server errors occur for a variety of reasons, including a bug in the code or a configuration error.",
]
# initialize the bm25 retriever and faiss retriever
bm25_retriever = BM25Retriever.from_texts(doc_list)
bm25_retriever.k = 1
embedding = OpenAIEmbeddings()
faiss_vectorstore = FAISS.from_texts(doc_list, embedding)
faiss_retriever = faiss_vectorstore.as_retriever(search_kwargs={"k": 1})
ensemble_retriever = EnsembleRetriever(
retrievers=[bm25_retriever, faiss_retriever], weights=[0.5, 0.5]
)
Initialize Context Relevance checks for each component retriever + ensemble¶
This requires knowing the feedback selector for each retriever. You can find this path by logging a run of your application and examining the application traces on the Evaluations page.
Read more in our docs: Selecting Components.
import numpy as np
from trulens.core.feedback.selector import Selector
from trulens.otel.semconv.trace import SpanAttributes
from trulens.providers.openai import OpenAI
# Initialize provider class
openai = OpenAI()
bm25_context = Selector(
function_name="langchain_community.retrievers.bm25.BM25Retriever._get_relevant_documents",
span_attribute=SpanAttributes.RETRIEVAL.RETRIEVED_CONTEXTS,
collect_list=False,
)
faiss_context = Selector(
function_name="langchain_core.vectorstores.base.VectorStoreRetriever._get_relevant_documents",
span_attribute=SpanAttributes.RETRIEVAL.RETRIEVED_CONTEXTS,
collect_list=False,
)
ensemble_context = Selector(
span_type=SpanAttributes.SpanType.RETRIEVAL,
span_attribute=SpanAttributes.RETRIEVAL.RETRIEVED_CONTEXTS,
collect_list=False,
)
# Question/statement relevance between question and each context chunk.
f_context_relevance_bm25 = (
Feedback(openai.context_relevance, name="BM25")
.on_input()
.on({"context": bm25_context})
.aggregate(np.mean)
)
f_context_relevance_faiss = (
Feedback(openai.context_relevance, name="FAISS")
.on_input()
.on({"context": faiss_context})
.aggregate(np.mean)
)
f_context_relevance_ensemble = (
Feedback(openai.context_relevance, name="Ensemble")
.on_input()
.on({"context": ensemble_context})
.aggregate(np.mean)
)
Add feedbacks¶
tru_recorder = TruChain(
ensemble_retriever,
app_name="Ensemble Retriever",
feedbacks=[
f_context_relevance_bm25,
f_context_relevance_faiss,
f_context_relevance_ensemble,
],
)
queries = [
"Internal server error code?",
"A limbless animal that slithers and is widespread.", # Should match snake (semantic only)
"Which language is preferred for low-level, high-speed applications?", # Should match C++ (semantic only)
]
for query in queries:
print(f"Query: {query}")
print(
"BM25:",
[d.page_content for d in bm25_retriever.get_relevant_documents(query)],
)
print(
"FAISS:",
[d.page_content for d in faiss_retriever.get_relevant_documents(query)],
)
print(
"Ensemble:",
[
d.page_content
for d in ensemble_retriever.get_relevant_documents(query)
],
)
print("-" * 40)
Explore in a Dashboard¶
from trulens.dashboard import run_dashboard
run_dashboard(session) # open a local streamlit app to explore
# stop_dashboard(session) # stop if needed
Alternatively, you can run trulens
from the CLI in the same folder to start the dashboard.