Spaces:
Sleeping
Sleeping
Commit
·
e5d40e3
1
Parent(s):
8d272fe
fix: Update Gradio to 4.44.1 and improve interface
Browse files- .github/workflows/ci.yml +142 -0
- .pre-commit-config.yaml +69 -0
- README.md +159 -52
- app.py +0 -148
- docs/api/README.md +121 -0
- docs/guides/developer_guide.md +362 -0
- docs/guides/user_guide.md +164 -0
- examples/api_client.py +127 -0
- examples/llava_demo.ipynb +1 -0
- examples/process_image.py +103 -0
- pyproject.toml +181 -0
- requirements-dev.txt +38 -0
- requirements.txt +18 -19
- src/__init__.py +0 -0
- src/api/__init__.py +0 -0
- src/api/app.py +159 -0
- src/configs/__init__.py +0 -0
- src/configs/settings.py +46 -0
- src/models/__init__.py +0 -0
- src/models/llava_model.py +88 -0
- main.py → src/models/main.py +0 -0
- src/requirements.txt +26 -0
- src/utils/__init__.py +0 -0
- src/utils/logging.py +51 -0
- tests/test_model.py +67 -0
.github/workflows/ci.yml
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: CI
|
2 |
+
|
3 |
+
on:
|
4 |
+
push:
|
5 |
+
branches: [ main ]
|
6 |
+
pull_request:
|
7 |
+
branches: [ main ]
|
8 |
+
|
9 |
+
jobs:
|
10 |
+
test:
|
11 |
+
runs-on: ubuntu-latest
|
12 |
+
strategy:
|
13 |
+
matrix:
|
14 |
+
python-version: ["3.8", "3.9", "3.10", "3.11"]
|
15 |
+
|
16 |
+
steps:
|
17 |
+
- uses: actions/checkout@v4
|
18 |
+
|
19 |
+
- name: Set up Python ${{ matrix.python-version }}
|
20 |
+
uses: actions/setup-python@v5
|
21 |
+
with:
|
22 |
+
python-version: ${{ matrix.python-version }}
|
23 |
+
cache: 'pip'
|
24 |
+
|
25 |
+
- name: Install dependencies
|
26 |
+
run: |
|
27 |
+
python -m pip install --upgrade pip
|
28 |
+
pip install -r requirements.txt
|
29 |
+
pip install -r requirements-dev.txt
|
30 |
+
|
31 |
+
- name: Run pre-commit hooks
|
32 |
+
run: |
|
33 |
+
pre-commit install
|
34 |
+
pre-commit run --all-files
|
35 |
+
|
36 |
+
- name: Run tests
|
37 |
+
run: |
|
38 |
+
pytest --cov=src --cov-report=xml
|
39 |
+
|
40 |
+
- name: Upload coverage to Codecov
|
41 |
+
uses: codecov/codecov-action@v4
|
42 |
+
with:
|
43 |
+
file: ./coverage.xml
|
44 |
+
fail_ci_if_error: true
|
45 |
+
|
46 |
+
lint:
|
47 |
+
runs-on: ubuntu-latest
|
48 |
+
steps:
|
49 |
+
- uses: actions/checkout@v4
|
50 |
+
|
51 |
+
- name: Set up Python
|
52 |
+
uses: actions/setup-python@v5
|
53 |
+
with:
|
54 |
+
python-version: "3.11"
|
55 |
+
cache: 'pip'
|
56 |
+
|
57 |
+
- name: Install dependencies
|
58 |
+
run: |
|
59 |
+
python -m pip install --upgrade pip
|
60 |
+
pip install -r requirements-dev.txt
|
61 |
+
|
62 |
+
- name: Run black
|
63 |
+
run: black --check src tests
|
64 |
+
|
65 |
+
- name: Run isort
|
66 |
+
run: isort --check-only src tests
|
67 |
+
|
68 |
+
- name: Run flake8
|
69 |
+
run: flake8 src tests
|
70 |
+
|
71 |
+
- name: Run mypy
|
72 |
+
run: mypy src
|
73 |
+
|
74 |
+
build:
|
75 |
+
needs: [test, lint]
|
76 |
+
runs-on: ubuntu-latest
|
77 |
+
if: github.event_name == 'push' && github.ref == 'refs/heads/main'
|
78 |
+
|
79 |
+
steps:
|
80 |
+
- uses: actions/checkout@v4
|
81 |
+
|
82 |
+
- name: Set up Python
|
83 |
+
uses: actions/setup-python@v5
|
84 |
+
with:
|
85 |
+
python-version: "3.11"
|
86 |
+
cache: 'pip'
|
87 |
+
|
88 |
+
- name: Install dependencies
|
89 |
+
run: |
|
90 |
+
python -m pip install --upgrade pip
|
91 |
+
pip install build twine
|
92 |
+
|
93 |
+
- name: Build package
|
94 |
+
run: python -m build
|
95 |
+
|
96 |
+
- name: Check package
|
97 |
+
run: twine check dist/*
|
98 |
+
|
99 |
+
- name: Upload artifacts
|
100 |
+
uses: actions/upload-artifact@v4
|
101 |
+
with:
|
102 |
+
name: dist
|
103 |
+
path: dist/
|
104 |
+
|
105 |
+
deploy:
|
106 |
+
needs: build
|
107 |
+
runs-on: ubuntu-latest
|
108 |
+
if: github.event_name == 'push' && github.ref == 'refs/heads/main'
|
109 |
+
|
110 |
+
steps:
|
111 |
+
- uses: actions/checkout@v4
|
112 |
+
|
113 |
+
- name: Download artifacts
|
114 |
+
uses: actions/download-artifact@v4
|
115 |
+
with:
|
116 |
+
name: dist
|
117 |
+
path: dist
|
118 |
+
|
119 |
+
- name: Set up Python
|
120 |
+
uses: actions/setup-python@v5
|
121 |
+
with:
|
122 |
+
python-version: "3.11"
|
123 |
+
cache: 'pip'
|
124 |
+
|
125 |
+
- name: Install dependencies
|
126 |
+
run: |
|
127 |
+
python -m pip install --upgrade pip
|
128 |
+
pip install twine
|
129 |
+
|
130 |
+
- name: Deploy to PyPI
|
131 |
+
env:
|
132 |
+
TWINE_USERNAME: __token__
|
133 |
+
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
|
134 |
+
run: twine upload dist/*
|
135 |
+
|
136 |
+
- name: Deploy to Hugging Face
|
137 |
+
env:
|
138 |
+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
139 |
+
run: |
|
140 |
+
pip install huggingface_hub
|
141 |
+
huggingface-cli login --token $HF_TOKEN
|
142 |
+
huggingface-cli upload Prashant26am/llava-chat dist/* --repo-type space
|
.pre-commit-config.yaml
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
repos:
|
2 |
+
- repo: https://github.com/pre-commit/pre-commit-hooks
|
3 |
+
rev: v4.5.0
|
4 |
+
hooks:
|
5 |
+
- id: trailing-whitespace
|
6 |
+
- id: end-of-file-fixer
|
7 |
+
- id: check-yaml
|
8 |
+
- id: check-added-large-files
|
9 |
+
- id: check-ast
|
10 |
+
- id: check-json
|
11 |
+
- id: check-merge-conflict
|
12 |
+
- id: detect-private-key
|
13 |
+
- id: debug-statements
|
14 |
+
|
15 |
+
- repo: https://github.com/psf/black
|
16 |
+
rev: 24.1.1
|
17 |
+
hooks:
|
18 |
+
- id: black
|
19 |
+
language_version: python3.8
|
20 |
+
|
21 |
+
- repo: https://github.com/pycqa/isort
|
22 |
+
rev: 5.13.2
|
23 |
+
hooks:
|
24 |
+
- id: isort
|
25 |
+
args: ["--profile", "black"]
|
26 |
+
|
27 |
+
- repo: https://github.com/pycqa/flake8
|
28 |
+
rev: 7.0.0
|
29 |
+
hooks:
|
30 |
+
- id: flake8
|
31 |
+
additional_dependencies:
|
32 |
+
- flake8-docstrings
|
33 |
+
- flake8-bugbear
|
34 |
+
- flake8-comprehensions
|
35 |
+
- flake8-simplify
|
36 |
+
- flake8-unused-arguments
|
37 |
+
- flake8-variables-names
|
38 |
+
- pep8-naming
|
39 |
+
|
40 |
+
- repo: https://github.com/pre-commit/mirrors-mypy
|
41 |
+
rev: v1.8.0
|
42 |
+
hooks:
|
43 |
+
- id: mypy
|
44 |
+
additional_dependencies:
|
45 |
+
- types-Pillow
|
46 |
+
- types-requests
|
47 |
+
- types-setuptools
|
48 |
+
- types-urllib3
|
49 |
+
|
50 |
+
- repo: https://github.com/asottile/pyupgrade
|
51 |
+
rev: v3.15.0
|
52 |
+
hooks:
|
53 |
+
- id: pyupgrade
|
54 |
+
args: [--py38-plus]
|
55 |
+
|
56 |
+
- repo: https://github.com/PyCQA/bandit
|
57 |
+
rev: 1.7.7
|
58 |
+
hooks:
|
59 |
+
- id: bandit
|
60 |
+
args: ["-c", "pyproject.toml"]
|
61 |
+
additional_dependencies: ["bandit[toml]"]
|
62 |
+
|
63 |
+
- repo: https://github.com/pre-commit/mirrors-prettier
|
64 |
+
rev: v4.0.0-alpha.8
|
65 |
+
hooks:
|
66 |
+
- id: prettier
|
67 |
+
types_or: [javascript, jsx, ts, tsx, json, css, scss, md, yaml, yml]
|
68 |
+
additional_dependencies:
|
69 |
+
- prettier@4.0.0-alpha.8
|
README.md
CHANGED
@@ -1,52 +1,159 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
-
|
20 |
-
-
|
21 |
-
-
|
22 |
-
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# LLaVA Implementation
|
2 |
+
|
3 |
+
[](https://opensource.org/licenses/MIT)
|
4 |
+
[](https://www.python.org/downloads/)
|
5 |
+
[](https://gradio.app/)
|
6 |
+
[](https://huggingface.co/spaces/Prashant26am/llava-chat)
|
7 |
+
|
8 |
+
A modern implementation of LLaVA (Large Language and Vision Assistant) with a beautiful web interface. This project combines state-of-the-art vision and language models to create an interactive AI assistant that can understand and discuss images.
|
9 |
+
|
10 |
+
## 🌟 Features
|
11 |
+
|
12 |
+
- **Modern Web Interface**
|
13 |
+
- Beautiful Gradio-based UI
|
14 |
+
- Real-time image analysis
|
15 |
+
- Interactive chat experience
|
16 |
+
- Responsive design
|
17 |
+
|
18 |
+
- **Advanced AI Capabilities**
|
19 |
+
- CLIP ViT-L/14 vision encoder
|
20 |
+
- Vicuna-7B language model
|
21 |
+
- Multimodal understanding
|
22 |
+
- Natural conversation flow
|
23 |
+
|
24 |
+
- **Developer Friendly**
|
25 |
+
- Clean, modular codebase
|
26 |
+
- Comprehensive documentation
|
27 |
+
- Easy deployment options
|
28 |
+
- Extensible architecture
|
29 |
+
|
30 |
+
## 📋 Project Structure
|
31 |
+
|
32 |
+
```
|
33 |
+
llava_implementation/
|
34 |
+
├── src/ # Source code
|
35 |
+
│ ├── api/ # API endpoints and FastAPI app
|
36 |
+
│ ├── models/ # Model implementations
|
37 |
+
│ ├── utils/ # Utility functions
|
38 |
+
│ └── configs/ # Configuration files
|
39 |
+
├── tests/ # Test suite
|
40 |
+
├── docs/ # Documentation
|
41 |
+
│ ├── api/ # API documentation
|
42 |
+
│ ├── examples/ # Usage examples
|
43 |
+
│ └── guides/ # User and developer guides
|
44 |
+
├── assets/ # Static assets
|
45 |
+
│ ├── images/ # Example images
|
46 |
+
│ └── icons/ # UI icons
|
47 |
+
├── scripts/ # Utility scripts
|
48 |
+
└── examples/ # Example images for the web interface
|
49 |
+
```
|
50 |
+
|
51 |
+
## 🚀 Quick Start
|
52 |
+
|
53 |
+
### Prerequisites
|
54 |
+
|
55 |
+
- Python 3.8+
|
56 |
+
- CUDA-capable GPU (recommended)
|
57 |
+
- Git
|
58 |
+
|
59 |
+
### Installation
|
60 |
+
|
61 |
+
1. Clone the repository:
|
62 |
+
```bash
|
63 |
+
git clone https://github.com/Prashant-ambati/llava-implementation.git
|
64 |
+
cd llava-implementation
|
65 |
+
```
|
66 |
+
|
67 |
+
2. Create and activate a virtual environment:
|
68 |
+
```bash
|
69 |
+
python -m venv venv
|
70 |
+
source venv/bin/activate # On Windows: venv\Scripts\activate
|
71 |
+
```
|
72 |
+
|
73 |
+
3. Install dependencies:
|
74 |
+
```bash
|
75 |
+
pip install -r requirements.txt
|
76 |
+
```
|
77 |
+
|
78 |
+
### Running Locally
|
79 |
+
|
80 |
+
1. Start the development server:
|
81 |
+
```bash
|
82 |
+
python src/api/app.py
|
83 |
+
```
|
84 |
+
|
85 |
+
2. Open your browser and navigate to:
|
86 |
+
```
|
87 |
+
http://localhost:7860
|
88 |
+
```
|
89 |
+
|
90 |
+
## 🌐 Web Deployment
|
91 |
+
|
92 |
+
### Hugging Face Spaces
|
93 |
+
|
94 |
+
The application is deployed on Hugging Face Spaces:
|
95 |
+
- [Live Demo](https://huggingface.co/spaces/Prashant26am/llava-chat)
|
96 |
+
- Automatic deployment from main branch
|
97 |
+
- Free GPU resources
|
98 |
+
- Public API access
|
99 |
+
|
100 |
+
### Local Deployment
|
101 |
+
|
102 |
+
For local deployment:
|
103 |
+
```bash
|
104 |
+
# Build the application
|
105 |
+
python -m build
|
106 |
+
|
107 |
+
# Run with production settings
|
108 |
+
python src/api/app.py --production
|
109 |
+
```
|
110 |
+
|
111 |
+
## 📚 Documentation
|
112 |
+
|
113 |
+
- [API Documentation](docs/api/README.md)
|
114 |
+
- [User Guide](docs/guides/user_guide.md)
|
115 |
+
- [Developer Guide](docs/guides/developer_guide.md)
|
116 |
+
- [Examples](docs/examples/README.md)
|
117 |
+
|
118 |
+
## 🛠️ Development
|
119 |
+
|
120 |
+
### Running Tests
|
121 |
+
|
122 |
+
```bash
|
123 |
+
pytest tests/
|
124 |
+
```
|
125 |
+
|
126 |
+
### Code Style
|
127 |
+
|
128 |
+
This project follows PEP 8 guidelines. To check your code:
|
129 |
+
|
130 |
+
```bash
|
131 |
+
flake8 src/
|
132 |
+
black src/
|
133 |
+
```
|
134 |
+
|
135 |
+
### Contributing
|
136 |
+
|
137 |
+
1. Fork the repository
|
138 |
+
2. Create a feature branch
|
139 |
+
3. Commit your changes
|
140 |
+
4. Push to the branch
|
141 |
+
5. Create a Pull Request
|
142 |
+
|
143 |
+
## 📝 License
|
144 |
+
|
145 |
+
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
146 |
+
|
147 |
+
## 🙏 Acknowledgments
|
148 |
+
|
149 |
+
- [LLaVA Paper](https://arxiv.org/abs/2304.08485) by Microsoft Research
|
150 |
+
- [Gradio](https://gradio.app/) for the web interface
|
151 |
+
- [Hugging Face](https://huggingface.co/) for model hosting
|
152 |
+
- [Vicuna](https://lmsys.org/blog/2023-03-30-vicuna/) for the language model
|
153 |
+
- [CLIP](https://openai.com/research/clip) for the vision model
|
154 |
+
|
155 |
+
## 📞 Contact
|
156 |
+
|
157 |
+
- GitHub Issues: [Report a bug](https://github.com/Prashant-ambati/llava-implementation/issues)
|
158 |
+
- Email: [Your Email]
|
159 |
+
- Twitter: [@YourTwitter]
|
app.py
DELETED
@@ -1,148 +0,0 @@
|
|
1 |
-
from fastapi import FastAPI, UploadFile, File, HTTPException
|
2 |
-
from fastapi.middleware.cors import CORSMiddleware
|
3 |
-
from fastapi.responses import JSONResponse
|
4 |
-
import os
|
5 |
-
import tempfile
|
6 |
-
from typing import Optional
|
7 |
-
from pydantic import BaseModel
|
8 |
-
import torch
|
9 |
-
import gradio as gr
|
10 |
-
from models.llava import LLaVA
|
11 |
-
|
12 |
-
# Initialize model globally
|
13 |
-
model = None
|
14 |
-
|
15 |
-
def initialize_model():
|
16 |
-
global model
|
17 |
-
try:
|
18 |
-
model = LLaVA(
|
19 |
-
vision_model_path="openai/clip-vit-large-patch14-336",
|
20 |
-
language_model_path="lmsys/vicuna-7b-v1.5",
|
21 |
-
device="cuda" if torch.cuda.is_available() else "cpu",
|
22 |
-
load_in_8bit=True
|
23 |
-
)
|
24 |
-
print(f"Model initialized on {model.device}")
|
25 |
-
return True
|
26 |
-
except Exception as e:
|
27 |
-
print(f"Error initializing model: {e}")
|
28 |
-
return False
|
29 |
-
|
30 |
-
def process_image(image, prompt, max_new_tokens=256, temperature=0.7, top_p=0.9):
|
31 |
-
if not model:
|
32 |
-
return "Error: Model not initialized"
|
33 |
-
|
34 |
-
try:
|
35 |
-
# Save the uploaded image temporarily
|
36 |
-
with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as temp_file:
|
37 |
-
image.save(temp_file.name)
|
38 |
-
temp_path = temp_file.name
|
39 |
-
|
40 |
-
# Generate response
|
41 |
-
response = model.generate_from_image(
|
42 |
-
image_path=temp_path,
|
43 |
-
prompt=prompt,
|
44 |
-
max_new_tokens=max_new_tokens,
|
45 |
-
temperature=temperature,
|
46 |
-
top_p=top_p
|
47 |
-
)
|
48 |
-
|
49 |
-
# Clean up temporary file
|
50 |
-
os.unlink(temp_path)
|
51 |
-
return response
|
52 |
-
|
53 |
-
except Exception as e:
|
54 |
-
return f"Error processing image: {str(e)}"
|
55 |
-
|
56 |
-
# Create Gradio interface
|
57 |
-
def create_interface():
|
58 |
-
with gr.Blocks(title="LLaVA Chat", theme=gr.themes.Soft()) as demo:
|
59 |
-
gr.Markdown("""
|
60 |
-
# LLaVA Chat
|
61 |
-
Upload an image and chat with LLaVA about it. This model can understand and describe images, answer questions about them, and engage in visual conversations.
|
62 |
-
""")
|
63 |
-
|
64 |
-
with gr.Row():
|
65 |
-
with gr.Column(scale=1):
|
66 |
-
image_input = gr.Image(type="pil", label="Upload Image")
|
67 |
-
prompt_input = gr.Textbox(
|
68 |
-
label="Ask about the image",
|
69 |
-
placeholder="What can you see in this image?",
|
70 |
-
lines=3
|
71 |
-
)
|
72 |
-
|
73 |
-
with gr.Accordion("Advanced Settings", open=False):
|
74 |
-
max_tokens = gr.Slider(
|
75 |
-
minimum=32,
|
76 |
-
maximum=512,
|
77 |
-
value=256,
|
78 |
-
step=32,
|
79 |
-
label="Max New Tokens"
|
80 |
-
)
|
81 |
-
temperature = gr.Slider(
|
82 |
-
minimum=0.1,
|
83 |
-
maximum=1.0,
|
84 |
-
value=0.7,
|
85 |
-
step=0.1,
|
86 |
-
label="Temperature"
|
87 |
-
)
|
88 |
-
top_p = gr.Slider(
|
89 |
-
minimum=0.1,
|
90 |
-
maximum=1.0,
|
91 |
-
value=0.9,
|
92 |
-
step=0.1,
|
93 |
-
label="Top P"
|
94 |
-
)
|
95 |
-
|
96 |
-
submit_btn = gr.Button("Generate Response", variant="primary")
|
97 |
-
|
98 |
-
with gr.Column(scale=1):
|
99 |
-
output = gr.Textbox(
|
100 |
-
label="Model Response",
|
101 |
-
lines=10,
|
102 |
-
show_copy_button=True
|
103 |
-
)
|
104 |
-
|
105 |
-
# Set up the submit action
|
106 |
-
submit_btn.click(
|
107 |
-
fn=process_image,
|
108 |
-
inputs=[image_input, prompt_input, max_tokens, temperature, top_p],
|
109 |
-
outputs=output
|
110 |
-
)
|
111 |
-
|
112 |
-
# Add examples
|
113 |
-
gr.Examples(
|
114 |
-
examples=[
|
115 |
-
["examples/cat.jpg", "What can you see in this image?"],
|
116 |
-
["examples/landscape.jpg", "Describe this scene in detail."],
|
117 |
-
["examples/food.jpg", "What kind of food is this and how would you describe it?"]
|
118 |
-
],
|
119 |
-
inputs=[image_input, prompt_input]
|
120 |
-
)
|
121 |
-
|
122 |
-
return demo
|
123 |
-
|
124 |
-
# Create FastAPI app
|
125 |
-
app = FastAPI(title="LLaVA Web Interface")
|
126 |
-
|
127 |
-
# Configure CORS
|
128 |
-
app.add_middleware(
|
129 |
-
CORSMiddleware,
|
130 |
-
allow_origins=["*"],
|
131 |
-
allow_credentials=True,
|
132 |
-
allow_methods=["*"],
|
133 |
-
allow_headers=["*"],
|
134 |
-
)
|
135 |
-
|
136 |
-
# Create Gradio app
|
137 |
-
demo = create_interface()
|
138 |
-
|
139 |
-
# Mount Gradio app
|
140 |
-
app = gr.mount_gradio_app(app, demo, path="/")
|
141 |
-
|
142 |
-
if __name__ == "__main__":
|
143 |
-
# Initialize model
|
144 |
-
if initialize_model():
|
145 |
-
import uvicorn
|
146 |
-
uvicorn.run(app, host="0.0.0.0", port=7860) # Hugging Face Spaces uses port 7860
|
147 |
-
else:
|
148 |
-
print("Failed to initialize model. Exiting...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
docs/api/README.md
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# LLaVA API Documentation
|
2 |
+
|
3 |
+
## Overview
|
4 |
+
|
5 |
+
The LLaVA API provides a simple interface for interacting with the LLaVA model through a Gradio web interface. The API allows users to upload images and receive AI-generated responses about the image content.
|
6 |
+
|
7 |
+
## API Endpoints
|
8 |
+
|
9 |
+
### Web Interface
|
10 |
+
|
11 |
+
The main interface is served at the root URL (`/`) and provides the following components:
|
12 |
+
|
13 |
+
#### Input Components
|
14 |
+
|
15 |
+
1. **Image Upload**
|
16 |
+
- Type: Image uploader
|
17 |
+
- Format: PIL Image
|
18 |
+
- Purpose: Upload an image for analysis
|
19 |
+
|
20 |
+
2. **Prompt Input**
|
21 |
+
- Type: Text input
|
22 |
+
- Purpose: Enter questions or prompts about the image
|
23 |
+
- Default placeholder: "What can you see in this image?"
|
24 |
+
|
25 |
+
3. **Generation Parameters**
|
26 |
+
- Max New Tokens (64-2048, default: 512)
|
27 |
+
- Temperature (0.1-1.0, default: 0.7)
|
28 |
+
- Top P (0.1-1.0, default: 0.9)
|
29 |
+
|
30 |
+
#### Output Components
|
31 |
+
|
32 |
+
1. **Response**
|
33 |
+
- Type: Text output
|
34 |
+
- Purpose: Displays the model's response
|
35 |
+
- Features: Copy button, scrollable
|
36 |
+
|
37 |
+
## Usage Examples
|
38 |
+
|
39 |
+
### Basic Usage
|
40 |
+
|
41 |
+
1. Upload an image using the image uploader
|
42 |
+
2. Enter a prompt in the text input
|
43 |
+
3. Click "Generate Response"
|
44 |
+
4. View the response in the output box
|
45 |
+
|
46 |
+
### Example Prompts
|
47 |
+
|
48 |
+
- "What can you see in this image?"
|
49 |
+
- "Describe this scene in detail"
|
50 |
+
- "What emotions does this image convey?"
|
51 |
+
- "What's happening in this picture?"
|
52 |
+
- "Can you identify any objects or people in this image?"
|
53 |
+
|
54 |
+
## Error Handling
|
55 |
+
|
56 |
+
The API handles various error cases:
|
57 |
+
|
58 |
+
1. **Invalid Images**
|
59 |
+
- Returns an error message if the image is invalid or corrupted
|
60 |
+
- Supports common image formats (JPEG, PNG, etc.)
|
61 |
+
|
62 |
+
2. **Empty Prompts**
|
63 |
+
- Returns an error message if no prompt is provided
|
64 |
+
- Prompts should be non-empty strings
|
65 |
+
|
66 |
+
3. **Model Errors**
|
67 |
+
- Returns descriptive error messages for model-related issues
|
68 |
+
- Includes logging for debugging
|
69 |
+
|
70 |
+
## Configuration
|
71 |
+
|
72 |
+
The API can be configured through environment variables or the settings file:
|
73 |
+
|
74 |
+
- `API_HOST`: Server host (default: "0.0.0.0")
|
75 |
+
- `API_PORT`: Server port (default: 7860)
|
76 |
+
- `GRADIO_THEME`: Interface theme (default: "soft")
|
77 |
+
- `DEFAULT_MAX_NEW_TOKENS`: Default token limit (default: 512)
|
78 |
+
- `DEFAULT_TEMPERATURE`: Default temperature (default: 0.7)
|
79 |
+
- `DEFAULT_TOP_P`: Default top-p value (default: 0.9)
|
80 |
+
|
81 |
+
## Development
|
82 |
+
|
83 |
+
### Running Locally
|
84 |
+
|
85 |
+
```bash
|
86 |
+
python src/api/app.py
|
87 |
+
```
|
88 |
+
|
89 |
+
### Running Tests
|
90 |
+
|
91 |
+
```bash
|
92 |
+
pytest tests/
|
93 |
+
```
|
94 |
+
|
95 |
+
### Code Style
|
96 |
+
|
97 |
+
The project follows PEP 8 guidelines. To check your code:
|
98 |
+
|
99 |
+
```bash
|
100 |
+
flake8 src/
|
101 |
+
black src/
|
102 |
+
```
|
103 |
+
|
104 |
+
## Security Considerations
|
105 |
+
|
106 |
+
1. The API is designed for public use but should be deployed behind appropriate security measures
|
107 |
+
2. Input validation is performed on all user inputs
|
108 |
+
3. Large file uploads are handled safely
|
109 |
+
4. Error messages are sanitized to prevent information leakage
|
110 |
+
|
111 |
+
## Rate Limiting
|
112 |
+
|
113 |
+
Currently, no rate limiting is implemented. Consider implementing rate limiting for production deployments.
|
114 |
+
|
115 |
+
## Future Improvements
|
116 |
+
|
117 |
+
1. Add authentication
|
118 |
+
2. Implement rate limiting
|
119 |
+
3. Add batch processing capabilities
|
120 |
+
4. Support for video input
|
121 |
+
5. Real-time streaming responses
|
docs/guides/developer_guide.md
ADDED
@@ -0,0 +1,362 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# LLaVA Implementation Developer Guide
|
2 |
+
|
3 |
+
## Overview
|
4 |
+
|
5 |
+
This guide is intended for developers who want to contribute to or extend the LLaVA implementation. The project is structured as a Python package with a Gradio web interface, using modern best practices and tools.
|
6 |
+
|
7 |
+
## Project Structure
|
8 |
+
|
9 |
+
```
|
10 |
+
llava_implementation/
|
11 |
+
├── src/ # Source code
|
12 |
+
│ ├── api/ # API endpoints and FastAPI app
|
13 |
+
│ │ ├── __init__.py
|
14 |
+
│ │ └── app.py # Gradio interface
|
15 |
+
│ ├── models/ # Model implementations
|
16 |
+
│ │ ├── __init__.py
|
17 |
+
│ │ └── llava_model.py # LLaVA model wrapper
|
18 |
+
│ ├── utils/ # Utility functions
|
19 |
+
│ │ ├── __init__.py
|
20 |
+
│ │ └── logging.py # Logging utilities
|
21 |
+
│ └── configs/ # Configuration files
|
22 |
+
│ ├── __init__.py
|
23 |
+
│ └── settings.py # Application settings
|
24 |
+
├── tests/ # Test suite
|
25 |
+
│ ├── __init__.py
|
26 |
+
│ └── test_model.py # Model tests
|
27 |
+
├── docs/ # Documentation
|
28 |
+
│ ├── api/ # API documentation
|
29 |
+
│ ├── examples/ # Usage examples
|
30 |
+
│ └── guides/ # User and developer guides
|
31 |
+
├── assets/ # Static assets
|
32 |
+
│ ├── images/ # Example images
|
33 |
+
│ └── icons/ # UI icons
|
34 |
+
├── scripts/ # Utility scripts
|
35 |
+
└── examples/ # Example images for the web interface
|
36 |
+
```
|
37 |
+
|
38 |
+
## Development Setup
|
39 |
+
|
40 |
+
### Prerequisites
|
41 |
+
|
42 |
+
- Python 3.8+
|
43 |
+
- Git
|
44 |
+
- CUDA-capable GPU (recommended)
|
45 |
+
- Virtual environment tool (venv, conda, etc.)
|
46 |
+
|
47 |
+
### Installation
|
48 |
+
|
49 |
+
1. Clone the repository:
|
50 |
+
```bash
|
51 |
+
git clone https://github.com/Prashant-ambati/llava-implementation.git
|
52 |
+
cd llava-implementation
|
53 |
+
```
|
54 |
+
|
55 |
+
2. Create and activate a virtual environment:
|
56 |
+
```bash
|
57 |
+
python -m venv venv
|
58 |
+
source venv/bin/activate # On Windows: venv\Scripts\activate
|
59 |
+
```
|
60 |
+
|
61 |
+
3. Install development dependencies:
|
62 |
+
```bash
|
63 |
+
pip install -r requirements.txt
|
64 |
+
pip install -r requirements-dev.txt # Development dependencies
|
65 |
+
```
|
66 |
+
|
67 |
+
### Development Tools
|
68 |
+
|
69 |
+
1. **Code Formatting**
|
70 |
+
- Black for code formatting
|
71 |
+
- isort for import sorting
|
72 |
+
- flake8 for linting
|
73 |
+
|
74 |
+
2. **Testing**
|
75 |
+
- pytest for testing
|
76 |
+
- pytest-cov for coverage
|
77 |
+
- pytest-mock for mocking
|
78 |
+
|
79 |
+
3. **Type Checking**
|
80 |
+
- mypy for static type checking
|
81 |
+
- types-* packages for type hints
|
82 |
+
|
83 |
+
## Code Style
|
84 |
+
|
85 |
+
### Python Style Guide
|
86 |
+
|
87 |
+
1. Follow PEP 8 guidelines
|
88 |
+
2. Use type hints
|
89 |
+
3. Write docstrings (Google style)
|
90 |
+
4. Keep functions focused and small
|
91 |
+
5. Use meaningful variable names
|
92 |
+
|
93 |
+
### Example
|
94 |
+
|
95 |
+
```python
|
96 |
+
from typing import Optional, List
|
97 |
+
from PIL import Image
|
98 |
+
|
99 |
+
def process_image(
|
100 |
+
image: Image.Image,
|
101 |
+
prompt: str,
|
102 |
+
max_tokens: Optional[int] = None
|
103 |
+
) -> List[str]:
|
104 |
+
"""
|
105 |
+
Process an image with the given prompt.
|
106 |
+
|
107 |
+
Args:
|
108 |
+
image: Input image as PIL Image
|
109 |
+
prompt: Text prompt for the model
|
110 |
+
max_tokens: Optional maximum tokens to generate
|
111 |
+
|
112 |
+
Returns:
|
113 |
+
List of generated responses
|
114 |
+
|
115 |
+
Raises:
|
116 |
+
ValueError: If image is invalid
|
117 |
+
RuntimeError: If model fails to process
|
118 |
+
"""
|
119 |
+
# Implementation
|
120 |
+
```
|
121 |
+
|
122 |
+
## Testing
|
123 |
+
|
124 |
+
### Running Tests
|
125 |
+
|
126 |
+
```bash
|
127 |
+
# Run all tests
|
128 |
+
pytest
|
129 |
+
|
130 |
+
# Run with coverage
|
131 |
+
pytest --cov=src
|
132 |
+
|
133 |
+
# Run specific test file
|
134 |
+
pytest tests/test_model.py
|
135 |
+
|
136 |
+
# Run with verbose output
|
137 |
+
pytest -v
|
138 |
+
```
|
139 |
+
|
140 |
+
### Writing Tests
|
141 |
+
|
142 |
+
1. Use pytest fixtures
|
143 |
+
2. Mock external dependencies
|
144 |
+
3. Test edge cases
|
145 |
+
4. Include both unit and integration tests
|
146 |
+
|
147 |
+
Example test:
|
148 |
+
```python
|
149 |
+
import pytest
|
150 |
+
from PIL import Image
|
151 |
+
|
152 |
+
def test_process_image(model, sample_image):
|
153 |
+
"""Test image processing functionality."""
|
154 |
+
prompt = "What color is this image?"
|
155 |
+
response = model.process_image(
|
156 |
+
image=sample_image,
|
157 |
+
prompt=prompt
|
158 |
+
)
|
159 |
+
assert isinstance(response, str)
|
160 |
+
assert len(response) > 0
|
161 |
+
```
|
162 |
+
|
163 |
+
## Model Development
|
164 |
+
|
165 |
+
### Adding New Models
|
166 |
+
|
167 |
+
1. Create a new model class in `src/models/`
|
168 |
+
2. Implement required methods
|
169 |
+
3. Add tests
|
170 |
+
4. Update documentation
|
171 |
+
|
172 |
+
Example:
|
173 |
+
```python
|
174 |
+
class NewModel:
|
175 |
+
"""New model implementation."""
|
176 |
+
|
177 |
+
def __init__(self, config: dict):
|
178 |
+
"""Initialize the model."""
|
179 |
+
self.config = config
|
180 |
+
self.model = self._load_model()
|
181 |
+
|
182 |
+
def process(self, *args, **kwargs):
|
183 |
+
"""Process inputs and generate output."""
|
184 |
+
pass
|
185 |
+
```
|
186 |
+
|
187 |
+
### Model Configuration
|
188 |
+
|
189 |
+
1. Add configuration in `src/configs/settings.py`
|
190 |
+
2. Use environment variables for secrets
|
191 |
+
3. Document all parameters
|
192 |
+
|
193 |
+
## API Development
|
194 |
+
|
195 |
+
### Adding New Endpoints
|
196 |
+
|
197 |
+
1. Create new endpoint in `src/api/app.py`
|
198 |
+
2. Add input validation
|
199 |
+
3. Implement error handling
|
200 |
+
4. Add tests
|
201 |
+
5. Update documentation
|
202 |
+
|
203 |
+
### Error Handling
|
204 |
+
|
205 |
+
1. Use custom exceptions
|
206 |
+
2. Implement proper logging
|
207 |
+
3. Return appropriate status codes
|
208 |
+
4. Include error messages
|
209 |
+
|
210 |
+
Example:
|
211 |
+
```python
|
212 |
+
class ModelError(Exception):
|
213 |
+
"""Base exception for model errors."""
|
214 |
+
pass
|
215 |
+
|
216 |
+
def process_request(request):
|
217 |
+
try:
|
218 |
+
result = model.process(request)
|
219 |
+
return result
|
220 |
+
except ModelError as e:
|
221 |
+
logger.error(f"Model error: {e}")
|
222 |
+
raise HTTPException(status_code=500, detail=str(e))
|
223 |
+
```
|
224 |
+
|
225 |
+
## Deployment
|
226 |
+
|
227 |
+
### Local Deployment
|
228 |
+
|
229 |
+
1. Build the package:
|
230 |
+
```bash
|
231 |
+
python -m build
|
232 |
+
```
|
233 |
+
|
234 |
+
2. Run the server:
|
235 |
+
```bash
|
236 |
+
python src/api/app.py
|
237 |
+
```
|
238 |
+
|
239 |
+
### Hugging Face Spaces
|
240 |
+
|
241 |
+
1. Update `README.md` with Space metadata
|
242 |
+
2. Ensure all dependencies are in `requirements.txt`
|
243 |
+
3. Test the Space locally
|
244 |
+
4. Push changes to the Space
|
245 |
+
|
246 |
+
### Production Deployment
|
247 |
+
|
248 |
+
1. Set up proper logging
|
249 |
+
2. Configure security measures
|
250 |
+
3. Implement rate limiting
|
251 |
+
4. Set up monitoring
|
252 |
+
5. Use environment variables
|
253 |
+
|
254 |
+
## Contributing
|
255 |
+
|
256 |
+
### Workflow
|
257 |
+
|
258 |
+
1. Fork the repository
|
259 |
+
2. Create a feature branch
|
260 |
+
3. Make changes
|
261 |
+
4. Run tests
|
262 |
+
5. Update documentation
|
263 |
+
6. Create a pull request
|
264 |
+
|
265 |
+
### Pull Request Process
|
266 |
+
|
267 |
+
1. Update documentation
|
268 |
+
2. Add tests
|
269 |
+
3. Ensure CI passes
|
270 |
+
4. Get code review
|
271 |
+
5. Address feedback
|
272 |
+
6. Merge when approved
|
273 |
+
|
274 |
+
## Performance Optimization
|
275 |
+
|
276 |
+
### Model Optimization
|
277 |
+
|
278 |
+
1. Use model quantization
|
279 |
+
2. Implement caching
|
280 |
+
3. Batch processing
|
281 |
+
4. GPU optimization
|
282 |
+
|
283 |
+
### API Optimization
|
284 |
+
|
285 |
+
1. Response compression
|
286 |
+
2. Request validation
|
287 |
+
3. Connection pooling
|
288 |
+
4. Caching strategies
|
289 |
+
|
290 |
+
## Security
|
291 |
+
|
292 |
+
### Best Practices
|
293 |
+
|
294 |
+
1. Input validation
|
295 |
+
2. Error handling
|
296 |
+
3. Rate limiting
|
297 |
+
4. Secure configuration
|
298 |
+
5. Regular updates
|
299 |
+
|
300 |
+
### Security Checklist
|
301 |
+
|
302 |
+
- [ ] Validate all inputs
|
303 |
+
- [ ] Sanitize outputs
|
304 |
+
- [ ] Use secure dependencies
|
305 |
+
- [ ] Implement rate limiting
|
306 |
+
- [ ] Set up monitoring
|
307 |
+
- [ ] Regular security audits
|
308 |
+
|
309 |
+
## Monitoring and Logging
|
310 |
+
|
311 |
+
### Logging
|
312 |
+
|
313 |
+
1. Use structured logging
|
314 |
+
2. Include context
|
315 |
+
3. Set appropriate levels
|
316 |
+
4. Rotate logs
|
317 |
+
|
318 |
+
### Monitoring
|
319 |
+
|
320 |
+
1. Track key metrics
|
321 |
+
2. Set up alerts
|
322 |
+
3. Monitor resources
|
323 |
+
4. Track errors
|
324 |
+
|
325 |
+
## Future Development
|
326 |
+
|
327 |
+
### Planned Features
|
328 |
+
|
329 |
+
1. Video support
|
330 |
+
2. Batch processing
|
331 |
+
3. Model fine-tuning
|
332 |
+
4. API authentication
|
333 |
+
5. Advanced caching
|
334 |
+
|
335 |
+
### Contributing Ideas
|
336 |
+
|
337 |
+
1. Open issues
|
338 |
+
2. Discuss in PRs
|
339 |
+
3. Join discussions
|
340 |
+
4. Share use cases
|
341 |
+
|
342 |
+
## Resources
|
343 |
+
|
344 |
+
### Documentation
|
345 |
+
|
346 |
+
- [Python Documentation](https://docs.python.org/)
|
347 |
+
- [Gradio Documentation](https://gradio.app/docs/)
|
348 |
+
- [Hugging Face Docs](https://huggingface.co/docs)
|
349 |
+
- [Pytest Documentation](https://docs.pytest.org/)
|
350 |
+
|
351 |
+
### Tools
|
352 |
+
|
353 |
+
- [Black](https://black.readthedocs.io/)
|
354 |
+
- [isort](https://pycqa.github.io/isort/)
|
355 |
+
- [flake8](https://flake8.pycqa.org/)
|
356 |
+
- [mypy](https://mypy.readthedocs.io/)
|
357 |
+
|
358 |
+
### Community
|
359 |
+
|
360 |
+
- [GitHub Issues](https://github.com/Prashant-ambati/llava-implementation/issues)
|
361 |
+
- [Hugging Face Forums](https://discuss.huggingface.co/)
|
362 |
+
- [Stack Overflow](https://stackoverflow.com/)
|
docs/guides/user_guide.md
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# LLaVA Chat User Guide
|
2 |
+
|
3 |
+
## Introduction
|
4 |
+
|
5 |
+
Welcome to LLaVA Chat! This guide will help you get started with using our AI-powered image understanding and chat interface. LLaVA (Large Language and Vision Assistant) combines advanced vision and language models to provide detailed analysis and natural conversations about images.
|
6 |
+
|
7 |
+
## Getting Started
|
8 |
+
|
9 |
+
### Accessing the Interface
|
10 |
+
|
11 |
+
1. Visit our [Hugging Face Space](https://huggingface.co/spaces/Prashant26am/llava-chat)
|
12 |
+
2. Wait for the interface to load (this may take a few moments as the model initializes)
|
13 |
+
3. You're ready to start chatting with images!
|
14 |
+
|
15 |
+
### Basic Usage
|
16 |
+
|
17 |
+
1. **Upload an Image**
|
18 |
+
- Click the image upload area or drag and drop an image
|
19 |
+
- Supported formats: JPEG, PNG, GIF
|
20 |
+
- Maximum file size: 10MB
|
21 |
+
|
22 |
+
2. **Enter Your Prompt**
|
23 |
+
- Type your question or prompt in the text box
|
24 |
+
- Be specific about what you want to know
|
25 |
+
- You can ask multiple questions about the same image
|
26 |
+
|
27 |
+
3. **Adjust Parameters** (Optional)
|
28 |
+
- Click "Generation Parameters" to expand
|
29 |
+
- Modify settings to control the response:
|
30 |
+
- Max New Tokens: Longer responses (64-2048)
|
31 |
+
- Temperature: More creative responses (0.1-1.0)
|
32 |
+
- Top P: More diverse responses (0.1-1.0)
|
33 |
+
|
34 |
+
4. **Generate Response**
|
35 |
+
- Click the "Generate Response" button
|
36 |
+
- Wait for the model to process (usually a few seconds)
|
37 |
+
- Read the response in the output box
|
38 |
+
- Use the copy button to save the response
|
39 |
+
|
40 |
+
## Best Practices
|
41 |
+
|
42 |
+
### Writing Effective Prompts
|
43 |
+
|
44 |
+
1. **Be Specific**
|
45 |
+
- Instead of "What's in this image?", try "What objects can you identify in this image?"
|
46 |
+
- Instead of "Describe this", try "Describe the scene, focusing on the main subject"
|
47 |
+
|
48 |
+
2. **Ask Follow-up Questions**
|
49 |
+
- "What emotions does this image convey?"
|
50 |
+
- "Can you identify any specific details about [object]?"
|
51 |
+
- "How would you describe the composition of this image?"
|
52 |
+
|
53 |
+
3. **Use Natural Language**
|
54 |
+
- Write as if you're talking to a person
|
55 |
+
- Feel free to ask for clarification or more details
|
56 |
+
- You can have a conversation about the image
|
57 |
+
|
58 |
+
### Example Prompts
|
59 |
+
|
60 |
+
1. **General Analysis**
|
61 |
+
- "What can you see in this image?"
|
62 |
+
- "Describe this scene in detail"
|
63 |
+
- "What's the main subject of this image?"
|
64 |
+
|
65 |
+
2. **Specific Details**
|
66 |
+
- "What colors are prominent in this image?"
|
67 |
+
- "Can you identify any text or signs in the image?"
|
68 |
+
- "What time of day does this image appear to be taken?"
|
69 |
+
|
70 |
+
3. **Emotional Response**
|
71 |
+
- "What mood or atmosphere does this image convey?"
|
72 |
+
- "How does this image make you feel?"
|
73 |
+
- "What emotions might this image evoke in viewers?"
|
74 |
+
|
75 |
+
4. **Technical Analysis**
|
76 |
+
- "What's the composition of this image?"
|
77 |
+
- "How would you describe the lighting in this image?"
|
78 |
+
- "What camera angle or perspective is used?"
|
79 |
+
|
80 |
+
## Troubleshooting
|
81 |
+
|
82 |
+
### Common Issues
|
83 |
+
|
84 |
+
1. **Image Not Loading**
|
85 |
+
- Check file format (JPEG, PNG, GIF)
|
86 |
+
- Ensure file size is under 10MB
|
87 |
+
- Try refreshing the page
|
88 |
+
|
89 |
+
2. **Slow Response**
|
90 |
+
- Reduce image size
|
91 |
+
- Simplify your prompt
|
92 |
+
- Check your internet connection
|
93 |
+
|
94 |
+
3. **Unexpected Responses**
|
95 |
+
- Try rephrasing your prompt
|
96 |
+
- Adjust generation parameters
|
97 |
+
- Be more specific in your question
|
98 |
+
|
99 |
+
### Getting Help
|
100 |
+
|
101 |
+
If you encounter any issues:
|
102 |
+
1. Check this guide for solutions
|
103 |
+
2. Visit our [GitHub repository](https://github.com/Prashant-ambati/llava-implementation)
|
104 |
+
3. Open an issue on GitHub
|
105 |
+
4. Contact us through Hugging Face
|
106 |
+
|
107 |
+
## Advanced Usage
|
108 |
+
|
109 |
+
### Parameter Tuning
|
110 |
+
|
111 |
+
1. **Max New Tokens**
|
112 |
+
- Lower values (64-256): Short, concise responses
|
113 |
+
- Medium values (256-512): Balanced responses
|
114 |
+
- Higher values (512+): Detailed, comprehensive responses
|
115 |
+
|
116 |
+
2. **Temperature**
|
117 |
+
- Lower values (0.1-0.3): More focused, deterministic responses
|
118 |
+
- Medium values (0.4-0.7): Balanced creativity
|
119 |
+
- Higher values (0.8-1.0): More creative, diverse responses
|
120 |
+
|
121 |
+
3. **Top P**
|
122 |
+
- Lower values (0.1-0.3): More focused word choice
|
123 |
+
- Medium values (0.4-0.7): Balanced diversity
|
124 |
+
- Higher values (0.8-1.0): More diverse word choice
|
125 |
+
|
126 |
+
### Tips for Better Results
|
127 |
+
|
128 |
+
1. **Image Quality**
|
129 |
+
- Use clear, well-lit images
|
130 |
+
- Ensure the subject is clearly visible
|
131 |
+
- Avoid heavily edited or filtered images
|
132 |
+
|
133 |
+
2. **Prompt Engineering**
|
134 |
+
- Start with simple questions
|
135 |
+
- Build up to more complex queries
|
136 |
+
- Use follow-up questions for details
|
137 |
+
|
138 |
+
3. **Response Management**
|
139 |
+
- Copy important responses
|
140 |
+
- Save interesting conversations
|
141 |
+
- Compare responses with different parameters
|
142 |
+
|
143 |
+
## Privacy and Ethics
|
144 |
+
|
145 |
+
1. **Image Privacy**
|
146 |
+
- Don't upload sensitive or private images
|
147 |
+
- Be mindful of copyright
|
148 |
+
- Respect others' privacy
|
149 |
+
|
150 |
+
2. **Responsible Use**
|
151 |
+
- Use the tool ethically
|
152 |
+
- Don't use for harmful purposes
|
153 |
+
- Respect content guidelines
|
154 |
+
|
155 |
+
## Future Updates
|
156 |
+
|
157 |
+
We're constantly improving LLaVA Chat. Planned features include:
|
158 |
+
1. Support for video input
|
159 |
+
2. Batch image processing
|
160 |
+
3. More advanced parameter controls
|
161 |
+
4. Additional model options
|
162 |
+
5. Enhanced response formatting
|
163 |
+
|
164 |
+
Stay tuned for updates!
|
examples/api_client.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Example API client for the LLaVA model.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import argparse
|
6 |
+
import json
|
7 |
+
from pathlib import Path
|
8 |
+
from typing import Dict, Any, Optional
|
9 |
+
|
10 |
+
import requests
|
11 |
+
from PIL import Image
|
12 |
+
import base64
|
13 |
+
from io import BytesIO
|
14 |
+
|
15 |
+
def encode_image(image_path: str) -> str:
|
16 |
+
"""
|
17 |
+
Encode an image to base64 string.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
image_path: Path to the image file
|
21 |
+
|
22 |
+
Returns:
|
23 |
+
str: Base64 encoded image
|
24 |
+
"""
|
25 |
+
with open(image_path, "rb") as image_file:
|
26 |
+
return base64.b64encode(image_file.read()).decode('utf-8')
|
27 |
+
|
28 |
+
def process_image(
|
29 |
+
api_url: str,
|
30 |
+
image_path: str,
|
31 |
+
prompt: str,
|
32 |
+
max_new_tokens: Optional[int] = None,
|
33 |
+
temperature: Optional[float] = None,
|
34 |
+
top_p: Optional[float] = None
|
35 |
+
) -> Dict[str, Any]:
|
36 |
+
"""
|
37 |
+
Process an image using the LLaVA API.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
api_url: URL of the API endpoint
|
41 |
+
image_path: Path to the input image
|
42 |
+
prompt: Text prompt for the model
|
43 |
+
max_new_tokens: Optional maximum tokens to generate
|
44 |
+
temperature: Optional sampling temperature
|
45 |
+
top_p: Optional top-p sampling parameter
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
Dict containing the API response
|
49 |
+
"""
|
50 |
+
# Prepare the request payload
|
51 |
+
payload = {
|
52 |
+
"image": encode_image(image_path),
|
53 |
+
"prompt": prompt
|
54 |
+
}
|
55 |
+
|
56 |
+
# Add optional parameters if provided
|
57 |
+
if max_new_tokens is not None:
|
58 |
+
payload["max_new_tokens"] = max_new_tokens
|
59 |
+
if temperature is not None:
|
60 |
+
payload["temperature"] = temperature
|
61 |
+
if top_p is not None:
|
62 |
+
payload["top_p"] = top_p
|
63 |
+
|
64 |
+
try:
|
65 |
+
# Send the request
|
66 |
+
response = requests.post(api_url, json=payload)
|
67 |
+
response.raise_for_status()
|
68 |
+
return response.json()
|
69 |
+
|
70 |
+
except requests.exceptions.RequestException as e:
|
71 |
+
print(f"Error making request: {e}")
|
72 |
+
if hasattr(e.response, 'text'):
|
73 |
+
print(f"Response: {e.response.text}")
|
74 |
+
raise
|
75 |
+
|
76 |
+
def save_response(response: Dict[str, Any], output_path: Optional[str] = None):
|
77 |
+
"""
|
78 |
+
Save or print the API response.
|
79 |
+
|
80 |
+
Args:
|
81 |
+
response: API response dictionary
|
82 |
+
output_path: Optional path to save the response
|
83 |
+
"""
|
84 |
+
if output_path:
|
85 |
+
with open(output_path, 'w') as f:
|
86 |
+
json.dump(response, f, indent=2)
|
87 |
+
print(f"Saved response to {output_path}")
|
88 |
+
else:
|
89 |
+
print("\nAPI Response:")
|
90 |
+
print("-" * 50)
|
91 |
+
print(json.dumps(response, indent=2))
|
92 |
+
print("-" * 50)
|
93 |
+
|
94 |
+
def main():
|
95 |
+
"""Main function to process images using the API."""
|
96 |
+
parser = argparse.ArgumentParser(description="Process images using LLaVA API")
|
97 |
+
parser.add_argument("image_path", type=str, help="Path to the input image")
|
98 |
+
parser.add_argument("prompt", type=str, help="Text prompt for the model")
|
99 |
+
parser.add_argument("--api-url", type=str, default="http://localhost:7860/api/process",
|
100 |
+
help="URL of the API endpoint")
|
101 |
+
parser.add_argument("--max-tokens", type=int, help="Maximum tokens to generate")
|
102 |
+
parser.add_argument("--temperature", type=float, help="Sampling temperature")
|
103 |
+
parser.add_argument("--top-p", type=float, help="Top-p sampling parameter")
|
104 |
+
parser.add_argument("--output", type=str, help="Path to save the response")
|
105 |
+
|
106 |
+
args = parser.parse_args()
|
107 |
+
|
108 |
+
try:
|
109 |
+
# Process image
|
110 |
+
response = process_image(
|
111 |
+
api_url=args.api_url,
|
112 |
+
image_path=args.image_path,
|
113 |
+
prompt=args.prompt,
|
114 |
+
max_new_tokens=args.max_tokens,
|
115 |
+
temperature=args.temperature,
|
116 |
+
top_p=args.top_p
|
117 |
+
)
|
118 |
+
|
119 |
+
# Save or print response
|
120 |
+
save_response(response, args.output)
|
121 |
+
|
122 |
+
except Exception as e:
|
123 |
+
print(f"Error: {str(e)}")
|
124 |
+
raise
|
125 |
+
|
126 |
+
if __name__ == "__main__":
|
127 |
+
main()
|
examples/llava_demo.ipynb
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
examples/process_image.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Example script for processing images with the LLaVA model.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import argparse
|
6 |
+
from pathlib import Path
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
from src.models.llava_model import LLaVAModel
|
10 |
+
from src.configs.settings import DEFAULT_MAX_NEW_TOKENS, DEFAULT_TEMPERATURE, DEFAULT_TOP_P
|
11 |
+
from src.utils.logging import setup_logging, get_logger
|
12 |
+
|
13 |
+
# Set up logging
|
14 |
+
setup_logging()
|
15 |
+
logger = get_logger(__name__)
|
16 |
+
|
17 |
+
def process_image(
|
18 |
+
image_path: str,
|
19 |
+
prompt: str,
|
20 |
+
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
|
21 |
+
temperature: float = DEFAULT_TEMPERATURE,
|
22 |
+
top_p: float = DEFAULT_TOP_P
|
23 |
+
) -> str:
|
24 |
+
"""
|
25 |
+
Process an image with the LLaVA model.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
image_path: Path to the input image
|
29 |
+
prompt: Text prompt for the model
|
30 |
+
max_new_tokens: Maximum number of tokens to generate
|
31 |
+
temperature: Sampling temperature
|
32 |
+
top_p: Top-p sampling parameter
|
33 |
+
|
34 |
+
Returns:
|
35 |
+
str: Model response
|
36 |
+
"""
|
37 |
+
try:
|
38 |
+
# Load image
|
39 |
+
image = Image.open(image_path)
|
40 |
+
logger.info(f"Loaded image from {image_path}")
|
41 |
+
|
42 |
+
# Initialize model
|
43 |
+
model = LLaVAModel()
|
44 |
+
logger.info("Model initialized")
|
45 |
+
|
46 |
+
# Generate response
|
47 |
+
response = model(
|
48 |
+
image=image,
|
49 |
+
prompt=prompt,
|
50 |
+
max_new_tokens=max_new_tokens,
|
51 |
+
temperature=temperature,
|
52 |
+
top_p=top_p
|
53 |
+
)
|
54 |
+
logger.info("Generated response")
|
55 |
+
|
56 |
+
return response
|
57 |
+
|
58 |
+
except Exception as e:
|
59 |
+
logger.error(f"Error processing image: {str(e)}")
|
60 |
+
raise
|
61 |
+
|
62 |
+
def main():
|
63 |
+
"""Main function to process images from command line."""
|
64 |
+
parser = argparse.ArgumentParser(description="Process images with LLaVA model")
|
65 |
+
parser.add_argument("image_path", type=str, help="Path to the input image")
|
66 |
+
parser.add_argument("prompt", type=str, help="Text prompt for the model")
|
67 |
+
parser.add_argument("--max-tokens", type=int, default=DEFAULT_MAX_NEW_TOKENS,
|
68 |
+
help="Maximum number of tokens to generate")
|
69 |
+
parser.add_argument("--temperature", type=float, default=DEFAULT_TEMPERATURE,
|
70 |
+
help="Sampling temperature")
|
71 |
+
parser.add_argument("--top-p", type=float, default=DEFAULT_TOP_P,
|
72 |
+
help="Top-p sampling parameter")
|
73 |
+
parser.add_argument("--output", type=str, help="Path to save the response")
|
74 |
+
|
75 |
+
args = parser.parse_args()
|
76 |
+
|
77 |
+
try:
|
78 |
+
# Process image
|
79 |
+
response = process_image(
|
80 |
+
image_path=args.image_path,
|
81 |
+
prompt=args.prompt,
|
82 |
+
max_new_tokens=args.max_tokens,
|
83 |
+
temperature=args.temperature,
|
84 |
+
top_p=args.top_p
|
85 |
+
)
|
86 |
+
|
87 |
+
# Print or save response
|
88 |
+
if args.output:
|
89 |
+
output_path = Path(args.output)
|
90 |
+
output_path.write_text(response)
|
91 |
+
logger.info(f"Saved response to {output_path}")
|
92 |
+
else:
|
93 |
+
print("\nModel Response:")
|
94 |
+
print("-" * 50)
|
95 |
+
print(response)
|
96 |
+
print("-" * 50)
|
97 |
+
|
98 |
+
except Exception as e:
|
99 |
+
logger.error(f"Error: {str(e)}")
|
100 |
+
raise
|
101 |
+
|
102 |
+
if __name__ == "__main__":
|
103 |
+
main()
|
pyproject.toml
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[build-system]
|
2 |
+
requires = ["setuptools>=61.0", "wheel"]
|
3 |
+
build-backend = "setuptools.build_meta"
|
4 |
+
|
5 |
+
[project]
|
6 |
+
name = "llava-implementation"
|
7 |
+
version = "0.1.0"
|
8 |
+
description = "A modern implementation of LLaVA with a beautiful web interface"
|
9 |
+
readme = "README.md"
|
10 |
+
requires-python = ">=3.8"
|
11 |
+
license = {text = "MIT"}
|
12 |
+
authors = [
|
13 |
+
{name = "Prashant Ambati", email = "your.email@example.com"}
|
14 |
+
]
|
15 |
+
classifiers = [
|
16 |
+
"Development Status :: 4 - Beta",
|
17 |
+
"Intended Audience :: Developers",
|
18 |
+
"Intended Audience :: Science/Research",
|
19 |
+
"License :: OSI Approved :: MIT License",
|
20 |
+
"Programming Language :: Python :: 3",
|
21 |
+
"Programming Language :: Python :: 3.8",
|
22 |
+
"Programming Language :: Python :: 3.9",
|
23 |
+
"Programming Language :: Python :: 3.10",
|
24 |
+
"Programming Language :: Python :: 3.11",
|
25 |
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
26 |
+
"Topic :: Software Development :: Libraries :: Python Modules",
|
27 |
+
]
|
28 |
+
dependencies = [
|
29 |
+
"torch>=2.0.0",
|
30 |
+
"torchvision>=0.15.0",
|
31 |
+
"transformers>=4.36.0",
|
32 |
+
"accelerate>=0.25.0",
|
33 |
+
"pillow>=10.0.0",
|
34 |
+
"numpy>=1.24.0",
|
35 |
+
"tqdm>=4.65.0",
|
36 |
+
"matplotlib>=3.7.0",
|
37 |
+
"opencv-python>=4.8.0",
|
38 |
+
"einops>=0.7.0",
|
39 |
+
"timm>=0.9.0",
|
40 |
+
"sentencepiece>=0.1.99",
|
41 |
+
"peft>=0.7.0",
|
42 |
+
"bitsandbytes>=0.41.0",
|
43 |
+
"safetensors>=0.4.0",
|
44 |
+
"gradio==4.44.1",
|
45 |
+
"fastapi>=0.109.0",
|
46 |
+
"uvicorn>=0.27.0",
|
47 |
+
"python-multipart>=0.0.6",
|
48 |
+
"pydantic>=2.5.0",
|
49 |
+
"python-jose>=3.3.0",
|
50 |
+
"passlib>=1.7.4",
|
51 |
+
"bcrypt>=4.0.1",
|
52 |
+
"aiofiles>=23.2.0",
|
53 |
+
"httpx>=0.26.0",
|
54 |
+
]
|
55 |
+
|
56 |
+
[project.optional-dependencies]
|
57 |
+
dev = [
|
58 |
+
"pytest>=8.0.0",
|
59 |
+
"pytest-cov>=4.1.0",
|
60 |
+
"pytest-mock>=3.12.0",
|
61 |
+
"pytest-asyncio>=0.23.5",
|
62 |
+
"pytest-xdist>=3.5.0",
|
63 |
+
"black>=24.1.1",
|
64 |
+
"isort>=5.13.2",
|
65 |
+
"flake8>=7.0.0",
|
66 |
+
"mypy>=1.8.0",
|
67 |
+
"types-Pillow>=10.2.0.20240106",
|
68 |
+
"types-requests>=2.31.0.20240125",
|
69 |
+
"sphinx>=7.2.6",
|
70 |
+
"sphinx-rtd-theme>=2.0.0",
|
71 |
+
"sphinx-autodoc-typehints>=2.0.1",
|
72 |
+
"sphinx-copybutton>=0.5.2",
|
73 |
+
"sphinx-tabs>=3.4.4",
|
74 |
+
"pre-commit>=3.6.0",
|
75 |
+
"ipython>=8.21.0",
|
76 |
+
"jupyter>=1.0.0",
|
77 |
+
"notebook>=7.0.7",
|
78 |
+
"ipykernel>=6.29.0",
|
79 |
+
"build>=1.0.3",
|
80 |
+
"twine>=4.0.2",
|
81 |
+
"wheel>=0.42.0",
|
82 |
+
"memory-profiler>=0.61.0",
|
83 |
+
"line-profiler>=4.1.2",
|
84 |
+
"debugpy>=1.8.0",
|
85 |
+
]
|
86 |
+
|
87 |
+
[project.urls]
|
88 |
+
Homepage = "https://github.com/Prashant-ambati/llava-implementation"
|
89 |
+
Documentation = "https://github.com/Prashant-ambati/llava-implementation#readme"
|
90 |
+
Repository = "https://github.com/Prashant-ambati/llava-implementation.git"
|
91 |
+
Issues = "https://github.com/Prashant-ambati/llava-implementation/issues"
|
92 |
+
"Bug Tracker" = "https://github.com/Prashant-ambati/llava-implementation/issues"
|
93 |
+
|
94 |
+
[tool.setuptools]
|
95 |
+
packages = ["src"]
|
96 |
+
|
97 |
+
[tool.black]
|
98 |
+
line-length = 88
|
99 |
+
target-version = ["py38"]
|
100 |
+
include = '\.pyi?$'
|
101 |
+
|
102 |
+
[tool.isort]
|
103 |
+
profile = "black"
|
104 |
+
multi_line_output = 3
|
105 |
+
include_trailing_comma = true
|
106 |
+
force_grid_wrap = 0
|
107 |
+
use_parentheses = true
|
108 |
+
ensure_newline_before_comments = true
|
109 |
+
line_length = 88
|
110 |
+
|
111 |
+
[tool.mypy]
|
112 |
+
python_version = "3.8"
|
113 |
+
warn_return_any = true
|
114 |
+
warn_unused_configs = true
|
115 |
+
disallow_untyped_defs = true
|
116 |
+
disallow_incomplete_defs = true
|
117 |
+
check_untyped_defs = true
|
118 |
+
disallow_untyped_decorators = true
|
119 |
+
no_implicit_optional = true
|
120 |
+
warn_redundant_casts = true
|
121 |
+
warn_unused_ignores = true
|
122 |
+
warn_no_return = true
|
123 |
+
warn_unreachable = true
|
124 |
+
strict_optional = true
|
125 |
+
|
126 |
+
[tool.pytest.ini_options]
|
127 |
+
minversion = "6.0"
|
128 |
+
addopts = "-ra -q --cov=src"
|
129 |
+
testpaths = [
|
130 |
+
"tests",
|
131 |
+
]
|
132 |
+
python_files = ["test_*.py"]
|
133 |
+
python_classes = ["Test*"]
|
134 |
+
python_functions = ["test_*"]
|
135 |
+
|
136 |
+
[tool.coverage.run]
|
137 |
+
source = ["src"]
|
138 |
+
branch = true
|
139 |
+
|
140 |
+
[tool.coverage.report]
|
141 |
+
exclude_lines = [
|
142 |
+
"pragma: no cover",
|
143 |
+
"def __repr__",
|
144 |
+
"if self.debug:",
|
145 |
+
"raise NotImplementedError",
|
146 |
+
"if __name__ == .__main__.:",
|
147 |
+
"pass",
|
148 |
+
"raise ImportError",
|
149 |
+
]
|
150 |
+
show_missing = true
|
151 |
+
fail_under = 80
|
152 |
+
|
153 |
+
[tool.bandit]
|
154 |
+
exclude_dirs = ["tests", "docs"]
|
155 |
+
skips = ["B101"]
|
156 |
+
|
157 |
+
[tool.ruff]
|
158 |
+
line-length = 88
|
159 |
+
target-version = "py38"
|
160 |
+
select = [
|
161 |
+
"E", # pycodestyle errors
|
162 |
+
"W", # pycodestyle warnings
|
163 |
+
"F", # pyflakes
|
164 |
+
"I", # isort
|
165 |
+
"B", # flake8-bugbear
|
166 |
+
"C4", # flake8-comprehensions
|
167 |
+
"UP", # pyupgrade
|
168 |
+
"N", # pep8-naming
|
169 |
+
"PL", # pylint
|
170 |
+
"RUF", # ruff-specific rules
|
171 |
+
]
|
172 |
+
ignore = [
|
173 |
+
"E501", # line length violations
|
174 |
+
"B008", # do not perform function calls in argument defaults
|
175 |
+
]
|
176 |
+
|
177 |
+
[tool.ruff.isort]
|
178 |
+
known-first-party = ["src"]
|
179 |
+
|
180 |
+
[tool.ruff.mccabe]
|
181 |
+
max-complexity = 10
|
requirements-dev.txt
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Testing
|
2 |
+
pytest==8.0.0
|
3 |
+
pytest-cov==4.1.0
|
4 |
+
pytest-mock==3.12.0
|
5 |
+
pytest-asyncio==0.23.5
|
6 |
+
pytest-xdist==3.5.0
|
7 |
+
|
8 |
+
# Code Quality
|
9 |
+
black==24.1.1
|
10 |
+
isort==5.13.2
|
11 |
+
flake8==7.0.0
|
12 |
+
mypy==1.8.0
|
13 |
+
types-Pillow==10.2.0.20240106
|
14 |
+
types-requests==2.31.0.20240125
|
15 |
+
|
16 |
+
# Documentation
|
17 |
+
sphinx==7.2.6
|
18 |
+
sphinx-rtd-theme==2.0.0
|
19 |
+
sphinx-autodoc-typehints==2.0.1
|
20 |
+
sphinx-copybutton==0.5.2
|
21 |
+
sphinx-tabs==3.4.4
|
22 |
+
|
23 |
+
# Development Tools
|
24 |
+
pre-commit==3.6.0
|
25 |
+
ipython==8.21.0
|
26 |
+
jupyter==1.0.0
|
27 |
+
notebook==7.0.7
|
28 |
+
ipykernel==6.29.0
|
29 |
+
|
30 |
+
# Build Tools
|
31 |
+
build==1.0.3
|
32 |
+
twine==4.0.2
|
33 |
+
wheel==0.42.0
|
34 |
+
|
35 |
+
# Monitoring and Debugging
|
36 |
+
memory-profiler==0.61.0
|
37 |
+
line-profiler==4.1.2
|
38 |
+
debugpy==1.8.0
|
requirements.txt
CHANGED
@@ -1,26 +1,25 @@
|
|
1 |
torch>=2.0.0
|
2 |
torchvision>=0.15.0
|
3 |
-
transformers>=4.
|
4 |
-
accelerate>=0.
|
5 |
-
pillow>=
|
6 |
numpy>=1.24.0
|
7 |
tqdm>=4.65.0
|
8 |
matplotlib>=3.7.0
|
9 |
-
opencv-python>=4.
|
10 |
-
einops>=0.
|
11 |
timm>=0.9.0
|
12 |
sentencepiece>=0.1.99
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
fastapi
|
18 |
-
uvicorn
|
19 |
-
python-multipart
|
20 |
-
pydantic
|
21 |
-
python-jose
|
22 |
-
passlib
|
23 |
-
bcrypt
|
24 |
-
aiofiles
|
25 |
-
|
26 |
-
httpx==0.25.2
|
|
|
1 |
torch>=2.0.0
|
2 |
torchvision>=0.15.0
|
3 |
+
transformers>=4.36.0
|
4 |
+
accelerate>=0.25.0
|
5 |
+
pillow>=10.0.0
|
6 |
numpy>=1.24.0
|
7 |
tqdm>=4.65.0
|
8 |
matplotlib>=3.7.0
|
9 |
+
opencv-python>=4.8.0
|
10 |
+
einops>=0.7.0
|
11 |
timm>=0.9.0
|
12 |
sentencepiece>=0.1.99
|
13 |
+
peft>=0.7.0
|
14 |
+
bitsandbytes>=0.41.0
|
15 |
+
safetensors>=0.4.0
|
16 |
+
gradio==4.44.1
|
17 |
+
fastapi>=0.109.0
|
18 |
+
uvicorn>=0.27.0
|
19 |
+
python-multipart>=0.0.6
|
20 |
+
pydantic>=2.5.0
|
21 |
+
python-jose>=3.3.0
|
22 |
+
passlib>=1.7.4
|
23 |
+
bcrypt>=4.0.1
|
24 |
+
aiofiles>=23.2.0
|
25 |
+
httpx>=0.26.0
|
|
src/__init__.py
ADDED
File without changes
|
src/api/__init__.py
ADDED
File without changes
|
src/api/app.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Gradio interface for the LLaVA model.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import gradio as gr
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
from ..configs.settings import (
|
9 |
+
GRADIO_THEME,
|
10 |
+
GRADIO_TITLE,
|
11 |
+
GRADIO_DESCRIPTION,
|
12 |
+
DEFAULT_MAX_NEW_TOKENS,
|
13 |
+
DEFAULT_TEMPERATURE,
|
14 |
+
DEFAULT_TOP_P,
|
15 |
+
API_HOST,
|
16 |
+
API_PORT,
|
17 |
+
API_WORKERS,
|
18 |
+
API_RELOAD
|
19 |
+
)
|
20 |
+
from ..models.llava_model import LLaVAModel
|
21 |
+
from ..utils.logging import setup_logging, get_logger
|
22 |
+
|
23 |
+
# Set up logging
|
24 |
+
setup_logging()
|
25 |
+
logger = get_logger(__name__)
|
26 |
+
|
27 |
+
# Initialize model
|
28 |
+
model = LLaVAModel()
|
29 |
+
|
30 |
+
def process_image(
|
31 |
+
image: Image.Image,
|
32 |
+
prompt: str,
|
33 |
+
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
|
34 |
+
temperature: float = DEFAULT_TEMPERATURE,
|
35 |
+
top_p: float = DEFAULT_TOP_P
|
36 |
+
) -> str:
|
37 |
+
"""
|
38 |
+
Process an image with the LLaVA model.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
image: Input image
|
42 |
+
prompt: Text prompt
|
43 |
+
max_new_tokens: Maximum number of tokens to generate
|
44 |
+
temperature: Sampling temperature
|
45 |
+
top_p: Top-p sampling parameter
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
str: Model response
|
49 |
+
"""
|
50 |
+
try:
|
51 |
+
logger.info(f"Processing image with prompt: {prompt[:100]}...")
|
52 |
+
response = model(
|
53 |
+
image=image,
|
54 |
+
prompt=prompt,
|
55 |
+
max_new_tokens=max_new_tokens,
|
56 |
+
temperature=temperature,
|
57 |
+
top_p=top_p
|
58 |
+
)
|
59 |
+
logger.info("Successfully generated response")
|
60 |
+
return response
|
61 |
+
except Exception as e:
|
62 |
+
logger.error(f"Error processing image: {str(e)}")
|
63 |
+
return f"Error: {str(e)}"
|
64 |
+
|
65 |
+
def create_interface() -> gr.Interface:
|
66 |
+
"""Create and return the Gradio interface."""
|
67 |
+
with gr.Blocks(theme=GRADIO_THEME) as interface:
|
68 |
+
gr.Markdown(f"""# {GRADIO_TITLE}
|
69 |
+
|
70 |
+
{GRADIO_DESCRIPTION}
|
71 |
+
|
72 |
+
## Example Prompts
|
73 |
+
|
74 |
+
Try these prompts to get started:
|
75 |
+
- "What can you see in this image?"
|
76 |
+
- "Describe this scene in detail"
|
77 |
+
- "What emotions does this image convey?"
|
78 |
+
- "What's happening in this picture?"
|
79 |
+
- "Can you identify any objects or people in this image?"
|
80 |
+
|
81 |
+
## Usage Instructions
|
82 |
+
|
83 |
+
1. Upload an image using the image uploader
|
84 |
+
2. Enter your prompt in the text box
|
85 |
+
3. (Optional) Adjust the generation parameters
|
86 |
+
4. Click "Generate Response" to get LLaVA's analysis
|
87 |
+
""")
|
88 |
+
|
89 |
+
with gr.Row():
|
90 |
+
with gr.Column():
|
91 |
+
# Input components
|
92 |
+
image_input = gr.Image(type="pil", label="Upload Image")
|
93 |
+
prompt_input = gr.Textbox(
|
94 |
+
label="Prompt",
|
95 |
+
placeholder="What can you see in this image?",
|
96 |
+
lines=3
|
97 |
+
)
|
98 |
+
|
99 |
+
with gr.Accordion("Generation Parameters", open=False):
|
100 |
+
max_tokens = gr.Slider(
|
101 |
+
minimum=64,
|
102 |
+
maximum=2048,
|
103 |
+
value=DEFAULT_MAX_NEW_TOKENS,
|
104 |
+
step=64,
|
105 |
+
label="Max New Tokens"
|
106 |
+
)
|
107 |
+
temperature = gr.Slider(
|
108 |
+
minimum=0.1,
|
109 |
+
maximum=1.0,
|
110 |
+
value=DEFAULT_TEMPERATURE,
|
111 |
+
step=0.1,
|
112 |
+
label="Temperature"
|
113 |
+
)
|
114 |
+
top_p = gr.Slider(
|
115 |
+
minimum=0.1,
|
116 |
+
maximum=1.0,
|
117 |
+
value=DEFAULT_TOP_P,
|
118 |
+
step=0.1,
|
119 |
+
label="Top P"
|
120 |
+
)
|
121 |
+
|
122 |
+
generate_btn = gr.Button("Generate Response", variant="primary")
|
123 |
+
|
124 |
+
with gr.Column():
|
125 |
+
# Output component
|
126 |
+
output = gr.Textbox(
|
127 |
+
label="Response",
|
128 |
+
lines=10,
|
129 |
+
show_copy_button=True
|
130 |
+
)
|
131 |
+
|
132 |
+
# Set up event handlers
|
133 |
+
generate_btn.click(
|
134 |
+
fn=process_image,
|
135 |
+
inputs=[
|
136 |
+
image_input,
|
137 |
+
prompt_input,
|
138 |
+
max_tokens,
|
139 |
+
temperature,
|
140 |
+
top_p
|
141 |
+
],
|
142 |
+
outputs=output
|
143 |
+
)
|
144 |
+
|
145 |
+
return interface
|
146 |
+
|
147 |
+
def main():
|
148 |
+
"""Run the Gradio interface."""
|
149 |
+
interface = create_interface()
|
150 |
+
interface.launch(
|
151 |
+
server_name=API_HOST,
|
152 |
+
server_port=API_PORT,
|
153 |
+
share=True,
|
154 |
+
show_error=True,
|
155 |
+
show_api=False
|
156 |
+
)
|
157 |
+
|
158 |
+
if __name__ == "__main__":
|
159 |
+
main()
|
src/configs/__init__.py
ADDED
File without changes
|
src/configs/settings.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Configuration settings for the LLaVA implementation.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import os
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
# Project paths
|
9 |
+
PROJECT_ROOT = Path(__file__).parent.parent.parent
|
10 |
+
SRC_DIR = PROJECT_ROOT / "src"
|
11 |
+
ASSETS_DIR = PROJECT_ROOT / "assets"
|
12 |
+
EXAMPLES_DIR = PROJECT_ROOT / "examples"
|
13 |
+
|
14 |
+
# Model settings
|
15 |
+
MODEL_NAME = "liuhaotian/llava-v1.5-7b"
|
16 |
+
MODEL_REVISION = "main"
|
17 |
+
DEVICE = "cuda" if os.environ.get("CUDA_VISIBLE_DEVICES") else "cpu"
|
18 |
+
|
19 |
+
# Generation settings
|
20 |
+
DEFAULT_MAX_NEW_TOKENS = 512
|
21 |
+
DEFAULT_TEMPERATURE = 0.7
|
22 |
+
DEFAULT_TOP_P = 0.9
|
23 |
+
|
24 |
+
# API settings
|
25 |
+
API_HOST = "0.0.0.0"
|
26 |
+
API_PORT = 7860
|
27 |
+
API_WORKERS = 1
|
28 |
+
API_RELOAD = True
|
29 |
+
|
30 |
+
# Gradio settings
|
31 |
+
GRADIO_THEME = "soft"
|
32 |
+
GRADIO_TITLE = "LLaVA Chat"
|
33 |
+
GRADIO_DESCRIPTION = """
|
34 |
+
A powerful multimodal AI assistant that can understand and discuss images.
|
35 |
+
Upload any image and chat with LLaVA about it!
|
36 |
+
"""
|
37 |
+
|
38 |
+
# Logging settings
|
39 |
+
LOG_LEVEL = "INFO"
|
40 |
+
LOG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
41 |
+
LOG_DIR = PROJECT_ROOT / "logs"
|
42 |
+
LOG_FILE = LOG_DIR / "app.log"
|
43 |
+
|
44 |
+
# Create necessary directories
|
45 |
+
for directory in [ASSETS_DIR, EXAMPLES_DIR, LOG_DIR]:
|
46 |
+
directory.mkdir(parents=True, exist_ok=True)
|
src/models/__init__.py
ADDED
File without changes
|
src/models/llava_model.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
LLaVA model implementation.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from transformers import AutoProcessor, AutoModelForCausalLM
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
from ..configs.settings import MODEL_NAME, MODEL_REVISION, DEVICE
|
10 |
+
from ..utils.logging import get_logger
|
11 |
+
|
12 |
+
logger = get_logger(__name__)
|
13 |
+
|
14 |
+
class LLaVAModel:
|
15 |
+
"""LLaVA model wrapper class."""
|
16 |
+
|
17 |
+
def __init__(self):
|
18 |
+
"""Initialize the LLaVA model and processor."""
|
19 |
+
logger.info(f"Initializing LLaVA model from {MODEL_NAME}")
|
20 |
+
self.processor = AutoProcessor.from_pretrained(
|
21 |
+
MODEL_NAME,
|
22 |
+
revision=MODEL_REVISION,
|
23 |
+
trust_remote_code=True
|
24 |
+
)
|
25 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
26 |
+
MODEL_NAME,
|
27 |
+
revision=MODEL_REVISION,
|
28 |
+
torch_dtype=torch.float16,
|
29 |
+
device_map="auto",
|
30 |
+
trust_remote_code=True
|
31 |
+
)
|
32 |
+
logger.info("Model initialization complete")
|
33 |
+
|
34 |
+
def generate_response(
|
35 |
+
self,
|
36 |
+
image: Image.Image,
|
37 |
+
prompt: str,
|
38 |
+
max_new_tokens: int = 512,
|
39 |
+
temperature: float = 0.7,
|
40 |
+
top_p: float = 0.9
|
41 |
+
) -> str:
|
42 |
+
"""
|
43 |
+
Generate a response for the given image and prompt.
|
44 |
+
|
45 |
+
Args:
|
46 |
+
image: Input image as PIL Image
|
47 |
+
prompt: Text prompt for the model
|
48 |
+
max_new_tokens: Maximum number of tokens to generate
|
49 |
+
temperature: Sampling temperature
|
50 |
+
top_p: Top-p sampling parameter
|
51 |
+
|
52 |
+
Returns:
|
53 |
+
str: Generated response
|
54 |
+
"""
|
55 |
+
try:
|
56 |
+
# Prepare inputs
|
57 |
+
inputs = self.processor(
|
58 |
+
prompt,
|
59 |
+
image,
|
60 |
+
return_tensors="pt"
|
61 |
+
).to(DEVICE)
|
62 |
+
|
63 |
+
# Generate response
|
64 |
+
with torch.no_grad():
|
65 |
+
outputs = self.model.generate(
|
66 |
+
**inputs,
|
67 |
+
max_new_tokens=max_new_tokens,
|
68 |
+
temperature=temperature,
|
69 |
+
top_p=top_p,
|
70 |
+
do_sample=True
|
71 |
+
)
|
72 |
+
|
73 |
+
# Decode and return response
|
74 |
+
response = self.processor.decode(
|
75 |
+
outputs[0],
|
76 |
+
skip_special_tokens=True
|
77 |
+
)
|
78 |
+
|
79 |
+
logger.debug(f"Generated response: {response[:100]}...")
|
80 |
+
return response
|
81 |
+
|
82 |
+
except Exception as e:
|
83 |
+
logger.error(f"Error generating response: {str(e)}")
|
84 |
+
raise
|
85 |
+
|
86 |
+
def __call__(self, *args, **kwargs):
|
87 |
+
"""Convenience method to call generate_response."""
|
88 |
+
return self.generate_response(*args, **kwargs)
|
main.py → src/models/main.py
RENAMED
File without changes
|
src/requirements.txt
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch>=2.0.0
|
2 |
+
torchvision>=0.15.0
|
3 |
+
transformers>=4.30.0
|
4 |
+
accelerate>=0.20.0
|
5 |
+
pillow>=9.0.0
|
6 |
+
numpy>=1.24.0
|
7 |
+
tqdm>=4.65.0
|
8 |
+
matplotlib>=3.7.0
|
9 |
+
opencv-python>=4.7.0
|
10 |
+
einops>=0.6.0
|
11 |
+
timm>=0.9.0
|
12 |
+
sentencepiece>=0.1.99
|
13 |
+
gradio>=3.35.0
|
14 |
+
peft>=0.4.0
|
15 |
+
bitsandbytes>=0.40.0
|
16 |
+
safetensors>=0.3.1
|
17 |
+
fastapi==0.104.1
|
18 |
+
uvicorn==0.24.0
|
19 |
+
python-multipart==0.0.6
|
20 |
+
pydantic==2.5.2
|
21 |
+
python-jose==3.3.0
|
22 |
+
passlib==1.7.4
|
23 |
+
bcrypt==4.0.1
|
24 |
+
aiofiles==23.2.1
|
25 |
+
python-dotenv==1.0.0
|
26 |
+
httpx==0.25.2
|
src/utils/__init__.py
ADDED
File without changes
|
src/utils/logging.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Logging utilities for the LLaVA implementation.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import logging
|
6 |
+
import sys
|
7 |
+
from pathlib import Path
|
8 |
+
|
9 |
+
from ..configs.settings import LOG_LEVEL, LOG_FORMAT, LOG_FILE
|
10 |
+
|
11 |
+
def setup_logging(name: str = None) -> logging.Logger:
|
12 |
+
"""
|
13 |
+
Set up logging configuration for the application.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
name: Optional name for the logger. If None, returns the root logger.
|
17 |
+
|
18 |
+
Returns:
|
19 |
+
logging.Logger: Configured logger instance.
|
20 |
+
"""
|
21 |
+
# Create logger
|
22 |
+
logger = logging.getLogger(name)
|
23 |
+
logger.setLevel(LOG_LEVEL)
|
24 |
+
|
25 |
+
# Create formatters
|
26 |
+
formatter = logging.Formatter(LOG_FORMAT)
|
27 |
+
|
28 |
+
# Create handlers
|
29 |
+
console_handler = logging.StreamHandler(sys.stdout)
|
30 |
+
console_handler.setFormatter(formatter)
|
31 |
+
|
32 |
+
file_handler = logging.FileHandler(LOG_FILE)
|
33 |
+
file_handler.setFormatter(formatter)
|
34 |
+
|
35 |
+
# Add handlers to logger
|
36 |
+
logger.addHandler(console_handler)
|
37 |
+
logger.addHandler(file_handler)
|
38 |
+
|
39 |
+
return logger
|
40 |
+
|
41 |
+
def get_logger(name: str = None) -> logging.Logger:
|
42 |
+
"""
|
43 |
+
Get a logger instance with the specified name.
|
44 |
+
|
45 |
+
Args:
|
46 |
+
name: Optional name for the logger. If None, returns the root logger.
|
47 |
+
|
48 |
+
Returns:
|
49 |
+
logging.Logger: Logger instance.
|
50 |
+
"""
|
51 |
+
return logging.getLogger(name)
|
tests/test_model.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Tests for the LLaVA model implementation.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import pytest
|
6 |
+
from PIL import Image
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from src.models.llava_model import LLaVAModel
|
10 |
+
from src.configs.settings import DEFAULT_MAX_NEW_TOKENS, DEFAULT_TEMPERATURE, DEFAULT_TOP_P
|
11 |
+
|
12 |
+
@pytest.fixture
|
13 |
+
def model():
|
14 |
+
"""Fixture to provide a model instance."""
|
15 |
+
return LLaVAModel()
|
16 |
+
|
17 |
+
@pytest.fixture
|
18 |
+
def sample_image():
|
19 |
+
"""Fixture to provide a sample image."""
|
20 |
+
# Create a simple test image
|
21 |
+
return Image.new('RGB', (224, 224), color='red')
|
22 |
+
|
23 |
+
def test_model_initialization(model):
|
24 |
+
"""Test that the model initializes correctly."""
|
25 |
+
assert model is not None
|
26 |
+
assert model.processor is not None
|
27 |
+
assert model.model is not None
|
28 |
+
|
29 |
+
def test_model_device(model):
|
30 |
+
"""Test that the model is on the correct device."""
|
31 |
+
assert next(model.model.parameters()).device.type in ['cuda', 'cpu']
|
32 |
+
|
33 |
+
def test_generate_response(model, sample_image):
|
34 |
+
"""Test that the model can generate responses."""
|
35 |
+
prompt = "What color is this image?"
|
36 |
+
response = model.generate_response(
|
37 |
+
image=sample_image,
|
38 |
+
prompt=prompt,
|
39 |
+
max_new_tokens=DEFAULT_MAX_NEW_TOKENS,
|
40 |
+
temperature=DEFAULT_TEMPERATURE,
|
41 |
+
top_p=DEFAULT_TOP_P
|
42 |
+
)
|
43 |
+
|
44 |
+
assert isinstance(response, str)
|
45 |
+
assert len(response) > 0
|
46 |
+
|
47 |
+
def test_generate_response_with_invalid_image(model):
|
48 |
+
"""Test that the model handles invalid images correctly."""
|
49 |
+
with pytest.raises(Exception):
|
50 |
+
model.generate_response(
|
51 |
+
image=None,
|
52 |
+
prompt="What color is this image?",
|
53 |
+
max_new_tokens=DEFAULT_MAX_NEW_TOKENS,
|
54 |
+
temperature=DEFAULT_TEMPERATURE,
|
55 |
+
top_p=DEFAULT_TOP_P
|
56 |
+
)
|
57 |
+
|
58 |
+
def test_generate_response_with_empty_prompt(model, sample_image):
|
59 |
+
"""Test that the model handles empty prompts correctly."""
|
60 |
+
with pytest.raises(Exception):
|
61 |
+
model.generate_response(
|
62 |
+
image=sample_image,
|
63 |
+
prompt="",
|
64 |
+
max_new_tokens=DEFAULT_MAX_NEW_TOKENS,
|
65 |
+
temperature=DEFAULT_TEMPERATURE,
|
66 |
+
top_p=DEFAULT_TOP_P
|
67 |
+
)
|