Bluestrikeai commited on
Commit
6b83aec
Β·
verified Β·
1 Parent(s): 21838a7

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +411 -0
app.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import gradio as gr
4
+ import spaces
5
+ from transformers import AutoTokenizer, AutoModel
6
+ import time
7
+ import re
8
+
9
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
10
+ print(f"Using device: {device}")
11
+
12
+ # Load model and tokenizer
13
+ tokenizer = AutoTokenizer.from_pretrained('mlx-community/LLaDA-8B-Instruct-mlx-4bit', trust_remote_code=True)
14
+ model = AutoModel.from_pretrained('mlx-community/LLaDA-8B-Instruct-mlx-4bit', trust_remote_code=True,
15
+ torch_dtype=torch.bfloat16).to(device)
16
+
17
+ # Constants
18
+ MASK_TOKEN = "[MASK]"
19
+ MASK_ID = 126336 # The token ID of [MASK] in LLaDA
20
+
21
+ def parse_constraints(constraints_text):
22
+ """Parse constraints in format: 'position:word, position:word, ...'"""
23
+ constraints = {}
24
+ if not constraints_text:
25
+ return constraints
26
+
27
+ parts = constraints_text.split(',')
28
+ for part in parts:
29
+ if ':' not in part:
30
+ continue
31
+ pos_str, word = part.split(':', 1)
32
+ try:
33
+ pos = int(pos_str.strip())
34
+ word = word.strip()
35
+ if word and pos >= 0:
36
+ constraints[pos] = word
37
+ except ValueError:
38
+ continue
39
+
40
+ return constraints
41
+
42
+ def format_chat_history(history):
43
+ """
44
+ Format chat history for the LLaDA model
45
+
46
+ Args:
47
+ history: List of [user_message, assistant_message] pairs
48
+
49
+ Returns:
50
+ Formatted conversation for the model
51
+ """
52
+ messages = []
53
+ for user_msg, assistant_msg in history:
54
+ messages.append({"role": "user", "content": user_msg})
55
+ if assistant_msg: # Skip if None (for the latest user message)
56
+ messages.append({"role": "assistant", "content": assistant_msg})
57
+
58
+ return messages
59
+
60
+ @spaces.GPU
61
+ def generate_response_with_visualization(model, tokenizer, device, messages, gen_length=64, steps=32, constraints=None):
62
+ """
63
+ Generate text with LLaDA model with visualization of the denoising process
64
+
65
+ Args:
66
+ messages: List of message dictionaries with 'role' and 'content'
67
+
68
+ Returns:
69
+ List of visualization states showing the progression and final text
70
+ """
71
+
72
+ # Process constraints
73
+ if constraints is None:
74
+ constraints = {}
75
+
76
+ # Convert any string constraints to token IDs
77
+ processed_constraints = {}
78
+ for pos, word in constraints.items():
79
+ tokens = tokenizer.encode(" " + word, add_special_tokens=False)
80
+ for i, token_id in enumerate(tokens):
81
+ processed_constraints[pos + i] = token_id
82
+
83
+ # Prepare the prompt using chat template
84
+ chat_input = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
85
+ input_ids = tokenizer(chat_input)['input_ids']
86
+ input_ids = torch.tensor(input_ids).to(device).unsqueeze(0)
87
+
88
+ # For generation
89
+ prompt_length = input_ids.shape[1]
90
+
91
+ # Initialize the sequence with masks for the response part
92
+ x = torch.full((1, prompt_length + gen_length), MASK_ID, dtype=torch.long).to(device)
93
+ x[:, :prompt_length] = input_ids.clone()
94
+
95
+ # Initialize visualization states for just the response part
96
+ visualization_states = []
97
+
98
+ # Add initial state (all masked) - only for the response part
99
+ initial_state = [(MASK_TOKEN, "#444444") for _ in range(gen_length)]
100
+ visualization_states.append(initial_state)
101
+
102
+ # Apply constraints to the initial state
103
+ for pos, token_id in processed_constraints.items():
104
+ absolute_pos = prompt_length + pos
105
+ if absolute_pos < x.shape[1]:
106
+ x[:, absolute_pos] = token_id
107
+
108
+ # Calculate timesteps
109
+ timesteps = torch.linspace(1.0, 0.0, steps + 1)[:-1]
110
+
111
+ # Keep track of already revealed tokens
112
+ revealed_tokens = torch.zeros(1, gen_length, dtype=torch.bool).to(device)
113
+
114
+ for step, t in enumerate(timesteps):
115
+ # Current t to next t
116
+ s = t - 1.0 / steps if step < steps - 1 else 0
117
+
118
+ # Get all mask positions in the current sequence
119
+ mask_indices = (x == MASK_ID)
120
+
121
+ # Skip if no masks
122
+ if not mask_indices.any():
123
+ break
124
+
125
+ # Get logits from the model
126
+ logits = model(x).logits
127
+
128
+ # Get the top predictions
129
+ x0 = torch.argmax(logits, dim=-1)
130
+
131
+ # Get probabilities for visualization
132
+ probs = torch.softmax(logits, dim=-1)
133
+ top_probs = torch.max(probs, dim=-1)[0]
134
+
135
+ # Apply the predictions where we have masks
136
+ x_old = x.clone()
137
+ x = torch.where(mask_indices, x0, x)
138
+
139
+ # Calculate how many tokens should remain masked at next step
140
+ total_len = gen_length
141
+ current_t_value = float(t)
142
+ next_t_value = float(s)
143
+
144
+ # Linear schedule: t=1 β†’ all masked, t=0 β†’ none masked
145
+ current_masks_expected = int(current_t_value * total_len)
146
+ next_masks_expected = int(next_t_value * total_len)
147
+
148
+ # How many to unmask in this step
149
+ tokens_to_unmask = current_masks_expected - next_masks_expected
150
+
151
+ if tokens_to_unmask > 0 and mask_indices.any():
152
+ # Get confidence scores for currently masked tokens
153
+ confidence_scores = top_probs[mask_indices]
154
+
155
+ # Sort confidence scores
156
+ sorted_indices = torch.argsort(confidence_scores, descending=True)
157
+
158
+ # Select which tokens to keep masked (the lowest confidence ones)
159
+ indices_to_remask = sorted_indices[tokens_to_unmask:]
160
+
161
+ # Get the actual indices in the sequence
162
+ mask_positions = torch.where(mask_indices)[1]
163
+ positions_to_remask = mask_positions[indices_to_remask]
164
+
165
+ # Remask these positions
166
+ x[:, positions_to_remask] = MASK_ID
167
+
168
+ # Ensure constraints are maintained
169
+ for pos, token_id in processed_constraints.items():
170
+ absolute_pos = prompt_length + pos
171
+ if absolute_pos < x.shape[1]:
172
+ x[:, absolute_pos] = token_id
173
+
174
+ # Create visualization state ONLY for the response part
175
+ current_state = []
176
+
177
+ # Update which tokens are newly revealed in this step
178
+ for i in range(gen_length):
179
+ pos = prompt_length + i # Absolute position in the sequence
180
+
181
+ if x[0, pos] == MASK_ID:
182
+ # Still masked
183
+ current_state.append((MASK_TOKEN, "#444444")) # Dark gray for masks
184
+
185
+ elif x_old[0, pos] == MASK_ID:
186
+ # Newly revealed in this step
187
+ token = tokenizer.decode([x[0, pos].item()], skip_special_tokens=True)
188
+ confidence = float(top_probs[0, pos].cpu())
189
+
190
+ # Color based on confidence: red (low) to green (high)
191
+ if confidence < 0.3:
192
+ color = "#FF6666" # Light red
193
+ elif confidence < 0.7:
194
+ color = "#FFAA33" # Orange
195
+ else:
196
+ color = "#66CC66" # Light green
197
+
198
+ current_state.append((token, color))
199
+ revealed_tokens[0, i] = True
200
+
201
+ else:
202
+ # Previously revealed
203
+ token = tokenizer.decode([x[0, pos].item()], skip_special_tokens=True)
204
+ current_state.append((token, "#6699CC")) # Light blue
205
+
206
+ visualization_states.append(current_state)
207
+
208
+ # Extract final text (just the assistant's response)
209
+ response_tokens = x[0, prompt_length:]
210
+ response_text = tokenizer.decode(response_tokens, skip_special_tokens=True)
211
+
212
+ # Clean the response text
213
+ final_text = tokenizer.decode(response_tokens,
214
+ skip_special_tokens=True,
215
+ clean_up_tokenization_spaces=True)
216
+
217
+ return visualization_states, final_text
218
+
219
+ css = '''
220
+ .category-legend{display:none}
221
+ button{height: 60px}
222
+ '''
223
+ def create_chatbot_demo():
224
+ with gr.Blocks(css=css) as demo:
225
+ gr.Markdown("# LLaDA - Large Language Diffusion Model demo")
226
+ gr.Markdown("[model](https://huggingface.co/GSAI-ML/LLaDA-8B-Instruct), [project page](https://ml-gsai.github.io/LLaDA-demo/)")
227
+
228
+ # STATE MANAGEMENT - IMPORTANT
229
+ # We use a dedicated state to track the full conversation history
230
+ chat_history = gr.State([])
231
+
232
+ # UI COMPONENTS
233
+ # Chatbot for displaying messages
234
+ with gr.Row():
235
+ with gr.Column(scale=3):
236
+ chatbot_ui = gr.Chatbot(label="Conversation", height=500)
237
+
238
+ # Message input
239
+ with gr.Group():
240
+ with gr.Row():
241
+ user_input = gr.Textbox(
242
+ label="Your Message",
243
+ placeholder="Type your message here...",
244
+ show_label=False
245
+ )
246
+ send_btn = gr.Button("Send")
247
+
248
+ constraints_input = gr.Textbox(
249
+ label="Word Constraints",
250
+ info="This model allows for placing specific words at specific positions using 'position:word' format. Example: 1st word once, 6th word 'upon' and 11th word 'time', would be: '0:Once, 5:upon, 10:time",
251
+ placeholder="0:Once, 5:upon, 10:time",
252
+ value=""
253
+ )
254
+ with gr.Column(scale=2):
255
+ output_vis = gr.HighlightedText(
256
+ label="Denoising Process Visualization",
257
+ combine_adjacent=False,
258
+ show_legend=True,
259
+ )
260
+ # Visualization and response components
261
+ with gr.Accordion("Generation Settings", open=False):
262
+ with gr.Row():
263
+ gen_length = gr.Slider(
264
+ minimum=16, maximum=128, value=64, step=8,
265
+ label="Generation Length"
266
+ )
267
+ steps = gr.Slider(
268
+ minimum=8, maximum=64, value=32, step=4,
269
+ label="Denoising Steps"
270
+ )
271
+
272
+
273
+ visualization_delay = gr.Slider(
274
+ minimum=0.0, maximum=1.0, value=0.1, step=0.1, visible=False,
275
+ label="Visualization Delay (seconds)"
276
+ )
277
+
278
+ # Current response text box
279
+ current_response = gr.Textbox(
280
+ label="Current Response",
281
+ placeholder="The assistant's response will appear here...",
282
+ lines=3,
283
+ visible=False
284
+ )
285
+
286
+ # Clear button
287
+ clear_btn = gr.Button("Clear Conversation")
288
+
289
+ # HELPER FUNCTIONS
290
+ def add_message(history, message, response):
291
+ """Add a message pair to the history and return the updated history"""
292
+ history = history.copy()
293
+ history.append([message, response])
294
+ return history
295
+
296
+ def user_message_submitted(message, history, gen_length, steps, constraints, delay):
297
+ """Process a submitted user message"""
298
+ # Skip empty messages
299
+ if not message.strip():
300
+ # Return current state unchanged
301
+ history_for_display = history.copy()
302
+ return history, history_for_display, "", [], ""
303
+
304
+ # Add user message to history
305
+ history = add_message(history, message, None)
306
+
307
+ # Format for display - temporarily show user message with empty response
308
+ history_for_display = history.copy()
309
+
310
+ # Clear the input
311
+ message_out = ""
312
+
313
+ # Return immediately to update UI with user message
314
+ return history, history_for_display, message_out, [], ""
315
+
316
+ def bot_response(history, gen_length, steps, constraints, delay):
317
+ """Generate bot response for the latest message"""
318
+ if not history:
319
+ return history, [], ""
320
+
321
+ # Get the last user message
322
+ last_user_message = history[-1][0]
323
+
324
+ try:
325
+ # Format all messages except the last one (which has no response yet)
326
+ messages = format_chat_history(history[:-1])
327
+
328
+ # Add the last user message
329
+ messages.append({"role": "user", "content": last_user_message})
330
+
331
+ # Parse constraints
332
+ parsed_constraints = parse_constraints(constraints)
333
+
334
+ # Generate response with visualization
335
+ vis_states, response_text = generate_response_with_visualization(
336
+ model, tokenizer, device,
337
+ messages,
338
+ gen_length=gen_length,
339
+ steps=steps,
340
+ constraints=parsed_constraints
341
+ )
342
+
343
+ # Update history with the assistant's response
344
+ history[-1][1] = response_text
345
+
346
+ # Return the initial state immediately
347
+ yield history, vis_states[0], response_text
348
+
349
+ # Then animate through visualization states
350
+ for state in vis_states[1:]:
351
+ time.sleep(delay)
352
+ yield history, state, response_text
353
+
354
+ except Exception as e:
355
+ error_msg = f"Error: {str(e)}"
356
+ print(error_msg)
357
+
358
+ # Show error in visualization
359
+ error_vis = [(error_msg, "red")]
360
+
361
+ # Don't update history with error
362
+ yield history, error_vis, error_msg
363
+
364
+ def clear_conversation():
365
+ """Clear the conversation history"""
366
+ return [], [], "", []
367
+
368
+ # EVENT HANDLERS
369
+
370
+ # Clear button handler
371
+ clear_btn.click(
372
+ fn=clear_conversation,
373
+ inputs=[],
374
+ outputs=[chat_history, chatbot_ui, current_response, output_vis]
375
+ )
376
+
377
+ # User message submission flow (2-step process)
378
+ # Step 1: Add user message to history and update UI
379
+ msg_submit = user_input.submit(
380
+ fn=user_message_submitted,
381
+ inputs=[user_input, chat_history, gen_length, steps, constraints_input, visualization_delay],
382
+ outputs=[chat_history, chatbot_ui, user_input, output_vis, current_response]
383
+ )
384
+
385
+ # Also connect the send button
386
+ send_click = send_btn.click(
387
+ fn=user_message_submitted,
388
+ inputs=[user_input, chat_history, gen_length, steps, constraints_input, visualization_delay],
389
+ outputs=[chat_history, chatbot_ui, user_input, output_vis, current_response]
390
+ )
391
+
392
+ # Step 2: Generate bot response
393
+ # This happens after the user message is displayed
394
+ msg_submit.then(
395
+ fn=bot_response,
396
+ inputs=[chat_history, gen_length, steps, constraints_input, visualization_delay],
397
+ outputs=[chatbot_ui, output_vis, current_response]
398
+ )
399
+
400
+ send_click.then(
401
+ fn=bot_response,
402
+ inputs=[chat_history, gen_length, steps, constraints_input, visualization_delay],
403
+ outputs=[chatbot_ui, output_vis, current_response]
404
+ )
405
+
406
+ return demo
407
+
408
+ # Launch the demo
409
+ if __name__ == "__main__":
410
+ demo = create_chatbot_demo()
411
+ demo.queue().launch(share=True)