Tinkabellaa commited on
Commit
b3f8d37
·
verified ·
1 Parent(s): 1963b4d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -0
app.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+
5
+ MODEL_NAME = "OpenAssistant/oasst-sft-1-pythia-12b"
6
+
7
+ # Load model & tokenizer
8
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
9
+ model = AutoModelForCausalLM.from_pretrained(
10
+ MODEL_NAME,
11
+ device_map="auto",
12
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
13
+ )
14
+ model.eval()
15
+
16
+ STYLE_PROMPT = """
17
+ You are a professional ghostwriter. Your writing is:
18
+ - Thick with rich description
19
+ - Heavy and raw in tone
20
+ - Mature, triggering, and relatable
21
+ - Explicit within acceptable content guidelines
22
+ - Italicized for internal monologue or emotional emphasis
23
+ - Structured and catchy
24
+ - Explicit and authentic
25
+ - Sexual and mixture of sfw and nsfw
26
+ - Realistic and slice of life
27
+ """
28
+
29
+ def stream_generate(prompt, temperature, max_tokens):
30
+ input_text = f"{STYLE_PROMPT}\n\nUser prompt:\n{prompt.strip()}"
31
+ input_ids = tokenizer.encode(input_text, return_tensors="pt").to(model.device)
32
+
33
+ output_ids = input_ids
34
+ past_key_values = None
35
+
36
+ for _ in range(max_tokens):
37
+ with torch.no_grad():
38
+ outputs = model(input_ids=output_ids[:, -1:], past_key_values=past_key_values, use_cache=True)
39
+ next_token_logits = outputs.logits[:, -1, :] / temperature
40
+ next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
41
+ output_ids = torch.cat([output_ids, next_token], dim=-1)
42
+ past_key_values = outputs.past_key_values
43
+
44
+ decoded = tokenizer.decode(output_ids[0], skip_special_tokens=True)
45
+ generated = decoded.replace(input_text, "").strip()
46
+ yield generated
47
+
48
+ if tokenizer.decode(next_token[0]) in [tokenizer.eos_token, "\n\n"]:
49
+ break
50
+
51
+ with gr.Blocks(title="🧠 HuggingChat Stream Writer") as demo:
52
+ gr.Markdown("## ✍️ Real-Time HuggingChat-Style Generator")
53
+ gr.Markdown("*Watch your story unfold word by word...*")
54
+
55
+ with gr.Row():
56
+ prompt = gr.Textbox(label="Prompt", lines=5, placeholder="Describe a rainy night and inner conflict...")
57
+ temperature = gr.Slider(0.5, 1.5, value=0.9, step=0.1, label="Temperature")
58
+ max_tokens = gr.Slider(50, 800, value=300, step=10, label="Max Tokens")
59
+
60
+ with gr.Row():
61
+ output = gr.Textbox(label="Generated Output (streaming)", lines=15)
62
+
63
+ gr.Button("Generate").click(fn=stream_generate, inputs=[prompt, temperature, max_tokens], outputs=output)
64
+
65
+ demo.launch()