ama2aifusion commited on
Commit
af7366a
·
verified ·
1 Parent(s): 1e16da9

Update agent_v2.py

Browse files
Files changed (1) hide show
  1. agent_v2.py +26 -0
agent_v2.py CHANGED
@@ -13,6 +13,8 @@ import pandas as pd
13
  import uuid
14
  import numpy as np
15
  from code_interpreter import CodeInterpreter
 
 
16
 
17
  interpreter_instance = CodeInterpreter()
18
 
@@ -747,6 +749,17 @@ def build_graph(provider: str = "huggingface"):
747
  """Assistant node"""
748
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
749
 
 
 
 
 
 
 
 
 
 
 
 
750
  """
751
  def retriever(state: MessagesState):
752
  "Retriever node"
@@ -762,6 +775,7 @@ def build_graph(provider: str = "huggingface"):
762
  return {"messages": [sys_msg] + state["messages"]}
763
  """
764
 
 
765
  builder = StateGraph(MessagesState)
766
  #builder.add_node("retriever", retriever)
767
  builder.add_node("assistant", assistant)
@@ -774,6 +788,18 @@ def build_graph(provider: str = "huggingface"):
774
  tools_condition,
775
  )
776
  builder.add_edge("tools", "assistant")
 
 
 
 
 
 
 
 
 
 
 
 
777
 
778
  # Compile graph
779
  return builder.compile()
 
13
  import uuid
14
  import numpy as np
15
  from code_interpreter import CodeInterpreter
16
+ import re
17
+ from langchain_core.messages import AIMessage
18
 
19
  interpreter_instance = CodeInterpreter()
20
 
 
749
  """Assistant node"""
750
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
751
 
752
+ # Define extractor node
753
+ def extractor(state: MessagesState):
754
+ last_message = state["messages"][-1].content
755
+ match = re.search(r"FINAL ANSWER:\s*(.*)", last_message, re.IGNORECASE)
756
+ if match:
757
+ answer = match.group(1).strip().split("\n")[0].rstrip('.')
758
+ return {"messages": [AIMessage(content=f"FINAL ANSWER: {answer}")]}
759
+ else:
760
+ return {"messages": [AIMessage(content="FINAL ANSWER: No valid answer found")]}
761
+
762
+
763
  """
764
  def retriever(state: MessagesState):
765
  "Retriever node"
 
775
  return {"messages": [sys_msg] + state["messages"]}
776
  """
777
 
778
+ """
779
  builder = StateGraph(MessagesState)
780
  #builder.add_node("retriever", retriever)
781
  builder.add_node("assistant", assistant)
 
788
  tools_condition,
789
  )
790
  builder.add_edge("tools", "assistant")
791
+ """
792
+
793
+ builder = StateGraph(MessagesState)
794
+ builder.add_node("assistant", assistant)
795
+ builder.add_node("tools", ToolNode(tools))
796
+ builder.add_node("extractor", extractor)
797
+ builder.add_edge(START, "assistant")
798
+ builder.add_conditional_edges("assistant", tools_condition)
799
+ builder.add_edge("tools", "assistant")
800
+ builder.add_edge("assistant", "extractor")
801
+ builder.set_finish_point("extractor")
802
+
803
 
804
  # Compile graph
805
  return builder.compile()