Add RWKV support for Agent
Browse files
agents.py
CHANGED
@@ -13,6 +13,15 @@ from langchain_community.document_loaders import ArxivLoader
|
|
13 |
from langchain_community.vectorstores import SupabaseVectorStore
|
14 |
from langchain_core.messages import SystemMessage, HumanMessage
|
15 |
from langchain_core.tools import tool
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
load_dotenv()
|
18 |
|
@@ -144,7 +153,7 @@ tools = [
|
|
144 |
]
|
145 |
|
146 |
# Build graph function
|
147 |
-
def build_graph(provider: str = "
|
148 |
"""Build the graph"""
|
149 |
# Load environment variables from .env file
|
150 |
if provider == "google":
|
@@ -161,6 +170,31 @@ def build_graph(provider: str = "groq"):
|
|
161 |
temperature=0,
|
162 |
),
|
163 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
else:
|
165 |
raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
|
166 |
# Bind tools to LLM
|
|
|
13 |
from langchain_community.vectorstores import SupabaseVectorStore
|
14 |
from langchain_core.messages import SystemMessage, HumanMessage
|
15 |
from langchain_core.tools import tool
|
16 |
+
from huggingface_hub import hf_hub_download
|
17 |
+
|
18 |
+
# RWKV setup flags — must come before importing RWKV
|
19 |
+
os.environ["RWKV_JIT_ON"] = "1"
|
20 |
+
os.environ["RWKV_V7_ON"] = "1" # enable RWKV-7
|
21 |
+
os.environ["RWKV_CUDA_ON"] = "0" # set to "1"
|
22 |
+
|
23 |
+
from rwkv.model import RWKV
|
24 |
+
from rwkv.utils import PIPELINE
|
25 |
|
26 |
load_dotenv()
|
27 |
|
|
|
153 |
]
|
154 |
|
155 |
# Build graph function
|
156 |
+
def build_graph(provider: str = "rwkv"):
|
157 |
"""Build the graph"""
|
158 |
# Load environment variables from .env file
|
159 |
if provider == "google":
|
|
|
170 |
temperature=0,
|
171 |
),
|
172 |
)
|
173 |
+
elif provider == "rwkv":
|
174 |
+
# --- BEGIN RWKV SETUP ---
|
175 |
+
title = "rwkv7-g1"
|
176 |
+
pth = hf_hub_download(repo_id="BlinkDL/rwkv7-g1", filename=f"{title}.pth")
|
177 |
+
# 2) Load RWKV (drop .pth extension for RWKV loader)
|
178 |
+
rwkv_model = RWKV(model=pth.replace(".pth", ""), strategy="cpu fp32")
|
179 |
+
# 3) Build the tokenization + generation pipeline
|
180 |
+
rwkv_pipe = PIPELINE(rwkv_model, "rwkv_vocab_v20230424")
|
181 |
+
# 4) Wrap into a Chat-style interface
|
182 |
+
class ChatRWKV:
|
183 |
+
def __init__(self, pipe):
|
184 |
+
self.pipe = pipe
|
185 |
+
def invoke(self, messages):
|
186 |
+
prompt = "\n".join(m.content for m in messages)
|
187 |
+
return self.pipe(
|
188 |
+
prompt,
|
189 |
+
temperature=0.0,
|
190 |
+
top_p=0.95,
|
191 |
+
max_tokens=256,
|
192 |
+
)
|
193 |
+
def bind_tools(self, tools):
|
194 |
+
return self
|
195 |
+
|
196 |
+
llm = ChatRWKV(rwkv_pipe)
|
197 |
+
# --- END RWKV SETUP ---
|
198 |
else:
|
199 |
raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
|
200 |
# Bind tools to LLM
|