Spaces:
Build error
Build error
Commit
·
e16e45f
1
Parent(s):
99216fc
Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- LICENSE +9 -0
- README.md +256 -7
- __pycache__/edit_dataset.cpython-38.pyc +0 -0
- __pycache__/main.cpython-38.pyc +0 -0
- configs/generate.yaml +99 -0
- configs/train.yaml +113 -0
- dataset_creation/generate_img_dataset.py +315 -0
- dataset_creation/generate_txt_dataset.py +113 -0
- dataset_creation/prepare_dataset.py +29 -0
- dataset_creation/prepare_for_gpt.py +25 -0
- edit_app.py +268 -0
- edit_cli.py +128 -0
- edit_dataset.py +121 -0
- environment.yaml +38 -0
- imgs/dataset.jpg +0 -0
- imgs/edit_app.jpg +0 -0
- imgs/example.jpg +0 -0
- imgs/prompt_app.jpg +0 -0
- logs/train_default/checkpoints/epoch=001542.ckpt +3 -0
- logs/train_default/checkpoints/last.ckpt +3 -0
- logs/train_default/checkpoints/trainstep_checkpoints/epoch=000333-step=000000999.ckpt +3 -0
- logs/train_default/checkpoints/trainstep_checkpoints/epoch=000666-step=000001999.ckpt +3 -0
- logs/train_default/checkpoints/trainstep_checkpoints/epoch=000999-step=000002999.ckpt +3 -0
- logs/train_default/checkpoints/trainstep_checkpoints/epoch=001333-step=000003999.ckpt +3 -0
- logs/train_default/configs/2023-06-30T02-08-15-lightning.yaml +15 -0
- logs/train_default/configs/2023-06-30T02-08-15-project.yaml +94 -0
- logs/train_default/configs/2023-06-30T02-17-16-lightning.yaml +16 -0
- logs/train_default/configs/2023-06-30T02-17-16-project.yaml +94 -0
- logs/train_default/configs/2023-06-30T05-33-22-lightning.yaml +16 -0
- logs/train_default/configs/2023-06-30T05-33-22-project.yaml +94 -0
- logs/train_default/configs/2023-07-03T07-00-36-lightning.yaml +15 -0
- logs/train_default/configs/2023-07-03T07-00-36-project.yaml +94 -0
- logs/train_default/configs/2023-07-03T07-11-08-lightning.yaml +15 -0
- logs/train_default/configs/2023-07-03T07-11-08-project.yaml +94 -0
- logs/train_default/images/train/gs-002000_e-000666_b-000008_after-gen.png +0 -0
- logs/train_default/images/train/gs-002000_e-000666_b-000008_after.png +0 -0
- logs/train_default/images/train/gs-002000_e-000666_b-000008_before-vq.png +0 -0
- logs/train_default/images/train/gs-002000_e-000666_b-000008_before.png +0 -0
- logs/train_default/images/train/gs-002000_e-000666_b-000008_prompt.json +8 -0
- logs/train_default/images/train/gs-002000_e-000666_b-000009_after-gen.png +0 -0
- logs/train_default/images/train/gs-002000_e-000666_b-000009_after.png +0 -0
- logs/train_default/images/train/gs-002000_e-000666_b-000009_before-vq.png +0 -0
- logs/train_default/images/train/gs-002000_e-000666_b-000009_before.png +0 -0
- logs/train_default/images/train/gs-002000_e-000666_b-000009_prompt.json +8 -0
- logs/train_default/images/train/gs-002000_e-000666_b-000010_after-gen.png +0 -0
- logs/train_default/images/train/gs-002000_e-000666_b-000010_after.png +0 -0
- logs/train_default/images/train/gs-002000_e-000666_b-000010_before-vq.png +0 -0
- logs/train_default/images/train/gs-002000_e-000666_b-000010_before.png +0 -0
- logs/train_default/images/train/gs-002000_e-000666_b-000010_prompt.json +8 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
logs/train_default/wandb/offline-run-20230703_071204-train_default/run-train_default.wandb filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright 2023 Timothy Brooks, Aleksander Holynski, Alexei A. Efros
|
2 |
+
|
3 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
4 |
+
|
5 |
+
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
6 |
+
|
7 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
8 |
+
|
9 |
+
Portions of code and models (such as pretrained checkpoints, which are fine-tuned starting from released Stable Diffusion checkpoints) are derived from the Stable Diffusion codebase (https://github.com/CompVis/stable-diffusion). Further restrictions may apply. Please consult the Stable Diffusion license `stable_diffusion/LICENSE`. Modified code is denoted as such in comments at the start of each file.
|
README.md
CHANGED
@@ -1,12 +1,261 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
|
4 |
-
colorFrom: blue
|
5 |
-
colorTo: yellow
|
6 |
sdk: gradio
|
7 |
sdk_version: 3.35.2
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
---
|
|
|
|
|
|
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
title: instruct-pix2pix
|
3 |
+
app_file: edit_app.py
|
|
|
|
|
4 |
sdk: gradio
|
5 |
sdk_version: 3.35.2
|
|
|
|
|
6 |
---
|
7 |
+
# InstructPix2Pix: Learning to Follow Image Editing Instructions
|
8 |
+
### [Project Page](https://www.timothybrooks.com/instruct-pix2pix/) | [Paper](https://arxiv.org/abs/2211.09800) | [Data](http://instruct-pix2pix.eecs.berkeley.edu/)
|
9 |
+
PyTorch implementation of InstructPix2Pix, an instruction-based image editing model, based on the original [CompVis/stable_diffusion](https://github.com/CompVis/stable-diffusion) repo. <br>
|
10 |
|
11 |
+
[InstructPix2Pix: Learning to Follow Image Editing Instructions](https://www.timothybrooks.com/instruct-pix2pix/)
|
12 |
+
[Tim Brooks](https://www.timothybrooks.com/)\*,
|
13 |
+
[Aleksander Holynski](https://holynski.org/)\*,
|
14 |
+
[Alexei A. Efros](https://people.eecs.berkeley.edu/~efros/) <br>
|
15 |
+
UC Berkeley <br>
|
16 |
+
\*denotes equal contribution
|
17 |
+
|
18 |
+
<img src='https://instruct-pix2pix.timothybrooks.com/teaser.jpg'/>
|
19 |
+
|
20 |
+
## TL;DR: quickstart
|
21 |
+
|
22 |
+
Follow the instructions below to download and run InstructPix2Pix on your own images. These instructions have been tested on a GPU with >18GB VRAM. If you don't have a GPU, you may need to change the default configuration, or check out [other ways of using the model](https://github.com/timothybrooks/instruct-pix2pix#other-ways-of-using-instructpix2pix).
|
23 |
+
|
24 |
+
### Set up a conda environment, and download a pretrained model:
|
25 |
+
```
|
26 |
+
conda env create -f environment.yaml
|
27 |
+
conda activate ip2p
|
28 |
+
bash scripts/download_checkpoints.sh
|
29 |
+
```
|
30 |
+
|
31 |
+
### Edit a single image:
|
32 |
+
```
|
33 |
+
python edit_cli.py --input imgs/example.jpg --output imgs/output.jpg --edit "turn him into a cyborg"
|
34 |
+
|
35 |
+
# Optionally, you can specify parameters to tune your result:
|
36 |
+
# python edit_cli.py --steps 100 --resolution 512 --seed 1371 --cfg-text 7.5 --cfg-image 1.2 --input imgs/example.jpg --output imgs/output.jpg --edit "turn him into a cyborg"
|
37 |
+
```
|
38 |
+
|
39 |
+
### Or launch your own interactive editing Gradio app:
|
40 |
+
```
|
41 |
+
python edit_app.py
|
42 |
+
```
|
43 |
+

|
44 |
+
|
45 |
+
_(For advice on how to get the best results by tuning parameters, see the [Tips](https://github.com/timothybrooks/instruct-pix2pix#tips) section)._
|
46 |
+
|
47 |
+
## Setup
|
48 |
+
|
49 |
+
Install all dependencies with:
|
50 |
+
```
|
51 |
+
conda env create -f environment.yaml
|
52 |
+
```
|
53 |
+
|
54 |
+
Download the pretrained models by running:
|
55 |
+
```
|
56 |
+
bash scripts/download_checkpoints.sh
|
57 |
+
```
|
58 |
+
|
59 |
+
## Generated Dataset
|
60 |
+
|
61 |
+
Our image editing model is trained on a generated dataset consisting of 454,445 examples. Each example contains (1) an input image, (2) an editing instruction, and (3) an output edited image. We provide two versions of the dataset, one in which each pair of edited images is generated 100 times, and the best examples are chosen based on CLIP metrics (Section 3.1.2 in the paper) (`clip-filtered-dataset`), and one in which examples are randomly chosen (`random-sample-dataset`).
|
62 |
+
|
63 |
+
For the released version of this dataset, we've additionally filtered prompts and images for NSFW content. After NSFW filtering, the GPT-3 generated dataset contains 451,990 examples. The final image-pair datasets contain:
|
64 |
+
|
65 |
+
| | # of image editing examples | Dataset size |
|
66 |
+
|--|-----------------------|----------------------- |
|
67 |
+
| `random-sample-dataset` |451990|727GB|
|
68 |
+
| `clip-filtered-dataset` |313010|436GB|
|
69 |
+
|
70 |
+
To download one of these datasets, along with the entire NSFW-filtered text data, run the following command with the appropriate dataset name:
|
71 |
+
|
72 |
+
```
|
73 |
+
bash scripts/download_data.sh clip-filtered-dataset
|
74 |
+
```
|
75 |
+
|
76 |
+
|
77 |
+
## Training InstructPix2Pix
|
78 |
+
|
79 |
+
InstructPix2Pix is trained by fine-tuning from an initial StableDiffusion checkpoint. The first step is to download a Stable Diffusion checkpoint. For our trained models, we used the v1.5 checkpoint as the starting point. To download the same ones we used, you can run the following script:
|
80 |
+
```
|
81 |
+
bash scripts/download_pretrained_sd.sh
|
82 |
+
```
|
83 |
+
If you'd like to use a different checkpoint, point to it in the config file `configs/train.yaml`, on line 8, after `ckpt_path:`.
|
84 |
+
|
85 |
+
Next, we need to change the config to point to our downloaded (or generated) dataset. If you're using the `clip-filtered-dataset` from above, you can skip this. Otherwise, you may need to edit lines 85 and 94 of the config (`data.params.train.params.path`, `data.params.validation.params.path`).
|
86 |
+
|
87 |
+
Finally, start a training job with the following command:
|
88 |
+
|
89 |
+
```
|
90 |
+
python main.py --name default --base configs/train.yaml --train --gpus 0,1,2,3,4,5,6,7
|
91 |
+
```
|
92 |
+
|
93 |
+
|
94 |
+
## Creating your own dataset
|
95 |
+
|
96 |
+
Our generated dataset of paired images and editing instructions is made in two phases: First, we use GPT-3 to generate text triplets: (a) a caption describing an image, (b) an edit instruction, (c) a caption describing the image after the edit. Then, we turn pairs of captions (before/after the edit) into pairs of images using Stable Diffusion and Prompt-to-Prompt.
|
97 |
+
|
98 |
+
### (1) Generate a dataset of captions and instructions
|
99 |
+
|
100 |
+
We provide our generated dataset of captions and edit instructions [here](https://instruct-pix2pix.eecs.berkeley.edu/gpt-generated-prompts.jsonl). If you plan to use our captions+instructions, skip to step (2). Otherwise, if you would like to create your own text dataset, please follow steps (1.1-1.3) below. Note that generating very large datasets using GPT-3 can be expensive.
|
101 |
+
|
102 |
+
#### (1.1) Manually write a dataset of instructions and captions
|
103 |
+
|
104 |
+
The first step of the process is fine-tuning GPT-3. To do this, we made a dataset of 700 examples broadly covering of edits that we might want our model to be able to perform. Our examples are available [here](https://instruct-pix2pix.eecs.berkeley.edu/human-written-prompts.jsonl). These should be diverse and cover a wide range of possible captions and types of edits. Ideally, they should avoid duplication or significant overlap of captions and instructions. It is also important to be mindful of limitations of Stable Diffusion and Prompt-to-Prompt in writing these examples, such as inability to perform large spatial transformations (e.g., moving the camera, zooming in, swapping object locations).
|
105 |
+
|
106 |
+
Input prompts should closely match the distribution of input prompts used to generate the larger dataset. We sampled the 700 input prompts from the _LAION Improved Aesthetics 6.5+_ dataset and also use this dataset for generating examples. We found this dataset is quite noisy (many of the captions are overly long and contain irrelevant text). For this reason, we also considered MSCOCO and LAION-COCO datasets, but ultimately chose _LAION Improved Aesthetics 6.5+_ due to its diversity of content, proper nouns, and artistic mediums. If you choose to use another dataset or combination of datasets as input to GPT-3 when generating examples, we recommend you sample the input prompts from the same distribution when manually writing training examples.
|
107 |
+
|
108 |
+
#### (1.2) Finetune GPT-3
|
109 |
+
|
110 |
+
The next step is to finetune a large language model on the manually written instructions/outputs to generate edit instructions and edited caption from a new input caption. For this, we finetune GPT-3's Davinci model via the OpenAI API, although other language models could be used.
|
111 |
+
|
112 |
+
To prepare training data for GPT-3, one must first create an OpenAI developer account to access the needed APIs, and [set up the API keys on your local device](https://beta.openai.com/docs/api-reference/introduction). Also, run the `prompts/prepare_for_gpt.py` script, which forms the prompts into the correct format by concatenating instructions and captions and adding delimiters and stop sequences.
|
113 |
+
|
114 |
+
```bash
|
115 |
+
python dataset_creation/prepare_for_gpt.py --input-path data/human-written-prompts.jsonl --output-path data/human-written-prompts-for-gpt.jsonl
|
116 |
+
```
|
117 |
+
|
118 |
+
Next, finetune GPT-3 via the OpenAI CLI. We provide an example below, although please refer to OpenAI's official documentation for this, as best practices may change. We trained the Davinci model for a single epoch. You can experiment with smaller less expensive GPT-3 variants or with open source language models, although this may negatively affect performance.
|
119 |
+
|
120 |
+
```bash
|
121 |
+
openai api fine_tunes.create -t data/human-written-prompts-for-gpt.jsonl -m davinci --n_epochs 1 --suffix "instruct-pix2pix"
|
122 |
+
```
|
123 |
+
|
124 |
+
You can test out the finetuned GPT-3 model by launching the provided Gradio app:
|
125 |
+
|
126 |
+
```bash
|
127 |
+
python prompt_app.py --openai-api-key OPENAI_KEY --openai-model OPENAI_MODEL_NAME
|
128 |
+
```
|
129 |
+
|
130 |
+

|
131 |
+
|
132 |
+
#### (1.3) Generate a large dataset of captions and instructions
|
133 |
+
|
134 |
+
We now use the finetuned GPT-3 model to generate a large dataset. Our dataset cost thousands of dollars to create. See `prompts/gen_instructions_and_captions.py` for the script which generates these examples. We recommend first generating a small number of examples (by setting a low value of `--num-samples`) and gradually increasing the scale to ensure the results are working as desired before increasing scale.
|
135 |
+
|
136 |
+
```bash
|
137 |
+
python dataset_creation/generate_txt_dataset.py --openai-api-key OPENAI_KEY --openai-model OPENAI_MODEL_NAME
|
138 |
+
```
|
139 |
+
|
140 |
+
If you are generating at a very large scale (e.g., 100K+), it will be noteably faster to generate the dataset with multiple processes running in parallel. This can be accomplished by setting `--partitions=N` to a higher number and running multiple processes, setting each `--partition` to the corresponding value.
|
141 |
+
|
142 |
+
```bash
|
143 |
+
python dataset_creation/generate_txt_dataset.py --openai-api-key OPENAI_KEY --openai-model OPENAI_MODEL_NAME --partitions=10 --partition=0
|
144 |
+
```
|
145 |
+
|
146 |
+
### (2) Turn paired captions into paired images
|
147 |
+
|
148 |
+
The next step is to turn pairs of text captions into pairs of images. For this, we need to copy some pre-trained Stable Diffusion checkpoints to `stable_diffusion/models/ldm/stable-diffusion-v1/`. You may have already done this if you followed the instructions above for training with our provided data, but if not, you can do this by running:
|
149 |
+
|
150 |
+
```bash
|
151 |
+
bash scripts/download_pretrained_sd.sh
|
152 |
+
```
|
153 |
+
|
154 |
+
For our model, we used [checkpoint v1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned.ckpt), and the [new autoencoder](https://huggingface.co/stabilityai/sd-vae-ft-mse-original/resolve/main/vae-ft-mse-840000-ema-pruned.ckpt), but other models may work as well. If you choose to use other models, make sure to change point to the corresponding checkpoints by passing in the `--ckpt` and `--vae-ckpt` arguments. Once all checkpoints have been downloaded, we can generate the dataset with the following command:
|
155 |
+
|
156 |
+
```
|
157 |
+
python dataset_creation/generate_img_dataset.py --out_dir data/instruct-pix2pix-dataset-000 --prompts_file path/to/generated_prompts.jsonl
|
158 |
+
```
|
159 |
+
|
160 |
+
This command operates on a single GPU (typically a V100 or A100). To parallelize over many GPUs/machines, set `--n-partitions` to the total number of parallel jobs and `--partition` to the index of each job.
|
161 |
+
|
162 |
+
```
|
163 |
+
python dataset_creation/generate_img_dataset.py --out_dir data/instruct-pix2pix-dataset-000 --prompts_file path/to/generated_prompts.jsonl --n-partitions 100 --partition 0
|
164 |
+
```
|
165 |
+
|
166 |
+
The default parameters match that of our dataset, although in practice you can use a smaller number of steps (e.g., `--steps=25`) to generate high quality data faster. By default, we generate 100 samples per prompt and use CLIP filtering to keep a max of 4 per prompt. You can experiment with fewer samples by setting `--n-samples`. The command below turns off CLIP filtering entirely and is therefore faster:
|
167 |
+
|
168 |
+
```
|
169 |
+
python dataset_creation/generate_img_dataset.py --out_dir data/instruct-pix2pix-dataset-000 --prompts_file path/to/generated_prompts.jsonl --n-samples 4 --clip-threshold 0 --clip-dir-threshold 0 --clip-img-threshold 0 --n-partitions 100 --partition 0
|
170 |
+
```
|
171 |
+
|
172 |
+
After generating all of the dataset examples, run the following command below to create a list of the examples. This is needed for the dataset onject to efficiently be able to sample examples without needing to iterate over the entire dataset directory at the start of each training run.
|
173 |
+
|
174 |
+
```
|
175 |
+
python dataset_creation/prepare_dataset.py data/instruct-pix2pix-dataset-000
|
176 |
+
```
|
177 |
+
|
178 |
+
## Evaluation
|
179 |
+
|
180 |
+
To generate plots like the ones in Figures 8 and 10 in the paper, run the following command:
|
181 |
+
|
182 |
+
```
|
183 |
+
python metrics/compute_metrics.py --ckpt /path/to/your/model.ckpt
|
184 |
+
```
|
185 |
+
|
186 |
+
## Tips
|
187 |
+
|
188 |
+
If you're not getting the quality result you want, there may be a few reasons:
|
189 |
+
1. **Is the image not changing enough?** Your Image CFG weight may be too high. This value dictates how similar the output should be to the input. It's possible your edit requires larger changes from the original image, and your Image CFG weight isn't allowing that. Alternatively, your Text CFG weight may be too low. This value dictates how much to listen to the text instruction. The default Image CFG of 1.5 and Text CFG of 7.5 are a good starting point, but aren't necessarily optimal for each edit. Try:
|
190 |
+
* Decreasing the Image CFG weight, or
|
191 |
+
* Increasing the Text CFG weight, or
|
192 |
+
2. Conversely, **is the image changing too much**, such that the details in the original image aren't preserved? Try:
|
193 |
+
* Increasing the Image CFG weight, or
|
194 |
+
* Decreasing the Text CFG weight
|
195 |
+
3. Try generating results with different random seeds by setting "Randomize Seed" and running generation multiple times. You can also try setting "Randomize CFG" to sample new Text CFG and Image CFG values each time.
|
196 |
+
4. Rephrasing the instruction sometimes improves results (e.g., "turn him into a dog" vs. "make him a dog" vs. "as a dog").
|
197 |
+
5. Increasing the number of steps sometimes improves results.
|
198 |
+
6. Do faces look weird? The Stable Diffusion autoencoder has a hard time with faces that are small in the image. Try cropping the image so the face takes up a larger portion of the frame.
|
199 |
+
|
200 |
+
## Comments
|
201 |
+
|
202 |
+
- Our codebase is based on the [Stable Diffusion codebase](https://github.com/CompVis/stable-diffusion).
|
203 |
+
|
204 |
+
## BibTeX
|
205 |
+
|
206 |
+
```
|
207 |
+
@article{brooks2022instructpix2pix,
|
208 |
+
title={InstructPix2Pix: Learning to Follow Image Editing Instructions},
|
209 |
+
author={Brooks, Tim and Holynski, Aleksander and Efros, Alexei A},
|
210 |
+
journal={arXiv preprint arXiv:2211.09800},
|
211 |
+
year={2022}
|
212 |
+
}
|
213 |
+
```
|
214 |
+
## Other ways of using InstructPix2Pix
|
215 |
+
|
216 |
+
### InstructPix2Pix on [HuggingFace](https://huggingface.co/spaces/timbrooks/instruct-pix2pix):
|
217 |
+
> A browser-based version of the demo is available as a [HuggingFace space](https://huggingface.co/spaces/timbrooks/instruct-pix2pix). For this version, you only need a browser, a picture you want to edit, and an instruction! Note that this is a shared online demo, and processing time may be slower during peak utilization.
|
218 |
+
|
219 |
+
### InstructPix2Pix on [Replicate](https://replicate.com/timothybrooks/instruct-pix2pix):
|
220 |
+
> Replicate provides a production-ready cloud API for running the InstructPix2Pix model. You can run the model from any environment using a simple API call with cURL, Python, JavaScript, or your language of choice. Replicate also provides a web interface for running the model and sharing predictions.
|
221 |
+
|
222 |
+
### InstructPix2Pix in [Imaginairy](https://github.com/brycedrennan/imaginAIry#-edit-images-with-instructions-alone-by-instructpix2pix):
|
223 |
+
> Imaginairy offers another way of easily installing InstructPix2Pix with a single command. It can run on devices without GPUs (like a Macbook!).
|
224 |
+
> ```bash
|
225 |
+
> pip install imaginairy --upgrade
|
226 |
+
> aimg edit any-image.jpg --gif "turn him into a cyborg"
|
227 |
+
> ```
|
228 |
+
> It also offers an easy way to perform a bunch of edits on an image, and can save edits out to an animated GIF:
|
229 |
+
> ```
|
230 |
+
> aimg edit --gif --surprise-me pearl-earring.jpg
|
231 |
+
> ```
|
232 |
+
> <img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/7c05c3aae2740278978c5e84962b826e58201bac/assets/girl_with_a_pearl_earring_suprise.gif" width="512">
|
233 |
+
|
234 |
+
### InstructPix2Pix in [🧨 Diffusers](https://github.com/huggingface/diffusers):
|
235 |
+
|
236 |
+
> InstructPix2Pix in Diffusers is a bit more optimized, so it may be faster and more suitable for GPUs with less memory. Below are instructions for installing the library and editing an image:
|
237 |
+
> 1. Install diffusers and relevant dependencies:
|
238 |
+
>
|
239 |
+
> ```bash
|
240 |
+
> pip install transformers accelerate torch
|
241 |
+
>
|
242 |
+
> pip install git+https://github.com/huggingface/diffusers.git
|
243 |
+
> ```
|
244 |
+
>
|
245 |
+
> 2. Load the model and edit the image:
|
246 |
+
>
|
247 |
+
> ```python
|
248 |
+
>
|
249 |
+
> import torch
|
250 |
+
> from diffusers import StableDiffusionInstructPix2PixPipeline, EulerAncestralDiscreteScheduler
|
251 |
+
>
|
252 |
+
> model_id = "timbrooks/instruct-pix2pix"
|
253 |
+
> pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(model_id, torch_dtype=torch.float16, safety_checker=None)
|
254 |
+
> pipe.to("cuda")
|
255 |
+
> pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
|
256 |
+
> # `image` is an RGB PIL.Image
|
257 |
+
> images = pipe("turn him into cyborg", image=image).images
|
258 |
+
> images[0]
|
259 |
+
> ```
|
260 |
+
>
|
261 |
+
> For more information, check the docs [here](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/pix2pix).
|
__pycache__/edit_dataset.cpython-38.pyc
ADDED
Binary file (4.06 kB). View file
|
|
__pycache__/main.cpython-38.pyc
ADDED
Binary file (20.2 kB). View file
|
|
configs/generate.yaml
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion).
|
2 |
+
# See more details in LICENSE.
|
3 |
+
|
4 |
+
model:
|
5 |
+
base_learning_rate: 1.0e-04
|
6 |
+
target: ldm.models.diffusion.ddpm_edit.LatentDiffusion
|
7 |
+
params:
|
8 |
+
linear_start: 0.00085
|
9 |
+
linear_end: 0.0120
|
10 |
+
num_timesteps_cond: 1
|
11 |
+
log_every_t: 200
|
12 |
+
timesteps: 1000
|
13 |
+
first_stage_key: edited
|
14 |
+
cond_stage_key: edit
|
15 |
+
# image_size: 64
|
16 |
+
# image_size: 32
|
17 |
+
image_size: 16
|
18 |
+
channels: 4
|
19 |
+
cond_stage_trainable: false # Note: different from the one we trained before
|
20 |
+
conditioning_key: hybrid
|
21 |
+
monitor: val/loss_simple_ema
|
22 |
+
scale_factor: 0.18215
|
23 |
+
use_ema: true
|
24 |
+
load_ema: true
|
25 |
+
|
26 |
+
scheduler_config: # 10000 warmup steps
|
27 |
+
target: ldm.lr_scheduler.LambdaLinearScheduler
|
28 |
+
params:
|
29 |
+
warm_up_steps: [ 0 ]
|
30 |
+
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
31 |
+
f_start: [ 1.e-6 ]
|
32 |
+
f_max: [ 1. ]
|
33 |
+
f_min: [ 1. ]
|
34 |
+
|
35 |
+
unet_config:
|
36 |
+
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
37 |
+
params:
|
38 |
+
image_size: 32 # unused
|
39 |
+
in_channels: 8
|
40 |
+
out_channels: 4
|
41 |
+
model_channels: 320
|
42 |
+
attention_resolutions: [ 4, 2, 1 ]
|
43 |
+
num_res_blocks: 2
|
44 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
45 |
+
num_heads: 8
|
46 |
+
use_spatial_transformer: True
|
47 |
+
transformer_depth: 1
|
48 |
+
context_dim: 768
|
49 |
+
use_checkpoint: True
|
50 |
+
legacy: False
|
51 |
+
|
52 |
+
first_stage_config:
|
53 |
+
target: ldm.models.autoencoder.AutoencoderKL
|
54 |
+
params:
|
55 |
+
embed_dim: 4
|
56 |
+
monitor: val/rec_loss
|
57 |
+
ddconfig:
|
58 |
+
double_z: true
|
59 |
+
z_channels: 4
|
60 |
+
resolution: 256
|
61 |
+
in_channels: 3
|
62 |
+
out_ch: 3
|
63 |
+
ch: 128
|
64 |
+
ch_mult:
|
65 |
+
- 1
|
66 |
+
- 2
|
67 |
+
- 4
|
68 |
+
- 4
|
69 |
+
num_res_blocks: 2
|
70 |
+
attn_resolutions: []
|
71 |
+
dropout: 0.0
|
72 |
+
lossconfig:
|
73 |
+
target: torch.nn.Identity
|
74 |
+
|
75 |
+
cond_stage_config:
|
76 |
+
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
77 |
+
|
78 |
+
data:
|
79 |
+
target: main.DataModuleFromConfig
|
80 |
+
params:
|
81 |
+
batch_size: 128
|
82 |
+
num_workers: 1
|
83 |
+
wrap: false
|
84 |
+
validation:
|
85 |
+
target: edit_dataset.EditDataset
|
86 |
+
params:
|
87 |
+
path: data/clip-filtered-dataset
|
88 |
+
cache_dir: data/
|
89 |
+
cache_name: data_10k
|
90 |
+
split: val
|
91 |
+
min_text_sim: 0.2
|
92 |
+
min_image_sim: 0.75
|
93 |
+
min_direction_sim: 0.2
|
94 |
+
max_samples_per_prompt: 1
|
95 |
+
min_resize_res: 512
|
96 |
+
max_resize_res: 512
|
97 |
+
crop_res: 512
|
98 |
+
output_as_edit: False
|
99 |
+
real_input: True
|
configs/train.yaml
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion).
|
2 |
+
# See more details in LICENSE.
|
3 |
+
|
4 |
+
model:
|
5 |
+
base_learning_rate: 1.0e-04
|
6 |
+
target: ldm.models.diffusion.ddpm_edit.LatentDiffusion
|
7 |
+
params:
|
8 |
+
ckpt_path: /home/ugrad/epoch=000027.ckpt
|
9 |
+
linear_start: 0.00085
|
10 |
+
linear_end: 0.0120
|
11 |
+
num_timesteps_cond: 1
|
12 |
+
log_every_t: 200
|
13 |
+
timesteps: 1000
|
14 |
+
first_stage_key: edited
|
15 |
+
cond_stage_key: edit
|
16 |
+
image_size: 32
|
17 |
+
channels: 4
|
18 |
+
cond_stage_trainable: false # Note: different from the one we trained before
|
19 |
+
conditioning_key: hybrid
|
20 |
+
monitor: val/loss_simple_ema
|
21 |
+
scale_factor: 0.18215
|
22 |
+
use_ema: true
|
23 |
+
load_ema: false
|
24 |
+
|
25 |
+
scheduler_config: # 10000 warmup steps
|
26 |
+
target: ldm.lr_scheduler.LambdaLinearScheduler
|
27 |
+
params:
|
28 |
+
warm_up_steps: [ 0 ]
|
29 |
+
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
30 |
+
f_start: [ 1.e-6 ]
|
31 |
+
f_max: [ 1. ]
|
32 |
+
f_min: [ 1. ]
|
33 |
+
|
34 |
+
unet_config:
|
35 |
+
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
36 |
+
params:
|
37 |
+
image_size: 32 # unused
|
38 |
+
in_channels: 8
|
39 |
+
out_channels: 4
|
40 |
+
model_channels: 320
|
41 |
+
attention_resolutions: [ 4, 2, 1 ]
|
42 |
+
num_res_blocks: 2
|
43 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
44 |
+
num_heads: 8
|
45 |
+
use_spatial_transformer: True
|
46 |
+
transformer_depth: 1
|
47 |
+
context_dim: 768
|
48 |
+
use_checkpoint: True
|
49 |
+
legacy: False
|
50 |
+
|
51 |
+
first_stage_config:
|
52 |
+
target: ldm.models.autoencoder.AutoencoderKL
|
53 |
+
params:
|
54 |
+
embed_dim: 4
|
55 |
+
monitor: val/rec_loss
|
56 |
+
ddconfig:
|
57 |
+
double_z: true
|
58 |
+
z_channels: 4
|
59 |
+
resolution: 256
|
60 |
+
in_channels: 3
|
61 |
+
out_ch: 3
|
62 |
+
ch: 128
|
63 |
+
ch_mult:
|
64 |
+
- 1
|
65 |
+
- 2
|
66 |
+
- 4
|
67 |
+
- 4
|
68 |
+
num_res_blocks: 2
|
69 |
+
attn_resolutions: []
|
70 |
+
dropout: 0.0
|
71 |
+
lossconfig:
|
72 |
+
target: torch.nn.Identity
|
73 |
+
|
74 |
+
cond_stage_config:
|
75 |
+
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
76 |
+
|
77 |
+
data:
|
78 |
+
target: main.DataModuleFromConfig
|
79 |
+
params:
|
80 |
+
batch_size: 32
|
81 |
+
num_workers: 2
|
82 |
+
train:
|
83 |
+
target: edit_dataset.EditDataset
|
84 |
+
params:
|
85 |
+
path: /home/ugrad/ip2pdata
|
86 |
+
split: train
|
87 |
+
min_resize_res: 256
|
88 |
+
max_resize_res: 256
|
89 |
+
crop_res: 256
|
90 |
+
flip_prob: 0.5
|
91 |
+
# validation:
|
92 |
+
# target: edit_dataset.EditDataset
|
93 |
+
# params:
|
94 |
+
# path: data/clip-filtered-dataset
|
95 |
+
# split: val
|
96 |
+
# min_resize_res: 256
|
97 |
+
# max_resize_res: 256
|
98 |
+
# crop_res: 256
|
99 |
+
|
100 |
+
lightning:
|
101 |
+
callbacks:
|
102 |
+
image_logger:
|
103 |
+
target: main.ImageLogger
|
104 |
+
params:
|
105 |
+
batch_frequency: 2000
|
106 |
+
max_images: 2
|
107 |
+
increase_log_steps: False
|
108 |
+
|
109 |
+
trainer:
|
110 |
+
max_epochs: 2000
|
111 |
+
benchmark: True
|
112 |
+
accumulate_grad_batches: 4
|
113 |
+
check_val_every_n_epoch: 4
|
dataset_creation/generate_img_dataset.py
ADDED
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import sys
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
import k_diffusion
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
from einops import rearrange, repeat
|
11 |
+
from omegaconf import OmegaConf
|
12 |
+
from PIL import Image
|
13 |
+
from pytorch_lightning import seed_everything
|
14 |
+
from tqdm import tqdm
|
15 |
+
|
16 |
+
sys.path.append("./")
|
17 |
+
sys.path.append("./stable_diffusion")
|
18 |
+
|
19 |
+
from ldm.modules.attention import CrossAttention
|
20 |
+
from ldm.util import instantiate_from_config
|
21 |
+
from metrics.clip_similarity import ClipSimilarity
|
22 |
+
|
23 |
+
|
24 |
+
################################################################################
|
25 |
+
# Modified K-diffusion Euler ancestral sampler with prompt-to-prompt.
|
26 |
+
# https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py
|
27 |
+
|
28 |
+
|
29 |
+
def append_dims(x, target_dims):
|
30 |
+
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
|
31 |
+
dims_to_append = target_dims - x.ndim
|
32 |
+
if dims_to_append < 0:
|
33 |
+
raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
|
34 |
+
return x[(...,) + (None,) * dims_to_append]
|
35 |
+
|
36 |
+
|
37 |
+
def to_d(x, sigma, denoised):
|
38 |
+
"""Converts a denoiser output to a Karras ODE derivative."""
|
39 |
+
return (x - denoised) / append_dims(sigma, x.ndim)
|
40 |
+
|
41 |
+
|
42 |
+
def get_ancestral_step(sigma_from, sigma_to):
|
43 |
+
"""Calculates the noise level (sigma_down) to step down to and the amount
|
44 |
+
of noise to add (sigma_up) when doing an ancestral sampling step."""
|
45 |
+
sigma_up = min(sigma_to, (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5)
|
46 |
+
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
|
47 |
+
return sigma_down, sigma_up
|
48 |
+
|
49 |
+
|
50 |
+
def sample_euler_ancestral(model, x, sigmas, prompt2prompt_threshold=0.0, **extra_args):
|
51 |
+
"""Ancestral sampling with Euler method steps."""
|
52 |
+
s_in = x.new_ones([x.shape[0]])
|
53 |
+
for i in range(len(sigmas) - 1):
|
54 |
+
prompt_to_prompt = prompt2prompt_threshold > i / (len(sigmas) - 2)
|
55 |
+
for m in model.modules():
|
56 |
+
if isinstance(m, CrossAttention):
|
57 |
+
m.prompt_to_prompt = prompt_to_prompt
|
58 |
+
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
59 |
+
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
|
60 |
+
d = to_d(x, sigmas[i], denoised)
|
61 |
+
# Euler method
|
62 |
+
dt = sigma_down - sigmas[i]
|
63 |
+
x = x + d * dt
|
64 |
+
if sigmas[i + 1] > 0:
|
65 |
+
# Make noise the same across all samples in batch.
|
66 |
+
x = x + torch.randn_like(x[:1]) * sigma_up
|
67 |
+
return x
|
68 |
+
|
69 |
+
|
70 |
+
################################################################################
|
71 |
+
|
72 |
+
|
73 |
+
def load_model_from_config(config, ckpt, vae_ckpt=None, verbose=False):
|
74 |
+
print(f"Loading model from {ckpt}")
|
75 |
+
pl_sd = torch.load(ckpt, map_location="cpu")
|
76 |
+
if "global_step" in pl_sd:
|
77 |
+
print(f"Global Step: {pl_sd['global_step']}")
|
78 |
+
sd = pl_sd["state_dict"]
|
79 |
+
if vae_ckpt is not None:
|
80 |
+
print(f"Loading VAE from {vae_ckpt}")
|
81 |
+
vae_sd = torch.load(vae_ckpt, map_location="cpu")["state_dict"]
|
82 |
+
sd = {
|
83 |
+
k: vae_sd[k[len("first_stage_model.") :]] if k.startswith("first_stage_model.") else v
|
84 |
+
for k, v in sd.items()
|
85 |
+
}
|
86 |
+
model = instantiate_from_config(config.model)
|
87 |
+
m, u = model.load_state_dict(sd, strict=False)
|
88 |
+
if len(m) > 0 and verbose:
|
89 |
+
print("missing keys:")
|
90 |
+
print(m)
|
91 |
+
if len(u) > 0 and verbose:
|
92 |
+
print("unexpected keys:")
|
93 |
+
print(u)
|
94 |
+
return model
|
95 |
+
|
96 |
+
|
97 |
+
class CFGDenoiser(nn.Module):
|
98 |
+
def __init__(self, model):
|
99 |
+
super().__init__()
|
100 |
+
self.inner_model = model
|
101 |
+
|
102 |
+
def forward(self, x, sigma, uncond, cond, cfg_scale):
|
103 |
+
x_in = torch.cat([x] * 2)
|
104 |
+
sigma_in = torch.cat([sigma] * 2)
|
105 |
+
cond_in = torch.cat([uncond, cond])
|
106 |
+
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
|
107 |
+
return uncond + (cond - uncond) * cfg_scale
|
108 |
+
|
109 |
+
|
110 |
+
def to_pil(image: torch.Tensor) -> Image.Image:
|
111 |
+
image = 255.0 * rearrange(image.cpu().numpy(), "c h w -> h w c")
|
112 |
+
image = Image.fromarray(image.astype(np.uint8))
|
113 |
+
return image
|
114 |
+
|
115 |
+
|
116 |
+
def main():
|
117 |
+
parser = argparse.ArgumentParser()
|
118 |
+
parser.add_argument(
|
119 |
+
"--out_dir",
|
120 |
+
type=str,
|
121 |
+
required=True,
|
122 |
+
help="Path to output dataset directory.",
|
123 |
+
)
|
124 |
+
parser.add_argument(
|
125 |
+
"--prompts_file",
|
126 |
+
type=str,
|
127 |
+
required=True,
|
128 |
+
help="Path to prompts .jsonl file.",
|
129 |
+
)
|
130 |
+
parser.add_argument(
|
131 |
+
"--ckpt",
|
132 |
+
type=str,
|
133 |
+
default="stable_diffusion/models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt",
|
134 |
+
help="Path to stable diffusion checkpoint.",
|
135 |
+
)
|
136 |
+
parser.add_argument(
|
137 |
+
"--vae-ckpt",
|
138 |
+
type=str,
|
139 |
+
default="stable_diffusion/models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt",
|
140 |
+
help="Path to vae checkpoint.",
|
141 |
+
)
|
142 |
+
parser.add_argument(
|
143 |
+
"--steps",
|
144 |
+
type=int,
|
145 |
+
default=100,
|
146 |
+
help="Number of sampling steps.",
|
147 |
+
)
|
148 |
+
parser.add_argument(
|
149 |
+
"--n-samples",
|
150 |
+
type=int,
|
151 |
+
default=100,
|
152 |
+
help="Number of samples to generate per prompt (before CLIP filtering).",
|
153 |
+
)
|
154 |
+
parser.add_argument(
|
155 |
+
"--max-out-samples",
|
156 |
+
type=int,
|
157 |
+
default=4,
|
158 |
+
help="Max number of output samples to save per prompt (after CLIP filtering).",
|
159 |
+
)
|
160 |
+
parser.add_argument(
|
161 |
+
"--n-partitions",
|
162 |
+
type=int,
|
163 |
+
default=1,
|
164 |
+
help="Number of total partitions.",
|
165 |
+
)
|
166 |
+
parser.add_argument(
|
167 |
+
"--partition",
|
168 |
+
type=int,
|
169 |
+
default=0,
|
170 |
+
help="Partition index.",
|
171 |
+
)
|
172 |
+
parser.add_argument(
|
173 |
+
"--min-p2p",
|
174 |
+
type=float,
|
175 |
+
default=0.1,
|
176 |
+
help="Min prompt2prompt threshold (portion of denoising for which to fix self attention maps).",
|
177 |
+
)
|
178 |
+
parser.add_argument(
|
179 |
+
"--max-p2p",
|
180 |
+
type=float,
|
181 |
+
default=0.9,
|
182 |
+
help="Max prompt2prompt threshold (portion of denoising for which to fix self attention maps).",
|
183 |
+
)
|
184 |
+
parser.add_argument(
|
185 |
+
"--min-cfg",
|
186 |
+
type=float,
|
187 |
+
default=7.5,
|
188 |
+
help="Min classifier free guidance scale.",
|
189 |
+
)
|
190 |
+
parser.add_argument(
|
191 |
+
"--max-cfg",
|
192 |
+
type=float,
|
193 |
+
default=15,
|
194 |
+
help="Max classifier free guidance scale.",
|
195 |
+
)
|
196 |
+
parser.add_argument(
|
197 |
+
"--clip-threshold",
|
198 |
+
type=float,
|
199 |
+
default=0.2,
|
200 |
+
help="CLIP threshold for text-image similarity of each image.",
|
201 |
+
)
|
202 |
+
parser.add_argument(
|
203 |
+
"--clip-dir-threshold",
|
204 |
+
type=float,
|
205 |
+
default=0.2,
|
206 |
+
help="Directional CLIP threshold for similarity of change between pairs of text and pairs of images.",
|
207 |
+
)
|
208 |
+
parser.add_argument(
|
209 |
+
"--clip-img-threshold",
|
210 |
+
type=float,
|
211 |
+
default=0.7,
|
212 |
+
help="CLIP threshold for image-image similarity.",
|
213 |
+
)
|
214 |
+
opt = parser.parse_args()
|
215 |
+
|
216 |
+
global_seed = torch.randint(1 << 32, ()).item()
|
217 |
+
print(f"Global seed: {global_seed}")
|
218 |
+
seed_everything(global_seed)
|
219 |
+
|
220 |
+
model = load_model_from_config(
|
221 |
+
OmegaConf.load("stable_diffusion/configs/stable-diffusion/v1-inference.yaml"),
|
222 |
+
ckpt=opt.ckpt,
|
223 |
+
vae_ckpt=opt.vae_ckpt,
|
224 |
+
)
|
225 |
+
model.cuda().eval()
|
226 |
+
model_wrap = k_diffusion.external.CompVisDenoiser(model)
|
227 |
+
|
228 |
+
clip_similarity = ClipSimilarity().cuda()
|
229 |
+
|
230 |
+
out_dir = Path(opt.out_dir)
|
231 |
+
out_dir.mkdir(exist_ok=True, parents=True)
|
232 |
+
|
233 |
+
with open(opt.prompts_file) as fp:
|
234 |
+
prompts = [json.loads(line) for line in fp]
|
235 |
+
|
236 |
+
print(f"Partition index {opt.partition} ({opt.partition + 1} / {opt.n_partitions})")
|
237 |
+
prompts = np.array_split(list(enumerate(prompts)), opt.n_partitions)[opt.partition]
|
238 |
+
|
239 |
+
with torch.no_grad(), torch.autocast("cuda"), model.ema_scope():
|
240 |
+
uncond = model.get_learned_conditioning(2 * [""])
|
241 |
+
sigmas = model_wrap.get_sigmas(opt.steps)
|
242 |
+
|
243 |
+
for i, prompt in tqdm(prompts, desc="Prompts"):
|
244 |
+
prompt_dir = out_dir.joinpath(f"{i:07d}")
|
245 |
+
prompt_dir.mkdir(exist_ok=True)
|
246 |
+
|
247 |
+
with open(prompt_dir.joinpath("prompt.json"), "w") as fp:
|
248 |
+
json.dump(prompt, fp)
|
249 |
+
|
250 |
+
cond = model.get_learned_conditioning([prompt["caption"], prompt["output"]])
|
251 |
+
results = {}
|
252 |
+
|
253 |
+
with tqdm(total=opt.n_samples, desc="Samples") as progress_bar:
|
254 |
+
|
255 |
+
while len(results) < opt.n_samples:
|
256 |
+
seed = torch.randint(1 << 32, ()).item()
|
257 |
+
if seed in results:
|
258 |
+
continue
|
259 |
+
torch.manual_seed(seed)
|
260 |
+
|
261 |
+
x = torch.randn(1, 4, 512 // 8, 512 // 8, device="cuda") * sigmas[0]
|
262 |
+
x = repeat(x, "1 ... -> n ...", n=2)
|
263 |
+
|
264 |
+
model_wrap_cfg = CFGDenoiser(model_wrap)
|
265 |
+
p2p_threshold = opt.min_p2p + torch.rand(()).item() * (opt.max_p2p - opt.min_p2p)
|
266 |
+
cfg_scale = opt.min_cfg + torch.rand(()).item() * (opt.max_cfg - opt.min_cfg)
|
267 |
+
extra_args = {"cond": cond, "uncond": uncond, "cfg_scale": cfg_scale}
|
268 |
+
samples_ddim = sample_euler_ancestral(model_wrap_cfg, x, sigmas, p2p_threshold, **extra_args)
|
269 |
+
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
270 |
+
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
271 |
+
|
272 |
+
x0 = x_samples_ddim[0]
|
273 |
+
x1 = x_samples_ddim[1]
|
274 |
+
|
275 |
+
clip_sim_0, clip_sim_1, clip_sim_dir, clip_sim_image = clip_similarity(
|
276 |
+
x0[None], x1[None], [prompt["caption"]], [prompt["output"]]
|
277 |
+
)
|
278 |
+
|
279 |
+
results[seed] = dict(
|
280 |
+
image_0=to_pil(x0),
|
281 |
+
image_1=to_pil(x1),
|
282 |
+
p2p_threshold=p2p_threshold,
|
283 |
+
cfg_scale=cfg_scale,
|
284 |
+
clip_sim_0=clip_sim_0[0].item(),
|
285 |
+
clip_sim_1=clip_sim_1[0].item(),
|
286 |
+
clip_sim_dir=clip_sim_dir[0].item(),
|
287 |
+
clip_sim_image=clip_sim_image[0].item(),
|
288 |
+
)
|
289 |
+
|
290 |
+
progress_bar.update()
|
291 |
+
|
292 |
+
# CLIP filter to get best samples for each prompt.
|
293 |
+
metadata = [
|
294 |
+
(result["clip_sim_dir"], seed)
|
295 |
+
for seed, result in results.items()
|
296 |
+
if result["clip_sim_image"] >= opt.clip_img_threshold
|
297 |
+
and result["clip_sim_dir"] >= opt.clip_dir_threshold
|
298 |
+
and result["clip_sim_0"] >= opt.clip_threshold
|
299 |
+
and result["clip_sim_1"] >= opt.clip_threshold
|
300 |
+
]
|
301 |
+
metadata.sort(reverse=True)
|
302 |
+
for _, seed in metadata[: opt.max_out_samples]:
|
303 |
+
result = results[seed]
|
304 |
+
image_0 = result.pop("image_0")
|
305 |
+
image_1 = result.pop("image_1")
|
306 |
+
image_0.save(prompt_dir.joinpath(f"{seed}_0.jpg"), quality=100)
|
307 |
+
image_1.save(prompt_dir.joinpath(f"{seed}_1.jpg"), quality=100)
|
308 |
+
with open(prompt_dir.joinpath(f"metadata.jsonl"), "a") as fp:
|
309 |
+
fp.write(f"{json.dumps(dict(seed=seed, **result))}\n")
|
310 |
+
|
311 |
+
print("Done.")
|
312 |
+
|
313 |
+
|
314 |
+
if __name__ == "__main__":
|
315 |
+
main()
|
dataset_creation/generate_txt_dataset.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import json
|
4 |
+
import time
|
5 |
+
from argparse import ArgumentParser
|
6 |
+
from pathlib import Path
|
7 |
+
from typing import Optional
|
8 |
+
|
9 |
+
import datasets
|
10 |
+
import numpy as np
|
11 |
+
import openai
|
12 |
+
from tqdm.auto import tqdm
|
13 |
+
|
14 |
+
|
15 |
+
DELIMITER_0 = "\n##\n"
|
16 |
+
DELIMITER_1 = "\n%%\n"
|
17 |
+
STOP = "\nEND"
|
18 |
+
|
19 |
+
|
20 |
+
def generate(
|
21 |
+
openai_model: str,
|
22 |
+
caption: str,
|
23 |
+
num_retries: int = 3,
|
24 |
+
max_tokens: int = 256,
|
25 |
+
temperature: float = 0.7,
|
26 |
+
top_p: float = 1.0,
|
27 |
+
frequency_penalty: float = 0.1,
|
28 |
+
presence_penalty: float = 0.0,
|
29 |
+
sleep_on_error: float = 1.0,
|
30 |
+
) -> Optional[tuple[str, str]]:
|
31 |
+
for _ in range(1 + num_retries):
|
32 |
+
try:
|
33 |
+
response = openai.Completion.create(
|
34 |
+
model=openai_model,
|
35 |
+
prompt=caption + DELIMITER_0,
|
36 |
+
temperature=temperature,
|
37 |
+
max_tokens=max_tokens,
|
38 |
+
top_p=top_p,
|
39 |
+
frequency_penalty=frequency_penalty,
|
40 |
+
presence_penalty=presence_penalty,
|
41 |
+
stop=[STOP],
|
42 |
+
)
|
43 |
+
except Exception as e:
|
44 |
+
print(e)
|
45 |
+
time.sleep(sleep_on_error)
|
46 |
+
continue
|
47 |
+
output = response["choices"][0]["text"].split(DELIMITER_1)
|
48 |
+
if len(output) == 2:
|
49 |
+
instruction, edited_caption = output
|
50 |
+
results = openai.Moderation.create([instruction, edited_caption])["results"]
|
51 |
+
if results[0]["flagged"] or results[1]["flagged"]:
|
52 |
+
continue
|
53 |
+
if caption.strip().strip(".!?").lower() != edited_caption.strip().strip(".!?").lower():
|
54 |
+
return instruction, edited_caption
|
55 |
+
|
56 |
+
|
57 |
+
def main(openai_model: str, num_samples: int, num_partitions: int, partition: int, seed: int):
|
58 |
+
dataset = datasets.load_dataset("ChristophSchuhmann/improved_aesthetics_6.5plus", split="train")
|
59 |
+
# Other datasets we considered that may be worth trying:
|
60 |
+
# dataset = datasets.load_dataset("ChristophSchuhmann/MS_COCO_2017_URL_TEXT", split="train")
|
61 |
+
# dataset = datasets.load_dataset("laion/laion-coco", split="train")
|
62 |
+
|
63 |
+
np.random.seed(seed)
|
64 |
+
permutation = np.array_split(np.random.permutation(len(dataset)), num_partitions)[partition]
|
65 |
+
dataset = dataset[permutation]
|
66 |
+
captions = dataset["TEXT"]
|
67 |
+
urls = dataset["URL"]
|
68 |
+
output_path = f"data/dataset=laion-aesthetics-6.5_model={openai_model}_samples={num_samples}_partition={partition}.jsonl" # fmt: skip
|
69 |
+
print(f"Prompt file path: {output_path}")
|
70 |
+
|
71 |
+
count = 0
|
72 |
+
caption_set = set()
|
73 |
+
url_set = set()
|
74 |
+
|
75 |
+
if Path(output_path).exists():
|
76 |
+
with open(output_path, "r") as f:
|
77 |
+
for line in tqdm(f, desc="Resuming from existing prompts"):
|
78 |
+
prompt = json.loads(line)
|
79 |
+
if prompt["caption"] not in caption_set and prompt["url"] not in url_set:
|
80 |
+
caption_set.add(prompt["caption"])
|
81 |
+
url_set.add(prompt["url"])
|
82 |
+
count += 1
|
83 |
+
|
84 |
+
with open(output_path, "a") as fp:
|
85 |
+
with tqdm(total=num_samples - count, desc="Generating instructions and edited captions") as progress_bar:
|
86 |
+
for caption, url in zip(captions, urls):
|
87 |
+
if caption in caption_set or url in url_set:
|
88 |
+
continue
|
89 |
+
if openai.Moderation.create(caption)["results"][0]["flagged"]:
|
90 |
+
continue
|
91 |
+
edit_output = generate(openai_model, caption)
|
92 |
+
if edit_output is not None:
|
93 |
+
edit, output = edit_output
|
94 |
+
fp.write(f"{json.dumps(dict(caption=caption, edit=edit, output=output, url=url))}\n")
|
95 |
+
count += 1
|
96 |
+
progress_bar.update()
|
97 |
+
caption_set.add(caption)
|
98 |
+
url_set.add(url)
|
99 |
+
if count == num_samples:
|
100 |
+
break
|
101 |
+
|
102 |
+
|
103 |
+
if __name__ == "__main__":
|
104 |
+
parser = ArgumentParser()
|
105 |
+
parser.add_argument("--openai-api-key", required=True, type=str)
|
106 |
+
parser.add_argument("--openai-model", required=True, type=str)
|
107 |
+
parser.add_argument("--num-samples", default=10000, type=int)
|
108 |
+
parser.add_argument("--num-partitions", default=1, type=int)
|
109 |
+
parser.add_argument("--partition", default=0, type=int)
|
110 |
+
parser.add_argument("--seed", default=0, type=int)
|
111 |
+
args = parser.parse_args()
|
112 |
+
openai.api_key = args.openai_api_key
|
113 |
+
main(args.openai_model, args.num_samples, args.num_partitions, args.partition, args.seed)
|
dataset_creation/prepare_dataset.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from argparse import ArgumentParser
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
from tqdm.auto import tqdm
|
6 |
+
|
7 |
+
|
8 |
+
def main():
|
9 |
+
parser = ArgumentParser()
|
10 |
+
parser.add_argument("dataset_dir")
|
11 |
+
args = parser.parse_args()
|
12 |
+
dataset_dir = Path(args.dataset_dir)
|
13 |
+
|
14 |
+
seeds = []
|
15 |
+
with tqdm(desc="Listing dataset image seeds") as progress_bar:
|
16 |
+
for prompt_dir in dataset_dir.iterdir():
|
17 |
+
if prompt_dir.is_dir():
|
18 |
+
prompt_seeds = [image_path.name.split("_")[0] for image_path in sorted(prompt_dir.glob("*_0.jpg"))]
|
19 |
+
if len(prompt_seeds) > 0:
|
20 |
+
seeds.append((prompt_dir.name, prompt_seeds))
|
21 |
+
progress_bar.update()
|
22 |
+
seeds.sort()
|
23 |
+
|
24 |
+
with open(dataset_dir.joinpath("seeds.json"), "w") as f:
|
25 |
+
json.dump(seeds, f)
|
26 |
+
|
27 |
+
|
28 |
+
if __name__ == "__main__":
|
29 |
+
main()
|
dataset_creation/prepare_for_gpt.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from argparse import ArgumentParser
|
3 |
+
|
4 |
+
from generate_txt_dataset import DELIMITER_0, DELIMITER_1, STOP
|
5 |
+
|
6 |
+
|
7 |
+
def main(input_path: str, output_path: str):
|
8 |
+
with open(input_path) as f:
|
9 |
+
prompts = [json.loads(l) for l in f]
|
10 |
+
|
11 |
+
with open(output_path, "w") as f:
|
12 |
+
for prompt in prompts:
|
13 |
+
prompt_for_gpt = {
|
14 |
+
"prompt": f"{prompt['input']}{DELIMITER_0}",
|
15 |
+
"completion": f"{prompt['edit']}{DELIMITER_1}{prompt['output']}{STOP}",
|
16 |
+
}
|
17 |
+
f.write(f"{json.dumps(prompt_for_gpt)}\n")
|
18 |
+
|
19 |
+
|
20 |
+
if __name__ == "__main__":
|
21 |
+
parser = ArgumentParser()
|
22 |
+
parser.add_argument("--input-path", required=True, type=str)
|
23 |
+
parser.add_argument("--output-path", required=True, type=str)
|
24 |
+
args = parser.parse_args()
|
25 |
+
main(args.input_path, args.output_path)
|
edit_app.py
ADDED
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import math
|
4 |
+
import random
|
5 |
+
import sys
|
6 |
+
from argparse import ArgumentParser
|
7 |
+
|
8 |
+
import einops
|
9 |
+
import gradio as gr
|
10 |
+
import k_diffusion as K
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
from einops import rearrange
|
15 |
+
from omegaconf import OmegaConf
|
16 |
+
from PIL import Image, ImageOps
|
17 |
+
from torch import autocast
|
18 |
+
|
19 |
+
sys.path.append("./stable_diffusion")
|
20 |
+
|
21 |
+
from stable_diffusion.ldm.util import instantiate_from_config
|
22 |
+
|
23 |
+
|
24 |
+
help_text = """
|
25 |
+
If you're not getting what you want, there may be a few reasons:
|
26 |
+
1. Is the image not changing enough? Your Image CFG weight may be too high. This value dictates how similar the output should be to the input. It's possible your edit requires larger changes from the original image, and your Image CFG weight isn't allowing that. Alternatively, your Text CFG weight may be too low. This value dictates how much to listen to the text instruction. The default Image CFG of 1.5 and Text CFG of 7.5 are a good starting point, but aren't necessarily optimal for each edit. Try:
|
27 |
+
* Decreasing the Image CFG weight, or
|
28 |
+
* Incerasing the Text CFG weight, or
|
29 |
+
2. Conversely, is the image changing too much, such that the details in the original image aren't preserved? Try:
|
30 |
+
* Increasing the Image CFG weight, or
|
31 |
+
* Decreasing the Text CFG weight
|
32 |
+
3. Try generating results with different random seeds by setting "Randomize Seed" and running generation multiple times. You can also try setting "Randomize CFG" to sample new Text CFG and Image CFG values each time.
|
33 |
+
4. Rephrasing the instruction sometimes improves results (e.g., "turn him into a dog" vs. "make him a dog" vs. "as a dog").
|
34 |
+
5. Increasing the number of steps sometimes improves results.
|
35 |
+
6. Do faces look weird? The Stable Diffusion autoencoder has a hard time with faces that are small in the image. Try:
|
36 |
+
* Cropping the image so the face takes up a larger portion of the frame.
|
37 |
+
"""
|
38 |
+
|
39 |
+
|
40 |
+
example_instructions = [
|
41 |
+
"Make it a picasso painting",
|
42 |
+
"as if it were by modigliani",
|
43 |
+
"convert to a bronze statue",
|
44 |
+
"Turn it into an anime.",
|
45 |
+
"have it look like a graphic novel",
|
46 |
+
"make him gain weight",
|
47 |
+
"what would he look like bald?",
|
48 |
+
"Have him smile",
|
49 |
+
"Put him in a cocktail party.",
|
50 |
+
"move him at the beach.",
|
51 |
+
"add dramatic lighting",
|
52 |
+
"Convert to black and white",
|
53 |
+
"What if it were snowing?",
|
54 |
+
"Give him a leather jacket",
|
55 |
+
"Turn him into a cyborg!",
|
56 |
+
"make him wear a beanie",
|
57 |
+
]
|
58 |
+
|
59 |
+
|
60 |
+
class CFGDenoiser(nn.Module):
|
61 |
+
def __init__(self, model):
|
62 |
+
super().__init__()
|
63 |
+
self.inner_model = model
|
64 |
+
|
65 |
+
def forward(self, z, sigma, cond, uncond, text_cfg_scale, image_cfg_scale):
|
66 |
+
cfg_z = einops.repeat(z, "1 ... -> n ...", n=3)
|
67 |
+
cfg_sigma = einops.repeat(sigma, "1 ... -> n ...", n=3)
|
68 |
+
cfg_cond = {
|
69 |
+
"c_crossattn": [torch.cat([cond["c_crossattn"][0], uncond["c_crossattn"][0], uncond["c_crossattn"][0]])],
|
70 |
+
"c_concat": [torch.cat([cond["c_concat"][0], cond["c_concat"][0], uncond["c_concat"][0]])],
|
71 |
+
}
|
72 |
+
out_cond, out_img_cond, out_uncond = self.inner_model(cfg_z, cfg_sigma, cond=cfg_cond).chunk(3)
|
73 |
+
return out_uncond + text_cfg_scale * (out_cond - out_img_cond) + image_cfg_scale * (out_img_cond - out_uncond)
|
74 |
+
|
75 |
+
|
76 |
+
def load_model_from_config(config, ckpt, vae_ckpt=None, verbose=False):
|
77 |
+
print(f"Loading model from {ckpt}")
|
78 |
+
pl_sd = torch.load(ckpt, map_location="cpu")
|
79 |
+
if "global_step" in pl_sd:
|
80 |
+
print(f"Global Step: {pl_sd['global_step']}")
|
81 |
+
sd = pl_sd["state_dict"]
|
82 |
+
if vae_ckpt is not None:
|
83 |
+
print(f"Loading VAE from {vae_ckpt}")
|
84 |
+
vae_sd = torch.load(vae_ckpt, map_location="cpu")["state_dict"]
|
85 |
+
sd = {
|
86 |
+
k: vae_sd[k[len("first_stage_model.") :]] if k.startswith("first_stage_model.") else v
|
87 |
+
for k, v in sd.items()
|
88 |
+
}
|
89 |
+
model = instantiate_from_config(config.model)
|
90 |
+
m, u = model.load_state_dict(sd, strict=False)
|
91 |
+
if len(m) > 0 and verbose:
|
92 |
+
print("missing keys:")
|
93 |
+
print(m)
|
94 |
+
if len(u) > 0 and verbose:
|
95 |
+
print("unexpected keys:")
|
96 |
+
print(u)
|
97 |
+
return model
|
98 |
+
|
99 |
+
|
100 |
+
def main():
|
101 |
+
parser = ArgumentParser()
|
102 |
+
parser.add_argument("--resolution", default=512, type=int)
|
103 |
+
parser.add_argument("--config", default="configs/generate.yaml", type=str)
|
104 |
+
parser.add_argument("--ckpt", default="checkpoints/instruct-pix2pix-00-22000.ckpt", type=str)
|
105 |
+
parser.add_argument("--vae-ckpt", default=None, type=str)
|
106 |
+
args = parser.parse_args()
|
107 |
+
|
108 |
+
config = OmegaConf.load(args.config)
|
109 |
+
model = load_model_from_config(config, args.ckpt, args.vae_ckpt)
|
110 |
+
model.eval().cuda()
|
111 |
+
model_wrap = K.external.CompVisDenoiser(model)
|
112 |
+
model_wrap_cfg = CFGDenoiser(model_wrap)
|
113 |
+
null_token = model.get_learned_conditioning([""])
|
114 |
+
example_image = Image.open("imgs/example.jpg").convert("RGB")
|
115 |
+
|
116 |
+
def load_example(
|
117 |
+
steps: int,
|
118 |
+
randomize_seed: bool,
|
119 |
+
seed: int,
|
120 |
+
randomize_cfg: bool,
|
121 |
+
text_cfg_scale: float,
|
122 |
+
image_cfg_scale: float,
|
123 |
+
):
|
124 |
+
example_instruction = random.choice(example_instructions)
|
125 |
+
return [example_image, example_instruction] + generate(
|
126 |
+
example_image,
|
127 |
+
example_instruction,
|
128 |
+
steps,
|
129 |
+
randomize_seed,
|
130 |
+
seed,
|
131 |
+
randomize_cfg,
|
132 |
+
text_cfg_scale,
|
133 |
+
image_cfg_scale,
|
134 |
+
)
|
135 |
+
|
136 |
+
def generate(
|
137 |
+
input_image: Image.Image,
|
138 |
+
instruction: str,
|
139 |
+
steps: int,
|
140 |
+
randomize_seed: bool,
|
141 |
+
seed: int,
|
142 |
+
randomize_cfg: bool,
|
143 |
+
text_cfg_scale: float,
|
144 |
+
image_cfg_scale: float,
|
145 |
+
):
|
146 |
+
seed = random.randint(0, 100000) if randomize_seed else seed
|
147 |
+
text_cfg_scale = round(random.uniform(6.0, 9.0), ndigits=2) if randomize_cfg else text_cfg_scale
|
148 |
+
image_cfg_scale = round(random.uniform(1.2, 1.8), ndigits=2) if randomize_cfg else image_cfg_scale
|
149 |
+
|
150 |
+
width, height = input_image.size
|
151 |
+
factor = args.resolution / max(width, height)
|
152 |
+
factor = math.ceil(min(width, height) * factor / 64) * 64 / min(width, height)
|
153 |
+
width = int((width * factor) // 64) * 64
|
154 |
+
height = int((height * factor) // 64) * 64
|
155 |
+
input_image = ImageOps.fit(input_image, (width, height), method=Image.Resampling.LANCZOS)
|
156 |
+
|
157 |
+
if instruction == "":
|
158 |
+
return [input_image, seed]
|
159 |
+
|
160 |
+
with torch.no_grad(), autocast("cuda"), model.ema_scope():
|
161 |
+
cond = {}
|
162 |
+
cond["c_crossattn"] = [model.get_learned_conditioning([instruction])]
|
163 |
+
input_image = 2 * torch.tensor(np.array(input_image)).float() / 255 - 1
|
164 |
+
input_image = rearrange(input_image, "h w c -> 1 c h w").to(model.device)
|
165 |
+
cond["c_concat"] = [model.encode_first_stage(input_image).mode()]
|
166 |
+
|
167 |
+
uncond = {}
|
168 |
+
uncond["c_crossattn"] = [null_token]
|
169 |
+
uncond["c_concat"] = [torch.zeros_like(cond["c_concat"][0])]
|
170 |
+
|
171 |
+
sigmas = model_wrap.get_sigmas(steps)
|
172 |
+
|
173 |
+
extra_args = {
|
174 |
+
"cond": cond,
|
175 |
+
"uncond": uncond,
|
176 |
+
"text_cfg_scale": text_cfg_scale,
|
177 |
+
"image_cfg_scale": image_cfg_scale,
|
178 |
+
}
|
179 |
+
torch.manual_seed(seed)
|
180 |
+
z = torch.randn_like(cond["c_concat"][0]) * sigmas[0]
|
181 |
+
z = K.sampling.sample_euler_ancestral(model_wrap_cfg, z, sigmas, extra_args=extra_args)
|
182 |
+
x = model.decode_first_stage(z)
|
183 |
+
x = torch.clamp((x + 1.0) / 2.0, min=0.0, max=1.0)
|
184 |
+
x = 255.0 * rearrange(x, "1 c h w -> h w c")
|
185 |
+
edited_image = Image.fromarray(x.type(torch.uint8).cpu().numpy())
|
186 |
+
|
187 |
+
return [seed, text_cfg_scale, image_cfg_scale, edited_image]
|
188 |
+
|
189 |
+
def reset():
|
190 |
+
return [0, "Randomize Seed", 1371, "Fix CFG", 7.5, 1.5, None]
|
191 |
+
|
192 |
+
with gr.Blocks(css="footer {visibility: hidden}") as demo:
|
193 |
+
with gr.Row():
|
194 |
+
with gr.Column(scale=1, min_width=100):
|
195 |
+
generate_button = gr.Button("Generate")
|
196 |
+
with gr.Column(scale=1, min_width=100):
|
197 |
+
load_button = gr.Button("Load Example")
|
198 |
+
with gr.Column(scale=1, min_width=100):
|
199 |
+
reset_button = gr.Button("Reset")
|
200 |
+
with gr.Column(scale=3):
|
201 |
+
instruction = gr.Textbox(lines=1, label="Edit Instruction", interactive=True)
|
202 |
+
|
203 |
+
with gr.Row():
|
204 |
+
input_image = gr.Image(label="Input Image", type="pil", interactive=True)
|
205 |
+
edited_image = gr.Image(label=f"Edited Image", type="pil", interactive=False)
|
206 |
+
input_image.style(height=512, width=512)
|
207 |
+
edited_image.style(height=512, width=512)
|
208 |
+
|
209 |
+
with gr.Row():
|
210 |
+
steps = gr.Number(value=100, precision=0, label="Steps", interactive=True)
|
211 |
+
randomize_seed = gr.Radio(
|
212 |
+
["Fix Seed", "Randomize Seed"],
|
213 |
+
value="Randomize Seed",
|
214 |
+
type="index",
|
215 |
+
show_label=False,
|
216 |
+
interactive=True,
|
217 |
+
)
|
218 |
+
seed = gr.Number(value=1371, precision=0, label="Seed", interactive=True)
|
219 |
+
randomize_cfg = gr.Radio(
|
220 |
+
["Fix CFG", "Randomize CFG"],
|
221 |
+
value="Fix CFG",
|
222 |
+
type="index",
|
223 |
+
show_label=False,
|
224 |
+
interactive=True,
|
225 |
+
)
|
226 |
+
text_cfg_scale = gr.Number(value=7.5, label=f"Text CFG", interactive=True)
|
227 |
+
image_cfg_scale = gr.Number(value=1.5, label=f"Image CFG", interactive=True)
|
228 |
+
|
229 |
+
gr.Markdown(help_text)
|
230 |
+
|
231 |
+
load_button.click(
|
232 |
+
fn=load_example,
|
233 |
+
inputs=[
|
234 |
+
steps,
|
235 |
+
randomize_seed,
|
236 |
+
seed,
|
237 |
+
randomize_cfg,
|
238 |
+
text_cfg_scale,
|
239 |
+
image_cfg_scale,
|
240 |
+
],
|
241 |
+
outputs=[input_image, instruction, seed, text_cfg_scale, image_cfg_scale, edited_image],
|
242 |
+
)
|
243 |
+
generate_button.click(
|
244 |
+
fn=generate,
|
245 |
+
inputs=[
|
246 |
+
input_image,
|
247 |
+
instruction,
|
248 |
+
steps,
|
249 |
+
randomize_seed,
|
250 |
+
seed,
|
251 |
+
randomize_cfg,
|
252 |
+
text_cfg_scale,
|
253 |
+
image_cfg_scale,
|
254 |
+
],
|
255 |
+
outputs=[seed, text_cfg_scale, image_cfg_scale, edited_image],
|
256 |
+
)
|
257 |
+
reset_button.click(
|
258 |
+
fn=reset,
|
259 |
+
inputs=[],
|
260 |
+
outputs=[steps, randomize_seed, seed, randomize_cfg, text_cfg_scale, image_cfg_scale, edited_image],
|
261 |
+
)
|
262 |
+
|
263 |
+
demo.queue(concurrency_count=1)
|
264 |
+
demo.launch(share=True)
|
265 |
+
|
266 |
+
|
267 |
+
if __name__ == "__main__":
|
268 |
+
main()
|
edit_cli.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import math
|
4 |
+
import random
|
5 |
+
import sys
|
6 |
+
from argparse import ArgumentParser
|
7 |
+
|
8 |
+
import einops
|
9 |
+
import k_diffusion as K
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
from einops import rearrange
|
14 |
+
from omegaconf import OmegaConf
|
15 |
+
from PIL import Image, ImageOps
|
16 |
+
from torch import autocast
|
17 |
+
|
18 |
+
sys.path.append("./stable_diffusion")
|
19 |
+
|
20 |
+
from stable_diffusion.ldm.util import instantiate_from_config
|
21 |
+
|
22 |
+
|
23 |
+
class CFGDenoiser(nn.Module):
|
24 |
+
def __init__(self, model):
|
25 |
+
super().__init__()
|
26 |
+
self.inner_model = model
|
27 |
+
|
28 |
+
def forward(self, z, sigma, cond, uncond, text_cfg_scale, image_cfg_scale):
|
29 |
+
cfg_z = einops.repeat(z, "1 ... -> n ...", n=3)
|
30 |
+
cfg_sigma = einops.repeat(sigma, "1 ... -> n ...", n=3)
|
31 |
+
cfg_cond = {
|
32 |
+
"c_crossattn": [torch.cat([cond["c_crossattn"][0], uncond["c_crossattn"][0], uncond["c_crossattn"][0]])],
|
33 |
+
"c_concat": [torch.cat([cond["c_concat"][0], cond["c_concat"][0], uncond["c_concat"][0]])],
|
34 |
+
}
|
35 |
+
out_cond, out_img_cond, out_uncond = self.inner_model(cfg_z, cfg_sigma, cond=cfg_cond).chunk(3)
|
36 |
+
return out_uncond + text_cfg_scale * (out_cond - out_img_cond) + image_cfg_scale * (out_img_cond - out_uncond)
|
37 |
+
|
38 |
+
|
39 |
+
def load_model_from_config(config, ckpt, vae_ckpt=None, verbose=False):
|
40 |
+
print(f"Loading model from {ckpt}")
|
41 |
+
pl_sd = torch.load(ckpt, map_location="cpu")
|
42 |
+
if "global_step" in pl_sd:
|
43 |
+
print(f"Global Step: {pl_sd['global_step']}")
|
44 |
+
sd = pl_sd["state_dict"]
|
45 |
+
if vae_ckpt is not None:
|
46 |
+
print(f"Loading VAE from {vae_ckpt}")
|
47 |
+
vae_sd = torch.load(vae_ckpt, map_location="cpu")["state_dict"]
|
48 |
+
sd = {
|
49 |
+
k: vae_sd[k[len("first_stage_model.") :]] if k.startswith("first_stage_model.") else v
|
50 |
+
for k, v in sd.items()
|
51 |
+
}
|
52 |
+
model = instantiate_from_config(config.model)
|
53 |
+
m, u = model.load_state_dict(sd, strict=False)
|
54 |
+
if len(m) > 0 and verbose:
|
55 |
+
print("missing keys:")
|
56 |
+
print(m)
|
57 |
+
if len(u) > 0 and verbose:
|
58 |
+
print("unexpected keys:")
|
59 |
+
print(u)
|
60 |
+
return model
|
61 |
+
|
62 |
+
|
63 |
+
def main():
|
64 |
+
parser = ArgumentParser()
|
65 |
+
parser.add_argument("--resolution", default=512, type=int)
|
66 |
+
parser.add_argument("--steps", default=100, type=int)
|
67 |
+
parser.add_argument("--config", default="configs/generate.yaml", type=str)
|
68 |
+
parser.add_argument("--ckpt", default="checkpoints/instruct-pix2pix-00-22000.ckpt", type=str)
|
69 |
+
parser.add_argument("--vae-ckpt", default=None, type=str)
|
70 |
+
parser.add_argument("--input", required=True, type=str)
|
71 |
+
parser.add_argument("--output", required=True, type=str)
|
72 |
+
parser.add_argument("--edit", required=True, type=str)
|
73 |
+
parser.add_argument("--cfg-text", default=7.5, type=float)
|
74 |
+
parser.add_argument("--cfg-image", default=1.5, type=float)
|
75 |
+
parser.add_argument("--seed", type=int)
|
76 |
+
args = parser.parse_args()
|
77 |
+
|
78 |
+
config = OmegaConf.load(args.config)
|
79 |
+
model = load_model_from_config(config, args.ckpt, args.vae_ckpt)
|
80 |
+
model.eval().cuda()
|
81 |
+
model_wrap = K.external.CompVisDenoiser(model)
|
82 |
+
model_wrap_cfg = CFGDenoiser(model_wrap)
|
83 |
+
null_token = model.get_learned_conditioning([""])
|
84 |
+
|
85 |
+
seed = random.randint(0, 100000) if args.seed is None else args.seed
|
86 |
+
input_image = Image.open(args.input).convert("RGB")
|
87 |
+
width, height = input_image.size
|
88 |
+
factor = args.resolution / max(width, height)
|
89 |
+
factor = math.ceil(min(width, height) * factor / 64) * 64 / min(width, height)
|
90 |
+
width = int((width * factor) // 64) * 64
|
91 |
+
height = int((height * factor) // 64) * 64
|
92 |
+
input_image = ImageOps.fit(input_image, (width, height), method=Image.Resampling.LANCZOS)
|
93 |
+
|
94 |
+
if args.edit == "":
|
95 |
+
input_image.save(args.output)
|
96 |
+
return
|
97 |
+
|
98 |
+
with torch.no_grad(), autocast("cuda"), model.ema_scope():
|
99 |
+
cond = {}
|
100 |
+
cond["c_crossattn"] = [model.get_learned_conditioning([args.edit])]
|
101 |
+
input_image = 2 * torch.tensor(np.array(input_image)).float() / 255 - 1
|
102 |
+
input_image = rearrange(input_image, "h w c -> 1 c h w").to(model.device)
|
103 |
+
cond["c_concat"] = [model.encode_first_stage(input_image).mode()]
|
104 |
+
|
105 |
+
uncond = {}
|
106 |
+
uncond["c_crossattn"] = [null_token]
|
107 |
+
uncond["c_concat"] = [torch.zeros_like(cond["c_concat"][0])]
|
108 |
+
|
109 |
+
sigmas = model_wrap.get_sigmas(args.steps)
|
110 |
+
|
111 |
+
extra_args = {
|
112 |
+
"cond": cond,
|
113 |
+
"uncond": uncond,
|
114 |
+
"text_cfg_scale": args.cfg_text,
|
115 |
+
"image_cfg_scale": args.cfg_image,
|
116 |
+
}
|
117 |
+
torch.manual_seed(seed)
|
118 |
+
z = torch.randn_like(cond["c_concat"][0]) * sigmas[0]
|
119 |
+
z = K.sampling.sample_euler_ancestral(model_wrap_cfg, z, sigmas, extra_args=extra_args)
|
120 |
+
x = model.decode_first_stage(z)
|
121 |
+
x = torch.clamp((x + 1.0) / 2.0, min=0.0, max=1.0)
|
122 |
+
x = 255.0 * rearrange(x, "1 c h w -> h w c")
|
123 |
+
edited_image = Image.fromarray(x.type(torch.uint8).cpu().numpy())
|
124 |
+
edited_image.save(args.output)
|
125 |
+
|
126 |
+
|
127 |
+
if __name__ == "__main__":
|
128 |
+
main()
|
edit_dataset.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import json
|
4 |
+
import math
|
5 |
+
from pathlib import Path
|
6 |
+
from typing import Any
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
import torchvision
|
11 |
+
from einops import rearrange
|
12 |
+
from PIL import Image
|
13 |
+
from torch.utils.data import Dataset
|
14 |
+
|
15 |
+
|
16 |
+
class EditDataset(Dataset):
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
path: str,
|
20 |
+
split: str = "train",
|
21 |
+
splits: tuple[float, float, float] = (0.9, 0.05, 0.05),
|
22 |
+
min_resize_res: int = 256,
|
23 |
+
max_resize_res: int = 256,
|
24 |
+
crop_res: int = 256,
|
25 |
+
flip_prob: float = 0.0,
|
26 |
+
):
|
27 |
+
assert split in ("train", "val", "test")
|
28 |
+
assert sum(splits) == 1
|
29 |
+
self.path = path
|
30 |
+
self.min_resize_res = min_resize_res
|
31 |
+
self.max_resize_res = max_resize_res
|
32 |
+
self.crop_res = crop_res
|
33 |
+
self.flip_prob = flip_prob
|
34 |
+
|
35 |
+
with open(Path(self.path, "seeds.json")) as f:
|
36 |
+
self.seeds = json.load(f)
|
37 |
+
|
38 |
+
split_0, split_1 = {
|
39 |
+
"train": (0.0, splits[0]),
|
40 |
+
"val": (splits[0], splits[0] + splits[1]),
|
41 |
+
"test": (splits[0] + splits[1], 1.0),
|
42 |
+
}[split]
|
43 |
+
|
44 |
+
idx_0 = math.floor(split_0 * len(self.seeds))
|
45 |
+
idx_1 = math.floor(split_1 * len(self.seeds))
|
46 |
+
self.seeds = self.seeds[idx_0:idx_1]
|
47 |
+
|
48 |
+
def __len__(self) -> int:
|
49 |
+
return len(self.seeds)
|
50 |
+
|
51 |
+
def __getitem__(self, i: int) -> dict[str, Any]:
|
52 |
+
name, seeds = self.seeds[i]
|
53 |
+
propt_dir = Path(self.path, name)
|
54 |
+
seed = seeds[torch.randint(0, len(seeds), ()).item()]
|
55 |
+
with open(propt_dir.joinpath("prompt.json")) as fp:
|
56 |
+
prompt = json.load(fp)["edit"]
|
57 |
+
|
58 |
+
image_0 = Image.open(propt_dir.joinpath(f"{seed}_0.jpg"))
|
59 |
+
image_1 = Image.open(propt_dir.joinpath(f"{seed}_1.jpg"))
|
60 |
+
|
61 |
+
reize_res = torch.randint(self.min_resize_res, self.max_resize_res + 1, ()).item()
|
62 |
+
image_0 = image_0.resize((reize_res, reize_res), Image.Resampling.LANCZOS)
|
63 |
+
image_1 = image_1.resize((reize_res, reize_res), Image.Resampling.LANCZOS)
|
64 |
+
|
65 |
+
image_0 = rearrange(2 * torch.tensor(np.array(image_0)).float() / 255 - 1, "h w c -> c h w")
|
66 |
+
image_1 = rearrange(2 * torch.tensor(np.array(image_1)).float() / 255 - 1, "h w c -> c h w")
|
67 |
+
|
68 |
+
crop = torchvision.transforms.RandomCrop(self.crop_res)
|
69 |
+
flip = torchvision.transforms.RandomHorizontalFlip(float(self.flip_prob))
|
70 |
+
image_0, image_1 = flip(crop(torch.cat((image_0, image_1)))).chunk(2)
|
71 |
+
|
72 |
+
return dict(edited=image_1, edit=dict(c_concat=image_0, c_crossattn=prompt))
|
73 |
+
|
74 |
+
|
75 |
+
class EditDatasetEval(Dataset):
|
76 |
+
def __init__(
|
77 |
+
self,
|
78 |
+
path: str,
|
79 |
+
split: str = "train",
|
80 |
+
splits: tuple[float, float, float] = (0.9, 0.05, 0.05),
|
81 |
+
res: int = 256,
|
82 |
+
):
|
83 |
+
assert split in ("train", "val", "test")
|
84 |
+
assert sum(splits) == 1
|
85 |
+
self.path = path
|
86 |
+
self.res = res
|
87 |
+
|
88 |
+
with open(Path(self.path, "seeds.json")) as f:
|
89 |
+
self.seeds = json.load(f)
|
90 |
+
|
91 |
+
split_0, split_1 = {
|
92 |
+
"train": (0.0, splits[0]),
|
93 |
+
"val": (splits[0], splits[0] + splits[1]),
|
94 |
+
"test": (splits[0] + splits[1], 1.0),
|
95 |
+
}[split]
|
96 |
+
|
97 |
+
idx_0 = math.floor(split_0 * len(self.seeds))
|
98 |
+
idx_1 = math.floor(split_1 * len(self.seeds))
|
99 |
+
self.seeds = self.seeds[idx_0:idx_1]
|
100 |
+
|
101 |
+
def __len__(self) -> int:
|
102 |
+
return len(self.seeds)
|
103 |
+
|
104 |
+
def __getitem__(self, i: int) -> dict[str, Any]:
|
105 |
+
name, seeds = self.seeds[i]
|
106 |
+
propt_dir = Path(self.path, name)
|
107 |
+
seed = seeds[torch.randint(0, len(seeds), ()).item()]
|
108 |
+
with open(propt_dir.joinpath("prompt.json")) as fp:
|
109 |
+
prompt = json.load(fp)
|
110 |
+
edit = prompt["edit"]
|
111 |
+
input_prompt = prompt["input"]
|
112 |
+
output_prompt = prompt["output"]
|
113 |
+
|
114 |
+
image_0 = Image.open(propt_dir.joinpath(f"{seed}_0.jpg"))
|
115 |
+
|
116 |
+
reize_res = torch.randint(self.res, self.res + 1, ()).item()
|
117 |
+
image_0 = image_0.resize((reize_res, reize_res), Image.Resampling.LANCZOS)
|
118 |
+
|
119 |
+
image_0 = rearrange(2 * torch.tensor(np.array(image_0)).float() / 255 - 1, "h w c -> c h w")
|
120 |
+
|
121 |
+
return dict(image_0=image_0, input_prompt=input_prompt, edit=edit, output_prompt=output_prompt)
|
environment.yaml
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion).
|
2 |
+
# See more details in LICENSE.
|
3 |
+
|
4 |
+
name: ip2p
|
5 |
+
channels:
|
6 |
+
- pytorch
|
7 |
+
- defaults
|
8 |
+
dependencies:
|
9 |
+
- python=3.8.5
|
10 |
+
- pip=20.3
|
11 |
+
- cudatoolkit=11.3
|
12 |
+
- pytorch=1.11.0
|
13 |
+
- torchvision=0.12.0
|
14 |
+
- numpy=1.19.2
|
15 |
+
- pip:
|
16 |
+
- albumentations==0.4.3
|
17 |
+
- datasets==2.8.0
|
18 |
+
- diffusers
|
19 |
+
- opencv-python==4.1.2.30
|
20 |
+
- pudb==2019.2
|
21 |
+
- invisible-watermark
|
22 |
+
- imageio==2.9.0
|
23 |
+
- imageio-ffmpeg==0.4.2
|
24 |
+
- pytorch-lightning==1.4.2
|
25 |
+
- omegaconf==2.1.1
|
26 |
+
- test-tube>=0.7.5
|
27 |
+
- streamlit>=0.73.1
|
28 |
+
- einops==0.3.0
|
29 |
+
- torch-fidelity==0.3.0
|
30 |
+
- transformers==4.19.2
|
31 |
+
- torchmetrics==0.6.0
|
32 |
+
- kornia==0.6
|
33 |
+
- -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
|
34 |
+
- -e git+https://github.com/openai/CLIP.git@main#egg=clip
|
35 |
+
- openai
|
36 |
+
- gradio
|
37 |
+
- seaborn
|
38 |
+
- git+https://github.com/crowsonkb/k-diffusion.git
|
imgs/dataset.jpg
ADDED
![]() |
imgs/edit_app.jpg
ADDED
![]() |
imgs/example.jpg
ADDED
![]() |
imgs/prompt_app.jpg
ADDED
![]() |
logs/train_default/checkpoints/epoch=001542.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4b042c2533765a40ab643d3d3ac227ae7e1f2f6c0eef0f9e727157cb60992f94
|
3 |
+
size 14580612838
|
logs/train_default/checkpoints/last.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4b042c2533765a40ab643d3d3ac227ae7e1f2f6c0eef0f9e727157cb60992f94
|
3 |
+
size 14580612838
|
logs/train_default/checkpoints/trainstep_checkpoints/epoch=000333-step=000000999.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c6fe7ae000456bbce01d14b00d4d540709a35f4cefdaf393840e99a36a72902a
|
3 |
+
size 7703925478
|
logs/train_default/checkpoints/trainstep_checkpoints/epoch=000666-step=000001999.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:64e0b2e2e5962a48791dc80827e50385e9b3e956f6d5cf31c37542de53bbb36f
|
3 |
+
size 7703925478
|
logs/train_default/checkpoints/trainstep_checkpoints/epoch=000999-step=000002999.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ed206b27463d5108bdce74ce9d81188d94398987ae23d1d78cbbbd4b4c9074f2
|
3 |
+
size 7703925478
|
logs/train_default/checkpoints/trainstep_checkpoints/epoch=001333-step=000003999.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d5823eee77baa0123a1be40932cc0bb7667d309e0ab4ca773445b7bd996ac33f
|
3 |
+
size 7703925478
|
logs/train_default/configs/2023-06-30T02-08-15-lightning.yaml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
lightning:
|
2 |
+
callbacks:
|
3 |
+
image_logger:
|
4 |
+
target: main.ImageLogger
|
5 |
+
params:
|
6 |
+
batch_frequency: 2000
|
7 |
+
max_images: 2
|
8 |
+
increase_log_steps: false
|
9 |
+
trainer:
|
10 |
+
max_epochs: 2000
|
11 |
+
benchmark: true
|
12 |
+
accumulate_grad_batches: 4
|
13 |
+
check_val_every_n_epoch: 4
|
14 |
+
accelerator: ddp
|
15 |
+
gpus: 0,1,2,3
|
logs/train_default/configs/2023-06-30T02-08-15-project.yaml
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 0.0001
|
3 |
+
target: ldm.models.diffusion.ddpm_edit.LatentDiffusion
|
4 |
+
params:
|
5 |
+
ckpt_path: /home/ugrad/epoch=000027.ckpt
|
6 |
+
linear_start: 0.00085
|
7 |
+
linear_end: 0.012
|
8 |
+
num_timesteps_cond: 1
|
9 |
+
log_every_t: 200
|
10 |
+
timesteps: 1000
|
11 |
+
first_stage_key: edited
|
12 |
+
cond_stage_key: edit
|
13 |
+
image_size: 32
|
14 |
+
channels: 4
|
15 |
+
cond_stage_trainable: false
|
16 |
+
conditioning_key: hybrid
|
17 |
+
monitor: val/loss_simple_ema
|
18 |
+
scale_factor: 0.18215
|
19 |
+
use_ema: true
|
20 |
+
load_ema: false
|
21 |
+
scheduler_config:
|
22 |
+
target: ldm.lr_scheduler.LambdaLinearScheduler
|
23 |
+
params:
|
24 |
+
warm_up_steps:
|
25 |
+
- 0
|
26 |
+
cycle_lengths:
|
27 |
+
- 10000000000000
|
28 |
+
f_start:
|
29 |
+
- 1.0e-06
|
30 |
+
f_max:
|
31 |
+
- 1.0
|
32 |
+
f_min:
|
33 |
+
- 1.0
|
34 |
+
unet_config:
|
35 |
+
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
36 |
+
params:
|
37 |
+
image_size: 32
|
38 |
+
in_channels: 8
|
39 |
+
out_channels: 4
|
40 |
+
model_channels: 320
|
41 |
+
attention_resolutions:
|
42 |
+
- 4
|
43 |
+
- 2
|
44 |
+
- 1
|
45 |
+
num_res_blocks: 2
|
46 |
+
channel_mult:
|
47 |
+
- 1
|
48 |
+
- 2
|
49 |
+
- 4
|
50 |
+
- 4
|
51 |
+
num_heads: 8
|
52 |
+
use_spatial_transformer: true
|
53 |
+
transformer_depth: 1
|
54 |
+
context_dim: 768
|
55 |
+
use_checkpoint: true
|
56 |
+
legacy: false
|
57 |
+
first_stage_config:
|
58 |
+
target: ldm.models.autoencoder.AutoencoderKL
|
59 |
+
params:
|
60 |
+
embed_dim: 4
|
61 |
+
monitor: val/rec_loss
|
62 |
+
ddconfig:
|
63 |
+
double_z: true
|
64 |
+
z_channels: 4
|
65 |
+
resolution: 256
|
66 |
+
in_channels: 3
|
67 |
+
out_ch: 3
|
68 |
+
ch: 128
|
69 |
+
ch_mult:
|
70 |
+
- 1
|
71 |
+
- 2
|
72 |
+
- 4
|
73 |
+
- 4
|
74 |
+
num_res_blocks: 2
|
75 |
+
attn_resolutions: []
|
76 |
+
dropout: 0.0
|
77 |
+
lossconfig:
|
78 |
+
target: torch.nn.Identity
|
79 |
+
cond_stage_config:
|
80 |
+
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
81 |
+
data:
|
82 |
+
target: main.DataModuleFromConfig
|
83 |
+
params:
|
84 |
+
batch_size: 32
|
85 |
+
num_workers: 2
|
86 |
+
train:
|
87 |
+
target: edit_dataset.EditDataset
|
88 |
+
params:
|
89 |
+
path: /home/ugrad/ip2pdata
|
90 |
+
split: train
|
91 |
+
min_resize_res: 256
|
92 |
+
max_resize_res: 256
|
93 |
+
crop_res: 256
|
94 |
+
flip_prob: 0.5
|
logs/train_default/configs/2023-06-30T02-17-16-lightning.yaml
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
lightning:
|
2 |
+
callbacks:
|
3 |
+
image_logger:
|
4 |
+
target: main.ImageLogger
|
5 |
+
params:
|
6 |
+
batch_frequency: 2000
|
7 |
+
max_images: 2
|
8 |
+
increase_log_steps: false
|
9 |
+
trainer:
|
10 |
+
max_epochs: 2000
|
11 |
+
benchmark: true
|
12 |
+
accumulate_grad_batches: 4
|
13 |
+
check_val_every_n_epoch: 4
|
14 |
+
accelerator: ddp
|
15 |
+
gpus: 0,1,2,3
|
16 |
+
resume_from_checkpoint: logs/train_default/checkpoints/last.ckpt
|
logs/train_default/configs/2023-06-30T02-17-16-project.yaml
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 0.0001
|
3 |
+
target: ldm.models.diffusion.ddpm_edit.LatentDiffusion
|
4 |
+
params:
|
5 |
+
ckpt_path: /home/ugrad/epoch=000027.ckpt
|
6 |
+
linear_start: 0.00085
|
7 |
+
linear_end: 0.012
|
8 |
+
num_timesteps_cond: 1
|
9 |
+
log_every_t: 200
|
10 |
+
timesteps: 1000
|
11 |
+
first_stage_key: edited
|
12 |
+
cond_stage_key: edit
|
13 |
+
image_size: 32
|
14 |
+
channels: 4
|
15 |
+
cond_stage_trainable: false
|
16 |
+
conditioning_key: hybrid
|
17 |
+
monitor: val/loss_simple_ema
|
18 |
+
scale_factor: 0.18215
|
19 |
+
use_ema: true
|
20 |
+
load_ema: true
|
21 |
+
scheduler_config:
|
22 |
+
target: ldm.lr_scheduler.LambdaLinearScheduler
|
23 |
+
params:
|
24 |
+
warm_up_steps:
|
25 |
+
- 0
|
26 |
+
cycle_lengths:
|
27 |
+
- 10000000000000
|
28 |
+
f_start:
|
29 |
+
- 1.0e-06
|
30 |
+
f_max:
|
31 |
+
- 1.0
|
32 |
+
f_min:
|
33 |
+
- 1.0
|
34 |
+
unet_config:
|
35 |
+
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
36 |
+
params:
|
37 |
+
image_size: 32
|
38 |
+
in_channels: 8
|
39 |
+
out_channels: 4
|
40 |
+
model_channels: 320
|
41 |
+
attention_resolutions:
|
42 |
+
- 4
|
43 |
+
- 2
|
44 |
+
- 1
|
45 |
+
num_res_blocks: 2
|
46 |
+
channel_mult:
|
47 |
+
- 1
|
48 |
+
- 2
|
49 |
+
- 4
|
50 |
+
- 4
|
51 |
+
num_heads: 8
|
52 |
+
use_spatial_transformer: true
|
53 |
+
transformer_depth: 1
|
54 |
+
context_dim: 768
|
55 |
+
use_checkpoint: true
|
56 |
+
legacy: false
|
57 |
+
first_stage_config:
|
58 |
+
target: ldm.models.autoencoder.AutoencoderKL
|
59 |
+
params:
|
60 |
+
embed_dim: 4
|
61 |
+
monitor: val/rec_loss
|
62 |
+
ddconfig:
|
63 |
+
double_z: true
|
64 |
+
z_channels: 4
|
65 |
+
resolution: 256
|
66 |
+
in_channels: 3
|
67 |
+
out_ch: 3
|
68 |
+
ch: 128
|
69 |
+
ch_mult:
|
70 |
+
- 1
|
71 |
+
- 2
|
72 |
+
- 4
|
73 |
+
- 4
|
74 |
+
num_res_blocks: 2
|
75 |
+
attn_resolutions: []
|
76 |
+
dropout: 0.0
|
77 |
+
lossconfig:
|
78 |
+
target: torch.nn.Identity
|
79 |
+
cond_stage_config:
|
80 |
+
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
81 |
+
data:
|
82 |
+
target: main.DataModuleFromConfig
|
83 |
+
params:
|
84 |
+
batch_size: 32
|
85 |
+
num_workers: 2
|
86 |
+
train:
|
87 |
+
target: edit_dataset.EditDataset
|
88 |
+
params:
|
89 |
+
path: /home/ugrad/ip2pdata
|
90 |
+
split: train
|
91 |
+
min_resize_res: 256
|
92 |
+
max_resize_res: 256
|
93 |
+
crop_res: 256
|
94 |
+
flip_prob: 0.5
|
logs/train_default/configs/2023-06-30T05-33-22-lightning.yaml
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
lightning:
|
2 |
+
callbacks:
|
3 |
+
image_logger:
|
4 |
+
target: main.ImageLogger
|
5 |
+
params:
|
6 |
+
batch_frequency: 2000
|
7 |
+
max_images: 2
|
8 |
+
increase_log_steps: false
|
9 |
+
trainer:
|
10 |
+
max_epochs: 2000
|
11 |
+
benchmark: true
|
12 |
+
accumulate_grad_batches: 4
|
13 |
+
check_val_every_n_epoch: 4
|
14 |
+
accelerator: ddp
|
15 |
+
gpus: 0,1,2,3
|
16 |
+
resume_from_checkpoint: logs/train_default/checkpoints/last.ckpt
|
logs/train_default/configs/2023-06-30T05-33-22-project.yaml
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 0.0001
|
3 |
+
target: ldm.models.diffusion.ddpm_edit.LatentDiffusion
|
4 |
+
params:
|
5 |
+
ckpt_path: /home/ugrad/epoch=000027.ckpt
|
6 |
+
linear_start: 0.00085
|
7 |
+
linear_end: 0.012
|
8 |
+
num_timesteps_cond: 1
|
9 |
+
log_every_t: 200
|
10 |
+
timesteps: 1000
|
11 |
+
first_stage_key: edited
|
12 |
+
cond_stage_key: edit
|
13 |
+
image_size: 32
|
14 |
+
channels: 4
|
15 |
+
cond_stage_trainable: false
|
16 |
+
conditioning_key: hybrid
|
17 |
+
monitor: val/loss_simple_ema
|
18 |
+
scale_factor: 0.18215
|
19 |
+
use_ema: true
|
20 |
+
load_ema: true
|
21 |
+
scheduler_config:
|
22 |
+
target: ldm.lr_scheduler.LambdaLinearScheduler
|
23 |
+
params:
|
24 |
+
warm_up_steps:
|
25 |
+
- 0
|
26 |
+
cycle_lengths:
|
27 |
+
- 10000000000000
|
28 |
+
f_start:
|
29 |
+
- 1.0e-06
|
30 |
+
f_max:
|
31 |
+
- 1.0
|
32 |
+
f_min:
|
33 |
+
- 1.0
|
34 |
+
unet_config:
|
35 |
+
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
36 |
+
params:
|
37 |
+
image_size: 32
|
38 |
+
in_channels: 8
|
39 |
+
out_channels: 4
|
40 |
+
model_channels: 320
|
41 |
+
attention_resolutions:
|
42 |
+
- 4
|
43 |
+
- 2
|
44 |
+
- 1
|
45 |
+
num_res_blocks: 2
|
46 |
+
channel_mult:
|
47 |
+
- 1
|
48 |
+
- 2
|
49 |
+
- 4
|
50 |
+
- 4
|
51 |
+
num_heads: 8
|
52 |
+
use_spatial_transformer: true
|
53 |
+
transformer_depth: 1
|
54 |
+
context_dim: 768
|
55 |
+
use_checkpoint: true
|
56 |
+
legacy: false
|
57 |
+
first_stage_config:
|
58 |
+
target: ldm.models.autoencoder.AutoencoderKL
|
59 |
+
params:
|
60 |
+
embed_dim: 4
|
61 |
+
monitor: val/rec_loss
|
62 |
+
ddconfig:
|
63 |
+
double_z: true
|
64 |
+
z_channels: 4
|
65 |
+
resolution: 256
|
66 |
+
in_channels: 3
|
67 |
+
out_ch: 3
|
68 |
+
ch: 128
|
69 |
+
ch_mult:
|
70 |
+
- 1
|
71 |
+
- 2
|
72 |
+
- 4
|
73 |
+
- 4
|
74 |
+
num_res_blocks: 2
|
75 |
+
attn_resolutions: []
|
76 |
+
dropout: 0.0
|
77 |
+
lossconfig:
|
78 |
+
target: torch.nn.Identity
|
79 |
+
cond_stage_config:
|
80 |
+
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
81 |
+
data:
|
82 |
+
target: main.DataModuleFromConfig
|
83 |
+
params:
|
84 |
+
batch_size: 32
|
85 |
+
num_workers: 2
|
86 |
+
train:
|
87 |
+
target: edit_dataset.EditDataset
|
88 |
+
params:
|
89 |
+
path: /home/ugrad/ip2pdata
|
90 |
+
split: train
|
91 |
+
min_resize_res: 256
|
92 |
+
max_resize_res: 256
|
93 |
+
crop_res: 256
|
94 |
+
flip_prob: 0.5
|
logs/train_default/configs/2023-07-03T07-00-36-lightning.yaml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
lightning:
|
2 |
+
callbacks:
|
3 |
+
image_logger:
|
4 |
+
target: main.ImageLogger
|
5 |
+
params:
|
6 |
+
batch_frequency: 2000
|
7 |
+
max_images: 2
|
8 |
+
increase_log_steps: false
|
9 |
+
trainer:
|
10 |
+
max_epochs: 2000
|
11 |
+
benchmark: true
|
12 |
+
accumulate_grad_batches: 4
|
13 |
+
check_val_every_n_epoch: 4
|
14 |
+
accelerator: ddp
|
15 |
+
gpus: 0,1,2,3
|
logs/train_default/configs/2023-07-03T07-00-36-project.yaml
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 0.0001
|
3 |
+
target: ldm.models.diffusion.ddpm_edit.LatentDiffusion
|
4 |
+
params:
|
5 |
+
ckpt_path: /home/ugrad/epoch=000027.ckpt
|
6 |
+
linear_start: 0.00085
|
7 |
+
linear_end: 0.012
|
8 |
+
num_timesteps_cond: 1
|
9 |
+
log_every_t: 200
|
10 |
+
timesteps: 1000
|
11 |
+
first_stage_key: edited
|
12 |
+
cond_stage_key: edit
|
13 |
+
image_size: 32
|
14 |
+
channels: 4
|
15 |
+
cond_stage_trainable: false
|
16 |
+
conditioning_key: hybrid
|
17 |
+
monitor: val/loss_simple_ema
|
18 |
+
scale_factor: 0.18215
|
19 |
+
use_ema: true
|
20 |
+
load_ema: false
|
21 |
+
scheduler_config:
|
22 |
+
target: ldm.lr_scheduler.LambdaLinearScheduler
|
23 |
+
params:
|
24 |
+
warm_up_steps:
|
25 |
+
- 0
|
26 |
+
cycle_lengths:
|
27 |
+
- 10000000000000
|
28 |
+
f_start:
|
29 |
+
- 1.0e-06
|
30 |
+
f_max:
|
31 |
+
- 1.0
|
32 |
+
f_min:
|
33 |
+
- 1.0
|
34 |
+
unet_config:
|
35 |
+
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
36 |
+
params:
|
37 |
+
image_size: 32
|
38 |
+
in_channels: 8
|
39 |
+
out_channels: 4
|
40 |
+
model_channels: 320
|
41 |
+
attention_resolutions:
|
42 |
+
- 4
|
43 |
+
- 2
|
44 |
+
- 1
|
45 |
+
num_res_blocks: 2
|
46 |
+
channel_mult:
|
47 |
+
- 1
|
48 |
+
- 2
|
49 |
+
- 4
|
50 |
+
- 4
|
51 |
+
num_heads: 8
|
52 |
+
use_spatial_transformer: true
|
53 |
+
transformer_depth: 1
|
54 |
+
context_dim: 768
|
55 |
+
use_checkpoint: true
|
56 |
+
legacy: false
|
57 |
+
first_stage_config:
|
58 |
+
target: ldm.models.autoencoder.AutoencoderKL
|
59 |
+
params:
|
60 |
+
embed_dim: 4
|
61 |
+
monitor: val/rec_loss
|
62 |
+
ddconfig:
|
63 |
+
double_z: true
|
64 |
+
z_channels: 4
|
65 |
+
resolution: 256
|
66 |
+
in_channels: 3
|
67 |
+
out_ch: 3
|
68 |
+
ch: 128
|
69 |
+
ch_mult:
|
70 |
+
- 1
|
71 |
+
- 2
|
72 |
+
- 4
|
73 |
+
- 4
|
74 |
+
num_res_blocks: 2
|
75 |
+
attn_resolutions: []
|
76 |
+
dropout: 0.0
|
77 |
+
lossconfig:
|
78 |
+
target: torch.nn.Identity
|
79 |
+
cond_stage_config:
|
80 |
+
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
81 |
+
data:
|
82 |
+
target: main.DataModuleFromConfig
|
83 |
+
params:
|
84 |
+
batch_size: 32
|
85 |
+
num_workers: 2
|
86 |
+
train:
|
87 |
+
target: edit_dataset.EditDataset
|
88 |
+
params:
|
89 |
+
path: /home/ugrad/ip2pdata
|
90 |
+
split: train
|
91 |
+
min_resize_res: 256
|
92 |
+
max_resize_res: 256
|
93 |
+
crop_res: 256
|
94 |
+
flip_prob: 0.5
|
logs/train_default/configs/2023-07-03T07-11-08-lightning.yaml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
lightning:
|
2 |
+
callbacks:
|
3 |
+
image_logger:
|
4 |
+
target: main.ImageLogger
|
5 |
+
params:
|
6 |
+
batch_frequency: 2000
|
7 |
+
max_images: 2
|
8 |
+
increase_log_steps: false
|
9 |
+
trainer:
|
10 |
+
max_epochs: 2000
|
11 |
+
benchmark: true
|
12 |
+
accumulate_grad_batches: 4
|
13 |
+
check_val_every_n_epoch: 4
|
14 |
+
accelerator: ddp
|
15 |
+
gpus: 0,1,2,3
|
logs/train_default/configs/2023-07-03T07-11-08-project.yaml
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 0.0001
|
3 |
+
target: ldm.models.diffusion.ddpm_edit.LatentDiffusion
|
4 |
+
params:
|
5 |
+
ckpt_path: /home/ugrad/epoch=000027.ckpt
|
6 |
+
linear_start: 0.00085
|
7 |
+
linear_end: 0.012
|
8 |
+
num_timesteps_cond: 1
|
9 |
+
log_every_t: 200
|
10 |
+
timesteps: 1000
|
11 |
+
first_stage_key: edited
|
12 |
+
cond_stage_key: edit
|
13 |
+
image_size: 32
|
14 |
+
channels: 4
|
15 |
+
cond_stage_trainable: false
|
16 |
+
conditioning_key: hybrid
|
17 |
+
monitor: val/loss_simple_ema
|
18 |
+
scale_factor: 0.18215
|
19 |
+
use_ema: true
|
20 |
+
load_ema: false
|
21 |
+
scheduler_config:
|
22 |
+
target: ldm.lr_scheduler.LambdaLinearScheduler
|
23 |
+
params:
|
24 |
+
warm_up_steps:
|
25 |
+
- 0
|
26 |
+
cycle_lengths:
|
27 |
+
- 10000000000000
|
28 |
+
f_start:
|
29 |
+
- 1.0e-06
|
30 |
+
f_max:
|
31 |
+
- 1.0
|
32 |
+
f_min:
|
33 |
+
- 1.0
|
34 |
+
unet_config:
|
35 |
+
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
36 |
+
params:
|
37 |
+
image_size: 32
|
38 |
+
in_channels: 8
|
39 |
+
out_channels: 4
|
40 |
+
model_channels: 320
|
41 |
+
attention_resolutions:
|
42 |
+
- 4
|
43 |
+
- 2
|
44 |
+
- 1
|
45 |
+
num_res_blocks: 2
|
46 |
+
channel_mult:
|
47 |
+
- 1
|
48 |
+
- 2
|
49 |
+
- 4
|
50 |
+
- 4
|
51 |
+
num_heads: 8
|
52 |
+
use_spatial_transformer: true
|
53 |
+
transformer_depth: 1
|
54 |
+
context_dim: 768
|
55 |
+
use_checkpoint: true
|
56 |
+
legacy: false
|
57 |
+
first_stage_config:
|
58 |
+
target: ldm.models.autoencoder.AutoencoderKL
|
59 |
+
params:
|
60 |
+
embed_dim: 4
|
61 |
+
monitor: val/rec_loss
|
62 |
+
ddconfig:
|
63 |
+
double_z: true
|
64 |
+
z_channels: 4
|
65 |
+
resolution: 256
|
66 |
+
in_channels: 3
|
67 |
+
out_ch: 3
|
68 |
+
ch: 128
|
69 |
+
ch_mult:
|
70 |
+
- 1
|
71 |
+
- 2
|
72 |
+
- 4
|
73 |
+
- 4
|
74 |
+
num_res_blocks: 2
|
75 |
+
attn_resolutions: []
|
76 |
+
dropout: 0.0
|
77 |
+
lossconfig:
|
78 |
+
target: torch.nn.Identity
|
79 |
+
cond_stage_config:
|
80 |
+
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
81 |
+
data:
|
82 |
+
target: main.DataModuleFromConfig
|
83 |
+
params:
|
84 |
+
batch_size: 32
|
85 |
+
num_workers: 2
|
86 |
+
train:
|
87 |
+
target: edit_dataset.EditDataset
|
88 |
+
params:
|
89 |
+
path: /home/ugrad/ip2pdata
|
90 |
+
split: train
|
91 |
+
min_resize_res: 256
|
92 |
+
max_resize_res: 256
|
93 |
+
crop_res: 256
|
94 |
+
flip_prob: 0.5
|
logs/train_default/images/train/gs-002000_e-000666_b-000008_after-gen.png
ADDED
![]() |
logs/train_default/images/train/gs-002000_e-000666_b-000008_after.png
ADDED
![]() |
logs/train_default/images/train/gs-002000_e-000666_b-000008_before-vq.png
ADDED
![]() |
logs/train_default/images/train/gs-002000_e-000666_b-000008_before.png
ADDED
![]() |
logs/train_default/images/train/gs-002000_e-000666_b-000008_prompt.json
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"increase centerline distance by zero, reduce heading error by zero, change time of the day to morning"
|
2 |
+
"reduce centerline distance by zero, increase heading error by zero, change time of the day to night"
|
3 |
+
"increase centerline distance by zero, reduce heading error by zero, change time of the day to night"
|
4 |
+
"increase centerline distance by zero, increase heading error by zero, change time of the day to morning"
|
5 |
+
"reduce centerline distance by zero, increase heading error by zero, change time of the day to night"
|
6 |
+
"increase centerline distance by zero, increase heading error by zero, change time of the day to night"
|
7 |
+
"reduce centerline distance by zero, increase heading error by zero, change time of the day to afternoon"
|
8 |
+
"increase centerline distance by zero, increase heading error by zero, change time of the day to night"
|
logs/train_default/images/train/gs-002000_e-000666_b-000009_after-gen.png
ADDED
![]() |
logs/train_default/images/train/gs-002000_e-000666_b-000009_after.png
ADDED
![]() |
logs/train_default/images/train/gs-002000_e-000666_b-000009_before-vq.png
ADDED
![]() |
logs/train_default/images/train/gs-002000_e-000666_b-000009_before.png
ADDED
![]() |
logs/train_default/images/train/gs-002000_e-000666_b-000009_prompt.json
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"reduce centerline distance by zero, increase heading error by zero, change time of the day to afternoon"
|
2 |
+
"increase centerline distance by zero, reduce heading error by zero, change time of the day to morning"
|
3 |
+
"reduce centerline distance by zero, increase heading error by zero, change time of the day to afternoon"
|
4 |
+
"reduce centerline distance by zero, reduce heading error by zero, change time of the day to night"
|
5 |
+
"reduce centerline distance by zero, increase heading error by zero, change time of the day to night"
|
6 |
+
"increase centerline distance by zero, increase heading error by zero, change time of the day to afternoon"
|
7 |
+
"reduce centerline distance by zero, increase heading error by zero, change time of the day to morning"
|
8 |
+
"reduce centerline distance by zero, increase heading error by zero, change time of the day to morning"
|
logs/train_default/images/train/gs-002000_e-000666_b-000010_after-gen.png
ADDED
![]() |
logs/train_default/images/train/gs-002000_e-000666_b-000010_after.png
ADDED
![]() |
logs/train_default/images/train/gs-002000_e-000666_b-000010_before-vq.png
ADDED
![]() |
logs/train_default/images/train/gs-002000_e-000666_b-000010_before.png
ADDED
![]() |
logs/train_default/images/train/gs-002000_e-000666_b-000010_prompt.json
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"reduce centerline distance by zero, increase heading error by zero, change time of the day to morning"
|
2 |
+
"increase centerline distance by zero, reduce heading error by zero, change time of the day to night"
|
3 |
+
"increase centerline distance by zero, increase heading error by zero, change time of the day to night"
|
4 |
+
"increase centerline distance by zero, increase heading error by zero, change time of the day to morning"
|
5 |
+
"increase centerline distance by zero, increase heading error by zero, change time of the day to morning"
|
6 |
+
"reduce centerline distance by zero, increase heading error by zero, change time of the day to afternoon"
|
7 |
+
"increase centerline distance by zero, increase heading error by zero, change time of the day to afternoon"
|
8 |
+
"reduce centerline distance by zero, increase heading error by zero, change time of the day to morning"
|