Spaces:
Build error
Build error
| try: | |
| import flash_attn | |
| except: | |
| import subprocess | |
| print("Installing flash-attn...") | |
| subprocess.run( | |
| "pip install flash-attn --no-build-isolation", | |
| env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, | |
| shell=True, | |
| ) | |
| import flash_attn | |
| print("flash-attn installed.") | |
| import os | |
| import torch | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| TextIteratorStreamer, | |
| BitsAndBytesConfig, | |
| ) | |
| from threading import Thread | |
| import gradio as gr | |
| from dotenv import load_dotenv | |
| import spaces | |
| load_dotenv() | |
| HF_API_KEY = os.getenv("HF_API_KEY") | |
| MODEL_NAME_MAP = { | |
| "150m-instruct3": "llm-jp/llm-jp-3-150m-instruct3", | |
| "440m-instruct3": "llm-jp/llm-jp-3-440m-instruct3", | |
| "980m-instruct3": "llm-jp/llm-jp-3-980m-instruct3", | |
| # "1.8b-instruct3": "llm-jp/llm-jp-3-1.8b-instruct3", | |
| # "3.7b-instruct3": "llm-jp/llm-jp-3-3.7b-instruct3", | |
| # "13b-instruct3": "llm-jp/llm-jp-3-13b-instruct3", | |
| } | |
| quantization_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_use_double_quant=True, | |
| ) | |
| MODELS = { | |
| key: AutoModelForCausalLM.from_pretrained( | |
| repo_id, quantization_config=quantization_config, device_map="auto" | |
| ) for key, repo_id in MODEL_NAME_MAP.items() | |
| } | |
| TOKENIZERS = { | |
| key: AutoTokenizer.from_pretrained(repo_id) for key, repo_id in MODEL_NAME_MAP.items() | |
| } | |
| print("Compiling model...") | |
| for key, model in MODELS.items(): | |
| MODELS[key] = torch.compile(model) | |
| print("Model compiled.") | |
| def generate( | |
| model_name: str, | |
| message: str, | |
| history: list[tuple[str, str]], | |
| system_message: str, | |
| max_tokens: int, | |
| temperature: float, | |
| top_p: float, | |
| top_k: int, | |
| ): | |
| if not message or message.strip() == "": | |
| return "", history | |
| messages = [{"role": "system", "content": system_message}] | |
| for val in history: | |
| if val[0]: | |
| messages.append({"role": "user", "content": val[0]}) | |
| if val[1]: | |
| messages.append({"role": "assistant", "content": val[1]}) | |
| messages.append({"role": "user", "content": message}) | |
| tokenized_input = TOKENIZERS[model_name].apply_chat_template( | |
| messages, add_generation_prompt=True, tokenize=True, return_tensors="pt" | |
| ).to(model.device) | |
| streamer = TextIteratorStreamer( | |
| TOKENIZERS[model_name], timeout=10.0, skip_prompt=True, skip_special_tokens=True | |
| ) | |
| generate_kwargs = dict( | |
| input_ids=tokenized_input, | |
| streamer=streamer, | |
| max_new_tokens=int(max_tokens), | |
| do_sample=True, | |
| temperature=float(temperature), | |
| top_k=int(top_k), | |
| top_p=float(top_p), | |
| num_beams=1, | |
| ) | |
| t = Thread(target=MODELS[model_name].generate, kwargs=generate_kwargs) | |
| t.start() | |
| # 返す値を初期化 | |
| partial_message = "" | |
| for new_token in streamer: | |
| partial_message += new_token | |
| new_history = history + [(message, partial_message)] | |
| # 入力テキストをクリアする | |
| yield "", new_history | |
| def respond( | |
| model_name: str, | |
| message: str, | |
| history: list[tuple[str, str]], | |
| system_message: str, | |
| max_tokens: int, | |
| temperature: float, | |
| top_p: float, | |
| top_k: int, | |
| ): | |
| for stream in generate( | |
| model_name, | |
| message, | |
| history, | |
| system_message, | |
| max_tokens, | |
| temperature, | |
| top_p, | |
| top_k, | |
| ): | |
| yield (*stream,) | |
| def retry( | |
| model_name: str, | |
| history: list[tuple[str, str]], | |
| system_message: str, | |
| max_tokens: int, | |
| temperature: float, | |
| top_p: float, | |
| top_k: int, | |
| ): | |
| # 最後のメッセージを削除 | |
| last_conversation = history[-1] | |
| user_message = last_conversation[0] | |
| history = history[:-1] | |
| for stream in generate( | |
| model_name, | |
| user_message, | |
| history, | |
| system_message, | |
| max_tokens, | |
| temperature, | |
| top_p, | |
| top_k, | |
| ): | |
| yield (*stream,) | |
| def demo(): | |
| with gr.Blocks() as ui: | |
| gr.Markdown( | |
| """\ | |
| # llm-jp/llm-jp-3 instruct3 モデルデモ | |
| コレクション: https://huggingface.co/collections/llm-jp/llm-jp-3-fine-tuned-models-672c621db852a01eae939731 | |
| """ | |
| ) | |
| model_name_dropdown = gr.Dropdown(label="モデル", choices=list(MODELS.keys()), value=list(MODELS.keys())[0]) | |
| chat_history = gr.Chatbot(value=[]) | |
| with gr.Row(): | |
| retry_btn = gr.Button(value="🔄 再生成", scale=1) | |
| clear_btn = gr.ClearButton( | |
| components=[chat_history], value="🗑️ 削除", scale=1, | |
| ) | |
| with gr.Row(): | |
| input_text = gr.Textbox( | |
| value="", | |
| placeholder="質問を入力してください...", | |
| show_label=False, | |
| scale=8, | |
| ) | |
| start_btn = gr.Button( | |
| value="送信", | |
| variant="primary", | |
| scale=2, | |
| ) | |
| with gr.Accordion(label="詳細設定", open=False): | |
| system_prompt_text = gr.Textbox( | |
| label="システムプロンプト", | |
| value="以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。", | |
| ) | |
| max_new_tokens_slider = gr.Slider( | |
| minimum=1, maximum=2048, value=256, step=1, label="Max new tokens" | |
| ) | |
| temperature_slider = gr.Slider( | |
| minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature" | |
| ) | |
| top_p_slider = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.95, | |
| step=0.05, | |
| label="Top-p", | |
| ) | |
| top_k_slider = gr.Slider( | |
| minimum=10, maximum=500, value=100, step=10, label="Top-k" | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["情けは人の為ならずとはどういう意味ですか?"], | |
| ["まどマギで一番可愛いのは誰?"], | |
| ], | |
| inputs=[input_text], | |
| cache_examples=False, | |
| ) | |
| gr.on( | |
| triggers=[start_btn.click, input_text.submit], | |
| fn=respond, | |
| inputs=[ | |
| model_name_dropdown, | |
| input_text, | |
| chat_history, | |
| system_prompt_text, | |
| max_new_tokens_slider, | |
| temperature_slider, | |
| top_p_slider, | |
| top_k_slider, | |
| ], | |
| outputs=[input_text, chat_history], | |
| ) | |
| retry_btn.click( | |
| retry, | |
| inputs=[ | |
| model_name_dropdown, | |
| chat_history, | |
| system_prompt_text, | |
| max_new_tokens_slider, | |
| temperature_slider, | |
| top_p_slider, | |
| top_k_slider, | |
| ], | |
| outputs=[input_text, chat_history], | |
| ) | |
| ui.launch() | |
| if __name__ == "__main__": | |
| demo() | |