Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import os | |
import json | |
import requests | |
import gradio as gr | |
import pandas as pd | |
from huggingface_hub import HfApi, hf_hub_download, snapshot_download | |
from huggingface_hub.repocard import metadata_load | |
from apscheduler.schedulers.background import BackgroundScheduler | |
from tqdm.contrib.concurrent import thread_map | |
from utils import * | |
DATASET_REPO_URL = "https://huggingface.co/datasets/huggingface-projects/drlc-leaderboard-data" | |
DATASET_REPO_ID = "huggingface-projects/drlc-leaderboard-data" | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
block = gr.Blocks() | |
api = HfApi(token=HF_TOKEN) | |
# Define RL environments | |
rl_envs = [ | |
{"rl_env_beautiful": "LunarLander-v2 π", "rl_env": "LunarLander-v2", "video_link": "", "global": None}, | |
{"rl_env_beautiful": "CartPole-v1", "rl_env": "CartPole-v1", "video_link": "https://huggingface.co/sb3/ppo-CartPole-v1/resolve/main/replay.mp4", "global": None}, | |
{"rl_env_beautiful": "FrozenLake-v1-4x4-no_slippery βοΈ", "rl_env": "FrozenLake-v1-4x4-no_slippery", "video_link": "", "global": None}, | |
{"rl_env_beautiful": "FrozenLake-v1-8x8-no_slippery βοΈ", "rl_env": "FrozenLake-v1-8x8-no_slippery", "video_link": "", "global": None}, | |
{"rl_env_beautiful": "FrozenLake-v1-4x4 βοΈ", "rl_env": "FrozenLake-v1-4x4", "video_link": "", "global": None}, | |
{"rl_env_beautiful": "FrozenLake-v1-8x8 βοΈ", "rl_env": "FrozenLake-v1-8x8", "video_link": "", "global": None}, | |
{"rl_env_beautiful": "Taxi-v3 π", "rl_env": "Taxi-v3", "video_link": "", "global": None}, | |
{"rl_env_beautiful": "CarRacing-v0 ποΈ", "rl_env": "CarRacing-v0", "video_link": "", "global": None}, | |
{"rl_env_beautiful": "CarRacing-v2 ποΈ", "rl_env": "CarRacing-v2", "video_link": "", "global": None}, | |
{"rl_env_beautiful": "MountainCar-v0 β°οΈ", "rl_env": "MountainCar-v0", "video_link": "", "global": None}, | |
{"rl_env_beautiful": "SpaceInvadersNoFrameskip-v4 πΎ", "rl_env": "SpaceInvadersNoFrameskip-v4", "video_link": "", "global": None}, | |
{"rl_env_beautiful": "PongNoFrameskip-v4 πΎ", "rl_env": "PongNoFrameskip-v4", "video_link": "", "global": None}, | |
{"rl_env_beautiful": "BreakoutNoFrameskip-v4 π§±", "rl_env": "BreakoutNoFrameskip-v4", "video_link": "", "global": None}, | |
{"rl_env_beautiful": "QbertNoFrameskip-v4 π¦", "rl_env": "QbertNoFrameskip-v4", "video_link": "", "global": None}, | |
{"rl_env_beautiful": "BipedalWalker-v3", "rl_env": "BipedalWalker-v3", "video_link": "", "global": None}, | |
{"rl_env_beautiful": "Walker2DBulletEnv-v0", "rl_env": "Walker2DBulletEnv-v0", "video_link": "", "global": None}, | |
{"rl_env_beautiful": "AntBulletEnv-v0", "rl_env": "AntBulletEnv-v0", "video_link": "", "global": None}, | |
{"rl_env_beautiful": "HalfCheetahBulletEnv-v0", "rl_env": "HalfCheetahBulletEnv-v0", "video_link": "", "global": None}, | |
{"rl_env_beautiful": "PandaReachDense-v2", "rl_env": "PandaReachDense-v2", "video_link": "", "global": None}, | |
{"rl_env_beautiful": "PandaReachDense-v3", "rl_env": "PandaReachDense-v3", "video_link": "", "global": None}, | |
{"rl_env_beautiful": "Pixelcopter-PLE-v0", "rl_env": "Pixelcopter-PLE-v0", "video_link": "", "global": None} | |
] | |
# -------------------- Utility Functions -------------------- | |
def restart(): | |
"""Restart the Hugging Face Space.""" | |
print("RESTARTING SPACE...") | |
api.restart_space(repo_id="huggingface-projects/Deep-Reinforcement-Learning-Leaderboard") | |
def download_leaderboard_dataset(): | |
"""Download leaderboard dataset once at startup.""" | |
print("Downloading leaderboard dataset...") | |
return snapshot_download(repo_id=DATASET_REPO_ID, repo_type="dataset") | |
def get_metadata(model_id): | |
"""Fetch metadata for a given model from Hugging Face.""" | |
try: | |
readme_path = hf_hub_download(model_id, filename="README.md", etag_timeout=180) | |
return metadata_load(readme_path) | |
except requests.exceptions.HTTPError: | |
return None # 404 README.md not found | |
def parse_metrics_accuracy(meta): | |
"""Extract accuracy metrics from metadata.""" | |
if "model-index" not in meta: | |
return None | |
result = meta["model-index"][0]["results"] | |
metrics = result[0]["metrics"] | |
return metrics[0]["value"] | |
def parse_rewards(accuracy): | |
"""Extract mean and std rewards from accuracy metrics.""" | |
default_std = -1000 | |
default_reward = -1000 | |
if accuracy is not None: | |
parsed = str(accuracy).split('+/-') | |
mean_reward = float(parsed[0].strip()) if parsed[0] else default_reward | |
std_reward = float(parsed[1].strip()) if len(parsed) > 1 else 0 | |
else: | |
mean_reward, std_reward = default_reward, default_std | |
return mean_reward, std_reward | |
def get_model_ids(rl_env): | |
"""Retrieve models matching the given RL environment.""" | |
return [x.modelId for x in api.list_models(filter=rl_env)] | |
def update_leaderboard_dataset_parallel(rl_env, path): | |
"""Parallelized update of leaderboard dataset for a given RL environment.""" | |
model_ids = get_model_ids(rl_env) | |
def process_model(model_id): | |
meta = get_metadata(model_id) | |
if not meta: | |
return None | |
user_id = model_id.split('/')[0] | |
row = { | |
"User": user_id, | |
"Model": model_id, | |
"Results": None, | |
"Mean Reward": None, | |
"Std Reward": None | |
} | |
accuracy = parse_metrics_accuracy(meta) | |
mean_reward, std_reward = parse_rewards(accuracy) | |
row["Results"] = mean_reward - std_reward | |
row["Mean Reward"] = mean_reward | |
row["Std Reward"] = std_reward | |
return row | |
data = list(thread_map(process_model, model_ids, desc="Processing models")) | |
data = [row for row in data if row is not None] | |
ranked_dataframe = rank_dataframe(pd.DataFrame.from_records(data)) | |
ranked_dataframe.to_csv(os.path.join(path, f"{rl_env}.csv"), index=False) | |
return ranked_dataframe | |
def rank_dataframe(dataframe): | |
"""Sort models by results and assign ranking.""" | |
dataframe = dataframe.sort_values(by=['Results', 'User', 'Model'], ascending=False) | |
dataframe.insert(0, 'Ranking', range(1, len(dataframe) + 1)) | |
return dataframe | |
def run_update_dataset(): | |
"""Update dataset periodically using the scheduler.""" | |
path_ = download_leaderboard_dataset() | |
for env in rl_envs: | |
update_leaderboard_dataset_parallel(env["rl_env"], path_) | |
print("Uploading updated dataset...") | |
api.upload_folder( | |
folder_path=path_, | |
repo_id=DATASET_REPO_ID, | |
repo_type="dataset", | |
commit_message="Update dataset" | |
) | |
def filter_data(rl_env, path, user_id): | |
"""Filter dataset for a specific user ID.""" | |
data_df = pd.read_csv(os.path.join(path, f"{rl_env}.csv")) | |
return data_df[data_df["User"] == user_id] | |
# -------------------- Gradio UI -------------------- | |
print("Initializing dataset...") | |
path_ = download_leaderboard_dataset() | |
with block: | |
gr.Markdown(""" | |
# π Deep Reinforcement Learning Course Leaderboard π | |
This leaderboard displays trained agents from the [Deep Reinforcement Learning Course](https://huggingface.co/learn/deep-rl-course/unit0/introduction?fw=pt). | |
**Models are ranked using `mean_reward - std_reward`.** | |
If you can't find your model, please wait for the next update (every 2 hours). | |
""") | |
grpath = gr.State(path_) # Store dataset path as a state variable | |
for env in rl_envs: | |
with gr.TabItem(env["rl_env_beautiful"]): | |
gr.Markdown(f"## {env['rl_env_beautiful']}") | |
user_id = gr.Textbox(label="Your user ID") | |
search_btn = gr.Button("Search π") | |
reset_btn = gr.Button("Clear Search") | |
env_state = gr.State(env["rl_env"]) # Store environment name as a state variable | |
gr_dataframe = gr.Dataframe( | |
value=pd.read_csv(os.path.join(path_, f"{env['rl_env']}.csv")), | |
headers=["Ranking π", "User π€", "Model π€", "Results", "Mean Reward", "Std Reward"], | |
datatype=["number", "markdown", "markdown", "number", "number", "number"], | |
# row_count=(100, 'fixed') | |
row_count=(100,"dynamic") # Allows displaying all rows dynamically | |
) | |
# β Corrected: Use `gr.State()` for env["rl_env"] and `grpath` | |
search_btn.click(fn=filter_data, inputs=[env_state, grpath, user_id], outputs=gr_dataframe) | |
reset_btn.click(fn=lambda: pd.read_csv(os.path.join(path_, f"{env['rl_env']}.csv")), inputs=[], outputs=gr_dataframe) | |
# -------------------- Scheduler -------------------- | |
scheduler = BackgroundScheduler() | |
scheduler.add_job(run_update_dataset, 'interval', hours=2) # Update dataset every 2 hours | |
scheduler.add_job(restart, 'interval', hours=3) # Restart space every 3 hours | |
scheduler.start() | |
block.launch() | |