Add custom bind_tools
Browse files
agents.py
CHANGED
@@ -138,9 +138,6 @@ def arvix_search(query: str) -> str:
|
|
138 |
with open("system_prompt.txt", "r", encoding="utf-8") as f:
|
139 |
system_prompt = f.read()
|
140 |
|
141 |
-
# System message
|
142 |
-
sys_msg = SystemMessage(content=system_prompt)
|
143 |
-
|
144 |
|
145 |
tools = [
|
146 |
multiply,
|
@@ -176,10 +173,47 @@ def build_graph(provider: str = "rwkv"):
|
|
176 |
title = "rwkv7-g1-0.1b-20250307-ctx4096"
|
177 |
pth = hf_hub_download(repo_id="BlinkDL/rwkv7-g1", filename=f"{title}.pth")
|
178 |
model_path = pth.replace(".pth", "")
|
179 |
-
|
180 |
model=model_path,
|
181 |
strategy="cpu fp32",
|
182 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
183 |
# --- END RWKV SETUP ---
|
184 |
else:
|
185 |
raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
|
|
|
138 |
with open("system_prompt.txt", "r", encoding="utf-8") as f:
|
139 |
system_prompt = f.read()
|
140 |
|
|
|
|
|
|
|
141 |
|
142 |
tools = [
|
143 |
multiply,
|
|
|
173 |
title = "rwkv7-g1-0.1b-20250307-ctx4096"
|
174 |
pth = hf_hub_download(repo_id="BlinkDL/rwkv7-g1", filename=f"{title}.pth")
|
175 |
model_path = pth.replace(".pth", "")
|
176 |
+
raw_llm = RWKV(
|
177 |
model=model_path,
|
178 |
strategy="cpu fp32",
|
179 |
)
|
180 |
+
|
181 |
+
class RWKVWithTools:
|
182 |
+
def __init__(self, llm, system_prompt: str):
|
183 |
+
self.llm = llm
|
184 |
+
self.system_prompt = system_prompt
|
185 |
+
self.tools = []
|
186 |
+
|
187 |
+
def bind_tools(self, tools):
|
188 |
+
self.tools = tools
|
189 |
+
return self
|
190 |
+
|
191 |
+
def invoke(self, messages):
|
192 |
+
# Build a tools spec block
|
193 |
+
specs = []
|
194 |
+
for t in self.tools:
|
195 |
+
specs.append(f"- {t.name}({getattr(t, 'args_schema', {})}): {t.description}")
|
196 |
+
|
197 |
+
header = (
|
198 |
+
f"{self.system_prompt}\n\n"
|
199 |
+
"TOOLS AVAILABLE:\n"
|
200 |
+
+ "\n".join(specs)
|
201 |
+
+ "\n\n"
|
202 |
+
"To call a tool, respond exactly with:\n"
|
203 |
+
"`<tool_name>(arg1=…,arg2=…)` and nothing else.\n\n"
|
204 |
+
)
|
205 |
+
|
206 |
+
# Reconstruct conversation
|
207 |
+
convo = "\n".join(
|
208 |
+
f"{'User:' if isinstance(m, HumanMessage) else 'Assistant:'} {m.content}"
|
209 |
+
for m in messages
|
210 |
+
)
|
211 |
+
|
212 |
+
prompt = header + convo
|
213 |
+
# delegate to LangChain’s invoke()
|
214 |
+
return self.llm.invoke(prompt)
|
215 |
+
|
216 |
+
llm = RWKVWithTools(raw_llm, system_prompt=system_prompt)
|
217 |
# --- END RWKV SETUP ---
|
218 |
else:
|
219 |
raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
|