Spaces:
Sleeping
Sleeping
from flask import Flask, request, jsonify, send_file | |
from flask_cors import CORS | |
import os | |
import json | |
import logging | |
from api_endpoint import VideoGenerationAPI, download_models | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
app = Flask(__name__) | |
CORS(app) # Enable CORS for all routes | |
# Global API instance | |
api_instance = None | |
def get_api_instance(): | |
"""Get or create the global API instance""" | |
global api_instance | |
if api_instance is None: | |
logger.info("Initializing Video Generation API...") | |
try: | |
# Download models first | |
download_models() | |
api_instance = VideoGenerationAPI() | |
logger.info("API instance created successfully") | |
except Exception as e: | |
logger.error(f"Failed to initialize API: {e}") | |
api_instance = None | |
return api_instance | |
def health_check(): | |
"""Health check endpoint""" | |
api = get_api_instance() | |
return jsonify({ | |
"status": "healthy" if api is not None else "unhealthy", | |
"device": api.device if api else "unknown", | |
"models_loaded": api.pipe is not None if api else False | |
}) | |
def generate_video(): | |
"""Generate video from text prompt""" | |
api = get_api_instance() | |
if api is None: | |
return jsonify({ | |
"error": "API not initialized", | |
"video_path": None | |
}), 500 | |
try: | |
# Get request data | |
data = request.get_json() | |
if not data or 'prompt' not in data: | |
return jsonify({ | |
"error": "Missing 'prompt' in request body", | |
"video_path": None | |
}), 400 | |
# Extract parameters with defaults | |
prompt = data['prompt'] | |
negative_prompt = data.get('negative_prompt', '') | |
seed = data.get('seed', -1) | |
cfg_scale = data.get('cfg_scale', 7.0) | |
clip_length = data.get('clip_length', 64) | |
motion_scale = data.get('motion_scale', 0.5) | |
fps = data.get('fps', 15.0) | |
enhance_prompt_flag = data.get('enhance_prompt_flag', True) | |
num_inference_steps = data.get('num_inference_steps', 50) | |
# Validate parameters | |
if not isinstance(prompt, str) or len(prompt.strip()) == 0: | |
return jsonify({ | |
"error": "Prompt must be a non-empty string", | |
"video_path": None | |
}), 400 | |
# Generate video | |
result = api.generate_video( | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
seed=seed, | |
cfg_scale=cfg_scale, | |
clip_length=clip_length, | |
motion_scale=motion_scale, | |
fps=fps, | |
enhance_prompt_flag=enhance_prompt_flag, | |
num_inference_steps=num_inference_steps | |
) | |
if "error" in result: | |
return jsonify(result), 500 | |
return jsonify(result) | |
except Exception as e: | |
logger.error(f"Error in generate_video endpoint: {e}") | |
return jsonify({ | |
"error": str(e), | |
"video_path": None | |
}), 500 | |
def serve_video(filename): | |
"""Serve generated video files""" | |
try: | |
video_path = os.path.join("outputs", filename) | |
if os.path.exists(video_path): | |
return send_file(video_path, mimetype='video/mp4') | |
else: | |
return jsonify({"error": "Video file not found"}), 404 | |
except Exception as e: | |
logger.error(f"Error serving video: {e}") | |
return jsonify({"error": str(e)}), 500 | |
def enhance_prompt(): | |
"""Enhance a text prompt using LLM""" | |
api = get_api_instance() | |
if api is None: | |
return jsonify({ | |
"error": "API not initialized", | |
"enhanced_prompt": None | |
}), 500 | |
try: | |
data = request.get_json() | |
if not data or 'prompt' not in data: | |
return jsonify({ | |
"error": "Missing 'prompt' in request body", | |
"enhanced_prompt": None | |
}), 400 | |
prompt = data['prompt'] | |
if not isinstance(prompt, str) or len(prompt.strip()) == 0: | |
return jsonify({ | |
"error": "Prompt must be a non-empty string", | |
"enhanced_prompt": None | |
}), 400 | |
enhanced_prompt = api.enhance_prompt(prompt) | |
return jsonify({ | |
"original_prompt": prompt, | |
"enhanced_prompt": enhanced_prompt | |
}) | |
except Exception as e: | |
logger.error(f"Error in enhance_prompt endpoint: {e}") | |
return jsonify({ | |
"error": str(e), | |
"enhanced_prompt": None | |
}), 500 | |
def index(): | |
"""API documentation""" | |
return jsonify({ | |
"name": "Self-Forcing Video Generation API", | |
"version": "1.0.0", | |
"description": "Generate high-quality videos from text descriptions using the Self-Forcing model", | |
"endpoints": { | |
"GET /": "API documentation", | |
"GET /health": "Health check", | |
"POST /generate": "Generate video from text prompt", | |
"POST /enhance_prompt": "Enhance text prompt using LLM", | |
"GET /video/<filename>": "Serve generated video files" | |
}, | |
"example_request": { | |
"url": "/generate", | |
"method": "POST", | |
"body": { | |
"prompt": "A cat playing with a ball in a sunny garden", | |
"negative_prompt": "blurry, low quality, distorted", | |
"seed": -1, | |
"cfg_scale": 7.0, | |
"clip_length": 64, | |
"motion_scale": 0.5, | |
"fps": 15.0, | |
"enhance_prompt_flag": True, | |
"num_inference_steps": 50 | |
} | |
} | |
}) | |
if __name__ == '__main__': | |
import argparse | |
parser = argparse.ArgumentParser(description="Self-Forcing Video Generation Flask API") | |
parser.add_argument('--port', type=int, default=5000, help="Port to run the API on") | |
parser.add_argument('--host', type=str, default='0.0.0.0', help="Host to bind the API to") | |
parser.add_argument('--debug', action='store_true', help="Run in debug mode") | |
args = parser.parse_args() | |
logger.info(f"Starting Flask API on {args.host}:{args.port}") | |
app.run(host=args.host, port=args.port, debug=args.debug) | |