Multi-agent network with Snowflake tools for querying unstructured and structured data¶
A single agent can usually operate effectively using a handful of tools within a single domain, but even using powerful models like gpt-4
, it can be less effective at using many tools.
This notebook is an extension of the multi-agent-collaboration notebook, showing how access to more tools - particularly with private data can enhance the ability of a data agent.
We will build up the agent with more tools, starting with web search, then adding Cortex Agent that can both document search and query snowflake tables in sql via Cortex Analyst.
Prequisites: you must create your own Cortex Search Service and Cortex Analyst semantic model to be used in the Cortex Agents REST API call.
In [ ]:
Copied!
%%capture --no-stderr
# pip install -U langchain_community langchain_openai langchain_experimental matplotlib langgraph pygraphviz google-search-results
%%capture --no-stderr
# pip install -U langchain_community langchain_openai langchain_experimental matplotlib langgraph pygraphviz google-search-results
Choose an app name¶
In [ ]:
Copied!
APP_NAME = "Finance Data Agent" # set this app name for your use case
APP_NAME = "Finance Data Agent" # set this app name for your use case
Set the resources for Cortex Agent¶
In [ ]:
Copied!
SEMANTIC_MODEL_FILE = "@agents_db.notebooks.semantic_models/sec_filings.yaml"
CORTEX_SEARCH_SERVICE = "CORTEX_SEARCH_TUTORIAL_DB.PUBLIC.FOMC_SEARCH_SERVICE"
ACCOUNT_URL = "..."
SEMANTIC_MODEL_FILE = "@agents_db.notebooks.semantic_models/sec_filings.yaml"
CORTEX_SEARCH_SERVICE = "CORTEX_SEARCH_TUTORIAL_DB.PUBLIC.FOMC_SEARCH_SERVICE"
ACCOUNT_URL = "..."
Set keys¶
In [ ]:
Copied!
import os
os.environ["OPENAI_API_KEY"] = "sk-proj-..." # llm used by langgraph
os.environ["SERPAPI_API_KEY"] = "..." # web search
# ai observablity
os.environ["SNOWFLAKE_ACCOUNT"] = "..."
os.environ["SNOWFLAKE_USER"] = "..."
os.environ["SNOWFLAKE_USER_PASSWORD"] = "..."
os.environ["SNOWFLAKE_DATABASE"] = "AGENTS_DB"
os.environ["SNOWFLAKE_SCHEMA"] = "NOTEBOOKS"
os.environ["SNOWFLAKE_ROLE"] = "CORTEX_USER_ROLE"
os.environ["SNOWFLAKE_WAREHOUSE"] = "CONTAINER_RUNTIME_WH"
os.environ["SNOWFLAKE_PAT"] = "..." # cortex agent call
os.environ["TRULENS_OTEL_TRACING"] = (
"1" # to enable OTEL tracing -> note the Snowsight UI experience for now is limited to PuPr customers, not yet supported for OSS.
)
import os
os.environ["OPENAI_API_KEY"] = "sk-proj-..." # llm used by langgraph
os.environ["SERPAPI_API_KEY"] = "..." # web search
# ai observablity
os.environ["SNOWFLAKE_ACCOUNT"] = "..."
os.environ["SNOWFLAKE_USER"] = "..."
os.environ["SNOWFLAKE_USER_PASSWORD"] = "..."
os.environ["SNOWFLAKE_DATABASE"] = "AGENTS_DB"
os.environ["SNOWFLAKE_SCHEMA"] = "NOTEBOOKS"
os.environ["SNOWFLAKE_ROLE"] = "CORTEX_USER_ROLE"
os.environ["SNOWFLAKE_WAREHOUSE"] = "CONTAINER_RUNTIME_WH"
os.environ["SNOWFLAKE_PAT"] = "..." # cortex agent call
os.environ["TRULENS_OTEL_TRACING"] = (
"1" # to enable OTEL tracing -> note the Snowsight UI experience for now is limited to PuPr customers, not yet supported for OSS.
)
Import libraries¶
In [ ]:
Copied!
import ast
import datetime
import json
import os
import time
from typing import List, Literal
import uuid
from langchain.load.dump import dumps
from langchain.prompts import PromptTemplate
from langchain_community.utilities import SerpAPIWrapper
from langchain_core.documents import Document
from langchain_core.messages import AIMessage
from langchain_core.messages import BaseMessage
from langchain_core.messages import HumanMessage
from langchain_core.messages import SystemMessage
from langchain_core.messages import ToolMessage
from langchain_core.tools import StructuredTool
from langchain_core.tools import Tool
from langchain_core.tools import tool
from langchain_core.vectorstores import InMemoryVectorStore
from langchain_experimental.utilities import PythonREPL
from langchain_openai import ChatOpenAI
from langchain_openai import OpenAIEmbeddings
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import MessagesState
from langgraph.graph import StateGraph
from langgraph.prebuilt import create_react_agent
from langgraph.types import Command
from pydantic import BaseModel
from snowflake.snowpark import Session
from trulens.apps.app import TruApp
from trulens.connectors.snowflake import SnowflakeConnector
from trulens.core.otel.instrument import instrument
from trulens.core.run import Run
from trulens.core.run import RunConfig
from trulens.otel.semconv.trace import BASE_SCOPE
from trulens.otel.semconv.trace import SpanAttributes
import ast
import datetime
import json
import os
import time
from typing import List, Literal
import uuid
from langchain.load.dump import dumps
from langchain.prompts import PromptTemplate
from langchain_community.utilities import SerpAPIWrapper
from langchain_core.documents import Document
from langchain_core.messages import AIMessage
from langchain_core.messages import BaseMessage
from langchain_core.messages import HumanMessage
from langchain_core.messages import SystemMessage
from langchain_core.messages import ToolMessage
from langchain_core.tools import StructuredTool
from langchain_core.tools import Tool
from langchain_core.tools import tool
from langchain_core.vectorstores import InMemoryVectorStore
from langchain_experimental.utilities import PythonREPL
from langchain_openai import ChatOpenAI
from langchain_openai import OpenAIEmbeddings
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import MessagesState
from langgraph.graph import StateGraph
from langgraph.prebuilt import create_react_agent
from langgraph.types import Command
from pydantic import BaseModel
from snowflake.snowpark import Session
from trulens.apps.app import TruApp
from trulens.connectors.snowflake import SnowflakeConnector
from trulens.core.otel.instrument import instrument
from trulens.core.run import Run
from trulens.core.run import RunConfig
from trulens.otel.semconv.trace import BASE_SCOPE
from trulens.otel.semconv.trace import SpanAttributes
Create TruLens/Snowflake Connection¶
In [ ]:
Copied!
# Snowflake account for trulens
snowflake_connection_parameters = {
"account": os.environ["SNOWFLAKE_ACCOUNT"],
"user": os.environ["SNOWFLAKE_USER"],
"password": os.environ["SNOWFLAKE_USER_PASSWORD"],
"database": os.environ["SNOWFLAKE_DATABASE"],
"schema": os.environ["SNOWFLAKE_SCHEMA"],
"role": os.environ["SNOWFLAKE_ROLE"],
"warehouse": os.environ["SNOWFLAKE_WAREHOUSE"],
}
snowpark_session_trulens = Session.builder.configs(
snowflake_connection_parameters
).create()
trulens_sf_connector = SnowflakeConnector(
snowpark_session=snowpark_session_trulens
)
# Snowflake account for trulens
snowflake_connection_parameters = {
"account": os.environ["SNOWFLAKE_ACCOUNT"],
"user": os.environ["SNOWFLAKE_USER"],
"password": os.environ["SNOWFLAKE_USER_PASSWORD"],
"database": os.environ["SNOWFLAKE_DATABASE"],
"schema": os.environ["SNOWFLAKE_SCHEMA"],
"role": os.environ["SNOWFLAKE_ROLE"],
"warehouse": os.environ["SNOWFLAKE_WAREHOUSE"],
}
snowpark_session_trulens = Session.builder.configs(
snowflake_connection_parameters
).create()
trulens_sf_connector = SnowflakeConnector(
snowpark_session=snowpark_session_trulens
)
Define the agent with web search and charting tools¶
In [ ]:
Copied!
# augment message state to also track selected tools and the chart path
class ToolState(MessagesState):
selected_tools: List[str]
chart_path: str
# augment message state to also track selected tools and the chart path
class ToolState(MessagesState):
selected_tools: List[str]
chart_path: str
In [ ]:
Copied!
def build_graph():
def make_system_prompt(suffix: str) -> str:
return (
"You are a helpful AI assistant, collaborating with other assistants."
" Use the provided tools to progress towards answering the question."
" If you are unable to fully answer, that's OK, another assistant with different tools "
" will help where you left off. Execute what you can to make progress."
" If you or any of the other assistants have the final answer or deliverable,"
" prefix your response with FINAL ANSWER so the team knows to stop."
f"\n{suffix}"
)
search = SerpAPIWrapper()
search_tool = Tool(
name="web_search",
description="Search the web for current information, such as weather or news",
func=search.run,
)
tool_registry = {
str(uuid.uuid4()): search_tool,
}
# Update your tool documents indexing accordingly
tool_documents = [
Document(
page_content=f"{tool.name}\n\n{tool.description}",
id=tool_id,
metadata={
"tool_name": tool.name,
"tool_description": tool.description,
},
)
for tool_id, tool in tool_registry.items()
]
vector_store = InMemoryVectorStore(embedding=OpenAIEmbeddings())
vector_store.add_documents(tool_documents)
llm = ChatOpenAI(model="gpt-4o")
@instrument(
span_type="SELECT_TOOLS",
attributes=lambda ret, exc, *args, **kw: {
# ---- state as JSON-text (OTLP needs a scalar) -----------------
f"{BASE_SCOPE}.select_tools_input_state": json.dumps( # ← turns dict → str
{
**{k: v for k, v in args[0].items() if k != "messages"},
"messages": [
{"type": m.__class__.__name__, "content": m.content}
if hasattr(m, "content") # BaseMessage subclasses
else m # already JSON-friendly
for m in args[0].get("messages", [])
],
}
),
# ---- selected tool IDs as a simple comma-separated string -----
f"{BASE_SCOPE}.selected_tool_ids": ", ".join(
ret.update.get("selected_tools", [])
)
if "selected_tools" in ret.update
else "",
f"{BASE_SCOPE}.selected_tool_names": ", ".join(
tool_registry[tool_id].name
for tool_id in ret.update.get("selected_tools", [])
)
if "selected_tools" in ret.update
else "",
},
)
def select_tools(
state: ToolState,
) -> Command[Literal["research_agent", END]]:
query = state["messages"][-1].content
# 1) Do a full similarity search over all tools
results = vector_store.similarity_search_with_score(
query,
k=len(tool_documents),
)
# 3) Pick the single best match
best_doc, best_score = max(results, key=lambda x: x[1])
# 4) If it’s truly too low, bail out (optional)
MIN_SCORE = 0.5
if best_score < MIN_SCORE:
return Command(
update={
"messages": state["messages"]
+ [
HumanMessage(
content="Sorry, I don’t have a tool that’s relevant enough to answer that.",
name="assistant",
)
]
},
goto=END,
)
# 5) Otherwise select that one
return Command(
update={"selected_tools": [best_doc.id]},
goto="research_agent",
)
def get_next_node(last_message: BaseMessage, goto: str):
if "FINAL ANSWER" in last_message.content:
# Any agent decided the work is done
return END
return goto
@instrument(
span_type="RESEARCH_NODE",
attributes=lambda ret, exception, *args, **kwargs: {
f"{BASE_SCOPE}.research_node_input_content": args[0]["messages"][
-1
].content,
f"{BASE_SCOPE}.research_node_selected_tool_names": (
", ".join(
tool_registry.get(tool_id, "").name
for tool_id in args[0].get("selected_tools", [])
)
if "selected_tools" in args[0]
and len(args[0]["selected_tools"]) > 0
else "No tools selected"
),
f"{BASE_SCOPE}.planned_tool_call_names": (
[
call.get("function", {}).get("name", "")
for call in (
# if ret is a tuple, msg[1] is the AIMessage
ret[0]
.update["messages"][1]
.additional_kwargs.get("tool_calls", [])
if isinstance(ret, tuple)
# otherwise it's ret.update["messages"][1]
else ret.update["messages"][1].additional_kwargs.get(
"tool_calls", []
)
)
]
if (isinstance(ret, tuple) or hasattr(ret, "update"))
else []
),
f"{BASE_SCOPE}.planned_tool_call_args": (
[
call.get("function", {}).get("arguments", "")
for call in (
ret[0]
.update["messages"][1]
.additional_kwargs.get("tool_calls", [])
if isinstance(ret, tuple)
else ret.update["messages"][1].additional_kwargs.get(
"tool_calls", []
)
)
]
if (isinstance(ret, tuple) or hasattr(ret, "update"))
else []
),
f"{BASE_SCOPE}.agent_response": ret.update["messages"][-1].content
if hasattr(ret, "update")
else json.dumps(ret, indent=4, sort_keys=True),
f"{BASE_SCOPE}.web_search_results": [
json.loads(dumps(message)).get("kwargs", {}).get("content", "")
for message in ret.update["messages"]
if isinstance(message, ToolMessage)
and message.name == "web_search"
]
if hasattr(ret, "update")
else "No tool call",
},
)
@instrument(
span_type=SpanAttributes.SpanType.RETRIEVAL,
attributes=lambda ret, exception, *args, **kwargs: {
SpanAttributes.RETRIEVAL.QUERY_TEXT: args[0]["messages"][
-1
].content,
SpanAttributes.RETRIEVAL.RETRIEVED_CONTEXTS: [
json.loads(dumps(message)).get("kwargs", {}).get("content", "")
for message in ret.update["messages"]
if isinstance(message, ToolMessage)
and message.name == "web_search"
]
if hasattr(ret, "update")
else "No tool call",
},
)
def research_agent_node(
state: ToolState,
) -> Command[Literal["chart_generator"]]:
"""
Always binds the selected tools and invokes the bound agent.
Stops on FINAL ANSWER or moves to chart_generator.
"""
# grab (non-empty) list of selected tool IDs
selected_ids = state["selected_tools"]
# bind only those tools
selected_tools = [tool_registry[tid] for tid in selected_ids]
bound_llm = llm.bind_tools(selected_tools)
bound_agent = create_react_agent(
bound_llm,
tools=selected_tools, # already bound
prompt=make_system_prompt(
"You can only do research. You are working with both a chart generator and a chart summarizer colleagues."
),
)
# run it
result = bound_agent.invoke(state)
# decide if we’re done
last = result["messages"][-1]
goto = get_next_node(last, "chart_generator")
# tag the origin of the final message
result["messages"][-1] = HumanMessage(
content=last.content,
name="research_agent",
)
return Command(
update={"messages": result["messages"]},
goto=goto,
)
# Warning: This executes code locally, which can be unsafe when not sandboxed
repl = PythonREPL()
@tool
@instrument(
span_type="PYTHON_REPL_TOOL",
attributes={
f"{BASE_SCOPE}.python_tool_input_code": "code",
},
)
def python_repl_tool(code: str):
"""
Run arbitrary Python, grab the CURRENT matplotlib figure (if any),
save it to ./langgraph_saved_images_snowflaketools/v1/chart_<uuid>.png,
and return a first-line `CHART_PATH=…`.
"""
import matplotlib
matplotlib.use("Agg") # headless safety
import os
import uuid
import matplotlib.pyplot as plt
# ------------------ run user code & capture stdout ------------------
repl.run(code)
# ------------------ locate a figure (if generated) ------------------
fig = plt.gcf()
has_axes = bool(fig.axes) # True if something was plotted
# ------------------ always save if we have a figure -----------------
chart_path = ""
if has_axes:
target_dir = "./langgraph_saved_images_snowflaketools/v1"
os.makedirs(target_dir, exist_ok=True)
chart_path = os.path.join(
target_dir, f"chart_{uuid.uuid4().hex}.png"
)
fig.savefig(chart_path, format="png")
plt.close(fig)
# ------------------ tool result (1st line = CHART_PATH) -------------
return f"CHART_PATH={chart_path if chart_path else 'NONE'}\n"
# Chart generator agent and node
# NOTE: THIS PERFORMS ARBITRARY CODE EXECUTION, WHICH CAN BE UNSAFE WHEN NOT SANDBOXED
# 1) Define the chart‐agent: it only returns JSON with a "code" field
chart_agent = create_react_agent(
llm,
[python_repl_tool],
prompt=make_system_prompt(
"""You can only generate charts by returning a single JSON object, for example:
{
"code": "<your python plotting code here>"
}
—where <your python plotting code> uses matplotlib to create exactly one figure.
The plot should always include axis titles and relevant labels at the minimum.
Do NOT include any prose or tool‐call wrappers."""
),
)
def extract_chart_path(text: str) -> str | None:
"""
Scan every line of tool stdout for 'CHART_PATH=' and return
whatever follows, trimmed. Returns None if no such line exists.
"""
for line in text.splitlines():
if "CHART_PATH=" in line:
# split on the first '=', strip whitespace
return line.split("CHART_PATH=", 1)[1].strip()
return None
@instrument(
span_type="CHART_GENERATOR_NODE",
attributes=lambda ret, exception, *args, **kwargs: {
f"{BASE_SCOPE}.chart_node_input": args[0]["messages"][-1].content,
f"{BASE_SCOPE}.chart_node_response": (
ret.update["messages"][-1].content
if ret and hasattr(ret, "update") and ret.update
else "No update response"
),
},
)
def chart_node(state: ToolState) -> Command[Literal["chart_summarizer"]]:
# 0) If a path is already in state, skip
# extract the current human query
current_query = state["messages"][-1].content
# if we already generated a chart for _this_ query, skip
if state.get("last_query") == current_query and state.get("chart_path"):
return Command(
update={"messages": state["messages"]}, goto="chart_summarizer"
)
# it's a new query (or first run) → clear any old chart_path and remember this query
state.pop("chart_path", None)
state["last_query"] = current_query
# 1) Remember how many messages we had
len_before = len(state["messages"])
# 2) Run the agent exactly once
agent_out = chart_agent.invoke(state)
all_msgs = agent_out["messages"]
# 3) Look at only the brand-new messages for our chart tool output
new_segment = all_msgs[len_before:]
tool_msgs = [
m
for m in new_segment
if isinstance(m, ToolMessage) and "CHART_PATH=" in m.content
]
if not tool_msgs:
return Command(
update={"messages": state["messages"]},
goto="research_agent",
)
# 4) Parse the last one in case there are multiples
tool_msg = tool_msgs[-1]
tool_stdout = tool_msg.content
chart_path = extract_chart_path(tool_stdout)
# 5) Build your new messages list: include only that new ToolMessage
new_msgs = state["messages"][:] + [tool_msg]
# 6) Success! stash path into state and append the CHART_PATH marker
new_msgs.append(
HumanMessage(
content=f"CHART_PATH={chart_path}", name="chart_generator"
)
)
return Command(
update={"messages": new_msgs, "chart_path": chart_path},
goto="chart_summarizer",
)
@instrument(
span_type="CHART_SUMMARY_NODE",
attributes=lambda ret, exception, *args, **kwargs: {
# grab the state dict (kwarg wins, else first arg)
f"{BASE_SCOPE}.summary_node_input": (
(kwargs.get("state") or args[0])["messages"][-1].content
),
f"{BASE_SCOPE}.summary_node_output": (
ret.update["messages"][-1].content
if hasattr(ret, "update")
else "NO SUMMARY GENERATED"
),
},
)
def chart_summary_node(state: ToolState) -> Command:
# 1) find the chart_path in state
chart_path = state.get("chart_path", "")
if not chart_path:
return Command(
update={
"messages": state["messages"]
+ [
HumanMessage(
"No valid chart was generated. Please try again.",
name="chart_summarizer",
)
]
},
goto="select_tools",
)
# 2) strip *everything* except human utterances
human_history = [
m for m in state["messages"] if isinstance(m, HumanMessage)
]
# ensure our CHART_PATH marker is last
if not human_history or not human_history[-1].content.startswith(
"CHART_PATH="
):
human_history.append(
HumanMessage(
f"CHART_PATH={chart_path}", name="chart_summarizer"
)
)
# 3) build your ChatCompletion prompt
system = SystemMessage(
content=make_system_prompt(
"You are an AI assistant whose *only* job is to describe a chart image. "
"Input is a message CHART_PATH=… pointing at a saved PNG. "
"Include complete details and specifics of the chart image."
)
)
messages_for_llm = (
[system]
+ human_history
+ [HumanMessage("Please describe the above chart.")]
)
# 4) call the LLM directly—no tools, no React agent
ai_msg: AIMessage = llm(messages_for_llm)
summary = ai_msg.content
return Command(
update={
"messages": state["messages"]
+ [
HumanMessage(summary, name="chart_summarizer"),
]
},
goto="reflection",
)
@instrument(
span_type="CHART_SUMMARY_REFLECTION",
attributes=lambda ret, exception, *args, **kwargs: {
f"{BASE_SCOPE}.chart_summary_reflection_input_user_query": (
(kwargs.get("state") or args[0])["messages"][0].content
),
f"{BASE_SCOPE}.chart_summary_reflection_input_chart_summary": (
(kwargs.get("state") or args[0])["messages"][-1].content
),
# extract the summary string rather than returning the Command object
f"{BASE_SCOPE}.chart_summary_reflection_response": (
ret.update["messages"][-1].content
if hasattr(ret, "update")
else ""
),
},
)
def reflection_node(state: ToolState) -> Command:
"""
This function uses an LLM to reflect on the quality of a chart summary
and determine if the task is complete or requires further refinement.
"""
reflection_prompt_template = PromptTemplate(
input_variables=["user_query", "chart_summary"],
template="""\
You are an AI assistant tasked with reflecting on the quality of a chart summary. The user has asked the following question:
"{user_query}"
You are given the following chart summary:
"{chart_summary}"
Your task is to evaluate how well the chart summary answers the user's question. Consider the following:
- Does it describe a chart that will be relevant for answering the user's query?
If the summary **generally** addresses the question, respond with 'Task complete'. If the summary **lacks significant** details or clarity, then respond with specific details on how the answer should be improved and what information is needed. Avoid being overly critical unless the summary completely misses key elements necessary to answer the query.
Please provide your answer in a **concise and encouraging** manner.
""",
)
# Create the chain using the prompt template and the LLM (ChatOpenAI)
reflection_chain = reflection_prompt_template | llm
user_query = state["messages"][0].content
chart_summary = state["messages"][-1].content
# Call the chain with the user query and chart summary
reflection_result = reflection_chain.invoke({
"user_query": user_query,
"chart_summary": chart_summary,
})
if "Task complete" in reflection_result.content:
return Command(
update={
"messages": state["messages"]
+ [
HumanMessage(
reflection_result.content, name="reflection"
)
]
+ [
HumanMessage(
f"Chart saved at {state['chart_path']}. \n The chart summary is: \n {chart_summary}",
name="approved chart summary",
)
]
},
goto=END,
)
else:
return Command(
update={
"messages": state["messages"]
+ [
HumanMessage(
reflection_result.content, name="reflection"
)
]
},
goto="select_tools",
)
workflow = StateGraph(ToolState)
workflow.add_node("select_tools", select_tools)
workflow.add_node("research_agent", research_agent_node)
workflow.add_node("chart_generator", chart_node)
workflow.add_node("chart_summarizer", chart_summary_node)
workflow.add_node("reflection", reflection_node)
workflow.add_edge(START, "select_tools")
workflow.add_edge("select_tools", "research_agent")
workflow.add_edge("research_agent", "chart_generator")
workflow.add_edge("chart_generator", "chart_summarizer")
workflow.add_edge("chart_summarizer", "reflection")
workflow.add_edge("reflection", END)
compiled_graph = workflow.compile()
return compiled_graph
def build_graph():
def make_system_prompt(suffix: str) -> str:
return (
"You are a helpful AI assistant, collaborating with other assistants."
" Use the provided tools to progress towards answering the question."
" If you are unable to fully answer, that's OK, another assistant with different tools "
" will help where you left off. Execute what you can to make progress."
" If you or any of the other assistants have the final answer or deliverable,"
" prefix your response with FINAL ANSWER so the team knows to stop."
f"\n{suffix}"
)
search = SerpAPIWrapper()
search_tool = Tool(
name="web_search",
description="Search the web for current information, such as weather or news",
func=search.run,
)
tool_registry = {
str(uuid.uuid4()): search_tool,
}
# Update your tool documents indexing accordingly
tool_documents = [
Document(
page_content=f"{tool.name}\n\n{tool.description}",
id=tool_id,
metadata={
"tool_name": tool.name,
"tool_description": tool.description,
},
)
for tool_id, tool in tool_registry.items()
]
vector_store = InMemoryVectorStore(embedding=OpenAIEmbeddings())
vector_store.add_documents(tool_documents)
llm = ChatOpenAI(model="gpt-4o")
@instrument(
span_type="SELECT_TOOLS",
attributes=lambda ret, exc, *args, **kw: {
# ---- state as JSON-text (OTLP needs a scalar) -----------------
f"{BASE_SCOPE}.select_tools_input_state": json.dumps( # ← turns dict → str
{
**{k: v for k, v in args[0].items() if k != "messages"},
"messages": [
{"type": m.__class__.__name__, "content": m.content}
if hasattr(m, "content") # BaseMessage subclasses
else m # already JSON-friendly
for m in args[0].get("messages", [])
],
}
),
# ---- selected tool IDs as a simple comma-separated string -----
f"{BASE_SCOPE}.selected_tool_ids": ", ".join(
ret.update.get("selected_tools", [])
)
if "selected_tools" in ret.update
else "",
f"{BASE_SCOPE}.selected_tool_names": ", ".join(
tool_registry[tool_id].name
for tool_id in ret.update.get("selected_tools", [])
)
if "selected_tools" in ret.update
else "",
},
)
def select_tools(
state: ToolState,
) -> Command[Literal["research_agent", END]]:
query = state["messages"][-1].content
# 1) Do a full similarity search over all tools
results = vector_store.similarity_search_with_score(
query,
k=len(tool_documents),
)
# 3) Pick the single best match
best_doc, best_score = max(results, key=lambda x: x[1])
# 4) If it’s truly too low, bail out (optional)
MIN_SCORE = 0.5
if best_score < MIN_SCORE:
return Command(
update={
"messages": state["messages"]
+ [
HumanMessage(
content="Sorry, I don’t have a tool that’s relevant enough to answer that.",
name="assistant",
)
]
},
goto=END,
)
# 5) Otherwise select that one
return Command(
update={"selected_tools": [best_doc.id]},
goto="research_agent",
)
def get_next_node(last_message: BaseMessage, goto: str):
if "FINAL ANSWER" in last_message.content:
# Any agent decided the work is done
return END
return goto
@instrument(
span_type="RESEARCH_NODE",
attributes=lambda ret, exception, *args, **kwargs: {
f"{BASE_SCOPE}.research_node_input_content": args[0]["messages"][
-1
].content,
f"{BASE_SCOPE}.research_node_selected_tool_names": (
", ".join(
tool_registry.get(tool_id, "").name
for tool_id in args[0].get("selected_tools", [])
)
if "selected_tools" in args[0]
and len(args[0]["selected_tools"]) > 0
else "No tools selected"
),
f"{BASE_SCOPE}.planned_tool_call_names": (
[
call.get("function", {}).get("name", "")
for call in (
# if ret is a tuple, msg[1] is the AIMessage
ret[0]
.update["messages"][1]
.additional_kwargs.get("tool_calls", [])
if isinstance(ret, tuple)
# otherwise it's ret.update["messages"][1]
else ret.update["messages"][1].additional_kwargs.get(
"tool_calls", []
)
)
]
if (isinstance(ret, tuple) or hasattr(ret, "update"))
else []
),
f"{BASE_SCOPE}.planned_tool_call_args": (
[
call.get("function", {}).get("arguments", "")
for call in (
ret[0]
.update["messages"][1]
.additional_kwargs.get("tool_calls", [])
if isinstance(ret, tuple)
else ret.update["messages"][1].additional_kwargs.get(
"tool_calls", []
)
)
]
if (isinstance(ret, tuple) or hasattr(ret, "update"))
else []
),
f"{BASE_SCOPE}.agent_response": ret.update["messages"][-1].content
if hasattr(ret, "update")
else json.dumps(ret, indent=4, sort_keys=True),
f"{BASE_SCOPE}.web_search_results": [
json.loads(dumps(message)).get("kwargs", {}).get("content", "")
for message in ret.update["messages"]
if isinstance(message, ToolMessage)
and message.name == "web_search"
]
if hasattr(ret, "update")
else "No tool call",
},
)
@instrument(
span_type=SpanAttributes.SpanType.RETRIEVAL,
attributes=lambda ret, exception, *args, **kwargs: {
SpanAttributes.RETRIEVAL.QUERY_TEXT: args[0]["messages"][
-1
].content,
SpanAttributes.RETRIEVAL.RETRIEVED_CONTEXTS: [
json.loads(dumps(message)).get("kwargs", {}).get("content", "")
for message in ret.update["messages"]
if isinstance(message, ToolMessage)
and message.name == "web_search"
]
if hasattr(ret, "update")
else "No tool call",
},
)
def research_agent_node(
state: ToolState,
) -> Command[Literal["chart_generator"]]:
"""
Always binds the selected tools and invokes the bound agent.
Stops on FINAL ANSWER or moves to chart_generator.
"""
# grab (non-empty) list of selected tool IDs
selected_ids = state["selected_tools"]
# bind only those tools
selected_tools = [tool_registry[tid] for tid in selected_ids]
bound_llm = llm.bind_tools(selected_tools)
bound_agent = create_react_agent(
bound_llm,
tools=selected_tools, # already bound
prompt=make_system_prompt(
"You can only do research. You are working with both a chart generator and a chart summarizer colleagues."
),
)
# run it
result = bound_agent.invoke(state)
# decide if we’re done
last = result["messages"][-1]
goto = get_next_node(last, "chart_generator")
# tag the origin of the final message
result["messages"][-1] = HumanMessage(
content=last.content,
name="research_agent",
)
return Command(
update={"messages": result["messages"]},
goto=goto,
)
# Warning: This executes code locally, which can be unsafe when not sandboxed
repl = PythonREPL()
@tool
@instrument(
span_type="PYTHON_REPL_TOOL",
attributes={
f"{BASE_SCOPE}.python_tool_input_code": "code",
},
)
def python_repl_tool(code: str):
"""
Run arbitrary Python, grab the CURRENT matplotlib figure (if any),
save it to ./langgraph_saved_images_snowflaketools/v1/chart_.png,
and return a first-line `CHART_PATH=…`.
"""
import matplotlib
matplotlib.use("Agg") # headless safety
import os
import uuid
import matplotlib.pyplot as plt
# ------------------ run user code & capture stdout ------------------
repl.run(code)
# ------------------ locate a figure (if generated) ------------------
fig = plt.gcf()
has_axes = bool(fig.axes) # True if something was plotted
# ------------------ always save if we have a figure -----------------
chart_path = ""
if has_axes:
target_dir = "./langgraph_saved_images_snowflaketools/v1"
os.makedirs(target_dir, exist_ok=True)
chart_path = os.path.join(
target_dir, f"chart_{uuid.uuid4().hex}.png"
)
fig.savefig(chart_path, format="png")
plt.close(fig)
# ------------------ tool result (1st line = CHART_PATH) -------------
return f"CHART_PATH={chart_path if chart_path else 'NONE'}\n"
# Chart generator agent and node
# NOTE: THIS PERFORMS ARBITRARY CODE EXECUTION, WHICH CAN BE UNSAFE WHEN NOT SANDBOXED
# 1) Define the chart‐agent: it only returns JSON with a "code" field
chart_agent = create_react_agent(
llm,
[python_repl_tool],
prompt=make_system_prompt(
"""You can only generate charts by returning a single JSON object, for example:
{
"code": ""
}
—where uses matplotlib to create exactly one figure.
The plot should always include axis titles and relevant labels at the minimum.
Do NOT include any prose or tool‐call wrappers."""
),
)
def extract_chart_path(text: str) -> str | None:
"""
Scan every line of tool stdout for 'CHART_PATH=' and return
whatever follows, trimmed. Returns None if no such line exists.
"""
for line in text.splitlines():
if "CHART_PATH=" in line:
# split on the first '=', strip whitespace
return line.split("CHART_PATH=", 1)[1].strip()
return None
@instrument(
span_type="CHART_GENERATOR_NODE",
attributes=lambda ret, exception, *args, **kwargs: {
f"{BASE_SCOPE}.chart_node_input": args[0]["messages"][-1].content,
f"{BASE_SCOPE}.chart_node_response": (
ret.update["messages"][-1].content
if ret and hasattr(ret, "update") and ret.update
else "No update response"
),
},
)
def chart_node(state: ToolState) -> Command[Literal["chart_summarizer"]]:
# 0) If a path is already in state, skip
# extract the current human query
current_query = state["messages"][-1].content
# if we already generated a chart for _this_ query, skip
if state.get("last_query") == current_query and state.get("chart_path"):
return Command(
update={"messages": state["messages"]}, goto="chart_summarizer"
)
# it's a new query (or first run) → clear any old chart_path and remember this query
state.pop("chart_path", None)
state["last_query"] = current_query
# 1) Remember how many messages we had
len_before = len(state["messages"])
# 2) Run the agent exactly once
agent_out = chart_agent.invoke(state)
all_msgs = agent_out["messages"]
# 3) Look at only the brand-new messages for our chart tool output
new_segment = all_msgs[len_before:]
tool_msgs = [
m
for m in new_segment
if isinstance(m, ToolMessage) and "CHART_PATH=" in m.content
]
if not tool_msgs:
return Command(
update={"messages": state["messages"]},
goto="research_agent",
)
# 4) Parse the last one in case there are multiples
tool_msg = tool_msgs[-1]
tool_stdout = tool_msg.content
chart_path = extract_chart_path(tool_stdout)
# 5) Build your new messages list: include only that new ToolMessage
new_msgs = state["messages"][:] + [tool_msg]
# 6) Success! stash path into state and append the CHART_PATH marker
new_msgs.append(
HumanMessage(
content=f"CHART_PATH={chart_path}", name="chart_generator"
)
)
return Command(
update={"messages": new_msgs, "chart_path": chart_path},
goto="chart_summarizer",
)
@instrument(
span_type="CHART_SUMMARY_NODE",
attributes=lambda ret, exception, *args, **kwargs: {
# grab the state dict (kwarg wins, else first arg)
f"{BASE_SCOPE}.summary_node_input": (
(kwargs.get("state") or args[0])["messages"][-1].content
),
f"{BASE_SCOPE}.summary_node_output": (
ret.update["messages"][-1].content
if hasattr(ret, "update")
else "NO SUMMARY GENERATED"
),
},
)
def chart_summary_node(state: ToolState) -> Command:
# 1) find the chart_path in state
chart_path = state.get("chart_path", "")
if not chart_path:
return Command(
update={
"messages": state["messages"]
+ [
HumanMessage(
"No valid chart was generated. Please try again.",
name="chart_summarizer",
)
]
},
goto="select_tools",
)
# 2) strip *everything* except human utterances
human_history = [
m for m in state["messages"] if isinstance(m, HumanMessage)
]
# ensure our CHART_PATH marker is last
if not human_history or not human_history[-1].content.startswith(
"CHART_PATH="
):
human_history.append(
HumanMessage(
f"CHART_PATH={chart_path}", name="chart_summarizer"
)
)
# 3) build your ChatCompletion prompt
system = SystemMessage(
content=make_system_prompt(
"You are an AI assistant whose *only* job is to describe a chart image. "
"Input is a message CHART_PATH=… pointing at a saved PNG. "
"Include complete details and specifics of the chart image."
)
)
messages_for_llm = (
[system]
+ human_history
+ [HumanMessage("Please describe the above chart.")]
)
# 4) call the LLM directly—no tools, no React agent
ai_msg: AIMessage = llm(messages_for_llm)
summary = ai_msg.content
return Command(
update={
"messages": state["messages"]
+ [
HumanMessage(summary, name="chart_summarizer"),
]
},
goto="reflection",
)
@instrument(
span_type="CHART_SUMMARY_REFLECTION",
attributes=lambda ret, exception, *args, **kwargs: {
f"{BASE_SCOPE}.chart_summary_reflection_input_user_query": (
(kwargs.get("state") or args[0])["messages"][0].content
),
f"{BASE_SCOPE}.chart_summary_reflection_input_chart_summary": (
(kwargs.get("state") or args[0])["messages"][-1].content
),
# extract the summary string rather than returning the Command object
f"{BASE_SCOPE}.chart_summary_reflection_response": (
ret.update["messages"][-1].content
if hasattr(ret, "update")
else ""
),
},
)
def reflection_node(state: ToolState) -> Command:
"""
This function uses an LLM to reflect on the quality of a chart summary
and determine if the task is complete or requires further refinement.
"""
reflection_prompt_template = PromptTemplate(
input_variables=["user_query", "chart_summary"],
template="""\
You are an AI assistant tasked with reflecting on the quality of a chart summary. The user has asked the following question:
"{user_query}"
You are given the following chart summary:
"{chart_summary}"
Your task is to evaluate how well the chart summary answers the user's question. Consider the following:
- Does it describe a chart that will be relevant for answering the user's query?
If the summary **generally** addresses the question, respond with 'Task complete'. If the summary **lacks significant** details or clarity, then respond with specific details on how the answer should be improved and what information is needed. Avoid being overly critical unless the summary completely misses key elements necessary to answer the query.
Please provide your answer in a **concise and encouraging** manner.
""",
)
# Create the chain using the prompt template and the LLM (ChatOpenAI)
reflection_chain = reflection_prompt_template | llm
user_query = state["messages"][0].content
chart_summary = state["messages"][-1].content
# Call the chain with the user query and chart summary
reflection_result = reflection_chain.invoke({
"user_query": user_query,
"chart_summary": chart_summary,
})
if "Task complete" in reflection_result.content:
return Command(
update={
"messages": state["messages"]
+ [
HumanMessage(
reflection_result.content, name="reflection"
)
]
+ [
HumanMessage(
f"Chart saved at {state['chart_path']}. \n The chart summary is: \n {chart_summary}",
name="approved chart summary",
)
]
},
goto=END,
)
else:
return Command(
update={
"messages": state["messages"]
+ [
HumanMessage(
reflection_result.content, name="reflection"
)
]
},
goto="select_tools",
)
workflow = StateGraph(ToolState)
workflow.add_node("select_tools", select_tools)
workflow.add_node("research_agent", research_agent_node)
workflow.add_node("chart_generator", chart_node)
workflow.add_node("chart_summarizer", chart_summary_node)
workflow.add_node("reflection", reflection_node)
workflow.add_edge(START, "select_tools")
workflow.add_edge("select_tools", "research_agent")
workflow.add_edge("research_agent", "chart_generator")
workflow.add_edge("chart_generator", "chart_summarizer")
workflow.add_edge("chart_summarizer", "reflection")
workflow.add_edge("reflection", END)
compiled_graph = workflow.compile()
return compiled_graph
Register the agent and create a run¶
In [ ]:
Copied!
class TruAgent:
def __init__(self):
self.graph = build_graph()
@instrument(
span_type=SpanAttributes.SpanType.RECORD_ROOT,
attributes={
SpanAttributes.RECORD_ROOT.INPUT: "query",
SpanAttributes.RECORD_ROOT.OUTPUT: "return",
},
)
def invoke_agent_graph(self, query: str) -> str:
try:
# rebuild the graph for each query
self.graph = build_graph()
# Initialize state with proper message format
state = {"messages": [HumanMessage(content=query)]}
# Stream events with recursion limit
events = self.graph.stream(
state,
{"recursion_limit": 100},
)
# Track all messages through the conversation
all_messages = []
for event in events:
# Get the payload from the event
_, payload = next(iter(event.items()))
if not payload: # Skip empty payloads
continue
messages = payload.get("messages")
if not messages:
continue
all_messages.extend(messages)
# Return the last message's content if available
return (
all_messages[-1].content
if all_messages and hasattr(all_messages[-1], "content")
else ""
)
except Exception:
return "I ran into an issue, and cannot answer your question."
tru_agent = TruAgent()
class TruAgent:
def __init__(self):
self.graph = build_graph()
@instrument(
span_type=SpanAttributes.SpanType.RECORD_ROOT,
attributes={
SpanAttributes.RECORD_ROOT.INPUT: "query",
SpanAttributes.RECORD_ROOT.OUTPUT: "return",
},
)
def invoke_agent_graph(self, query: str) -> str:
try:
# rebuild the graph for each query
self.graph = build_graph()
# Initialize state with proper message format
state = {"messages": [HumanMessage(content=query)]}
# Stream events with recursion limit
events = self.graph.stream(
state,
{"recursion_limit": 100},
)
# Track all messages through the conversation
all_messages = []
for event in events:
# Get the payload from the event
_, payload = next(iter(event.items()))
if not payload: # Skip empty payloads
continue
messages = payload.get("messages")
if not messages:
continue
all_messages.extend(messages)
# Return the last message's content if available
return (
all_messages[-1].content
if all_messages and hasattr(all_messages[-1], "content")
else ""
)
except Exception:
return "I ran into an issue, and cannot answer your question."
tru_agent = TruAgent()
In [ ]:
Copied!
tru_agent_app = TruApp(
tru_agent,
app_name=APP_NAME,
app_version="web search",
connector=trulens_sf_connector,
main_method=tru_agent.invoke_agent_graph,
)
tru_agent_app = TruApp(
tru_agent,
app_name=APP_NAME,
app_version="web search",
connector=trulens_sf_connector,
main_method=tru_agent.invoke_agent_graph,
)
In [ ]:
Copied!
st_1 = datetime.datetime.fromtimestamp(time.time()).strftime(
"%Y-%m-%d %H:%M:%S"
)
run_config = RunConfig(
run_name="Multi-agent demo run " + st_1,
description="this is a run with access to web search and charting capabilities",
dataset_name="Research test dataset",
source_type="DATAFRAME",
label="langgraph demo",
dataset_spec={
"RECORD_ROOT.INPUT": "query",
},
)
run: Run = tru_agent_app.add_run(run_config)
st_1 = datetime.datetime.fromtimestamp(time.time()).strftime(
"%Y-%m-%d %H:%M:%S"
)
run_config = RunConfig(
run_name="Multi-agent demo run " + st_1,
description="this is a run with access to web search and charting capabilities",
dataset_name="Research test dataset",
source_type="DATAFRAME",
label="langgraph demo",
dataset_spec={
"RECORD_ROOT.INPUT": "query",
},
)
run: Run = tru_agent_app.add_run(run_config)
Start the run¶
This runs the agent in batch using the queries in the input_df
.
In [ ]:
Copied!
import pandas as pd
user_queries = [
"Compare the summer high and low temperatures in London versus Atlanta. Create a simple bar chart showing the highs and lows side by side between the two cities.",
"Compare the 12-month PCE inflation rate to the interest rate paid on federal reserve balances. Create a simple bar chart showing both rates side by side with the unit as basis points rather than percentage points.",
"what were the top 10 funds in terms of holding values in their most recent 13F-HR filings? Create a bar chart to illustrate, and use billions as the unit.",
]
user_queries_df = pd.DataFrame(user_queries, columns=["query"])
import pandas as pd
user_queries = [
"Compare the summer high and low temperatures in London versus Atlanta. Create a simple bar chart showing the highs and lows side by side between the two cities.",
"Compare the 12-month PCE inflation rate to the interest rate paid on federal reserve balances. Create a simple bar chart showing both rates side by side with the unit as basis points rather than percentage points.",
"what were the top 10 funds in terms of holding values in their most recent 13F-HR filings? Create a bar chart to illustrate, and use billions as the unit.",
]
user_queries_df = pd.DataFrame(user_queries, columns=["query"])
In [ ]:
Copied!
run.start(input_df=user_queries_df)
run.start(input_df=user_queries_df)
Compute metrics¶
In [ ]:
Copied!
import time
while run.get_status() == "INVOCATION_IN_PROGRESS":
time.sleep(3)
run.compute_metrics(["context_relevance", "answer_relevance", "groundedness"])
import time
while run.get_status() == "INVOCATION_IN_PROGRESS":
time.sleep(3)
run.compute_metrics(["context_relevance", "answer_relevance", "groundedness"])
Use Cortex Agent to gain access structured and unstructured data as a sub-agent¶
In [ ]:
Copied!
from typing import Type
import requests
from snowflake.snowpark import Session
class CortexAgentArgs(BaseModel):
query: str
class CortexAgentTool(StructuredTool):
name: str = "CortexAgent"
description: str = "answers questions using Fed minutes + SEC data"
# ← annotate this override
args_schema: Type[CortexAgentArgs] = CortexAgentArgs
# now declare your extra fields, too:
session: Session
api_url: str
headers: dict
# allow extra attributes (optional if you declare all fields above)
model_config = {"extra": "allow"}
def __init__(self, session: Session, account_url: str):
# pass the declared fields into super()
super().__init__(
session=session,
api_url=f"{account_url}/api/v2/cortex/agent:run",
headers={}, # we'll populate it next
)
pat = os.getenv("SNOWFLAKE_PAT")
if not pat:
raise RuntimeError("Set SNOWFLAKE_PAT")
self.headers.update({
"Authorization": f"Bearer {pat}",
"X-Snowflake-Authorization-Token-Type": "PROGRAMMATIC_ACCESS_TOKEN",
"Content-Type": "application/json",
})
def process_sse_response(self, resp):
"""
Process SSE stream lines, extracting any 'delta' payloads,
regardless of whether the JSON contains an 'event' field.
"""
text, sql, citations = "", "", []
for raw_line in resp.iter_lines(decode_unicode=True):
if not raw_line:
continue
raw_line = raw_line.strip()
# only handle data lines
if not raw_line.startswith("data:"):
continue
payload = raw_line[len("data:") :].strip()
if payload in ("", "[DONE]"):
continue
try:
evt = json.loads(payload)
except json.JSONDecodeError:
continue
# Grab the 'delta' section, whether top-level or nested in 'data'
delta = evt.get("delta") or evt.get("data", {}).get("delta")
if not isinstance(delta, dict):
continue
for item in delta.get("content", []):
t = item.get("type")
if t == "text":
text += item.get("text", "")
elif t == "tool_results":
for result in item["tool_results"].get("content", []):
if result.get("type") == "json":
j = result["json"]
text += j.get("text", "")
# capture SQL if present
if "sql" in j:
sql = j["sql"]
# capture any citations
for s in j.get("searchResults", []):
citations.append({
"source_id": s.get("source_id"),
"doc_id": s.get("doc_id"),
})
return text, sql, str(citations)
def run(self, query: str):
payload = {
"model": "claude-3-5-sonnet",
"response_instruction": "You are a helpful AI assistant.",
"experimental": {},
"tools": [
{
"tool_spec": {
"type": "cortex_analyst_text_to_sql",
"name": "Analyst1",
}
},
{"tool_spec": {"type": "cortex_search", "name": "Search1"}},
{
"tool_spec": {
"type": "sql_exec",
"name": "sql_execution_tool",
}
},
],
"tool_resources": {
"Analyst1": {"semantic_model_file": SEMANTIC_MODEL_FILE},
"Search1": {"name": CORTEX_SEARCH_SERVICE},
},
"tool_choice": {"type": "auto"},
"messages": [
{"role": "user", "content": [{"type": "text", "text": query}]}
],
}
resp = requests.post(
self.api_url, json=payload, headers=self.headers, stream=True
)
# parse SSE
text, sql, citations = self.process_sse_response(resp)
# execute SQL if returned
results = None
results_str = ""
if sql:
try:
results = self.session.sql(sql.replace(";", "")).collect()
results = pd.DataFrame(results)
results_str = results.to_string()
except Exception as e:
results_str = f"SQL execution error: {e}"
return text, citations, sql, results_str
def build_graph_with_agent():
def make_system_prompt(suffix: str) -> str:
return (
"You are a helpful AI assistant, collaborating with other assistants."
" Use the provided tools to progress towards answering the question."
" If you are unable to fully answer, that's OK, another assistant with different tools "
" will help where you left off. Execute what you can to make progress."
" If you or any of the other assistants have the final answer or deliverable,"
" prefix your response with FINAL ANSWER so the team knows to stop."
f"\n{suffix}"
)
search = SerpAPIWrapper()
search_tool = Tool(
name="web_search",
description="Search the web for current information, such as weather or news",
func=search.run,
)
# Instantiate CortexAgentTool
account_url = ACCOUNT_URL
cortex_agent_tool = CortexAgentTool(
session=snowpark_session_trulens,
account_url=account_url,
)
wrapped_cortex_agent_tool = Tool(
name=cortex_agent_tool.name,
description=cortex_agent_tool.description,
func=cortex_agent_tool.run,
return_direct=False, # set to True only if you want the agent to stop after using it
)
tool_registry = {
str(uuid.uuid4()): search_tool,
str(uuid.uuid4()): wrapped_cortex_agent_tool,
}
# Update your tool documents indexing accordingly
tool_documents = [
Document(
page_content=f"{tool.name}\n\n{tool.description}",
id=tool_id,
metadata={
"tool_name": tool.name,
"tool_description": tool.description,
},
)
for tool_id, tool in tool_registry.items()
]
vector_store = InMemoryVectorStore(embedding=OpenAIEmbeddings())
vector_store.add_documents(tool_documents)
llm = ChatOpenAI(model="gpt-4o")
@instrument(
span_type="SELECT_TOOLS",
attributes=lambda ret, exc, *args, **kw: {
# ---- state as JSON-text (OTLP needs a scalar) -----------------
f"{BASE_SCOPE}.select_tools_input_state": json.dumps( # ← turns dict → str
{
**{k: v for k, v in args[0].items() if k != "messages"},
"messages": [
{"type": m.__class__.__name__, "content": m.content}
if hasattr(m, "content") # BaseMessage subclasses
else m # already JSON-friendly
for m in args[0].get("messages", [])
],
}
),
# ---- selected tool IDs as a simple comma-separated string -----
f"{BASE_SCOPE}.selected_tool_ids": ", ".join(
ret.update.get("selected_tools", [])
)
if "selected_tools" in ret.update
else "",
f"{BASE_SCOPE}.selected_tool_names": ", ".join(
tool_registry[tool_id].name
for tool_id in ret.update.get("selected_tools", [])
)
if "selected_tools" in ret.update
else "",
},
)
def select_tools(
state: ToolState,
) -> Command[Literal["research_agent", END]]:
query = state["messages"][-1].content
# 1) Do a full similarity search over all tools
results = vector_store.similarity_search_with_score(
query,
k=len(tool_documents),
)
# 3) Pick the single best match
best_doc, best_score = max(results, key=lambda x: x[1])
# 4) If it’s truly too low, bail out (optional)
MIN_SCORE = 0.5
if best_score < MIN_SCORE:
return Command(
update={
"messages": state["messages"]
+ [
HumanMessage(
content="Sorry, I don’t have a tool that’s relevant enough to answer that.",
name="assistant",
)
]
},
goto=END,
)
# 5) Otherwise select that one
return Command(
update={"selected_tools": [best_doc.id]},
goto="research_agent",
)
# Warning: This executes code locally, which can be unsafe when not sandboxed
repl = PythonREPL()
@tool
@instrument(
span_type="PYTHON_REPL_TOOL",
attributes={
f"{BASE_SCOPE}.python_tool_input_code": "code",
},
)
def python_repl_tool(code: str):
"""
Run arbitrary Python, grab the CURRENT matplotlib figure (if any),
save it to ./langgraph_saved_images_snowflaketools/v2/chart_<uuid>.png,
and return a first-line `CHART_PATH=…`.
"""
import matplotlib
matplotlib.use("Agg") # headless safety
import os
import uuid
import matplotlib.pyplot as plt
# ------------------ run user code & capture stdout ------------------
repl.run(code)
# ------------------ locate a figure (if generated) ------------------
fig = plt.gcf()
has_axes = bool(fig.axes) # True if something was plotted
# ------------------ always save if we have a figure -----------------
chart_path = ""
if has_axes:
target_dir = "./langgraph_saved_images_snowflaketools/v2"
os.makedirs(target_dir, exist_ok=True)
chart_path = os.path.join(
target_dir, f"chart_{uuid.uuid4().hex}.png"
)
fig.savefig(chart_path, format="png")
plt.close(fig)
# ------------------ tool result (1st line = CHART_PATH) -------------
return f"CHART_PATH={chart_path if chart_path else 'NONE'}\n"
def get_next_node(last_message: BaseMessage, goto: str):
if "FINAL ANSWER" in last_message.content:
# Any agent decided the work is done
return END
return goto
@instrument(
span_type="RESEARCH_NODE",
attributes=lambda ret, exception, *args, **kwargs: {
f"{BASE_SCOPE}.research_node_input_content": args[0]["messages"][
-1
].content,
f"{BASE_SCOPE}.research_node_selected_tool_names": (
", ".join(
tool_registry.get(tool_id, "").name
for tool_id in args[0].get("selected_tools", [])
)
if "selected_tools" in args[0]
and len(args[0]["selected_tools"]) > 0
else "No tools selected"
),
f"{BASE_SCOPE}.planned_tool_call_names": (
[
call.get("function", {}).get("name", "")
for call in (
# if ret is a tuple, msg[1] is the AIMessage
ret[0]
.update["messages"][1]
.additional_kwargs.get("tool_calls", [])
if isinstance(ret, tuple)
# otherwise it's ret.update["messages"][1]
else ret.update["messages"][1].additional_kwargs.get(
"tool_calls", []
)
)
]
if (isinstance(ret, tuple) or hasattr(ret, "update"))
else []
),
f"{BASE_SCOPE}.planned_tool_call_args": (
[
call.get("function", {}).get("arguments", "")
for call in (
ret[0]
.update["messages"][1]
.additional_kwargs.get("tool_calls", [])
if isinstance(ret, tuple)
else ret.update["messages"][1].additional_kwargs.get(
"tool_calls", []
)
)
]
if (isinstance(ret, tuple) or hasattr(ret, "update"))
else []
),
f"{BASE_SCOPE}.agent_response": ret.update["messages"][-1].content
if hasattr(ret, "update")
else json.dumps(ret, indent=4, sort_keys=True),
f"{BASE_SCOPE}.web_search_results": [
json.loads(dumps(message)).get("kwargs", {}).get("content", "")
for message in ret.update["messages"]
if isinstance(message, ToolMessage)
and message.name == "web_search"
]
if hasattr(ret, "update")
else "No tool call",
f"{BASE_SCOPE}.cortex_agent_results": [
json.loads(dumps(message)).get("kwargs", {}).get("content", "")
for message in ret.update["messages"]
if isinstance(message, ToolMessage)
and message.name == "CortexAgent"
]
if hasattr(ret, "update")
else "No tool call",
f"{BASE_SCOPE}.cortex_agent_text": [
ast.literal_eval(
json.loads(dumps(m)).get("kwargs", {}).get("content", "")
)[0]
for m in ret.update["messages"]
if isinstance(m, ToolMessage) and m.name == "CortexAgent"
]
if hasattr(ret, "update")
else "No tool call",
f"{BASE_SCOPE}.cortex_agent_citations": [
ast.literal_eval(
json.loads(dumps(m)).get("kwargs", {}).get("content", "")
)[1]
for m in ret.update["messages"]
if isinstance(m, ToolMessage) and m.name == "CortexAgent"
]
if hasattr(ret, "update")
else "No tool call",
f"{BASE_SCOPE}.cortex_agent_sql": [
ast.literal_eval(
json.loads(dumps(m)).get("kwargs", {}).get("content", "")
)[2]
for m in ret.update["messages"]
if isinstance(m, ToolMessage) and m.name == "CortexAgent"
]
if hasattr(ret, "update")
else "No tool call",
f"{BASE_SCOPE}.cortex_agent_sql_results": [
ast.literal_eval(
json.loads(dumps(m)).get("kwargs", {}).get("content", "")
)[3]
for m in ret.update["messages"]
if isinstance(m, ToolMessage) and m.name == "CortexAgent"
]
if hasattr(ret, "update")
else "No tool call",
},
)
@instrument(
span_type=SpanAttributes.SpanType.RETRIEVAL,
attributes=lambda ret, exception, *args, **kwargs: {
SpanAttributes.RETRIEVAL.QUERY_TEXT: args[0]["messages"][
-1
].content,
SpanAttributes.RETRIEVAL.RETRIEVED_CONTEXTS:
[
(
ast.literal_eval(
json.loads(dumps(m)).get("kwargs", {}).get("content", "")
)[3]
if m.name == "CortexAgent"
else
json.loads(dumps(m)).get("kwargs", {}).get("content", "")
)
for m in ret.update["messages"]
if isinstance(m, ToolMessage)
and m.name in ("CortexAgent", "web_search")
]
if hasattr(ret, "update")
else "No tool call",
},
)
def research_agent_node(
state: ToolState,
) -> Command[Literal["chart_generator"]]:
"""
Always binds the selected tools and invokes the bound agent.
Stops on FINAL ANSWER or moves to chart_generator.
"""
# grab (non-empty) list of selected tool IDs
selected_ids = state["selected_tools"]
# bind only those tools
selected_tools = [tool_registry[tid] for tid in selected_ids]
bound_llm = llm.bind_tools(selected_tools)
bound_agent = create_react_agent(
bound_llm,
tools=selected_tools, # already bound
prompt=make_system_prompt(
"You can only do research. You are working with both a chart generator and a chart summarizer colleagues."
),
)
# run it
result = bound_agent.invoke(state)
# decide if we’re done
last = result["messages"][-1]
goto = get_next_node(last, "chart_generator")
# tag the origin of the final message
result["messages"][-1] = HumanMessage(
content=last.content,
name="research_agent",
)
return Command(
update={"messages": result["messages"]},
goto=goto,
)
# Chart generator agent and node
# NOTE: THIS PERFORMS ARBITRARY CODE EXECUTION, WHICH CAN BE UNSAFE WHEN NOT SANDBOXED
# 1) Define the chart‐agent: it only returns JSON with a "code" field
chart_agent = create_react_agent(
llm,
[python_repl_tool],
prompt=make_system_prompt(
"""You can only generate charts by returning a single JSON object, for example:
{
"code": "<your python plotting code here>"
}
—where <your python plotting code> uses matplotlib to create exactly one figure.
The plot should always include axis titles and relevant labels at the minimum.
Do NOT include any prose or tool‐call wrappers."""
),
)
def extract_chart_path(text: str) -> str | None:
"""
Scan every line of tool stdout for 'CHART_PATH=' and return
whatever follows, trimmed. Returns None if no such line exists.
"""
for line in text.splitlines():
if "CHART_PATH=" in line:
# split on the first '=', strip whitespace
return line.split("CHART_PATH=", 1)[1].strip()
return None
@instrument(
span_type="CHART_GENERATOR_NODE",
attributes=lambda ret, exception, *args, **kwargs: {
f"{BASE_SCOPE}.chart_node_input": args[0]["messages"][-1].content,
f"{BASE_SCOPE}.chart_node_response": (
ret.update["messages"][-1].content
if ret and hasattr(ret, "update") and ret.update
else "No update response"
),
},
)
def chart_node(state: ToolState) -> Command[Literal["chart_summarizer"]]:
# 0) If a path is already in state, skip
# extract the current human query
current_query = state["messages"][-1].content
# if we already generated a chart for _this_ query, skip
if state.get("last_query") == current_query and state.get("chart_path"):
return Command(
update={"messages": state["messages"]}, goto="chart_summarizer"
)
# it's a new query (or first run) → clear any old chart_path and remember this query
state.pop("chart_path", None)
state["last_query"] = current_query
# 1) Remember how many messages we had
len_before = len(state["messages"])
# 2) Run the agent exactly once
agent_out = chart_agent.invoke(state)
all_msgs = agent_out["messages"]
# 3) Look at only the brand-new messages for our chart tool output
new_segment = all_msgs[len_before:]
tool_msgs = [
m
for m in new_segment
if isinstance(m, ToolMessage) and "CHART_PATH=" in m.content
]
if not tool_msgs:
return Command(
update={"messages": state["messages"]},
goto="research_agent",
)
# 4) Parse the last one in case there are multiples
tool_msg = tool_msgs[-1]
tool_stdout = tool_msg.content
chart_path = extract_chart_path(tool_stdout)
# 5) Build your new messages list: include only that new ToolMessage
new_msgs = state["messages"][:] + [tool_msg]
# 6) Success! stash path into state and append the CHART_PATH marker
new_msgs.append(
HumanMessage(
content=f"CHART_PATH={chart_path}", name="chart_generator"
)
)
return Command(
update={"messages": new_msgs, "chart_path": chart_path},
goto="chart_summarizer",
)
@instrument(
span_type="CHART_SUMMARY_NODE",
attributes=lambda ret, exception, *args, **kwargs: {
# grab the state dict (kwarg wins, else first arg)
f"{BASE_SCOPE}.summary_node_input": (
(kwargs.get("state") or args[0])["messages"][-1].content
),
f"{BASE_SCOPE}.summary_node_output": (
ret.update["messages"][-1].content
if hasattr(ret, "update")
else "NO SUMMARY GENERATED"
),
},
)
def chart_summary_node(state: ToolState) -> Command:
# 1) find the chart_path in state
chart_path = state.get("chart_path", "")
if not chart_path:
return Command(
update={
"messages": state["messages"]
+ [
HumanMessage(
"No valid chart was generated. Please try again.",
name="chart_summarizer",
)
]
},
goto="select_tools",
)
# 2) strip *everything* except human utterances
human_history = [
m for m in state["messages"] if isinstance(m, HumanMessage)
]
# ensure our CHART_PATH marker is last
if not human_history or not human_history[-1].content.startswith(
"CHART_PATH="
):
human_history.append(
HumanMessage(
f"CHART_PATH={chart_path}", name="chart_summarizer"
)
)
# 3) build your ChatCompletion prompt
system = SystemMessage(
content=make_system_prompt(
"You are an AI assistant whose *only* job is to describe a chart image. "
"Input is a message CHART_PATH=… pointing at a saved PNG. "
"Include complete details and specifics of the chart image."
)
)
messages_for_llm = (
[system]
+ human_history
+ [HumanMessage("Please describe the above chart.")]
)
# 4) call the LLM directly—no tools, no React agent
ai_msg: AIMessage = llm(messages_for_llm)
summary = ai_msg.content
return Command(
update={
"messages": state["messages"]
+ [
HumanMessage(summary, name="chart_summarizer"),
]
},
goto="reflection",
)
@instrument(
span_type="CHART_SUMMARY_REFLECTION",
attributes=lambda ret, exception, *args, **kwargs: {
f"{BASE_SCOPE}.chart_summary_reflection_input_user_query": (
(kwargs.get("state") or args[0])["messages"][0].content
),
f"{BASE_SCOPE}.chart_summary_reflection_input_chart_summary": (
(kwargs.get("state") or args[0])["messages"][-1].content
),
# extract the summary string rather than returning the Command object
f"{BASE_SCOPE}.chart_summary_reflection_response": (
ret.update["messages"][-1].content
if hasattr(ret, "update")
else ""
),
},
)
def reflection_node(state: ToolState) -> Command:
"""
This function uses an LLM to reflect on the quality of a chart summary
and determine if the task is complete or requires further refinement.
"""
reflection_prompt_template = PromptTemplate(
input_variables=["user_query", "chart_summary"],
template="""\
You are an AI assistant tasked with reflecting on the quality of a chart summary. The user has asked the following question:
"{user_query}"
You are given the following chart summary:
"{chart_summary}"
Your task is to evaluate how well the chart summary answers the user's question. Consider the following:
- Does it describe a chart that will be relevant for answering the user's query?
If the summary **generally** addresses the question, respond with 'Task complete'. If the summary **lacks significant** details or clarity, then respond with specific details on how the answer should be improved and what information is needed. Avoid being overly critical unless the summary completely misses key elements necessary to answer the query.
Please provide your answer in a **concise and encouraging** manner.
""",
)
# Create the chain using the prompt template and the LLM (ChatOpenAI)
reflection_chain = reflection_prompt_template | llm
user_query = state["messages"][0].content
chart_summary = state["messages"][-1].content
# Call the chain with the user query and chart summary
reflection_result = reflection_chain.invoke({
"user_query": user_query,
"chart_summary": chart_summary,
})
if "Task complete" in reflection_result.content:
return Command(
update={
"messages": state["messages"]
+ [
HumanMessage(
reflection_result.content, name="reflection"
)
]
+ [
HumanMessage(
f"Chart saved at {state['chart_path']}. \n The chart summary is: \n {chart_summary}",
name="approved chart summary",
)
]
},
goto=END,
)
else:
return Command(
update={
"messages": state["messages"]
+ [
HumanMessage(
reflection_result.content, name="reflection"
)
]
},
goto="select_tools",
)
workflow = StateGraph(ToolState)
workflow.add_node("select_tools", select_tools)
workflow.add_node("research_agent", research_agent_node)
workflow.add_node("chart_generator", chart_node)
workflow.add_node("chart_summarizer", chart_summary_node)
workflow.add_node("reflection", reflection_node)
workflow.add_edge(START, "select_tools")
workflow.add_edge("select_tools", "research_agent")
workflow.add_edge("research_agent", "chart_generator")
workflow.add_edge("chart_generator", "chart_summarizer")
workflow.add_edge("chart_summarizer", "reflection")
workflow.add_edge("reflection", END)
compiled_graph = workflow.compile()
return compiled_graph
from typing import Type
import requests
from snowflake.snowpark import Session
class CortexAgentArgs(BaseModel):
query: str
class CortexAgentTool(StructuredTool):
name: str = "CortexAgent"
description: str = "answers questions using Fed minutes + SEC data"
# ← annotate this override
args_schema: Type[CortexAgentArgs] = CortexAgentArgs
# now declare your extra fields, too:
session: Session
api_url: str
headers: dict
# allow extra attributes (optional if you declare all fields above)
model_config = {"extra": "allow"}
def __init__(self, session: Session, account_url: str):
# pass the declared fields into super()
super().__init__(
session=session,
api_url=f"{account_url}/api/v2/cortex/agent:run",
headers={}, # we'll populate it next
)
pat = os.getenv("SNOWFLAKE_PAT")
if not pat:
raise RuntimeError("Set SNOWFLAKE_PAT")
self.headers.update({
"Authorization": f"Bearer {pat}",
"X-Snowflake-Authorization-Token-Type": "PROGRAMMATIC_ACCESS_TOKEN",
"Content-Type": "application/json",
})
def process_sse_response(self, resp):
"""
Process SSE stream lines, extracting any 'delta' payloads,
regardless of whether the JSON contains an 'event' field.
"""
text, sql, citations = "", "", []
for raw_line in resp.iter_lines(decode_unicode=True):
if not raw_line:
continue
raw_line = raw_line.strip()
# only handle data lines
if not raw_line.startswith("data:"):
continue
payload = raw_line[len("data:") :].strip()
if payload in ("", "[DONE]"):
continue
try:
evt = json.loads(payload)
except json.JSONDecodeError:
continue
# Grab the 'delta' section, whether top-level or nested in 'data'
delta = evt.get("delta") or evt.get("data", {}).get("delta")
if not isinstance(delta, dict):
continue
for item in delta.get("content", []):
t = item.get("type")
if t == "text":
text += item.get("text", "")
elif t == "tool_results":
for result in item["tool_results"].get("content", []):
if result.get("type") == "json":
j = result["json"]
text += j.get("text", "")
# capture SQL if present
if "sql" in j:
sql = j["sql"]
# capture any citations
for s in j.get("searchResults", []):
citations.append({
"source_id": s.get("source_id"),
"doc_id": s.get("doc_id"),
})
return text, sql, str(citations)
def run(self, query: str):
payload = {
"model": "claude-3-5-sonnet",
"response_instruction": "You are a helpful AI assistant.",
"experimental": {},
"tools": [
{
"tool_spec": {
"type": "cortex_analyst_text_to_sql",
"name": "Analyst1",
}
},
{"tool_spec": {"type": "cortex_search", "name": "Search1"}},
{
"tool_spec": {
"type": "sql_exec",
"name": "sql_execution_tool",
}
},
],
"tool_resources": {
"Analyst1": {"semantic_model_file": SEMANTIC_MODEL_FILE},
"Search1": {"name": CORTEX_SEARCH_SERVICE},
},
"tool_choice": {"type": "auto"},
"messages": [
{"role": "user", "content": [{"type": "text", "text": query}]}
],
}
resp = requests.post(
self.api_url, json=payload, headers=self.headers, stream=True
)
# parse SSE
text, sql, citations = self.process_sse_response(resp)
# execute SQL if returned
results = None
results_str = ""
if sql:
try:
results = self.session.sql(sql.replace(";", "")).collect()
results = pd.DataFrame(results)
results_str = results.to_string()
except Exception as e:
results_str = f"SQL execution error: {e}"
return text, citations, sql, results_str
def build_graph_with_agent():
def make_system_prompt(suffix: str) -> str:
return (
"You are a helpful AI assistant, collaborating with other assistants."
" Use the provided tools to progress towards answering the question."
" If you are unable to fully answer, that's OK, another assistant with different tools "
" will help where you left off. Execute what you can to make progress."
" If you or any of the other assistants have the final answer or deliverable,"
" prefix your response with FINAL ANSWER so the team knows to stop."
f"\n{suffix}"
)
search = SerpAPIWrapper()
search_tool = Tool(
name="web_search",
description="Search the web for current information, such as weather or news",
func=search.run,
)
# Instantiate CortexAgentTool
account_url = ACCOUNT_URL
cortex_agent_tool = CortexAgentTool(
session=snowpark_session_trulens,
account_url=account_url,
)
wrapped_cortex_agent_tool = Tool(
name=cortex_agent_tool.name,
description=cortex_agent_tool.description,
func=cortex_agent_tool.run,
return_direct=False, # set to True only if you want the agent to stop after using it
)
tool_registry = {
str(uuid.uuid4()): search_tool,
str(uuid.uuid4()): wrapped_cortex_agent_tool,
}
# Update your tool documents indexing accordingly
tool_documents = [
Document(
page_content=f"{tool.name}\n\n{tool.description}",
id=tool_id,
metadata={
"tool_name": tool.name,
"tool_description": tool.description,
},
)
for tool_id, tool in tool_registry.items()
]
vector_store = InMemoryVectorStore(embedding=OpenAIEmbeddings())
vector_store.add_documents(tool_documents)
llm = ChatOpenAI(model="gpt-4o")
@instrument(
span_type="SELECT_TOOLS",
attributes=lambda ret, exc, *args, **kw: {
# ---- state as JSON-text (OTLP needs a scalar) -----------------
f"{BASE_SCOPE}.select_tools_input_state": json.dumps( # ← turns dict → str
{
**{k: v for k, v in args[0].items() if k != "messages"},
"messages": [
{"type": m.__class__.__name__, "content": m.content}
if hasattr(m, "content") # BaseMessage subclasses
else m # already JSON-friendly
for m in args[0].get("messages", [])
],
}
),
# ---- selected tool IDs as a simple comma-separated string -----
f"{BASE_SCOPE}.selected_tool_ids": ", ".join(
ret.update.get("selected_tools", [])
)
if "selected_tools" in ret.update
else "",
f"{BASE_SCOPE}.selected_tool_names": ", ".join(
tool_registry[tool_id].name
for tool_id in ret.update.get("selected_tools", [])
)
if "selected_tools" in ret.update
else "",
},
)
def select_tools(
state: ToolState,
) -> Command[Literal["research_agent", END]]:
query = state["messages"][-1].content
# 1) Do a full similarity search over all tools
results = vector_store.similarity_search_with_score(
query,
k=len(tool_documents),
)
# 3) Pick the single best match
best_doc, best_score = max(results, key=lambda x: x[1])
# 4) If it’s truly too low, bail out (optional)
MIN_SCORE = 0.5
if best_score < MIN_SCORE:
return Command(
update={
"messages": state["messages"]
+ [
HumanMessage(
content="Sorry, I don’t have a tool that’s relevant enough to answer that.",
name="assistant",
)
]
},
goto=END,
)
# 5) Otherwise select that one
return Command(
update={"selected_tools": [best_doc.id]},
goto="research_agent",
)
# Warning: This executes code locally, which can be unsafe when not sandboxed
repl = PythonREPL()
@tool
@instrument(
span_type="PYTHON_REPL_TOOL",
attributes={
f"{BASE_SCOPE}.python_tool_input_code": "code",
},
)
def python_repl_tool(code: str):
"""
Run arbitrary Python, grab the CURRENT matplotlib figure (if any),
save it to ./langgraph_saved_images_snowflaketools/v2/chart_.png,
and return a first-line `CHART_PATH=…`.
"""
import matplotlib
matplotlib.use("Agg") # headless safety
import os
import uuid
import matplotlib.pyplot as plt
# ------------------ run user code & capture stdout ------------------
repl.run(code)
# ------------------ locate a figure (if generated) ------------------
fig = plt.gcf()
has_axes = bool(fig.axes) # True if something was plotted
# ------------------ always save if we have a figure -----------------
chart_path = ""
if has_axes:
target_dir = "./langgraph_saved_images_snowflaketools/v2"
os.makedirs(target_dir, exist_ok=True)
chart_path = os.path.join(
target_dir, f"chart_{uuid.uuid4().hex}.png"
)
fig.savefig(chart_path, format="png")
plt.close(fig)
# ------------------ tool result (1st line = CHART_PATH) -------------
return f"CHART_PATH={chart_path if chart_path else 'NONE'}\n"
def get_next_node(last_message: BaseMessage, goto: str):
if "FINAL ANSWER" in last_message.content:
# Any agent decided the work is done
return END
return goto
@instrument(
span_type="RESEARCH_NODE",
attributes=lambda ret, exception, *args, **kwargs: {
f"{BASE_SCOPE}.research_node_input_content": args[0]["messages"][
-1
].content,
f"{BASE_SCOPE}.research_node_selected_tool_names": (
", ".join(
tool_registry.get(tool_id, "").name
for tool_id in args[0].get("selected_tools", [])
)
if "selected_tools" in args[0]
and len(args[0]["selected_tools"]) > 0
else "No tools selected"
),
f"{BASE_SCOPE}.planned_tool_call_names": (
[
call.get("function", {}).get("name", "")
for call in (
# if ret is a tuple, msg[1] is the AIMessage
ret[0]
.update["messages"][1]
.additional_kwargs.get("tool_calls", [])
if isinstance(ret, tuple)
# otherwise it's ret.update["messages"][1]
else ret.update["messages"][1].additional_kwargs.get(
"tool_calls", []
)
)
]
if (isinstance(ret, tuple) or hasattr(ret, "update"))
else []
),
f"{BASE_SCOPE}.planned_tool_call_args": (
[
call.get("function", {}).get("arguments", "")
for call in (
ret[0]
.update["messages"][1]
.additional_kwargs.get("tool_calls", [])
if isinstance(ret, tuple)
else ret.update["messages"][1].additional_kwargs.get(
"tool_calls", []
)
)
]
if (isinstance(ret, tuple) or hasattr(ret, "update"))
else []
),
f"{BASE_SCOPE}.agent_response": ret.update["messages"][-1].content
if hasattr(ret, "update")
else json.dumps(ret, indent=4, sort_keys=True),
f"{BASE_SCOPE}.web_search_results": [
json.loads(dumps(message)).get("kwargs", {}).get("content", "")
for message in ret.update["messages"]
if isinstance(message, ToolMessage)
and message.name == "web_search"
]
if hasattr(ret, "update")
else "No tool call",
f"{BASE_SCOPE}.cortex_agent_results": [
json.loads(dumps(message)).get("kwargs", {}).get("content", "")
for message in ret.update["messages"]
if isinstance(message, ToolMessage)
and message.name == "CortexAgent"
]
if hasattr(ret, "update")
else "No tool call",
f"{BASE_SCOPE}.cortex_agent_text": [
ast.literal_eval(
json.loads(dumps(m)).get("kwargs", {}).get("content", "")
)[0]
for m in ret.update["messages"]
if isinstance(m, ToolMessage) and m.name == "CortexAgent"
]
if hasattr(ret, "update")
else "No tool call",
f"{BASE_SCOPE}.cortex_agent_citations": [
ast.literal_eval(
json.loads(dumps(m)).get("kwargs", {}).get("content", "")
)[1]
for m in ret.update["messages"]
if isinstance(m, ToolMessage) and m.name == "CortexAgent"
]
if hasattr(ret, "update")
else "No tool call",
f"{BASE_SCOPE}.cortex_agent_sql": [
ast.literal_eval(
json.loads(dumps(m)).get("kwargs", {}).get("content", "")
)[2]
for m in ret.update["messages"]
if isinstance(m, ToolMessage) and m.name == "CortexAgent"
]
if hasattr(ret, "update")
else "No tool call",
f"{BASE_SCOPE}.cortex_agent_sql_results": [
ast.literal_eval(
json.loads(dumps(m)).get("kwargs", {}).get("content", "")
)[3]
for m in ret.update["messages"]
if isinstance(m, ToolMessage) and m.name == "CortexAgent"
]
if hasattr(ret, "update")
else "No tool call",
},
)
@instrument(
span_type=SpanAttributes.SpanType.RETRIEVAL,
attributes=lambda ret, exception, *args, **kwargs: {
SpanAttributes.RETRIEVAL.QUERY_TEXT: args[0]["messages"][
-1
].content,
SpanAttributes.RETRIEVAL.RETRIEVED_CONTEXTS:
[
(
ast.literal_eval(
json.loads(dumps(m)).get("kwargs", {}).get("content", "")
)[3]
if m.name == "CortexAgent"
else
json.loads(dumps(m)).get("kwargs", {}).get("content", "")
)
for m in ret.update["messages"]
if isinstance(m, ToolMessage)
and m.name in ("CortexAgent", "web_search")
]
if hasattr(ret, "update")
else "No tool call",
},
)
def research_agent_node(
state: ToolState,
) -> Command[Literal["chart_generator"]]:
"""
Always binds the selected tools and invokes the bound agent.
Stops on FINAL ANSWER or moves to chart_generator.
"""
# grab (non-empty) list of selected tool IDs
selected_ids = state["selected_tools"]
# bind only those tools
selected_tools = [tool_registry[tid] for tid in selected_ids]
bound_llm = llm.bind_tools(selected_tools)
bound_agent = create_react_agent(
bound_llm,
tools=selected_tools, # already bound
prompt=make_system_prompt(
"You can only do research. You are working with both a chart generator and a chart summarizer colleagues."
),
)
# run it
result = bound_agent.invoke(state)
# decide if we’re done
last = result["messages"][-1]
goto = get_next_node(last, "chart_generator")
# tag the origin of the final message
result["messages"][-1] = HumanMessage(
content=last.content,
name="research_agent",
)
return Command(
update={"messages": result["messages"]},
goto=goto,
)
# Chart generator agent and node
# NOTE: THIS PERFORMS ARBITRARY CODE EXECUTION, WHICH CAN BE UNSAFE WHEN NOT SANDBOXED
# 1) Define the chart‐agent: it only returns JSON with a "code" field
chart_agent = create_react_agent(
llm,
[python_repl_tool],
prompt=make_system_prompt(
"""You can only generate charts by returning a single JSON object, for example:
{
"code": ""
}
—where uses matplotlib to create exactly one figure.
The plot should always include axis titles and relevant labels at the minimum.
Do NOT include any prose or tool‐call wrappers."""
),
)
def extract_chart_path(text: str) -> str | None:
"""
Scan every line of tool stdout for 'CHART_PATH=' and return
whatever follows, trimmed. Returns None if no such line exists.
"""
for line in text.splitlines():
if "CHART_PATH=" in line:
# split on the first '=', strip whitespace
return line.split("CHART_PATH=", 1)[1].strip()
return None
@instrument(
span_type="CHART_GENERATOR_NODE",
attributes=lambda ret, exception, *args, **kwargs: {
f"{BASE_SCOPE}.chart_node_input": args[0]["messages"][-1].content,
f"{BASE_SCOPE}.chart_node_response": (
ret.update["messages"][-1].content
if ret and hasattr(ret, "update") and ret.update
else "No update response"
),
},
)
def chart_node(state: ToolState) -> Command[Literal["chart_summarizer"]]:
# 0) If a path is already in state, skip
# extract the current human query
current_query = state["messages"][-1].content
# if we already generated a chart for _this_ query, skip
if state.get("last_query") == current_query and state.get("chart_path"):
return Command(
update={"messages": state["messages"]}, goto="chart_summarizer"
)
# it's a new query (or first run) → clear any old chart_path and remember this query
state.pop("chart_path", None)
state["last_query"] = current_query
# 1) Remember how many messages we had
len_before = len(state["messages"])
# 2) Run the agent exactly once
agent_out = chart_agent.invoke(state)
all_msgs = agent_out["messages"]
# 3) Look at only the brand-new messages for our chart tool output
new_segment = all_msgs[len_before:]
tool_msgs = [
m
for m in new_segment
if isinstance(m, ToolMessage) and "CHART_PATH=" in m.content
]
if not tool_msgs:
return Command(
update={"messages": state["messages"]},
goto="research_agent",
)
# 4) Parse the last one in case there are multiples
tool_msg = tool_msgs[-1]
tool_stdout = tool_msg.content
chart_path = extract_chart_path(tool_stdout)
# 5) Build your new messages list: include only that new ToolMessage
new_msgs = state["messages"][:] + [tool_msg]
# 6) Success! stash path into state and append the CHART_PATH marker
new_msgs.append(
HumanMessage(
content=f"CHART_PATH={chart_path}", name="chart_generator"
)
)
return Command(
update={"messages": new_msgs, "chart_path": chart_path},
goto="chart_summarizer",
)
@instrument(
span_type="CHART_SUMMARY_NODE",
attributes=lambda ret, exception, *args, **kwargs: {
# grab the state dict (kwarg wins, else first arg)
f"{BASE_SCOPE}.summary_node_input": (
(kwargs.get("state") or args[0])["messages"][-1].content
),
f"{BASE_SCOPE}.summary_node_output": (
ret.update["messages"][-1].content
if hasattr(ret, "update")
else "NO SUMMARY GENERATED"
),
},
)
def chart_summary_node(state: ToolState) -> Command:
# 1) find the chart_path in state
chart_path = state.get("chart_path", "")
if not chart_path:
return Command(
update={
"messages": state["messages"]
+ [
HumanMessage(
"No valid chart was generated. Please try again.",
name="chart_summarizer",
)
]
},
goto="select_tools",
)
# 2) strip *everything* except human utterances
human_history = [
m for m in state["messages"] if isinstance(m, HumanMessage)
]
# ensure our CHART_PATH marker is last
if not human_history or not human_history[-1].content.startswith(
"CHART_PATH="
):
human_history.append(
HumanMessage(
f"CHART_PATH={chart_path}", name="chart_summarizer"
)
)
# 3) build your ChatCompletion prompt
system = SystemMessage(
content=make_system_prompt(
"You are an AI assistant whose *only* job is to describe a chart image. "
"Input is a message CHART_PATH=… pointing at a saved PNG. "
"Include complete details and specifics of the chart image."
)
)
messages_for_llm = (
[system]
+ human_history
+ [HumanMessage("Please describe the above chart.")]
)
# 4) call the LLM directly—no tools, no React agent
ai_msg: AIMessage = llm(messages_for_llm)
summary = ai_msg.content
return Command(
update={
"messages": state["messages"]
+ [
HumanMessage(summary, name="chart_summarizer"),
]
},
goto="reflection",
)
@instrument(
span_type="CHART_SUMMARY_REFLECTION",
attributes=lambda ret, exception, *args, **kwargs: {
f"{BASE_SCOPE}.chart_summary_reflection_input_user_query": (
(kwargs.get("state") or args[0])["messages"][0].content
),
f"{BASE_SCOPE}.chart_summary_reflection_input_chart_summary": (
(kwargs.get("state") or args[0])["messages"][-1].content
),
# extract the summary string rather than returning the Command object
f"{BASE_SCOPE}.chart_summary_reflection_response": (
ret.update["messages"][-1].content
if hasattr(ret, "update")
else ""
),
},
)
def reflection_node(state: ToolState) -> Command:
"""
This function uses an LLM to reflect on the quality of a chart summary
and determine if the task is complete or requires further refinement.
"""
reflection_prompt_template = PromptTemplate(
input_variables=["user_query", "chart_summary"],
template="""\
You are an AI assistant tasked with reflecting on the quality of a chart summary. The user has asked the following question:
"{user_query}"
You are given the following chart summary:
"{chart_summary}"
Your task is to evaluate how well the chart summary answers the user's question. Consider the following:
- Does it describe a chart that will be relevant for answering the user's query?
If the summary **generally** addresses the question, respond with 'Task complete'. If the summary **lacks significant** details or clarity, then respond with specific details on how the answer should be improved and what information is needed. Avoid being overly critical unless the summary completely misses key elements necessary to answer the query.
Please provide your answer in a **concise and encouraging** manner.
""",
)
# Create the chain using the prompt template and the LLM (ChatOpenAI)
reflection_chain = reflection_prompt_template | llm
user_query = state["messages"][0].content
chart_summary = state["messages"][-1].content
# Call the chain with the user query and chart summary
reflection_result = reflection_chain.invoke({
"user_query": user_query,
"chart_summary": chart_summary,
})
if "Task complete" in reflection_result.content:
return Command(
update={
"messages": state["messages"]
+ [
HumanMessage(
reflection_result.content, name="reflection"
)
]
+ [
HumanMessage(
f"Chart saved at {state['chart_path']}. \n The chart summary is: \n {chart_summary}",
name="approved chart summary",
)
]
},
goto=END,
)
else:
return Command(
update={
"messages": state["messages"]
+ [
HumanMessage(
reflection_result.content, name="reflection"
)
]
},
goto="select_tools",
)
workflow = StateGraph(ToolState)
workflow.add_node("select_tools", select_tools)
workflow.add_node("research_agent", research_agent_node)
workflow.add_node("chart_generator", chart_node)
workflow.add_node("chart_summarizer", chart_summary_node)
workflow.add_node("reflection", reflection_node)
workflow.add_edge(START, "select_tools")
workflow.add_edge("select_tools", "research_agent")
workflow.add_edge("research_agent", "chart_generator")
workflow.add_edge("chart_generator", "chart_summarizer")
workflow.add_edge("chart_summarizer", "reflection")
workflow.add_edge("reflection", END)
compiled_graph = workflow.compile()
return compiled_graph
In [ ]:
Copied!
class TruAgent:
def __init__(self):
self.graph = build_graph_with_agent()
@instrument(
span_type=SpanAttributes.SpanType.RECORD_ROOT,
attributes={
SpanAttributes.RECORD_ROOT.INPUT: "query",
SpanAttributes.RECORD_ROOT.OUTPUT: "return",
},
)
def invoke_agent_graph(self, query: str) -> str:
try:
# rebuild the graph for each query
self.graph = build_graph_with_agent()
# Initialize state with proper message format
state = {"messages": [HumanMessage(content=query)]}
# Stream events with recursion limit
events = self.graph.stream(
state,
{"recursion_limit": 60},
)
# Track all messages through the conversation
all_messages = []
for event in events:
# Get the payload from the event
_, payload = next(iter(event.items()))
if not payload: # Skip empty payloads
continue
messages = payload.get("messages")
if not messages:
continue
all_messages.extend(messages)
# Return the last message's content if available
return (
all_messages[-1].content
if all_messages and hasattr(all_messages[-1], "content")
else ""
)
except Exception:
return "I ran into an issue, and cannot answer your question."
tru_agent = TruAgent()
tru_agent_app = TruApp(
tru_agent,
app_name=APP_NAME,
app_version="doc, sql and web search",
connector=trulens_sf_connector,
main_method=tru_agent.invoke_agent_graph,
)
st_1 = datetime.datetime.fromtimestamp(time.time()).strftime(
"%Y-%m-%d %H:%M:%S"
)
run_config = RunConfig(
run_name="Multi-agent demo run - document, sql and web search " + st_1,
description="this is a run with access to cortex agent (search and analyst) and web search",
dataset_name="Research test dataset",
source_type="DATAFRAME",
label="langgraph demo",
dataset_spec={
"RECORD_ROOT.INPUT": "query",
},
)
run: Run = tru_agent_app.add_run(run_config)
class TruAgent:
def __init__(self):
self.graph = build_graph_with_agent()
@instrument(
span_type=SpanAttributes.SpanType.RECORD_ROOT,
attributes={
SpanAttributes.RECORD_ROOT.INPUT: "query",
SpanAttributes.RECORD_ROOT.OUTPUT: "return",
},
)
def invoke_agent_graph(self, query: str) -> str:
try:
# rebuild the graph for each query
self.graph = build_graph_with_agent()
# Initialize state with proper message format
state = {"messages": [HumanMessage(content=query)]}
# Stream events with recursion limit
events = self.graph.stream(
state,
{"recursion_limit": 60},
)
# Track all messages through the conversation
all_messages = []
for event in events:
# Get the payload from the event
_, payload = next(iter(event.items()))
if not payload: # Skip empty payloads
continue
messages = payload.get("messages")
if not messages:
continue
all_messages.extend(messages)
# Return the last message's content if available
return (
all_messages[-1].content
if all_messages and hasattr(all_messages[-1], "content")
else ""
)
except Exception:
return "I ran into an issue, and cannot answer your question."
tru_agent = TruAgent()
tru_agent_app = TruApp(
tru_agent,
app_name=APP_NAME,
app_version="doc, sql and web search",
connector=trulens_sf_connector,
main_method=tru_agent.invoke_agent_graph,
)
st_1 = datetime.datetime.fromtimestamp(time.time()).strftime(
"%Y-%m-%d %H:%M:%S"
)
run_config = RunConfig(
run_name="Multi-agent demo run - document, sql and web search " + st_1,
description="this is a run with access to cortex agent (search and analyst) and web search",
dataset_name="Research test dataset",
source_type="DATAFRAME",
label="langgraph demo",
dataset_spec={
"RECORD_ROOT.INPUT": "query",
},
)
run: Run = tru_agent_app.add_run(run_config)
In [ ]:
Copied!
run.start(input_df=user_queries_df)
run.start(input_df=user_queries_df)
In [ ]:
Copied!
import time
while run.get_status() == "INVOCATION_IN_PROGRESS":
time.sleep(3)
run.compute_metrics(["context_relevance", "answer_relevance", "groundedness"])
import time
while run.get_status() == "INVOCATION_IN_PROGRESS":
time.sleep(3)
run.compute_metrics(["context_relevance", "answer_relevance", "groundedness"])