Wan-AI-Wan2.1-T2V-1.3B / flask_api.py
wynai's picture
Upload 4 files
1bc0a1c verified
from flask import Flask, request, jsonify, send_file
from flask_cors import CORS
import os
import json
import logging
from api_endpoint import VideoGenerationAPI, download_models
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = Flask(__name__)
CORS(app) # Enable CORS for all routes
# Global API instance
api_instance = None
def get_api_instance():
"""Get or create the global API instance"""
global api_instance
if api_instance is None:
logger.info("Initializing Video Generation API...")
try:
# Download models first
download_models()
api_instance = VideoGenerationAPI()
logger.info("API instance created successfully")
except Exception as e:
logger.error(f"Failed to initialize API: {e}")
api_instance = None
return api_instance
@app.route('/health', methods=['GET'])
def health_check():
"""Health check endpoint"""
api = get_api_instance()
return jsonify({
"status": "healthy" if api is not None else "unhealthy",
"device": api.device if api else "unknown",
"models_loaded": api.pipe is not None if api else False
})
@app.route('/generate', methods=['POST'])
def generate_video():
"""Generate video from text prompt"""
api = get_api_instance()
if api is None:
return jsonify({
"error": "API not initialized",
"video_path": None
}), 500
try:
# Get request data
data = request.get_json()
if not data or 'prompt' not in data:
return jsonify({
"error": "Missing 'prompt' in request body",
"video_path": None
}), 400
# Extract parameters with defaults
prompt = data['prompt']
negative_prompt = data.get('negative_prompt', '')
seed = data.get('seed', -1)
cfg_scale = data.get('cfg_scale', 7.0)
clip_length = data.get('clip_length', 64)
motion_scale = data.get('motion_scale', 0.5)
fps = data.get('fps', 15.0)
enhance_prompt_flag = data.get('enhance_prompt_flag', True)
num_inference_steps = data.get('num_inference_steps', 50)
# Validate parameters
if not isinstance(prompt, str) or len(prompt.strip()) == 0:
return jsonify({
"error": "Prompt must be a non-empty string",
"video_path": None
}), 400
# Generate video
result = api.generate_video(
prompt=prompt,
negative_prompt=negative_prompt,
seed=seed,
cfg_scale=cfg_scale,
clip_length=clip_length,
motion_scale=motion_scale,
fps=fps,
enhance_prompt_flag=enhance_prompt_flag,
num_inference_steps=num_inference_steps
)
if "error" in result:
return jsonify(result), 500
return jsonify(result)
except Exception as e:
logger.error(f"Error in generate_video endpoint: {e}")
return jsonify({
"error": str(e),
"video_path": None
}), 500
@app.route('/video/<filename>', methods=['GET'])
def serve_video(filename):
"""Serve generated video files"""
try:
video_path = os.path.join("outputs", filename)
if os.path.exists(video_path):
return send_file(video_path, mimetype='video/mp4')
else:
return jsonify({"error": "Video file not found"}), 404
except Exception as e:
logger.error(f"Error serving video: {e}")
return jsonify({"error": str(e)}), 500
@app.route('/enhance_prompt', methods=['POST'])
def enhance_prompt():
"""Enhance a text prompt using LLM"""
api = get_api_instance()
if api is None:
return jsonify({
"error": "API not initialized",
"enhanced_prompt": None
}), 500
try:
data = request.get_json()
if not data or 'prompt' not in data:
return jsonify({
"error": "Missing 'prompt' in request body",
"enhanced_prompt": None
}), 400
prompt = data['prompt']
if not isinstance(prompt, str) or len(prompt.strip()) == 0:
return jsonify({
"error": "Prompt must be a non-empty string",
"enhanced_prompt": None
}), 400
enhanced_prompt = api.enhance_prompt(prompt)
return jsonify({
"original_prompt": prompt,
"enhanced_prompt": enhanced_prompt
})
except Exception as e:
logger.error(f"Error in enhance_prompt endpoint: {e}")
return jsonify({
"error": str(e),
"enhanced_prompt": None
}), 500
@app.route('/', methods=['GET'])
def index():
"""API documentation"""
return jsonify({
"name": "Self-Forcing Video Generation API",
"version": "1.0.0",
"description": "Generate high-quality videos from text descriptions using the Self-Forcing model",
"endpoints": {
"GET /": "API documentation",
"GET /health": "Health check",
"POST /generate": "Generate video from text prompt",
"POST /enhance_prompt": "Enhance text prompt using LLM",
"GET /video/<filename>": "Serve generated video files"
},
"example_request": {
"url": "/generate",
"method": "POST",
"body": {
"prompt": "A cat playing with a ball in a sunny garden",
"negative_prompt": "blurry, low quality, distorted",
"seed": -1,
"cfg_scale": 7.0,
"clip_length": 64,
"motion_scale": 0.5,
"fps": 15.0,
"enhance_prompt_flag": True,
"num_inference_steps": 50
}
}
})
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description="Self-Forcing Video Generation Flask API")
parser.add_argument('--port', type=int, default=5000, help="Port to run the API on")
parser.add_argument('--host', type=str, default='0.0.0.0', help="Host to bind the API to")
parser.add_argument('--debug', action='store_true', help="Run in debug mode")
args = parser.parse_args()
logger.info(f"Starting Flask API on {args.host}:{args.port}")
app.run(host=args.host, port=args.port, debug=args.debug)