Update agent_v2.py
Browse files- agent_v2.py +26 -0
agent_v2.py
CHANGED
@@ -13,6 +13,8 @@ import pandas as pd
|
|
13 |
import uuid
|
14 |
import numpy as np
|
15 |
from code_interpreter import CodeInterpreter
|
|
|
|
|
16 |
|
17 |
interpreter_instance = CodeInterpreter()
|
18 |
|
@@ -747,6 +749,17 @@ def build_graph(provider: str = "huggingface"):
|
|
747 |
"""Assistant node"""
|
748 |
return {"messages": [llm_with_tools.invoke(state["messages"])]}
|
749 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
750 |
"""
|
751 |
def retriever(state: MessagesState):
|
752 |
"Retriever node"
|
@@ -762,6 +775,7 @@ def build_graph(provider: str = "huggingface"):
|
|
762 |
return {"messages": [sys_msg] + state["messages"]}
|
763 |
"""
|
764 |
|
|
|
765 |
builder = StateGraph(MessagesState)
|
766 |
#builder.add_node("retriever", retriever)
|
767 |
builder.add_node("assistant", assistant)
|
@@ -774,6 +788,18 @@ def build_graph(provider: str = "huggingface"):
|
|
774 |
tools_condition,
|
775 |
)
|
776 |
builder.add_edge("tools", "assistant")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
777 |
|
778 |
# Compile graph
|
779 |
return builder.compile()
|
|
|
13 |
import uuid
|
14 |
import numpy as np
|
15 |
from code_interpreter import CodeInterpreter
|
16 |
+
import re
|
17 |
+
from langchain_core.messages import AIMessage
|
18 |
|
19 |
interpreter_instance = CodeInterpreter()
|
20 |
|
|
|
749 |
"""Assistant node"""
|
750 |
return {"messages": [llm_with_tools.invoke(state["messages"])]}
|
751 |
|
752 |
+
# Define extractor node
|
753 |
+
def extractor(state: MessagesState):
|
754 |
+
last_message = state["messages"][-1].content
|
755 |
+
match = re.search(r"FINAL ANSWER:\s*(.*)", last_message, re.IGNORECASE)
|
756 |
+
if match:
|
757 |
+
answer = match.group(1).strip().split("\n")[0].rstrip('.')
|
758 |
+
return {"messages": [AIMessage(content=f"FINAL ANSWER: {answer}")]}
|
759 |
+
else:
|
760 |
+
return {"messages": [AIMessage(content="FINAL ANSWER: No valid answer found")]}
|
761 |
+
|
762 |
+
|
763 |
"""
|
764 |
def retriever(state: MessagesState):
|
765 |
"Retriever node"
|
|
|
775 |
return {"messages": [sys_msg] + state["messages"]}
|
776 |
"""
|
777 |
|
778 |
+
"""
|
779 |
builder = StateGraph(MessagesState)
|
780 |
#builder.add_node("retriever", retriever)
|
781 |
builder.add_node("assistant", assistant)
|
|
|
788 |
tools_condition,
|
789 |
)
|
790 |
builder.add_edge("tools", "assistant")
|
791 |
+
"""
|
792 |
+
|
793 |
+
builder = StateGraph(MessagesState)
|
794 |
+
builder.add_node("assistant", assistant)
|
795 |
+
builder.add_node("tools", ToolNode(tools))
|
796 |
+
builder.add_node("extractor", extractor)
|
797 |
+
builder.add_edge(START, "assistant")
|
798 |
+
builder.add_conditional_edges("assistant", tools_condition)
|
799 |
+
builder.add_edge("tools", "assistant")
|
800 |
+
builder.add_edge("assistant", "extractor")
|
801 |
+
builder.set_finish_point("extractor")
|
802 |
+
|
803 |
|
804 |
# Compile graph
|
805 |
return builder.compile()
|