budivoy commited on
Commit
60e2f92
·
verified ·
1 Parent(s): dd3b428

Update agents.py

Browse files
Files changed (1) hide show
  1. agents.py +10 -10
agents.py CHANGED
@@ -6,7 +6,6 @@ from langgraph.prebuilt import tools_condition
6
  from langgraph.prebuilt import ToolNode
7
  from langchain_google_genai import ChatGoogleGenerativeAI
8
  from langchain_groq import ChatGroq
9
- from langchain_community.llms import RWKV
10
  from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
11
  from langchain_community.tools.tavily_search import TavilySearchResults
12
  from langchain_community.document_loaders import WikipediaLoader
@@ -14,7 +13,10 @@ from langchain_community.document_loaders import ArxivLoader
14
  from langchain_community.vectorstores import SupabaseVectorStore
15
  from langchain_core.messages import SystemMessage, HumanMessage
16
  from langchain_core.tools import tool
 
17
  from huggingface_hub import hf_hub_download
 
 
18
 
19
  load_dotenv()
20
 
@@ -166,15 +168,12 @@ def build_graph(provider: str = "rwkv"):
166
  pth = hf_hub_download(repo_id="BlinkDL/rwkv7-g1", filename=f"{title}.pth")
167
  model_path = pth.replace(".pth", "")
168
 
169
- raw_llm = RWKV(
170
- model=model_path,
171
- strategy="cpu fp32",
172
- tokens_path="./20B_tokenizer.json"
173
- )
174
 
175
  class RWKVWithTools:
176
- def __init__(self, llm, system_prompt: str):
177
- self.llm = llm
178
  self.system_prompt = system_prompt
179
  self.tools = []
180
 
@@ -204,8 +203,9 @@ def build_graph(provider: str = "rwkv"):
204
  )
205
 
206
  prompt = header + convo
207
- # delegate to LangChain’s invoke()
208
- return self.llm.invoke(prompt)
 
209
 
210
  llm = RWKVWithTools(raw_llm, system_prompt=system_prompt)
211
  # --- END RWKV SETUP ---
 
6
  from langgraph.prebuilt import ToolNode
7
  from langchain_google_genai import ChatGoogleGenerativeAI
8
  from langchain_groq import ChatGroq
 
9
  from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
10
  from langchain_community.tools.tavily_search import TavilySearchResults
11
  from langchain_community.document_loaders import WikipediaLoader
 
13
  from langchain_community.vectorstores import SupabaseVectorStore
14
  from langchain_core.messages import SystemMessage, HumanMessage
15
  from langchain_core.tools import tool
16
+
17
  from huggingface_hub import hf_hub_download
18
+ from rwkv.model import RWKV
19
+ from rwkv.utils import PIPELINE, PIPELINE_ARGS
20
 
21
  load_dotenv()
22
 
 
168
  pth = hf_hub_download(repo_id="BlinkDL/rwkv7-g1", filename=f"{title}.pth")
169
  model_path = pth.replace(".pth", "")
170
 
171
+ raw_llm = RWKV(model=model_path, strategy='cuda fp32')
172
+ pipeline = PIPELINE(raw_llm, "rwkv_vocab_v20230424")
 
 
 
173
 
174
  class RWKVWithTools:
175
+ def __init__(self, pipeline, system_prompt: str):
176
+ self.pipeline = pipeline
177
  self.system_prompt = system_prompt
178
  self.tools = []
179
 
 
203
  )
204
 
205
  prompt = header + convo
206
+
207
+ # delegate to RWKV invoke()
208
+ return self.pipeline.generate(prompt, token_count=200)
209
 
210
  llm = RWKVWithTools(raw_llm, system_prompt=system_prompt)
211
  # --- END RWKV SETUP ---