Prashant26am commited on
Commit
e5d40e3
·
1 Parent(s): 8d272fe

fix: Update Gradio to 4.44.1 and improve interface

Browse files
.github/workflows/ci.yml ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: CI
2
+
3
+ on:
4
+ push:
5
+ branches: [ main ]
6
+ pull_request:
7
+ branches: [ main ]
8
+
9
+ jobs:
10
+ test:
11
+ runs-on: ubuntu-latest
12
+ strategy:
13
+ matrix:
14
+ python-version: ["3.8", "3.9", "3.10", "3.11"]
15
+
16
+ steps:
17
+ - uses: actions/checkout@v4
18
+
19
+ - name: Set up Python ${{ matrix.python-version }}
20
+ uses: actions/setup-python@v5
21
+ with:
22
+ python-version: ${{ matrix.python-version }}
23
+ cache: 'pip'
24
+
25
+ - name: Install dependencies
26
+ run: |
27
+ python -m pip install --upgrade pip
28
+ pip install -r requirements.txt
29
+ pip install -r requirements-dev.txt
30
+
31
+ - name: Run pre-commit hooks
32
+ run: |
33
+ pre-commit install
34
+ pre-commit run --all-files
35
+
36
+ - name: Run tests
37
+ run: |
38
+ pytest --cov=src --cov-report=xml
39
+
40
+ - name: Upload coverage to Codecov
41
+ uses: codecov/codecov-action@v4
42
+ with:
43
+ file: ./coverage.xml
44
+ fail_ci_if_error: true
45
+
46
+ lint:
47
+ runs-on: ubuntu-latest
48
+ steps:
49
+ - uses: actions/checkout@v4
50
+
51
+ - name: Set up Python
52
+ uses: actions/setup-python@v5
53
+ with:
54
+ python-version: "3.11"
55
+ cache: 'pip'
56
+
57
+ - name: Install dependencies
58
+ run: |
59
+ python -m pip install --upgrade pip
60
+ pip install -r requirements-dev.txt
61
+
62
+ - name: Run black
63
+ run: black --check src tests
64
+
65
+ - name: Run isort
66
+ run: isort --check-only src tests
67
+
68
+ - name: Run flake8
69
+ run: flake8 src tests
70
+
71
+ - name: Run mypy
72
+ run: mypy src
73
+
74
+ build:
75
+ needs: [test, lint]
76
+ runs-on: ubuntu-latest
77
+ if: github.event_name == 'push' && github.ref == 'refs/heads/main'
78
+
79
+ steps:
80
+ - uses: actions/checkout@v4
81
+
82
+ - name: Set up Python
83
+ uses: actions/setup-python@v5
84
+ with:
85
+ python-version: "3.11"
86
+ cache: 'pip'
87
+
88
+ - name: Install dependencies
89
+ run: |
90
+ python -m pip install --upgrade pip
91
+ pip install build twine
92
+
93
+ - name: Build package
94
+ run: python -m build
95
+
96
+ - name: Check package
97
+ run: twine check dist/*
98
+
99
+ - name: Upload artifacts
100
+ uses: actions/upload-artifact@v4
101
+ with:
102
+ name: dist
103
+ path: dist/
104
+
105
+ deploy:
106
+ needs: build
107
+ runs-on: ubuntu-latest
108
+ if: github.event_name == 'push' && github.ref == 'refs/heads/main'
109
+
110
+ steps:
111
+ - uses: actions/checkout@v4
112
+
113
+ - name: Download artifacts
114
+ uses: actions/download-artifact@v4
115
+ with:
116
+ name: dist
117
+ path: dist
118
+
119
+ - name: Set up Python
120
+ uses: actions/setup-python@v5
121
+ with:
122
+ python-version: "3.11"
123
+ cache: 'pip'
124
+
125
+ - name: Install dependencies
126
+ run: |
127
+ python -m pip install --upgrade pip
128
+ pip install twine
129
+
130
+ - name: Deploy to PyPI
131
+ env:
132
+ TWINE_USERNAME: __token__
133
+ TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
134
+ run: twine upload dist/*
135
+
136
+ - name: Deploy to Hugging Face
137
+ env:
138
+ HF_TOKEN: ${{ secrets.HF_TOKEN }}
139
+ run: |
140
+ pip install huggingface_hub
141
+ huggingface-cli login --token $HF_TOKEN
142
+ huggingface-cli upload Prashant26am/llava-chat dist/* --repo-type space
.pre-commit-config.yaml ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/pre-commit/pre-commit-hooks
3
+ rev: v4.5.0
4
+ hooks:
5
+ - id: trailing-whitespace
6
+ - id: end-of-file-fixer
7
+ - id: check-yaml
8
+ - id: check-added-large-files
9
+ - id: check-ast
10
+ - id: check-json
11
+ - id: check-merge-conflict
12
+ - id: detect-private-key
13
+ - id: debug-statements
14
+
15
+ - repo: https://github.com/psf/black
16
+ rev: 24.1.1
17
+ hooks:
18
+ - id: black
19
+ language_version: python3.8
20
+
21
+ - repo: https://github.com/pycqa/isort
22
+ rev: 5.13.2
23
+ hooks:
24
+ - id: isort
25
+ args: ["--profile", "black"]
26
+
27
+ - repo: https://github.com/pycqa/flake8
28
+ rev: 7.0.0
29
+ hooks:
30
+ - id: flake8
31
+ additional_dependencies:
32
+ - flake8-docstrings
33
+ - flake8-bugbear
34
+ - flake8-comprehensions
35
+ - flake8-simplify
36
+ - flake8-unused-arguments
37
+ - flake8-variables-names
38
+ - pep8-naming
39
+
40
+ - repo: https://github.com/pre-commit/mirrors-mypy
41
+ rev: v1.8.0
42
+ hooks:
43
+ - id: mypy
44
+ additional_dependencies:
45
+ - types-Pillow
46
+ - types-requests
47
+ - types-setuptools
48
+ - types-urllib3
49
+
50
+ - repo: https://github.com/asottile/pyupgrade
51
+ rev: v3.15.0
52
+ hooks:
53
+ - id: pyupgrade
54
+ args: [--py38-plus]
55
+
56
+ - repo: https://github.com/PyCQA/bandit
57
+ rev: 1.7.7
58
+ hooks:
59
+ - id: bandit
60
+ args: ["-c", "pyproject.toml"]
61
+ additional_dependencies: ["bandit[toml]"]
62
+
63
+ - repo: https://github.com/pre-commit/mirrors-prettier
64
+ rev: v4.0.0-alpha.8
65
+ hooks:
66
+ - id: prettier
67
+ types_or: [javascript, jsx, ts, tsx, json, css, scss, md, yaml, yml]
68
+ additional_dependencies:
69
+ - prettier@4.0.0-alpha.8
README.md CHANGED
@@ -1,52 +1,159 @@
1
- ---
2
- title: LLaVA Chat
3
- emoji: 🖼️
4
- colorFrom: blue
5
- colorTo: indigo
6
- sdk: gradio
7
- sdk_version: 3.50.2
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- ---
12
-
13
- # LLaVA Chat
14
-
15
- A powerful multimodal AI assistant that can understand and discuss images. Upload any image and chat with LLaVA about it!
16
-
17
- ## Features
18
-
19
- - 🖼️ Upload and analyze any image
20
- - 💬 Natural conversation about image content
21
- - ⚙️ Adjustable generation parameters
22
- - 🎯 High-quality image understanding
23
- - 🚀 Fast and responsive interface
24
-
25
- ## How to Use
26
-
27
- 1. Upload an image using the image uploader
28
- 2. Type your question or prompt about the image
29
- 3. (Optional) Adjust the generation parameters:
30
- - Max New Tokens: Control response length
31
- - Temperature: Adjust response creativity
32
- - Top P: Fine-tune response diversity
33
- 4. Click "Generate Response" to get LLaVA's analysis
34
-
35
- ## Example Prompts
36
-
37
- - "What can you see in this image?"
38
- - "Describe this scene in detail"
39
- - "What emotions does this image convey?"
40
- - "What's happening in this picture?"
41
- - "Can you identify any objects or people in this image?"
42
-
43
- ## Model Details
44
-
45
- This Space uses the LLaVA (Large Language and Vision Assistant) model, which combines:
46
- - CLIP ViT-L/14 vision encoder
47
- - Vicuna-7B language model
48
- - Advanced multimodal understanding capabilities
49
-
50
- ## License
51
-
52
- This project is licensed under the MIT License.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LLaVA Implementation
2
+
3
+ [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
4
+ [![Python 3.8+](https://img.shields.io/badge/python-3.8+-blue.svg)](https://www.python.org/downloads/)
5
+ [![Gradio](https://img.shields.io/badge/Gradio-4.44.1-orange.svg)](https://gradio.app/)
6
+ [![Hugging Face](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/Prashant26am/llava-chat)
7
+
8
+ A modern implementation of LLaVA (Large Language and Vision Assistant) with a beautiful web interface. This project combines state-of-the-art vision and language models to create an interactive AI assistant that can understand and discuss images.
9
+
10
+ ## 🌟 Features
11
+
12
+ - **Modern Web Interface**
13
+ - Beautiful Gradio-based UI
14
+ - Real-time image analysis
15
+ - Interactive chat experience
16
+ - Responsive design
17
+
18
+ - **Advanced AI Capabilities**
19
+ - CLIP ViT-L/14 vision encoder
20
+ - Vicuna-7B language model
21
+ - Multimodal understanding
22
+ - Natural conversation flow
23
+
24
+ - **Developer Friendly**
25
+ - Clean, modular codebase
26
+ - Comprehensive documentation
27
+ - Easy deployment options
28
+ - Extensible architecture
29
+
30
+ ## 📋 Project Structure
31
+
32
+ ```
33
+ llava_implementation/
34
+ ├── src/ # Source code
35
+ │ ├── api/ # API endpoints and FastAPI app
36
+ │ ├── models/ # Model implementations
37
+ │ ├── utils/ # Utility functions
38
+ │ └── configs/ # Configuration files
39
+ ├── tests/ # Test suite
40
+ ├── docs/ # Documentation
41
+ │ ├── api/ # API documentation
42
+ │ ├── examples/ # Usage examples
43
+ │ └── guides/ # User and developer guides
44
+ ├── assets/ # Static assets
45
+ │ ├── images/ # Example images
46
+ │ └── icons/ # UI icons
47
+ ├── scripts/ # Utility scripts
48
+ └── examples/ # Example images for the web interface
49
+ ```
50
+
51
+ ## 🚀 Quick Start
52
+
53
+ ### Prerequisites
54
+
55
+ - Python 3.8+
56
+ - CUDA-capable GPU (recommended)
57
+ - Git
58
+
59
+ ### Installation
60
+
61
+ 1. Clone the repository:
62
+ ```bash
63
+ git clone https://github.com/Prashant-ambati/llava-implementation.git
64
+ cd llava-implementation
65
+ ```
66
+
67
+ 2. Create and activate a virtual environment:
68
+ ```bash
69
+ python -m venv venv
70
+ source venv/bin/activate # On Windows: venv\Scripts\activate
71
+ ```
72
+
73
+ 3. Install dependencies:
74
+ ```bash
75
+ pip install -r requirements.txt
76
+ ```
77
+
78
+ ### Running Locally
79
+
80
+ 1. Start the development server:
81
+ ```bash
82
+ python src/api/app.py
83
+ ```
84
+
85
+ 2. Open your browser and navigate to:
86
+ ```
87
+ http://localhost:7860
88
+ ```
89
+
90
+ ## 🌐 Web Deployment
91
+
92
+ ### Hugging Face Spaces
93
+
94
+ The application is deployed on Hugging Face Spaces:
95
+ - [Live Demo](https://huggingface.co/spaces/Prashant26am/llava-chat)
96
+ - Automatic deployment from main branch
97
+ - Free GPU resources
98
+ - Public API access
99
+
100
+ ### Local Deployment
101
+
102
+ For local deployment:
103
+ ```bash
104
+ # Build the application
105
+ python -m build
106
+
107
+ # Run with production settings
108
+ python src/api/app.py --production
109
+ ```
110
+
111
+ ## 📚 Documentation
112
+
113
+ - [API Documentation](docs/api/README.md)
114
+ - [User Guide](docs/guides/user_guide.md)
115
+ - [Developer Guide](docs/guides/developer_guide.md)
116
+ - [Examples](docs/examples/README.md)
117
+
118
+ ## 🛠️ Development
119
+
120
+ ### Running Tests
121
+
122
+ ```bash
123
+ pytest tests/
124
+ ```
125
+
126
+ ### Code Style
127
+
128
+ This project follows PEP 8 guidelines. To check your code:
129
+
130
+ ```bash
131
+ flake8 src/
132
+ black src/
133
+ ```
134
+
135
+ ### Contributing
136
+
137
+ 1. Fork the repository
138
+ 2. Create a feature branch
139
+ 3. Commit your changes
140
+ 4. Push to the branch
141
+ 5. Create a Pull Request
142
+
143
+ ## 📝 License
144
+
145
+ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
146
+
147
+ ## 🙏 Acknowledgments
148
+
149
+ - [LLaVA Paper](https://arxiv.org/abs/2304.08485) by Microsoft Research
150
+ - [Gradio](https://gradio.app/) for the web interface
151
+ - [Hugging Face](https://huggingface.co/) for model hosting
152
+ - [Vicuna](https://lmsys.org/blog/2023-03-30-vicuna/) for the language model
153
+ - [CLIP](https://openai.com/research/clip) for the vision model
154
+
155
+ ## 📞 Contact
156
+
157
+ - GitHub Issues: [Report a bug](https://github.com/Prashant-ambati/llava-implementation/issues)
158
+ - Email: [Your Email]
159
+ - Twitter: [@YourTwitter]
app.py DELETED
@@ -1,148 +0,0 @@
1
- from fastapi import FastAPI, UploadFile, File, HTTPException
2
- from fastapi.middleware.cors import CORSMiddleware
3
- from fastapi.responses import JSONResponse
4
- import os
5
- import tempfile
6
- from typing import Optional
7
- from pydantic import BaseModel
8
- import torch
9
- import gradio as gr
10
- from models.llava import LLaVA
11
-
12
- # Initialize model globally
13
- model = None
14
-
15
- def initialize_model():
16
- global model
17
- try:
18
- model = LLaVA(
19
- vision_model_path="openai/clip-vit-large-patch14-336",
20
- language_model_path="lmsys/vicuna-7b-v1.5",
21
- device="cuda" if torch.cuda.is_available() else "cpu",
22
- load_in_8bit=True
23
- )
24
- print(f"Model initialized on {model.device}")
25
- return True
26
- except Exception as e:
27
- print(f"Error initializing model: {e}")
28
- return False
29
-
30
- def process_image(image, prompt, max_new_tokens=256, temperature=0.7, top_p=0.9):
31
- if not model:
32
- return "Error: Model not initialized"
33
-
34
- try:
35
- # Save the uploaded image temporarily
36
- with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as temp_file:
37
- image.save(temp_file.name)
38
- temp_path = temp_file.name
39
-
40
- # Generate response
41
- response = model.generate_from_image(
42
- image_path=temp_path,
43
- prompt=prompt,
44
- max_new_tokens=max_new_tokens,
45
- temperature=temperature,
46
- top_p=top_p
47
- )
48
-
49
- # Clean up temporary file
50
- os.unlink(temp_path)
51
- return response
52
-
53
- except Exception as e:
54
- return f"Error processing image: {str(e)}"
55
-
56
- # Create Gradio interface
57
- def create_interface():
58
- with gr.Blocks(title="LLaVA Chat", theme=gr.themes.Soft()) as demo:
59
- gr.Markdown("""
60
- # LLaVA Chat
61
- Upload an image and chat with LLaVA about it. This model can understand and describe images, answer questions about them, and engage in visual conversations.
62
- """)
63
-
64
- with gr.Row():
65
- with gr.Column(scale=1):
66
- image_input = gr.Image(type="pil", label="Upload Image")
67
- prompt_input = gr.Textbox(
68
- label="Ask about the image",
69
- placeholder="What can you see in this image?",
70
- lines=3
71
- )
72
-
73
- with gr.Accordion("Advanced Settings", open=False):
74
- max_tokens = gr.Slider(
75
- minimum=32,
76
- maximum=512,
77
- value=256,
78
- step=32,
79
- label="Max New Tokens"
80
- )
81
- temperature = gr.Slider(
82
- minimum=0.1,
83
- maximum=1.0,
84
- value=0.7,
85
- step=0.1,
86
- label="Temperature"
87
- )
88
- top_p = gr.Slider(
89
- minimum=0.1,
90
- maximum=1.0,
91
- value=0.9,
92
- step=0.1,
93
- label="Top P"
94
- )
95
-
96
- submit_btn = gr.Button("Generate Response", variant="primary")
97
-
98
- with gr.Column(scale=1):
99
- output = gr.Textbox(
100
- label="Model Response",
101
- lines=10,
102
- show_copy_button=True
103
- )
104
-
105
- # Set up the submit action
106
- submit_btn.click(
107
- fn=process_image,
108
- inputs=[image_input, prompt_input, max_tokens, temperature, top_p],
109
- outputs=output
110
- )
111
-
112
- # Add examples
113
- gr.Examples(
114
- examples=[
115
- ["examples/cat.jpg", "What can you see in this image?"],
116
- ["examples/landscape.jpg", "Describe this scene in detail."],
117
- ["examples/food.jpg", "What kind of food is this and how would you describe it?"]
118
- ],
119
- inputs=[image_input, prompt_input]
120
- )
121
-
122
- return demo
123
-
124
- # Create FastAPI app
125
- app = FastAPI(title="LLaVA Web Interface")
126
-
127
- # Configure CORS
128
- app.add_middleware(
129
- CORSMiddleware,
130
- allow_origins=["*"],
131
- allow_credentials=True,
132
- allow_methods=["*"],
133
- allow_headers=["*"],
134
- )
135
-
136
- # Create Gradio app
137
- demo = create_interface()
138
-
139
- # Mount Gradio app
140
- app = gr.mount_gradio_app(app, demo, path="/")
141
-
142
- if __name__ == "__main__":
143
- # Initialize model
144
- if initialize_model():
145
- import uvicorn
146
- uvicorn.run(app, host="0.0.0.0", port=7860) # Hugging Face Spaces uses port 7860
147
- else:
148
- print("Failed to initialize model. Exiting...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
docs/api/README.md ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LLaVA API Documentation
2
+
3
+ ## Overview
4
+
5
+ The LLaVA API provides a simple interface for interacting with the LLaVA model through a Gradio web interface. The API allows users to upload images and receive AI-generated responses about the image content.
6
+
7
+ ## API Endpoints
8
+
9
+ ### Web Interface
10
+
11
+ The main interface is served at the root URL (`/`) and provides the following components:
12
+
13
+ #### Input Components
14
+
15
+ 1. **Image Upload**
16
+ - Type: Image uploader
17
+ - Format: PIL Image
18
+ - Purpose: Upload an image for analysis
19
+
20
+ 2. **Prompt Input**
21
+ - Type: Text input
22
+ - Purpose: Enter questions or prompts about the image
23
+ - Default placeholder: "What can you see in this image?"
24
+
25
+ 3. **Generation Parameters**
26
+ - Max New Tokens (64-2048, default: 512)
27
+ - Temperature (0.1-1.0, default: 0.7)
28
+ - Top P (0.1-1.0, default: 0.9)
29
+
30
+ #### Output Components
31
+
32
+ 1. **Response**
33
+ - Type: Text output
34
+ - Purpose: Displays the model's response
35
+ - Features: Copy button, scrollable
36
+
37
+ ## Usage Examples
38
+
39
+ ### Basic Usage
40
+
41
+ 1. Upload an image using the image uploader
42
+ 2. Enter a prompt in the text input
43
+ 3. Click "Generate Response"
44
+ 4. View the response in the output box
45
+
46
+ ### Example Prompts
47
+
48
+ - "What can you see in this image?"
49
+ - "Describe this scene in detail"
50
+ - "What emotions does this image convey?"
51
+ - "What's happening in this picture?"
52
+ - "Can you identify any objects or people in this image?"
53
+
54
+ ## Error Handling
55
+
56
+ The API handles various error cases:
57
+
58
+ 1. **Invalid Images**
59
+ - Returns an error message if the image is invalid or corrupted
60
+ - Supports common image formats (JPEG, PNG, etc.)
61
+
62
+ 2. **Empty Prompts**
63
+ - Returns an error message if no prompt is provided
64
+ - Prompts should be non-empty strings
65
+
66
+ 3. **Model Errors**
67
+ - Returns descriptive error messages for model-related issues
68
+ - Includes logging for debugging
69
+
70
+ ## Configuration
71
+
72
+ The API can be configured through environment variables or the settings file:
73
+
74
+ - `API_HOST`: Server host (default: "0.0.0.0")
75
+ - `API_PORT`: Server port (default: 7860)
76
+ - `GRADIO_THEME`: Interface theme (default: "soft")
77
+ - `DEFAULT_MAX_NEW_TOKENS`: Default token limit (default: 512)
78
+ - `DEFAULT_TEMPERATURE`: Default temperature (default: 0.7)
79
+ - `DEFAULT_TOP_P`: Default top-p value (default: 0.9)
80
+
81
+ ## Development
82
+
83
+ ### Running Locally
84
+
85
+ ```bash
86
+ python src/api/app.py
87
+ ```
88
+
89
+ ### Running Tests
90
+
91
+ ```bash
92
+ pytest tests/
93
+ ```
94
+
95
+ ### Code Style
96
+
97
+ The project follows PEP 8 guidelines. To check your code:
98
+
99
+ ```bash
100
+ flake8 src/
101
+ black src/
102
+ ```
103
+
104
+ ## Security Considerations
105
+
106
+ 1. The API is designed for public use but should be deployed behind appropriate security measures
107
+ 2. Input validation is performed on all user inputs
108
+ 3. Large file uploads are handled safely
109
+ 4. Error messages are sanitized to prevent information leakage
110
+
111
+ ## Rate Limiting
112
+
113
+ Currently, no rate limiting is implemented. Consider implementing rate limiting for production deployments.
114
+
115
+ ## Future Improvements
116
+
117
+ 1. Add authentication
118
+ 2. Implement rate limiting
119
+ 3. Add batch processing capabilities
120
+ 4. Support for video input
121
+ 5. Real-time streaming responses
docs/guides/developer_guide.md ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LLaVA Implementation Developer Guide
2
+
3
+ ## Overview
4
+
5
+ This guide is intended for developers who want to contribute to or extend the LLaVA implementation. The project is structured as a Python package with a Gradio web interface, using modern best practices and tools.
6
+
7
+ ## Project Structure
8
+
9
+ ```
10
+ llava_implementation/
11
+ ├── src/ # Source code
12
+ │ ├── api/ # API endpoints and FastAPI app
13
+ │ │ ├── __init__.py
14
+ │ │ └── app.py # Gradio interface
15
+ │ ├── models/ # Model implementations
16
+ │ │ ├── __init__.py
17
+ │ │ └── llava_model.py # LLaVA model wrapper
18
+ │ ├── utils/ # Utility functions
19
+ │ │ ├── __init__.py
20
+ │ │ └── logging.py # Logging utilities
21
+ │ └── configs/ # Configuration files
22
+ │ ├── __init__.py
23
+ │ └── settings.py # Application settings
24
+ ├── tests/ # Test suite
25
+ │ ├── __init__.py
26
+ │ └── test_model.py # Model tests
27
+ ├── docs/ # Documentation
28
+ │ ├── api/ # API documentation
29
+ │ ├── examples/ # Usage examples
30
+ │ └── guides/ # User and developer guides
31
+ ├── assets/ # Static assets
32
+ │ ├── images/ # Example images
33
+ │ └── icons/ # UI icons
34
+ ├── scripts/ # Utility scripts
35
+ └── examples/ # Example images for the web interface
36
+ ```
37
+
38
+ ## Development Setup
39
+
40
+ ### Prerequisites
41
+
42
+ - Python 3.8+
43
+ - Git
44
+ - CUDA-capable GPU (recommended)
45
+ - Virtual environment tool (venv, conda, etc.)
46
+
47
+ ### Installation
48
+
49
+ 1. Clone the repository:
50
+ ```bash
51
+ git clone https://github.com/Prashant-ambati/llava-implementation.git
52
+ cd llava-implementation
53
+ ```
54
+
55
+ 2. Create and activate a virtual environment:
56
+ ```bash
57
+ python -m venv venv
58
+ source venv/bin/activate # On Windows: venv\Scripts\activate
59
+ ```
60
+
61
+ 3. Install development dependencies:
62
+ ```bash
63
+ pip install -r requirements.txt
64
+ pip install -r requirements-dev.txt # Development dependencies
65
+ ```
66
+
67
+ ### Development Tools
68
+
69
+ 1. **Code Formatting**
70
+ - Black for code formatting
71
+ - isort for import sorting
72
+ - flake8 for linting
73
+
74
+ 2. **Testing**
75
+ - pytest for testing
76
+ - pytest-cov for coverage
77
+ - pytest-mock for mocking
78
+
79
+ 3. **Type Checking**
80
+ - mypy for static type checking
81
+ - types-* packages for type hints
82
+
83
+ ## Code Style
84
+
85
+ ### Python Style Guide
86
+
87
+ 1. Follow PEP 8 guidelines
88
+ 2. Use type hints
89
+ 3. Write docstrings (Google style)
90
+ 4. Keep functions focused and small
91
+ 5. Use meaningful variable names
92
+
93
+ ### Example
94
+
95
+ ```python
96
+ from typing import Optional, List
97
+ from PIL import Image
98
+
99
+ def process_image(
100
+ image: Image.Image,
101
+ prompt: str,
102
+ max_tokens: Optional[int] = None
103
+ ) -> List[str]:
104
+ """
105
+ Process an image with the given prompt.
106
+
107
+ Args:
108
+ image: Input image as PIL Image
109
+ prompt: Text prompt for the model
110
+ max_tokens: Optional maximum tokens to generate
111
+
112
+ Returns:
113
+ List of generated responses
114
+
115
+ Raises:
116
+ ValueError: If image is invalid
117
+ RuntimeError: If model fails to process
118
+ """
119
+ # Implementation
120
+ ```
121
+
122
+ ## Testing
123
+
124
+ ### Running Tests
125
+
126
+ ```bash
127
+ # Run all tests
128
+ pytest
129
+
130
+ # Run with coverage
131
+ pytest --cov=src
132
+
133
+ # Run specific test file
134
+ pytest tests/test_model.py
135
+
136
+ # Run with verbose output
137
+ pytest -v
138
+ ```
139
+
140
+ ### Writing Tests
141
+
142
+ 1. Use pytest fixtures
143
+ 2. Mock external dependencies
144
+ 3. Test edge cases
145
+ 4. Include both unit and integration tests
146
+
147
+ Example test:
148
+ ```python
149
+ import pytest
150
+ from PIL import Image
151
+
152
+ def test_process_image(model, sample_image):
153
+ """Test image processing functionality."""
154
+ prompt = "What color is this image?"
155
+ response = model.process_image(
156
+ image=sample_image,
157
+ prompt=prompt
158
+ )
159
+ assert isinstance(response, str)
160
+ assert len(response) > 0
161
+ ```
162
+
163
+ ## Model Development
164
+
165
+ ### Adding New Models
166
+
167
+ 1. Create a new model class in `src/models/`
168
+ 2. Implement required methods
169
+ 3. Add tests
170
+ 4. Update documentation
171
+
172
+ Example:
173
+ ```python
174
+ class NewModel:
175
+ """New model implementation."""
176
+
177
+ def __init__(self, config: dict):
178
+ """Initialize the model."""
179
+ self.config = config
180
+ self.model = self._load_model()
181
+
182
+ def process(self, *args, **kwargs):
183
+ """Process inputs and generate output."""
184
+ pass
185
+ ```
186
+
187
+ ### Model Configuration
188
+
189
+ 1. Add configuration in `src/configs/settings.py`
190
+ 2. Use environment variables for secrets
191
+ 3. Document all parameters
192
+
193
+ ## API Development
194
+
195
+ ### Adding New Endpoints
196
+
197
+ 1. Create new endpoint in `src/api/app.py`
198
+ 2. Add input validation
199
+ 3. Implement error handling
200
+ 4. Add tests
201
+ 5. Update documentation
202
+
203
+ ### Error Handling
204
+
205
+ 1. Use custom exceptions
206
+ 2. Implement proper logging
207
+ 3. Return appropriate status codes
208
+ 4. Include error messages
209
+
210
+ Example:
211
+ ```python
212
+ class ModelError(Exception):
213
+ """Base exception for model errors."""
214
+ pass
215
+
216
+ def process_request(request):
217
+ try:
218
+ result = model.process(request)
219
+ return result
220
+ except ModelError as e:
221
+ logger.error(f"Model error: {e}")
222
+ raise HTTPException(status_code=500, detail=str(e))
223
+ ```
224
+
225
+ ## Deployment
226
+
227
+ ### Local Deployment
228
+
229
+ 1. Build the package:
230
+ ```bash
231
+ python -m build
232
+ ```
233
+
234
+ 2. Run the server:
235
+ ```bash
236
+ python src/api/app.py
237
+ ```
238
+
239
+ ### Hugging Face Spaces
240
+
241
+ 1. Update `README.md` with Space metadata
242
+ 2. Ensure all dependencies are in `requirements.txt`
243
+ 3. Test the Space locally
244
+ 4. Push changes to the Space
245
+
246
+ ### Production Deployment
247
+
248
+ 1. Set up proper logging
249
+ 2. Configure security measures
250
+ 3. Implement rate limiting
251
+ 4. Set up monitoring
252
+ 5. Use environment variables
253
+
254
+ ## Contributing
255
+
256
+ ### Workflow
257
+
258
+ 1. Fork the repository
259
+ 2. Create a feature branch
260
+ 3. Make changes
261
+ 4. Run tests
262
+ 5. Update documentation
263
+ 6. Create a pull request
264
+
265
+ ### Pull Request Process
266
+
267
+ 1. Update documentation
268
+ 2. Add tests
269
+ 3. Ensure CI passes
270
+ 4. Get code review
271
+ 5. Address feedback
272
+ 6. Merge when approved
273
+
274
+ ## Performance Optimization
275
+
276
+ ### Model Optimization
277
+
278
+ 1. Use model quantization
279
+ 2. Implement caching
280
+ 3. Batch processing
281
+ 4. GPU optimization
282
+
283
+ ### API Optimization
284
+
285
+ 1. Response compression
286
+ 2. Request validation
287
+ 3. Connection pooling
288
+ 4. Caching strategies
289
+
290
+ ## Security
291
+
292
+ ### Best Practices
293
+
294
+ 1. Input validation
295
+ 2. Error handling
296
+ 3. Rate limiting
297
+ 4. Secure configuration
298
+ 5. Regular updates
299
+
300
+ ### Security Checklist
301
+
302
+ - [ ] Validate all inputs
303
+ - [ ] Sanitize outputs
304
+ - [ ] Use secure dependencies
305
+ - [ ] Implement rate limiting
306
+ - [ ] Set up monitoring
307
+ - [ ] Regular security audits
308
+
309
+ ## Monitoring and Logging
310
+
311
+ ### Logging
312
+
313
+ 1. Use structured logging
314
+ 2. Include context
315
+ 3. Set appropriate levels
316
+ 4. Rotate logs
317
+
318
+ ### Monitoring
319
+
320
+ 1. Track key metrics
321
+ 2. Set up alerts
322
+ 3. Monitor resources
323
+ 4. Track errors
324
+
325
+ ## Future Development
326
+
327
+ ### Planned Features
328
+
329
+ 1. Video support
330
+ 2. Batch processing
331
+ 3. Model fine-tuning
332
+ 4. API authentication
333
+ 5. Advanced caching
334
+
335
+ ### Contributing Ideas
336
+
337
+ 1. Open issues
338
+ 2. Discuss in PRs
339
+ 3. Join discussions
340
+ 4. Share use cases
341
+
342
+ ## Resources
343
+
344
+ ### Documentation
345
+
346
+ - [Python Documentation](https://docs.python.org/)
347
+ - [Gradio Documentation](https://gradio.app/docs/)
348
+ - [Hugging Face Docs](https://huggingface.co/docs)
349
+ - [Pytest Documentation](https://docs.pytest.org/)
350
+
351
+ ### Tools
352
+
353
+ - [Black](https://black.readthedocs.io/)
354
+ - [isort](https://pycqa.github.io/isort/)
355
+ - [flake8](https://flake8.pycqa.org/)
356
+ - [mypy](https://mypy.readthedocs.io/)
357
+
358
+ ### Community
359
+
360
+ - [GitHub Issues](https://github.com/Prashant-ambati/llava-implementation/issues)
361
+ - [Hugging Face Forums](https://discuss.huggingface.co/)
362
+ - [Stack Overflow](https://stackoverflow.com/)
docs/guides/user_guide.md ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LLaVA Chat User Guide
2
+
3
+ ## Introduction
4
+
5
+ Welcome to LLaVA Chat! This guide will help you get started with using our AI-powered image understanding and chat interface. LLaVA (Large Language and Vision Assistant) combines advanced vision and language models to provide detailed analysis and natural conversations about images.
6
+
7
+ ## Getting Started
8
+
9
+ ### Accessing the Interface
10
+
11
+ 1. Visit our [Hugging Face Space](https://huggingface.co/spaces/Prashant26am/llava-chat)
12
+ 2. Wait for the interface to load (this may take a few moments as the model initializes)
13
+ 3. You're ready to start chatting with images!
14
+
15
+ ### Basic Usage
16
+
17
+ 1. **Upload an Image**
18
+ - Click the image upload area or drag and drop an image
19
+ - Supported formats: JPEG, PNG, GIF
20
+ - Maximum file size: 10MB
21
+
22
+ 2. **Enter Your Prompt**
23
+ - Type your question or prompt in the text box
24
+ - Be specific about what you want to know
25
+ - You can ask multiple questions about the same image
26
+
27
+ 3. **Adjust Parameters** (Optional)
28
+ - Click "Generation Parameters" to expand
29
+ - Modify settings to control the response:
30
+ - Max New Tokens: Longer responses (64-2048)
31
+ - Temperature: More creative responses (0.1-1.0)
32
+ - Top P: More diverse responses (0.1-1.0)
33
+
34
+ 4. **Generate Response**
35
+ - Click the "Generate Response" button
36
+ - Wait for the model to process (usually a few seconds)
37
+ - Read the response in the output box
38
+ - Use the copy button to save the response
39
+
40
+ ## Best Practices
41
+
42
+ ### Writing Effective Prompts
43
+
44
+ 1. **Be Specific**
45
+ - Instead of "What's in this image?", try "What objects can you identify in this image?"
46
+ - Instead of "Describe this", try "Describe the scene, focusing on the main subject"
47
+
48
+ 2. **Ask Follow-up Questions**
49
+ - "What emotions does this image convey?"
50
+ - "Can you identify any specific details about [object]?"
51
+ - "How would you describe the composition of this image?"
52
+
53
+ 3. **Use Natural Language**
54
+ - Write as if you're talking to a person
55
+ - Feel free to ask for clarification or more details
56
+ - You can have a conversation about the image
57
+
58
+ ### Example Prompts
59
+
60
+ 1. **General Analysis**
61
+ - "What can you see in this image?"
62
+ - "Describe this scene in detail"
63
+ - "What's the main subject of this image?"
64
+
65
+ 2. **Specific Details**
66
+ - "What colors are prominent in this image?"
67
+ - "Can you identify any text or signs in the image?"
68
+ - "What time of day does this image appear to be taken?"
69
+
70
+ 3. **Emotional Response**
71
+ - "What mood or atmosphere does this image convey?"
72
+ - "How does this image make you feel?"
73
+ - "What emotions might this image evoke in viewers?"
74
+
75
+ 4. **Technical Analysis**
76
+ - "What's the composition of this image?"
77
+ - "How would you describe the lighting in this image?"
78
+ - "What camera angle or perspective is used?"
79
+
80
+ ## Troubleshooting
81
+
82
+ ### Common Issues
83
+
84
+ 1. **Image Not Loading**
85
+ - Check file format (JPEG, PNG, GIF)
86
+ - Ensure file size is under 10MB
87
+ - Try refreshing the page
88
+
89
+ 2. **Slow Response**
90
+ - Reduce image size
91
+ - Simplify your prompt
92
+ - Check your internet connection
93
+
94
+ 3. **Unexpected Responses**
95
+ - Try rephrasing your prompt
96
+ - Adjust generation parameters
97
+ - Be more specific in your question
98
+
99
+ ### Getting Help
100
+
101
+ If you encounter any issues:
102
+ 1. Check this guide for solutions
103
+ 2. Visit our [GitHub repository](https://github.com/Prashant-ambati/llava-implementation)
104
+ 3. Open an issue on GitHub
105
+ 4. Contact us through Hugging Face
106
+
107
+ ## Advanced Usage
108
+
109
+ ### Parameter Tuning
110
+
111
+ 1. **Max New Tokens**
112
+ - Lower values (64-256): Short, concise responses
113
+ - Medium values (256-512): Balanced responses
114
+ - Higher values (512+): Detailed, comprehensive responses
115
+
116
+ 2. **Temperature**
117
+ - Lower values (0.1-0.3): More focused, deterministic responses
118
+ - Medium values (0.4-0.7): Balanced creativity
119
+ - Higher values (0.8-1.0): More creative, diverse responses
120
+
121
+ 3. **Top P**
122
+ - Lower values (0.1-0.3): More focused word choice
123
+ - Medium values (0.4-0.7): Balanced diversity
124
+ - Higher values (0.8-1.0): More diverse word choice
125
+
126
+ ### Tips for Better Results
127
+
128
+ 1. **Image Quality**
129
+ - Use clear, well-lit images
130
+ - Ensure the subject is clearly visible
131
+ - Avoid heavily edited or filtered images
132
+
133
+ 2. **Prompt Engineering**
134
+ - Start with simple questions
135
+ - Build up to more complex queries
136
+ - Use follow-up questions for details
137
+
138
+ 3. **Response Management**
139
+ - Copy important responses
140
+ - Save interesting conversations
141
+ - Compare responses with different parameters
142
+
143
+ ## Privacy and Ethics
144
+
145
+ 1. **Image Privacy**
146
+ - Don't upload sensitive or private images
147
+ - Be mindful of copyright
148
+ - Respect others' privacy
149
+
150
+ 2. **Responsible Use**
151
+ - Use the tool ethically
152
+ - Don't use for harmful purposes
153
+ - Respect content guidelines
154
+
155
+ ## Future Updates
156
+
157
+ We're constantly improving LLaVA Chat. Planned features include:
158
+ 1. Support for video input
159
+ 2. Batch image processing
160
+ 3. More advanced parameter controls
161
+ 4. Additional model options
162
+ 5. Enhanced response formatting
163
+
164
+ Stay tuned for updates!
examples/api_client.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Example API client for the LLaVA model.
3
+ """
4
+
5
+ import argparse
6
+ import json
7
+ from pathlib import Path
8
+ from typing import Dict, Any, Optional
9
+
10
+ import requests
11
+ from PIL import Image
12
+ import base64
13
+ from io import BytesIO
14
+
15
+ def encode_image(image_path: str) -> str:
16
+ """
17
+ Encode an image to base64 string.
18
+
19
+ Args:
20
+ image_path: Path to the image file
21
+
22
+ Returns:
23
+ str: Base64 encoded image
24
+ """
25
+ with open(image_path, "rb") as image_file:
26
+ return base64.b64encode(image_file.read()).decode('utf-8')
27
+
28
+ def process_image(
29
+ api_url: str,
30
+ image_path: str,
31
+ prompt: str,
32
+ max_new_tokens: Optional[int] = None,
33
+ temperature: Optional[float] = None,
34
+ top_p: Optional[float] = None
35
+ ) -> Dict[str, Any]:
36
+ """
37
+ Process an image using the LLaVA API.
38
+
39
+ Args:
40
+ api_url: URL of the API endpoint
41
+ image_path: Path to the input image
42
+ prompt: Text prompt for the model
43
+ max_new_tokens: Optional maximum tokens to generate
44
+ temperature: Optional sampling temperature
45
+ top_p: Optional top-p sampling parameter
46
+
47
+ Returns:
48
+ Dict containing the API response
49
+ """
50
+ # Prepare the request payload
51
+ payload = {
52
+ "image": encode_image(image_path),
53
+ "prompt": prompt
54
+ }
55
+
56
+ # Add optional parameters if provided
57
+ if max_new_tokens is not None:
58
+ payload["max_new_tokens"] = max_new_tokens
59
+ if temperature is not None:
60
+ payload["temperature"] = temperature
61
+ if top_p is not None:
62
+ payload["top_p"] = top_p
63
+
64
+ try:
65
+ # Send the request
66
+ response = requests.post(api_url, json=payload)
67
+ response.raise_for_status()
68
+ return response.json()
69
+
70
+ except requests.exceptions.RequestException as e:
71
+ print(f"Error making request: {e}")
72
+ if hasattr(e.response, 'text'):
73
+ print(f"Response: {e.response.text}")
74
+ raise
75
+
76
+ def save_response(response: Dict[str, Any], output_path: Optional[str] = None):
77
+ """
78
+ Save or print the API response.
79
+
80
+ Args:
81
+ response: API response dictionary
82
+ output_path: Optional path to save the response
83
+ """
84
+ if output_path:
85
+ with open(output_path, 'w') as f:
86
+ json.dump(response, f, indent=2)
87
+ print(f"Saved response to {output_path}")
88
+ else:
89
+ print("\nAPI Response:")
90
+ print("-" * 50)
91
+ print(json.dumps(response, indent=2))
92
+ print("-" * 50)
93
+
94
+ def main():
95
+ """Main function to process images using the API."""
96
+ parser = argparse.ArgumentParser(description="Process images using LLaVA API")
97
+ parser.add_argument("image_path", type=str, help="Path to the input image")
98
+ parser.add_argument("prompt", type=str, help="Text prompt for the model")
99
+ parser.add_argument("--api-url", type=str, default="http://localhost:7860/api/process",
100
+ help="URL of the API endpoint")
101
+ parser.add_argument("--max-tokens", type=int, help="Maximum tokens to generate")
102
+ parser.add_argument("--temperature", type=float, help="Sampling temperature")
103
+ parser.add_argument("--top-p", type=float, help="Top-p sampling parameter")
104
+ parser.add_argument("--output", type=str, help="Path to save the response")
105
+
106
+ args = parser.parse_args()
107
+
108
+ try:
109
+ # Process image
110
+ response = process_image(
111
+ api_url=args.api_url,
112
+ image_path=args.image_path,
113
+ prompt=args.prompt,
114
+ max_new_tokens=args.max_tokens,
115
+ temperature=args.temperature,
116
+ top_p=args.top_p
117
+ )
118
+
119
+ # Save or print response
120
+ save_response(response, args.output)
121
+
122
+ except Exception as e:
123
+ print(f"Error: {str(e)}")
124
+ raise
125
+
126
+ if __name__ == "__main__":
127
+ main()
examples/llava_demo.ipynb ADDED
@@ -0,0 +1 @@
 
 
1
+
examples/process_image.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Example script for processing images with the LLaVA model.
3
+ """
4
+
5
+ import argparse
6
+ from pathlib import Path
7
+ from PIL import Image
8
+
9
+ from src.models.llava_model import LLaVAModel
10
+ from src.configs.settings import DEFAULT_MAX_NEW_TOKENS, DEFAULT_TEMPERATURE, DEFAULT_TOP_P
11
+ from src.utils.logging import setup_logging, get_logger
12
+
13
+ # Set up logging
14
+ setup_logging()
15
+ logger = get_logger(__name__)
16
+
17
+ def process_image(
18
+ image_path: str,
19
+ prompt: str,
20
+ max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
21
+ temperature: float = DEFAULT_TEMPERATURE,
22
+ top_p: float = DEFAULT_TOP_P
23
+ ) -> str:
24
+ """
25
+ Process an image with the LLaVA model.
26
+
27
+ Args:
28
+ image_path: Path to the input image
29
+ prompt: Text prompt for the model
30
+ max_new_tokens: Maximum number of tokens to generate
31
+ temperature: Sampling temperature
32
+ top_p: Top-p sampling parameter
33
+
34
+ Returns:
35
+ str: Model response
36
+ """
37
+ try:
38
+ # Load image
39
+ image = Image.open(image_path)
40
+ logger.info(f"Loaded image from {image_path}")
41
+
42
+ # Initialize model
43
+ model = LLaVAModel()
44
+ logger.info("Model initialized")
45
+
46
+ # Generate response
47
+ response = model(
48
+ image=image,
49
+ prompt=prompt,
50
+ max_new_tokens=max_new_tokens,
51
+ temperature=temperature,
52
+ top_p=top_p
53
+ )
54
+ logger.info("Generated response")
55
+
56
+ return response
57
+
58
+ except Exception as e:
59
+ logger.error(f"Error processing image: {str(e)}")
60
+ raise
61
+
62
+ def main():
63
+ """Main function to process images from command line."""
64
+ parser = argparse.ArgumentParser(description="Process images with LLaVA model")
65
+ parser.add_argument("image_path", type=str, help="Path to the input image")
66
+ parser.add_argument("prompt", type=str, help="Text prompt for the model")
67
+ parser.add_argument("--max-tokens", type=int, default=DEFAULT_MAX_NEW_TOKENS,
68
+ help="Maximum number of tokens to generate")
69
+ parser.add_argument("--temperature", type=float, default=DEFAULT_TEMPERATURE,
70
+ help="Sampling temperature")
71
+ parser.add_argument("--top-p", type=float, default=DEFAULT_TOP_P,
72
+ help="Top-p sampling parameter")
73
+ parser.add_argument("--output", type=str, help="Path to save the response")
74
+
75
+ args = parser.parse_args()
76
+
77
+ try:
78
+ # Process image
79
+ response = process_image(
80
+ image_path=args.image_path,
81
+ prompt=args.prompt,
82
+ max_new_tokens=args.max_tokens,
83
+ temperature=args.temperature,
84
+ top_p=args.top_p
85
+ )
86
+
87
+ # Print or save response
88
+ if args.output:
89
+ output_path = Path(args.output)
90
+ output_path.write_text(response)
91
+ logger.info(f"Saved response to {output_path}")
92
+ else:
93
+ print("\nModel Response:")
94
+ print("-" * 50)
95
+ print(response)
96
+ print("-" * 50)
97
+
98
+ except Exception as e:
99
+ logger.error(f"Error: {str(e)}")
100
+ raise
101
+
102
+ if __name__ == "__main__":
103
+ main()
pyproject.toml ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=61.0", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "llava-implementation"
7
+ version = "0.1.0"
8
+ description = "A modern implementation of LLaVA with a beautiful web interface"
9
+ readme = "README.md"
10
+ requires-python = ">=3.8"
11
+ license = {text = "MIT"}
12
+ authors = [
13
+ {name = "Prashant Ambati", email = "your.email@example.com"}
14
+ ]
15
+ classifiers = [
16
+ "Development Status :: 4 - Beta",
17
+ "Intended Audience :: Developers",
18
+ "Intended Audience :: Science/Research",
19
+ "License :: OSI Approved :: MIT License",
20
+ "Programming Language :: Python :: 3",
21
+ "Programming Language :: Python :: 3.8",
22
+ "Programming Language :: Python :: 3.9",
23
+ "Programming Language :: Python :: 3.10",
24
+ "Programming Language :: Python :: 3.11",
25
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
26
+ "Topic :: Software Development :: Libraries :: Python Modules",
27
+ ]
28
+ dependencies = [
29
+ "torch>=2.0.0",
30
+ "torchvision>=0.15.0",
31
+ "transformers>=4.36.0",
32
+ "accelerate>=0.25.0",
33
+ "pillow>=10.0.0",
34
+ "numpy>=1.24.0",
35
+ "tqdm>=4.65.0",
36
+ "matplotlib>=3.7.0",
37
+ "opencv-python>=4.8.0",
38
+ "einops>=0.7.0",
39
+ "timm>=0.9.0",
40
+ "sentencepiece>=0.1.99",
41
+ "peft>=0.7.0",
42
+ "bitsandbytes>=0.41.0",
43
+ "safetensors>=0.4.0",
44
+ "gradio==4.44.1",
45
+ "fastapi>=0.109.0",
46
+ "uvicorn>=0.27.0",
47
+ "python-multipart>=0.0.6",
48
+ "pydantic>=2.5.0",
49
+ "python-jose>=3.3.0",
50
+ "passlib>=1.7.4",
51
+ "bcrypt>=4.0.1",
52
+ "aiofiles>=23.2.0",
53
+ "httpx>=0.26.0",
54
+ ]
55
+
56
+ [project.optional-dependencies]
57
+ dev = [
58
+ "pytest>=8.0.0",
59
+ "pytest-cov>=4.1.0",
60
+ "pytest-mock>=3.12.0",
61
+ "pytest-asyncio>=0.23.5",
62
+ "pytest-xdist>=3.5.0",
63
+ "black>=24.1.1",
64
+ "isort>=5.13.2",
65
+ "flake8>=7.0.0",
66
+ "mypy>=1.8.0",
67
+ "types-Pillow>=10.2.0.20240106",
68
+ "types-requests>=2.31.0.20240125",
69
+ "sphinx>=7.2.6",
70
+ "sphinx-rtd-theme>=2.0.0",
71
+ "sphinx-autodoc-typehints>=2.0.1",
72
+ "sphinx-copybutton>=0.5.2",
73
+ "sphinx-tabs>=3.4.4",
74
+ "pre-commit>=3.6.0",
75
+ "ipython>=8.21.0",
76
+ "jupyter>=1.0.0",
77
+ "notebook>=7.0.7",
78
+ "ipykernel>=6.29.0",
79
+ "build>=1.0.3",
80
+ "twine>=4.0.2",
81
+ "wheel>=0.42.0",
82
+ "memory-profiler>=0.61.0",
83
+ "line-profiler>=4.1.2",
84
+ "debugpy>=1.8.0",
85
+ ]
86
+
87
+ [project.urls]
88
+ Homepage = "https://github.com/Prashant-ambati/llava-implementation"
89
+ Documentation = "https://github.com/Prashant-ambati/llava-implementation#readme"
90
+ Repository = "https://github.com/Prashant-ambati/llava-implementation.git"
91
+ Issues = "https://github.com/Prashant-ambati/llava-implementation/issues"
92
+ "Bug Tracker" = "https://github.com/Prashant-ambati/llava-implementation/issues"
93
+
94
+ [tool.setuptools]
95
+ packages = ["src"]
96
+
97
+ [tool.black]
98
+ line-length = 88
99
+ target-version = ["py38"]
100
+ include = '\.pyi?$'
101
+
102
+ [tool.isort]
103
+ profile = "black"
104
+ multi_line_output = 3
105
+ include_trailing_comma = true
106
+ force_grid_wrap = 0
107
+ use_parentheses = true
108
+ ensure_newline_before_comments = true
109
+ line_length = 88
110
+
111
+ [tool.mypy]
112
+ python_version = "3.8"
113
+ warn_return_any = true
114
+ warn_unused_configs = true
115
+ disallow_untyped_defs = true
116
+ disallow_incomplete_defs = true
117
+ check_untyped_defs = true
118
+ disallow_untyped_decorators = true
119
+ no_implicit_optional = true
120
+ warn_redundant_casts = true
121
+ warn_unused_ignores = true
122
+ warn_no_return = true
123
+ warn_unreachable = true
124
+ strict_optional = true
125
+
126
+ [tool.pytest.ini_options]
127
+ minversion = "6.0"
128
+ addopts = "-ra -q --cov=src"
129
+ testpaths = [
130
+ "tests",
131
+ ]
132
+ python_files = ["test_*.py"]
133
+ python_classes = ["Test*"]
134
+ python_functions = ["test_*"]
135
+
136
+ [tool.coverage.run]
137
+ source = ["src"]
138
+ branch = true
139
+
140
+ [tool.coverage.report]
141
+ exclude_lines = [
142
+ "pragma: no cover",
143
+ "def __repr__",
144
+ "if self.debug:",
145
+ "raise NotImplementedError",
146
+ "if __name__ == .__main__.:",
147
+ "pass",
148
+ "raise ImportError",
149
+ ]
150
+ show_missing = true
151
+ fail_under = 80
152
+
153
+ [tool.bandit]
154
+ exclude_dirs = ["tests", "docs"]
155
+ skips = ["B101"]
156
+
157
+ [tool.ruff]
158
+ line-length = 88
159
+ target-version = "py38"
160
+ select = [
161
+ "E", # pycodestyle errors
162
+ "W", # pycodestyle warnings
163
+ "F", # pyflakes
164
+ "I", # isort
165
+ "B", # flake8-bugbear
166
+ "C4", # flake8-comprehensions
167
+ "UP", # pyupgrade
168
+ "N", # pep8-naming
169
+ "PL", # pylint
170
+ "RUF", # ruff-specific rules
171
+ ]
172
+ ignore = [
173
+ "E501", # line length violations
174
+ "B008", # do not perform function calls in argument defaults
175
+ ]
176
+
177
+ [tool.ruff.isort]
178
+ known-first-party = ["src"]
179
+
180
+ [tool.ruff.mccabe]
181
+ max-complexity = 10
requirements-dev.txt ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Testing
2
+ pytest==8.0.0
3
+ pytest-cov==4.1.0
4
+ pytest-mock==3.12.0
5
+ pytest-asyncio==0.23.5
6
+ pytest-xdist==3.5.0
7
+
8
+ # Code Quality
9
+ black==24.1.1
10
+ isort==5.13.2
11
+ flake8==7.0.0
12
+ mypy==1.8.0
13
+ types-Pillow==10.2.0.20240106
14
+ types-requests==2.31.0.20240125
15
+
16
+ # Documentation
17
+ sphinx==7.2.6
18
+ sphinx-rtd-theme==2.0.0
19
+ sphinx-autodoc-typehints==2.0.1
20
+ sphinx-copybutton==0.5.2
21
+ sphinx-tabs==3.4.4
22
+
23
+ # Development Tools
24
+ pre-commit==3.6.0
25
+ ipython==8.21.0
26
+ jupyter==1.0.0
27
+ notebook==7.0.7
28
+ ipykernel==6.29.0
29
+
30
+ # Build Tools
31
+ build==1.0.3
32
+ twine==4.0.2
33
+ wheel==0.42.0
34
+
35
+ # Monitoring and Debugging
36
+ memory-profiler==0.61.0
37
+ line-profiler==4.1.2
38
+ debugpy==1.8.0
requirements.txt CHANGED
@@ -1,26 +1,25 @@
1
  torch>=2.0.0
2
  torchvision>=0.15.0
3
- transformers>=4.30.0
4
- accelerate>=0.20.0
5
- pillow>=9.0.0
6
  numpy>=1.24.0
7
  tqdm>=4.65.0
8
  matplotlib>=3.7.0
9
- opencv-python>=4.7.0
10
- einops>=0.6.0
11
  timm>=0.9.0
12
  sentencepiece>=0.1.99
13
- gradio>=3.35.0
14
- peft>=0.4.0
15
- bitsandbytes>=0.40.0
16
- safetensors>=0.3.1
17
- fastapi==0.104.1
18
- uvicorn==0.24.0
19
- python-multipart==0.0.6
20
- pydantic==2.5.2
21
- python-jose==3.3.0
22
- passlib==1.7.4
23
- bcrypt==4.0.1
24
- aiofiles==23.2.1
25
- python-dotenv==1.0.0
26
- httpx==0.25.2
 
1
  torch>=2.0.0
2
  torchvision>=0.15.0
3
+ transformers>=4.36.0
4
+ accelerate>=0.25.0
5
+ pillow>=10.0.0
6
  numpy>=1.24.0
7
  tqdm>=4.65.0
8
  matplotlib>=3.7.0
9
+ opencv-python>=4.8.0
10
+ einops>=0.7.0
11
  timm>=0.9.0
12
  sentencepiece>=0.1.99
13
+ peft>=0.7.0
14
+ bitsandbytes>=0.41.0
15
+ safetensors>=0.4.0
16
+ gradio==4.44.1
17
+ fastapi>=0.109.0
18
+ uvicorn>=0.27.0
19
+ python-multipart>=0.0.6
20
+ pydantic>=2.5.0
21
+ python-jose>=3.3.0
22
+ passlib>=1.7.4
23
+ bcrypt>=4.0.1
24
+ aiofiles>=23.2.0
25
+ httpx>=0.26.0
 
src/__init__.py ADDED
File without changes
src/api/__init__.py ADDED
File without changes
src/api/app.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio interface for the LLaVA model.
3
+ """
4
+
5
+ import gradio as gr
6
+ from PIL import Image
7
+
8
+ from ..configs.settings import (
9
+ GRADIO_THEME,
10
+ GRADIO_TITLE,
11
+ GRADIO_DESCRIPTION,
12
+ DEFAULT_MAX_NEW_TOKENS,
13
+ DEFAULT_TEMPERATURE,
14
+ DEFAULT_TOP_P,
15
+ API_HOST,
16
+ API_PORT,
17
+ API_WORKERS,
18
+ API_RELOAD
19
+ )
20
+ from ..models.llava_model import LLaVAModel
21
+ from ..utils.logging import setup_logging, get_logger
22
+
23
+ # Set up logging
24
+ setup_logging()
25
+ logger = get_logger(__name__)
26
+
27
+ # Initialize model
28
+ model = LLaVAModel()
29
+
30
+ def process_image(
31
+ image: Image.Image,
32
+ prompt: str,
33
+ max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
34
+ temperature: float = DEFAULT_TEMPERATURE,
35
+ top_p: float = DEFAULT_TOP_P
36
+ ) -> str:
37
+ """
38
+ Process an image with the LLaVA model.
39
+
40
+ Args:
41
+ image: Input image
42
+ prompt: Text prompt
43
+ max_new_tokens: Maximum number of tokens to generate
44
+ temperature: Sampling temperature
45
+ top_p: Top-p sampling parameter
46
+
47
+ Returns:
48
+ str: Model response
49
+ """
50
+ try:
51
+ logger.info(f"Processing image with prompt: {prompt[:100]}...")
52
+ response = model(
53
+ image=image,
54
+ prompt=prompt,
55
+ max_new_tokens=max_new_tokens,
56
+ temperature=temperature,
57
+ top_p=top_p
58
+ )
59
+ logger.info("Successfully generated response")
60
+ return response
61
+ except Exception as e:
62
+ logger.error(f"Error processing image: {str(e)}")
63
+ return f"Error: {str(e)}"
64
+
65
+ def create_interface() -> gr.Interface:
66
+ """Create and return the Gradio interface."""
67
+ with gr.Blocks(theme=GRADIO_THEME) as interface:
68
+ gr.Markdown(f"""# {GRADIO_TITLE}
69
+
70
+ {GRADIO_DESCRIPTION}
71
+
72
+ ## Example Prompts
73
+
74
+ Try these prompts to get started:
75
+ - "What can you see in this image?"
76
+ - "Describe this scene in detail"
77
+ - "What emotions does this image convey?"
78
+ - "What's happening in this picture?"
79
+ - "Can you identify any objects or people in this image?"
80
+
81
+ ## Usage Instructions
82
+
83
+ 1. Upload an image using the image uploader
84
+ 2. Enter your prompt in the text box
85
+ 3. (Optional) Adjust the generation parameters
86
+ 4. Click "Generate Response" to get LLaVA's analysis
87
+ """)
88
+
89
+ with gr.Row():
90
+ with gr.Column():
91
+ # Input components
92
+ image_input = gr.Image(type="pil", label="Upload Image")
93
+ prompt_input = gr.Textbox(
94
+ label="Prompt",
95
+ placeholder="What can you see in this image?",
96
+ lines=3
97
+ )
98
+
99
+ with gr.Accordion("Generation Parameters", open=False):
100
+ max_tokens = gr.Slider(
101
+ minimum=64,
102
+ maximum=2048,
103
+ value=DEFAULT_MAX_NEW_TOKENS,
104
+ step=64,
105
+ label="Max New Tokens"
106
+ )
107
+ temperature = gr.Slider(
108
+ minimum=0.1,
109
+ maximum=1.0,
110
+ value=DEFAULT_TEMPERATURE,
111
+ step=0.1,
112
+ label="Temperature"
113
+ )
114
+ top_p = gr.Slider(
115
+ minimum=0.1,
116
+ maximum=1.0,
117
+ value=DEFAULT_TOP_P,
118
+ step=0.1,
119
+ label="Top P"
120
+ )
121
+
122
+ generate_btn = gr.Button("Generate Response", variant="primary")
123
+
124
+ with gr.Column():
125
+ # Output component
126
+ output = gr.Textbox(
127
+ label="Response",
128
+ lines=10,
129
+ show_copy_button=True
130
+ )
131
+
132
+ # Set up event handlers
133
+ generate_btn.click(
134
+ fn=process_image,
135
+ inputs=[
136
+ image_input,
137
+ prompt_input,
138
+ max_tokens,
139
+ temperature,
140
+ top_p
141
+ ],
142
+ outputs=output
143
+ )
144
+
145
+ return interface
146
+
147
+ def main():
148
+ """Run the Gradio interface."""
149
+ interface = create_interface()
150
+ interface.launch(
151
+ server_name=API_HOST,
152
+ server_port=API_PORT,
153
+ share=True,
154
+ show_error=True,
155
+ show_api=False
156
+ )
157
+
158
+ if __name__ == "__main__":
159
+ main()
src/configs/__init__.py ADDED
File without changes
src/configs/settings.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Configuration settings for the LLaVA implementation.
3
+ """
4
+
5
+ import os
6
+ from pathlib import Path
7
+
8
+ # Project paths
9
+ PROJECT_ROOT = Path(__file__).parent.parent.parent
10
+ SRC_DIR = PROJECT_ROOT / "src"
11
+ ASSETS_DIR = PROJECT_ROOT / "assets"
12
+ EXAMPLES_DIR = PROJECT_ROOT / "examples"
13
+
14
+ # Model settings
15
+ MODEL_NAME = "liuhaotian/llava-v1.5-7b"
16
+ MODEL_REVISION = "main"
17
+ DEVICE = "cuda" if os.environ.get("CUDA_VISIBLE_DEVICES") else "cpu"
18
+
19
+ # Generation settings
20
+ DEFAULT_MAX_NEW_TOKENS = 512
21
+ DEFAULT_TEMPERATURE = 0.7
22
+ DEFAULT_TOP_P = 0.9
23
+
24
+ # API settings
25
+ API_HOST = "0.0.0.0"
26
+ API_PORT = 7860
27
+ API_WORKERS = 1
28
+ API_RELOAD = True
29
+
30
+ # Gradio settings
31
+ GRADIO_THEME = "soft"
32
+ GRADIO_TITLE = "LLaVA Chat"
33
+ GRADIO_DESCRIPTION = """
34
+ A powerful multimodal AI assistant that can understand and discuss images.
35
+ Upload any image and chat with LLaVA about it!
36
+ """
37
+
38
+ # Logging settings
39
+ LOG_LEVEL = "INFO"
40
+ LOG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
41
+ LOG_DIR = PROJECT_ROOT / "logs"
42
+ LOG_FILE = LOG_DIR / "app.log"
43
+
44
+ # Create necessary directories
45
+ for directory in [ASSETS_DIR, EXAMPLES_DIR, LOG_DIR]:
46
+ directory.mkdir(parents=True, exist_ok=True)
src/models/__init__.py ADDED
File without changes
src/models/llava_model.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LLaVA model implementation.
3
+ """
4
+
5
+ import torch
6
+ from transformers import AutoProcessor, AutoModelForCausalLM
7
+ from PIL import Image
8
+
9
+ from ..configs.settings import MODEL_NAME, MODEL_REVISION, DEVICE
10
+ from ..utils.logging import get_logger
11
+
12
+ logger = get_logger(__name__)
13
+
14
+ class LLaVAModel:
15
+ """LLaVA model wrapper class."""
16
+
17
+ def __init__(self):
18
+ """Initialize the LLaVA model and processor."""
19
+ logger.info(f"Initializing LLaVA model from {MODEL_NAME}")
20
+ self.processor = AutoProcessor.from_pretrained(
21
+ MODEL_NAME,
22
+ revision=MODEL_REVISION,
23
+ trust_remote_code=True
24
+ )
25
+ self.model = AutoModelForCausalLM.from_pretrained(
26
+ MODEL_NAME,
27
+ revision=MODEL_REVISION,
28
+ torch_dtype=torch.float16,
29
+ device_map="auto",
30
+ trust_remote_code=True
31
+ )
32
+ logger.info("Model initialization complete")
33
+
34
+ def generate_response(
35
+ self,
36
+ image: Image.Image,
37
+ prompt: str,
38
+ max_new_tokens: int = 512,
39
+ temperature: float = 0.7,
40
+ top_p: float = 0.9
41
+ ) -> str:
42
+ """
43
+ Generate a response for the given image and prompt.
44
+
45
+ Args:
46
+ image: Input image as PIL Image
47
+ prompt: Text prompt for the model
48
+ max_new_tokens: Maximum number of tokens to generate
49
+ temperature: Sampling temperature
50
+ top_p: Top-p sampling parameter
51
+
52
+ Returns:
53
+ str: Generated response
54
+ """
55
+ try:
56
+ # Prepare inputs
57
+ inputs = self.processor(
58
+ prompt,
59
+ image,
60
+ return_tensors="pt"
61
+ ).to(DEVICE)
62
+
63
+ # Generate response
64
+ with torch.no_grad():
65
+ outputs = self.model.generate(
66
+ **inputs,
67
+ max_new_tokens=max_new_tokens,
68
+ temperature=temperature,
69
+ top_p=top_p,
70
+ do_sample=True
71
+ )
72
+
73
+ # Decode and return response
74
+ response = self.processor.decode(
75
+ outputs[0],
76
+ skip_special_tokens=True
77
+ )
78
+
79
+ logger.debug(f"Generated response: {response[:100]}...")
80
+ return response
81
+
82
+ except Exception as e:
83
+ logger.error(f"Error generating response: {str(e)}")
84
+ raise
85
+
86
+ def __call__(self, *args, **kwargs):
87
+ """Convenience method to call generate_response."""
88
+ return self.generate_response(*args, **kwargs)
main.py → src/models/main.py RENAMED
File without changes
src/requirements.txt ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ torchvision>=0.15.0
3
+ transformers>=4.30.0
4
+ accelerate>=0.20.0
5
+ pillow>=9.0.0
6
+ numpy>=1.24.0
7
+ tqdm>=4.65.0
8
+ matplotlib>=3.7.0
9
+ opencv-python>=4.7.0
10
+ einops>=0.6.0
11
+ timm>=0.9.0
12
+ sentencepiece>=0.1.99
13
+ gradio>=3.35.0
14
+ peft>=0.4.0
15
+ bitsandbytes>=0.40.0
16
+ safetensors>=0.3.1
17
+ fastapi==0.104.1
18
+ uvicorn==0.24.0
19
+ python-multipart==0.0.6
20
+ pydantic==2.5.2
21
+ python-jose==3.3.0
22
+ passlib==1.7.4
23
+ bcrypt==4.0.1
24
+ aiofiles==23.2.1
25
+ python-dotenv==1.0.0
26
+ httpx==0.25.2
src/utils/__init__.py ADDED
File without changes
src/utils/logging.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Logging utilities for the LLaVA implementation.
3
+ """
4
+
5
+ import logging
6
+ import sys
7
+ from pathlib import Path
8
+
9
+ from ..configs.settings import LOG_LEVEL, LOG_FORMAT, LOG_FILE
10
+
11
+ def setup_logging(name: str = None) -> logging.Logger:
12
+ """
13
+ Set up logging configuration for the application.
14
+
15
+ Args:
16
+ name: Optional name for the logger. If None, returns the root logger.
17
+
18
+ Returns:
19
+ logging.Logger: Configured logger instance.
20
+ """
21
+ # Create logger
22
+ logger = logging.getLogger(name)
23
+ logger.setLevel(LOG_LEVEL)
24
+
25
+ # Create formatters
26
+ formatter = logging.Formatter(LOG_FORMAT)
27
+
28
+ # Create handlers
29
+ console_handler = logging.StreamHandler(sys.stdout)
30
+ console_handler.setFormatter(formatter)
31
+
32
+ file_handler = logging.FileHandler(LOG_FILE)
33
+ file_handler.setFormatter(formatter)
34
+
35
+ # Add handlers to logger
36
+ logger.addHandler(console_handler)
37
+ logger.addHandler(file_handler)
38
+
39
+ return logger
40
+
41
+ def get_logger(name: str = None) -> logging.Logger:
42
+ """
43
+ Get a logger instance with the specified name.
44
+
45
+ Args:
46
+ name: Optional name for the logger. If None, returns the root logger.
47
+
48
+ Returns:
49
+ logging.Logger: Logger instance.
50
+ """
51
+ return logging.getLogger(name)
tests/test_model.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tests for the LLaVA model implementation.
3
+ """
4
+
5
+ import pytest
6
+ from PIL import Image
7
+ import torch
8
+
9
+ from src.models.llava_model import LLaVAModel
10
+ from src.configs.settings import DEFAULT_MAX_NEW_TOKENS, DEFAULT_TEMPERATURE, DEFAULT_TOP_P
11
+
12
+ @pytest.fixture
13
+ def model():
14
+ """Fixture to provide a model instance."""
15
+ return LLaVAModel()
16
+
17
+ @pytest.fixture
18
+ def sample_image():
19
+ """Fixture to provide a sample image."""
20
+ # Create a simple test image
21
+ return Image.new('RGB', (224, 224), color='red')
22
+
23
+ def test_model_initialization(model):
24
+ """Test that the model initializes correctly."""
25
+ assert model is not None
26
+ assert model.processor is not None
27
+ assert model.model is not None
28
+
29
+ def test_model_device(model):
30
+ """Test that the model is on the correct device."""
31
+ assert next(model.model.parameters()).device.type in ['cuda', 'cpu']
32
+
33
+ def test_generate_response(model, sample_image):
34
+ """Test that the model can generate responses."""
35
+ prompt = "What color is this image?"
36
+ response = model.generate_response(
37
+ image=sample_image,
38
+ prompt=prompt,
39
+ max_new_tokens=DEFAULT_MAX_NEW_TOKENS,
40
+ temperature=DEFAULT_TEMPERATURE,
41
+ top_p=DEFAULT_TOP_P
42
+ )
43
+
44
+ assert isinstance(response, str)
45
+ assert len(response) > 0
46
+
47
+ def test_generate_response_with_invalid_image(model):
48
+ """Test that the model handles invalid images correctly."""
49
+ with pytest.raises(Exception):
50
+ model.generate_response(
51
+ image=None,
52
+ prompt="What color is this image?",
53
+ max_new_tokens=DEFAULT_MAX_NEW_TOKENS,
54
+ temperature=DEFAULT_TEMPERATURE,
55
+ top_p=DEFAULT_TOP_P
56
+ )
57
+
58
+ def test_generate_response_with_empty_prompt(model, sample_image):
59
+ """Test that the model handles empty prompts correctly."""
60
+ with pytest.raises(Exception):
61
+ model.generate_response(
62
+ image=sample_image,
63
+ prompt="",
64
+ max_new_tokens=DEFAULT_MAX_NEW_TOKENS,
65
+ temperature=DEFAULT_TEMPERATURE,
66
+ top_p=DEFAULT_TOP_P
67
+ )