budivoy commited on
Commit
322605d
·
verified ·
1 Parent(s): 913bb08

Add RWKV support for Agent

Browse files
Files changed (1) hide show
  1. agents.py +35 -1
agents.py CHANGED
@@ -13,6 +13,15 @@ from langchain_community.document_loaders import ArxivLoader
13
  from langchain_community.vectorstores import SupabaseVectorStore
14
  from langchain_core.messages import SystemMessage, HumanMessage
15
  from langchain_core.tools import tool
 
 
 
 
 
 
 
 
 
16
 
17
  load_dotenv()
18
 
@@ -144,7 +153,7 @@ tools = [
144
  ]
145
 
146
  # Build graph function
147
- def build_graph(provider: str = "groq"):
148
  """Build the graph"""
149
  # Load environment variables from .env file
150
  if provider == "google":
@@ -161,6 +170,31 @@ def build_graph(provider: str = "groq"):
161
  temperature=0,
162
  ),
163
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  else:
165
  raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
166
  # Bind tools to LLM
 
13
  from langchain_community.vectorstores import SupabaseVectorStore
14
  from langchain_core.messages import SystemMessage, HumanMessage
15
  from langchain_core.tools import tool
16
+ from huggingface_hub import hf_hub_download
17
+
18
+ # RWKV setup flags — must come before importing RWKV
19
+ os.environ["RWKV_JIT_ON"] = "1"
20
+ os.environ["RWKV_V7_ON"] = "1" # enable RWKV-7
21
+ os.environ["RWKV_CUDA_ON"] = "0" # set to "1"
22
+
23
+ from rwkv.model import RWKV
24
+ from rwkv.utils import PIPELINE
25
 
26
  load_dotenv()
27
 
 
153
  ]
154
 
155
  # Build graph function
156
+ def build_graph(provider: str = "rwkv"):
157
  """Build the graph"""
158
  # Load environment variables from .env file
159
  if provider == "google":
 
170
  temperature=0,
171
  ),
172
  )
173
+ elif provider == "rwkv":
174
+ # --- BEGIN RWKV SETUP ---
175
+ title = "rwkv7-g1"
176
+ pth = hf_hub_download(repo_id="BlinkDL/rwkv7-g1", filename=f"{title}.pth")
177
+ # 2) Load RWKV (drop .pth extension for RWKV loader)
178
+ rwkv_model = RWKV(model=pth.replace(".pth", ""), strategy="cpu fp32")
179
+ # 3) Build the tokenization + generation pipeline
180
+ rwkv_pipe = PIPELINE(rwkv_model, "rwkv_vocab_v20230424")
181
+ # 4) Wrap into a Chat-style interface
182
+ class ChatRWKV:
183
+ def __init__(self, pipe):
184
+ self.pipe = pipe
185
+ def invoke(self, messages):
186
+ prompt = "\n".join(m.content for m in messages)
187
+ return self.pipe(
188
+ prompt,
189
+ temperature=0.0,
190
+ top_p=0.95,
191
+ max_tokens=256,
192
+ )
193
+ def bind_tools(self, tools):
194
+ return self
195
+
196
+ llm = ChatRWKV(rwkv_pipe)
197
+ # --- END RWKV SETUP ---
198
  else:
199
  raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
200
  # Bind tools to LLM