diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..2f9e9ed6ae2280f787caeb77cbd91ba19d2e692a 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,29 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+PusaV1/PusaV1.0_Report.pdf filter=lfs diff=lfs merge=lfs -text
+PusaV1/demos/end_frame.jpg filter=lfs diff=lfs merge=lfs -text
+PusaV1/demos/input_image.jpg filter=lfs diff=lfs merge=lfs -text
+PusaV1/demos/input_video.mp4 filter=lfs diff=lfs merge=lfs -text
+PusaV1/demos/start_frame.jpg filter=lfs diff=lfs merge=lfs -text
+PusaV1/diffsynth/models/__pycache__/sd3_text_encoder.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text
+PusaV1/diffsynth/models/__pycache__/sd3_text_encoder.cpython-312.pyc filter=lfs diff=lfs merge=lfs -text
+PusaV1/diffsynth/models/__pycache__/sd_unet.cpython-312.pyc filter=lfs diff=lfs merge=lfs -text
+PusaV1/diffsynth/models/__pycache__/sdxl_unet.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text
+PusaV1/diffsynth/models/__pycache__/sdxl_unet.cpython-312.pyc filter=lfs diff=lfs merge=lfs -text
+PusaV1/diffsynth/models/__pycache__/svd_unet.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text
+PusaV1/diffsynth/models/__pycache__/svd_unet.cpython-312.pyc filter=lfs diff=lfs merge=lfs -text
+PusaV1/diffsynth/tokenizer_configs/hunyuan_video/tokenizer_2/tokenizer.json filter=lfs diff=lfs merge=lfs -text
+PusaV1/diffsynth/tokenizer_configs/kolors/tokenizer/vocab.txt filter=lfs diff=lfs merge=lfs -text
+PusaV1/pusa_benchmark_figure_dark.png filter=lfs diff=lfs merge=lfs -text
+assets/demo0.gif filter=lfs diff=lfs merge=lfs -text
+assets/demo_T2V.gif filter=lfs diff=lfs merge=lfs -text
+assets/example.gif filter=lfs diff=lfs merge=lfs -text
+assets/example_baseline.gif filter=lfs diff=lfs merge=lfs -text
+assets/icon.png filter=lfs diff=lfs merge=lfs -text
+assets/methods_overview.gif filter=lfs diff=lfs merge=lfs -text
+demos/example1.mp4 filter=lfs diff=lfs merge=lfs -text
+demos/example2.mp4 filter=lfs diff=lfs merge=lfs -text
+demos/example3.jpg filter=lfs diff=lfs merge=lfs -text
+demos/example4.jpg filter=lfs diff=lfs merge=lfs -text
+demos/example5.jpg filter=lfs diff=lfs merge=lfs -text
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..4454d1c4bc6833e4e76eab94d32596a10010f1cc
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright 2024 Yaofang Liu, Rui Liu
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/ORIGINAL_README.md b/ORIGINAL_README.md
new file mode 100644
index 0000000000000000000000000000000000000000..dbfda2d6fe0cc9df6a4f969f55b77c932dd31da0
--- /dev/null
+++ b/ORIGINAL_README.md
@@ -0,0 +1,382 @@
+
+
+
+
+# Pusa: Thousands Timesteps Video Diffusion Model
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+## 🔥🔥🔥🚀 Announcing Pusa V1.0 🚀🔥🔥🔥
+
+We are excited to release **Pusa V1.0**, a groundbreaking paradigm that leverages **vectorized timestep adaptation (VTA)** to enable fine-grained temporal control within a unified video diffusion framework. By finetuning the SOTA **Wan-T2V-14B** model with VTA, Pusa V1.0 achieves unprecedented efficiency --**surpassing the performance of Wan-I2V-14B with ≤ 1/200 of the training cost ($500 vs. ≥ $100,000)** and **≤ 1/2500 of the dataset size (4K vs. ≥ 10M samples)**. The codebase has been integrated into the `PusaV1` directory, based on `DiffSynth-Studio`.
+
+
+
+Pusa V1.0 not only sets a new standard for image-to-video generation but also unlocks many other zero-shot multi-task capabilities such as start-end frames and video extension, all without task-specific training while preserving the base model's T2V capabilities.
+
+For detailed usage and examples for Pusa V1.0, please see the **[Pusa V1.0 README](./PusaV1/README.md)**.
+
+
+## News
+#### 🔥🔥🔥 2025.07: Pusa V1.0 (Pusa-Wan) Code, Technical Report, and Dataset, all released!!! Check our [project page](https://yaofang-liu.github.io/Pusa_Web/) and [paper](https://github.com/Yaofang-Liu/Pusa-VidGen/blob/e99c3dcf866789a2db7fbe2686888ec398076a82/PusaV1/PusaV1.0_Report.pdf) for more info.
+#### 🔥🔥🔥 2025.04: Pusa V0.5 (Pusa-Mochi) released.
+
+
+
+
+
+
+ Pusa V0.5 showcases
+
+
+
+
+
+ Pusa V0.5 still can do text-to-video generation like base model Mochi
+
+
+**Pusa can do many more other things, you may check details below.**
+
+
+
+## Table of Contents
+- [Overview](#overview)
+- [Changelog](#changelog)
+- [Pusa V1.0 (Based on Wan)](#pusa-v10-based-on-wan)
+- [Pusa V0.5 (Based on Mochi)](#pusa-v05-based-on-mochi)
+- [Training](#training)
+- [Limitations](#limitations)
+- [Current Status and Roadmap](#current-status-and-roadmap)
+- [Related Work](#related-work)
+- [BibTeX](#bibtex)
+
+## Overview
+
+Pusa (*pu: 'sA:*, from "Thousand-Hand Guanyin" in Chinese) introduces a paradigm shift in video diffusion modeling through frame-level noise control with vectorized timesteps, departing from conventional scalar timestep approaches. This shift was first presented in our [FVDM](https://arxiv.org/abs/2410.03160) paper.
+
+**Pusa V1.0** is based on the SOTA **Wan-T2V-14B** model and enhances it with our unique vectorized timestep adaptations (VTA), a non-destructive adaptation that fully preserves the capabilities of the base model.
+
+**Pusa V0.5** leverages this architecture, and it is based on [Mochi1-Preview](https://huggingface.co/genmo/mochi-1-preview). We are open-sourcing this work to foster community collaboration, enhance methodologies, and expand capabilities.
+
+
+Pusa's novel frame-level noise architecture with vectorized timesteps compared with conventional video diffusion models with a scalar timestep
+
+https://github.com/user-attachments/assets/7d751fd8-9a14-42e6-bcde-6db940df6537
+
+
+### ✨ Key Features
+
+- **Comprehensive Multi-task Support**:
+ - Text-to-Video
+ - Image-to-Video
+ - Start-End Frames
+ - Video completion/transitions
+ - Video Extension
+ - And more...
+
+- **Unprecedented Efficiency**:
+ - Surpasses Wan-I2V-14B with **≤ 1/200 of the training cost** (\$500 vs. ≥ \$100,000)
+ - Trained on a dataset **≤ 1/2500 of the size** (4K vs. ≥ 10M samples)
+ - Achieves a **VBench-I2V score of 87.32%** (vs. 86.86% for Wan-I2V-14B)
+
+- **Complete Open-Source Release**:
+ - Full codebase and training/inference scripts
+ - LoRA model weights and dataset for Pusa V1.0
+ - Detailed architecture specifications
+ - Comprehensive training methodology
+
+### 🔍 Unique Architecture
+
+- **Novel Diffusion Paradigm**: Implements frame-level noise control with vectorized timesteps, originally introduced in the [FVDM paper](https://arxiv.org/abs/2410.03160), enabling unprecedented flexibility and scalability.
+
+- **Non-destructive Modification**: Our adaptations to the base model preserve its original Text-to-Video generation capabilities. After this adaptation, we only need a slight fine-tuning.
+
+- **Universal Applicability**: The methodology can be readily applied to other leading video diffusion models including Hunyuan Video, Wan2.1, and others. *Collaborations enthusiastically welcomed!*
+
+
+## Changelog
+
+**v1.0 (July 15, 2025)**
+- Released Pusa V1.0, based on the Wan-Video models.
+- Released Technical Report, V1.0 model weights and dataset.
+- Integrated codebase as `/PusaV1`.
+- Added new examples and training scripts for Pusa V1.0 in `PusaV1/`.
+- Updated documentation for the V1.0 release.
+
+**v0.5 (June 3, 2025)**
+- Released inference scripts for Start&End Frames Generation, Multi-Frames Generation, Video Transition, and Video Extension.
+
+**v0.5 (April 10, 2025)**
+- Released our training codes and details [here](https://github.com/Yaofang-Liu/Mochi-Full-Finetuner)
+- Support multi-nodes/single-node full finetuning code for both Pusa and Mochi
+- Released our training dataset [dataset](https://huggingface.co/datasets/RaphaelLiu/PusaV0.5_Training)
+
+## Pusa V1.0 (Based on Wan)
+
+Pusa V1.0 leverages the powerful Wan-Video models and enhances them with our custom LoRA models and training scripts. For detailed instructions on installation, model preparation, usage examples, and training, please refer to the **[Pusa V1.0 README](./PusaV1/README.md)**.
+
+## Pusa V0.5 (Based on Mochi)
+
+
+Click to expand for Pusa V0.5 details
+
+### Installation
+
+You may install using [uv](https://github.com/astral-sh/uv):
+
+```bash
+git clone https://github.com/genmoai/models
+cd models
+pip install uv
+uv venv .venv
+source .venv/bin/activate
+uv pip install setuptools
+uv pip install -e . --no-build-isolation
+```
+
+If you want to install flash attention, you can use:
+```
+uv pip install -e .[flash] --no-build-isolation
+```
+
+### Download Weights
+
+**Option 1**: Use the Hugging Face CLI:
+```bash
+pip install huggingface_hub
+huggingface-cli download RaphaelLiu/Pusa-V0.5 --local-dir
+```
+
+**Option 2**: Download directly from [Hugging Face](https://huggingface.co/RaphaelLiu/Pusa-V0.5) to your local machine.
+
+
+## Usage
+
+### Image-to-Video Generation
+
+```bash
+python ./demos/cli_test_ti2v_release.py \
+ --model_dir "/path/to/Pusa-V0.5" \
+ --dit_path "/path/to/Pusa-V0.5/pusa_v0_dit.safetensors" \
+ --prompt "Your_prompt_here" \
+ --image_dir "/path/to/input/image.jpg" \
+ --cond_position 0 \
+ --num_steps 30 \
+ --noise_multiplier 0
+```
+Note: We suggest you to try different `con_position` here, and you may also modify the level of noise added to the condition image. You'd be likely to get some surprises.
+
+Take `./demos/example.jpg` as an example and run with 4 GPUs:
+```bash
+CUDA_VISIBLE_DEVICES=0,1,2,3 python ./demos/cli_test_ti2v_release.py \
+ --model_dir "/path/to/Pusa-V0.5" \
+ --dit_path "/path/to/Pusa-V0.5/pusa_v0_dit.safetensors" \
+ --prompt "The camera remains still, the man is surfing on a wave with his surfboard." \
+ --image_dir "./demos/example.jpg" \
+ --cond_position 0 \
+ --num_steps 30 \
+ --noise_multiplier 0.4
+```
+You can get this result:
+
+
+
+
+
+
+You may ref to the baselines' results from the [VideoGen-Eval](https://github.com/AILab-CVC/VideoGen-Eval) benchmark for comparison:
+
+
+
+
+
+
+#### Processing A Group of Images
+```bash
+python ./demos/cli_test_ti2v_release.py \
+ --model_dir "/path/to/Pusa-V0.5" \
+ --dit_path "/path/to/Pusa-V0.5/pusa_v0_dit.safetensors" \
+ --image_dir "/path/to/image/directory" \
+ --prompt_dir "/path/to/prompt/directory" \
+ --cond_position 1 \
+ --num_steps 30
+```
+
+For group processing, each image should have a corresponding text file with the same name in the prompt directory.
+
+#### Using the Provided Shell Script
+We also provide a shell script for convenience:
+
+```bash
+# Edit cli_test_ti2v_release.sh to set your paths
+# Then run:
+bash ./demos/cli_test_ti2v_release.sh
+```
+
+### Multi-frame Condition
+
+Pusa supports generating videos from multiple keyframes (2 or more) placed at specific positions in the sequence. This is useful for both start-end frame generation and multi-keyframe interpolation.
+
+#### Start & End Frame Generation
+
+```bash
+python ./demos/cli_test_multi_frames_release.py \
+ --model_dir "/path/to/Pusa-V0.5" \
+ --dit_path "/path/to/Pusa-V0.5/pusa_v0_dit.safetensors" \
+ --prompt "Drone view of waves crashing against the rugged cliffs along Big Sur’s garay point beach. The crashing blue waters create white-tipped waves, while the golden light of the setting sun illuminates the rocky shore. A small island with a lighthouse sits in the distance, and green shrubbery covers the cliff’s edge. The steep drop from the road down to the beach is a dramatic feat, with the cliff’s edges jutting out over the sea. This is a view that captures the raw beauty of the coast and the rugged landscape of the Pacific Coast Highway." \
+ --multi_cond '{"0": ["./demos/example3.jpg", 0.3], "20": ["./demos/example5.jpg", 0.7]}' \
+ --num_steps 30
+```
+
+The `multi_cond` parameter specifies frame condition positions and their corresponding image paths and noise multipliers. In this example, the first frame (position 0) uses `./demos/example3.jpg` with noise multiplier 0.3, and frame 20 uses `./demos/example5.jpg` with noise multiplier 0.5.
+
+Alternatively, use the provided shell script:
+```bash
+# Edit parameters in cli_test_multi_frames_release.sh first
+bash ./demos/cli_test_multi_frames_release.sh
+```
+
+#### Multi-keyframe Interpolation
+
+To generate videos with more than two keyframes (e.g., start, middle, and end):
+
+```bash
+python ./demos/cli_test_multi_frames_release.py \
+ --model_dir "/path/to/Pusa-V0.5" \
+ --dit_path "/path/to/Pusa-V0.5/pusa_v0_dit.safetensors" \
+ --prompt "Drone view of waves crashing against the rugged cliffs along Big Sur’s garay point beach. The crashing blue waters create white-tipped waves, while the golden light of the setting sun illuminates the rocky shore. A small island with a lighthouse sits in the distance, and green shrubbery covers the cliff’s edge. The steep drop from the road down to the beach is a dramatic feat, with the cliff’s edges jutting out over the sea. This is a view that captures the raw beauty of the coast and the rugged landscape of the Pacific Coast Highway." \
+ --multi_cond '{"0": ["./demos/example3.jpg", 0.3], "13": ["./demos/example4.jpg", 0.7], "27": ["./demos/example5.jpg", 0.7]}' \
+ --num_steps 30
+```
+
+### Video Transition
+
+Create smooth transitions between two videos:
+
+```bash
+python ./demos/cli_test_transition_release.py \
+ --model_dir "/path/to/Pusa-V0.5" \
+ --dit_path "/path/to/Pusa-V0.5/pusa_v0_dit.safetensors" \
+ --prompt "A fluffy Cockapoo, perched atop a vibrant pink flamingo jumps into a crystal-clear pool." \
+ --video_start_dir "./demos/example1.mp4" \
+ --video_end_dir "./demos/example2.mp4" \
+ --cond_position_start "[0]" \
+ --cond_position_end "[-3,-2,-1]" \
+ --noise_multiplier "[0.3,0.8,0.8,0.8]" \
+ --num_steps 30
+```
+
+Parameters:
+- `cond_position_start`: Frame indices from the start video to use as conditioning
+- `cond_position_end`: Frame indices from the end video to use as conditioning
+- `noise_multiplier`: Noise level multipliers for each conditioning frame
+
+Alternatively, use the provided shell script:
+```bash
+# Edit parameters in cli_test_transition_release.sh first
+bash ./demos/cli_test_transition_release.sh
+```
+
+### Video Extension
+
+Extend existing videos with generated content:
+
+```bash
+python ./demos/cli_test_extension_release.py \
+ --model_dir "/path/to/Pusa-V0.5" \
+ --dit_path "/path/to/Pusa-V0.5/pusa_v0_dit.safetensors" \
+ --prompt "A cinematic shot captures a fluffy Cockapoo, perched atop a vibrant pink flamingo float, in a sun-drenched Los Angeles swimming pool. The crystal-clear water sparkles under the bright California sun, reflecting the playful scene." \
+ --video_dir "./demos/example1.mp4" \
+ --cond_position "[0,1,2,3]" \
+ --noise_multiplier "[0.1,0.2,0.3,0.4]" \
+ --num_steps 30
+```
+
+Parameters:
+- `cond_position`: Frame indices from the input video to use as conditioning
+- `noise_multiplier`: Noise level multipliers for each conditioning frame
+
+Alternatively, use the provided shell script:
+```bash
+# Edit parameters in cli_test_v2v_release.sh first
+bash ./demos/cli_test_v2v_release.sh
+```
+
+### Text-to-Video Generation
+```bash
+python ./demos/cli_test_ti2v_release.py \
+ --model_dir "/path/to/Pusa-V0.5" \
+ --dit_path "/path/to/Pusa-V0.5/pusa_v0_dit.safetensors" \
+ --prompt "A man is playing basketball" \
+ --num_steps 30
+```
+
+
+
+## Training
+
+For Pusa V1.0, please find the training details in the **[Pusa V1.0 README](./PusaV1/README.md#training)**.
+
+For Pusa V0.5, you can find our training code and details [here](https://github.com/Yaofang-Liu/Mochi-Full-Finetuner), which also supports training for the original Mochi model.
+
+## Limitations
+
+Pusa currently has several known limitations:
+- Video generation quality is dependent on the base model (e.g., Wan-T2V-14B for V1.0).
+- We anticipate significant quality improvements when applying our methodology to more advanced models.
+- We welcome community contributions to enhance model performance and extend its capabilities.
+
+### Currently Available
+- ✅ Model weights for Pusa V1.0 and V0.5
+- ✅ Inference code for Text-to-Video generation
+- ✅ Inference code for Image-to-Video generation
+- ✅ Inference scripts for start & end frames, multi-frames, video transition, video extension
+- ✅ Training code and details
+- ✅ Model full fine-tuning guide (for Pusa V0.5)
+- ✅ Training datasets
+- ✅ Technical Report for Pusa V1.0
+
+### TODO List
+- 🔄 Release more advanced versions with SOTA models
+- 🔄 More capabilities like long video generation
+- 🔄 ....
+
+## Related Work
+
+- [FVDM](https://arxiv.org/abs/2410.03160): Introduces the groundbreaking frame-level noise control with vectorized timestep approach that inspired Pusa.
+- [Wan-Video](https://github.com/modelscope/DiffSynth-Studio): The foundation model for Pusa V1.0.
+- [Mochi](https://huggingface.co/genmo/mochi-1-preview): The foundation model for Pusa V0.5, recognized as a leading open-source video generation system on the Artificial Analysis Leaderboard.
+
+## BibTeX
+If you use this work in your project, please cite the following references.
+```
+@misc{Liu2025pusa,
+ title={Pusa: Thousands Timesteps Video Diffusion Model},
+ author={Yaofang Liu and Rui Liu},
+ year={2025},
+ url={https://github.com/Yaofang-Liu/Pusa-VidGen},
+}
+```
+
+```
+@article{liu2024redefining,
+ title={Redefining Temporal Modeling in Video Diffusion: The Vectorized Timestep Approach},
+ author={Liu, Yaofang and Ren, Yumeng and Cun, Xiaodong and Artola, Aitor and Liu, Yang and Zeng, Tieyong and Chan, Raymond H and Morel, Jean-michel},
+ journal={arXiv preprint arXiv:2410.03160},
+ year={2024}
+}
+```
+
+
+
diff --git a/PusaV1/PusaV1.0_Report.pdf b/PusaV1/PusaV1.0_Report.pdf
new file mode 100644
index 0000000000000000000000000000000000000000..a9d27921c943d975399fa14c43090710bb55f4ac
--- /dev/null
+++ b/PusaV1/PusaV1.0_Report.pdf
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:259aa6e00fc50f9981880432ad456e424433945db203cdd7c8ebdea0ba47ca29
+size 56655271
diff --git a/PusaV1/README.md b/PusaV1/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..52c1a4eab23849ad232aeeaaa32c568ff32bfce6
--- /dev/null
+++ b/PusaV1/README.md
@@ -0,0 +1,141 @@
+# Pusa-Video V1.0
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+## 🔥🔥🔥🚀 Announcing Pusa V1.0 🚀🔥🔥🔥
+
+We are excited to release **Pusa V1.0**, a groundbreaking paradigm that leverages **vectorized timestep adaptation (VTA)** to enable fine-grained temporal control within a unified video diffusion framework. By finetuning the SOTA **Wan-T2V-14B** model with VTA, Pusa V1.0 achieves unprecedented efficiency, **surpassing Wan-I2V on Vbench-I2V with only $500 of training cost**. The codebase has been integrated into the `PusaV1` directory, based on `DiffSynth-Studio`.
+
+
+
+Pusa V1.0 not only sets a new standard for image-to-video generation but also unlocks many other zero-shot multi-task capabilities such as start-end frames and video extension, all without task-specific training while preserving the base model's T2V capabilities.
+
+For detailed usage and examples for Pusa V1.0, please see the **[Pusa V1.0 README](./PusaV1/README.md)**.
+
+
+## Installation
+
+Before using this model, you may follow the code below to setup the environment, Cuda 12.4 recommended.
+```shell
+conda create -n pusav1 python=3.10 -y
+conda activate pusav1
+cd ./PusaV1
+pip install -e .
+pip install xfuser>=0.4.3 absl-py peft lightning pandas deepspeed wandb av
+```
+
+## Model Preparation
+
+Download the necessary models and place them into the `./model_zoo` directory. You can use the following commands to download and arrange the models correctly.
+
+```shell
+# Make sure you are in the PusaV1 directory
+# Install huggingface-cli if you don't have it
+pip install -U "huggingface_hub[cli]"
+huggingface-cli download RaphaelLiu/PusaV1 --local-dir ./model_zoo/
+cat ./model_zoo/PusaV1/pusa_v1.pt.part* > ./model_zoo/PusaV1/pusa_v1.pt
+```
+
+## Usage Examples
+
+All scripts save their output in an `outputs` directory, which will be created if it doesn't exist.
+
+### Image-to-Video Generation
+
+This script generates a video conditioned on an input image and a text prompt.
+
+```shell
+python examples/pusavideo/wan_14b_image_to_video_pusa.py \
+ --image_path "./demos/input_image.jpg" \
+ --prompt "A wide-angle shot shows a serene monk meditating perched a top of the letter E of a pile of weathered rocks that vertically spell out 'ZEN'. The rock formation is perched atop a misty mountain peak at sunrise. The warm light bathes the monk in a gentle glow, highlighting the folds of his saffron robes. The sky behind him is a soft gradient of pink and orange, creating a tranquil backdrop. The camera slowly zooms in, capturing the monk's peaceful expression and the intricate details of the rocks. The scene is bathed in a soft, ethereal light, emphasizing the spiritual atmosphere." \
+ --lora_path "./model_zoo/PusaV1/pusa_v1.pt"
+```
+
+### Video-to-Video Generation
+
+This script can be used for various video-to-video tasks like video completion, video extension, or video transition, by providing an input video with at least 81 frames and specify condition settings. The generated video has 81 frames/21 latent frames in total.
+
+**Example 1: Video Completion (Start-End Frames)**
+Give the start frame and 4 end frames (encoded to one single latent frame) as conditions.
+
+```shell
+python examples/pusavideo/wan_14b_v2v_pusa.py \
+ --video_path "./demos/input_video.mp4" \
+ --prompt "piggy bank surfing a tube in teahupo'o wave dusk light cinematic shot shot in 35mm film" \
+ --cond_position "0,20" \
+ --noise_multipliers "0,0" \
+ --lora_path "./model_zoo/PusaV1/pusa_v1.pt"
+```
+
+**Example 2: Video Extension**
+Give 13 frames as condition (encoded to the first 4 latent frames).
+
+```shell
+python examples/pusavideo/wan_14b_v2v_pusa.py \
+ --video_path "./demos/input_video.mp4" \
+ --prompt "piggy bank surfing a tube in teahupo'o wave dusk light cinematic shot shot in 35mm film" \
+ --cond_position "0,1,2,3" \
+ --noise_multipliers "0,0,0,0" \
+ --lora_path "./model_zoo/PusaV1/pusa_v1.pt"
+```
+
+### Multi-Frame Conditioned Generation
+
+This script generates a video conditioned on multiple input frames and a prompt.
+
+**Example: Start-End Frames**
+Give the start and end frames as image files for conditioning, and add some noise to the condition frames to generate more coherent video.
+
+```shell
+python examples/pusavideo/wan_14b_multi_frames_pusa.py \
+ --image_paths "./demos/start_frame.jpg" "./demos/end_frame.jpg" \
+ --prompt "plastic injection machine opens releasing a soft inflatable foamy morphing sticky figure over a hand. isometric. low light. dramatic light. macro shot. real footage" \
+ --cond_position "0,20" \
+ --noise_multipliers "0.3,0.7" \
+ --lora_path "./model_zoo/PusaV1/pusa_v1.pt"
+```
+
+### Text-to-Video Generation
+
+This script generates a video from a text prompt.
+
+```shell
+python examples/pusavideo/wan_14b_text_to_video_pusa.py \
+ --prompt "A vibrant coral reef teeming with life, schools of colorful fish darting through the intricate coral formations. A majestic sea turtle glides gracefully past, its shell a mosaic of earthy tones. Sunlight filters through the clear blue water, creating a breathtaking underwater spectacle." \
+ --lora_path "./model_zoo/PusaV1/pusa_v1.pt"
+```
+
+## Training
+Our training pipeline is based on Diffsynth-Studio, which supports both full finetuing and lora finetuing. We use LoRA training on a custom dataset to get Pusa V1.0 model. The training process consists of two stages: data preparation and training.
+
+### Prepare Dataset
+You can download our dataset on Huggingface or prepare our own dataset following https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo.
+
+Download `PusaV1_training` dataset to here `./dataset/`.
+```shell
+huggingface-cli download RaphaelLiu/PusaV1_training --repo-type dataset --local-dir ./dataset/
+```
+
+### Training
+After prepraring the dataset, you can start training. We provide a sample script `train.sh` for multi-GPU training on a single node using `torchrun` and `deepspeed`.
+
+You can find the content in `examples/pusavideo/train.sh` and modify the paths and parameters as needed. Finally, run the script from the `PusaV1` directory:
+```shell
+bash ./examples/pusavideo/train.sh
+```
+The trained LoRA model will be saved in the `lightning_logs` directory inside your specified `--output_path`.
+
+
+
diff --git a/PusaV1/dataset/train_dataset_here b/PusaV1/dataset/train_dataset_here
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/PusaV1/demos/end_frame.jpg b/PusaV1/demos/end_frame.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..25aa120f48aea80f61f9f368f3add31fdc65f685
--- /dev/null
+++ b/PusaV1/demos/end_frame.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1ca1ba36f88702330d881fc2019750be57aa5057413d8f3a50de4863f9c2aa1f
+size 142045
diff --git a/PusaV1/demos/input_image.jpg b/PusaV1/demos/input_image.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..d0c01c9b1b99fa9fb09fbb3f10d1a23ae7ff8868
--- /dev/null
+++ b/PusaV1/demos/input_image.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:eba146f59d1f5b19ed9db490b3f84d8c9e576907aefe561b96ffc7b8f78db7d7
+size 138642
diff --git a/PusaV1/demos/input_video.mp4 b/PusaV1/demos/input_video.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..011f49fed6f2690cf12c3de9008b1323c5cc2641
--- /dev/null
+++ b/PusaV1/demos/input_video.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c7b60085459e4eb455166bed3d5267d626fe76ff035ca15d220bbadc3ce86045
+size 1935048
diff --git a/PusaV1/demos/start_frame.jpg b/PusaV1/demos/start_frame.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..772535eb5864b6bbafd42a4321b7bad06281d890
--- /dev/null
+++ b/PusaV1/demos/start_frame.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5220869dade17fefa7c1c07377ae671035458f49acc6f4ce0bcc13bdc52eab02
+size 116854
diff --git a/PusaV1/diffsynth/__init__.py b/PusaV1/diffsynth/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae0a45c2e2dc61f8f16354feb1b0c481776b523f
--- /dev/null
+++ b/PusaV1/diffsynth/__init__.py
@@ -0,0 +1,6 @@
+from .data import *
+from .models import *
+from .prompters import *
+from .schedulers import *
+from .pipelines import *
+from .controlnets import *
diff --git a/PusaV1/diffsynth/__pycache__/__init__.cpython-310.pyc b/PusaV1/diffsynth/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..40f47b2aaab5916092daa90223886a9c2f4b16b5
Binary files /dev/null and b/PusaV1/diffsynth/__pycache__/__init__.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/__pycache__/__init__.cpython-312.pyc b/PusaV1/diffsynth/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..039f9225532b03628a67d61bf0c533df802b850d
Binary files /dev/null and b/PusaV1/diffsynth/__pycache__/__init__.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/configs/__init__.py b/PusaV1/diffsynth/configs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/PusaV1/diffsynth/configs/__pycache__/__init__.cpython-310.pyc b/PusaV1/diffsynth/configs/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..37ab9b5d98a3433ce81a7c7edd61beeb65ec937b
Binary files /dev/null and b/PusaV1/diffsynth/configs/__pycache__/__init__.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/configs/__pycache__/__init__.cpython-312.pyc b/PusaV1/diffsynth/configs/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..65fb96ecbec6791ba08f4c55bfb4127c59079247
Binary files /dev/null and b/PusaV1/diffsynth/configs/__pycache__/__init__.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/configs/__pycache__/model_config.cpython-310.pyc b/PusaV1/diffsynth/configs/__pycache__/model_config.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c059858b4db48d3d2da7c9ee423a63c6d1f6e2ac
Binary files /dev/null and b/PusaV1/diffsynth/configs/__pycache__/model_config.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/configs/__pycache__/model_config.cpython-312.pyc b/PusaV1/diffsynth/configs/__pycache__/model_config.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..659f3a4a5a7bfc37daf7d8bfcd96aec4d53ef63b
Binary files /dev/null and b/PusaV1/diffsynth/configs/__pycache__/model_config.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/configs/__pycache__/model_config_pusa.cpython-312.pyc b/PusaV1/diffsynth/configs/__pycache__/model_config_pusa.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6168f8c2c96a164420c8ae6676b1d2aad609dabf
Binary files /dev/null and b/PusaV1/diffsynth/configs/__pycache__/model_config_pusa.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/configs/model_config.py b/PusaV1/diffsynth/configs/model_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..d8c9d841aae34e61121123cbe93e5b243157c623
--- /dev/null
+++ b/PusaV1/diffsynth/configs/model_config.py
@@ -0,0 +1,818 @@
+from typing_extensions import Literal, TypeAlias
+
+from ..models.sd_text_encoder import SDTextEncoder
+from ..models.sd_unet import SDUNet
+from ..models.sd_vae_encoder import SDVAEEncoder
+from ..models.sd_vae_decoder import SDVAEDecoder
+
+from ..models.sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
+from ..models.sdxl_unet import SDXLUNet
+from ..models.sdxl_vae_decoder import SDXLVAEDecoder
+from ..models.sdxl_vae_encoder import SDXLVAEEncoder
+
+from ..models.sd3_text_encoder import SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3
+from ..models.sd3_dit import SD3DiT
+from ..models.sd3_vae_decoder import SD3VAEDecoder
+from ..models.sd3_vae_encoder import SD3VAEEncoder
+
+from ..models.sd_controlnet import SDControlNet
+from ..models.sdxl_controlnet import SDXLControlNetUnion
+
+from ..models.sd_motion import SDMotionModel
+from ..models.sdxl_motion import SDXLMotionModel
+
+from ..models.svd_image_encoder import SVDImageEncoder
+from ..models.svd_unet import SVDUNet
+from ..models.svd_vae_decoder import SVDVAEDecoder
+from ..models.svd_vae_encoder import SVDVAEEncoder
+
+from ..models.sd_ipadapter import SDIpAdapter, IpAdapterCLIPImageEmbedder
+from ..models.sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
+
+from ..models.hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
+from ..models.hunyuan_dit import HunyuanDiT
+
+from ..models.flux_dit import FluxDiT
+from ..models.flux_text_encoder import FluxTextEncoder2
+from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder
+from ..models.flux_controlnet import FluxControlNet
+from ..models.flux_ipadapter import FluxIpAdapter
+from ..models.flux_infiniteyou import InfiniteYouImageProjector
+
+from ..models.cog_vae import CogVAEEncoder, CogVAEDecoder
+from ..models.cog_dit import CogDiT
+
+from ..models.omnigen import OmniGenTransformer
+
+from ..models.hunyuan_video_vae_decoder import HunyuanVideoVAEDecoder
+from ..models.hunyuan_video_vae_encoder import HunyuanVideoVAEEncoder
+
+from ..extensions.RIFE import IFNet
+from ..extensions.ESRGAN import RRDBNet
+
+from ..models.hunyuan_video_dit import HunyuanVideoDiT
+
+from ..models.stepvideo_vae import StepVideoVAE
+from ..models.stepvideo_dit import StepVideoModel
+
+from ..models.wan_video_dit import WanModel
+from ..models.wan_video_pusa import WanModelPusa
+from ..models.wan_video_text_encoder import WanTextEncoder
+from ..models.wan_video_image_encoder import WanImageEncoder
+from ..models.wan_video_vae import WanVideoVAE
+from ..models.wan_video_motion_controller import WanMotionControllerModel
+from ..models.wan_video_vace import VaceWanModel
+
+
+model_loader_configs = [
+ # These configs are provided for detecting model type automatically.
+ # The format is (state_dict_keys_hash, state_dict_keys_hash_with_shape, model_names, model_classes, model_resource)
+ (None, "091b0e30e77c76626b3ba62acdf95343", ["sd_controlnet"], [SDControlNet], "civitai"),
+ (None, "4a6c8306a27d916dea81263c8c88f450", ["hunyuan_dit_clip_text_encoder"], [HunyuanDiTCLIPTextEncoder], "civitai"),
+ (None, "f4aec400fe394297961218c768004521", ["hunyuan_dit"], [HunyuanDiT], "civitai"),
+ (None, "9e6e58043a5a2e332803ed42f6ee7181", ["hunyuan_dit_t5_text_encoder"], [HunyuanDiTT5TextEncoder], "civitai"),
+ (None, "13115dd45a6e1c39860f91ab073b8a78", ["sdxl_vae_encoder", "sdxl_vae_decoder"], [SDXLVAEEncoder, SDXLVAEDecoder], "diffusers"),
+ (None, "d78aa6797382a6d455362358a3295ea9", ["sd_ipadapter_clip_image_encoder"], [IpAdapterCLIPImageEmbedder], "diffusers"),
+ (None, "e291636cc15e803186b47404262ef812", ["sd_ipadapter"], [SDIpAdapter], "civitai"),
+ (None, "399c81f2f8de8d1843d0127a00f3c224", ["sdxl_ipadapter_clip_image_encoder"], [IpAdapterXLCLIPImageEmbedder], "diffusers"),
+ (None, "a64eac9aa0db4b9602213bc0131281c7", ["sdxl_ipadapter"], [SDXLIpAdapter], "civitai"),
+ (None, "52817e4fdd89df154f02749ca6f692ac", ["sdxl_unet"], [SDXLUNet], "diffusers"),
+ (None, "03343c606f16d834d6411d0902b53636", ["sd_text_encoder", "sd_unet", "sd_vae_decoder", "sd_vae_encoder"], [SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder], "civitai"),
+ (None, "d4ba77a7ece070679b4a987f58f201e9", ["sd_text_encoder"], [SDTextEncoder], "civitai"),
+ (None, "d0c89e55c5a57cf3981def0cb1c9e65a", ["sd_vae_decoder", "sd_vae_encoder"], [SDVAEDecoder, SDVAEEncoder], "civitai"),
+ (None, "3926bf373b39a67eeafd7901478a47a7", ["sd_unet"], [SDUNet], "civitai"),
+ (None, "1e0c39ec176b9007c05f76d52b554a4d", ["sd3_text_encoder_1", "sd3_text_encoder_2", "sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3TextEncoder1, SD3TextEncoder2, SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
+ (None, "d9e0290829ba8d98e28e1a2b1407db4a", ["sd3_text_encoder_1", "sd3_text_encoder_2", "sd3_text_encoder_3", "sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3, SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
+ (None, "5072d0b24e406b49507abe861cf97691", ["sd3_text_encoder_3"], [SD3TextEncoder3], "civitai"),
+ (None, "4cf64a799d04260df438c6f33c9a047e", ["sdxl_text_encoder", "sdxl_text_encoder_2", "sdxl_unet", "sdxl_vae_decoder", "sdxl_vae_encoder"], [SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder], "civitai"),
+ (None, "d9b008a867c498ab12ad24042eff8e3f", ["sdxl_text_encoder", "sdxl_text_encoder_2", "sdxl_unet", "sdxl_vae_decoder", "sdxl_vae_encoder"], [SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder], "civitai"), # SDXL-Turbo
+ (None, "025bb7452e531a3853d951d77c63f032", ["sdxl_text_encoder", "sdxl_text_encoder_2"], [SDXLTextEncoder, SDXLTextEncoder2], "civitai"),
+ (None, "298997b403a4245c04102c9f36aac348", ["sdxl_unet"], [SDXLUNet], "civitai"),
+ (None, "2a07abce74b4bdc696b76254ab474da6", ["svd_image_encoder", "svd_unet", "svd_vae_decoder", "svd_vae_encoder"], [SVDImageEncoder, SVDUNet, SVDVAEDecoder, SVDVAEEncoder], "civitai"),
+ (None, "c96a285a6888465f87de22a984d049fb", ["sd_motion_modules"], [SDMotionModel], "civitai"),
+ (None, "72907b92caed19bdb2adb89aa4063fe2", ["sdxl_motion_modules"], [SDXLMotionModel], "civitai"),
+ (None, "31d2d9614fba60511fc9bf2604aa01f7", ["sdxl_controlnet"], [SDXLControlNetUnion], "diffusers"),
+ (None, "94eefa3dac9cec93cb1ebaf1747d7b78", ["sd3_text_encoder_1"], [SD3TextEncoder1], "diffusers"),
+ (None, "1aafa3cc91716fb6b300cc1cd51b85a3", ["flux_vae_encoder", "flux_vae_decoder"], [FluxVAEEncoder, FluxVAEDecoder], "diffusers"),
+ (None, "21ea55f476dfc4fd135587abb59dfe5d", ["flux_vae_encoder", "flux_vae_decoder"], [FluxVAEEncoder, FluxVAEDecoder], "civitai"),
+ (None, "a29710fea6dddb0314663ee823598e50", ["flux_dit"], [FluxDiT], "civitai"),
+ (None, "57b02550baab820169365b3ee3afa2c9", ["flux_dit"], [FluxDiT], "civitai"),
+ (None, "3394f306c4cbf04334b712bf5aaed95f", ["flux_dit"], [FluxDiT], "civitai"),
+ (None, "023f054d918a84ccf503481fd1e3379e", ["flux_dit"], [FluxDiT], "civitai"),
+ (None, "605c56eab23e9e2af863ad8f0813a25d", ["flux_dit"], [FluxDiT], "diffusers"),
+ (None, "280189ee084bca10f70907bf6ce1649d", ["cog_vae_encoder", "cog_vae_decoder"], [CogVAEEncoder, CogVAEDecoder], "diffusers"),
+ (None, "9b9313d104ac4df27991352fec013fd4", ["rife"], [IFNet], "civitai"),
+ (None, "6b7116078c4170bfbeaedc8fe71f6649", ["esrgan"], [RRDBNet], "civitai"),
+ (None, "61cbcbc7ac11f169c5949223efa960d1", ["omnigen_transformer"], [OmniGenTransformer], "diffusers"),
+ (None, "78d18b9101345ff695f312e7e62538c0", ["flux_controlnet"], [FluxControlNet], "diffusers"),
+ (None, "b001c89139b5f053c715fe772362dd2a", ["flux_controlnet"], [FluxControlNet], "diffusers"),
+ (None, "52357cb26250681367488a8954c271e8", ["flux_controlnet"], [FluxControlNet], "diffusers"),
+ (None, "0cfd1740758423a2a854d67c136d1e8c", ["flux_controlnet"], [FluxControlNet], "diffusers"),
+ (None, "7f9583eb8ba86642abb9a21a4b2c9e16", ["flux_controlnet"], [FluxControlNet], "diffusers"),
+ (None, "c07c0f04f5ff55e86b4e937c7a40d481", ["infiniteyou_image_projector"], [InfiniteYouImageProjector], "diffusers"),
+ (None, "4daaa66cc656a8fe369908693dad0a35", ["flux_ipadapter"], [FluxIpAdapter], "diffusers"),
+ (None, "51aed3d27d482fceb5e0739b03060e8f", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
+ (None, "98cc34ccc5b54ae0e56bdea8688dcd5a", ["sd3_text_encoder_2"], [SD3TextEncoder2], "civitai"),
+ (None, "77ff18050dbc23f50382e45d51a779fe", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
+ (None, "5da81baee73198a7c19e6d2fe8b5148e", ["sd3_text_encoder_1"], [SD3TextEncoder1], "diffusers"),
+ (None, "aeb82dce778a03dcb4d726cb03f3c43f", ["hunyuan_video_vae_decoder", "hunyuan_video_vae_encoder"], [HunyuanVideoVAEDecoder, HunyuanVideoVAEEncoder], "diffusers"),
+ (None, "b9588f02e78f5ccafc9d7c0294e46308", ["hunyuan_video_dit"], [HunyuanVideoDiT], "civitai"),
+ (None, "84ef4bd4757f60e906b54aa6a7815dc6", ["hunyuan_video_dit"], [HunyuanVideoDiT], "civitai"),
+ (None, "68beaf8429b7c11aa8ca05b1bd0058bd", ["stepvideo_vae"], [StepVideoVAE], "civitai"),
+ (None, "5c0216a2132b082c10cb7a0e0377e681", ["stepvideo_dit"], [StepVideoModel], "civitai"),
+ (None, "9269f8db9040a9d860eaca435be61814", ["wan_video_dit"], [WanModel], "civitai"),
+ (None, "aafcfd9672c3a2456dc46e1cb6e52c70", ["wan_video_dit"], [WanModel], "civitai"),
+ (None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
+ (None, "6d6ccde6845b95ad9114ab993d917893", ["wan_video_dit"], [WanModel], "civitai"),
+ (None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
+ (None, "349723183fc063b2bfc10bb2835cf677", ["wan_video_dit"], [WanModel], "civitai"),
+ (None, "efa44cddf936c70abd0ea28b6cbe946c", ["wan_video_dit"], [WanModel], "civitai"),
+ (None, "a61453409b67cd3246cf0c3bebad47ba", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"),
+ (None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"),
+ (None, "9269f8db9040a9d860eaca435be61814", ["wan_video_pusa"], [WanModelPusa], "civitai"),
+ (None, "aafcfd9672c3a2456dc46e1cb6e52c70", ["wan_video_pusa"], [WanModelPusa], "civitai"),
+ (None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_pusa"], [WanModelPusa], "civitai"),
+ (None, "6d6ccde6845b95ad9114ab993d917893", ["wan_video_pusa"], [WanModelPusa], "civitai"),
+ (None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_pusa"], [WanModelPusa], "civitai"),
+ (None, "349723183fc063b2bfc10bb2835cf677", ["wan_video_pusa"], [WanModelPusa], "civitai"),
+ (None, "efa44cddf936c70abd0ea28b6cbe946c", ["wan_video_pusa"], [WanModelPusa], "civitai"),
+ (None, "a61453409b67cd3246cf0c3bebad47ba", ["wan_video_pusa", "wan_video_vace"], [WanModelPusa, VaceWanModel], "civitai"),
+ (None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_pusa"], [WanModelPusa], "diffusers"),
+ (None, "9c8818c2cbea55eca56c7b447df170da", ["wan_video_text_encoder"], [WanTextEncoder], "civitai"),
+ (None, "5941c53e207d62f20f9025686193c40b", ["wan_video_image_encoder"], [WanImageEncoder], "civitai"),
+ (None, "1378ea763357eea97acdef78e65d6d96", ["wan_video_vae"], [WanVideoVAE], "civitai"),
+ (None, "ccc42284ea13e1ad04693284c7a09be6", ["wan_video_vae"], [WanVideoVAE], "civitai"),
+ (None, "dbd5ec76bbf977983f972c151d545389", ["wan_video_motion_controller"], [WanMotionControllerModel], "civitai"),
+]
+huggingface_model_loader_configs = [
+ # These configs are provided for detecting model type automatically.
+ # The format is (architecture_in_huggingface_config, huggingface_lib, model_name, redirected_architecture)
+ ("ChatGLMModel", "diffsynth.models.kolors_text_encoder", "kolors_text_encoder", None),
+ ("MarianMTModel", "transformers.models.marian.modeling_marian", "translator", None),
+ ("BloomForCausalLM", "transformers.models.bloom.modeling_bloom", "beautiful_prompt", None),
+ ("Qwen2ForCausalLM", "transformers.models.qwen2.modeling_qwen2", "qwen_prompt", None),
+ # ("LlamaForCausalLM", "transformers.models.llama.modeling_llama", "omost_prompt", None),
+ ("T5EncoderModel", "diffsynth.models.flux_text_encoder", "flux_text_encoder_2", "FluxTextEncoder2"),
+ ("CogVideoXTransformer3DModel", "diffsynth.models.cog_dit", "cog_dit", "CogDiT"),
+ ("SiglipModel", "transformers.models.siglip.modeling_siglip", "siglip_vision_model", "SiglipVisionModel"),
+ ("LlamaForCausalLM", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoLLMEncoder"),
+ ("LlavaForConditionalGeneration", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoMLLMEncoder"),
+ ("Step1Model", "diffsynth.models.stepvideo_text_encoder", "stepvideo_text_encoder_2", "STEP1TextEncoder"),
+]
+patch_model_loader_configs = [
+ # These configs are provided for detecting model type automatically.
+ # The format is (state_dict_keys_hash_with_shape, model_name, model_class, extra_kwargs)
+ ("9a4ab6869ac9b7d6e31f9854e397c867", ["svd_unet"], [SVDUNet], {"add_positional_conv": 128}),
+]
+
+preset_models_on_huggingface = {
+ "HunyuanDiT": [
+ ("Tencent-Hunyuan/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
+ ("Tencent-Hunyuan/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
+ ("Tencent-Hunyuan/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
+ ("Tencent-Hunyuan/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
+ ],
+ "stable-video-diffusion-img2vid-xt": [
+ ("stabilityai/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
+ ],
+ "ExVideo-SVD-128f-v1": [
+ ("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
+ ],
+ # Stable Diffusion
+ "StableDiffusion_v15": [
+ ("benjamin-paine/stable-diffusion-v1-5", "v1-5-pruned-emaonly.safetensors", "models/stable_diffusion"),
+ ],
+ "DreamShaper_8": [
+ ("Yntec/Dreamshaper8", "dreamshaper_8.safetensors", "models/stable_diffusion"),
+ ],
+ # Textual Inversion
+ "TextualInversion_VeryBadImageNegative_v1.3": [
+ ("gemasai/verybadimagenegative_v1.3", "verybadimagenegative_v1.3.pt", "models/textual_inversion"),
+ ],
+ # Stable Diffusion XL
+ "StableDiffusionXL_v1": [
+ ("stabilityai/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors", "models/stable_diffusion_xl"),
+ ],
+ "BluePencilXL_v200": [
+ ("frankjoshua/bluePencilXL_v200", "bluePencilXL_v200.safetensors", "models/stable_diffusion_xl"),
+ ],
+ "StableDiffusionXL_Turbo": [
+ ("stabilityai/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors", "models/stable_diffusion_xl_turbo"),
+ ],
+ # Stable Diffusion 3
+ "StableDiffusion3": [
+ ("stabilityai/stable-diffusion-3-medium", "sd3_medium_incl_clips_t5xxlfp16.safetensors", "models/stable_diffusion_3"),
+ ],
+ "StableDiffusion3_without_T5": [
+ ("stabilityai/stable-diffusion-3-medium", "sd3_medium_incl_clips.safetensors", "models/stable_diffusion_3"),
+ ],
+ # ControlNet
+ "ControlNet_v11f1p_sd15_depth": [
+ ("lllyasviel/ControlNet-v1-1", "control_v11f1p_sd15_depth.pth", "models/ControlNet"),
+ ("lllyasviel/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
+ ],
+ "ControlNet_v11p_sd15_softedge": [
+ ("lllyasviel/ControlNet-v1-1", "control_v11p_sd15_softedge.pth", "models/ControlNet"),
+ ("lllyasviel/Annotators", "ControlNetHED.pth", "models/Annotators")
+ ],
+ "ControlNet_v11f1e_sd15_tile": [
+ ("lllyasviel/ControlNet-v1-1", "control_v11f1e_sd15_tile.pth", "models/ControlNet")
+ ],
+ "ControlNet_v11p_sd15_lineart": [
+ ("lllyasviel/ControlNet-v1-1", "control_v11p_sd15_lineart.pth", "models/ControlNet"),
+ ("lllyasviel/Annotators", "sk_model.pth", "models/Annotators"),
+ ("lllyasviel/Annotators", "sk_model2.pth", "models/Annotators")
+ ],
+ "ControlNet_union_sdxl_promax": [
+ ("xinsir/controlnet-union-sdxl-1.0", "diffusion_pytorch_model_promax.safetensors", "models/ControlNet/controlnet_union"),
+ ("lllyasviel/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
+ ],
+ # AnimateDiff
+ "AnimateDiff_v2": [
+ ("guoyww/animatediff", "mm_sd_v15_v2.ckpt", "models/AnimateDiff"),
+ ],
+ "AnimateDiff_xl_beta": [
+ ("guoyww/animatediff", "mm_sdxl_v10_beta.ckpt", "models/AnimateDiff"),
+ ],
+
+ # Qwen Prompt
+ "QwenPrompt": [
+ ("Qwen/Qwen2-1.5B-Instruct", "config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
+ ("Qwen/Qwen2-1.5B-Instruct", "generation_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
+ ("Qwen/Qwen2-1.5B-Instruct", "model.safetensors", "models/QwenPrompt/qwen2-1.5b-instruct"),
+ ("Qwen/Qwen2-1.5B-Instruct", "special_tokens_map.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
+ ("Qwen/Qwen2-1.5B-Instruct", "tokenizer.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
+ ("Qwen/Qwen2-1.5B-Instruct", "tokenizer_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
+ ("Qwen/Qwen2-1.5B-Instruct", "merges.txt", "models/QwenPrompt/qwen2-1.5b-instruct"),
+ ("Qwen/Qwen2-1.5B-Instruct", "vocab.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
+ ],
+ # Beautiful Prompt
+ "BeautifulPrompt": [
+ ("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
+ ("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "generation_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
+ ("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "model.safetensors", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
+ ("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "special_tokens_map.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
+ ("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
+ ("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
+ ],
+ # Omost prompt
+ "OmostPrompt":[
+ ("lllyasviel/omost-llama-3-8b-4bits", "model-00001-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
+ ("lllyasviel/omost-llama-3-8b-4bits", "model-00002-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
+ ("lllyasviel/omost-llama-3-8b-4bits", "tokenizer.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
+ ("lllyasviel/omost-llama-3-8b-4bits", "tokenizer_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
+ ("lllyasviel/omost-llama-3-8b-4bits", "config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
+ ("lllyasviel/omost-llama-3-8b-4bits", "generation_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
+ ("lllyasviel/omost-llama-3-8b-4bits", "model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
+ ("lllyasviel/omost-llama-3-8b-4bits", "special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
+ ],
+ # Translator
+ "opus-mt-zh-en": [
+ ("Helsinki-NLP/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
+ ("Helsinki-NLP/opus-mt-zh-en", "generation_config.json", "models/translator/opus-mt-zh-en"),
+ ("Helsinki-NLP/opus-mt-zh-en", "metadata.json", "models/translator/opus-mt-zh-en"),
+ ("Helsinki-NLP/opus-mt-zh-en", "pytorch_model.bin", "models/translator/opus-mt-zh-en"),
+ ("Helsinki-NLP/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"),
+ ("Helsinki-NLP/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"),
+ ("Helsinki-NLP/opus-mt-zh-en", "tokenizer_config.json", "models/translator/opus-mt-zh-en"),
+ ("Helsinki-NLP/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"),
+ ],
+ # IP-Adapter
+ "IP-Adapter-SD": [
+ ("h94/IP-Adapter", "models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion/image_encoder"),
+ ("h94/IP-Adapter", "models/ip-adapter_sd15.bin", "models/IpAdapter/stable_diffusion"),
+ ],
+ "IP-Adapter-SDXL": [
+ ("h94/IP-Adapter", "sdxl_models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion_xl/image_encoder"),
+ ("h94/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"),
+ ],
+ "SDXL-vae-fp16-fix": [
+ ("madebyollin/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "models/sdxl-vae-fp16-fix")
+ ],
+ # Kolors
+ "Kolors": [
+ ("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"),
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"),
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"),
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"),
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"),
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"),
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"),
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"),
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"),
+ ("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
+ ("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
+ ],
+ # FLUX
+ "FLUX.1-dev": [
+ ("black-forest-labs/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
+ ("black-forest-labs/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
+ ("black-forest-labs/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
+ ("black-forest-labs/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
+ ("black-forest-labs/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
+ ("black-forest-labs/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
+ ("black-forest-labs/FLUX.1-dev", "flux1-dev.safetensors", "models/FLUX/FLUX.1-dev"),
+ ],
+ "InstantX/FLUX.1-dev-IP-Adapter": {
+ "file_list": [
+ ("InstantX/FLUX.1-dev-IP-Adapter", "ip-adapter.bin", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter"),
+ ("google/siglip-so400m-patch14-384", "model.safetensors", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
+ ("google/siglip-so400m-patch14-384", "config.json", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
+ ],
+ "load_path": [
+ "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/ip-adapter.bin",
+ "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
+ ],
+ },
+ # RIFE
+ "RIFE": [
+ ("AlexWortega/RIFE", "flownet.pkl", "models/RIFE"),
+ ],
+ # CogVideo
+ "CogVideoX-5B": [
+ ("THUDM/CogVideoX-5b", "text_encoder/config.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
+ ("THUDM/CogVideoX-5b", "text_encoder/model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
+ ("THUDM/CogVideoX-5b", "text_encoder/model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
+ ("THUDM/CogVideoX-5b", "text_encoder/model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
+ ("THUDM/CogVideoX-5b", "transformer/config.json", "models/CogVideo/CogVideoX-5b/transformer"),
+ ("THUDM/CogVideoX-5b", "transformer/diffusion_pytorch_model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/transformer"),
+ ("THUDM/CogVideoX-5b", "transformer/diffusion_pytorch_model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
+ ("THUDM/CogVideoX-5b", "transformer/diffusion_pytorch_model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
+ ("THUDM/CogVideoX-5b", "vae/diffusion_pytorch_model.safetensors", "models/CogVideo/CogVideoX-5b/vae"),
+ ],
+ # Stable Diffusion 3.5
+ "StableDiffusion3.5-large": [
+ ("stabilityai/stable-diffusion-3.5-large", "sd3.5_large.safetensors", "models/stable_diffusion_3"),
+ ("stabilityai/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
+ ("stabilityai/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
+ ("stabilityai/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
+ ],
+}
+preset_models_on_modelscope = {
+ # Hunyuan DiT
+ "HunyuanDiT": [
+ ("modelscope/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
+ ("modelscope/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
+ ("modelscope/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
+ ("modelscope/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
+ ],
+ # Stable Video Diffusion
+ "stable-video-diffusion-img2vid-xt": [
+ ("AI-ModelScope/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
+ ],
+ # ExVideo
+ "ExVideo-SVD-128f-v1": [
+ ("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
+ ],
+ "ExVideo-CogVideoX-LoRA-129f-v1": [
+ ("ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1", "ExVideo-CogVideoX-LoRA-129f-v1.safetensors", "models/lora"),
+ ],
+ # Stable Diffusion
+ "StableDiffusion_v15": [
+ ("AI-ModelScope/stable-diffusion-v1-5", "v1-5-pruned-emaonly.safetensors", "models/stable_diffusion"),
+ ],
+ "DreamShaper_8": [
+ ("sd_lora/dreamshaper_8", "dreamshaper_8.safetensors", "models/stable_diffusion"),
+ ],
+ "AingDiffusion_v12": [
+ ("sd_lora/aingdiffusion_v12", "aingdiffusion_v12.safetensors", "models/stable_diffusion"),
+ ],
+ "Flat2DAnimerge_v45Sharp": [
+ ("sd_lora/Flat-2D-Animerge", "flat2DAnimerge_v45Sharp.safetensors", "models/stable_diffusion"),
+ ],
+ # Textual Inversion
+ "TextualInversion_VeryBadImageNegative_v1.3": [
+ ("sd_lora/verybadimagenegative_v1.3", "verybadimagenegative_v1.3.pt", "models/textual_inversion"),
+ ],
+ # Stable Diffusion XL
+ "StableDiffusionXL_v1": [
+ ("AI-ModelScope/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors", "models/stable_diffusion_xl"),
+ ],
+ "BluePencilXL_v200": [
+ ("sd_lora/bluePencilXL_v200", "bluePencilXL_v200.safetensors", "models/stable_diffusion_xl"),
+ ],
+ "StableDiffusionXL_Turbo": [
+ ("AI-ModelScope/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors", "models/stable_diffusion_xl_turbo"),
+ ],
+ "SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0": [
+ ("sd_lora/zyd232_ChineseInkStyle_SDXL_v1_0", "zyd232_ChineseInkStyle_SDXL_v1_0.safetensors", "models/lora"),
+ ],
+ # Stable Diffusion 3
+ "StableDiffusion3": [
+ ("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips_t5xxlfp16.safetensors", "models/stable_diffusion_3"),
+ ],
+ "StableDiffusion3_without_T5": [
+ ("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips.safetensors", "models/stable_diffusion_3"),
+ ],
+ # ControlNet
+ "ControlNet_v11f1p_sd15_depth": [
+ ("AI-ModelScope/ControlNet-v1-1", "control_v11f1p_sd15_depth.pth", "models/ControlNet"),
+ ("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
+ ],
+ "ControlNet_v11p_sd15_softedge": [
+ ("AI-ModelScope/ControlNet-v1-1", "control_v11p_sd15_softedge.pth", "models/ControlNet"),
+ ("sd_lora/Annotators", "ControlNetHED.pth", "models/Annotators")
+ ],
+ "ControlNet_v11f1e_sd15_tile": [
+ ("AI-ModelScope/ControlNet-v1-1", "control_v11f1e_sd15_tile.pth", "models/ControlNet")
+ ],
+ "ControlNet_v11p_sd15_lineart": [
+ ("AI-ModelScope/ControlNet-v1-1", "control_v11p_sd15_lineart.pth", "models/ControlNet"),
+ ("sd_lora/Annotators", "sk_model.pth", "models/Annotators"),
+ ("sd_lora/Annotators", "sk_model2.pth", "models/Annotators")
+ ],
+ "ControlNet_union_sdxl_promax": [
+ ("AI-ModelScope/controlnet-union-sdxl-1.0", "diffusion_pytorch_model_promax.safetensors", "models/ControlNet/controlnet_union"),
+ ("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
+ ],
+ "Annotators:Depth": [
+ ("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators"),
+ ],
+ "Annotators:Softedge": [
+ ("sd_lora/Annotators", "ControlNetHED.pth", "models/Annotators"),
+ ],
+ "Annotators:Lineart": [
+ ("sd_lora/Annotators", "sk_model.pth", "models/Annotators"),
+ ("sd_lora/Annotators", "sk_model2.pth", "models/Annotators"),
+ ],
+ "Annotators:Normal": [
+ ("sd_lora/Annotators", "scannet.pt", "models/Annotators"),
+ ],
+ "Annotators:Openpose": [
+ ("sd_lora/Annotators", "body_pose_model.pth", "models/Annotators"),
+ ("sd_lora/Annotators", "facenet.pth", "models/Annotators"),
+ ("sd_lora/Annotators", "hand_pose_model.pth", "models/Annotators"),
+ ],
+ # AnimateDiff
+ "AnimateDiff_v2": [
+ ("Shanghai_AI_Laboratory/animatediff", "mm_sd_v15_v2.ckpt", "models/AnimateDiff"),
+ ],
+ "AnimateDiff_xl_beta": [
+ ("Shanghai_AI_Laboratory/animatediff", "mm_sdxl_v10_beta.ckpt", "models/AnimateDiff"),
+ ],
+ # RIFE
+ "RIFE": [
+ ("Damo_XR_Lab/cv_rife_video-frame-interpolation", "flownet.pkl", "models/RIFE"),
+ ],
+ # Qwen Prompt
+ "QwenPrompt": {
+ "file_list": [
+ ("qwen/Qwen2-1.5B-Instruct", "config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
+ ("qwen/Qwen2-1.5B-Instruct", "generation_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
+ ("qwen/Qwen2-1.5B-Instruct", "model.safetensors", "models/QwenPrompt/qwen2-1.5b-instruct"),
+ ("qwen/Qwen2-1.5B-Instruct", "special_tokens_map.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
+ ("qwen/Qwen2-1.5B-Instruct", "tokenizer.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
+ ("qwen/Qwen2-1.5B-Instruct", "tokenizer_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
+ ("qwen/Qwen2-1.5B-Instruct", "merges.txt", "models/QwenPrompt/qwen2-1.5b-instruct"),
+ ("qwen/Qwen2-1.5B-Instruct", "vocab.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
+ ],
+ "load_path": [
+ "models/QwenPrompt/qwen2-1.5b-instruct",
+ ],
+ },
+ # Beautiful Prompt
+ "BeautifulPrompt": {
+ "file_list": [
+ ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
+ ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "generation_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
+ ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "model.safetensors", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
+ ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "special_tokens_map.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
+ ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
+ ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
+ ],
+ "load_path": [
+ "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd",
+ ],
+ },
+ # Omost prompt
+ "OmostPrompt": {
+ "file_list": [
+ ("Omost/omost-llama-3-8b-4bits", "model-00001-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
+ ("Omost/omost-llama-3-8b-4bits", "model-00002-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
+ ("Omost/omost-llama-3-8b-4bits", "tokenizer.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
+ ("Omost/omost-llama-3-8b-4bits", "tokenizer_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
+ ("Omost/omost-llama-3-8b-4bits", "config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
+ ("Omost/omost-llama-3-8b-4bits", "generation_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
+ ("Omost/omost-llama-3-8b-4bits", "model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
+ ("Omost/omost-llama-3-8b-4bits", "special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
+ ],
+ "load_path": [
+ "models/OmostPrompt/omost-llama-3-8b-4bits",
+ ],
+ },
+ # Translator
+ "opus-mt-zh-en": {
+ "file_list": [
+ ("moxying/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
+ ("moxying/opus-mt-zh-en", "generation_config.json", "models/translator/opus-mt-zh-en"),
+ ("moxying/opus-mt-zh-en", "metadata.json", "models/translator/opus-mt-zh-en"),
+ ("moxying/opus-mt-zh-en", "pytorch_model.bin", "models/translator/opus-mt-zh-en"),
+ ("moxying/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"),
+ ("moxying/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"),
+ ("moxying/opus-mt-zh-en", "tokenizer_config.json", "models/translator/opus-mt-zh-en"),
+ ("moxying/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"),
+ ],
+ "load_path": [
+ "models/translator/opus-mt-zh-en",
+ ],
+ },
+ # IP-Adapter
+ "IP-Adapter-SD": [
+ ("AI-ModelScope/IP-Adapter", "models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion/image_encoder"),
+ ("AI-ModelScope/IP-Adapter", "models/ip-adapter_sd15.bin", "models/IpAdapter/stable_diffusion"),
+ ],
+ "IP-Adapter-SDXL": [
+ ("AI-ModelScope/IP-Adapter", "sdxl_models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion_xl/image_encoder"),
+ ("AI-ModelScope/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"),
+ ],
+ # Kolors
+ "Kolors": {
+ "file_list": [
+ ("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"),
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"),
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"),
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"),
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"),
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"),
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"),
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"),
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"),
+ ("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
+ ("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
+ ],
+ "load_path": [
+ "models/kolors/Kolors/text_encoder",
+ "models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors",
+ "models/kolors/Kolors/vae/diffusion_pytorch_model.safetensors",
+ ],
+ },
+ "SDXL-vae-fp16-fix": [
+ ("AI-ModelScope/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "models/sdxl-vae-fp16-fix")
+ ],
+ # FLUX
+ "FLUX.1-dev": {
+ "file_list": [
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
+ ("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
+ ("AI-ModelScope/FLUX.1-dev", "flux1-dev.safetensors", "models/FLUX/FLUX.1-dev"),
+ ],
+ "load_path": [
+ "models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
+ "models/FLUX/FLUX.1-dev/text_encoder_2",
+ "models/FLUX/FLUX.1-dev/ae.safetensors",
+ "models/FLUX/FLUX.1-dev/flux1-dev.safetensors"
+ ],
+ },
+ "FLUX.1-schnell": {
+ "file_list": [
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
+ ("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
+ ("AI-ModelScope/FLUX.1-schnell", "flux1-schnell.safetensors", "models/FLUX/FLUX.1-schnell"),
+ ],
+ "load_path": [
+ "models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
+ "models/FLUX/FLUX.1-dev/text_encoder_2",
+ "models/FLUX/FLUX.1-dev/ae.safetensors",
+ "models/FLUX/FLUX.1-schnell/flux1-schnell.safetensors"
+ ],
+ },
+ "InstantX/FLUX.1-dev-Controlnet-Union-alpha": [
+ ("InstantX/FLUX.1-dev-Controlnet-Union-alpha", "diffusion_pytorch_model.safetensors", "models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha"),
+ ],
+ "jasperai/Flux.1-dev-Controlnet-Depth": [
+ ("jasperai/Flux.1-dev-Controlnet-Depth", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Depth"),
+ ],
+ "jasperai/Flux.1-dev-Controlnet-Surface-Normals": [
+ ("jasperai/Flux.1-dev-Controlnet-Surface-Normals", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Surface-Normals"),
+ ],
+ "jasperai/Flux.1-dev-Controlnet-Upscaler": [
+ ("jasperai/Flux.1-dev-Controlnet-Upscaler", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Upscaler"),
+ ],
+ "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha": [
+ ("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha", "diffusion_pytorch_model.safetensors", "models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha"),
+ ],
+ "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta": [
+ ("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", "diffusion_pytorch_model.safetensors", "models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta"),
+ ],
+ "Shakker-Labs/FLUX.1-dev-ControlNet-Depth": [
+ ("Shakker-Labs/FLUX.1-dev-ControlNet-Depth", "diffusion_pytorch_model.safetensors", "models/ControlNet/Shakker-Labs/FLUX.1-dev-ControlNet-Depth"),
+ ],
+ "Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro": [
+ ("Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro", "diffusion_pytorch_model.safetensors", "models/ControlNet/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro"),
+ ],
+ "InstantX/FLUX.1-dev-IP-Adapter": {
+ "file_list": [
+ ("InstantX/FLUX.1-dev-IP-Adapter", "ip-adapter.bin", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter"),
+ ("AI-ModelScope/siglip-so400m-patch14-384", "model.safetensors", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
+ ("AI-ModelScope/siglip-so400m-patch14-384", "config.json", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
+ ],
+ "load_path": [
+ "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/ip-adapter.bin",
+ "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
+ ],
+ },
+ "InfiniteYou":{
+ "file_list":[
+ ("ByteDance/InfiniteYou", "infu_flux_v1.0/aes_stage2/InfuseNetModel/diffusion_pytorch_model-00001-of-00002.safetensors", "models/InfiniteYou/InfuseNetModel"),
+ ("ByteDance/InfiniteYou", "infu_flux_v1.0/aes_stage2/InfuseNetModel/diffusion_pytorch_model-00002-of-00002.safetensors", "models/InfiniteYou/InfuseNetModel"),
+ ("ByteDance/InfiniteYou", "infu_flux_v1.0/aes_stage2/image_proj_model.bin", "models/InfiniteYou"),
+ ("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/1k3d68.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
+ ("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/2d106det.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
+ ("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/genderage.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
+ ("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/glintr100.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
+ ("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/scrfd_10g_bnkps.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
+ ],
+ "load_path":[
+ [
+ "models/InfiniteYou/InfuseNetModel/diffusion_pytorch_model-00001-of-00002.safetensors",
+ "models/InfiniteYou/InfuseNetModel/diffusion_pytorch_model-00002-of-00002.safetensors"
+ ],
+ "models/InfiniteYou/image_proj_model.bin",
+ ],
+ },
+ # ESRGAN
+ "ESRGAN_x4": [
+ ("AI-ModelScope/Real-ESRGAN", "RealESRGAN_x4.pth", "models/ESRGAN"),
+ ],
+ # RIFE
+ "RIFE": [
+ ("AI-ModelScope/RIFE", "flownet.pkl", "models/RIFE"),
+ ],
+ # Omnigen
+ "OmniGen-v1": {
+ "file_list": [
+ ("BAAI/OmniGen-v1", "vae/diffusion_pytorch_model.safetensors", "models/OmniGen/OmniGen-v1/vae"),
+ ("BAAI/OmniGen-v1", "model.safetensors", "models/OmniGen/OmniGen-v1"),
+ ("BAAI/OmniGen-v1", "config.json", "models/OmniGen/OmniGen-v1"),
+ ("BAAI/OmniGen-v1", "special_tokens_map.json", "models/OmniGen/OmniGen-v1"),
+ ("BAAI/OmniGen-v1", "tokenizer_config.json", "models/OmniGen/OmniGen-v1"),
+ ("BAAI/OmniGen-v1", "tokenizer.json", "models/OmniGen/OmniGen-v1"),
+ ],
+ "load_path": [
+ "models/OmniGen/OmniGen-v1/vae/diffusion_pytorch_model.safetensors",
+ "models/OmniGen/OmniGen-v1/model.safetensors",
+ ]
+ },
+ # CogVideo
+ "CogVideoX-5B": {
+ "file_list": [
+ ("ZhipuAI/CogVideoX-5b", "text_encoder/config.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
+ ("ZhipuAI/CogVideoX-5b", "text_encoder/model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
+ ("ZhipuAI/CogVideoX-5b", "text_encoder/model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
+ ("ZhipuAI/CogVideoX-5b", "text_encoder/model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
+ ("ZhipuAI/CogVideoX-5b", "transformer/config.json", "models/CogVideo/CogVideoX-5b/transformer"),
+ ("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/transformer"),
+ ("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
+ ("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
+ ("ZhipuAI/CogVideoX-5b", "vae/diffusion_pytorch_model.safetensors", "models/CogVideo/CogVideoX-5b/vae"),
+ ],
+ "load_path": [
+ "models/CogVideo/CogVideoX-5b/text_encoder",
+ "models/CogVideo/CogVideoX-5b/transformer",
+ "models/CogVideo/CogVideoX-5b/vae/diffusion_pytorch_model.safetensors",
+ ],
+ },
+ # Stable Diffusion 3.5
+ "StableDiffusion3.5-large": [
+ ("AI-ModelScope/stable-diffusion-3.5-large", "sd3.5_large.safetensors", "models/stable_diffusion_3"),
+ ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
+ ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
+ ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
+ ],
+ "StableDiffusion3.5-medium": [
+ ("AI-ModelScope/stable-diffusion-3.5-medium", "sd3.5_medium.safetensors", "models/stable_diffusion_3"),
+ ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
+ ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
+ ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
+ ],
+ "StableDiffusion3.5-large-turbo": [
+ ("AI-ModelScope/stable-diffusion-3.5-large-turbo", "sd3.5_large_turbo.safetensors", "models/stable_diffusion_3"),
+ ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
+ ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
+ ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
+ ],
+ "HunyuanVideo":{
+ "file_list": [
+ ("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideo/text_encoder"),
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00001-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00002-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00003-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00004-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "config.json", "models/HunyuanVideo/text_encoder_2"),
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model.safetensors.index.json", "models/HunyuanVideo/text_encoder_2"),
+ ("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/vae/pytorch_model.pt", "models/HunyuanVideo/vae"),
+ ("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt", "models/HunyuanVideo/transformers")
+ ],
+ "load_path": [
+ "models/HunyuanVideo/text_encoder/model.safetensors",
+ "models/HunyuanVideo/text_encoder_2",
+ "models/HunyuanVideo/vae/pytorch_model.pt",
+ "models/HunyuanVideo/transformers/mp_rank_00_model_states.pt"
+ ],
+ },
+ "HunyuanVideoI2V":{
+ "file_list": [
+ ("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideoI2V/text_encoder"),
+ ("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00001-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
+ ("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00002-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
+ ("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00003-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
+ ("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00004-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
+ ("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "config.json", "models/HunyuanVideoI2V/text_encoder_2"),
+ ("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model.safetensors.index.json", "models/HunyuanVideoI2V/text_encoder_2"),
+ ("AI-ModelScope/HunyuanVideo-I2V", "hunyuan-video-i2v-720p/vae/pytorch_model.pt", "models/HunyuanVideoI2V/vae"),
+ ("AI-ModelScope/HunyuanVideo-I2V", "hunyuan-video-i2v-720p/transformers/mp_rank_00_model_states.pt", "models/HunyuanVideoI2V/transformers")
+ ],
+ "load_path": [
+ "models/HunyuanVideoI2V/text_encoder/model.safetensors",
+ "models/HunyuanVideoI2V/text_encoder_2",
+ "models/HunyuanVideoI2V/vae/pytorch_model.pt",
+ "models/HunyuanVideoI2V/transformers/mp_rank_00_model_states.pt"
+ ],
+ },
+ "HunyuanVideo-fp8":{
+ "file_list": [
+ ("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideo/text_encoder"),
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00001-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00002-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00003-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00004-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "config.json", "models/HunyuanVideo/text_encoder_2"),
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model.safetensors.index.json", "models/HunyuanVideo/text_encoder_2"),
+ ("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/vae/pytorch_model.pt", "models/HunyuanVideo/vae"),
+ ("DiffSynth-Studio/HunyuanVideo-safetensors", "model.fp8.safetensors", "models/HunyuanVideo/transformers")
+ ],
+ "load_path": [
+ "models/HunyuanVideo/text_encoder/model.safetensors",
+ "models/HunyuanVideo/text_encoder_2",
+ "models/HunyuanVideo/vae/pytorch_model.pt",
+ "models/HunyuanVideo/transformers/model.fp8.safetensors"
+ ],
+ },
+}
+Preset_model_id: TypeAlias = Literal[
+ "HunyuanDiT",
+ "stable-video-diffusion-img2vid-xt",
+ "ExVideo-SVD-128f-v1",
+ "ExVideo-CogVideoX-LoRA-129f-v1",
+ "StableDiffusion_v15",
+ "DreamShaper_8",
+ "AingDiffusion_v12",
+ "Flat2DAnimerge_v45Sharp",
+ "TextualInversion_VeryBadImageNegative_v1.3",
+ "StableDiffusionXL_v1",
+ "BluePencilXL_v200",
+ "StableDiffusionXL_Turbo",
+ "ControlNet_v11f1p_sd15_depth",
+ "ControlNet_v11p_sd15_softedge",
+ "ControlNet_v11f1e_sd15_tile",
+ "ControlNet_v11p_sd15_lineart",
+ "AnimateDiff_v2",
+ "AnimateDiff_xl_beta",
+ "RIFE",
+ "BeautifulPrompt",
+ "opus-mt-zh-en",
+ "IP-Adapter-SD",
+ "IP-Adapter-SDXL",
+ "StableDiffusion3",
+ "StableDiffusion3_without_T5",
+ "Kolors",
+ "SDXL-vae-fp16-fix",
+ "ControlNet_union_sdxl_promax",
+ "FLUX.1-dev",
+ "FLUX.1-schnell",
+ "InstantX/FLUX.1-dev-Controlnet-Union-alpha",
+ "jasperai/Flux.1-dev-Controlnet-Depth",
+ "jasperai/Flux.1-dev-Controlnet-Surface-Normals",
+ "jasperai/Flux.1-dev-Controlnet-Upscaler",
+ "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha",
+ "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta",
+ "Shakker-Labs/FLUX.1-dev-ControlNet-Depth",
+ "Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro",
+ "InstantX/FLUX.1-dev-IP-Adapter",
+ "InfiniteYou",
+ "SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0",
+ "QwenPrompt",
+ "OmostPrompt",
+ "ESRGAN_x4",
+ "RIFE",
+ "OmniGen-v1",
+ "CogVideoX-5B",
+ "Annotators:Depth",
+ "Annotators:Softedge",
+ "Annotators:Lineart",
+ "Annotators:Normal",
+ "Annotators:Openpose",
+ "StableDiffusion3.5-large",
+ "StableDiffusion3.5-medium",
+ "HunyuanVideo",
+ "HunyuanVideo-fp8",
+ "HunyuanVideoI2V",
+]
diff --git a/PusaV1/diffsynth/controlnets/__init__.py b/PusaV1/diffsynth/controlnets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3e15add6ab116bf261804b8c83c86ff4d61c41b
--- /dev/null
+++ b/PusaV1/diffsynth/controlnets/__init__.py
@@ -0,0 +1,2 @@
+from .controlnet_unit import ControlNetConfigUnit, ControlNetUnit, MultiControlNetManager, FluxMultiControlNetManager
+from .processors import Annotator
diff --git a/PusaV1/diffsynth/controlnets/__pycache__/__init__.cpython-310.pyc b/PusaV1/diffsynth/controlnets/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4f30ac0df4ef9628001ebedb71597841b2548f51
Binary files /dev/null and b/PusaV1/diffsynth/controlnets/__pycache__/__init__.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/controlnets/__pycache__/__init__.cpython-312.pyc b/PusaV1/diffsynth/controlnets/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a7ab604e059bfd924e5539e3c85adbfa65f97e5c
Binary files /dev/null and b/PusaV1/diffsynth/controlnets/__pycache__/__init__.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/controlnets/__pycache__/controlnet_unit.cpython-310.pyc b/PusaV1/diffsynth/controlnets/__pycache__/controlnet_unit.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2006805fb082f03ddb5e712c9e305f8e2d3a1f01
Binary files /dev/null and b/PusaV1/diffsynth/controlnets/__pycache__/controlnet_unit.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/controlnets/__pycache__/controlnet_unit.cpython-312.pyc b/PusaV1/diffsynth/controlnets/__pycache__/controlnet_unit.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d27bc50107c7389a5f025b25ec98d78e5516cab1
Binary files /dev/null and b/PusaV1/diffsynth/controlnets/__pycache__/controlnet_unit.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/controlnets/__pycache__/processors.cpython-310.pyc b/PusaV1/diffsynth/controlnets/__pycache__/processors.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c0940e89f78090843378572712fc20d14f0e42c8
Binary files /dev/null and b/PusaV1/diffsynth/controlnets/__pycache__/processors.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/controlnets/__pycache__/processors.cpython-312.pyc b/PusaV1/diffsynth/controlnets/__pycache__/processors.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c382e15b82debc6353ea95cfd028359be6cfe8f6
Binary files /dev/null and b/PusaV1/diffsynth/controlnets/__pycache__/processors.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/controlnets/controlnet_unit.py b/PusaV1/diffsynth/controlnets/controlnet_unit.py
new file mode 100644
index 0000000000000000000000000000000000000000..fdb4829483d208ec0295d1b5a8f82681b4251ea4
--- /dev/null
+++ b/PusaV1/diffsynth/controlnets/controlnet_unit.py
@@ -0,0 +1,91 @@
+import torch
+import numpy as np
+from .processors import Processor_id
+
+
+class ControlNetConfigUnit:
+ def __init__(self, processor_id: Processor_id, model_path, scale=1.0, skip_processor=False):
+ self.processor_id = processor_id
+ self.model_path = model_path
+ self.scale = scale
+ self.skip_processor = skip_processor
+
+
+class ControlNetUnit:
+ def __init__(self, processor, model, scale=1.0):
+ self.processor = processor
+ self.model = model
+ self.scale = scale
+
+
+class MultiControlNetManager:
+ def __init__(self, controlnet_units=[]):
+ self.processors = [unit.processor for unit in controlnet_units]
+ self.models = [unit.model for unit in controlnet_units]
+ self.scales = [unit.scale for unit in controlnet_units]
+
+ def cpu(self):
+ for model in self.models:
+ model.cpu()
+
+ def to(self, device):
+ for model in self.models:
+ model.to(device)
+ for processor in self.processors:
+ processor.to(device)
+
+ def process_image(self, image, processor_id=None):
+ if processor_id is None:
+ processed_image = [processor(image) for processor in self.processors]
+ else:
+ processed_image = [self.processors[processor_id](image)]
+ processed_image = torch.concat([
+ torch.Tensor(np.array(image_, dtype=np.float32) / 255).permute(2, 0, 1).unsqueeze(0)
+ for image_ in processed_image
+ ], dim=0)
+ return processed_image
+
+ def __call__(
+ self,
+ sample, timestep, encoder_hidden_states, conditionings,
+ tiled=False, tile_size=64, tile_stride=32, **kwargs
+ ):
+ res_stack = None
+ for processor, conditioning, model, scale in zip(self.processors, conditionings, self.models, self.scales):
+ res_stack_ = model(
+ sample, timestep, encoder_hidden_states, conditioning, **kwargs,
+ tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
+ processor_id=processor.processor_id
+ )
+ res_stack_ = [res * scale for res in res_stack_]
+ if res_stack is None:
+ res_stack = res_stack_
+ else:
+ res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
+ return res_stack
+
+
+class FluxMultiControlNetManager(MultiControlNetManager):
+ def __init__(self, controlnet_units=[]):
+ super().__init__(controlnet_units=controlnet_units)
+
+ def process_image(self, image, processor_id=None):
+ if processor_id is None:
+ processed_image = [processor(image) for processor in self.processors]
+ else:
+ processed_image = [self.processors[processor_id](image)]
+ return processed_image
+
+ def __call__(self, conditionings, **kwargs):
+ res_stack, single_res_stack = None, None
+ for processor, conditioning, model, scale in zip(self.processors, conditionings, self.models, self.scales):
+ res_stack_, single_res_stack_ = model(controlnet_conditioning=conditioning, processor_id=processor.processor_id, **kwargs)
+ res_stack_ = [res * scale for res in res_stack_]
+ single_res_stack_ = [res * scale for res in single_res_stack_]
+ if res_stack is None:
+ res_stack = res_stack_
+ single_res_stack = single_res_stack_
+ else:
+ res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
+ single_res_stack = [i + j for i, j in zip(single_res_stack, single_res_stack_)]
+ return res_stack, single_res_stack
diff --git a/PusaV1/diffsynth/controlnets/processors.py b/PusaV1/diffsynth/controlnets/processors.py
new file mode 100644
index 0000000000000000000000000000000000000000..06553e06d1c6d09f5a3deecfd4ea5604c5dd4352
--- /dev/null
+++ b/PusaV1/diffsynth/controlnets/processors.py
@@ -0,0 +1,62 @@
+from typing_extensions import Literal, TypeAlias
+
+
+Processor_id: TypeAlias = Literal[
+ "canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "normal", "tile", "none", "inpaint"
+]
+
+class Annotator:
+ def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device='cuda', skip_processor=False):
+ if not skip_processor:
+ if processor_id == "canny":
+ from controlnet_aux.processor import CannyDetector
+ self.processor = CannyDetector()
+ elif processor_id == "depth":
+ from controlnet_aux.processor import MidasDetector
+ self.processor = MidasDetector.from_pretrained(model_path).to(device)
+ elif processor_id == "softedge":
+ from controlnet_aux.processor import HEDdetector
+ self.processor = HEDdetector.from_pretrained(model_path).to(device)
+ elif processor_id == "lineart":
+ from controlnet_aux.processor import LineartDetector
+ self.processor = LineartDetector.from_pretrained(model_path).to(device)
+ elif processor_id == "lineart_anime":
+ from controlnet_aux.processor import LineartAnimeDetector
+ self.processor = LineartAnimeDetector.from_pretrained(model_path).to(device)
+ elif processor_id == "openpose":
+ from controlnet_aux.processor import OpenposeDetector
+ self.processor = OpenposeDetector.from_pretrained(model_path).to(device)
+ elif processor_id == "normal":
+ from controlnet_aux.processor import NormalBaeDetector
+ self.processor = NormalBaeDetector.from_pretrained(model_path).to(device)
+ elif processor_id == "tile" or processor_id == "none" or processor_id == "inpaint":
+ self.processor = None
+ else:
+ raise ValueError(f"Unsupported processor_id: {processor_id}")
+ else:
+ self.processor = None
+
+ self.processor_id = processor_id
+ self.detect_resolution = detect_resolution
+
+ def to(self,device):
+ if hasattr(self.processor,"model") and hasattr(self.processor.model,"to"):
+
+ self.processor.model.to(device)
+
+ def __call__(self, image, mask=None):
+ width, height = image.size
+ if self.processor_id == "openpose":
+ kwargs = {
+ "include_body": True,
+ "include_hand": True,
+ "include_face": True
+ }
+ else:
+ kwargs = {}
+ if self.processor is not None:
+ detect_resolution = self.detect_resolution if self.detect_resolution is not None else min(width, height)
+ image = self.processor(image, detect_resolution=detect_resolution, image_resolution=min(width, height), **kwargs)
+ image = image.resize((width, height))
+ return image
+
diff --git a/PusaV1/diffsynth/data/__init__.py b/PusaV1/diffsynth/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..de09a29905d289673e40e53278a0a3181232640d
--- /dev/null
+++ b/PusaV1/diffsynth/data/__init__.py
@@ -0,0 +1 @@
+from .video import VideoData, save_video, save_frames
diff --git a/PusaV1/diffsynth/data/__pycache__/__init__.cpython-310.pyc b/PusaV1/diffsynth/data/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1b82cfc3bb0343ba28e95d5e7ccbd179bbc2b6a1
Binary files /dev/null and b/PusaV1/diffsynth/data/__pycache__/__init__.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/data/__pycache__/__init__.cpython-312.pyc b/PusaV1/diffsynth/data/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dc82f17e697d2fbe41bb268c3e40efd91107ddf2
Binary files /dev/null and b/PusaV1/diffsynth/data/__pycache__/__init__.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/data/__pycache__/video.cpython-310.pyc b/PusaV1/diffsynth/data/__pycache__/video.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ee26f2322a2cde6d851e8c620941ba62d4a9128f
Binary files /dev/null and b/PusaV1/diffsynth/data/__pycache__/video.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/data/__pycache__/video.cpython-312.pyc b/PusaV1/diffsynth/data/__pycache__/video.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9a1c49a3d00b96e74bfdee0311a23972070d80d6
Binary files /dev/null and b/PusaV1/diffsynth/data/__pycache__/video.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/data/simple_text_image.py b/PusaV1/diffsynth/data/simple_text_image.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a9525e3c8a4d21418c1464fe11fc621450fd0d8
--- /dev/null
+++ b/PusaV1/diffsynth/data/simple_text_image.py
@@ -0,0 +1,41 @@
+import torch, os, torchvision
+from torchvision import transforms
+import pandas as pd
+from PIL import Image
+
+
+
+class TextImageDataset(torch.utils.data.Dataset):
+ def __init__(self, dataset_path, steps_per_epoch=10000, height=1024, width=1024, center_crop=True, random_flip=False):
+ self.steps_per_epoch = steps_per_epoch
+ metadata = pd.read_csv(os.path.join(dataset_path, "train/metadata.csv"))
+ self.path = [os.path.join(dataset_path, "train", file_name) for file_name in metadata["file_name"]]
+ self.text = metadata["text"].to_list()
+ self.height = height
+ self.width = width
+ self.image_processor = transforms.Compose(
+ [
+ transforms.CenterCrop((height, width)) if center_crop else transforms.RandomCrop((height, width)),
+ transforms.RandomHorizontalFlip() if random_flip else transforms.Lambda(lambda x: x),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+
+ def __getitem__(self, index):
+ data_id = torch.randint(0, len(self.path), (1,))[0]
+ data_id = (data_id + index) % len(self.path) # For fixed seed.
+ text = self.text[data_id]
+ image = Image.open(self.path[data_id]).convert("RGB")
+ target_height, target_width = self.height, self.width
+ width, height = image.size
+ scale = max(target_width / width, target_height / height)
+ shape = [round(height*scale),round(width*scale)]
+ image = torchvision.transforms.functional.resize(image,shape,interpolation=transforms.InterpolationMode.BILINEAR)
+ image = self.image_processor(image)
+ return {"text": text, "image": image}
+
+
+ def __len__(self):
+ return self.steps_per_epoch
diff --git a/PusaV1/diffsynth/data/video.py b/PusaV1/diffsynth/data/video.py
new file mode 100644
index 0000000000000000000000000000000000000000..8eafa66855fa5668d42a65ac205776ed254213cf
--- /dev/null
+++ b/PusaV1/diffsynth/data/video.py
@@ -0,0 +1,148 @@
+import imageio, os
+import numpy as np
+from PIL import Image
+from tqdm import tqdm
+
+
+class LowMemoryVideo:
+ def __init__(self, file_name):
+ self.reader = imageio.get_reader(file_name)
+
+ def __len__(self):
+ return self.reader.count_frames()
+
+ def __getitem__(self, item):
+ return Image.fromarray(np.array(self.reader.get_data(item))).convert("RGB")
+
+ def __del__(self):
+ self.reader.close()
+
+
+def split_file_name(file_name):
+ result = []
+ number = -1
+ for i in file_name:
+ if ord(i)>=ord("0") and ord(i)<=ord("9"):
+ if number == -1:
+ number = 0
+ number = number*10 + ord(i) - ord("0")
+ else:
+ if number != -1:
+ result.append(number)
+ number = -1
+ result.append(i)
+ if number != -1:
+ result.append(number)
+ result = tuple(result)
+ return result
+
+
+def search_for_images(folder):
+ file_list = [i for i in os.listdir(folder) if i.endswith(".jpg") or i.endswith(".png")]
+ file_list = [(split_file_name(file_name), file_name) for file_name in file_list]
+ file_list = [i[1] for i in sorted(file_list)]
+ file_list = [os.path.join(folder, i) for i in file_list]
+ return file_list
+
+
+class LowMemoryImageFolder:
+ def __init__(self, folder, file_list=None):
+ if file_list is None:
+ self.file_list = search_for_images(folder)
+ else:
+ self.file_list = [os.path.join(folder, file_name) for file_name in file_list]
+
+ def __len__(self):
+ return len(self.file_list)
+
+ def __getitem__(self, item):
+ return Image.open(self.file_list[item]).convert("RGB")
+
+ def __del__(self):
+ pass
+
+
+def crop_and_resize(image, height, width):
+ image = np.array(image)
+ image_height, image_width, _ = image.shape
+ if image_height / image_width < height / width:
+ croped_width = int(image_height / height * width)
+ left = (image_width - croped_width) // 2
+ image = image[:, left: left+croped_width]
+ image = Image.fromarray(image).resize((width, height))
+ else:
+ croped_height = int(image_width / width * height)
+ left = (image_height - croped_height) // 2
+ image = image[left: left+croped_height, :]
+ image = Image.fromarray(image).resize((width, height))
+ return image
+
+
+class VideoData:
+ def __init__(self, video_file=None, image_folder=None, height=None, width=None, **kwargs):
+ if video_file is not None:
+ self.data_type = "video"
+ self.data = LowMemoryVideo(video_file, **kwargs)
+ elif image_folder is not None:
+ self.data_type = "images"
+ self.data = LowMemoryImageFolder(image_folder, **kwargs)
+ else:
+ raise ValueError("Cannot open video or image folder")
+ self.length = None
+ self.set_shape(height, width)
+
+ def raw_data(self):
+ frames = []
+ for i in range(self.__len__()):
+ frames.append(self.__getitem__(i))
+ return frames
+
+ def set_length(self, length):
+ self.length = length
+
+ def set_shape(self, height, width):
+ self.height = height
+ self.width = width
+
+ def __len__(self):
+ if self.length is None:
+ return len(self.data)
+ else:
+ return self.length
+
+ def shape(self):
+ if self.height is not None and self.width is not None:
+ return self.height, self.width
+ else:
+ height, width, _ = self.__getitem__(0).shape
+ return height, width
+
+ def __getitem__(self, item):
+ frame = self.data.__getitem__(item)
+ width, height = frame.size
+ if self.height is not None and self.width is not None:
+ if self.height != height or self.width != width:
+ frame = crop_and_resize(frame, self.height, self.width)
+ return frame
+
+ def __del__(self):
+ pass
+
+ def save_images(self, folder):
+ os.makedirs(folder, exist_ok=True)
+ for i in tqdm(range(self.__len__()), desc="Saving images"):
+ frame = self.__getitem__(i)
+ frame.save(os.path.join(folder, f"{i}.png"))
+
+
+def save_video(frames, save_path, fps, quality=9, ffmpeg_params=None):
+ writer = imageio.get_writer(save_path, fps=fps, quality=quality, ffmpeg_params=ffmpeg_params)
+ for frame in tqdm(frames, desc="Saving video"):
+ frame = np.array(frame)
+ writer.append_data(frame)
+ writer.close()
+
+def save_frames(frames, save_path):
+ os.makedirs(save_path, exist_ok=True)
+ for i, frame in enumerate(tqdm(frames, desc="Saving images")):
+ frame.save(os.path.join(save_path, f"{i}.png"))
diff --git a/PusaV1/diffsynth/distributed/__init__.py b/PusaV1/diffsynth/distributed/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/PusaV1/diffsynth/distributed/__pycache__/__init__.cpython-312.pyc b/PusaV1/diffsynth/distributed/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0aacc987a686f4daa090141dd1570f219543dae1
Binary files /dev/null and b/PusaV1/diffsynth/distributed/__pycache__/__init__.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/distributed/__pycache__/xdit_context_parallel.cpython-312.pyc b/PusaV1/diffsynth/distributed/__pycache__/xdit_context_parallel.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..209b2fe6bf908e781d31a9ff8a872d661b44ca6e
Binary files /dev/null and b/PusaV1/diffsynth/distributed/__pycache__/xdit_context_parallel.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/distributed/xdit_context_parallel.py b/PusaV1/diffsynth/distributed/xdit_context_parallel.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c1a2572539aba98ecf900ae264dbaaf29286371
--- /dev/null
+++ b/PusaV1/diffsynth/distributed/xdit_context_parallel.py
@@ -0,0 +1,129 @@
+import torch
+from typing import Optional
+from einops import rearrange
+from xfuser.core.distributed import (get_sequence_parallel_rank,
+ get_sequence_parallel_world_size,
+ get_sp_group)
+from xfuser.core.long_ctx_attention import xFuserLongContextAttention
+
+def sinusoidal_embedding_1d(dim, position):
+ sinusoid = torch.outer(position.type(torch.float64), torch.pow(
+ 10000, -torch.arange(dim//2, dtype=torch.float64, device=position.device).div(dim//2)))
+ x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
+ return x.to(position.dtype)
+
+def pad_freqs(original_tensor, target_len):
+ seq_len, s1, s2 = original_tensor.shape
+ pad_size = target_len - seq_len
+ padding_tensor = torch.ones(
+ pad_size,
+ s1,
+ s2,
+ dtype=original_tensor.dtype,
+ device=original_tensor.device)
+ padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
+ return padded_tensor
+
+def rope_apply(x, freqs, num_heads):
+ x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
+ s_per_rank = x.shape[1]
+
+ x_out = torch.view_as_complex(x.to(torch.float64).reshape(
+ x.shape[0], x.shape[1], x.shape[2], -1, 2))
+
+ sp_size = get_sequence_parallel_world_size()
+ sp_rank = get_sequence_parallel_rank()
+ freqs = pad_freqs(freqs, s_per_rank * sp_size)
+ freqs_rank = freqs[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :]
+
+ x_out = torch.view_as_real(x_out * freqs_rank).flatten(2)
+ return x_out.to(x.dtype)
+
+def usp_dit_forward(self,
+ x: torch.Tensor,
+ timestep: torch.Tensor,
+ context: torch.Tensor,
+ clip_feature: Optional[torch.Tensor] = None,
+ y: Optional[torch.Tensor] = None,
+ use_gradient_checkpointing: bool = False,
+ use_gradient_checkpointing_offload: bool = False,
+ **kwargs,
+ ):
+ t = self.time_embedding(
+ sinusoidal_embedding_1d(self.freq_dim, timestep))
+ t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
+ context = self.text_embedding(context)
+
+ if self.has_image_input:
+ x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
+ clip_embdding = self.img_emb(clip_feature)
+ context = torch.cat([clip_embdding, context], dim=1)
+
+ x, (f, h, w) = self.patchify(x)
+
+ freqs = torch.cat([
+ self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
+ self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
+ self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
+ ], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+ return custom_forward
+
+ # Context Parallel
+ x = torch.chunk(
+ x, get_sequence_parallel_world_size(),
+ dim=1)[get_sequence_parallel_rank()]
+
+ for block in self.blocks:
+ if self.training and use_gradient_checkpointing:
+ if use_gradient_checkpointing_offload:
+ with torch.autograd.graph.save_on_cpu():
+ x = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ x, context, t_mod, freqs,
+ use_reentrant=False,
+ )
+ else:
+ x = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ x, context, t_mod, freqs,
+ use_reentrant=False,
+ )
+ else:
+ x = block(x, context, t_mod, freqs)
+
+ x = self.head(x, t)
+
+ # Context Parallel
+ x = get_sp_group().all_gather(x, dim=1)
+
+ # unpatchify
+ x = self.unpatchify(x, (f, h, w))
+ return x
+
+
+def usp_attn_forward(self, x, freqs):
+ q = self.norm_q(self.q(x))
+ k = self.norm_k(self.k(x))
+ v = self.v(x)
+
+ q = rope_apply(q, freqs, self.num_heads)
+ k = rope_apply(k, freqs, self.num_heads)
+ q = rearrange(q, "b s (n d) -> b s n d", n=self.num_heads)
+ k = rearrange(k, "b s (n d) -> b s n d", n=self.num_heads)
+ v = rearrange(v, "b s (n d) -> b s n d", n=self.num_heads)
+
+ x = xFuserLongContextAttention()(
+ None,
+ query=q,
+ key=k,
+ value=v,
+ )
+ x = x.flatten(2)
+
+ del q, k, v
+ torch.cuda.empty_cache()
+ return self.o(x)
\ No newline at end of file
diff --git a/PusaV1/diffsynth/extensions/ESRGAN/__init__.py b/PusaV1/diffsynth/extensions/ESRGAN/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..94aff4c6fe8d75ff65e30d672dbe3e38a0d919c3
--- /dev/null
+++ b/PusaV1/diffsynth/extensions/ESRGAN/__init__.py
@@ -0,0 +1,137 @@
+import torch
+from einops import repeat
+from PIL import Image
+import numpy as np
+
+
+class ResidualDenseBlock(torch.nn.Module):
+
+ def __init__(self, num_feat=64, num_grow_ch=32):
+ super(ResidualDenseBlock, self).__init__()
+ self.conv1 = torch.nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
+ self.conv2 = torch.nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
+ self.conv3 = torch.nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
+ self.conv4 = torch.nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
+ self.conv5 = torch.nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
+ self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True)
+
+ def forward(self, x):
+ x1 = self.lrelu(self.conv1(x))
+ x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
+ x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
+ x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
+ return x5 * 0.2 + x
+
+
+class RRDB(torch.nn.Module):
+
+ def __init__(self, num_feat, num_grow_ch=32):
+ super(RRDB, self).__init__()
+ self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
+ self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
+ self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
+
+ def forward(self, x):
+ out = self.rdb1(x)
+ out = self.rdb2(out)
+ out = self.rdb3(out)
+ return out * 0.2 + x
+
+
+class RRDBNet(torch.nn.Module):
+
+ def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, **kwargs):
+ super(RRDBNet, self).__init__()
+ self.conv_first = torch.nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
+ self.body = torch.torch.nn.Sequential(*[RRDB(num_feat=num_feat, num_grow_ch=num_grow_ch) for _ in range(num_block)])
+ self.conv_body = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ # upsample
+ self.conv_up1 = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.conv_up2 = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.conv_hr = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.conv_last = torch.nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
+ self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True)
+
+ def forward(self, x):
+ feat = x
+ feat = self.conv_first(feat)
+ body_feat = self.conv_body(self.body(feat))
+ feat = feat + body_feat
+ # upsample
+ feat = repeat(feat, "B C H W -> B C (H 2) (W 2)")
+ feat = self.lrelu(self.conv_up1(feat))
+ feat = repeat(feat, "B C H W -> B C (H 2) (W 2)")
+ feat = self.lrelu(self.conv_up2(feat))
+ out = self.conv_last(self.lrelu(self.conv_hr(feat)))
+ return out
+
+ @staticmethod
+ def state_dict_converter():
+ return RRDBNetStateDictConverter()
+
+
+class RRDBNetStateDictConverter:
+ def __init__(self):
+ pass
+
+ def from_diffusers(self, state_dict):
+ return state_dict, {"upcast_to_float32": True}
+
+ def from_civitai(self, state_dict):
+ return state_dict, {"upcast_to_float32": True}
+
+
+class ESRGAN(torch.nn.Module):
+ def __init__(self, model):
+ super().__init__()
+ self.model = model
+
+ @staticmethod
+ def from_model_manager(model_manager):
+ return ESRGAN(model_manager.fetch_model("esrgan"))
+
+ def process_image(self, image):
+ image = torch.Tensor(np.array(image, dtype=np.float32) / 255).permute(2, 0, 1)
+ return image
+
+ def process_images(self, images):
+ images = [self.process_image(image) for image in images]
+ images = torch.stack(images)
+ return images
+
+ def decode_images(self, images):
+ images = (images.permute(0, 2, 3, 1) * 255).clip(0, 255).numpy().astype(np.uint8)
+ images = [Image.fromarray(image) for image in images]
+ return images
+
+ @torch.no_grad()
+ def upscale(self, images, batch_size=4, progress_bar=lambda x:x):
+ if not isinstance(images, list):
+ images = [images]
+ is_single_image = True
+ else:
+ is_single_image = False
+
+ # Preprocess
+ input_tensor = self.process_images(images)
+
+ # Interpolate
+ output_tensor = []
+ for batch_id in progress_bar(range(0, input_tensor.shape[0], batch_size)):
+ batch_id_ = min(batch_id + batch_size, input_tensor.shape[0])
+ batch_input_tensor = input_tensor[batch_id: batch_id_]
+ batch_input_tensor = batch_input_tensor.to(
+ device=self.model.conv_first.weight.device,
+ dtype=self.model.conv_first.weight.dtype)
+ batch_output_tensor = self.model(batch_input_tensor)
+ output_tensor.append(batch_output_tensor.cpu())
+
+ # Output
+ output_tensor = torch.concat(output_tensor, dim=0)
+
+ # To images
+ output_images = self.decode_images(output_tensor)
+ if is_single_image:
+ output_images = output_images[0]
+ return output_images
diff --git a/PusaV1/diffsynth/extensions/ESRGAN/__pycache__/__init__.cpython-310.pyc b/PusaV1/diffsynth/extensions/ESRGAN/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f49b614ba1b01ebc142b57210a7233eee262c02d
Binary files /dev/null and b/PusaV1/diffsynth/extensions/ESRGAN/__pycache__/__init__.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/extensions/ESRGAN/__pycache__/__init__.cpython-312.pyc b/PusaV1/diffsynth/extensions/ESRGAN/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ac8dd8569729295c1336740be2d19f4d4eba854d
Binary files /dev/null and b/PusaV1/diffsynth/extensions/ESRGAN/__pycache__/__init__.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/extensions/FastBlend/__init__.py b/PusaV1/diffsynth/extensions/FastBlend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2bf812c2085082bfa82658dd249ebca89e9fb465
--- /dev/null
+++ b/PusaV1/diffsynth/extensions/FastBlend/__init__.py
@@ -0,0 +1,63 @@
+from .runners.fast import TableManager, PyramidPatchMatcher
+from PIL import Image
+import numpy as np
+import cupy as cp
+
+
+class FastBlendSmoother:
+ def __init__(self):
+ self.batch_size = 8
+ self.window_size = 64
+ self.ebsynth_config = {
+ "minimum_patch_size": 5,
+ "threads_per_block": 8,
+ "num_iter": 5,
+ "gpu_id": 0,
+ "guide_weight": 10.0,
+ "initialize": "identity",
+ "tracking_window_size": 0,
+ }
+
+ @staticmethod
+ def from_model_manager(model_manager):
+ # TODO: fetch GPU ID from model_manager
+ return FastBlendSmoother()
+
+ def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config):
+ frames_guide = [np.array(frame) for frame in frames_guide]
+ frames_style = [np.array(frame) for frame in frames_style]
+ table_manager = TableManager()
+ patch_match_engine = PyramidPatchMatcher(
+ image_height=frames_style[0].shape[0],
+ image_width=frames_style[0].shape[1],
+ channel=3,
+ **ebsynth_config
+ )
+ # left part
+ table_l = table_manager.build_remapping_table(frames_guide, frames_style, patch_match_engine, batch_size, desc="FastBlend Step 1/4")
+ table_l = table_manager.remapping_table_to_blending_table(table_l)
+ table_l = table_manager.process_window_sum(frames_guide, table_l, patch_match_engine, window_size, batch_size, desc="FastBlend Step 2/4")
+ # right part
+ table_r = table_manager.build_remapping_table(frames_guide[::-1], frames_style[::-1], patch_match_engine, batch_size, desc="FastBlend Step 3/4")
+ table_r = table_manager.remapping_table_to_blending_table(table_r)
+ table_r = table_manager.process_window_sum(frames_guide[::-1], table_r, patch_match_engine, window_size, batch_size, desc="FastBlend Step 4/4")[::-1]
+ # merge
+ frames = []
+ for (frame_l, weight_l), frame_m, (frame_r, weight_r) in zip(table_l, frames_style, table_r):
+ weight_m = -1
+ weight = weight_l + weight_m + weight_r
+ frame = frame_l * (weight_l / weight) + frame_m * (weight_m / weight) + frame_r * (weight_r / weight)
+ frames.append(frame)
+ frames = [Image.fromarray(frame.clip(0, 255).astype("uint8")) for frame in frames]
+ return frames
+
+ def __call__(self, rendered_frames, original_frames=None, **kwargs):
+ frames = self.run(
+ original_frames, rendered_frames,
+ self.batch_size, self.window_size, self.ebsynth_config
+ )
+ mempool = cp.get_default_memory_pool()
+ pinned_mempool = cp.get_default_pinned_memory_pool()
+ mempool.free_all_blocks()
+ pinned_mempool.free_all_blocks()
+ return frames
\ No newline at end of file
diff --git a/PusaV1/diffsynth/extensions/FastBlend/api.py b/PusaV1/diffsynth/extensions/FastBlend/api.py
new file mode 100644
index 0000000000000000000000000000000000000000..2db24330e375ed62065af54613b6ab956c9c64cf
--- /dev/null
+++ b/PusaV1/diffsynth/extensions/FastBlend/api.py
@@ -0,0 +1,397 @@
+from .runners import AccurateModeRunner, FastModeRunner, BalancedModeRunner, InterpolationModeRunner, InterpolationModeSingleFrameRunner
+from .data import VideoData, get_video_fps, save_video, search_for_images
+import os
+import gradio as gr
+
+
+def check_input_for_blending(video_guide, video_guide_folder, video_style, video_style_folder):
+ frames_guide = VideoData(video_guide, video_guide_folder)
+ frames_style = VideoData(video_style, video_style_folder)
+ message = ""
+ if len(frames_guide) < len(frames_style):
+ message += f"The number of frames mismatches. Only the first {len(frames_guide)} frames of style video will be used.\n"
+ frames_style.set_length(len(frames_guide))
+ elif len(frames_guide) > len(frames_style):
+ message += f"The number of frames mismatches. Only the first {len(frames_style)} frames of guide video will be used.\n"
+ frames_guide.set_length(len(frames_style))
+ height_guide, width_guide = frames_guide.shape()
+ height_style, width_style = frames_style.shape()
+ if height_guide != height_style or width_guide != width_style:
+ message += f"The shape of frames mismatches. The frames in style video will be resized to (height: {height_guide}, width: {width_guide})\n"
+ frames_style.set_shape(height_guide, width_guide)
+ return frames_guide, frames_style, message
+
+
+def smooth_video(
+ video_guide,
+ video_guide_folder,
+ video_style,
+ video_style_folder,
+ mode,
+ window_size,
+ batch_size,
+ tracking_window_size,
+ output_path,
+ fps,
+ minimum_patch_size,
+ num_iter,
+ guide_weight,
+ initialize,
+ progress = None,
+):
+ # input
+ frames_guide, frames_style, message = check_input_for_blending(video_guide, video_guide_folder, video_style, video_style_folder)
+ if len(message) > 0:
+ print(message)
+ # output
+ if output_path == "":
+ if video_style is None:
+ output_path = os.path.join(video_style_folder, "output")
+ else:
+ output_path = os.path.join(os.path.split(video_style)[0], "output")
+ os.makedirs(output_path, exist_ok=True)
+ print("No valid output_path. Your video will be saved here:", output_path)
+ elif not os.path.exists(output_path):
+ os.makedirs(output_path, exist_ok=True)
+ print("Your video will be saved here:", output_path)
+ frames_path = os.path.join(output_path, "frames")
+ video_path = os.path.join(output_path, "video.mp4")
+ os.makedirs(frames_path, exist_ok=True)
+ # process
+ if mode == "Fast" or mode == "Balanced":
+ tracking_window_size = 0
+ ebsynth_config = {
+ "minimum_patch_size": minimum_patch_size,
+ "threads_per_block": 8,
+ "num_iter": num_iter,
+ "gpu_id": 0,
+ "guide_weight": guide_weight,
+ "initialize": initialize,
+ "tracking_window_size": tracking_window_size,
+ }
+ if mode == "Fast":
+ FastModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path)
+ elif mode == "Balanced":
+ BalancedModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path)
+ elif mode == "Accurate":
+ AccurateModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path)
+ # output
+ try:
+ fps = int(fps)
+ except:
+ fps = get_video_fps(video_style) if video_style is not None else 30
+ print("Fps:", fps)
+ print("Saving video...")
+ video_path = save_video(frames_path, video_path, num_frames=len(frames_style), fps=fps)
+ print("Success!")
+ print("Your frames are here:", frames_path)
+ print("Your video is here:", video_path)
+ return output_path, fps, video_path
+
+
+class KeyFrameMatcher:
+ def __init__(self):
+ pass
+
+ def extract_number_from_filename(self, file_name):
+ result = []
+ number = -1
+ for i in file_name:
+ if ord(i)>=ord("0") and ord(i)<=ord("9"):
+ if number == -1:
+ number = 0
+ number = number*10 + ord(i) - ord("0")
+ else:
+ if number != -1:
+ result.append(number)
+ number = -1
+ if number != -1:
+ result.append(number)
+ result = tuple(result)
+ return result
+
+ def extract_number_from_filenames(self, file_names):
+ numbers = [self.extract_number_from_filename(file_name) for file_name in file_names]
+ min_length = min(len(i) for i in numbers)
+ for i in range(min_length-1, -1, -1):
+ if len(set(number[i] for number in numbers))==len(file_names):
+ return [number[i] for number in numbers]
+ return list(range(len(file_names)))
+
+ def match_using_filename(self, file_names_a, file_names_b):
+ file_names_b_set = set(file_names_b)
+ matched_file_name = []
+ for file_name in file_names_a:
+ if file_name not in file_names_b_set:
+ matched_file_name.append(None)
+ else:
+ matched_file_name.append(file_name)
+ return matched_file_name
+
+ def match_using_numbers(self, file_names_a, file_names_b):
+ numbers_a = self.extract_number_from_filenames(file_names_a)
+ numbers_b = self.extract_number_from_filenames(file_names_b)
+ numbers_b_dict = {number: file_name for number, file_name in zip(numbers_b, file_names_b)}
+ matched_file_name = []
+ for number in numbers_a:
+ if number in numbers_b_dict:
+ matched_file_name.append(numbers_b_dict[number])
+ else:
+ matched_file_name.append(None)
+ return matched_file_name
+
+ def match_filenames(self, file_names_a, file_names_b):
+ matched_file_name = self.match_using_filename(file_names_a, file_names_b)
+ if sum([i is not None for i in matched_file_name]) > 0:
+ return matched_file_name
+ matched_file_name = self.match_using_numbers(file_names_a, file_names_b)
+ return matched_file_name
+
+
+def detect_frames(frames_path, keyframes_path):
+ if not os.path.exists(frames_path) and not os.path.exists(keyframes_path):
+ return "Please input the directory of guide video and rendered frames"
+ elif not os.path.exists(frames_path):
+ return "Please input the directory of guide video"
+ elif not os.path.exists(keyframes_path):
+ return "Please input the directory of rendered frames"
+ frames = [os.path.split(i)[-1] for i in search_for_images(frames_path)]
+ keyframes = [os.path.split(i)[-1] for i in search_for_images(keyframes_path)]
+ if len(frames)==0:
+ return f"No images detected in {frames_path}"
+ if len(keyframes)==0:
+ return f"No images detected in {keyframes_path}"
+ matched_keyframes = KeyFrameMatcher().match_filenames(frames, keyframes)
+ max_filename_length = max([len(i) for i in frames])
+ if sum([i is not None for i in matched_keyframes])==0:
+ message = ""
+ for frame, matched_keyframe in zip(frames, matched_keyframes):
+ message += frame + " " * (max_filename_length - len(frame) + 1)
+ message += "--> No matched keyframes\n"
+ else:
+ message = ""
+ for frame, matched_keyframe in zip(frames, matched_keyframes):
+ message += frame + " " * (max_filename_length - len(frame) + 1)
+ if matched_keyframe is None:
+ message += "--> [to be rendered]\n"
+ else:
+ message += f"--> {matched_keyframe}\n"
+ return message
+
+
+def check_input_for_interpolating(frames_path, keyframes_path):
+ # search for images
+ frames = [os.path.split(i)[-1] for i in search_for_images(frames_path)]
+ keyframes = [os.path.split(i)[-1] for i in search_for_images(keyframes_path)]
+ # match frames
+ matched_keyframes = KeyFrameMatcher().match_filenames(frames, keyframes)
+ file_list = [file_name for file_name in matched_keyframes if file_name is not None]
+ index_style = [i for i, file_name in enumerate(matched_keyframes) if file_name is not None]
+ frames_guide = VideoData(None, frames_path)
+ frames_style = VideoData(None, keyframes_path, file_list=file_list)
+ # match shape
+ message = ""
+ height_guide, width_guide = frames_guide.shape()
+ height_style, width_style = frames_style.shape()
+ if height_guide != height_style or width_guide != width_style:
+ message += f"The shape of frames mismatches. The rendered keyframes will be resized to (height: {height_guide}, width: {width_guide})\n"
+ frames_style.set_shape(height_guide, width_guide)
+ return frames_guide, frames_style, index_style, message
+
+
+def interpolate_video(
+ frames_path,
+ keyframes_path,
+ output_path,
+ fps,
+ batch_size,
+ tracking_window_size,
+ minimum_patch_size,
+ num_iter,
+ guide_weight,
+ initialize,
+ progress = None,
+):
+ # input
+ frames_guide, frames_style, index_style, message = check_input_for_interpolating(frames_path, keyframes_path)
+ if len(message) > 0:
+ print(message)
+ # output
+ if output_path == "":
+ output_path = os.path.join(keyframes_path, "output")
+ os.makedirs(output_path, exist_ok=True)
+ print("No valid output_path. Your video will be saved here:", output_path)
+ elif not os.path.exists(output_path):
+ os.makedirs(output_path, exist_ok=True)
+ print("Your video will be saved here:", output_path)
+ output_frames_path = os.path.join(output_path, "frames")
+ output_video_path = os.path.join(output_path, "video.mp4")
+ os.makedirs(output_frames_path, exist_ok=True)
+ # process
+ ebsynth_config = {
+ "minimum_patch_size": minimum_patch_size,
+ "threads_per_block": 8,
+ "num_iter": num_iter,
+ "gpu_id": 0,
+ "guide_weight": guide_weight,
+ "initialize": initialize,
+ "tracking_window_size": tracking_window_size
+ }
+ if len(index_style)==1:
+ InterpolationModeSingleFrameRunner().run(frames_guide, frames_style, index_style, batch_size=batch_size, ebsynth_config=ebsynth_config, save_path=output_frames_path)
+ else:
+ InterpolationModeRunner().run(frames_guide, frames_style, index_style, batch_size=batch_size, ebsynth_config=ebsynth_config, save_path=output_frames_path)
+ try:
+ fps = int(fps)
+ except:
+ fps = 30
+ print("Fps:", fps)
+ print("Saving video...")
+ video_path = save_video(output_frames_path, output_video_path, num_frames=len(frames_guide), fps=fps)
+ print("Success!")
+ print("Your frames are here:", output_frames_path)
+ print("Your video is here:", video_path)
+ return output_path, fps, video_path
+
+
+def on_ui_tabs():
+ with gr.Blocks(analytics_enabled=False) as ui_component:
+ with gr.Tab("Blend"):
+ gr.Markdown("""
+# Blend
+
+Given a guide video and a style video, this algorithm will make the style video fluent according to the motion features of the guide video. Click [here](https://github.com/Artiprocher/sd-webui-fastblend/assets/35051019/208d902d-6aba-48d7-b7d5-cd120ebd306d) to see the example. Note that this extension doesn't support long videos. Please use short videos (e.g., several seconds). The algorithm is mainly designed for 512*512 resolution. Please use a larger `Minimum patch size` for higher resolution.
+ """)
+ with gr.Row():
+ with gr.Column():
+ with gr.Tab("Guide video"):
+ video_guide = gr.Video(label="Guide video")
+ with gr.Tab("Guide video (images format)"):
+ video_guide_folder = gr.Textbox(label="Guide video (images format)", value="")
+ with gr.Column():
+ with gr.Tab("Style video"):
+ video_style = gr.Video(label="Style video")
+ with gr.Tab("Style video (images format)"):
+ video_style_folder = gr.Textbox(label="Style video (images format)", value="")
+ with gr.Column():
+ output_path = gr.Textbox(label="Output directory", value="", placeholder="Leave empty to use the directory of style video")
+ fps = gr.Textbox(label="Fps", value="", placeholder="Leave empty to use the default fps")
+ video_output = gr.Video(label="Output video", interactive=False, show_share_button=True)
+ btn = gr.Button(value="Blend")
+ with gr.Row():
+ with gr.Column():
+ gr.Markdown("# Settings")
+ mode = gr.Radio(["Fast", "Balanced", "Accurate"], label="Inference mode", value="Fast", interactive=True)
+ window_size = gr.Slider(label="Sliding window size", value=15, minimum=1, maximum=1000, step=1, interactive=True)
+ batch_size = gr.Slider(label="Batch size", value=8, minimum=1, maximum=128, step=1, interactive=True)
+ tracking_window_size = gr.Slider(label="Tracking window size (only for accurate mode)", value=0, minimum=0, maximum=10, step=1, interactive=True)
+ gr.Markdown("## Advanced Settings")
+ minimum_patch_size = gr.Slider(label="Minimum patch size (odd number)", value=5, minimum=5, maximum=99, step=2, interactive=True)
+ num_iter = gr.Slider(label="Number of iterations", value=5, minimum=1, maximum=10, step=1, interactive=True)
+ guide_weight = gr.Slider(label="Guide weight", value=10.0, minimum=0.0, maximum=100.0, step=0.1, interactive=True)
+ initialize = gr.Radio(["identity", "random"], label="NNF initialization", value="identity", interactive=True)
+ with gr.Column():
+ gr.Markdown("""
+# Reference
+
+* Output directory: the directory to save the video.
+* Inference mode
+
+|Mode|Time|Memory|Quality|Frame by frame output|Description|
+|-|-|-|-|-|-|
+|Fast|■|■■■|■■|No|Blend the frames using a tree-like data structure, which requires much RAM but is fast.|
+|Balanced|■■|■|■■|Yes|Blend the frames naively.|
+|Accurate|■■■|■|■■■|Yes|Blend the frames and align them together for higher video quality. When [batch size] >= [sliding window size] * 2 + 1, the performance is the best.|
+
+* Sliding window size: our algorithm will blend the frames in a sliding windows. If the size is n, each frame will be blended with the last n frames and the next n frames. A large sliding window can make the video fluent but sometimes smoggy.
+* Batch size: a larger batch size makes the program faster but requires more VRAM.
+* Tracking window size (only for accurate mode): The size of window in which our algorithm tracks moving objects. Empirically, 1 is enough.
+* Advanced settings
+ * Minimum patch size (odd number): the minimum patch size used for patch matching. (Default: 5)
+ * Number of iterations: the number of iterations of patch matching. (Default: 5)
+ * Guide weight: a parameter that determines how much motion feature applied to the style video. (Default: 10)
+ * NNF initialization: how to initialize the NNF (Nearest Neighbor Field). (Default: identity)
+ """)
+ btn.click(
+ smooth_video,
+ inputs=[
+ video_guide,
+ video_guide_folder,
+ video_style,
+ video_style_folder,
+ mode,
+ window_size,
+ batch_size,
+ tracking_window_size,
+ output_path,
+ fps,
+ minimum_patch_size,
+ num_iter,
+ guide_weight,
+ initialize
+ ],
+ outputs=[output_path, fps, video_output]
+ )
+ with gr.Tab("Interpolate"):
+ gr.Markdown("""
+# Interpolate
+
+Given a guide video and some rendered keyframes, this algorithm will render the remaining frames. Click [here](https://github.com/Artiprocher/sd-webui-fastblend/assets/35051019/3490c5b4-8f67-478f-86de-f9adc2ace16a) to see the example. The algorithm is experimental and is only tested for 512*512 resolution.
+ """)
+ with gr.Row():
+ with gr.Column():
+ with gr.Row():
+ with gr.Column():
+ video_guide_folder_ = gr.Textbox(label="Guide video (images format)", value="")
+ with gr.Column():
+ rendered_keyframes_ = gr.Textbox(label="Rendered keyframes (images format)", value="")
+ with gr.Row():
+ detected_frames = gr.Textbox(label="Detected frames", value="Please input the directory of guide video and rendered frames", lines=9, max_lines=9, interactive=False)
+ video_guide_folder_.change(detect_frames, inputs=[video_guide_folder_, rendered_keyframes_], outputs=detected_frames)
+ rendered_keyframes_.change(detect_frames, inputs=[video_guide_folder_, rendered_keyframes_], outputs=detected_frames)
+ with gr.Column():
+ output_path_ = gr.Textbox(label="Output directory", value="", placeholder="Leave empty to use the directory of rendered keyframes")
+ fps_ = gr.Textbox(label="Fps", value="", placeholder="Leave empty to use the default fps")
+ video_output_ = gr.Video(label="Output video", interactive=False, show_share_button=True)
+ btn_ = gr.Button(value="Interpolate")
+ with gr.Row():
+ with gr.Column():
+ gr.Markdown("# Settings")
+ batch_size_ = gr.Slider(label="Batch size", value=8, minimum=1, maximum=128, step=1, interactive=True)
+ tracking_window_size_ = gr.Slider(label="Tracking window size", value=0, minimum=0, maximum=10, step=1, interactive=True)
+ gr.Markdown("## Advanced Settings")
+ minimum_patch_size_ = gr.Slider(label="Minimum patch size (odd number, larger is better)", value=15, minimum=5, maximum=99, step=2, interactive=True)
+ num_iter_ = gr.Slider(label="Number of iterations", value=5, minimum=1, maximum=10, step=1, interactive=True)
+ guide_weight_ = gr.Slider(label="Guide weight", value=10.0, minimum=0.0, maximum=100.0, step=0.1, interactive=True)
+ initialize_ = gr.Radio(["identity", "random"], label="NNF initialization", value="identity", interactive=True)
+ with gr.Column():
+ gr.Markdown("""
+# Reference
+
+* Output directory: the directory to save the video.
+* Batch size: a larger batch size makes the program faster but requires more VRAM.
+* Tracking window size (only for accurate mode): The size of window in which our algorithm tracks moving objects. Empirically, 1 is enough.
+* Advanced settings
+ * Minimum patch size (odd number): the minimum patch size used for patch matching. **This parameter should be larger than that in blending. (Default: 15)**
+ * Number of iterations: the number of iterations of patch matching. (Default: 5)
+ * Guide weight: a parameter that determines how much motion feature applied to the style video. (Default: 10)
+ * NNF initialization: how to initialize the NNF (Nearest Neighbor Field). (Default: identity)
+ """)
+ btn_.click(
+ interpolate_video,
+ inputs=[
+ video_guide_folder_,
+ rendered_keyframes_,
+ output_path_,
+ fps_,
+ batch_size_,
+ tracking_window_size_,
+ minimum_patch_size_,
+ num_iter_,
+ guide_weight_,
+ initialize_,
+ ],
+ outputs=[output_path_, fps_, video_output_]
+ )
+
+ return [(ui_component, "FastBlend", "FastBlend_ui")]
diff --git a/PusaV1/diffsynth/extensions/FastBlend/cupy_kernels.py b/PusaV1/diffsynth/extensions/FastBlend/cupy_kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..70e2790a2c67a2dd537f4188b38ebfc785f1fb34
--- /dev/null
+++ b/PusaV1/diffsynth/extensions/FastBlend/cupy_kernels.py
@@ -0,0 +1,119 @@
+import cupy as cp
+
+remapping_kernel = cp.RawKernel(r'''
+extern "C" __global__
+void remap(
+ const int height,
+ const int width,
+ const int channel,
+ const int patch_size,
+ const int pad_size,
+ const float* source_style,
+ const int* nnf,
+ float* target_style
+) {
+ const int r = (patch_size - 1) / 2;
+ const int x = blockDim.x * blockIdx.x + threadIdx.x;
+ const int y = blockDim.y * blockIdx.y + threadIdx.y;
+ if (x >= height or y >= width) return;
+ const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
+ const int pid = (x + pad_size) * (width + pad_size * 2) + (y + pad_size);
+ const int min_px = x < r ? -x : -r;
+ const int max_px = x + r > height - 1 ? height - 1 - x : r;
+ const int min_py = y < r ? -y : -r;
+ const int max_py = y + r > width - 1 ? width - 1 - y : r;
+ int num = 0;
+ for (int px = min_px; px <= max_px; px++){
+ for (int py = min_py; py <= max_py; py++){
+ const int nid = (x + px) * width + y + py;
+ const int x_ = nnf[blockIdx.z * height * width * 2 + nid*2 + 0] - px;
+ const int y_ = nnf[blockIdx.z * height * width * 2 + nid*2 + 1] - py;
+ if (x_ < 0 or y_ < 0 or x_ >= height or y_ >= width)continue;
+ const int pid_ = (x_ + pad_size) * (width + pad_size * 2) + (y_ + pad_size);
+ num++;
+ for (int c = 0; c < channel; c++){
+ target_style[z + pid * channel + c] += source_style[z + pid_ * channel + c];
+ }
+ }
+ }
+ for (int c = 0; c < channel; c++){
+ target_style[z + pid * channel + c] /= num;
+ }
+}
+''', 'remap')
+
+
+patch_error_kernel = cp.RawKernel(r'''
+extern "C" __global__
+void patch_error(
+ const int height,
+ const int width,
+ const int channel,
+ const int patch_size,
+ const int pad_size,
+ const float* source,
+ const int* nnf,
+ const float* target,
+ float* error
+) {
+ const int r = (patch_size - 1) / 2;
+ const int x = blockDim.x * blockIdx.x + threadIdx.x;
+ const int y = blockDim.y * blockIdx.y + threadIdx.y;
+ const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
+ if (x >= height or y >= width) return;
+ const int x_ = nnf[blockIdx.z * height * width * 2 + (x * width + y)*2 + 0];
+ const int y_ = nnf[blockIdx.z * height * width * 2 + (x * width + y)*2 + 1];
+ float e = 0;
+ for (int px = -r; px <= r; px++){
+ for (int py = -r; py <= r; py++){
+ const int pid = (x + pad_size + px) * (width + pad_size * 2) + y + pad_size + py;
+ const int pid_ = (x_ + pad_size + px) * (width + pad_size * 2) + y_ + pad_size + py;
+ for (int c = 0; c < channel; c++){
+ const float diff = target[z + pid * channel + c] - source[z + pid_ * channel + c];
+ e += diff * diff;
+ }
+ }
+ }
+ error[blockIdx.z * height * width + x * width + y] = e;
+}
+''', 'patch_error')
+
+
+pairwise_patch_error_kernel = cp.RawKernel(r'''
+extern "C" __global__
+void pairwise_patch_error(
+ const int height,
+ const int width,
+ const int channel,
+ const int patch_size,
+ const int pad_size,
+ const float* source_a,
+ const int* nnf_a,
+ const float* source_b,
+ const int* nnf_b,
+ float* error
+) {
+ const int r = (patch_size - 1) / 2;
+ const int x = blockDim.x * blockIdx.x + threadIdx.x;
+ const int y = blockDim.y * blockIdx.y + threadIdx.y;
+ const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
+ if (x >= height or y >= width) return;
+ const int z_nnf = blockIdx.z * height * width * 2 + (x * width + y) * 2;
+ const int x_a = nnf_a[z_nnf + 0];
+ const int y_a = nnf_a[z_nnf + 1];
+ const int x_b = nnf_b[z_nnf + 0];
+ const int y_b = nnf_b[z_nnf + 1];
+ float e = 0;
+ for (int px = -r; px <= r; px++){
+ for (int py = -r; py <= r; py++){
+ const int pid_a = (x_a + pad_size + px) * (width + pad_size * 2) + y_a + pad_size + py;
+ const int pid_b = (x_b + pad_size + px) * (width + pad_size * 2) + y_b + pad_size + py;
+ for (int c = 0; c < channel; c++){
+ const float diff = source_a[z + pid_a * channel + c] - source_b[z + pid_b * channel + c];
+ e += diff * diff;
+ }
+ }
+ }
+ error[blockIdx.z * height * width + x * width + y] = e;
+}
+''', 'pairwise_patch_error')
diff --git a/PusaV1/diffsynth/extensions/FastBlend/data.py b/PusaV1/diffsynth/extensions/FastBlend/data.py
new file mode 100644
index 0000000000000000000000000000000000000000..dcaddd77de9eaf208cd083dd522e5eaa6b58f783
--- /dev/null
+++ b/PusaV1/diffsynth/extensions/FastBlend/data.py
@@ -0,0 +1,146 @@
+import imageio, os
+import numpy as np
+from PIL import Image
+
+
+def read_video(file_name):
+ reader = imageio.get_reader(file_name)
+ video = []
+ for frame in reader:
+ frame = np.array(frame)
+ video.append(frame)
+ reader.close()
+ return video
+
+
+def get_video_fps(file_name):
+ reader = imageio.get_reader(file_name)
+ fps = reader.get_meta_data()["fps"]
+ reader.close()
+ return fps
+
+
+def save_video(frames_path, video_path, num_frames, fps):
+ writer = imageio.get_writer(video_path, fps=fps, quality=9)
+ for i in range(num_frames):
+ frame = np.array(Image.open(os.path.join(frames_path, "%05d.png" % i)))
+ writer.append_data(frame)
+ writer.close()
+ return video_path
+
+
+class LowMemoryVideo:
+ def __init__(self, file_name):
+ self.reader = imageio.get_reader(file_name)
+
+ def __len__(self):
+ return self.reader.count_frames()
+
+ def __getitem__(self, item):
+ return np.array(self.reader.get_data(item))
+
+ def __del__(self):
+ self.reader.close()
+
+
+def split_file_name(file_name):
+ result = []
+ number = -1
+ for i in file_name:
+ if ord(i)>=ord("0") and ord(i)<=ord("9"):
+ if number == -1:
+ number = 0
+ number = number*10 + ord(i) - ord("0")
+ else:
+ if number != -1:
+ result.append(number)
+ number = -1
+ result.append(i)
+ if number != -1:
+ result.append(number)
+ result = tuple(result)
+ return result
+
+
+def search_for_images(folder):
+ file_list = [i for i in os.listdir(folder) if i.endswith(".jpg") or i.endswith(".png")]
+ file_list = [(split_file_name(file_name), file_name) for file_name in file_list]
+ file_list = [i[1] for i in sorted(file_list)]
+ file_list = [os.path.join(folder, i) for i in file_list]
+ return file_list
+
+
+def read_images(folder):
+ file_list = search_for_images(folder)
+ frames = [np.array(Image.open(i)) for i in file_list]
+ return frames
+
+
+class LowMemoryImageFolder:
+ def __init__(self, folder, file_list=None):
+ if file_list is None:
+ self.file_list = search_for_images(folder)
+ else:
+ self.file_list = [os.path.join(folder, file_name) for file_name in file_list]
+
+ def __len__(self):
+ return len(self.file_list)
+
+ def __getitem__(self, item):
+ return np.array(Image.open(self.file_list[item]))
+
+ def __del__(self):
+ pass
+
+
+class VideoData:
+ def __init__(self, video_file, image_folder, **kwargs):
+ if video_file is not None:
+ self.data_type = "video"
+ self.data = LowMemoryVideo(video_file, **kwargs)
+ elif image_folder is not None:
+ self.data_type = "images"
+ self.data = LowMemoryImageFolder(image_folder, **kwargs)
+ else:
+ raise ValueError("Cannot open video or image folder")
+ self.length = None
+ self.height = None
+ self.width = None
+
+ def raw_data(self):
+ frames = []
+ for i in range(self.__len__()):
+ frames.append(self.__getitem__(i))
+ return frames
+
+ def set_length(self, length):
+ self.length = length
+
+ def set_shape(self, height, width):
+ self.height = height
+ self.width = width
+
+ def __len__(self):
+ if self.length is None:
+ return len(self.data)
+ else:
+ return self.length
+
+ def shape(self):
+ if self.height is not None and self.width is not None:
+ return self.height, self.width
+ else:
+ height, width, _ = self.__getitem__(0).shape
+ return height, width
+
+ def __getitem__(self, item):
+ frame = self.data.__getitem__(item)
+ height, width, _ = frame.shape
+ if self.height is not None and self.width is not None:
+ if self.height != height or self.width != width:
+ frame = Image.fromarray(frame).resize((self.width, self.height))
+ frame = np.array(frame)
+ return frame
+
+ def __del__(self):
+ pass
diff --git a/PusaV1/diffsynth/extensions/FastBlend/patch_match.py b/PusaV1/diffsynth/extensions/FastBlend/patch_match.py
new file mode 100644
index 0000000000000000000000000000000000000000..aeb1f7f9e31b4b2ad77ec58ba8d32361315e0390
--- /dev/null
+++ b/PusaV1/diffsynth/extensions/FastBlend/patch_match.py
@@ -0,0 +1,298 @@
+from .cupy_kernels import remapping_kernel, patch_error_kernel, pairwise_patch_error_kernel
+import numpy as np
+import cupy as cp
+import cv2
+
+
+class PatchMatcher:
+ def __init__(
+ self, height, width, channel, minimum_patch_size,
+ threads_per_block=8, num_iter=5, gpu_id=0, guide_weight=10.0,
+ random_search_steps=3, random_search_range=4,
+ use_mean_target_style=False, use_pairwise_patch_error=False,
+ tracking_window_size=0
+ ):
+ self.height = height
+ self.width = width
+ self.channel = channel
+ self.minimum_patch_size = minimum_patch_size
+ self.threads_per_block = threads_per_block
+ self.num_iter = num_iter
+ self.gpu_id = gpu_id
+ self.guide_weight = guide_weight
+ self.random_search_steps = random_search_steps
+ self.random_search_range = random_search_range
+ self.use_mean_target_style = use_mean_target_style
+ self.use_pairwise_patch_error = use_pairwise_patch_error
+ self.tracking_window_size = tracking_window_size
+
+ self.patch_size_list = [minimum_patch_size + i*2 for i in range(num_iter)][::-1]
+ self.pad_size = self.patch_size_list[0] // 2
+ self.grid = (
+ (height + threads_per_block - 1) // threads_per_block,
+ (width + threads_per_block - 1) // threads_per_block
+ )
+ self.block = (threads_per_block, threads_per_block)
+
+ def pad_image(self, image):
+ return cp.pad(image, ((0, 0), (self.pad_size, self.pad_size), (self.pad_size, self.pad_size), (0, 0)))
+
+ def unpad_image(self, image):
+ return image[:, self.pad_size: -self.pad_size, self.pad_size: -self.pad_size, :]
+
+ def apply_nnf_to_image(self, nnf, source):
+ batch_size = source.shape[0]
+ target = cp.zeros((batch_size, self.height + self.pad_size * 2, self.width + self.pad_size * 2, self.channel), dtype=cp.float32)
+ remapping_kernel(
+ self.grid + (batch_size,),
+ self.block,
+ (self.height, self.width, self.channel, self.patch_size, self.pad_size, source, nnf, target)
+ )
+ return target
+
+ def get_patch_error(self, source, nnf, target):
+ batch_size = source.shape[0]
+ error = cp.zeros((batch_size, self.height, self.width), dtype=cp.float32)
+ patch_error_kernel(
+ self.grid + (batch_size,),
+ self.block,
+ (self.height, self.width, self.channel, self.patch_size, self.pad_size, source, nnf, target, error)
+ )
+ return error
+
+ def get_pairwise_patch_error(self, source, nnf):
+ batch_size = source.shape[0]//2
+ error = cp.zeros((batch_size, self.height, self.width), dtype=cp.float32)
+ source_a, nnf_a = source[0::2].copy(), nnf[0::2].copy()
+ source_b, nnf_b = source[1::2].copy(), nnf[1::2].copy()
+ pairwise_patch_error_kernel(
+ self.grid + (batch_size,),
+ self.block,
+ (self.height, self.width, self.channel, self.patch_size, self.pad_size, source_a, nnf_a, source_b, nnf_b, error)
+ )
+ error = error.repeat(2, axis=0)
+ return error
+
+ def get_error(self, source_guide, target_guide, source_style, target_style, nnf):
+ error_guide = self.get_patch_error(source_guide, nnf, target_guide)
+ if self.use_mean_target_style:
+ target_style = self.apply_nnf_to_image(nnf, source_style)
+ target_style = target_style.mean(axis=0, keepdims=True)
+ target_style = target_style.repeat(source_guide.shape[0], axis=0)
+ if self.use_pairwise_patch_error:
+ error_style = self.get_pairwise_patch_error(source_style, nnf)
+ else:
+ error_style = self.get_patch_error(source_style, nnf, target_style)
+ error = error_guide * self.guide_weight + error_style
+ return error
+
+ def clamp_bound(self, nnf):
+ nnf[:,:,:,0] = cp.clip(nnf[:,:,:,0], 0, self.height-1)
+ nnf[:,:,:,1] = cp.clip(nnf[:,:,:,1], 0, self.width-1)
+ return nnf
+
+ def random_step(self, nnf, r):
+ batch_size = nnf.shape[0]
+ step = cp.random.randint(-r, r+1, size=(batch_size, self.height, self.width, 2), dtype=cp.int32)
+ upd_nnf = self.clamp_bound(nnf + step)
+ return upd_nnf
+
+ def neighboor_step(self, nnf, d):
+ if d==0:
+ upd_nnf = cp.concatenate([nnf[:, :1, :], nnf[:, :-1, :]], axis=1)
+ upd_nnf[:, :, :, 0] += 1
+ elif d==1:
+ upd_nnf = cp.concatenate([nnf[:, :, :1], nnf[:, :, :-1]], axis=2)
+ upd_nnf[:, :, :, 1] += 1
+ elif d==2:
+ upd_nnf = cp.concatenate([nnf[:, 1:, :], nnf[:, -1:, :]], axis=1)
+ upd_nnf[:, :, :, 0] -= 1
+ elif d==3:
+ upd_nnf = cp.concatenate([nnf[:, :, 1:], nnf[:, :, -1:]], axis=2)
+ upd_nnf[:, :, :, 1] -= 1
+ upd_nnf = self.clamp_bound(upd_nnf)
+ return upd_nnf
+
+ def shift_nnf(self, nnf, d):
+ if d>0:
+ d = min(nnf.shape[0], d)
+ upd_nnf = cp.concatenate([nnf[d:]] + [nnf[-1:]] * d, axis=0)
+ else:
+ d = max(-nnf.shape[0], d)
+ upd_nnf = cp.concatenate([nnf[:1]] * (-d) + [nnf[:d]], axis=0)
+ return upd_nnf
+
+ def track_step(self, nnf, d):
+ if self.use_pairwise_patch_error:
+ upd_nnf = cp.zeros_like(nnf)
+ upd_nnf[0::2] = self.shift_nnf(nnf[0::2], d)
+ upd_nnf[1::2] = self.shift_nnf(nnf[1::2], d)
+ else:
+ upd_nnf = self.shift_nnf(nnf, d)
+ return upd_nnf
+
+ def C(self, n, m):
+ # not used
+ c = 1
+ for i in range(1, n+1):
+ c *= i
+ for i in range(1, m+1):
+ c //= i
+ for i in range(1, n-m+1):
+ c //= i
+ return c
+
+ def bezier_step(self, nnf, r):
+ # not used
+ n = r * 2 - 1
+ upd_nnf = cp.zeros(shape=nnf.shape, dtype=cp.float32)
+ for i, d in enumerate(list(range(-r, 0)) + list(range(1, r+1))):
+ if d>0:
+ ctl_nnf = cp.concatenate([nnf[d:]] + [nnf[-1:]] * d, axis=0)
+ elif d<0:
+ ctl_nnf = cp.concatenate([nnf[:1]] * (-d) + [nnf[:d]], axis=0)
+ upd_nnf += ctl_nnf * (self.C(n, i) / 2**n)
+ upd_nnf = self.clamp_bound(upd_nnf).astype(nnf.dtype)
+ return upd_nnf
+
+ def update(self, source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf):
+ upd_err = self.get_error(source_guide, target_guide, source_style, target_style, upd_nnf)
+ upd_idx = (upd_err < err)
+ nnf[upd_idx] = upd_nnf[upd_idx]
+ err[upd_idx] = upd_err[upd_idx]
+ return nnf, err
+
+ def propagation(self, source_guide, target_guide, source_style, target_style, nnf, err):
+ for d in cp.random.permutation(4):
+ upd_nnf = self.neighboor_step(nnf, d)
+ nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
+ return nnf, err
+
+ def random_search(self, source_guide, target_guide, source_style, target_style, nnf, err):
+ for i in range(self.random_search_steps):
+ upd_nnf = self.random_step(nnf, self.random_search_range)
+ nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
+ return nnf, err
+
+ def track(self, source_guide, target_guide, source_style, target_style, nnf, err):
+ for d in range(1, self.tracking_window_size + 1):
+ upd_nnf = self.track_step(nnf, d)
+ nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
+ upd_nnf = self.track_step(nnf, -d)
+ nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
+ return nnf, err
+
+ def iteration(self, source_guide, target_guide, source_style, target_style, nnf, err):
+ nnf, err = self.propagation(source_guide, target_guide, source_style, target_style, nnf, err)
+ nnf, err = self.random_search(source_guide, target_guide, source_style, target_style, nnf, err)
+ nnf, err = self.track(source_guide, target_guide, source_style, target_style, nnf, err)
+ return nnf, err
+
+ def estimate_nnf(self, source_guide, target_guide, source_style, nnf):
+ with cp.cuda.Device(self.gpu_id):
+ source_guide = self.pad_image(source_guide)
+ target_guide = self.pad_image(target_guide)
+ source_style = self.pad_image(source_style)
+ for it in range(self.num_iter):
+ self.patch_size = self.patch_size_list[it]
+ target_style = self.apply_nnf_to_image(nnf, source_style)
+ err = self.get_error(source_guide, target_guide, source_style, target_style, nnf)
+ nnf, err = self.iteration(source_guide, target_guide, source_style, target_style, nnf, err)
+ target_style = self.unpad_image(self.apply_nnf_to_image(nnf, source_style))
+ return nnf, target_style
+
+
+class PyramidPatchMatcher:
+ def __init__(
+ self, image_height, image_width, channel, minimum_patch_size,
+ threads_per_block=8, num_iter=5, gpu_id=0, guide_weight=10.0,
+ use_mean_target_style=False, use_pairwise_patch_error=False,
+ tracking_window_size=0,
+ initialize="identity"
+ ):
+ maximum_patch_size = minimum_patch_size + (num_iter - 1) * 2
+ self.pyramid_level = int(np.log2(min(image_height, image_width) / maximum_patch_size))
+ self.pyramid_heights = []
+ self.pyramid_widths = []
+ self.patch_matchers = []
+ self.minimum_patch_size = minimum_patch_size
+ self.num_iter = num_iter
+ self.gpu_id = gpu_id
+ self.initialize = initialize
+ for level in range(self.pyramid_level):
+ height = image_height//(2**(self.pyramid_level - 1 - level))
+ width = image_width//(2**(self.pyramid_level - 1 - level))
+ self.pyramid_heights.append(height)
+ self.pyramid_widths.append(width)
+ self.patch_matchers.append(PatchMatcher(
+ height, width, channel, minimum_patch_size=minimum_patch_size,
+ threads_per_block=threads_per_block, num_iter=num_iter, gpu_id=gpu_id, guide_weight=guide_weight,
+ use_mean_target_style=use_mean_target_style, use_pairwise_patch_error=use_pairwise_patch_error,
+ tracking_window_size=tracking_window_size
+ ))
+
+ def resample_image(self, images, level):
+ height, width = self.pyramid_heights[level], self.pyramid_widths[level]
+ images = images.get()
+ images_resample = []
+ for image in images:
+ image_resample = cv2.resize(image, (width, height), interpolation=cv2.INTER_AREA)
+ images_resample.append(image_resample)
+ images_resample = cp.array(np.stack(images_resample), dtype=cp.float32)
+ return images_resample
+
+ def initialize_nnf(self, batch_size):
+ if self.initialize == "random":
+ height, width = self.pyramid_heights[0], self.pyramid_widths[0]
+ nnf = cp.stack([
+ cp.random.randint(0, height, (batch_size, height, width), dtype=cp.int32),
+ cp.random.randint(0, width, (batch_size, height, width), dtype=cp.int32)
+ ], axis=3)
+ elif self.initialize == "identity":
+ height, width = self.pyramid_heights[0], self.pyramid_widths[0]
+ nnf = cp.stack([
+ cp.repeat(cp.arange(height), width).reshape(height, width),
+ cp.tile(cp.arange(width), height).reshape(height, width)
+ ], axis=2)
+ nnf = cp.stack([nnf] * batch_size)
+ else:
+ raise NotImplementedError()
+ return nnf
+
+ def update_nnf(self, nnf, level):
+ # upscale
+ nnf = nnf.repeat(2, axis=1).repeat(2, axis=2) * 2
+ nnf[:,[i for i in range(nnf.shape[0]) if i&1],:,0] += 1
+ nnf[:,:,[i for i in range(nnf.shape[0]) if i&1],1] += 1
+ # check if scale is 2
+ height, width = self.pyramid_heights[level], self.pyramid_widths[level]
+ if height != nnf.shape[0] * 2 or width != nnf.shape[1] * 2:
+ nnf = nnf.get().astype(np.float32)
+ nnf = [cv2.resize(n, (width, height), interpolation=cv2.INTER_LINEAR) for n in nnf]
+ nnf = cp.array(np.stack(nnf), dtype=cp.int32)
+ nnf = self.patch_matchers[level].clamp_bound(nnf)
+ return nnf
+
+ def apply_nnf_to_image(self, nnf, image):
+ with cp.cuda.Device(self.gpu_id):
+ image = self.patch_matchers[-1].pad_image(image)
+ image = self.patch_matchers[-1].apply_nnf_to_image(nnf, image)
+ return image
+
+ def estimate_nnf(self, source_guide, target_guide, source_style):
+ with cp.cuda.Device(self.gpu_id):
+ if not isinstance(source_guide, cp.ndarray):
+ source_guide = cp.array(source_guide, dtype=cp.float32)
+ if not isinstance(target_guide, cp.ndarray):
+ target_guide = cp.array(target_guide, dtype=cp.float32)
+ if not isinstance(source_style, cp.ndarray):
+ source_style = cp.array(source_style, dtype=cp.float32)
+ for level in range(self.pyramid_level):
+ nnf = self.initialize_nnf(source_guide.shape[0]) if level==0 else self.update_nnf(nnf, level)
+ source_guide_ = self.resample_image(source_guide, level)
+ target_guide_ = self.resample_image(target_guide, level)
+ source_style_ = self.resample_image(source_style, level)
+ nnf, target_style = self.patch_matchers[level].estimate_nnf(
+ source_guide_, target_guide_, source_style_, nnf
+ )
+ return nnf.get(), target_style.get()
diff --git a/PusaV1/diffsynth/extensions/FastBlend/runners/__init__.py b/PusaV1/diffsynth/extensions/FastBlend/runners/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..078382729690d282436411661693ce22f3dcc033
--- /dev/null
+++ b/PusaV1/diffsynth/extensions/FastBlend/runners/__init__.py
@@ -0,0 +1,4 @@
+from .accurate import AccurateModeRunner
+from .fast import FastModeRunner
+from .balanced import BalancedModeRunner
+from .interpolation import InterpolationModeRunner, InterpolationModeSingleFrameRunner
diff --git a/PusaV1/diffsynth/extensions/FastBlend/runners/accurate.py b/PusaV1/diffsynth/extensions/FastBlend/runners/accurate.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e4a47f1981ebc1ec9a034a814dfc1130955c2e1
--- /dev/null
+++ b/PusaV1/diffsynth/extensions/FastBlend/runners/accurate.py
@@ -0,0 +1,35 @@
+from ..patch_match import PyramidPatchMatcher
+import os
+import numpy as np
+from PIL import Image
+from tqdm import tqdm
+
+
+class AccurateModeRunner:
+ def __init__(self):
+ pass
+
+ def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, desc="Accurate Mode", save_path=None):
+ patch_match_engine = PyramidPatchMatcher(
+ image_height=frames_style[0].shape[0],
+ image_width=frames_style[0].shape[1],
+ channel=3,
+ use_mean_target_style=True,
+ **ebsynth_config
+ )
+ # run
+ n = len(frames_style)
+ for target in tqdm(range(n), desc=desc):
+ l, r = max(target - window_size, 0), min(target + window_size + 1, n)
+ remapped_frames = []
+ for i in range(l, r, batch_size):
+ j = min(i + batch_size, r)
+ source_guide = np.stack([frames_guide[source] for source in range(i, j)])
+ target_guide = np.stack([frames_guide[target]] * (j - i))
+ source_style = np.stack([frames_style[source] for source in range(i, j)])
+ _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
+ remapped_frames.append(target_style)
+ frame = np.concatenate(remapped_frames, axis=0).mean(axis=0)
+ frame = frame.clip(0, 255).astype("uint8")
+ if save_path is not None:
+ Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))
\ No newline at end of file
diff --git a/PusaV1/diffsynth/extensions/FastBlend/runners/balanced.py b/PusaV1/diffsynth/extensions/FastBlend/runners/balanced.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c9a2bb7e438b49c89d0786e858ccf03302fab35
--- /dev/null
+++ b/PusaV1/diffsynth/extensions/FastBlend/runners/balanced.py
@@ -0,0 +1,46 @@
+from ..patch_match import PyramidPatchMatcher
+import os
+import numpy as np
+from PIL import Image
+from tqdm import tqdm
+
+
+class BalancedModeRunner:
+ def __init__(self):
+ pass
+
+ def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, desc="Balanced Mode", save_path=None):
+ patch_match_engine = PyramidPatchMatcher(
+ image_height=frames_style[0].shape[0],
+ image_width=frames_style[0].shape[1],
+ channel=3,
+ **ebsynth_config
+ )
+ # tasks
+ n = len(frames_style)
+ tasks = []
+ for target in range(n):
+ for source in range(target - window_size, target + window_size + 1):
+ if source >= 0 and source < n and source != target:
+ tasks.append((source, target))
+ # run
+ frames = [(None, 1) for i in range(n)]
+ for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
+ tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
+ source_guide = np.stack([frames_guide[source] for source, target in tasks_batch])
+ target_guide = np.stack([frames_guide[target] for source, target in tasks_batch])
+ source_style = np.stack([frames_style[source] for source, target in tasks_batch])
+ _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
+ for (source, target), result in zip(tasks_batch, target_style):
+ frame, weight = frames[target]
+ if frame is None:
+ frame = frames_style[target]
+ frames[target] = (
+ frame * (weight / (weight + 1)) + result / (weight + 1),
+ weight + 1
+ )
+ if weight + 1 == min(n, target + window_size + 1) - max(0, target - window_size):
+ frame = frame.clip(0, 255).astype("uint8")
+ if save_path is not None:
+ Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))
+ frames[target] = (None, 1)
diff --git a/PusaV1/diffsynth/extensions/FastBlend/runners/fast.py b/PusaV1/diffsynth/extensions/FastBlend/runners/fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ba5731475ab875929b14181e0c22f4fd466c591
--- /dev/null
+++ b/PusaV1/diffsynth/extensions/FastBlend/runners/fast.py
@@ -0,0 +1,141 @@
+from ..patch_match import PyramidPatchMatcher
+import functools, os
+import numpy as np
+from PIL import Image
+from tqdm import tqdm
+
+
+class TableManager:
+ def __init__(self):
+ pass
+
+ def task_list(self, n):
+ tasks = []
+ max_level = 1
+ while (1<=n:
+ break
+ meta_data = {
+ "source": i,
+ "target": j,
+ "level": level + 1
+ }
+ tasks.append(meta_data)
+ tasks.sort(key=functools.cmp_to_key(lambda u, v: u["level"]-v["level"]))
+ return tasks
+
+ def build_remapping_table(self, frames_guide, frames_style, patch_match_engine, batch_size, desc=""):
+ n = len(frames_guide)
+ tasks = self.task_list(n)
+ remapping_table = [[(frames_style[i], 1)] for i in range(n)]
+ for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
+ tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
+ source_guide = np.stack([frames_guide[task["source"]] for task in tasks_batch])
+ target_guide = np.stack([frames_guide[task["target"]] for task in tasks_batch])
+ source_style = np.stack([frames_style[task["source"]] for task in tasks_batch])
+ _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
+ for task, result in zip(tasks_batch, target_style):
+ target, level = task["target"], task["level"]
+ if len(remapping_table[target])==level:
+ remapping_table[target].append((result, 1))
+ else:
+ frame, weight = remapping_table[target][level]
+ remapping_table[target][level] = (
+ frame * (weight / (weight + 1)) + result / (weight + 1),
+ weight + 1
+ )
+ return remapping_table
+
+ def remapping_table_to_blending_table(self, table):
+ for i in range(len(table)):
+ for j in range(1, len(table[i])):
+ frame_1, weight_1 = table[i][j-1]
+ frame_2, weight_2 = table[i][j]
+ frame = (frame_1 + frame_2) / 2
+ weight = weight_1 + weight_2
+ table[i][j] = (frame, weight)
+ return table
+
+ def tree_query(self, leftbound, rightbound):
+ node_list = []
+ node_index = rightbound
+ while node_index>=leftbound:
+ node_level = 0
+ while (1<=leftbound:
+ node_level += 1
+ node_list.append((node_index, node_level))
+ node_index -= 1<0:
+ tasks = []
+ for m in range(index_style[0]):
+ tasks.append((index_style[0], m, index_style[0]))
+ task_group.append(tasks)
+ # middle frames
+ for l, r in zip(index_style[:-1], index_style[1:]):
+ tasks = []
+ for m in range(l, r):
+ tasks.append((l, m, r))
+ task_group.append(tasks)
+ # last frame
+ tasks = []
+ for m in range(index_style[-1], n):
+ tasks.append((index_style[-1], m, index_style[-1]))
+ task_group.append(tasks)
+ return task_group
+
+ def run(self, frames_guide, frames_style, index_style, batch_size, ebsynth_config, save_path=None):
+ patch_match_engine = PyramidPatchMatcher(
+ image_height=frames_style[0].shape[0],
+ image_width=frames_style[0].shape[1],
+ channel=3,
+ use_mean_target_style=False,
+ use_pairwise_patch_error=True,
+ **ebsynth_config
+ )
+ # task
+ index_dict = self.get_index_dict(index_style)
+ task_group = self.get_task_group(index_style, len(frames_guide))
+ # run
+ for tasks in task_group:
+ index_start, index_end = min([i[1] for i in tasks]), max([i[1] for i in tasks])
+ for batch_id in tqdm(range(0, len(tasks), batch_size), desc=f"Rendering frames {index_start}...{index_end}"):
+ tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
+ source_guide, target_guide, source_style = [], [], []
+ for l, m, r in tasks_batch:
+ # l -> m
+ source_guide.append(frames_guide[l])
+ target_guide.append(frames_guide[m])
+ source_style.append(frames_style[index_dict[l]])
+ # r -> m
+ source_guide.append(frames_guide[r])
+ target_guide.append(frames_guide[m])
+ source_style.append(frames_style[index_dict[r]])
+ source_guide = np.stack(source_guide)
+ target_guide = np.stack(target_guide)
+ source_style = np.stack(source_style)
+ _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
+ if save_path is not None:
+ for frame_l, frame_r, (l, m, r) in zip(target_style[0::2], target_style[1::2], tasks_batch):
+ weight_l, weight_r = self.get_weight(l, m, r)
+ frame = frame_l * weight_l + frame_r * weight_r
+ frame = frame.clip(0, 255).astype("uint8")
+ Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % m))
+
+
+class InterpolationModeSingleFrameRunner:
+ def __init__(self):
+ pass
+
+ def run(self, frames_guide, frames_style, index_style, batch_size, ebsynth_config, save_path=None):
+ # check input
+ tracking_window_size = ebsynth_config["tracking_window_size"]
+ if tracking_window_size * 2 >= batch_size:
+ raise ValueError("batch_size should be larger than track_window_size * 2")
+ frame_style = frames_style[0]
+ frame_guide = frames_guide[index_style[0]]
+ patch_match_engine = PyramidPatchMatcher(
+ image_height=frame_style.shape[0],
+ image_width=frame_style.shape[1],
+ channel=3,
+ **ebsynth_config
+ )
+ # run
+ frame_id, n = 0, len(frames_guide)
+ for i in tqdm(range(0, n, batch_size - tracking_window_size * 2), desc=f"Rendering frames 0...{n}"):
+ if i + batch_size > n:
+ l, r = max(n - batch_size, 0), n
+ else:
+ l, r = i, i + batch_size
+ source_guide = np.stack([frame_guide] * (r-l))
+ target_guide = np.stack([frames_guide[i] for i in range(l, r)])
+ source_style = np.stack([frame_style] * (r-l))
+ _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
+ for i, frame in zip(range(l, r), target_style):
+ if i==frame_id:
+ frame = frame.clip(0, 255).astype("uint8")
+ Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % frame_id))
+ frame_id += 1
+ if r < n and r-frame_id <= tracking_window_size:
+ break
diff --git a/PusaV1/diffsynth/extensions/ImageQualityMetric/BLIP/__init__.py b/PusaV1/diffsynth/extensions/ImageQualityMetric/BLIP/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..885dcf8f76ad77865054f0c033f8541ae08b1e04
--- /dev/null
+++ b/PusaV1/diffsynth/extensions/ImageQualityMetric/BLIP/__init__.py
@@ -0,0 +1 @@
+from .blip_pretrain import *
diff --git a/PusaV1/diffsynth/extensions/ImageQualityMetric/BLIP/blip.py b/PusaV1/diffsynth/extensions/ImageQualityMetric/BLIP/blip.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b24c3c17fdeff6949c3692164362abb8d8d0989
--- /dev/null
+++ b/PusaV1/diffsynth/extensions/ImageQualityMetric/BLIP/blip.py
@@ -0,0 +1,77 @@
+'''
+ * Adapted from BLIP (https://github.com/salesforce/BLIP)
+'''
+
+import warnings
+warnings.filterwarnings("ignore")
+
+import torch
+import os
+from urllib.parse import urlparse
+from timm.models.hub import download_cached_file
+from transformers import BertTokenizer
+from .vit import VisionTransformer, interpolate_pos_embed
+
+
+def default_bert():
+ current_dir = os.path.dirname(os.path.abspath(__file__))
+ project_root = os.path.abspath(os.path.join(current_dir, '../../../../'))
+ model_path = os.path.join(project_root, 'models', 'QualityMetric')
+ return os.path.join(model_path, "bert-base-uncased")
+
+
+def init_tokenizer(bert_model_path):
+ tokenizer = BertTokenizer.from_pretrained(bert_model_path)
+ tokenizer.add_special_tokens({'bos_token':'[DEC]'})
+ tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})
+ tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
+ return tokenizer
+
+
+def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):
+
+ assert vit in ['base', 'large'], "vit parameter must be base or large"
+ if vit=='base':
+ vision_width = 768
+ visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12,
+ num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
+ drop_path_rate=0 or drop_path_rate
+ )
+ elif vit=='large':
+ vision_width = 1024
+ visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24,
+ num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
+ drop_path_rate=0.1 or drop_path_rate
+ )
+ return visual_encoder, vision_width
+
+
+def is_url(url_or_filename):
+ parsed = urlparse(url_or_filename)
+ return parsed.scheme in ("http", "https")
+
+def load_checkpoint(model,url_or_filename):
+ if is_url(url_or_filename):
+ cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
+ checkpoint = torch.load(cached_file, map_location='cpu')
+ elif os.path.isfile(url_or_filename):
+ checkpoint = torch.load(url_or_filename, map_location='cpu')
+ else:
+ raise RuntimeError('checkpoint url or path is invalid')
+
+ state_dict = checkpoint['model']
+
+ state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
+ if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
+ state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],
+ model.visual_encoder_m)
+ for key in model.state_dict().keys():
+ if key in state_dict.keys():
+ if state_dict[key].shape!=model.state_dict()[key].shape:
+ print(key, ": ", state_dict[key].shape, ', ', model.state_dict()[key].shape)
+ del state_dict[key]
+
+ msg = model.load_state_dict(state_dict,strict=False)
+ print('load checkpoint from %s'%url_or_filename)
+ return model,msg
+
diff --git a/PusaV1/diffsynth/extensions/ImageQualityMetric/BLIP/blip_pretrain.py b/PusaV1/diffsynth/extensions/ImageQualityMetric/BLIP/blip_pretrain.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba711e2776fd086190ca940248e022b4e083819a
--- /dev/null
+++ b/PusaV1/diffsynth/extensions/ImageQualityMetric/BLIP/blip_pretrain.py
@@ -0,0 +1,44 @@
+'''
+ * Adapted from BLIP (https://github.com/salesforce/BLIP)
+'''
+
+import transformers
+transformers.logging.set_verbosity_error()
+
+from torch import nn
+import os
+from .med import BertConfig, BertModel
+from .blip import create_vit, init_tokenizer
+
+class BLIP_Pretrain(nn.Module):
+ def __init__(self,
+ med_config = "med_config.json",
+ image_size = 224,
+ vit = 'base',
+ vit_grad_ckpt = False,
+ vit_ckpt_layer = 0,
+ embed_dim = 256,
+ queue_size = 57600,
+ momentum = 0.995,
+ bert_model_path = ""
+ ):
+ """
+ Args:
+ med_config (str): path for the mixture of encoder-decoder model's configuration file
+ image_size (int): input image size
+ vit (str): model size of vision transformer
+ """
+ super().__init__()
+
+ self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, 0)
+
+ self.tokenizer = init_tokenizer(bert_model_path)
+ encoder_config = BertConfig.from_json_file(med_config)
+ encoder_config.encoder_width = vision_width
+ self.text_encoder = BertModel(config=encoder_config, add_pooling_layer=False)
+
+ text_width = self.text_encoder.config.hidden_size
+
+ self.vision_proj = nn.Linear(vision_width, embed_dim)
+ self.text_proj = nn.Linear(text_width, embed_dim)
+
diff --git a/PusaV1/diffsynth/extensions/ImageQualityMetric/BLIP/med.py b/PusaV1/diffsynth/extensions/ImageQualityMetric/BLIP/med.py
new file mode 100644
index 0000000000000000000000000000000000000000..426f4689833d988526c6e26cd627f30975ab7606
--- /dev/null
+++ b/PusaV1/diffsynth/extensions/ImageQualityMetric/BLIP/med.py
@@ -0,0 +1,947 @@
+'''
+ * Adapted from BLIP (https://github.com/salesforce/BLIP)
+ * Based on huggingface code base
+ * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
+'''
+
+import math
+from typing import Tuple
+
+import torch
+from torch import Tensor, device, nn
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import CrossEntropyLoss
+
+from transformers.activations import ACT2FN
+from transformers.file_utils import (
+ ModelOutput,
+)
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPastAndCrossAttentions,
+ BaseModelOutputWithPoolingAndCrossAttentions,
+ CausalLMOutputWithCrossAttentions,
+ MaskedLMOutput,
+ MultipleChoiceModelOutput,
+ NextSentencePredictorOutput,
+ QuestionAnsweringModelOutput,
+ SequenceClassifierOutput,
+ TokenClassifierOutput,
+)
+from transformers.modeling_utils import (
+ PreTrainedModel,
+ apply_chunking_to_forward,
+ find_pruneable_heads_and_indices,
+ prune_linear_layer,
+)
+from transformers.utils import logging
+from transformers.models.bert.configuration_bert import BertConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+class BertEmbeddings(nn.Module):
+ """Construct the embeddings from word and position embeddings."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
+
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
+ # any TensorFlow checkpoint file
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
+
+ self.config = config
+
+ def forward(
+ self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
+ ):
+ if input_ids is not None:
+ input_shape = input_ids.size()
+ else:
+ input_shape = inputs_embeds.size()[:-1]
+
+ seq_length = input_shape[1]
+
+ if position_ids is None:
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
+
+ if inputs_embeds is None:
+ inputs_embeds = self.word_embeddings(input_ids)
+
+ embeddings = inputs_embeds
+
+ if self.position_embedding_type == "absolute":
+ position_embeddings = self.position_embeddings(position_ids)
+ embeddings += position_embeddings
+ embeddings = self.LayerNorm(embeddings)
+ embeddings = self.dropout(embeddings)
+ return embeddings
+
+
+class BertSelfAttention(nn.Module):
+ def __init__(self, config, is_cross_attention):
+ super().__init__()
+ self.config = config
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+ raise ValueError(
+ "The hidden size (%d) is not a multiple of the number of attention "
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
+ )
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
+ if is_cross_attention:
+ self.key = nn.Linear(config.encoder_width, self.all_head_size)
+ self.value = nn.Linear(config.encoder_width, self.all_head_size)
+ else:
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+ self.max_position_embeddings = config.max_position_embeddings
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
+ self.save_attention = False
+
+ def save_attn_gradients(self, attn_gradients):
+ self.attn_gradients = attn_gradients
+
+ def get_attn_gradients(self):
+ return self.attn_gradients
+
+ def save_attention_map(self, attention_map):
+ self.attention_map = attention_map
+
+ def get_attention_map(self):
+ return self.attention_map
+
+ def transpose_for_scores(self, x):
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
+ x = x.view(*new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_value=None,
+ output_attentions=False,
+ ):
+ mixed_query_layer = self.query(hidden_states)
+
+ # If this is instantiated as a cross-attention module, the keys
+ # and values come from an encoder; the attention mask needs to be
+ # such that the encoder's padding tokens are not attended to.
+ is_cross_attention = encoder_hidden_states is not None
+
+ if is_cross_attention:
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
+ attention_mask = encoder_attention_mask
+ elif past_key_value is not None:
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
+ else:
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+
+ past_key_value = (key_layer, value_layer)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+ seq_length = hidden_states.size()[1]
+ position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
+ position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
+ distance = position_ids_l - position_ids_r
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
+
+ if self.position_embedding_type == "relative_key":
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+ attention_scores = attention_scores + relative_position_scores
+ elif self.position_embedding_type == "relative_key_query":
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
+
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+ if attention_mask is not None:
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
+ attention_scores = attention_scores + attention_mask
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
+
+ if is_cross_attention and self.save_attention:
+ self.save_attention_map(attention_probs)
+ attention_probs.register_hook(self.save_attn_gradients)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs_dropped = self.dropout(attention_probs)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs_dropped = attention_probs_dropped * head_mask
+
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(*new_context_layer_shape)
+
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+ outputs = outputs + (past_key_value,)
+ return outputs
+
+
+class BertSelfOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states, input_tensor):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+class BertAttention(nn.Module):
+ def __init__(self, config, is_cross_attention=False):
+ super().__init__()
+ self.self = BertSelfAttention(config, is_cross_attention)
+ self.output = BertSelfOutput(config)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
+ )
+
+ # Prune linear layers
+ self.self.query = prune_linear_layer(self.self.query, index)
+ self.self.key = prune_linear_layer(self.self.key, index)
+ self.self.value = prune_linear_layer(self.self.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_value=None,
+ output_attentions=False,
+ ):
+ self_outputs = self.self(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ )
+ attention_output = self.output(self_outputs[0], hidden_states)
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+ return outputs
+
+
+class BertIntermediate(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ return hidden_states
+
+
+class BertOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states, input_tensor):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+class BertLayer(nn.Module):
+ def __init__(self, config, layer_num):
+ super().__init__()
+ self.config = config
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+ self.seq_len_dim = 1
+ self.attention = BertAttention(config)
+ self.layer_num = layer_num
+ if self.config.add_cross_attention:
+ self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention)
+ self.intermediate = BertIntermediate(config)
+ self.output = BertOutput(config)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_value=None,
+ output_attentions=False,
+ mode=None,
+ ):
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
+ self_attention_outputs = self.attention(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ output_attentions=output_attentions,
+ past_key_value=self_attn_past_key_value,
+ )
+ attention_output = self_attention_outputs[0]
+
+ outputs = self_attention_outputs[1:-1]
+ present_key_value = self_attention_outputs[-1]
+
+ if mode=='multimodal':
+ assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
+
+ cross_attention_outputs = self.crossattention(
+ attention_output,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ output_attentions=output_attentions,
+ )
+ attention_output = cross_attention_outputs[0]
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
+ layer_output = apply_chunking_to_forward(
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
+ )
+ outputs = (layer_output,) + outputs
+
+ outputs = outputs + (present_key_value,)
+
+ return outputs
+
+ def feed_forward_chunk(self, attention_output):
+ intermediate_output = self.intermediate(attention_output)
+ layer_output = self.output(intermediate_output, attention_output)
+ return layer_output
+
+
+class BertEncoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_values=None,
+ use_cache=None,
+ output_attentions=False,
+ output_hidden_states=False,
+ return_dict=True,
+ mode='multimodal',
+ ):
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
+
+ next_decoder_cache = () if use_cache else None
+
+ for i in range(self.config.num_hidden_layers):
+ layer_module = self.layer[i]
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+ past_key_value = past_key_values[i] if past_key_values is not None else None
+
+ if self.gradient_checkpointing and self.training:
+
+ if use_cache:
+ logger.warn(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs, past_key_value, output_attentions)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(layer_module),
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ mode=mode,
+ )
+ else:
+ layer_outputs = layer_module(
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ mode=mode,
+ )
+
+ hidden_states = layer_outputs[0]
+ if use_cache:
+ next_decoder_cache += (layer_outputs[-1],)
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [
+ hidden_states,
+ next_decoder_cache,
+ all_hidden_states,
+ all_self_attentions,
+ all_cross_attentions,
+ ]
+ if v is not None
+ )
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=next_decoder_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ cross_attentions=all_cross_attentions,
+ )
+
+
+class BertPooler(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.activation = nn.Tanh()
+
+ def forward(self, hidden_states):
+ # We "pool" the model by simply taking the hidden state corresponding
+ # to the first token.
+ first_token_tensor = hidden_states[:, 0]
+ pooled_output = self.dense(first_token_tensor)
+ pooled_output = self.activation(pooled_output)
+ return pooled_output
+
+
+class BertPredictionHeadTransform(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ if isinstance(config.hidden_act, str):
+ self.transform_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.transform_act_fn = config.hidden_act
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ def forward(self, hidden_states):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.transform_act_fn(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states)
+ return hidden_states
+
+
+class BertLMPredictionHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.transform = BertPredictionHeadTransform(config)
+
+ # The output weights are the same as the input embeddings, but there is
+ # an output-only bias for each token.
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
+
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
+ self.decoder.bias = self.bias
+
+ def forward(self, hidden_states):
+ hidden_states = self.transform(hidden_states)
+ hidden_states = self.decoder(hidden_states)
+ return hidden_states
+
+
+class BertOnlyMLMHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.predictions = BertLMPredictionHead(config)
+
+ def forward(self, sequence_output):
+ prediction_scores = self.predictions(sequence_output)
+ return prediction_scores
+
+
+class BertPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = BertConfig
+ base_model_prefix = "bert"
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
+
+ def _init_weights(self, module):
+ """ Initialize the weights """
+ if isinstance(module, (nn.Linear, nn.Embedding)):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ if isinstance(module, nn.Linear) and module.bias is not None:
+ module.bias.data.zero_()
+
+
+class BertModel(BertPreTrainedModel):
+ """
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention is
+ all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
+ input to the forward pass.
+ """
+
+ def __init__(self, config, add_pooling_layer=True):
+ super().__init__(config)
+ self.config = config
+
+ self.embeddings = BertEmbeddings(config)
+
+ self.encoder = BertEncoder(config)
+
+ self.pooler = BertPooler(config) if add_pooling_layer else None
+
+ self.init_weights()
+
+
+ def get_input_embeddings(self):
+ return self.embeddings.word_embeddings
+
+ def set_input_embeddings(self, value):
+ self.embeddings.word_embeddings = value
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+
+ def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor:
+ """
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
+
+ Arguments:
+ attention_mask (:obj:`torch.Tensor`):
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
+ input_shape (:obj:`Tuple[int]`):
+ The shape of the input to the model.
+ device: (:obj:`torch.device`):
+ The device of the input to the model.
+
+ Returns:
+ :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
+ """
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable to all heads.
+ if attention_mask.dim() == 3:
+ extended_attention_mask = attention_mask[:, None, :, :]
+ elif attention_mask.dim() == 2:
+ # Provided a padding mask of dimensions [batch_size, seq_length]
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if is_decoder:
+ batch_size, seq_length = input_shape
+
+ seq_ids = torch.arange(seq_length, device=device)
+ causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
+ # in case past_key_values are used we need to add a prefix ones mask to the causal mask
+ # causal and attention masks must have same type with pytorch version < 1.3
+ causal_mask = causal_mask.to(attention_mask.dtype)
+
+ if causal_mask.shape[1] < attention_mask.shape[1]:
+ prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
+ causal_mask = torch.cat(
+ [
+ torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
+ causal_mask,
+ ],
+ axis=-1,
+ )
+
+ extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
+ else:
+ extended_attention_mask = attention_mask[:, None, None, :]
+ else:
+ raise ValueError(
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
+ input_shape, attention_mask.shape
+ )
+ )
+
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
+ # masked positions, this operation will create a tensor which is 0.0 for
+ # positions we want to attend and -10000.0 for masked positions.
+ # Since we are adding it to the raw scores before the softmax, this is
+ # effectively the same as removing these entirely.
+ extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
+ return extended_attention_mask
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ encoder_embeds=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_values=None,
+ use_cache=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ is_decoder=False,
+ mode='multimodal',
+ ):
+ r"""
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+ the model is configured as a decoder.
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
+ use_cache (:obj:`bool`, `optional`):
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
+ decoding (see :obj:`past_key_values`).
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if is_decoder:
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ else:
+ use_cache = False
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ batch_size, seq_length = input_shape
+ device = input_ids.device
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ batch_size, seq_length = input_shape
+ device = inputs_embeds.device
+ elif encoder_embeds is not None:
+ input_shape = encoder_embeds.size()[:-1]
+ batch_size, seq_length = input_shape
+ device = encoder_embeds.device
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds")
+
+ # past_key_values_length
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
+
+ if attention_mask is None:
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
+
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable to all heads.
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape,
+ device, is_decoder)
+
+ # If a 2D or 3D attention mask is provided for the cross-attention
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if encoder_hidden_states is not None:
+ if type(encoder_hidden_states) == list:
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
+ else:
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
+
+ if type(encoder_attention_mask) == list:
+ encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
+ elif encoder_attention_mask is None:
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+ else:
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+ else:
+ encoder_extended_attention_mask = None
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ if encoder_embeds is None:
+ embedding_output = self.embeddings(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ inputs_embeds=inputs_embeds,
+ past_key_values_length=past_key_values_length,
+ )
+ else:
+ embedding_output = encoder_embeds
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ attention_mask=extended_attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_extended_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ mode=mode,
+ )
+ sequence_output = encoder_outputs[0]
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
+
+ if not return_dict:
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPoolingAndCrossAttentions(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ past_key_values=encoder_outputs.past_key_values,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ cross_attentions=encoder_outputs.cross_attentions,
+ )
+
+
+
+class BertLMHeadModel(BertPreTrainedModel):
+
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.bert = BertModel(config, add_pooling_layer=False)
+ self.cls = BertOnlyMLMHead(config)
+
+ self.init_weights()
+
+ def get_output_embeddings(self):
+ return self.cls.predictions.decoder
+
+ def set_output_embeddings(self, new_embeddings):
+ self.cls.predictions.decoder = new_embeddings
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ labels=None,
+ past_key_values=None,
+ use_cache=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ return_logits=False,
+ is_decoder=True,
+ reduction='mean',
+ mode='multimodal',
+ ):
+ r"""
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+ the model is configured as a decoder.
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
+ ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
+ ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
+ use_cache (:obj:`bool`, `optional`):
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
+ decoding (see :obj:`past_key_values`).
+ Returns:
+ Example::
+ >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
+ >>> import torch
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
+ >>> config = BertConfig.from_pretrained("bert-base-cased")
+ >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
+ >>> outputs = model(**inputs)
+ >>> prediction_logits = outputs.logits
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ if labels is not None:
+ use_cache = False
+
+ outputs = self.bert(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ is_decoder=is_decoder,
+ mode=mode,
+ )
+
+ sequence_output = outputs[0]
+ prediction_scores = self.cls(sequence_output)
+
+ if return_logits:
+ return prediction_scores[:, :-1, :].contiguous()
+
+ lm_loss = None
+ if labels is not None:
+ # we are doing next-token prediction; shift prediction scores and input ids by one
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
+ labels = labels[:, 1:].contiguous()
+ loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
+ lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
+ if reduction=='none':
+ lm_loss = lm_loss.view(prediction_scores.size(0),-1).sum(1)
+
+ if not return_dict:
+ output = (prediction_scores,) + outputs[2:]
+ return ((lm_loss,) + output) if lm_loss is not None else output
+
+ return CausalLMOutputWithCrossAttentions(
+ loss=lm_loss,
+ logits=prediction_scores,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ cross_attentions=outputs.cross_attentions,
+ )
+
+ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
+ input_shape = input_ids.shape
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
+ if attention_mask is None:
+ attention_mask = input_ids.new_ones(input_shape)
+
+ # cut decoder_input_ids if past is used
+ if past is not None:
+ input_ids = input_ids[:, -1:]
+
+ return {
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ "past_key_values": past,
+ "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
+ "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
+ "is_decoder": True,
+ }
+
+ def _reorder_cache(self, past, beam_idx):
+ reordered_past = ()
+ for layer_past in past:
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
+ return reordered_past
diff --git a/PusaV1/diffsynth/extensions/ImageQualityMetric/BLIP/vit.py b/PusaV1/diffsynth/extensions/ImageQualityMetric/BLIP/vit.py
new file mode 100644
index 0000000000000000000000000000000000000000..cef7b650a95f56266775cf0f18b28bc0f74987ab
--- /dev/null
+++ b/PusaV1/diffsynth/extensions/ImageQualityMetric/BLIP/vit.py
@@ -0,0 +1,301 @@
+'''
+ * Adapted from BLIP (https://github.com/salesforce/BLIP)
+ * Based on timm code base
+ * https://github.com/rwightman/pytorch-image-models/tree/master/timm
+'''
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from functools import partial
+
+from timm.models.vision_transformer import _cfg, PatchEmbed
+from timm.models.registry import register_model
+from timm.models.layers import trunc_normal_, DropPath
+from timm.models.helpers import named_apply, adapt_input_conv
+
+# from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
+
+class Mlp(nn.Module):
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks
+ """
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class Attention(nn.Module):
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
+ self.scale = qk_scale or head_dim ** -0.5
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+ self.attn_gradients = None
+ self.attention_map = None
+
+ def save_attn_gradients(self, attn_gradients):
+ self.attn_gradients = attn_gradients
+
+ def get_attn_gradients(self):
+ return self.attn_gradients
+
+ def save_attention_map(self, attention_map):
+ self.attention_map = attention_map
+
+ def get_attention_map(self):
+ return self.attention_map
+
+ def forward(self, x, register_hook=False):
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
+
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ if register_hook:
+ self.save_attention_map(attn)
+ attn.register_hook(self.save_attn_gradients)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class Block(nn.Module):
+
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False):
+ super().__init__()
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ # if use_grad_checkpointing:
+ # self.attn = checkpoint_wrapper(self.attn)
+ # self.mlp = checkpoint_wrapper(self.mlp)
+
+ def forward(self, x, register_hook=False):
+ x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ return x
+
+
+class VisionTransformer(nn.Module):
+ """ Vision Transformer
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
+ https://arxiv.org/abs/2010.11929
+ """
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
+ num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None,
+ use_grad_checkpointing=False, ckpt_layer=0):
+ """
+ Args:
+ img_size (int, tuple): input image size
+ patch_size (int, tuple): patch size
+ in_chans (int): number of input channels
+ num_classes (int): number of classes for classification head
+ embed_dim (int): embedding dimension
+ depth (int): depth of transformer
+ num_heads (int): number of attention heads
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
+ qkv_bias (bool): enable bias for qkv if True
+ qk_scale (float): override default qk scale of head_dim ** -0.5 if set
+ representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
+ drop_rate (float): dropout rate
+ attn_drop_rate (float): attention dropout rate
+ drop_path_rate (float): stochastic depth rate
+ norm_layer: (nn.Module): normalization layer
+ """
+ super().__init__()
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
+
+ self.patch_embed = PatchEmbed(
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
+
+ num_patches = self.patch_embed.num_patches
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+ self.blocks = nn.ModuleList([
+ Block(
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
+ use_grad_checkpointing=(use_grad_checkpointing and i>=depth-ckpt_layer)
+ )
+ for i in range(depth)])
+ self.norm = norm_layer(embed_dim)
+
+ trunc_normal_(self.pos_embed, std=.02)
+ trunc_normal_(self.cls_token, std=.02)
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'pos_embed', 'cls_token'}
+
+ def forward(self, x, register_blk=-1):
+ B = x.shape[0]
+ x = self.patch_embed(x)
+
+ cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
+ x = torch.cat((cls_tokens, x), dim=1)
+
+ x = x + self.pos_embed[:,:x.size(1),:]
+ x = self.pos_drop(x)
+
+ for i,blk in enumerate(self.blocks):
+ x = blk(x, register_blk==i)
+ x = self.norm(x)
+
+ return x
+
+ @torch.jit.ignore()
+ def load_pretrained(self, checkpoint_path, prefix=''):
+ _load_weights(self, checkpoint_path, prefix)
+
+
+@torch.no_grad()
+def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
+ """ Load weights from .npz checkpoints for official Google Brain Flax implementation
+ """
+ import numpy as np
+
+ def _n2p(w, t=True):
+ if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
+ w = w.flatten()
+ if t:
+ if w.ndim == 4:
+ w = w.transpose([3, 2, 0, 1])
+ elif w.ndim == 3:
+ w = w.transpose([2, 0, 1])
+ elif w.ndim == 2:
+ w = w.transpose([1, 0])
+ return torch.from_numpy(w)
+
+ w = np.load(checkpoint_path)
+ if not prefix and 'opt/target/embedding/kernel' in w:
+ prefix = 'opt/target/'
+
+ if hasattr(model.patch_embed, 'backbone'):
+ # hybrid
+ backbone = model.patch_embed.backbone
+ stem_only = not hasattr(backbone, 'stem')
+ stem = backbone if stem_only else backbone.stem
+ stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
+ stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
+ stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
+ if not stem_only:
+ for i, stage in enumerate(backbone.stages):
+ for j, block in enumerate(stage.blocks):
+ bp = f'{prefix}block{i + 1}/unit{j + 1}/'
+ for r in range(3):
+ getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
+ getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
+ getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
+ if block.downsample is not None:
+ block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
+ block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
+ block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
+ embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
+ else:
+ embed_conv_w = adapt_input_conv(
+ model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
+ model.patch_embed.proj.weight.copy_(embed_conv_w)
+ model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
+ model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
+ pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
+ if pos_embed_w.shape != model.pos_embed.shape:
+ pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
+ pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
+ model.pos_embed.copy_(pos_embed_w)
+ model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
+ model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
+# if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
+# model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
+# model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
+# if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
+# model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
+# model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
+ for i, block in enumerate(model.blocks.children()):
+ block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
+ mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
+ block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
+ block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
+ block.attn.qkv.weight.copy_(torch.cat([
+ _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
+ block.attn.qkv.bias.copy_(torch.cat([
+ _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
+ block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
+ block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
+ for r in range(2):
+ getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
+ getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
+ block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
+ block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
+
+
+def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder):
+ # interpolate position embedding
+ embedding_size = pos_embed_checkpoint.shape[-1]
+ num_patches = visual_encoder.patch_embed.num_patches
+ num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches
+ # height (== width) for the checkpoint position embedding
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
+ # height (== width) for the new position embedding
+ new_size = int(num_patches ** 0.5)
+
+ if orig_size!=new_size:
+ # class_token and dist_token are kept unchanged
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
+ # only the position tokens are interpolated
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
+ pos_tokens = torch.nn.functional.interpolate(
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
+ print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2))
+
+ return new_pos_embed
+ else:
+ return pos_embed_checkpoint
\ No newline at end of file
diff --git a/PusaV1/diffsynth/extensions/ImageQualityMetric/__init__.py b/PusaV1/diffsynth/extensions/ImageQualityMetric/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fcfb7c02b0ce2b6a2fbe345d87c31e0d1bb3a128
--- /dev/null
+++ b/PusaV1/diffsynth/extensions/ImageQualityMetric/__init__.py
@@ -0,0 +1,148 @@
+from modelscope import snapshot_download
+from typing_extensions import Literal, TypeAlias
+import os
+from diffsynth.extensions.ImageQualityMetric.aesthetic import AestheticScore
+from diffsynth.extensions.ImageQualityMetric.imagereward import ImageRewardScore
+from diffsynth.extensions.ImageQualityMetric.pickscore import PickScore
+from diffsynth.extensions.ImageQualityMetric.clip import CLIPScore
+from diffsynth.extensions.ImageQualityMetric.hps import HPScore_v2
+from diffsynth.extensions.ImageQualityMetric.mps import MPScore
+
+
+preference_model_id: TypeAlias = Literal[
+ "ImageReward",
+ "Aesthetic",
+ "PickScore",
+ "CLIP",
+ "HPSv2",
+ "HPSv2.1",
+ "MPS",
+]
+model_dict = {
+ "ImageReward": {
+ "model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
+ "allow_file_pattern": [
+ "ImageReward/ImageReward.safetensors",
+ "ImageReward/med_config.json",
+ "bert-base-uncased/config.json",
+ "bert-base-uncased/model.safetensors",
+ "bert-base-uncased/tokenizer.json",
+ "bert-base-uncased/tokenizer_config.json",
+ "bert-base-uncased/vocab.txt",
+ ],
+ "load_path": {
+ "imagereward": "ImageReward/ImageReward.safetensors",
+ "med_config": "ImageReward/med_config.json",
+ "bert_model_path": "bert-base-uncased",
+ },
+ "model_class": ImageRewardScore
+ },
+ "Aesthetic": {
+ "model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
+ "allow_file_pattern": [
+ "aesthetic-predictor/sac+logos+ava1-l14-linearMSE.safetensors",
+ "clip-vit-large-patch14/config.json",
+ "clip-vit-large-patch14/merges.txt",
+ "clip-vit-large-patch14/model.safetensors",
+ "clip-vit-large-patch14/preprocessor_config.json",
+ "clip-vit-large-patch14/special_tokens_map.json",
+ "clip-vit-large-patch14/tokenizer.json",
+ "clip-vit-large-patch14/tokenizer_config.json",
+ "clip-vit-large-patch14/vocab.json",
+ ],
+ "load_path": {
+ "aesthetic_predictor": "aesthetic-predictor/sac+logos+ava1-l14-linearMSE.safetensors",
+ "clip-large": "clip-vit-large-patch14",
+ },
+ "model_class": AestheticScore
+ },
+ "PickScore": {
+ "model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
+ "allow_file_pattern": [
+ "PickScore_v1/*",
+ "CLIP-ViT-H-14-laion2B-s32B-b79K/config.json",
+ "CLIP-ViT-H-14-laion2B-s32B-b79K/merges.txt",
+ "CLIP-ViT-H-14-laion2B-s32B-b79K/preprocessor_config.json",
+ "CLIP-ViT-H-14-laion2B-s32B-b79K/special_tokens_map.json",
+ "CLIP-ViT-H-14-laion2B-s32B-b79K/tokenizer.json",
+ "CLIP-ViT-H-14-laion2B-s32B-b79K/tokenizer_config.json",
+ "CLIP-ViT-H-14-laion2B-s32B-b79K/vocab.json",
+ ],
+ "load_path": {
+ "pickscore": "PickScore_v1",
+ "clip": "CLIP-ViT-H-14-laion2B-s32B-b79K",
+ },
+ "model_class": PickScore
+ },
+ "CLIP": {
+ "model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
+ "allow_file_pattern": [
+ "CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin",
+ "bpe_simple_vocab_16e6.txt.gz",
+ ],
+ "load_path": {
+ "open_clip": "CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin",
+ "open_clip_bpe": "bpe_simple_vocab_16e6.txt.gz",
+ },
+ "model_class": CLIPScore
+ },
+ "HPSv2": {
+ "model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
+ "allow_file_pattern": [
+ "HPS_v2/HPS_v2_compressed.safetensors",
+ "bpe_simple_vocab_16e6.txt.gz",
+ ],
+ "load_path": {
+ "hpsv2": "HPS_v2/HPS_v2_compressed.safetensors",
+ "open_clip_bpe": "bpe_simple_vocab_16e6.txt.gz",
+ },
+ "model_class": HPScore_v2,
+ "extra_kwargs": {"model_version": "v2"}
+ },
+ "HPSv2.1": {
+ "model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
+ "allow_file_pattern": [
+ "HPS_v2/HPS_v2.1_compressed.safetensors",
+ "bpe_simple_vocab_16e6.txt.gz",
+ ],
+ "load_path": {
+ "hpsv2.1": "HPS_v2/HPS_v2.1_compressed.safetensors",
+ "open_clip_bpe": "bpe_simple_vocab_16e6.txt.gz",
+ },
+ "model_class": HPScore_v2,
+ "extra_kwargs": {"model_version": "v21"}
+ },
+ "MPS": {
+ "model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
+ "allow_file_pattern": [
+ "MPS_overall_checkpoint/MPS_overall_checkpoint_diffsynth.safetensors",
+ "CLIP-ViT-H-14-laion2B-s32B-b79K/config.json",
+ "CLIP-ViT-H-14-laion2B-s32B-b79K/merges.txt",
+ "CLIP-ViT-H-14-laion2B-s32B-b79K/preprocessor_config.json",
+ "CLIP-ViT-H-14-laion2B-s32B-b79K/special_tokens_map.json",
+ "CLIP-ViT-H-14-laion2B-s32B-b79K/tokenizer.json",
+ "CLIP-ViT-H-14-laion2B-s32B-b79K/tokenizer_config.json",
+ "CLIP-ViT-H-14-laion2B-s32B-b79K/vocab.json",
+ ],
+ "load_path": {
+ "mps": "MPS_overall_checkpoint/MPS_overall_checkpoint_diffsynth.safetensors",
+ "clip": "CLIP-ViT-H-14-laion2B-s32B-b79K",
+ },
+ "model_class": MPScore
+ },
+}
+
+
+def download_preference_model(model_name: preference_model_id, cache_dir="models"):
+ metadata = model_dict[model_name]
+ snapshot_download(model_id=metadata["model_id"], allow_file_pattern=metadata["allow_file_pattern"], cache_dir=cache_dir)
+ load_path = metadata["load_path"]
+ load_path = {key: os.path.join(cache_dir, metadata["model_id"], path) for key, path in load_path.items()}
+ return load_path
+
+
+def load_preference_model(model_name: preference_model_id, device = "cuda", path = None):
+ model_class = model_dict[model_name]["model_class"]
+ extra_kwargs = model_dict[model_name].get("extra_kwargs", {})
+ preference_model = model_class(device=device, path=path, **extra_kwargs)
+ return preference_model
diff --git a/PusaV1/diffsynth/extensions/ImageQualityMetric/aesthetic.py b/PusaV1/diffsynth/extensions/ImageQualityMetric/aesthetic.py
new file mode 100644
index 0000000000000000000000000000000000000000..13da98a1f45ca7eea0411e18c307cc5d0154488f
--- /dev/null
+++ b/PusaV1/diffsynth/extensions/ImageQualityMetric/aesthetic.py
@@ -0,0 +1,148 @@
+from typing import List, Optional
+from PIL import Image
+import torch
+from transformers import AutoProcessor, AutoModel
+from safetensors.torch import load_file
+import os
+from typing import Union, List
+from .config import MODEL_PATHS
+
+class MLP(torch.nn.Module):
+ def __init__(self, input_size: int, xcol: str = "emb", ycol: str = "avg_rating"):
+ super().__init__()
+ self.input_size = input_size
+ self.xcol = xcol
+ self.ycol = ycol
+ self.layers = torch.nn.Sequential(
+ torch.nn.Linear(self.input_size, 1024),
+ #torch.nn.ReLU(),
+ torch.nn.Dropout(0.2),
+ torch.nn.Linear(1024, 128),
+ #torch.nn.ReLU(),
+ torch.nn.Dropout(0.2),
+ torch.nn.Linear(128, 64),
+ #torch.nn.ReLU(),
+ torch.nn.Dropout(0.1),
+ torch.nn.Linear(64, 16),
+ #torch.nn.ReLU(),
+ torch.nn.Linear(16, 1),
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.layers(x)
+
+ def training_step(self, batch: dict, batch_idx: int) -> torch.Tensor:
+ x = batch[self.xcol]
+ y = batch[self.ycol].reshape(-1, 1)
+ x_hat = self.layers(x)
+ loss = torch.nn.functional.mse_loss(x_hat, y)
+ return loss
+
+ def validation_step(self, batch: dict, batch_idx: int) -> torch.Tensor:
+ x = batch[self.xcol]
+ y = batch[self.ycol].reshape(-1, 1)
+ x_hat = self.layers(x)
+ loss = torch.nn.functional.mse_loss(x_hat, y)
+ return loss
+
+ def configure_optimizers(self) -> torch.optim.Optimizer:
+ return torch.optim.Adam(self.parameters(), lr=1e-3)
+
+
+class AestheticScore(torch.nn.Module):
+ def __init__(self, device: torch.device, path: str = MODEL_PATHS):
+ super().__init__()
+ self.device = device
+ self.aes_model_path = path.get("aesthetic_predictor")
+ # Load the MLP model
+ self.model = MLP(768)
+ try:
+ if self.aes_model_path.endswith(".safetensors"):
+ state_dict = load_file(self.aes_model_path)
+ else:
+ state_dict = torch.load(self.aes_model_path)
+ self.model.load_state_dict(state_dict)
+ except Exception as e:
+ raise ValueError(f"Error loading model weights from {self.aes_model_path}: {e}")
+
+ self.model.to(device)
+ self.model.eval()
+
+ # Load the CLIP model and processor
+ clip_model_name = path.get('clip-large')
+ self.model2 = AutoModel.from_pretrained(clip_model_name).eval().to(device)
+ self.processor = AutoProcessor.from_pretrained(clip_model_name)
+
+ def _calculate_score(self, image: torch.Tensor) -> float:
+ """Calculate the aesthetic score for a single image.
+
+ Args:
+ image (torch.Tensor): The processed image tensor.
+
+ Returns:
+ float: The aesthetic score.
+ """
+ with torch.no_grad():
+ # Get image embeddings
+ image_embs = self.model2.get_image_features(image)
+ image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True)
+
+ # Compute score
+ score = self.model(image_embs).cpu().flatten().item()
+
+ return score
+
+ @torch.no_grad()
+ def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str = "") -> List[float]:
+ """Score the images based on their aesthetic quality.
+
+ Args:
+ images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
+
+ Returns:
+ List[float]: List of scores for the images.
+ """
+ try:
+ if isinstance(images, (str, Image.Image)):
+ # Single image
+ if isinstance(images, str):
+ pil_image = Image.open(images)
+ else:
+ pil_image = images
+
+ # Prepare image inputs
+ image_inputs = self.processor(
+ images=pil_image,
+ padding=True,
+ truncation=True,
+ max_length=77,
+ return_tensors="pt",
+ ).to(self.device)
+
+ return [self._calculate_score(image_inputs["pixel_values"])]
+ elif isinstance(images, list):
+ # Multiple images
+ scores = []
+ for one_image in images:
+ if isinstance(one_image, str):
+ pil_image = Image.open(one_image)
+ elif isinstance(one_image, Image.Image):
+ pil_image = one_image
+ else:
+ raise TypeError("The type of parameter images is illegal.")
+
+ # Prepare image inputs
+ image_inputs = self.processor(
+ images=pil_image,
+ padding=True,
+ truncation=True,
+ max_length=77,
+ return_tensors="pt",
+ ).to(self.device)
+
+ scores.append(self._calculate_score(image_inputs["pixel_values"]))
+ return scores
+ else:
+ raise TypeError("The type of parameter images is illegal.")
+ except Exception as e:
+ raise RuntimeError(f"Error in scoring images: {e}")
diff --git a/PusaV1/diffsynth/extensions/ImageQualityMetric/clip.py b/PusaV1/diffsynth/extensions/ImageQualityMetric/clip.py
new file mode 100644
index 0000000000000000000000000000000000000000..f70941e0a45db61be87e21c347e97ad8bb390fff
--- /dev/null
+++ b/PusaV1/diffsynth/extensions/ImageQualityMetric/clip.py
@@ -0,0 +1,97 @@
+from typing import List, Union
+from PIL import Image
+import torch
+from .open_clip import create_model_and_transforms, get_tokenizer
+from .config import MODEL_PATHS
+
+class CLIPScore(torch.nn.Module):
+ def __init__(self, device: torch.device, path: str = MODEL_PATHS):
+ super().__init__()
+ """Initialize the CLIPScore with a model and tokenizer.
+
+ Args:
+ device (torch.device): The device to load the model on.
+ """
+ self.device = device
+
+ # Create model and transforms
+ self.model, _, self.preprocess_val = create_model_and_transforms(
+ "ViT-H-14",
+ # "laion2B-s32B-b79K",
+ pretrained=path.get("open_clip"),
+ precision="amp",
+ device=device,
+ jit=False,
+ force_quick_gelu=False,
+ force_custom_text=False,
+ force_patch_dropout=False,
+ force_image_size=None,
+ pretrained_image=False,
+ image_mean=None,
+ image_std=None,
+ light_augmentation=True,
+ aug_cfg={},
+ output_dict=True,
+ with_score_predictor=False,
+ with_region_predictor=False,
+ )
+
+ # Initialize tokenizer
+ self.tokenizer = get_tokenizer("ViT-H-14", path["open_clip_bpe"])
+ self.model = self.model.to(device)
+ self.model.eval()
+
+ def _calculate_score(self, image: torch.Tensor, prompt: str) -> float:
+ """Calculate the CLIP score for a single image and prompt.
+
+ Args:
+ image (torch.Tensor): The processed image tensor.
+ prompt (str): The prompt text.
+
+ Returns:
+ float: The CLIP score.
+ """
+ with torch.no_grad():
+ # Process the prompt
+ text = self.tokenizer([prompt]).to(device=self.device, non_blocking=True)
+
+ # Calculate the CLIP score
+ outputs = self.model(image, text)
+ image_features, text_features = outputs["image_features"], outputs["text_features"]
+ logits_per_image = image_features @ text_features.T
+ clip_score = torch.diagonal(logits_per_image).cpu().numpy()
+
+ return clip_score[0].item()
+
+ @torch.no_grad()
+ def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str) -> List[float]:
+ """Score the images based on the prompt.
+
+ Args:
+ images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
+ prompt (str): The prompt text.
+
+ Returns:
+ List[float]: List of CLIP scores for the images.
+ """
+ if isinstance(images, (str, Image.Image)):
+ # Single image
+ if isinstance(images, str):
+ image = self.preprocess_val(Image.open(images)).unsqueeze(0).to(device=self.device, non_blocking=True)
+ else:
+ image = self.preprocess_val(images).unsqueeze(0).to(device=self.device, non_blocking=True)
+ return [self._calculate_score(image, prompt)]
+ elif isinstance(images, list):
+ # Multiple images
+ scores = []
+ for one_images in images:
+ if isinstance(one_images, str):
+ image = self.preprocess_val(Image.open(one_images)).unsqueeze(0).to(device=self.device, non_blocking=True)
+ elif isinstance(one_images, Image.Image):
+ image = self.preprocess_val(one_images).unsqueeze(0).to(device=self.device, non_blocking=True)
+ else:
+ raise TypeError("The type of parameter images is illegal.")
+ scores.append(self._calculate_score(image, prompt))
+ return scores
+ else:
+ raise TypeError("The type of parameter images is illegal.")
diff --git a/PusaV1/diffsynth/extensions/ImageQualityMetric/config.py b/PusaV1/diffsynth/extensions/ImageQualityMetric/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..60faadcb1e5554c8f8f29a64fc55c3150d8a8bbe
--- /dev/null
+++ b/PusaV1/diffsynth/extensions/ImageQualityMetric/config.py
@@ -0,0 +1,23 @@
+import os
+
+current_dir = os.path.dirname(os.path.abspath(__file__))
+project_root = os.path.abspath(os.path.join(current_dir, '../../../'))
+model_path = os.path.join(project_root, 'models', 'QualityMetric')
+
+
+def get_model_path(model_name):
+ return os.path.join(model_path, model_name)
+
+
+MODEL_PATHS = {
+ "aesthetic_predictor": get_model_path("aesthetic-predictor/sac+logos+ava1-l14-linearMSE.safetensors"),
+ "open_clip": get_model_path("CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin"),
+ "hpsv2": get_model_path("HPS_v2/HPS_v2_compressed.safetensors"),
+ "hpsv2.1": get_model_path("HPS_v2/HPS_v2.1_compressed.safetensors"),
+ "imagereward": get_model_path("ImageReward/ImageReward.safetensors"),
+ "med_config": get_model_path("ImageReward/med_config.json"),
+ "clip": get_model_path("CLIP-ViT-H-14-laion2B-s32B-b79K"),
+ "clip-large": get_model_path("clip-vit-large-patch14"),
+ "mps": get_model_path("MPS_overall_checkpoint/MPS_overall_checkpoint_diffsynth.safetensors"),
+ "pickscore": get_model_path("PickScore_v1")
+}
\ No newline at end of file
diff --git a/PusaV1/diffsynth/extensions/ImageQualityMetric/hps.py b/PusaV1/diffsynth/extensions/ImageQualityMetric/hps.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4b266bd261a95676ba700d38c3a63b143bbbb40
--- /dev/null
+++ b/PusaV1/diffsynth/extensions/ImageQualityMetric/hps.py
@@ -0,0 +1,118 @@
+from typing import List, Union
+from PIL import Image
+import torch
+from .open_clip import create_model_and_transforms, get_tokenizer
+from safetensors.torch import load_file
+import os
+from .config import MODEL_PATHS
+
+class HPScore_v2(torch.nn.Module):
+ def __init__(self, device: torch.device, path: str = MODEL_PATHS, model_version: str = "v2"):
+ super().__init__()
+ """Initialize the Selector with a model and tokenizer.
+
+ Args:
+ device (torch.device): The device to load the model on.
+ model_version (str): The version of the model to load. Supports "v2" or "v21". Default is "v2".
+ """
+ self.device = device
+
+ if model_version == "v2":
+ safetensors_path = path.get("hpsv2")
+ elif model_version == "v21":
+ safetensors_path = path.get("hpsv2.1")
+ else:
+ raise ValueError(f"Unsupported model version: {model_version}. Choose 'v2' or 'v21'.")
+
+ # Create model and transforms
+ model, _, self.preprocess_val = create_model_and_transforms(
+ "ViT-H-14",
+ # "laion2B-s32B-b79K",
+ pretrained=path.get("open_clip"),
+ precision="amp",
+ device=device,
+ jit=False,
+ force_quick_gelu=False,
+ force_custom_text=False,
+ force_patch_dropout=False,
+ force_image_size=None,
+ pretrained_image=False,
+ image_mean=None,
+ image_std=None,
+ light_augmentation=True,
+ aug_cfg={},
+ output_dict=True,
+ with_score_predictor=False,
+ with_region_predictor=False,
+ )
+
+ # Load model weights
+ try:
+ state_dict = load_file(safetensors_path)
+ model.load_state_dict(state_dict)
+ except Exception as e:
+ raise ValueError(f"Error loading model weights from {safetensors_path}: {e}")
+
+ # Initialize tokenizer and model
+ self.tokenizer = get_tokenizer("ViT-H-14", path["open_clip_bpe"])
+ model = model.to(device)
+ model.eval()
+ self.model = model
+
+ def _calculate_score(self, image: torch.Tensor, prompt: str) -> float:
+ """Calculate the HPS score for a single image and prompt.
+
+ Args:
+ image (torch.Tensor): The processed image tensor.
+ prompt (str): The prompt text.
+
+ Returns:
+ float: The HPS score.
+ """
+ with torch.no_grad():
+ # Process the prompt
+ text = self.tokenizer([prompt]).to(device=self.device, non_blocking=True)
+
+ # Calculate the HPS score
+ outputs = self.model(image, text)
+ image_features, text_features = outputs["image_features"], outputs["text_features"]
+ logits_per_image = image_features @ text_features.T
+ hps_score = torch.diagonal(logits_per_image).cpu().numpy()
+
+ return hps_score[0].item()
+
+ @torch.no_grad()
+ def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str) -> List[float]:
+ """Score the images based on the prompt.
+
+ Args:
+ images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
+ prompt (str): The prompt text.
+
+ Returns:
+ List[float]: List of HPS scores for the images.
+ """
+ try:
+ if isinstance(images, (str, Image.Image)):
+ # Single image
+ if isinstance(images, str):
+ image = self.preprocess_val(Image.open(images)).unsqueeze(0).to(device=self.device, non_blocking=True)
+ else:
+ image = self.preprocess_val(images).unsqueeze(0).to(device=self.device, non_blocking=True)
+ return [self._calculate_score(image, prompt)]
+ elif isinstance(images, list):
+ # Multiple images
+ scores = []
+ for one_images in images:
+ if isinstance(one_images, str):
+ image = self.preprocess_val(Image.open(one_images)).unsqueeze(0).to(device=self.device, non_blocking=True)
+ elif isinstance(one_images, Image.Image):
+ image = self.preprocess_val(one_images).unsqueeze(0).to(device=self.device, non_blocking=True)
+ else:
+ raise TypeError("The type of parameter images is illegal.")
+ scores.append(self._calculate_score(image, prompt))
+ return scores
+ else:
+ raise TypeError("The type of parameter images is illegal.")
+ except Exception as e:
+ raise RuntimeError(f"Error in scoring images: {e}")
diff --git a/PusaV1/diffsynth/extensions/ImageQualityMetric/imagereward.py b/PusaV1/diffsynth/extensions/ImageQualityMetric/imagereward.py
new file mode 100644
index 0000000000000000000000000000000000000000..27607904b23fa1691c5a6966eb4030cd813567b0
--- /dev/null
+++ b/PusaV1/diffsynth/extensions/ImageQualityMetric/imagereward.py
@@ -0,0 +1,212 @@
+import os
+import torch
+from PIL import Image
+from typing import List, Union
+from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
+from .BLIP.blip_pretrain import BLIP_Pretrain
+from torchvision.transforms import InterpolationMode
+from safetensors.torch import load_file
+from .config import MODEL_PATHS
+BICUBIC = InterpolationMode.BICUBIC
+
+def _convert_image_to_rgb(image):
+ return image.convert("RGB")
+
+def _transform(n_px):
+ return Compose([
+ Resize(n_px, interpolation=BICUBIC),
+ CenterCrop(n_px),
+ _convert_image_to_rgb,
+ ToTensor(),
+ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
+ ])
+
+class MLP(torch.nn.Module):
+ def __init__(self, input_size):
+ super().__init__()
+ self.input_size = input_size
+
+ self.layers = torch.nn.Sequential(
+ torch.nn.Linear(self.input_size, 1024),
+ #nn.ReLU(),
+ torch.nn.Dropout(0.2),
+ torch.nn.Linear(1024, 128),
+ #nn.ReLU(),
+ torch.nn.Dropout(0.2),
+ torch.nn.Linear(128, 64),
+ #nn.ReLU(),
+ torch.nn.Dropout(0.1),
+ torch.nn.Linear(64, 16),
+ #nn.ReLU(),
+ torch.nn.Linear(16, 1)
+ )
+
+ # initial MLP param
+ for name, param in self.layers.named_parameters():
+ if 'weight' in name:
+ torch.nn.init.normal_(param, mean=0.0, std=1.0/(self.input_size+1))
+ if 'bias' in name:
+ torch.nn.init.constant_(param, val=0)
+
+ def forward(self, input):
+ return self.layers(input)
+
+class ImageReward(torch.nn.Module):
+ def __init__(self, med_config, device='cpu', bert_model_path=""):
+ super().__init__()
+ self.device = device
+
+ self.blip = BLIP_Pretrain(image_size=224, vit='large', med_config=med_config, bert_model_path=bert_model_path)
+ self.preprocess = _transform(224)
+ self.mlp = MLP(768)
+
+ self.mean = 0.16717362830052426
+ self.std = 1.0333394966054072
+
+ def score_grad(self, prompt_ids, prompt_attention_mask, image):
+ """Calculate the score with gradient for a single image and prompt.
+
+ Args:
+ prompt_ids (torch.Tensor): Tokenized prompt IDs.
+ prompt_attention_mask (torch.Tensor): Attention mask for the prompt.
+ image (torch.Tensor): The processed image tensor.
+
+ Returns:
+ torch.Tensor: The reward score.
+ """
+ image_embeds = self.blip.visual_encoder(image)
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.device)
+ text_output = self.blip.text_encoder(
+ prompt_ids,
+ attention_mask=prompt_attention_mask,
+ encoder_hidden_states=image_embeds,
+ encoder_attention_mask=image_atts,
+ return_dict=True,
+ )
+ txt_features = text_output.last_hidden_state[:, 0, :]
+ rewards = self.mlp(txt_features)
+ rewards = (rewards - self.mean) / self.std
+ return rewards
+
+ def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str = "") -> List[float]:
+ """Score the images based on the prompt.
+
+ Args:
+ prompt (str): The prompt text.
+ images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
+
+ Returns:
+ List[float]: List of scores for the images.
+ """
+ if isinstance(images, (str, Image.Image)):
+ # Single image
+ if isinstance(images, str):
+ pil_image = Image.open(images)
+ else:
+ pil_image = images
+ image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
+ return [self._calculate_score(prompt, image).item()]
+ elif isinstance(images, list):
+ # Multiple images
+ scores = []
+ for one_image in images:
+ if isinstance(one_image, str):
+ pil_image = Image.open(one_image)
+ elif isinstance(one_image, Image.Image):
+ pil_image = one_image
+ else:
+ raise TypeError("The type of parameter images is illegal.")
+ image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
+ scores.append(self._calculate_score(prompt, image).item())
+ return scores
+ else:
+ raise TypeError("The type of parameter images is illegal.")
+
+ def _calculate_score(self, prompt: str, image: torch.Tensor) -> torch.Tensor:
+ """Calculate the score for a single image and prompt.
+
+ Args:
+ prompt (str): The prompt text.
+ image (torch.Tensor): The processed image tensor.
+
+ Returns:
+ torch.Tensor: The reward score.
+ """
+ text_input = self.blip.tokenizer(prompt, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(self.device)
+ image_embeds = self.blip.visual_encoder(image)
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.device)
+ text_output = self.blip.text_encoder(
+ text_input.input_ids,
+ attention_mask=text_input.attention_mask,
+ encoder_hidden_states=image_embeds,
+ encoder_attention_mask=image_atts,
+ return_dict=True,
+ )
+ txt_features = text_output.last_hidden_state[:, 0, :].float()
+ rewards = self.mlp(txt_features)
+ rewards = (rewards - self.mean) / self.std
+ return rewards
+
+ def inference_rank(self, prompt: str, generations_list: List[Union[str, Image.Image]]) -> tuple:
+ """Rank the images based on the prompt.
+
+ Args:
+ prompt (str): The prompt text.
+ generations_list (List[Union[str, Image.Image]]): List of image paths or PIL images.
+
+ Returns:
+ tuple: (indices, rewards) where indices are the ranks and rewards are the scores.
+ """
+ text_input = self.blip.tokenizer(prompt, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(self.device)
+ txt_set = []
+ for generation in generations_list:
+ if isinstance(generation, str):
+ pil_image = Image.open(generation)
+ elif isinstance(generation, Image.Image):
+ pil_image = generation
+ else:
+ raise TypeError("The type of parameter generations_list is illegal.")
+ image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
+ image_embeds = self.blip.visual_encoder(image)
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.device)
+ text_output = self.blip.text_encoder(
+ text_input.input_ids,
+ attention_mask=text_input.attention_mask,
+ encoder_hidden_states=image_embeds,
+ encoder_attention_mask=image_atts,
+ return_dict=True,
+ )
+ txt_set.append(text_output.last_hidden_state[:, 0, :])
+ txt_features = torch.cat(txt_set, 0).float()
+ rewards = self.mlp(txt_features)
+ rewards = (rewards - self.mean) / self.std
+ rewards = torch.squeeze(rewards)
+ _, rank = torch.sort(rewards, dim=0, descending=True)
+ _, indices = torch.sort(rank, dim=0)
+ indices = indices + 1
+ return indices.detach().cpu().numpy().tolist(), rewards.detach().cpu().numpy().tolist()
+
+
+class ImageRewardScore(torch.nn.Module):
+ def __init__(self, device: Union[str, torch.device], path: str = MODEL_PATHS):
+ super().__init__()
+ self.device = device if isinstance(device, torch.device) else torch.device(device)
+ model_path = path.get("imagereward")
+ med_config = path.get("med_config")
+ state_dict = load_file(model_path)
+ self.model = ImageReward(device=self.device, med_config=med_config, bert_model_path=path.get("bert_model_path")).to(self.device)
+ self.model.load_state_dict(state_dict, strict=False)
+ self.model.eval()
+
+ @torch.no_grad()
+ def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str) -> List[float]:
+ """Score the images based on the prompt.
+
+ Args:
+ images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
+ prompt (str): The prompt text.
+
+ Returns:
+ List[float]: List of scores for the images.
+ """
+ return self.model.score(images, prompt)
diff --git a/PusaV1/diffsynth/extensions/ImageQualityMetric/mps.py b/PusaV1/diffsynth/extensions/ImageQualityMetric/mps.py
new file mode 100644
index 0000000000000000000000000000000000000000..d15aad4b81026a743911512bcc569520182b31c5
--- /dev/null
+++ b/PusaV1/diffsynth/extensions/ImageQualityMetric/mps.py
@@ -0,0 +1,129 @@
+import numpy as np
+import torch
+from PIL import Image
+from io import BytesIO
+from tqdm.auto import tqdm
+from transformers import CLIPFeatureExtractor, CLIPImageProcessor
+from transformers import CLIPConfig
+from dataclasses import dataclass
+from transformers import CLIPModel as HFCLIPModel
+from safetensors.torch import load_file
+from torch import nn, einsum
+
+from .trainer.models.base_model import BaseModelConfig
+
+from transformers import CLIPConfig
+from transformers import AutoProcessor, AutoModel, AutoTokenizer
+from typing import Any, Optional, Tuple, Union, List
+import torch
+
+from .trainer.models.cross_modeling import Cross_model
+from .trainer.models import clip_model
+import torch.nn.functional as F
+import gc
+import json
+from .config import MODEL_PATHS
+
+class MPScore(torch.nn.Module):
+ def __init__(self, device: Union[str, torch.device], path: str = MODEL_PATHS, condition: str = 'overall'):
+ super().__init__()
+ """Initialize the MPSModel with a processor, tokenizer, and model.
+
+ Args:
+ device (Union[str, torch.device]): The device to load the model on.
+ """
+ self.device = device
+ processor_name_or_path = path.get("clip")
+ self.image_processor = CLIPImageProcessor.from_pretrained(processor_name_or_path)
+ self.tokenizer = AutoTokenizer.from_pretrained(processor_name_or_path, trust_remote_code=True)
+ self.model = clip_model.CLIPModel(processor_name_or_path, config_file=True)
+ state_dict = load_file(path.get("mps"))
+ self.model.load_state_dict(state_dict, strict=False)
+ self.model.to(device)
+ self.condition = condition
+
+ def _calculate_score(self, image: torch.Tensor, prompt: str) -> float:
+ """Calculate the reward score for a single image and prompt.
+
+ Args:
+ image (torch.Tensor): The processed image tensor.
+ prompt (str): The prompt text.
+
+ Returns:
+ float: The reward score.
+ """
+ def _tokenize(caption):
+ input_ids = self.tokenizer(
+ caption,
+ max_length=self.tokenizer.model_max_length,
+ padding="max_length",
+ truncation=True,
+ return_tensors="pt"
+ ).input_ids
+ return input_ids
+
+ text_input = _tokenize(prompt).to(self.device)
+ if self.condition == 'overall':
+ condition_prompt = 'light, color, clarity, tone, style, ambiance, artistry, shape, face, hair, hands, limbs, structure, instance, texture, quantity, attributes, position, number, location, word, things'
+ elif self.condition == 'aesthetics':
+ condition_prompt = 'light, color, clarity, tone, style, ambiance, artistry'
+ elif self.condition == 'quality':
+ condition_prompt = 'shape, face, hair, hands, limbs, structure, instance, texture'
+ elif self.condition == 'semantic':
+ condition_prompt = 'quantity, attributes, position, number, location'
+ else:
+ raise ValueError(
+ f"Unsupported condition: {self.condition}. Choose 'overall', 'aesthetics', 'quality', or 'semantic'.")
+ condition_batch = _tokenize(condition_prompt).repeat(text_input.shape[0], 1).to(self.device)
+
+ with torch.no_grad():
+ text_f, text_features = self.model.model.get_text_features(text_input)
+
+ image_f = self.model.model.get_image_features(image.half())
+ condition_f, _ = self.model.model.get_text_features(condition_batch)
+
+ sim_text_condition = einsum('b i d, b j d -> b j i', text_f, condition_f)
+ sim_text_condition = torch.max(sim_text_condition, dim=1, keepdim=True)[0]
+ sim_text_condition = sim_text_condition / sim_text_condition.max()
+ mask = torch.where(sim_text_condition > 0.3, 0, float('-inf'))
+ mask = mask.repeat(1, image_f.shape[1], 1)
+ image_features = self.model.cross_model(image_f, text_f, mask.half())[:, 0, :]
+
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
+ image_score = self.model.logit_scale.exp() * text_features @ image_features.T
+
+ return image_score[0].cpu().numpy().item()
+
+ @torch.no_grad()
+ def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str) -> List[float]:
+ """Score the images based on the prompt.
+
+ Args:
+ images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
+ prompt (str): The prompt text.
+
+ Returns:
+ List[float]: List of reward scores for the images.
+ """
+ if isinstance(images, (str, Image.Image)):
+ # Single image
+ if isinstance(images, str):
+ image = self.image_processor(Image.open(images), return_tensors="pt")["pixel_values"].to(self.device)
+ else:
+ image = self.image_processor(images, return_tensors="pt")["pixel_values"].to(self.device)
+ return [self._calculate_score(image, prompt)]
+ elif isinstance(images, list):
+ # Multiple images
+ scores = []
+ for one_images in images:
+ if isinstance(one_images, str):
+ image = self.image_processor(Image.open(one_images), return_tensors="pt")["pixel_values"].to(self.device)
+ elif isinstance(one_images, Image.Image):
+ image = self.image_processor(one_images, return_tensors="pt")["pixel_values"].to(self.device)
+ else:
+ raise TypeError("The type of parameter images is illegal.")
+ scores.append(self._calculate_score(image, prompt))
+ return scores
+ else:
+ raise TypeError("The type of parameter images is illegal.")
diff --git a/PusaV1/diffsynth/extensions/ImageQualityMetric/open_clip/__init__.py b/PusaV1/diffsynth/extensions/ImageQualityMetric/open_clip/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1560db0b543b7b8857f39d7de435c834380666ab
--- /dev/null
+++ b/PusaV1/diffsynth/extensions/ImageQualityMetric/open_clip/__init__.py
@@ -0,0 +1,14 @@
+from .coca_model import CoCa
+from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
+from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss
+from .factory import list_models, add_model_config, get_model_config, load_checkpoint
+from .loss import ClipLoss, DistillClipLoss, CoCaLoss
+from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \
+ convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype
+from .openai import load_openai_model, list_openai_models
+from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \
+ get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
+from .push_to_hf_hub import push_pretrained_to_hf_hub, push_to_hf_hub
+from .tokenizer import SimpleTokenizer
+from .transform import image_transform, AugmentationCfg
+from .utils import freeze_batch_norm_2d
diff --git a/PusaV1/diffsynth/extensions/ImageQualityMetric/open_clip/coca_model.py b/PusaV1/diffsynth/extensions/ImageQualityMetric/open_clip/coca_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..039453af70d1c865dd7cc6016f732aff2f7dc3d2
--- /dev/null
+++ b/PusaV1/diffsynth/extensions/ImageQualityMetric/open_clip/coca_model.py
@@ -0,0 +1,458 @@
+from typing import Optional
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+import numpy as np
+from dataclasses import dataclass
+
+from .transformer import (
+ LayerNormFp32,
+ LayerNorm,
+ QuickGELU,
+ MultimodalTransformer,
+)
+from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower
+
+try:
+ from transformers import (
+ BeamSearchScorer,
+ LogitsProcessorList,
+ TopPLogitsWarper,
+ TopKLogitsWarper,
+ RepetitionPenaltyLogitsProcessor,
+ MinLengthLogitsProcessor,
+ MaxLengthCriteria,
+ StoppingCriteriaList
+ )
+
+ GENERATION_TYPES = {
+ "top_k": TopKLogitsWarper,
+ "top_p": TopPLogitsWarper,
+ "beam_search": "beam_search"
+ }
+ _has_transformers = True
+except ImportError as e:
+ GENERATION_TYPES = {
+ "top_k": None,
+ "top_p": None,
+ "beam_search": "beam_search"
+ }
+ _has_transformers = False
+
+
+@dataclass
+class MultimodalCfg(CLIPTextCfg):
+ mlp_ratio: int = 4
+ dim_head: int = 64
+ heads: int = 8
+ n_queries: int = 256
+ attn_pooler_heads: int = 8
+
+
+def _build_text_decoder_tower(
+ embed_dim,
+ multimodal_cfg,
+ quick_gelu: bool = False,
+ cast_dtype: Optional[torch.dtype] = None,
+):
+ multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
+ act_layer = QuickGELU if quick_gelu else nn.GELU
+ norm_layer = (
+ LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
+ )
+
+ decoder = MultimodalTransformer(
+ context_length=multimodal_cfg.context_length,
+ width=multimodal_cfg.width,
+ heads=multimodal_cfg.heads,
+ layers=multimodal_cfg.layers,
+ ls_init_value=multimodal_cfg.ls_init_value,
+ output_dim=embed_dim,
+ act_layer=act_layer,
+ norm_layer=norm_layer,
+ )
+
+ return decoder
+
+
+class CoCa(nn.Module):
+ def __init__(
+ self,
+ embed_dim,
+ multimodal_cfg: MultimodalCfg,
+ text_cfg: CLIPTextCfg,
+ vision_cfg: CLIPVisionCfg,
+ quick_gelu: bool = False,
+ cast_dtype: Optional[torch.dtype] = None,
+ pad_id: int = 0,
+ ):
+ super().__init__()
+ multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
+ text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg
+ vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg
+
+ self.text = _build_text_tower(
+ embed_dim=embed_dim,
+ text_cfg=text_cfg,
+ quick_gelu=quick_gelu,
+ cast_dtype=cast_dtype,
+ )
+
+ vocab_size = (
+ text_cfg.vocab_size # for hf models
+ if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None
+ else text_cfg.vocab_size
+ )
+
+ self.visual = _build_vision_tower(
+ embed_dim=embed_dim,
+ vision_cfg=vision_cfg,
+ quick_gelu=quick_gelu,
+ cast_dtype=cast_dtype,
+ )
+
+ self.text_decoder = _build_text_decoder_tower(
+ vocab_size,
+ multimodal_cfg=multimodal_cfg,
+ quick_gelu=quick_gelu,
+ cast_dtype=cast_dtype,
+ )
+
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
+ self.pad_id = pad_id
+
+ @torch.jit.ignore
+ def set_grad_checkpointing(self, enable=True):
+ self.visual.set_grad_checkpointing(enable)
+ self.text.set_grad_checkpointing(enable)
+ self.text_decoder.set_grad_checkpointing(enable)
+
+ def _encode_image(self, images, normalize=True):
+ image_latent, tokens_embs = self.visual(images)
+ image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent
+ return image_latent, tokens_embs
+
+ def _encode_text(self, text, normalize=True, embed_cls=True):
+ text = text[:, :-1] if embed_cls else text # make space for CLS token
+ text_latent, token_emb = self.text(text)
+ text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent
+ return text_latent, token_emb
+
+ def encode_image(self, images, normalize=True):
+ image_latent, _ = self._encode_image(images, normalize=normalize)
+ return image_latent
+
+ def encode_text(self, text, normalize=True, embed_cls=True):
+ text_latent, _ = self._encode_text(text, normalize=normalize, embed_cls=embed_cls)
+ return text_latent
+
+ def forward(self, image, text, embed_cls=True, image_latent=None, image_embs=None):
+ text_latent, token_embs = self._encode_text(text, embed_cls=embed_cls)
+ if image_latent is None or image_embs is None:
+ image_latent, image_embs = self._encode_image(image)
+
+ # TODO: add assertion to avoid bugs?
+ labels = text[:, -token_embs.shape[1]:]
+
+ logits = self.text_decoder(image_embs, token_embs)
+ return {
+ "image_features": image_latent,
+ "text_features": text_latent,
+ "logits": logits,
+ "labels": labels,
+ "logit_scale": self.logit_scale.exp()
+ }
+
+ def generate(
+ self,
+ image,
+ text=None,
+ seq_len=30,
+ max_seq_len=77,
+ temperature=1.,
+ generation_type="beam_search",
+ top_p=0.1, # keep tokens in the 1 - top_p quantile
+ top_k=1, # keeps the top_k most probable tokens
+ pad_token_id=None,
+ eos_token_id=None,
+ sot_token_id=None,
+ num_beams=6,
+ num_beam_groups=3,
+ min_seq_len=5,
+ stopping_criteria=None,
+ repetition_penalty=1.0,
+ fixed_output_length=False # if True output.shape == (batch_size, seq_len)
+ ):
+ # taking many ideas and components from HuggingFace GenerationMixin
+ # https://huggingface.co/docs/transformers/main/en/main_classes/text_generation
+ assert _has_transformers, "Please install transformers for generate functionality. `pip install transformers`."
+ assert seq_len > min_seq_len, "seq_len must be larger than min_seq_len"
+
+ with torch.no_grad():
+ sot_token_id = 49406 if sot_token_id is None else sot_token_id
+ eos_token_id = 49407 if eos_token_id is None else eos_token_id
+ pad_token_id = self.pad_id if pad_token_id is None else pad_token_id
+ logit_processor = LogitsProcessorList(
+ [
+ MinLengthLogitsProcessor(min_seq_len, eos_token_id),
+ RepetitionPenaltyLogitsProcessor(repetition_penalty),
+ ]
+ )
+
+ if stopping_criteria is None:
+ stopping_criteria = [MaxLengthCriteria(max_length=seq_len)]
+
+ stopping_criteria = StoppingCriteriaList(
+ stopping_criteria
+ )
+
+ device = image.device
+
+ if generation_type == "beam_search":
+ output = self._generate_beamsearch(
+ image_inputs = image,
+ pad_token_id=pad_token_id,
+ eos_token_id=eos_token_id,
+ sot_token_id=sot_token_id,
+ num_beams=num_beams,
+ num_beam_groups=num_beam_groups,
+ min_seq_len=min_seq_len,
+ stopping_criteria=stopping_criteria,
+ logit_processor=logit_processor,
+ )
+ if fixed_output_length and output.shape[1] < seq_len:
+ return torch.cat(
+ (output, torch.ones(output.shape[0], seq_len-output.shape[1], device=device, dtype=output.dtype) * self.pad_id),
+ dim=1
+ )
+ return output
+
+ elif generation_type == "top_p":
+ logit_warper = GENERATION_TYPES[generation_type](top_p)
+ elif generation_type == "top_k":
+ logit_warper = GENERATION_TYPES[generation_type](top_k)
+ else:
+ raise ValueError(
+ f"generation_type has to be one of "
+ f"{'| ' + ' | '.join(list(GENERATION_TYPES.keys())) + ' |'}."
+ )
+
+ image_latent, image_embs = self._encode_image(image)
+
+ if text is None:
+ text = torch.ones((image.shape[0], 1), device=device, dtype=torch.long) * sot_token_id
+
+ was_training = self.training
+ num_dims = len(text.shape)
+
+ if num_dims == 1:
+ text = text[None, :]
+
+ cur_len = text.shape[1]
+ self.eval()
+ out = text
+
+ while True:
+ x = out[:, -max_seq_len:]
+ cur_len = x.shape[1]
+ logits = self(image, x, image_latent=image_latent, image_embs=image_embs, embed_cls=False)["logits"][:, -1]
+ mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id)
+ sample = torch.ones((out.shape[0], 1), device=device, dtype=torch.long) * pad_token_id
+
+ if mask.all():
+ if not fixed_output_length:
+ break
+ else:
+ logits = logits[~mask, :]
+ filtered_logits = logit_processor(x[~mask, :], logits)
+ filtered_logits = logit_warper(x[~mask, :], filtered_logits)
+ probs = F.softmax(filtered_logits / temperature, dim=-1)
+
+ if (cur_len + 1 == seq_len):
+ sample[~mask, :] = torch.ones((sum(~mask), 1), device=device, dtype=torch.long) * eos_token_id
+ else:
+ sample[~mask, :] = torch.multinomial(probs, 1)
+
+ out = torch.cat((out, sample), dim=-1)
+
+ cur_len += 1
+
+ if stopping_criteria(out, None):
+ break
+
+ if num_dims == 1:
+ out = out.squeeze(0)
+
+ self.train(was_training)
+ return out
+
+ def _generate_beamsearch(
+ self,
+ image_inputs,
+ pad_token_id=None,
+ eos_token_id=None,
+ sot_token_id=None,
+ num_beams=6,
+ num_beam_groups=3,
+ min_seq_len=5,
+ stopping_criteria=None,
+ logit_processor=None,
+ logit_warper=None,
+ ):
+ device = image_inputs.device
+ batch_size = image_inputs.shape[0]
+ image_inputs = torch.repeat_interleave(image_inputs, num_beams, dim=0)
+ image_latent, image_embs = self._encode_image(image_inputs)
+
+ input_ids = torch.ones((batch_size * num_beams, 1), device=device, dtype=torch.long)
+ input_ids = input_ids * sot_token_id
+ beam_scorer = BeamSearchScorer(
+ batch_size=batch_size,
+ num_beams=num_beams,
+ device=device,
+ num_beam_groups=num_beam_groups,
+ )
+ # instantiate logits processors
+ logits_processor = (
+ LogitsProcessorList([MinLengthLogitsProcessor(min_seq_len, eos_token_id=eos_token_id)])
+ if logit_processor is None
+ else logit_processor
+ )
+
+ batch_size = len(beam_scorer._beam_hyps)
+ num_beams = beam_scorer.num_beams
+ num_beam_groups = beam_scorer.num_beam_groups
+ num_sub_beams = num_beams // num_beam_groups
+ batch_beam_size, cur_len = input_ids.shape
+ beam_indices = None
+
+ if num_beams * batch_size != batch_beam_size:
+ raise ValueError(
+ f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
+ )
+
+ beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)
+ # initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in
+ # the same group don't produce same tokens everytime.
+ beam_scores[:, ::num_sub_beams] = 0
+ beam_scores = beam_scores.view((batch_size * num_beams,))
+
+ while True:
+
+ # predicted tokens in cur_len step
+ current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)
+
+ # indices which will form the beams in the next time step
+ reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device)
+
+ # do one decoder step on all beams of all sentences in batch
+ model_inputs = prepare_inputs_for_generation(input_ids=input_ids, image_inputs=image_inputs)
+ outputs = self(
+ model_inputs['images'],
+ model_inputs['text'],
+ embed_cls=False,
+ image_latent=image_latent,
+ image_embs=image_embs
+ )
+
+ for beam_group_idx in range(num_beam_groups):
+ group_start_idx = beam_group_idx * num_sub_beams
+ group_end_idx = min(group_start_idx + num_sub_beams, num_beams)
+ group_size = group_end_idx - group_start_idx
+
+ # indices of beams of current group among all sentences in batch
+ batch_group_indices = []
+
+ for batch_idx in range(batch_size):
+ batch_group_indices.extend(
+ [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]
+ )
+ group_input_ids = input_ids[batch_group_indices]
+
+ # select outputs of beams of currentg group only
+ next_token_logits = outputs['logits'][batch_group_indices, -1, :]
+ vocab_size = next_token_logits.shape[-1]
+
+ next_token_scores_processed = logits_processor(
+ group_input_ids, next_token_logits, current_tokens=current_tokens, beam_group_idx=beam_group_idx
+ )
+ next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1)
+ next_token_scores = next_token_scores.expand_as(next_token_scores_processed)
+
+ # reshape for beam search
+ next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)
+
+ next_token_scores, next_tokens = torch.topk(
+ next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True
+ )
+
+ next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
+ next_tokens = next_tokens % vocab_size
+
+ # stateless
+ process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
+ beam_outputs = beam_scorer.process(
+ group_input_ids,
+ next_token_scores,
+ next_tokens,
+ next_indices,
+ pad_token_id=pad_token_id,
+ eos_token_id=eos_token_id,
+ beam_indices=process_beam_indices,
+ )
+ beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
+ beam_next_tokens = beam_outputs["next_beam_tokens"]
+ beam_idx = beam_outputs["next_beam_indices"]
+
+ input_ids[batch_group_indices] = group_input_ids[beam_idx]
+ group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
+ current_tokens[batch_group_indices] = group_input_ids[:, -1]
+
+ # (beam_idx // group_size) -> batch_idx
+ # (beam_idx % group_size) -> offset of idx inside the group
+ reordering_indices[batch_group_indices] = (
+ num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + group_start_idx + (beam_idx % group_size)
+ )
+
+ input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)
+
+ # increase cur_len
+ cur_len = cur_len + 1
+ if beam_scorer.is_done or stopping_criteria(input_ids, None):
+ break
+
+ final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
+ sequence_outputs = beam_scorer.finalize(
+ input_ids,
+ beam_scores,
+ next_tokens,
+ next_indices,
+ pad_token_id=pad_token_id,
+ eos_token_id=eos_token_id,
+ max_length=stopping_criteria.max_length,
+ beam_indices=final_beam_indices,
+ )
+ return sequence_outputs['sequences']
+
+
+def prepare_inputs_for_generation(input_ids, image_inputs, past=None, **kwargs):
+ if past:
+ input_ids = input_ids[:, -1].unsqueeze(-1)
+
+ attention_mask = kwargs.get("attention_mask", None)
+ position_ids = kwargs.get("position_ids", None)
+
+ if attention_mask is not None and position_ids is None:
+ # create position_ids on the fly for batch generation
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ else:
+ position_ids = None
+ return {
+ "text": input_ids,
+ "images": image_inputs,
+ "past_key_values": past,
+ "position_ids": position_ids,
+ "attention_mask": attention_mask,
+ }
diff --git a/PusaV1/diffsynth/extensions/ImageQualityMetric/open_clip/constants.py b/PusaV1/diffsynth/extensions/ImageQualityMetric/open_clip/constants.py
new file mode 100644
index 0000000000000000000000000000000000000000..a670bb3fab442baeb9af53b91c312e6982af57ee
--- /dev/null
+++ b/PusaV1/diffsynth/extensions/ImageQualityMetric/open_clip/constants.py
@@ -0,0 +1,2 @@
+OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
+OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
diff --git a/PusaV1/diffsynth/extensions/ImageQualityMetric/open_clip/factory.py b/PusaV1/diffsynth/extensions/ImageQualityMetric/open_clip/factory.py
new file mode 100644
index 0000000000000000000000000000000000000000..5bd51a1bb6b69e0e69147c8b7cb8d7bd4899b349
--- /dev/null
+++ b/PusaV1/diffsynth/extensions/ImageQualityMetric/open_clip/factory.py
@@ -0,0 +1,433 @@
+import json
+import logging
+import os
+import pathlib
+import re
+from copy import deepcopy
+from pathlib import Path
+# from turtle import forward
+from typing import Any, Dict, Optional, Tuple, Union
+
+import torch
+
+from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
+from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\
+ resize_pos_embed, get_cast_dtype
+from .coca_model import CoCa
+from .loss import ClipLoss, DistillClipLoss, CoCaLoss
+from .openai import load_openai_model
+from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model, download_pretrained_from_hf
+from .transform import image_transform, AugmentationCfg
+from .tokenizer import HFTokenizer, SimpleTokenizer
+
+
+HF_HUB_PREFIX = 'hf-hub:'
+_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
+_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
+
+
+def _natural_key(string_):
+ return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
+
+
+def _rescan_model_configs():
+ global _MODEL_CONFIGS
+
+ config_ext = ('.json',)
+ config_files = []
+ for config_path in _MODEL_CONFIG_PATHS:
+ if config_path.is_file() and config_path.suffix in config_ext:
+ config_files.append(config_path)
+ elif config_path.is_dir():
+ for ext in config_ext:
+ config_files.extend(config_path.glob(f'*{ext}'))
+
+ for cf in config_files:
+ with open(cf, 'r') as f:
+ model_cfg = json.load(f)
+ if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')):
+ _MODEL_CONFIGS[cf.stem] = model_cfg
+
+ _MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))}
+
+
+_rescan_model_configs() # initial populate of model config registry
+
+
+def list_models():
+ """ enumerate available model architectures based on config files """
+ return list(_MODEL_CONFIGS.keys())
+
+
+def add_model_config(path):
+ """ add model config path or file and update registry """
+ if not isinstance(path, Path):
+ path = Path(path)
+ _MODEL_CONFIG_PATHS.append(path)
+ _rescan_model_configs()
+
+
+def get_model_config(model_name):
+ if model_name in _MODEL_CONFIGS:
+ return deepcopy(_MODEL_CONFIGS[model_name])
+ else:
+ return None
+
+
+def get_tokenizer(model_name, open_clip_bpe_path=None):
+ if model_name.startswith(HF_HUB_PREFIX):
+ tokenizer = HFTokenizer(model_name[len(HF_HUB_PREFIX):])
+ else:
+ config = get_model_config(model_name)
+ tokenizer = HFTokenizer(
+ config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else SimpleTokenizer(open_clip_bpe_path)
+ return tokenizer
+
+
+def load_state_dict(checkpoint_path: str, map_location='cpu'):
+ checkpoint = torch.load(checkpoint_path, map_location=map_location)
+ if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
+ state_dict = checkpoint['state_dict']
+ else:
+ state_dict = checkpoint
+ if next(iter(state_dict.items()))[0].startswith('module'):
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
+ return state_dict
+
+
+def load_checkpoint(model, checkpoint_path, strict=True):
+ state_dict = load_state_dict(checkpoint_path)
+ # detect old format and make compatible with new format
+ if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
+ state_dict = convert_to_custom_text_state_dict(state_dict)
+ resize_pos_embed(state_dict, model)
+ incompatible_keys = model.load_state_dict(state_dict, strict=strict)
+ return incompatible_keys
+
+
+def create_model(
+ model_name: str,
+ pretrained: Optional[str] = None,
+ precision: str = 'fp32',
+ device: Union[str, torch.device] = 'cpu',
+ jit: bool = False,
+ force_quick_gelu: bool = False,
+ force_custom_text: bool = False,
+ force_patch_dropout: Optional[float] = None,
+ force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
+ pretrained_image: bool = False,
+ pretrained_hf: bool = True,
+ cache_dir: Optional[str] = None,
+ output_dict: Optional[bool] = None,
+ require_pretrained: bool = False,
+):
+ has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX)
+ if has_hf_hub_prefix:
+ model_id = model_name[len(HF_HUB_PREFIX):]
+ checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
+ config_path = download_pretrained_from_hf(model_id, filename='open_clip_config.json', cache_dir=cache_dir)
+
+ with open(config_path, 'r', encoding='utf-8') as f:
+ config = json.load(f)
+ pretrained_cfg = config['preprocess_cfg']
+ model_cfg = config['model_cfg']
+ else:
+ model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names
+ checkpoint_path = None
+ pretrained_cfg = {}
+ model_cfg = None
+
+ if isinstance(device, str):
+ device = torch.device(device)
+
+ if pretrained and pretrained.lower() == 'openai':
+ logging.info(f'Loading pretrained {model_name} from OpenAI.')
+ model = load_openai_model(
+ model_name,
+ precision=precision,
+ device=device,
+ jit=jit,
+ cache_dir=cache_dir,
+ )
+
+ # to always output dict even if it is clip
+ if output_dict and hasattr(model, "output_dict"):
+ model.output_dict = True
+ else:
+ model_cfg = model_cfg or get_model_config(model_name)
+ if model_cfg is not None:
+ logging.info(f'Loaded {model_name} model config.')
+ else:
+ logging.error(f'Model config for {model_name} not found; available models {list_models()}.')
+ raise RuntimeError(f'Model config for {model_name} not found.')
+
+ if force_quick_gelu:
+ # override for use of QuickGELU on non-OpenAI transformer models
+ model_cfg["quick_gelu"] = True
+
+ if force_patch_dropout is not None:
+ # override the default patch dropout value
+ model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout
+
+ if force_image_size is not None:
+ # override model config's image size
+ model_cfg["vision_cfg"]["image_size"] = force_image_size
+
+ if pretrained_image:
+ if 'timm_model_name' in model_cfg.get('vision_cfg', {}):
+ # pretrained weight loading for timm models set via vision_cfg
+ model_cfg['vision_cfg']['timm_model_pretrained'] = True
+ else:
+ assert False, 'pretrained image towers currently only supported for timm models'
+
+ cast_dtype = get_cast_dtype(precision)
+ is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {})
+ custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model
+
+ if custom_text:
+ if is_hf_model:
+ model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf
+ if "coca" in model_name:
+ model = CoCa(**model_cfg, cast_dtype=cast_dtype)
+ else:
+ model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)
+ else:
+ model = CLIP(**model_cfg, cast_dtype=cast_dtype)
+
+ pretrained_loaded = False
+ if pretrained:
+ checkpoint_path = ''
+ pretrained_cfg = get_pretrained_cfg(model_name, pretrained)
+ if pretrained_cfg:
+ checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir)
+ elif os.path.exists(pretrained):
+ checkpoint_path = pretrained
+
+ if checkpoint_path:
+ logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
+ load_checkpoint(model, checkpoint_path)
+ else:
+ error_str = (
+ f'Pretrained weights ({pretrained}) not found for model {model_name}.'
+ f'Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.')
+ logging.warning(error_str)
+ raise RuntimeError(error_str)
+ pretrained_loaded = True
+ elif has_hf_hub_prefix:
+ logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
+ load_checkpoint(model, checkpoint_path)
+ pretrained_loaded = True
+
+ if require_pretrained and not pretrained_loaded:
+ # callers of create_model_from_pretrained always expect pretrained weights
+ raise RuntimeError(
+ f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.')
+
+ model.to(device=device)
+ if precision in ("fp16", "bf16"):
+ convert_weights_to_lp(model, dtype=torch.bfloat16 if precision == 'bf16' else torch.float16)
+
+ # set image / mean metadata from pretrained_cfg if available, or use default
+ model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN
+ model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD
+
+ # to always output dict even if it is clip
+ if output_dict and hasattr(model, "output_dict"):
+ model.output_dict = True
+
+ if jit:
+ model = torch.jit.script(model)
+
+ return model
+
+
+def create_loss(args):
+ if args.distill:
+ return DistillClipLoss(
+ local_loss=args.local_loss,
+ gather_with_grad=args.gather_with_grad,
+ cache_labels=True,
+ rank=args.rank,
+ world_size=args.world_size,
+ use_horovod=args.horovod,
+ )
+ elif "coca" in args.model.lower():
+ return CoCaLoss(
+ caption_loss_weight=args.coca_caption_loss_weight,
+ clip_loss_weight=args.coca_contrastive_loss_weight,
+ local_loss=args.local_loss,
+ gather_with_grad=args.gather_with_grad,
+ cache_labels=True,
+ rank=args.rank,
+ world_size=args.world_size,
+ use_horovod=args.horovod,
+ )
+ return ClipLoss(
+ local_loss=args.local_loss,
+ gather_with_grad=args.gather_with_grad,
+ cache_labels=True,
+ rank=args.rank,
+ world_size=args.world_size,
+ use_horovod=args.horovod,
+ )
+
+class MLP(torch.nn.Module):
+ def __init__(self, input_size):
+ super().__init__()
+ self.input_size = input_size
+ self.layers = torch.nn.Sequential(
+ torch.nn.Linear(self.input_size, 1024),
+ torch.nn.Dropout(0.2),
+ torch.nn.Linear(1024, 128),
+ torch.nn.Dropout(0.2),
+ torch.nn.Linear(128, 64),
+ torch.nn.Dropout(0.1),
+ torch.nn.Linear(64, 16),
+ torch.nn.Linear(16, 1)
+ )
+
+ def forward(self, x):
+ return self.layers(x)
+
+# class semantic_head(torch.nn.Module):
+# def __init__(self, input_size):
+# super().__init__()
+# self.input_size = input_size # for ViT-L-14 is 1024
+# self.seg_head = torch.nn.Sequential(
+# torch.nn.Linear(input_size, 128),
+# torch.nn.Dropout(0.2),
+# torch.nn.Linear(128, 64),
+# torch.nn.Dropout(0.1),
+# torch.nn.Linear(64, 16),
+# torch.nn.Linear(16, 1),
+# )
+# self.sigmoid = torch.nn.Sigmoid()
+
+# def forward(self, x):
+# return self.sigmoid(self.seg_head(x))
+
+def create_model_and_transforms(
+ model_name: str,
+ pretrained: Optional[str] = None,
+ precision: str = 'fp32',
+ device: Union[str, torch.device] = 'cpu',
+ jit: bool = False,
+ force_quick_gelu: bool = False,
+ force_custom_text: bool = False,
+ force_patch_dropout: Optional[float] = None,
+ force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
+ pretrained_image: bool = False,
+ pretrained_hf: bool = True,
+ image_mean: Optional[Tuple[float, ...]] = None,
+ image_std: Optional[Tuple[float, ...]] = None,
+ aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
+ cache_dir: Optional[str] = None,
+ light_augmentation = False,
+ output_dict: Optional[bool] = None,
+ with_score_predictor: bool = False,
+ with_region_predictor: bool = False
+):
+ model = create_model(
+ model_name,
+ pretrained,
+ precision=precision,
+ device=device,
+ jit=jit,
+ force_quick_gelu=force_quick_gelu,
+ force_custom_text=force_custom_text,
+ force_patch_dropout=force_patch_dropout,
+ force_image_size=force_image_size,
+ pretrained_image=pretrained_image,
+ pretrained_hf=pretrained_hf,
+ cache_dir=cache_dir,
+ output_dict=output_dict,
+ )
+
+ image_mean = image_mean or getattr(model.visual, 'image_mean', None)
+ image_std = image_std or getattr(model.visual, 'image_std', None)
+
+ if with_score_predictor:
+ model.score_predictor = MLP(model.visual.proj.size(1)).to(device=device, dtype=model.visual.proj.dtype)
+
+ if with_region_predictor:
+ # model.region_predictor = semantic_head(model.visual.proj.size(1)).to(device=device, dtype=model.visual.proj.dtype)
+ model.region_predictor = torch.nn.Linear(model.visual.proj.size(0), 1).to(device=device, dtype=model.visual.proj.dtype)
+ # preprocess_train = image_transform_region(
+ # model.visual.image_size,
+ # is_train=True,
+ # mean=image_mean,
+ # std=image_std
+ # )
+ # preprocess_val = image_transform_region(
+ # model.visual.image_size,
+ # is_train=False,
+ # mean=image_mean,
+ # std=image_std
+ # )
+
+ if light_augmentation:
+ preprocess_val = image_transform(
+ model.visual.image_size,
+ is_train=False,
+ mean=image_mean,
+ std=image_std,
+ resize_longest_max=True,
+ )
+ preprocess_train = preprocess_val
+ else:
+ preprocess_train = image_transform(
+ model.visual.image_size,
+ is_train=True,
+ mean=image_mean,
+ std=image_std
+ )
+ preprocess_val = image_transform(
+ model.visual.image_size,
+ is_train=False,
+ mean=image_mean,
+ std=image_std
+ )
+
+ return model, preprocess_train, preprocess_val
+
+
+def create_model_from_pretrained(
+ model_name: str,
+ pretrained: Optional[str] = None,
+ precision: str = 'fp32',
+ device: Union[str, torch.device] = 'cpu',
+ jit: bool = False,
+ force_quick_gelu: bool = False,
+ force_custom_text: bool = False,
+ force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
+ return_transform: bool = True,
+ image_mean: Optional[Tuple[float, ...]] = None,
+ image_std: Optional[Tuple[float, ...]] = None,
+ cache_dir: Optional[str] = None,
+):
+ model = create_model(
+ model_name,
+ pretrained,
+ precision=precision,
+ device=device,
+ jit=jit,
+ force_quick_gelu=force_quick_gelu,
+ force_custom_text=force_custom_text,
+ force_image_size=force_image_size,
+ cache_dir=cache_dir,
+ require_pretrained=True,
+ )
+
+ if not return_transform:
+ return model
+
+ image_mean = image_mean or getattr(model.visual, 'image_mean', None)
+ image_std = image_std or getattr(model.visual, 'image_std', None)
+ preprocess = image_transform(
+ model.visual.image_size,
+ is_train=False,
+ mean=image_mean,
+ std=image_std,
+ )
+
+ return model, preprocess
diff --git a/PusaV1/diffsynth/extensions/ImageQualityMetric/open_clip/generation_utils.py b/PusaV1/diffsynth/extensions/ImageQualityMetric/open_clip/generation_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/PusaV1/diffsynth/extensions/ImageQualityMetric/open_clip/hf_configs.py b/PusaV1/diffsynth/extensions/ImageQualityMetric/open_clip/hf_configs.py
new file mode 100644
index 0000000000000000000000000000000000000000..e236222bafce0358445ea16953ca0b2d5a84758a
--- /dev/null
+++ b/PusaV1/diffsynth/extensions/ImageQualityMetric/open_clip/hf_configs.py
@@ -0,0 +1,45 @@
+# HF architecture dict:
+arch_dict = {
+ # https://huggingface.co/docs/transformers/model_doc/roberta#roberta
+ "roberta": {
+ "config_names": {
+ "context_length": "max_position_embeddings",
+ "vocab_size": "vocab_size",
+ "width": "hidden_size",
+ "heads": "num_attention_heads",
+ "layers": "num_hidden_layers",
+ "layer_attr": "layer",
+ "token_embeddings_attr": "embeddings"
+ },
+ "pooler": "mean_pooler",
+ },
+ # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig
+ "xlm-roberta": {
+ "config_names": {
+ "context_length": "max_position_embeddings",
+ "vocab_size": "vocab_size",
+ "width": "hidden_size",
+ "heads": "num_attention_heads",
+ "layers": "num_hidden_layers",
+ "layer_attr": "layer",
+ "token_embeddings_attr": "embeddings"
+ },
+ "pooler": "mean_pooler",
+ },
+ # https://huggingface.co/docs/transformers/model_doc/mt5#mt5
+ "mt5": {
+ "config_names": {
+ # unlimited seqlen
+ # https://github.com/google-research/text-to-text-transfer-transformer/issues/273
+ # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374
+ "context_length": "",
+ "vocab_size": "vocab_size",
+ "width": "d_model",
+ "heads": "num_heads",
+ "layers": "num_layers",
+ "layer_attr": "block",
+ "token_embeddings_attr": "embed_tokens"
+ },
+ "pooler": "mean_pooler",
+ },
+}
diff --git a/PusaV1/diffsynth/extensions/ImageQualityMetric/open_clip/hf_model.py b/PusaV1/diffsynth/extensions/ImageQualityMetric/open_clip/hf_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..fbccc812757bf10b122ff14096980e0e38d1d221
--- /dev/null
+++ b/PusaV1/diffsynth/extensions/ImageQualityMetric/open_clip/hf_model.py
@@ -0,0 +1,176 @@
+""" huggingface model adapter
+
+Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model.
+"""
+
+import re
+
+import torch
+import torch.nn as nn
+from torch import TensorType
+
+try:
+ import transformers
+ from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \
+ BaseModelOutputWithPoolingAndCrossAttentions
+except ImportError as e:
+ transformers = None
+
+
+ class BaseModelOutput:
+ pass
+
+
+ class PretrainedConfig:
+ pass
+
+from .hf_configs import arch_dict
+
+
+# utils
+def _camel2snake(s):
+ return re.sub(r'(? torch.Tensor:
+ # calculated ground-truth and cache if enabled
+ if self.prev_num_logits != num_logits or device not in self.labels:
+ labels = torch.arange(num_logits, device=device, dtype=torch.long)
+ if self.world_size > 1 and self.local_loss:
+ labels = labels + num_logits * self.rank
+ if self.cache_labels:
+ self.labels[device] = labels
+ self.prev_num_logits = num_logits
+ else:
+ labels = self.labels[device]
+ return labels
+
+ def get_logits(self, image_features, text_features, logit_scale):
+ if self.world_size > 1:
+ all_image_features, all_text_features = gather_features(
+ image_features, text_features,
+ self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod)
+
+ if self.local_loss:
+ logits_per_image = logit_scale * image_features @ all_text_features.T
+ logits_per_text = logit_scale * text_features @ all_image_features.T
+ else:
+ logits_per_image = logit_scale * all_image_features @ all_text_features.T
+ logits_per_text = logits_per_image.T
+ else:
+ logits_per_image = logit_scale * image_features @ text_features.T
+ logits_per_text = logit_scale * text_features @ image_features.T
+
+ return logits_per_image, logits_per_text
+
+ def forward(self, image_features, text_features, logit_scale, output_dict=False):
+ device = image_features.device
+ logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale)
+
+ labels = self.get_ground_truth(device, logits_per_image.shape[0])
+
+ total_loss = (
+ F.cross_entropy(logits_per_image, labels) +
+ F.cross_entropy(logits_per_text, labels)
+ ) / 2
+ return total_loss
+
+class PreferenceLoss(nn.Module):
+
+ def forward(self, logits_per_image, num_images, labels):
+
+ paired_logits_list = [logit[:,i] for i, logit in enumerate(logits_per_image.split(num_images.tolist()))]
+ paired_logits = pad_sequence(paired_logits_list, batch_first=True, padding_value=-999)
+
+ ce_loss = F.cross_entropy(paired_logits, labels)
+ return ce_loss
+
+class HPSLoss(nn.Module):
+
+ def forward(self, text_logits, labels):
+
+ device = text_logits.device
+ text_0_logits, text_1_logits = text_logits.chunk(2, dim=-1)
+ label_0, label_1 = labels.chunk(2, dim=-1)
+
+ index = torch.arange(text_0_logits.shape[0], device=device, dtype=torch.long)
+ text_0_logits = text_0_logits[index, index]
+ text_1_logits = text_1_logits[index, index]
+ text_logits = torch.stack([text_0_logits, text_1_logits], dim=-1)
+ text_0_labels = torch.zeros(text_logits.shape[0], device=device, dtype=torch.long)
+ text_1_labels = text_0_labels + 1
+
+ text_0_loss = torch.nn.functional.cross_entropy(text_logits, text_0_labels, reduction="none")
+ text_1_loss = torch.nn.functional.cross_entropy(text_logits, text_1_labels, reduction="none")
+
+ text_loss = label_0 * text_0_loss + label_1 * text_1_loss
+
+ # absolute_example_weight = 1 / num_per_prompt
+ # denominator = absolute_example_weight.sum()
+ # weight_per_example = absolute_example_weight / denominator
+ # text_loss *= weight_per_example
+
+ text_loss = text_loss.sum()
+ return text_loss
+
+class RankingLoss(nn.Module):
+
+ def forward(self, logits_per_image, num_images, labels, margin = 1.0):
+ paired_logits_list = [logit[:,i] for i, logit in enumerate(logits_per_image.split(num_images.tolist()))]
+ label_list = [label for label in labels.split(num_images.tolist())]
+ # ranked_logits = [torch.index_select(paired_logits_list[i], 0, rank) for i, rank in enumerate(label_list)]
+
+ paired_logits = pad_sequence(paired_logits_list, batch_first=True, padding_value=-1)
+ padded_labels = pad_sequence(label_list, batch_first=True, padding_value=10)
+
+ # regulized_logits = torch.log(torch.sigmoid(paired_logits))
+
+ diff = paired_logits.unsqueeze(1) - paired_logits.unsqueeze(2)
+ # diff = paired_logits.unsqueeze(1) - paired_logits.unsqueeze(2)
+ # diff_label = torch.clamp(padded_labels.unsqueeze(1) - padded_labels.unsqueeze(2), min=-1, max=1)
+ diff_label = - (padded_labels.unsqueeze(1) - padded_labels.unsqueeze(2))
+ mask = torch.triu(torch.ones(diff.shape[1], diff.shape[1]), diagonal=1).bool().detach()
+
+ loss = torch.clamp(margin - torch.mul(diff[:, ~mask],diff_label[:,~mask]), min=0).mean()
+ return loss
+
+class CoCaLoss(ClipLoss):
+ def __init__(
+ self,
+ caption_loss_weight,
+ clip_loss_weight,
+ pad_id=0, # pad_token for open_clip custom tokenizer
+ local_loss=False,
+ gather_with_grad=False,
+ cache_labels=False,
+ rank=0,
+ world_size=1,
+ use_horovod=False,
+ ):
+ super().__init__(
+ local_loss=local_loss,
+ gather_with_grad=gather_with_grad,
+ cache_labels=cache_labels,
+ rank=rank,
+ world_size=world_size,
+ use_horovod=use_horovod
+ )
+
+ self.clip_loss_weight = clip_loss_weight
+ self.caption_loss_weight = caption_loss_weight
+ self.caption_loss = nn.CrossEntropyLoss(ignore_index=pad_id)
+
+ def forward(self, image_features, text_features, logits, labels, logit_scale, output_dict=False):
+ clip_loss = super().forward(image_features, text_features, logit_scale)
+ clip_loss = self.clip_loss_weight * clip_loss
+
+ caption_loss = self.caption_loss(
+ logits.permute(0, 2, 1),
+ labels,
+ )
+ caption_loss = caption_loss * self.caption_loss_weight
+
+ if output_dict:
+ return {"contrastive_loss": clip_loss, "caption_loss": caption_loss}
+
+ return clip_loss, caption_loss
+
+
+class DistillClipLoss(ClipLoss):
+
+ def dist_loss(self, teacher_logits, student_logits):
+ return -(teacher_logits.softmax(dim=1) * student_logits.log_softmax(dim=1)).sum(dim=1).mean(dim=0)
+
+ def forward(
+ self,
+ image_features,
+ text_features,
+ logit_scale,
+ dist_image_features,
+ dist_text_features,
+ dist_logit_scale,
+ output_dict=False,
+ ):
+ logits_per_image, logits_per_text = \
+ self.get_logits(image_features, text_features, logit_scale)
+
+ dist_logits_per_image, dist_logits_per_text = \
+ self.get_logits(dist_image_features, dist_text_features, dist_logit_scale)
+
+ labels = self.get_ground_truth(image_features.device, logits_per_image.shape[0])
+
+ contrastive_loss = (
+ F.cross_entropy(logits_per_image, labels) +
+ F.cross_entropy(logits_per_text, labels)
+ ) / 2
+
+ distill_loss = (
+ self.dist_loss(dist_logits_per_image, logits_per_image) +
+ self.dist_loss(dist_logits_per_text, logits_per_text)
+ ) / 2
+
+ if output_dict:
+ return {"contrastive_loss": contrastive_loss, "distill_loss": distill_loss}
+
+ return contrastive_loss, distill_loss
diff --git a/PusaV1/diffsynth/extensions/ImageQualityMetric/open_clip/model.py b/PusaV1/diffsynth/extensions/ImageQualityMetric/open_clip/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..e347c42fc8df6464ca28e59adadba61e53a38add
--- /dev/null
+++ b/PusaV1/diffsynth/extensions/ImageQualityMetric/open_clip/model.py
@@ -0,0 +1,461 @@
+""" CLIP Model
+
+Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
+"""
+from dataclasses import dataclass
+import logging
+import math
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import nn
+from torch.utils.checkpoint import checkpoint
+
+from .hf_model import HFTextEncoder
+from .modified_resnet import ModifiedResNet
+from .timm_model import TimmModel
+from .transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer
+from .utils import to_2tuple
+
+
+@dataclass
+class CLIPVisionCfg:
+ layers: Union[Tuple[int, int, int, int], int] = 12
+ width: int = 768
+ head_width: int = 64
+ mlp_ratio: float = 4.0
+ patch_size: int = 16
+ image_size: Union[Tuple[int, int], int] = 224
+ ls_init_value: Optional[float] = None # layer scale initial value
+ patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
+ input_patchnorm: bool = False # whether to use dual patchnorm - would only apply the input layernorm on each patch, as post-layernorm already exist in original clip vit design
+ global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)
+ attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer
+ n_queries: int = 256 # n_queries for attentional pooler
+ attn_pooler_heads: int = 8 # n heads for attentional_pooling
+ timm_model_name: str = None # a valid model name overrides layers, width, patch_size
+ timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
+ timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
+ timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
+ timm_proj_bias: bool = False # enable bias final projection
+ timm_drop: float = 0. # head dropout
+ timm_drop_path: Optional[float] = None # backbone stochastic depth
+ output_tokens: bool = False
+
+
+@dataclass
+class CLIPTextCfg:
+ context_length: int = 77
+ vocab_size: int = 49408
+ width: int = 512
+ heads: int = 8
+ layers: int = 12
+ ls_init_value: Optional[float] = None # layer scale initial value
+ hf_model_name: str = None
+ hf_tokenizer_name: str = None
+ hf_model_pretrained: bool = True
+ proj: str = 'mlp'
+ pooler_type: str = 'mean_pooler'
+ embed_cls: bool = False
+ pad_id: int = 0
+ output_tokens: bool = False
+
+
+def get_cast_dtype(precision: str):
+ cast_dtype = None
+ if precision == 'bf16':
+ cast_dtype = torch.bfloat16
+ elif precision == 'fp16':
+ cast_dtype = torch.float16
+ return cast_dtype
+
+
+def _build_vision_tower(
+ embed_dim: int,
+ vision_cfg: CLIPVisionCfg,
+ quick_gelu: bool = False,
+ cast_dtype: Optional[torch.dtype] = None
+):
+ if isinstance(vision_cfg, dict):
+ vision_cfg = CLIPVisionCfg(**vision_cfg)
+
+ # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
+ # memory efficient in recent PyTorch releases (>= 1.10).
+ # NOTE: timm models always use native GELU regardless of quick_gelu flag.
+ act_layer = QuickGELU if quick_gelu else nn.GELU
+
+ if vision_cfg.timm_model_name:
+ visual = TimmModel(
+ vision_cfg.timm_model_name,
+ pretrained=vision_cfg.timm_model_pretrained,
+ pool=vision_cfg.timm_pool,
+ proj=vision_cfg.timm_proj,
+ proj_bias=vision_cfg.timm_proj_bias,
+ drop=vision_cfg.timm_drop,
+ drop_path=vision_cfg.timm_drop_path,
+ embed_dim=embed_dim,
+ image_size=vision_cfg.image_size,
+ )
+ act_layer = nn.GELU # so that text transformer doesn't use QuickGELU w/ timm models
+ elif isinstance(vision_cfg.layers, (tuple, list)):
+ vision_heads = vision_cfg.width * 32 // vision_cfg.head_width
+ visual = ModifiedResNet(
+ layers=vision_cfg.layers,
+ output_dim=embed_dim,
+ heads=vision_heads,
+ image_size=vision_cfg.image_size,
+ width=vision_cfg.width,
+ )
+ else:
+ vision_heads = vision_cfg.width // vision_cfg.head_width
+ norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
+ visual = VisionTransformer(
+ image_size=vision_cfg.image_size,
+ patch_size=vision_cfg.patch_size,
+ width=vision_cfg.width,
+ layers=vision_cfg.layers,
+ heads=vision_heads,
+ mlp_ratio=vision_cfg.mlp_ratio,
+ ls_init_value=vision_cfg.ls_init_value,
+ patch_dropout=vision_cfg.patch_dropout,
+ input_patchnorm=vision_cfg.input_patchnorm,
+ global_average_pool=vision_cfg.global_average_pool,
+ attentional_pool=vision_cfg.attentional_pool,
+ n_queries=vision_cfg.n_queries,
+ attn_pooler_heads=vision_cfg.attn_pooler_heads,
+ output_tokens=vision_cfg.output_tokens,
+ output_dim=embed_dim,
+ act_layer=act_layer,
+ norm_layer=norm_layer,
+ )
+
+ return visual
+
+
+def _build_text_tower(
+ embed_dim: int,
+ text_cfg: CLIPTextCfg,
+ quick_gelu: bool = False,
+ cast_dtype: Optional[torch.dtype] = None,
+):
+ if isinstance(text_cfg, dict):
+ text_cfg = CLIPTextCfg(**text_cfg)
+
+ if text_cfg.hf_model_name:
+ text = HFTextEncoder(
+ text_cfg.hf_model_name,
+ output_dim=embed_dim,
+ proj=text_cfg.proj,
+ pooler_type=text_cfg.pooler_type,
+ pretrained=text_cfg.hf_model_pretrained,
+ output_tokens=text_cfg.output_tokens,
+ )
+ else:
+ act_layer = QuickGELU if quick_gelu else nn.GELU
+ norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
+
+ text = TextTransformer(
+ context_length=text_cfg.context_length,
+ vocab_size=text_cfg.vocab_size,
+ width=text_cfg.width,
+ heads=text_cfg.heads,
+ layers=text_cfg.layers,
+ ls_init_value=text_cfg.ls_init_value,
+ output_dim=embed_dim,
+ embed_cls=text_cfg.embed_cls,
+ output_tokens=text_cfg.output_tokens,
+ pad_id=text_cfg.pad_id,
+ act_layer=act_layer,
+ norm_layer=norm_layer,
+ )
+ return text
+
+
+class CLIP(nn.Module):
+ output_dict: torch.jit.Final[bool]
+
+ def __init__(
+ self,
+ embed_dim: int,
+ vision_cfg: CLIPVisionCfg,
+ text_cfg: CLIPTextCfg,
+ quick_gelu: bool = False,
+ cast_dtype: Optional[torch.dtype] = None,
+ output_dict: bool = False,
+ ):
+ super().__init__()
+ self.output_dict = output_dict
+ self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
+
+ text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
+ self.transformer = text.transformer
+ self.vocab_size = text.vocab_size
+ self.token_embedding = text.token_embedding
+ self.positional_embedding = text.positional_embedding
+ self.ln_final = text.ln_final
+ self.text_projection = text.text_projection
+ self.register_buffer('attn_mask', text.attn_mask, persistent=False)
+
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
+
+ def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
+ # lock image tower as per LiT - https://arxiv.org/abs/2111.07991
+ self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
+
+ def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
+ locked_layers = []
+ locked_layers.append(self.token_embedding)
+ self.positional_embedding.requires_grad = False
+ if unlocked_layers > 0:
+ locked_layers.append(self.transformer.resblocks[:-unlocked_layers])
+ else:
+ locked_layers.append(self.transformer)
+ locked_layers.append(self.ln_final)
+ self.text_projection.requires_grad = False
+
+ # freeze layers
+ for module in locked_layers:
+ for n, p in module.named_parameters():
+ p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
+
+ @torch.jit.ignore
+ def set_grad_checkpointing(self, enable=True):
+ self.visual.set_grad_checkpointing(enable)
+ self.transformer.grad_checkpointing = enable
+
+ def encode_image(self, image, normalize: bool = False):
+ features = self.visual(image)
+ return F.normalize(features, dim=-1) if normalize else features
+
+ def encode_text(self, text, normalize: bool = False):
+ cast_dtype = self.transformer.get_cast_dtype()
+
+ x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
+
+ x = x + self.positional_embedding.to(cast_dtype)
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.transformer(x, attn_mask=self.attn_mask)
+ x = x.permute(1, 0, 2) # LND -> NLD
+ x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
+ return F.normalize(x, dim=-1) if normalize else x
+
+ def forward(self, image, text):
+ image_features = self.encode_image(image, normalize=True)
+ text_features = self.encode_text(text, normalize=True)
+ if self.output_dict:
+ return {
+ "image_features": image_features,
+ "text_features": text_features,
+ "logit_scale": self.logit_scale.exp()
+ }
+ return image_features, text_features, self.logit_scale.exp()
+
+
+class CustomTextCLIP(nn.Module):
+ output_dict: torch.jit.Final[bool]
+
+ def __init__(
+ self,
+ embed_dim: int,
+ vision_cfg: CLIPVisionCfg,
+ text_cfg: CLIPTextCfg,
+ quick_gelu: bool = False,
+ cast_dtype: Optional[torch.dtype] = None,
+ output_dict: bool = False,
+ ):
+ super().__init__()
+ self.output_dict = output_dict
+ self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
+ self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
+
+ def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
+ # lock image tower as per LiT - https://arxiv.org/abs/2111.07991
+ self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
+
+ def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
+ self.text.lock(unlocked_layers, freeze_layer_norm)
+
+ @torch.jit.ignore
+ def set_grad_checkpointing(self, enable=True):
+ self.visual.set_grad_checkpointing(enable)
+ self.text.set_grad_checkpointing(enable)
+
+ def encode_image(self, image, normalize: bool = False):
+ features = self.visual(image)
+ return F.normalize(features, dim=-1) if normalize else features
+
+ def encode_text(self, text, normalize: bool = False):
+ features = self.text(text)
+ return F.normalize(features, dim=-1) if normalize else features
+
+ def forward(self, image, text):
+ image_features = self.encode_image(image, normalize=True)
+ text_features = self.encode_text(text, normalize=True)
+ if self.output_dict:
+ return {
+ "image_features": image_features,
+ "text_features": text_features,
+ "logit_scale": self.logit_scale.exp()
+ }
+ return image_features, text_features, self.logit_scale.exp()
+
+
+def convert_weights_to_lp(model: nn.Module, dtype=torch.float16):
+ """Convert applicable model parameters to low-precision (bf16 or fp16)"""
+
+ def _convert_weights(l):
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
+ l.weight.data = l.weight.data.to(dtype)
+ if l.bias is not None:
+ l.bias.data = l.bias.data.to(dtype)
+
+ if isinstance(l, (nn.MultiheadAttention, Attention)):
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
+ tensor = getattr(l, attr)
+ if tensor is not None:
+ tensor.data = tensor.data.to(dtype)
+
+ for name in ["text_projection", "proj"]:
+ if hasattr(l, name):
+ attr = getattr(l, name)
+ if attr is not None:
+ attr.data = attr.data.to(dtype)
+
+ model.apply(_convert_weights)
+
+
+convert_weights_to_fp16 = convert_weights_to_lp # backwards compat
+
+
+# used to maintain checkpoint compatibility
+def convert_to_custom_text_state_dict(state_dict: dict):
+ if 'text_projection' in state_dict:
+ # old format state_dict, move text tower -> .text
+ new_state_dict = {}
+ for k, v in state_dict.items():
+ if any(k.startswith(p) for p in (
+ 'text_projection',
+ 'positional_embedding',
+ 'token_embedding',
+ 'transformer',
+ 'ln_final',
+ )):
+ k = 'text.' + k
+ new_state_dict[k] = v
+ return new_state_dict
+ return state_dict
+
+
+def build_model_from_openai_state_dict(
+ state_dict: dict,
+ quick_gelu=True,
+ cast_dtype=torch.float16,
+):
+ vit = "visual.proj" in state_dict
+
+ if vit:
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
+ vision_layers = len(
+ [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
+ image_size = vision_patch_size * grid_size
+ else:
+ counts: list = [
+ len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
+ vision_layers = tuple(counts)
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
+ output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
+ vision_patch_size = None
+ assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
+ image_size = output_width * 32
+
+ embed_dim = state_dict["text_projection"].shape[1]
+ context_length = state_dict["positional_embedding"].shape[0]
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
+ transformer_width = state_dict["ln_final.weight"].shape[0]
+ transformer_heads = transformer_width // 64
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
+
+ vision_cfg = CLIPVisionCfg(
+ layers=vision_layers,
+ width=vision_width,
+ patch_size=vision_patch_size,
+ image_size=image_size,
+ )
+ text_cfg = CLIPTextCfg(
+ context_length=context_length,
+ vocab_size=vocab_size,
+ width=transformer_width,
+ heads=transformer_heads,
+ layers=transformer_layers,
+ )
+ model = CLIP(
+ embed_dim,
+ vision_cfg=vision_cfg,
+ text_cfg=text_cfg,
+ quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU
+ cast_dtype=cast_dtype,
+ )
+
+ for key in ["input_resolution", "context_length", "vocab_size"]:
+ state_dict.pop(key, None)
+
+ convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16
+ model.load_state_dict(state_dict)
+ return model.eval()
+
+
+def trace_model(model, batch_size=256, device=torch.device('cpu')):
+ model.eval()
+ image_size = model.visual.image_size
+ example_images = torch.ones((batch_size, 3, image_size, image_size), device=device)
+ example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device)
+ model = torch.jit.trace_module(
+ model,
+ inputs=dict(
+ forward=(example_images, example_text),
+ encode_text=(example_text,),
+ encode_image=(example_images,)
+ ))
+ model.visual.image_size = image_size
+ return model
+
+
+def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True):
+ # Rescale the grid of position embeddings when loading from state_dict
+ old_pos_embed = state_dict.get('visual.positional_embedding', None)
+ if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):
+ return
+ grid_size = to_2tuple(model.visual.grid_size)
+ extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
+ new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
+ if new_seq_len == old_pos_embed.shape[0]:
+ return
+
+ if extra_tokens:
+ pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
+ else:
+ pos_emb_tok, pos_emb_img = None, old_pos_embed
+ old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
+
+ logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
+ pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
+ pos_emb_img = F.interpolate(
+ pos_emb_img,
+ size=grid_size,
+ mode=interpolation,
+ antialias=antialias,
+ align_corners=False,
+ )
+ pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
+ if pos_emb_tok is not None:
+ new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
+ else:
+ new_pos_embed = pos_emb_img
+ state_dict['visual.positional_embedding'] = new_pos_embed
diff --git a/PusaV1/diffsynth/extensions/ImageQualityMetric/open_clip/model_configs/ViT-H-14.json b/PusaV1/diffsynth/extensions/ImageQualityMetric/open_clip/model_configs/ViT-H-14.json
new file mode 100644
index 0000000000000000000000000000000000000000..3e3a7e934e7f02e41f4829996c4950e05f015a74
--- /dev/null
+++ b/PusaV1/diffsynth/extensions/ImageQualityMetric/open_clip/model_configs/ViT-H-14.json
@@ -0,0 +1,17 @@
+{
+ "embed_dim": 1024,
+ "vision_cfg": {
+ "image_size": 224,
+ "layers": 32,
+ "width": 1280,
+ "head_width": 80,
+ "patch_size": 14
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 1024,
+ "heads": 16,
+ "layers": 24
+ }
+}
\ No newline at end of file
diff --git a/PusaV1/diffsynth/extensions/ImageQualityMetric/open_clip/modified_resnet.py b/PusaV1/diffsynth/extensions/ImageQualityMetric/open_clip/modified_resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a8d3aeda91ecb394303becbbfccc8acd8cddcd9
--- /dev/null
+++ b/PusaV1/diffsynth/extensions/ImageQualityMetric/open_clip/modified_resnet.py
@@ -0,0 +1,181 @@
+from collections import OrderedDict
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from .utils import freeze_batch_norm_2d
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self, inplanes, planes, stride=1):
+ super().__init__()
+
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.act1 = nn.ReLU(inplace=True)
+
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.act2 = nn.ReLU(inplace=True)
+
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
+
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
+ self.act3 = nn.ReLU(inplace=True)
+
+ self.downsample = None
+ self.stride = stride
+
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
+ self.downsample = nn.Sequential(OrderedDict([
+ ("-1", nn.AvgPool2d(stride)),
+ ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
+ ("1", nn.BatchNorm2d(planes * self.expansion))
+ ]))
+
+ def forward(self, x: torch.Tensor):
+ identity = x
+
+ out = self.act1(self.bn1(self.conv1(x)))
+ out = self.act2(self.bn2(self.conv2(out)))
+ out = self.avgpool(out)
+ out = self.bn3(self.conv3(out))
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.act3(out)
+ return out
+
+
+class AttentionPool2d(nn.Module):
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
+ super().__init__()
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
+ self.num_heads = num_heads
+
+ def forward(self, x):
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
+ x, _ = F.multi_head_attention_forward(
+ query=x, key=x, value=x,
+ embed_dim_to_check=x.shape[-1],
+ num_heads=self.num_heads,
+ q_proj_weight=self.q_proj.weight,
+ k_proj_weight=self.k_proj.weight,
+ v_proj_weight=self.v_proj.weight,
+ in_proj_weight=None,
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
+ bias_k=None,
+ bias_v=None,
+ add_zero_attn=False,
+ dropout_p=0.,
+ out_proj_weight=self.c_proj.weight,
+ out_proj_bias=self.c_proj.bias,
+ use_separate_proj_weight=True,
+ training=self.training,
+ need_weights=False
+ )
+
+ return x[0]
+
+
+class ModifiedResNet(nn.Module):
+ """
+ A ResNet class that is similar to torchvision's but contains the following changes:
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
+ - The final pooling layer is a QKV attention instead of an average pool
+ """
+
+ def __init__(self, layers, output_dim, heads, image_size=224, width=64):
+ super().__init__()
+ self.output_dim = output_dim
+ self.image_size = image_size
+
+ # the 3-layer stem
+ self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(width // 2)
+ self.act1 = nn.ReLU(inplace=True)
+ self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(width // 2)
+ self.act2 = nn.ReLU(inplace=True)
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
+ self.bn3 = nn.BatchNorm2d(width)
+ self.act3 = nn.ReLU(inplace=True)
+ self.avgpool = nn.AvgPool2d(2)
+
+ # residual layers
+ self._inplanes = width # this is a *mutable* variable used during construction
+ self.layer1 = self._make_layer(width, layers[0])
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
+
+ embed_dim = width * 32 # the ResNet feature dimension
+ self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)
+
+ self.init_parameters()
+
+ def _make_layer(self, planes, blocks, stride=1):
+ layers = [Bottleneck(self._inplanes, planes, stride)]
+
+ self._inplanes = planes * Bottleneck.expansion
+ for _ in range(1, blocks):
+ layers.append(Bottleneck(self._inplanes, planes))
+
+ return nn.Sequential(*layers)
+
+ def init_parameters(self):
+ if self.attnpool is not None:
+ std = self.attnpool.c_proj.in_features ** -0.5
+ nn.init.normal_(self.attnpool.q_proj.weight, std=std)
+ nn.init.normal_(self.attnpool.k_proj.weight, std=std)
+ nn.init.normal_(self.attnpool.v_proj.weight, std=std)
+ nn.init.normal_(self.attnpool.c_proj.weight, std=std)
+
+ for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
+ for name, param in resnet_block.named_parameters():
+ if name.endswith("bn3.weight"):
+ nn.init.zeros_(param)
+
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
+ assert unlocked_groups == 0, 'partial locking not currently supported for this model'
+ for param in self.parameters():
+ param.requires_grad = False
+ if freeze_bn_stats:
+ freeze_batch_norm_2d(self)
+
+ @torch.jit.ignore
+ def set_grad_checkpointing(self, enable=True):
+ # FIXME support for non-transformer
+ pass
+
+ def stem(self, x):
+ x = self.act1(self.bn1(self.conv1(x)))
+ x = self.act2(self.bn2(self.conv2(x)))
+ x = self.act3(self.bn3(self.conv3(x)))
+ x = self.avgpool(x)
+ return x
+
+ def forward(self, x):
+ x = self.stem(x)
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+ x = self.attnpool(x)
+
+ return x
diff --git a/PusaV1/diffsynth/extensions/ImageQualityMetric/open_clip/openai.py b/PusaV1/diffsynth/extensions/ImageQualityMetric/open_clip/openai.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc4e13e876d6a7a3463b457e62c517cb063b1356
--- /dev/null
+++ b/PusaV1/diffsynth/extensions/ImageQualityMetric/open_clip/openai.py
@@ -0,0 +1,144 @@
+""" OpenAI pretrained model functions
+
+Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
+"""
+
+import os
+import warnings
+from typing import List, Optional, Union
+
+import torch
+
+from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype
+from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url
+
+__all__ = ["list_openai_models", "load_openai_model"]
+
+
+def list_openai_models() -> List[str]:
+ """Returns the names of available CLIP models"""
+ return list_pretrained_models_by_tag('openai')
+
+
+def load_openai_model(
+ name: str,
+ precision: Optional[str] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ jit: bool = True,
+ cache_dir: Optional[str] = None,
+):
+ """Load a CLIP model
+
+ Parameters
+ ----------
+ name : str
+ A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
+ precision: str
+ Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'.
+ device : Union[str, torch.device]
+ The device to put the loaded model
+ jit : bool
+ Whether to load the optimized JIT model (default) or more hackable non-JIT model.
+ cache_dir : Optional[str]
+ The directory to cache the downloaded model weights
+
+ Returns
+ -------
+ model : torch.nn.Module
+ The CLIP model
+ preprocess : Callable[[PIL.Image], torch.Tensor]
+ A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
+ """
+ if device is None:
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ if precision is None:
+ precision = 'fp32' if device == 'cpu' else 'fp16'
+
+ if get_pretrained_url(name, 'openai'):
+ model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir)
+ elif os.path.isfile(name):
+ model_path = name
+ else:
+ raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}")
+
+ try:
+ # loading JIT archive
+ model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
+ state_dict = None
+ except RuntimeError:
+ # loading saved state dict
+ if jit:
+ warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
+ jit = False
+ state_dict = torch.load(model_path, map_location="cpu")
+
+ if not jit:
+ # Build a non-jit model from the OpenAI jitted model state dict
+ cast_dtype = get_cast_dtype(precision)
+ try:
+ model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype)
+ except KeyError:
+ sd = {k[7:]: v for k, v in state_dict["state_dict"].items()}
+ model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype)
+
+ # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use
+ model = model.to(device)
+ if precision.startswith('amp') or precision == 'fp32':
+ model.float()
+ elif precision == 'bf16':
+ convert_weights_to_lp(model, dtype=torch.bfloat16)
+
+ return model
+
+ # patch the device names
+ device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
+ device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
+
+ def patch_device(module):
+ try:
+ graphs = [module.graph] if hasattr(module, "graph") else []
+ except RuntimeError:
+ graphs = []
+
+ if hasattr(module, "forward1"):
+ graphs.append(module.forward1.graph)
+
+ for graph in graphs:
+ for node in graph.findAllNodes("prim::Constant"):
+ if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
+ node.copyAttributes(device_node)
+
+ model.apply(patch_device)
+ patch_device(model.encode_image)
+ patch_device(model.encode_text)
+
+ # patch dtype to float32 (typically for CPU)
+ if precision == 'fp32':
+ float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
+ float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
+ float_node = float_input.node()
+
+ def patch_float(module):
+ try:
+ graphs = [module.graph] if hasattr(module, "graph") else []
+ except RuntimeError:
+ graphs = []
+
+ if hasattr(module, "forward1"):
+ graphs.append(module.forward1.graph)
+
+ for graph in graphs:
+ for node in graph.findAllNodes("aten::to"):
+ inputs = list(node.inputs())
+ for i in [1, 2]: # dtype can be the second or third argument to aten::to()
+ if inputs[i].node()["value"] == 5:
+ inputs[i].node().copyAttributes(float_node)
+
+ model.apply(patch_float)
+ patch_float(model.encode_image)
+ patch_float(model.encode_text)
+ model.float()
+
+ # ensure image_size attr available at consistent location for both jit and non-jit
+ model.visual.image_size = model.input_resolution.item()
+ return model
diff --git a/PusaV1/diffsynth/extensions/ImageQualityMetric/open_clip/pretrained.py b/PusaV1/diffsynth/extensions/ImageQualityMetric/open_clip/pretrained.py
new file mode 100644
index 0000000000000000000000000000000000000000..87e7e527497d643fdf6ac931ac73b6e887a90d0d
--- /dev/null
+++ b/PusaV1/diffsynth/extensions/ImageQualityMetric/open_clip/pretrained.py
@@ -0,0 +1,376 @@
+import hashlib
+import os
+import urllib
+import warnings
+from functools import partial
+from typing import Dict, Union
+
+from tqdm import tqdm
+
+from .version import __version__
+
+try:
+ from huggingface_hub import hf_hub_download
+ hf_hub_download = partial(hf_hub_download, library_name="open_clip", library_version=__version__)
+ _has_hf_hub = True
+except ImportError:
+ hf_hub_download = None
+ _has_hf_hub = False
+
+
+def _pcfg(url='', hf_hub='', mean=None, std=None):
+ return dict(
+ url=url,
+ hf_hub=hf_hub,
+ mean=mean,
+ std=std,
+ )
+
+
+_RN50 = dict(
+ openai=_pcfg(
+ "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"),
+ yfcc15m=_pcfg(
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"),
+ cc12m=_pcfg(
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"),
+)
+
+_RN50_quickgelu = dict(
+ openai=_pcfg(
+ "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"),
+ yfcc15m=_pcfg(
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"),
+ cc12m=_pcfg(
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"),
+)
+
+_RN101 = dict(
+ openai=_pcfg(
+ "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"),
+ yfcc15m=_pcfg(
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"),
+)
+
+_RN101_quickgelu = dict(
+ openai=_pcfg(
+ "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"),
+ yfcc15m=_pcfg(
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"),
+)
+
+_RN50x4 = dict(
+ openai=_pcfg(
+ "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt"),
+)
+
+_RN50x16 = dict(
+ openai=_pcfg(
+ "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt"),
+)
+
+_RN50x64 = dict(
+ openai=_pcfg(
+ "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt"),
+)
+
+_VITB32 = dict(
+ openai=_pcfg(
+ "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"),
+ laion400m_e31=_pcfg(
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"),
+ laion400m_e32=_pcfg(
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"),
+ laion2b_e16=_pcfg(
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth"),
+ laion2b_s34b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-laion2B-s34B-b79K/')
+)
+
+_VITB32_quickgelu = dict(
+ openai=_pcfg(
+ "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"),
+ laion400m_e31=_pcfg(
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"),
+ laion400m_e32=_pcfg(
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"),
+)
+
+_VITB16 = dict(
+ openai=_pcfg(
+ "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"),
+ laion400m_e31=_pcfg(
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt"),
+ laion400m_e32=_pcfg(
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt"),
+ # laion400m_32k=_pcfg(
+ # url="",
+ # mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
+ # laion400m_64k=_pcfg(
+ # url="",
+ # mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
+ laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-laion2B-s34B-b88K/'),
+)
+
+_VITB16_PLUS_240 = dict(
+ laion400m_e31=_pcfg(
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt"),
+ laion400m_e32=_pcfg(
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"),
+)
+
+_VITL14 = dict(
+ openai=_pcfg(
+ "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"),
+ laion400m_e31=_pcfg(
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt"),
+ laion400m_e32=_pcfg(
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt"),
+ laion2b_s32b_b82k=_pcfg(
+ hf_hub='laion/CLIP-ViT-L-14-laion2B-s32B-b82K/',
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
+)
+
+_VITL14_336 = dict(
+ openai=_pcfg(
+ "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"),
+)
+
+_VITH14 = dict(
+ laion2b_s32b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-laion2B-s32B-b79K/'),
+)
+
+_VITg14 = dict(
+ laion2b_s12b_b42k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s12B-b42K/'),
+ laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s34B-b88K/'),
+)
+
+_VITbigG14 = dict(
+ laion2b_s39b_b160k=_pcfg(hf_hub='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/'),
+)
+
+_robertaViTB32 = dict(
+ laion2b_s12b_b32k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-roberta-base-laion2B-s12B-b32k/'),
+)
+
+_xlmRobertaBaseViTB32 = dict(
+ laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-xlm-roberta-base-laion5B-s13B-b90k/'),
+)
+
+_xlmRobertaLargeFrozenViTH14 = dict(
+ frozen_laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-frozen-xlm-roberta-large-laion5B-s13B-b90k/'),
+)
+
+_convnext_base = dict(
+ laion400m_s13b_b51k=_pcfg(hf_hub='laion/CLIP-convnext_base-laion400M-s13B-b51K/'),
+)
+
+_convnext_base_w = dict(
+ laion2b_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K/'),
+ laion2b_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K-augreg/'),
+ laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion_aesthetic-s13B-b82K/'),
+)
+
+_convnext_base_w_320 = dict(
+ laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K/'),
+ laion_aesthetic_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K-augreg/'),
+)
+
+_convnext_large_d = dict(
+ laion2b_s26b_b102k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_large_d.laion2B-s26B-b102K-augreg/'),
+)
+
+_convnext_large_d_320 = dict(
+ laion2b_s29b_b131k_ft=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft/'),
+ laion2b_s29b_b131k_ft_soup=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup/'),
+)
+
+_convnext_xxlarge = dict(
+ laion2b_s34b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg/'),
+ laion2b_s34b_b82k_augreg_rewind=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-rewind/'),
+ laion2b_s34b_b82k_augreg_soup=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup/'),
+)
+
+_coca_VITB32 = dict(
+ laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-B-32-laion2B-s13B-b90k/'),
+ mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-B-32-laion2B-s13B-b90k/')
+)
+
+_coca_VITL14 = dict(
+ laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-L-14-laion2B-s13B-b90k/'),
+ mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-L-14-laion2B-s13B-b90k/')
+)
+
+
+_PRETRAINED = {
+ "RN50": _RN50,
+ "RN50-quickgelu": _RN50_quickgelu,
+ "RN101": _RN101,
+ "RN101-quickgelu": _RN101_quickgelu,
+ "RN50x4": _RN50x4,
+ "RN50x16": _RN50x16,
+ "RN50x64": _RN50x64,
+ "ViT-B-32": _VITB32,
+ "ViT-B-32-quickgelu": _VITB32_quickgelu,
+ "ViT-B-16": _VITB16,
+ "ViT-B-16-plus-240": _VITB16_PLUS_240,
+ "ViT-L-14": _VITL14,
+ "ViT-L-14-336": _VITL14_336,
+ "ViT-H-14": _VITH14,
+ "ViT-g-14": _VITg14,
+ "ViT-bigG-14": _VITbigG14,
+ "roberta-ViT-B-32": _robertaViTB32,
+ "xlm-roberta-base-ViT-B-32": _xlmRobertaBaseViTB32,
+ "xlm-roberta-large-ViT-H-14": _xlmRobertaLargeFrozenViTH14,
+ "convnext_base": _convnext_base,
+ "convnext_base_w": _convnext_base_w,
+ "convnext_base_w_320": _convnext_base_w_320,
+ "convnext_large_d": _convnext_large_d,
+ "convnext_large_d_320": _convnext_large_d_320,
+ "convnext_xxlarge": _convnext_xxlarge,
+ "coca_ViT-B-32": _coca_VITB32,
+ "coca_ViT-L-14": _coca_VITL14,
+}
+
+
+def _clean_tag(tag: str):
+ # normalize pretrained tags
+ return tag.lower().replace('-', '_')
+
+
+def list_pretrained(as_str: bool = False):
+ """ returns list of pretrained models
+ Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True
+ """
+ return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()]
+
+
+def list_pretrained_models_by_tag(tag: str):
+ """ return all models having the specified pretrain tag """
+ models = []
+ tag = _clean_tag(tag)
+ for k in _PRETRAINED.keys():
+ if tag in _PRETRAINED[k]:
+ models.append(k)
+ return models
+
+
+def list_pretrained_tags_by_model(model: str):
+ """ return all pretrain tags for the specified model architecture """
+ tags = []
+ if model in _PRETRAINED:
+ tags.extend(_PRETRAINED[model].keys())
+ return tags
+
+
+def is_pretrained_cfg(model: str, tag: str):
+ if model not in _PRETRAINED:
+ return False
+ return _clean_tag(tag) in _PRETRAINED[model]
+
+
+def get_pretrained_cfg(model: str, tag: str):
+ if model not in _PRETRAINED:
+ return {}
+ model_pretrained = _PRETRAINED[model]
+ return model_pretrained.get(_clean_tag(tag), {})
+
+
+def get_pretrained_url(model: str, tag: str):
+ cfg = get_pretrained_cfg(model, _clean_tag(tag))
+ return cfg.get('url', '')
+
+
+def download_pretrained_from_url(
+ url: str,
+ cache_dir: Union[str, None] = None,
+):
+ if not cache_dir:
+ cache_dir = os.path.expanduser("~/.cache/clip")
+ os.makedirs(cache_dir, exist_ok=True)
+ filename = os.path.basename(url)
+
+ if 'openaipublic' in url:
+ expected_sha256 = url.split("/")[-2]
+ elif 'mlfoundations' in url:
+ expected_sha256 = os.path.splitext(filename)[0].split("-")[-1]
+ else:
+ expected_sha256 = ''
+
+ download_target = os.path.join(cache_dir, filename)
+
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
+
+ if os.path.isfile(download_target):
+ if expected_sha256:
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
+ return download_target
+ else:
+ warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
+ else:
+ return download_target
+
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
+ with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
+ while True:
+ buffer = source.read(8192)
+ if not buffer:
+ break
+
+ output.write(buffer)
+ loop.update(len(buffer))
+
+ if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
+ raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
+
+ return download_target
+
+
+def has_hf_hub(necessary=False):
+ if not _has_hf_hub and necessary:
+ # if no HF Hub module installed, and it is necessary to continue, raise error
+ raise RuntimeError(
+ 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.')
+ return _has_hf_hub
+
+
+def download_pretrained_from_hf(
+ model_id: str,
+ filename: str = 'open_clip_pytorch_model.bin',
+ revision=None,
+ cache_dir: Union[str, None] = None,
+):
+ has_hf_hub(True)
+ cached_file = hf_hub_download(model_id, filename, revision=revision, cache_dir=cache_dir)
+ return cached_file
+
+
+def download_pretrained(
+ cfg: Dict,
+ force_hf_hub: bool = False,
+ cache_dir: Union[str, None] = None,
+):
+ target = ''
+ if not cfg:
+ return target
+
+ download_url = cfg.get('url', '')
+ download_hf_hub = cfg.get('hf_hub', '')
+ if download_hf_hub and force_hf_hub:
+ # use HF hub even if url exists
+ download_url = ''
+
+ if download_url:
+ target = download_pretrained_from_url(download_url, cache_dir=cache_dir)
+ elif download_hf_hub:
+ has_hf_hub(True)
+ # we assume the hf_hub entries in pretrained config combine model_id + filename in
+ # 'org/model_name/filename.pt' form. To specify just the model id w/o filename and
+ # use 'open_clip_pytorch_model.bin' default, there must be a trailing slash 'org/model_name/'.
+ model_id, filename = os.path.split(download_hf_hub)
+ if filename:
+ target = download_pretrained_from_hf(model_id, filename=filename, cache_dir=cache_dir)
+ else:
+ target = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
+
+ return target
diff --git a/PusaV1/diffsynth/extensions/ImageQualityMetric/open_clip/push_to_hf_hub.py b/PusaV1/diffsynth/extensions/ImageQualityMetric/open_clip/push_to_hf_hub.py
new file mode 100644
index 0000000000000000000000000000000000000000..23c0631c81dcb43829b7374fac09406ecefcb436
--- /dev/null
+++ b/PusaV1/diffsynth/extensions/ImageQualityMetric/open_clip/push_to_hf_hub.py
@@ -0,0 +1,243 @@
+import argparse
+import json
+from pathlib import Path
+from tempfile import TemporaryDirectory
+from typing import Optional, Tuple
+
+import torch
+
+try:
+ from huggingface_hub import (
+ create_repo,
+ get_hf_file_metadata,
+ hf_hub_download,
+ hf_hub_url,
+ repo_type_and_id_from_hf_id,
+ upload_folder,
+ )
+ from huggingface_hub.utils import EntryNotFoundError
+ _has_hf_hub = True
+except ImportError:
+ _has_hf_hub = False
+
+from .factory import create_model_from_pretrained, get_model_config, get_tokenizer
+from .tokenizer import HFTokenizer
+
+
+def save_config_for_hf(
+ model,
+ config_path: str,
+ model_config: Optional[dict]
+):
+ preprocess_cfg = {
+ 'mean': model.visual.image_mean,
+ 'std': model.visual.image_std,
+ }
+ hf_config = {
+ 'model_cfg': model_config,
+ 'preprocess_cfg': preprocess_cfg,
+ }
+
+ with config_path.open('w') as f:
+ json.dump(hf_config, f, indent=2)
+
+
+def save_for_hf(
+ model,
+ tokenizer: HFTokenizer,
+ model_config: dict,
+ save_directory: str,
+ weights_filename='open_clip_pytorch_model.bin',
+ config_filename='open_clip_config.json',
+):
+ save_directory = Path(save_directory)
+ save_directory.mkdir(exist_ok=True, parents=True)
+
+ weights_path = save_directory / weights_filename
+ torch.save(model.state_dict(), weights_path)
+
+ tokenizer.save_pretrained(save_directory)
+
+ config_path = save_directory / config_filename
+ save_config_for_hf(model, config_path, model_config=model_config)
+
+
+def push_to_hf_hub(
+ model,
+ tokenizer,
+ model_config: Optional[dict],
+ repo_id: str,
+ commit_message: str = 'Add model',
+ token: Optional[str] = None,
+ revision: Optional[str] = None,
+ private: bool = False,
+ create_pr: bool = False,
+ model_card: Optional[dict] = None,
+):
+ if not isinstance(tokenizer, HFTokenizer):
+ # default CLIP tokenizers use https://huggingface.co/openai/clip-vit-large-patch14
+ tokenizer = HFTokenizer('openai/clip-vit-large-patch14')
+
+ # Create repo if it doesn't exist yet
+ repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True)
+
+ # Infer complete repo_id from repo_url
+ # Can be different from the input `repo_id` if repo_owner was implicit
+ _, repo_owner, repo_name = repo_type_and_id_from_hf_id(repo_url)
+ repo_id = f"{repo_owner}/{repo_name}"
+
+ # Check if README file already exist in repo
+ try:
+ get_hf_file_metadata(hf_hub_url(repo_id=repo_id, filename="README.md", revision=revision))
+ has_readme = True
+ except EntryNotFoundError:
+ has_readme = False
+
+ # Dump model and push to Hub
+ with TemporaryDirectory() as tmpdir:
+ # Save model weights and config.
+ save_for_hf(
+ model,
+ tokenizer=tokenizer,
+ model_config=model_config,
+ save_directory=tmpdir,
+ )
+
+ # Add readme if it does not exist
+ if not has_readme:
+ model_card = model_card or {}
+ model_name = repo_id.split('/')[-1]
+ readme_path = Path(tmpdir) / "README.md"
+ readme_text = generate_readme(model_card, model_name)
+ readme_path.write_text(readme_text)
+
+ # Upload model and return
+ return upload_folder(
+ repo_id=repo_id,
+ folder_path=tmpdir,
+ revision=revision,
+ create_pr=create_pr,
+ commit_message=commit_message,
+ )
+
+
+def push_pretrained_to_hf_hub(
+ model_name,
+ pretrained: str,
+ repo_id: str,
+ image_mean: Optional[Tuple[float, ...]] = None,
+ image_std: Optional[Tuple[float, ...]] = None,
+ commit_message: str = 'Add model',
+ token: Optional[str] = None,
+ revision: Optional[str] = None,
+ private: bool = False,
+ create_pr: bool = False,
+ model_card: Optional[dict] = None,
+):
+ model, preprocess_eval = create_model_from_pretrained(
+ model_name,
+ pretrained=pretrained,
+ image_mean=image_mean,
+ image_std=image_std,
+ )
+
+ model_config = get_model_config(model_name)
+ assert model_config
+
+ tokenizer = get_tokenizer(model_name)
+
+ push_to_hf_hub(
+ model=model,
+ tokenizer=tokenizer,
+ model_config=model_config,
+ repo_id=repo_id,
+ commit_message=commit_message,
+ token=token,
+ revision=revision,
+ private=private,
+ create_pr=create_pr,
+ model_card=model_card,
+ )
+
+
+def generate_readme(model_card: dict, model_name: str):
+ readme_text = "---\n"
+ readme_text += "tags:\n- zero-shot-image-classification\n- clip\n"
+ readme_text += "library_tag: open_clip\n"
+ readme_text += f"license: {model_card.get('license', 'mit')}\n"
+ if 'details' in model_card and 'Dataset' in model_card['details']:
+ readme_text += 'datasets:\n'
+ readme_text += f"- {model_card['details']['Dataset'].lower()}\n"
+ readme_text += "---\n"
+ readme_text += f"# Model card for {model_name}\n"
+ if 'description' in model_card:
+ readme_text += f"\n{model_card['description']}\n"
+ if 'details' in model_card:
+ readme_text += f"\n## Model Details\n"
+ for k, v in model_card['details'].items():
+ if isinstance(v, (list, tuple)):
+ readme_text += f"- **{k}:**\n"
+ for vi in v:
+ readme_text += f" - {vi}\n"
+ elif isinstance(v, dict):
+ readme_text += f"- **{k}:**\n"
+ for ki, vi in v.items():
+ readme_text += f" - {ki}: {vi}\n"
+ else:
+ readme_text += f"- **{k}:** {v}\n"
+ if 'usage' in model_card:
+ readme_text += f"\n## Model Usage\n"
+ readme_text += model_card['usage']
+ readme_text += '\n'
+
+ if 'comparison' in model_card:
+ readme_text += f"\n## Model Comparison\n"
+ readme_text += model_card['comparison']
+ readme_text += '\n'
+
+ if 'citation' in model_card:
+ readme_text += f"\n## Citation\n"
+ if not isinstance(model_card['citation'], (list, tuple)):
+ citations = [model_card['citation']]
+ else:
+ citations = model_card['citation']
+ for c in citations:
+ readme_text += f"```bibtex\n{c}\n```\n"
+
+ return readme_text
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Push to Hugging Face Hub")
+ parser.add_argument(
+ "--model", type=str, help="Name of the model to use.",
+ )
+ parser.add_argument(
+ "--pretrained", type=str,
+ help="Use a pretrained CLIP model weights with the specified tag or file path.",
+ )
+ parser.add_argument(
+ "--repo-id", type=str,
+ help="Destination HF Hub repo-id ie 'organization/model_id'.",
+ )
+ parser.add_argument(
+ '--image-mean', type=float, nargs='+', default=None, metavar='MEAN',
+ help='Override default image mean value of dataset')
+ parser.add_argument(
+ '--image-std', type=float, nargs='+', default=None, metavar='STD',
+ help='Override default image std deviation of of dataset')
+ args = parser.parse_args()
+
+ print(f'Saving model {args.model} with pretrained weights {args.pretrained} to Hugging Face Hub at {args.repo_id}')
+
+ # FIXME add support to pass model_card json / template from file via cmd line
+
+ push_pretrained_to_hf_hub(
+ args.model,
+ args.pretrained,
+ args.repo_id,
+ image_mean=args.image_mean, # override image mean/std if trained w/ non defaults
+ image_std=args.image_std,
+ )
+
+ print(f'{args.model} saved.')
diff --git a/PusaV1/diffsynth/extensions/ImageQualityMetric/open_clip/timm_model.py b/PusaV1/diffsynth/extensions/ImageQualityMetric/open_clip/timm_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc71a693f9a42ec01fd88d307661bc382b4d05bc
--- /dev/null
+++ b/PusaV1/diffsynth/extensions/ImageQualityMetric/open_clip/timm_model.py
@@ -0,0 +1,127 @@
+""" timm model adapter
+
+Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model.
+"""
+import logging
+from collections import OrderedDict
+
+import torch
+import torch.nn as nn
+
+try:
+ import timm
+ from timm.models.layers import Mlp, to_2tuple
+ try:
+ # old timm imports < 0.8.1
+ from timm.models.layers.attention_pool2d import RotAttentionPool2d
+ from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d
+ except ImportError:
+ # new timm imports >= 0.8.1
+ from timm.layers import RotAttentionPool2d
+ from timm.layers import AttentionPool2d as AbsAttentionPool2d
+except ImportError:
+ timm = None
+
+from .utils import freeze_batch_norm_2d
+
+
+class TimmModel(nn.Module):
+ """ timm model adapter
+ # FIXME this adapter is a work in progress, may change in ways that break weight compat
+ """
+
+ def __init__(
+ self,
+ model_name,
+ embed_dim,
+ image_size=224,
+ pool='avg',
+ proj='linear',
+ proj_bias=False,
+ drop=0.,
+ drop_path=None,
+ pretrained=False,
+ ):
+ super().__init__()
+ if timm is None:
+ raise RuntimeError("Please `pip install timm` to use timm models.")
+
+ self.image_size = to_2tuple(image_size)
+ timm_kwargs = {}
+ if drop_path is not None:
+ timm_kwargs['drop_path_rate'] = drop_path
+ self.trunk = timm.create_model(model_name, pretrained=pretrained, **timm_kwargs)
+ feat_size = self.trunk.default_cfg.get('pool_size', None)
+ feature_ndim = 1 if not feat_size else 2
+ if pool in ('abs_attn', 'rot_attn'):
+ assert feature_ndim == 2
+ # if attn pooling used, remove both classifier and default pool
+ self.trunk.reset_classifier(0, global_pool='')
+ else:
+ # reset global pool if pool config set, otherwise leave as network default
+ reset_kwargs = dict(global_pool=pool) if pool else {}
+ self.trunk.reset_classifier(0, **reset_kwargs)
+ prev_chs = self.trunk.num_features
+
+ head_layers = OrderedDict()
+ if pool == 'abs_attn':
+ head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim)
+ prev_chs = embed_dim
+ elif pool == 'rot_attn':
+ head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim)
+ prev_chs = embed_dim
+ else:
+ assert proj, 'projection layer needed if non-attention pooling is used.'
+
+ # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used
+ if proj == 'linear':
+ head_layers['drop'] = nn.Dropout(drop)
+ head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias)
+ elif proj == 'mlp':
+ head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=(drop, 0), bias=(True, proj_bias))
+
+ self.head = nn.Sequential(head_layers)
+
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
+ """ lock modules
+ Args:
+ unlocked_groups (int): leave last n layer groups unlocked (default: 0)
+ """
+ if not unlocked_groups:
+ # lock full model
+ for param in self.trunk.parameters():
+ param.requires_grad = False
+ if freeze_bn_stats:
+ freeze_batch_norm_2d(self.trunk)
+ else:
+ # NOTE: partial freeze requires latest timm (master) branch and is subject to change
+ try:
+ # FIXME import here until API stable and in an official release
+ from timm.models.helpers import group_parameters, group_modules
+ except ImportError:
+ raise RuntimeError(
+ 'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`')
+ matcher = self.trunk.group_matcher()
+ gparams = group_parameters(self.trunk, matcher)
+ max_layer_id = max(gparams.keys())
+ max_layer_id = max_layer_id - unlocked_groups
+ for group_idx in range(max_layer_id + 1):
+ group = gparams[group_idx]
+ for param in group:
+ self.trunk.get_parameter(param).requires_grad = False
+ if freeze_bn_stats:
+ gmodules = group_modules(self.trunk, matcher, reverse=True)
+ gmodules = {k for k, v in gmodules.items() if v <= max_layer_id}
+ freeze_batch_norm_2d(self.trunk, gmodules)
+
+ @torch.jit.ignore
+ def set_grad_checkpointing(self, enable=True):
+ try:
+ self.trunk.set_grad_checkpointing(enable)
+ except Exception as e:
+ logging.warning('grad checkpointing not supported for this timm image tower, continuing without...')
+
+ def forward(self, x):
+ x = self.trunk(x)
+ x = self.head(x)
+ return x
diff --git a/PusaV1/diffsynth/extensions/ImageQualityMetric/open_clip/tokenizer.py b/PusaV1/diffsynth/extensions/ImageQualityMetric/open_clip/tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..22ec4880b13ec73594d5c19b3d3be83aadb55aba
--- /dev/null
+++ b/PusaV1/diffsynth/extensions/ImageQualityMetric/open_clip/tokenizer.py
@@ -0,0 +1,211 @@
+""" CLIP tokenizer
+
+Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
+"""
+import gzip
+import html
+import os
+from functools import lru_cache
+from typing import Union, List
+
+import ftfy
+import regex as re
+import torch
+
+# https://stackoverflow.com/q/62691279
+import os
+os.environ["TOKENIZERS_PARALLELISM"] = "false"
+
+
+@lru_cache()
+def default_bpe():
+ current_dir = os.path.dirname(os.path.abspath(__file__))
+ project_root = os.path.abspath(os.path.join(current_dir, '../../../../'))
+ quality_metric_path = os.path.join(project_root, 'models', 'QualityMetric')
+ return os.path.join(quality_metric_path, "bpe_simple_vocab_16e6.txt.gz")
+
+
+@lru_cache()
+def bytes_to_unicode():
+ """
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
+ The reversible bpe codes work on unicode strings.
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
+ This is a significant percentage of your normal, say, 32K bpe vocab.
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
+ """
+ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
+ cs = bs[:]
+ n = 0
+ for b in range(2**8):
+ if b not in bs:
+ bs.append(b)
+ cs.append(2**8+n)
+ n += 1
+ cs = [chr(n) for n in cs]
+ return dict(zip(bs, cs))
+
+
+def get_pairs(word):
+ """Return set of symbol pairs in a word.
+ Word is represented as tuple of symbols (symbols being variable-length strings).
+ """
+ pairs = set()
+ prev_char = word[0]
+ for char in word[1:]:
+ pairs.add((prev_char, char))
+ prev_char = char
+ return pairs
+
+
+def basic_clean(text):
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+def whitespace_clean(text):
+ text = re.sub(r'\s+', ' ', text)
+ text = text.strip()
+ return text
+
+
+class SimpleTokenizer(object):
+ def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):
+ self.byte_encoder = bytes_to_unicode()
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
+ merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
+ merges = merges[1:49152-256-2+1]
+ merges = [tuple(merge.split()) for merge in merges]
+ vocab = list(bytes_to_unicode().values())
+ vocab = vocab + [v+'' for v in vocab]
+ for merge in merges:
+ vocab.append(''.join(merge))
+ if not special_tokens:
+ special_tokens = ['', '']
+ else:
+ special_tokens = ['', ''] + special_tokens
+ vocab.extend(special_tokens)
+ self.encoder = dict(zip(vocab, range(len(vocab))))
+ self.decoder = {v: k for k, v in self.encoder.items()}
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
+ self.cache = {t:t for t in special_tokens}
+ special = "|".join(special_tokens)
+ self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
+
+ self.vocab_size = len(self.encoder)
+ self.all_special_ids = [self.encoder[t] for t in special_tokens]
+
+ def bpe(self, token):
+ if token in self.cache:
+ return self.cache[token]
+ word = tuple(token[:-1]) + ( token[-1] + '',)
+ pairs = get_pairs(word)
+
+ if not pairs:
+ return token+''
+
+ while True:
+ bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
+ if bigram not in self.bpe_ranks:
+ break
+ first, second = bigram
+ new_word = []
+ i = 0
+ while i < len(word):
+ try:
+ j = word.index(first, i)
+ new_word.extend(word[i:j])
+ i = j
+ except:
+ new_word.extend(word[i:])
+ break
+
+ if word[i] == first and i < len(word)-1 and word[i+1] == second:
+ new_word.append(first+second)
+ i += 2
+ else:
+ new_word.append(word[i])
+ i += 1
+ new_word = tuple(new_word)
+ word = new_word
+ if len(word) == 1:
+ break
+ else:
+ pairs = get_pairs(word)
+ word = ' '.join(word)
+ self.cache[token] = word
+ return word
+
+ def encode(self, text):
+ bpe_tokens = []
+ text = whitespace_clean(basic_clean(text)).lower()
+ for token in re.findall(self.pat, text):
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
+ return bpe_tokens
+
+ def decode(self, tokens):
+ text = ''.join([self.decoder[token] for token in tokens])
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
+ return text
+
+ def __call__(self, texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
+ """
+ Returns the tokenized representation of given input string(s)
+
+ Parameters
+ ----------
+ texts : Union[str, List[str]]
+ An input string or a list of input strings to tokenize
+ context_length : int
+ The context length to use; all CLIP models use 77 as the context length
+
+ Returns
+ -------
+ A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
+ """
+ if isinstance(texts, str):
+ texts = [texts]
+
+ sot_token = self.encoder[""]
+ eot_token = self.encoder[""]
+ all_tokens = [[sot_token] + self.encode(text) + [eot_token] for text in texts]
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
+
+ for i, tokens in enumerate(all_tokens):
+ if len(tokens) > context_length:
+ tokens = tokens[:context_length] # Truncate
+ tokens[-1] = eot_token
+ result[i, :len(tokens)] = torch.tensor(tokens)
+
+ return result
+
+
+
+class HFTokenizer:
+ """HuggingFace tokenizer wrapper"""
+
+ def __init__(self, tokenizer_name: str):
+ from transformers import AutoTokenizer
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
+
+ def save_pretrained(self, dest):
+ self.tokenizer.save_pretrained(dest)
+
+ def __call__(self, texts: Union[str, List[str]], context_length: int = 77) -> torch.Tensor:
+ # same cleaning as for default tokenizer, except lowercasing
+ # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance
+ if isinstance(texts, str):
+ texts = [texts]
+ texts = [whitespace_clean(basic_clean(text)) for text in texts]
+ input_ids = self.tokenizer(
+ texts,
+ return_tensors='pt',
+ max_length=context_length,
+ padding='max_length',
+ truncation=True,
+ ).input_ids
+ return input_ids
diff --git a/PusaV1/diffsynth/extensions/ImageQualityMetric/open_clip/transform.py b/PusaV1/diffsynth/extensions/ImageQualityMetric/open_clip/transform.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe4e21fa5b515f2412049f9274bd06fbe77fb9b9
--- /dev/null
+++ b/PusaV1/diffsynth/extensions/ImageQualityMetric/open_clip/transform.py
@@ -0,0 +1,216 @@
+import warnings
+from dataclasses import dataclass, asdict
+from typing import Any, Dict, Optional, Sequence, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torchvision.transforms.functional as F
+from functools import partial
+from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \
+ CenterCrop
+
+from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
+
+
+@dataclass
+class AugmentationCfg:
+ scale: Tuple[float, float] = (0.9, 1.0)
+ ratio: Optional[Tuple[float, float]] = None
+ color_jitter: Optional[Union[float, Tuple[float, float, float]]] = None
+ interpolation: Optional[str] = None
+ re_prob: Optional[float] = None
+ re_count: Optional[int] = None
+ use_timm: bool = False
+
+
+class ResizeMaxSize(nn.Module):
+
+ def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0):
+ super().__init__()
+ if not isinstance(max_size, int):
+ raise TypeError(f"Size should be int. Got {type(max_size)}")
+ self.max_size = max_size
+ self.interpolation = interpolation
+ self.fn = min if fn == 'min' else min
+ self.fill = fill
+
+ def forward(self, img):
+ if isinstance(img, torch.Tensor):
+ height, width = img.shape[1:]
+ else:
+ width, height = img.size
+ scale = self.max_size / float(max(height, width))
+ if scale != 1.0:
+ new_size = tuple(round(dim * scale) for dim in (height, width))
+ img = F.resize(img, new_size, self.interpolation)
+ pad_h = self.max_size - new_size[0]
+ pad_w = self.max_size - new_size[1]
+ img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill)
+ return img
+
+
+def _convert_to_rgb_or_rgba(image):
+ if image.mode == 'RGBA':
+ return image
+ else:
+ return image.convert('RGB')
+
+# def transform_and_split(merged, transform_fn, normalize_fn):
+# transformed = transform_fn(merged)
+# crop_img, crop_label = torch.split(transformed, [3,1], dim=0)
+
+# # crop_img = _convert_to_rgb(crop_img)
+# crop_img = normalize_fn(ToTensor()(crop_img))
+# return crop_img, crop_label
+
+class MaskAwareNormalize(nn.Module):
+ def __init__(self, mean, std):
+ super().__init__()
+ self.normalize = Normalize(mean=mean, std=std)
+
+ def forward(self, tensor):
+ if tensor.shape[0] == 4:
+ return torch.cat([self.normalize(tensor[:3]), tensor[3:]], dim=0)
+ else:
+ return self.normalize(tensor)
+
+def image_transform(
+ image_size: int,
+ is_train: bool,
+ mean: Optional[Tuple[float, ...]] = None,
+ std: Optional[Tuple[float, ...]] = None,
+ resize_longest_max: bool = False,
+ fill_color: int = 0,
+ aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
+):
+ mean = mean or OPENAI_DATASET_MEAN
+ if not isinstance(mean, (list, tuple)):
+ mean = (mean,) * 3
+
+ std = std or OPENAI_DATASET_STD
+ if not isinstance(std, (list, tuple)):
+ std = (std,) * 3
+
+ if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]:
+ # for square size, pass size as int so that Resize() uses aspect preserving shortest edge
+ image_size = image_size[0]
+
+ if isinstance(aug_cfg, dict):
+ aug_cfg = AugmentationCfg(**aug_cfg)
+ else:
+ aug_cfg = aug_cfg or AugmentationCfg()
+ normalize = MaskAwareNormalize(mean=mean, std=std)
+ if is_train:
+ aug_cfg_dict = {k: v for k, v in asdict(aug_cfg).items() if v is not None}
+ use_timm = aug_cfg_dict.pop('use_timm', False)
+ if use_timm:
+ assert False, "not tested for augmentation with mask"
+ from timm.data import create_transform # timm can still be optional
+ if isinstance(image_size, (tuple, list)):
+ assert len(image_size) >= 2
+ input_size = (3,) + image_size[-2:]
+ else:
+ input_size = (3, image_size, image_size)
+ # by default, timm aug randomly alternates bicubic & bilinear for better robustness at inference time
+ aug_cfg_dict.setdefault('interpolation', 'random')
+ aug_cfg_dict.setdefault('color_jitter', None) # disable by default
+ train_transform = create_transform(
+ input_size=input_size,
+ is_training=True,
+ hflip=0.,
+ mean=mean,
+ std=std,
+ re_mode='pixel',
+ **aug_cfg_dict,
+ )
+ else:
+ train_transform = Compose([
+ _convert_to_rgb_or_rgba,
+ ToTensor(),
+ RandomResizedCrop(
+ image_size,
+ scale=aug_cfg_dict.pop('scale'),
+ interpolation=InterpolationMode.BICUBIC,
+ ),
+ normalize,
+ ])
+ if aug_cfg_dict:
+ warnings.warn(f'Unused augmentation cfg items, specify `use_timm` to use ({list(aug_cfg_dict.keys())}).')
+ return train_transform
+ else:
+ transforms = [
+ _convert_to_rgb_or_rgba,
+ ToTensor(),
+ ]
+ if resize_longest_max:
+ transforms.extend([
+ ResizeMaxSize(image_size, fill=fill_color)
+ ])
+ else:
+ transforms.extend([
+ Resize(image_size, interpolation=InterpolationMode.BICUBIC),
+ CenterCrop(image_size),
+ ])
+ transforms.extend([
+ normalize,
+ ])
+ return Compose(transforms)
+
+
+# def image_transform_region(
+# image_size: int,
+# is_train: bool,
+# mean: Optional[Tuple[float, ...]] = None,
+# std: Optional[Tuple[float, ...]] = None,
+# resize_longest_max: bool = False,
+# fill_color: int = 0,
+# aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
+# ):
+# mean = mean or OPENAI_DATASET_MEAN
+# if not isinstance(mean, (list, tuple)):
+# mean = (mean,) * 3
+
+# std = std or OPENAI_DATASET_STD
+# if not isinstance(std, (list, tuple)):
+# std = (std,) * 3
+
+# if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]:
+# # for square size, pass size as int so that Resize() uses aspect preserving shortest edge
+# image_size = image_size[0]
+
+# if isinstance(aug_cfg, dict):
+# aug_cfg = AugmentationCfg(**aug_cfg)
+# else:
+# aug_cfg = aug_cfg or AugmentationCfg()
+# normalize = Normalize(mean=mean, std=std)
+# if is_train:
+# aug_cfg_dict = {k: v for k, v in asdict(aug_cfg).items() if v is not None}
+
+# transform = Compose([
+# RandomResizedCrop(
+# image_size,
+# scale=aug_cfg_dict.pop('scale'),
+# interpolation=InterpolationMode.BICUBIC,
+# ),
+# ])
+# train_transform = Compose([
+# partial(transform_and_split, transform_fn=transform,normalize_fn=normalize)
+# ])
+# return train_transform
+# else:
+# if resize_longest_max:
+# transform = [
+# ResizeMaxSize(image_size, fill=fill_color)
+# ]
+# val_transform = Compose([
+# partial(transform_and_split, transform_fn=transform,normalize_fn=normalize),
+# ])
+# else:
+# transform = [
+# Resize(image_size, interpolation=InterpolationMode.BICUBIC),
+# CenterCrop(image_size),
+# ]
+# val_transform = Compose([
+# partial(transform_and_split, transform_fn=transform,normalize_fn=normalize),
+# ])
+# return val_transform
\ No newline at end of file
diff --git a/PusaV1/diffsynth/extensions/ImageQualityMetric/open_clip/transformer.py b/PusaV1/diffsynth/extensions/ImageQualityMetric/open_clip/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..7465c1b20bf388a17e0f4f80f7b8eee3b564af92
--- /dev/null
+++ b/PusaV1/diffsynth/extensions/ImageQualityMetric/open_clip/transformer.py
@@ -0,0 +1,727 @@
+from collections import OrderedDict
+import math
+from typing import Callable, Optional, Sequence, Tuple
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+from torch.utils.checkpoint import checkpoint
+
+from .utils import to_2tuple
+
+
+class LayerNormFp32(nn.LayerNorm):
+ """Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back)."""
+
+ def forward(self, x: torch.Tensor):
+ orig_type = x.dtype
+ x = F.layer_norm(x.to(torch.float32), self.normalized_shape, self.weight, self.bias, self.eps)
+ return x.to(orig_type)
+
+
+class LayerNorm(nn.LayerNorm):
+ """Subclass torch's LayerNorm (with cast back to input dtype)."""
+
+ def forward(self, x: torch.Tensor):
+ orig_type = x.dtype
+ x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
+ return x.to(orig_type)
+
+
+class QuickGELU(nn.Module):
+ # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory
+ def forward(self, x: torch.Tensor):
+ return x * torch.sigmoid(1.702 * x)
+
+
+class LayerScale(nn.Module):
+ def __init__(self, dim, init_values=1e-5, inplace=False):
+ super().__init__()
+ self.inplace = inplace
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
+
+ def forward(self, x):
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
+
+
+class PatchDropout(nn.Module):
+ """
+ https://arxiv.org/abs/2212.00794
+ """
+
+ def __init__(self, prob, exclude_first_token=True):
+ super().__init__()
+ assert 0 <= prob < 1.
+ self.prob = prob
+ self.exclude_first_token = exclude_first_token # exclude CLS token
+
+ def forward(self, x):
+ if not self.training or self.prob == 0.:
+ return x
+
+ if self.exclude_first_token:
+ cls_tokens, x = x[:, :1], x[:, 1:]
+ else:
+ cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
+
+ batch = x.size()[0]
+ num_tokens = x.size()[1]
+
+ batch_indices = torch.arange(batch)
+ batch_indices = batch_indices[..., None]
+
+ keep_prob = 1 - self.prob
+ num_patches_keep = max(1, int(num_tokens * keep_prob))
+
+ rand = torch.randn(batch, num_tokens)
+ patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
+
+ x = x[batch_indices, patch_indices_keep]
+
+ if self.exclude_first_token:
+ x = torch.cat((cls_tokens, x), dim=1)
+
+ return x
+
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dim,
+ num_heads=8,
+ qkv_bias=True,
+ scaled_cosine=False,
+ scale_heads=False,
+ logit_scale_max=math.log(1. / 0.01),
+ attn_drop=0.,
+ proj_drop=0.
+ ):
+ super().__init__()
+ self.scaled_cosine = scaled_cosine
+ self.scale_heads = scale_heads
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.scale = self.head_dim ** -0.5
+ self.logit_scale_max = logit_scale_max
+
+ # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original
+ self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
+ if qkv_bias:
+ self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
+ else:
+ self.in_proj_bias = None
+
+ if self.scaled_cosine:
+ self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
+ else:
+ self.logit_scale = None
+ self.attn_drop = nn.Dropout(attn_drop)
+ if self.scale_heads:
+ self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))
+ else:
+ self.head_scale = None
+ self.out_proj = nn.Linear(dim, dim)
+ self.out_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
+ L, N, C = x.shape
+ q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1)
+ q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
+ k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
+ v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
+
+ if self.logit_scale is not None:
+ attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2))
+ logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()
+ attn = attn.view(N, self.num_heads, L, L) * logit_scale
+ attn = attn.view(-1, L, L)
+ else:
+ q = q * self.scale
+ attn = torch.bmm(q, k.transpose(-1, -2))
+
+ if attn_mask is not None:
+ if attn_mask.dtype == torch.bool:
+ new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
+ new_attn_mask.masked_fill_(attn_mask, float("-inf"))
+ attn_mask = new_attn_mask
+ attn += attn_mask
+
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = torch.bmm(attn, v)
+ if self.head_scale is not None:
+ x = x.view(N, self.num_heads, L, C) * self.head_scale
+ x = x.view(-1, L, C)
+ x = x.transpose(0, 1).reshape(L, N, C)
+ x = self.out_proj(x)
+ x = self.out_drop(x)
+ return x
+
+
+class AttentionalPooler(nn.Module):
+ def __init__(
+ self,
+ d_model: int,
+ context_dim: int,
+ n_head: int = 8,
+ n_queries: int = 256,
+ norm_layer: Callable = LayerNorm
+ ):
+ super().__init__()
+ self.query = nn.Parameter(torch.randn(n_queries, d_model))
+ self.attn = nn.MultiheadAttention(d_model, n_head, kdim=context_dim, vdim=context_dim)
+ self.ln_q = norm_layer(d_model)
+ self.ln_k = norm_layer(context_dim)
+
+ def forward(self, x: torch.Tensor):
+ x = self.ln_k(x).permute(1, 0, 2) # NLD -> LND
+ N = x.shape[1]
+ q = self.ln_q(self.query)
+ out = self.attn(self._repeat(q, N), x, x, need_weights=False)[0]
+ return out.permute(1, 0, 2) # LND -> NLD
+
+ def _repeat(self, query, N: int):
+ return query.unsqueeze(1).repeat(1, N, 1)
+
+
+class ResidualAttentionBlock(nn.Module):
+ def __init__(
+ self,
+ d_model: int,
+ n_head: int,
+ mlp_ratio: float = 4.0,
+ ls_init_value: float = None,
+ act_layer: Callable = nn.GELU,
+ norm_layer: Callable = LayerNorm,
+ is_cross_attention: bool = False,
+ ):
+ super().__init__()
+
+ self.ln_1 = norm_layer(d_model)
+ self.attn = nn.MultiheadAttention(d_model, n_head)
+ self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
+ if is_cross_attention:
+ self.ln_1_kv = norm_layer(d_model)
+
+ self.ln_2 = norm_layer(d_model)
+ mlp_width = int(d_model * mlp_ratio)
+ self.mlp = nn.Sequential(OrderedDict([
+ ("c_fc", nn.Linear(d_model, mlp_width)),
+ ("gelu", act_layer()),
+ ("c_proj", nn.Linear(mlp_width, d_model))
+ ]))
+ self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
+
+ def attention(
+ self,
+ q_x: torch.Tensor,
+ k_x: Optional[torch.Tensor] = None,
+ v_x: Optional[torch.Tensor] = None,
+ attn_mask: Optional[torch.Tensor] = None,
+ ):
+ k_x = k_x if k_x is not None else q_x
+ v_x = v_x if v_x is not None else q_x
+
+ attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None
+ return self.attn(
+ q_x, k_x, v_x, need_weights=False, attn_mask=attn_mask
+ )[0]
+
+ def forward(
+ self,
+ q_x: torch.Tensor,
+ k_x: Optional[torch.Tensor] = None,
+ v_x: Optional[torch.Tensor] = None,
+ attn_mask: Optional[torch.Tensor] = None,
+ ):
+ k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None
+ v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None
+
+ x = q_x + self.ls_1(self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask))
+ x = x + self.ls_2(self.mlp(self.ln_2(x)))
+ return x
+
+
+class CustomResidualAttentionBlock(nn.Module):
+ def __init__(
+ self,
+ d_model: int,
+ n_head: int,
+ mlp_ratio: float = 4.0,
+ ls_init_value: float = None,
+ act_layer: Callable = nn.GELU,
+ norm_layer: Callable = LayerNorm,
+ scale_cosine_attn: bool = False,
+ scale_heads: bool = False,
+ scale_attn: bool = False,
+ scale_fc: bool = False,
+ ):
+ super().__init__()
+
+ self.ln_1 = norm_layer(d_model)
+ self.attn = Attention(
+ d_model, n_head,
+ scaled_cosine=scale_cosine_attn,
+ scale_heads=scale_heads,
+ )
+ self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity()
+ self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
+
+ self.ln_2 = norm_layer(d_model)
+ mlp_width = int(d_model * mlp_ratio)
+ self.mlp = nn.Sequential(OrderedDict([
+ ("c_fc", nn.Linear(d_model, mlp_width)),
+ ('ln', norm_layer(mlp_width) if scale_fc else nn.Identity()),
+ ("gelu", act_layer()),
+ ("c_proj", nn.Linear(mlp_width, d_model))
+ ]))
+ self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
+
+ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
+ x = x + self.ls_1(self.ln_attn(self.attn(self.ln_1(x), attn_mask=attn_mask)))
+ x = x + self.ls_2(self.mlp(self.ln_2(x)))
+ return x
+
+
+class Transformer(nn.Module):
+ def __init__(
+ self,
+ width: int,
+ layers: int,
+ heads: int,
+ mlp_ratio: float = 4.0,
+ ls_init_value: float = None,
+ act_layer: Callable = nn.GELU,
+ norm_layer: Callable = LayerNorm,
+ ):
+ super().__init__()
+ self.width = width
+ self.layers = layers
+ self.grad_checkpointing = False
+
+ self.resblocks = nn.ModuleList([
+ ResidualAttentionBlock(
+ width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer)
+ for _ in range(layers)
+ ])
+
+ def get_cast_dtype(self) -> torch.dtype:
+ return self.resblocks[0].mlp.c_fc.weight.dtype
+
+ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
+ for r in self.resblocks:
+ if self.grad_checkpointing and not torch.jit.is_scripting():
+ # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
+ x = checkpoint(r, x, None, None, attn_mask)
+ else:
+ x = r(x, attn_mask=attn_mask)
+ return x
+
+
+class VisionTransformer(nn.Module):
+ output_tokens: torch.jit.Final[bool]
+
+ def __init__(
+ self,
+ image_size: int,
+ patch_size: int,
+ width: int,
+ layers: int,
+ heads: int,
+ mlp_ratio: float,
+ ls_init_value: float = None,
+ global_average_pool: bool = False,
+ attentional_pool: bool = False,
+ n_queries: int = 256,
+ attn_pooler_heads: int = 8,
+ output_dim: int = 512,
+ patch_dropout: float = 0.,
+ input_patchnorm: bool = False,
+ act_layer: Callable = nn.GELU,
+ norm_layer: Callable = LayerNorm,
+ output_tokens: bool = False
+ ):
+ super().__init__()
+ self.output_tokens = output_tokens
+ image_height, image_width = self.image_size = to_2tuple(image_size)
+ patch_height, patch_width = self.patch_size = to_2tuple(patch_size)
+ self.grid_size = (image_height // patch_height, image_width // patch_width)
+ self.output_dim = output_dim
+
+ # whether to layernorm each patch, as done in dual patchnorm paper - https://arxiv.org/abs/2302.01327v1
+ self.input_patchnorm = input_patchnorm
+
+ if input_patchnorm:
+ patch_input_dim = patch_height * patch_width * 3
+ self.patchnorm_pre_ln = LayerNorm(patch_input_dim)
+ self.conv1 = nn.Linear(patch_input_dim, width)
+ else:
+ self.patchnorm_pre_ln = nn.Identity()
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
+
+ # class embeddings and positional embeddings
+ scale = width ** -0.5
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
+ self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width))
+
+ # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
+ self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity()
+
+ self.ln_pre = norm_layer(width)
+ self.transformer = Transformer(
+ width,
+ layers,
+ heads,
+ mlp_ratio,
+ ls_init_value=ls_init_value,
+ act_layer=act_layer,
+ norm_layer=norm_layer,
+ )
+
+ self.global_average_pool = global_average_pool
+ if attentional_pool:
+ self.attn_pool = AttentionalPooler(output_dim, width, n_head=attn_pooler_heads, n_queries=n_queries)
+ self.ln_post = norm_layer(output_dim)
+ self.proj = nn.Parameter(scale * torch.randn(output_dim, output_dim))
+ else:
+ self.attn_pool = None
+ self.ln_post = norm_layer(width)
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
+
+ self.init_parameters()
+
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
+ for param in self.parameters():
+ param.requires_grad = False
+
+ if unlocked_groups != 0:
+ groups = [
+ [
+ self.conv1,
+ self.class_embedding,
+ self.positional_embedding,
+ self.ln_pre,
+ ],
+ *self.transformer.resblocks[:-1],
+ [
+ self.transformer.resblocks[-1],
+ self.ln_post,
+ ],
+ self.proj,
+ ]
+
+ def _unlock(x):
+ if isinstance(x, Sequence):
+ for g in x:
+ _unlock(g)
+ else:
+ if isinstance(x, torch.nn.Parameter):
+ x.requires_grad = True
+ else:
+ for p in x.parameters():
+ p.requires_grad = True
+
+ _unlock(groups[-unlocked_groups:])
+
+ def init_parameters(self):
+ # FIXME OpenAI CLIP did not define an init for the VisualTransformer
+ # TODO experiment if default PyTorch init, below, or alternate init is best.
+
+ # nn.init.normal_(self.class_embedding, std=self.scale)
+ # nn.init.normal_(self.positional_embedding, std=self.scale)
+ #
+ # proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
+ # attn_std = self.transformer.width ** -0.5
+ # fc_std = (2 * self.transformer.width) ** -0.5
+ # for block in self.transformer.resblocks:
+ # nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
+ # nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
+ # nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
+ # nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
+ #
+ # if self.text_projection is not None:
+ # nn.init.normal_(self.text_projection, std=self.scale)
+ pass
+
+ @torch.jit.ignore
+ def set_grad_checkpointing(self, enable=True):
+ self.transformer.grad_checkpointing = enable
+
+ def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ if self.global_average_pool:
+ return x.mean(dim=1), x
+ else:
+ return x[:, 0], x[:, 1:]
+
+ def forward(self, x: torch.Tensor, skip_pool: bool = False):
+
+ # to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1
+ if self.input_patchnorm:
+ # einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)')
+ x = x.reshape(x.shape[0], x.shape[1], self.grid_size[0], self.patch_size[0], self.grid_size[1], self.patch_size[1])
+ x = x.permute(0, 2, 4, 1, 3, 5)
+ x = x.reshape(x.shape[0], self.grid_size[0] * self.grid_size[1], -1)
+ x = self.patchnorm_pre_ln(x)
+ x = self.conv1(x)
+ else:
+ x = self.conv1(x) # shape = [*, width, grid, grid]
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
+
+ # class embeddings and positional embeddings
+ x = torch.cat(
+ [self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
+ x], dim=1) # shape = [*, grid ** 2 + 1, width]
+ x = x + self.positional_embedding.to(x.dtype)
+
+ # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
+ x = self.patch_dropout(x)
+ x = self.ln_pre(x)
+
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.transformer(x)
+ x = x.permute(1, 0, 2) # LND -> NLD
+
+ if skip_pool:
+ return x
+
+ if self.attn_pool is not None:
+ x = self.attn_pool(x)
+ x = self.ln_post(x)
+ pooled, tokens = self._global_pool(x)
+ else:
+ pooled, tokens = self._global_pool(x)
+ pooled = self.ln_post(pooled)
+
+ if self.proj is not None:
+ pooled = pooled @ self.proj
+
+ if self.output_tokens:
+ return pooled, tokens
+
+ return pooled
+
+
+class TextTransformer(nn.Module):
+ output_tokens: torch.jit.Final[bool]
+
+ def __init__(
+ self,
+ context_length: int = 77,
+ vocab_size: int = 49408,
+ width: int = 512,
+ heads: int = 8,
+ layers: int = 12,
+ ls_init_value: float = None,
+ output_dim: int = 512,
+ act_layer: Callable = nn.GELU,
+ norm_layer: Callable = LayerNorm,
+ embed_cls: bool = False,
+ pad_id: int = 0,
+ output_tokens: bool = False,
+ ):
+ super().__init__()
+ self.output_tokens = output_tokens
+ self.num_pos = self.context_length = context_length
+ self.vocab_size = vocab_size
+ self.width = width
+ self.output_dim = output_dim
+ self.heads = heads
+ self.pad_id = pad_id
+
+ self.text_projection = nn.Parameter(torch.empty(width, output_dim))
+
+ if embed_cls:
+ self.cls_emb = nn.Parameter(torch.empty(width))
+ self.num_pos += 1
+ else:
+ self.cls_emb = None
+
+ self.token_embedding = nn.Embedding(vocab_size, width)
+ self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width))
+ self.transformer = Transformer(
+ width=width,
+ layers=layers,
+ heads=heads,
+ ls_init_value=ls_init_value,
+ act_layer=act_layer,
+ norm_layer=norm_layer,
+ )
+ self.ln_final = norm_layer(width)
+
+ self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False)
+
+ self.init_parameters()
+
+ def init_parameters(self):
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
+ nn.init.normal_(self.positional_embedding, std=0.01)
+ if self.cls_emb is not None:
+ nn.init.normal_(self.cls_emb, std=0.01)
+
+ proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
+ attn_std = self.transformer.width ** -0.5
+ fc_std = (2 * self.transformer.width) ** -0.5
+ for block in self.transformer.resblocks:
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
+
+ if self.text_projection is not None:
+ nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
+
+ @torch.jit.ignore
+ def set_grad_checkpointing(self, enable=True):
+ self.transformer.grad_checkpointing = enable
+
+ def build_attention_mask(self):
+ # lazily create causal attention mask, with full attention between the tokens
+ # pytorch uses additive attention mask; fill with -inf
+ mask = torch.empty(self.num_pos, self.num_pos)
+ mask.fill_(float("-inf"))
+ mask.triu_(1) # zero out the lower diagonal
+ return mask
+
+ def build_cls_mask(self, text, cast_dtype: torch.dtype):
+ cls_mask = (text != self.pad_id).unsqueeze(1)
+ cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=1.0)
+ additive_mask = torch.empty(cls_mask.shape, dtype=cast_dtype, device=cls_mask.device)
+ additive_mask.fill_(0)
+ additive_mask.masked_fill_(~cls_mask, float("-inf"))
+ additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0)
+ return additive_mask
+
+ def _repeat(self, t, N: int):
+ return t.reshape(1, 1, -1).repeat(N, 1, 1)
+
+ def forward(self, text):
+ cast_dtype = self.transformer.get_cast_dtype()
+ seq_len = text.shape[1]
+
+ x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
+ attn_mask = self.attn_mask
+ if self.cls_emb is not None:
+ seq_len += 1
+ x = torch.cat([x, self._repeat(self.cls_emb, x.shape[0])], dim=1)
+ cls_mask = self.build_cls_mask(text, cast_dtype)
+ attn_mask = attn_mask[None, :seq_len, :seq_len] + cls_mask[:, :seq_len, :seq_len]
+
+ x = x + self.positional_embedding[:seq_len].to(cast_dtype)
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.transformer(x, attn_mask=attn_mask)
+ x = x.permute(1, 0, 2) # LND -> NLD
+
+ # x.shape = [batch_size, n_ctx, transformer.width]
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
+ if self.cls_emb is not None:
+ pooled, tokens = x[:, -1], x[:, :-1]
+ pooled = self.ln_final(pooled)
+ else:
+ x = self.ln_final(x)
+ pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x
+
+ if self.text_projection is not None:
+ pooled = pooled @ self.text_projection
+
+ if self.output_tokens:
+ return pooled, tokens
+
+ return pooled
+
+
+class MultimodalTransformer(Transformer):
+ def __init__(
+ self,
+ width: int,
+ layers: int,
+ heads: int,
+ context_length: int = 77,
+ mlp_ratio: float = 4.0,
+ ls_init_value: float = None,
+ act_layer: Callable = nn.GELU,
+ norm_layer: Callable = LayerNorm,
+ output_dim: int = 512,
+ ):
+
+ super().__init__(
+ width=width,
+ layers=layers,
+ heads=heads,
+ mlp_ratio=mlp_ratio,
+ ls_init_value=ls_init_value,
+ act_layer=act_layer,
+ norm_layer=norm_layer,
+ )
+ self.context_length = context_length
+ self.cross_attn = nn.ModuleList([
+ ResidualAttentionBlock(
+ width,
+ heads,
+ mlp_ratio,
+ ls_init_value=ls_init_value,
+ act_layer=act_layer,
+ norm_layer=norm_layer,
+ is_cross_attention=True,
+ )
+ for _ in range(layers)
+ ])
+
+ self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False)
+
+ self.ln_final = norm_layer(width)
+ self.text_projection = nn.Parameter(torch.empty(width, output_dim))
+
+ def init_parameters(self):
+ proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
+ attn_std = self.transformer.width ** -0.5
+ fc_std = (2 * self.transformer.width) ** -0.5
+ for block in self.transformer.resblocks:
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
+ for block in self.transformer.cross_attn:
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
+
+ if self.text_projection is not None:
+ nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
+
+ def build_attention_mask(self):
+ # lazily create causal attention mask, with full attention between the tokens
+ # pytorch uses additive attention mask; fill with -inf
+ mask = torch.empty(self.context_length, self.context_length)
+ mask.fill_(float("-inf"))
+ mask.triu_(1) # zero out the lower diagonal
+ return mask
+
+ def forward(self, image_embs, text_embs):
+ text_embs = text_embs.permute(1, 0, 2) # NLD -> LNDsq
+ image_embs = image_embs.permute(1, 0, 2) # NLD -> LND
+ seq_len = text_embs.shape[0]
+
+ for resblock, cross_attn in zip(self.resblocks, self.cross_attn):
+ if self.grad_checkpointing and not torch.jit.is_scripting():
+ # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
+ text_embs = checkpoint(resblock, text_embs, None, None, self.attn_mask[:seq_len, :seq_len])
+ text_embs = checkpoint(cross_attn, text_embs, image_embs, image_embs, None)
+ else:
+ text_embs = resblock(text_embs, attn_mask=self.attn_mask[:seq_len, :seq_len])
+ text_embs = cross_attn(text_embs, k_x=image_embs, v_x=image_embs)
+
+ x = text_embs.permute(1, 0, 2) # LND -> NLD
+ x = self.ln_final(x)
+
+ if self.text_projection is not None:
+ x = x @ self.text_projection
+
+ return x
+
+ @torch.jit.ignore
+ def set_grad_checkpointing(self, enable=True):
+ self.grad_checkpointing = enable
diff --git a/PusaV1/diffsynth/extensions/ImageQualityMetric/open_clip/utils.py b/PusaV1/diffsynth/extensions/ImageQualityMetric/open_clip/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..51e80c5e296b24cae130ab0459baf268e0db7673
--- /dev/null
+++ b/PusaV1/diffsynth/extensions/ImageQualityMetric/open_clip/utils.py
@@ -0,0 +1,60 @@
+from itertools import repeat
+import collections.abc
+
+from torch import nn as nn
+from torchvision.ops.misc import FrozenBatchNorm2d
+
+
+def freeze_batch_norm_2d(module, module_match={}, name=''):
+ """
+ Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
+ itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
+ returned. Otherwise, the module is walked recursively and submodules are converted in place.
+
+ Args:
+ module (torch.nn.Module): Any PyTorch module.
+ module_match (dict): Dictionary of full module names to freeze (all if empty)
+ name (str): Full module name (prefix)
+
+ Returns:
+ torch.nn.Module: Resulting module
+
+ Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
+ """
+ res = module
+ is_match = True
+ if module_match:
+ is_match = name in module_match
+ if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)):
+ res = FrozenBatchNorm2d(module.num_features)
+ res.num_features = module.num_features
+ res.affine = module.affine
+ if module.affine:
+ res.weight.data = module.weight.data.clone().detach()
+ res.bias.data = module.bias.data.clone().detach()
+ res.running_mean.data = module.running_mean.data
+ res.running_var.data = module.running_var.data
+ res.eps = module.eps
+ else:
+ for child_name, child in module.named_children():
+ full_child_name = '.'.join([name, child_name]) if name else child_name
+ new_child = freeze_batch_norm_2d(child, module_match, full_child_name)
+ if new_child is not child:
+ res.add_module(child_name, new_child)
+ return res
+
+
+# From PyTorch internals
+def _ntuple(n):
+ def parse(x):
+ if isinstance(x, collections.abc.Iterable):
+ return x
+ return tuple(repeat(x, n))
+ return parse
+
+
+to_1tuple = _ntuple(1)
+to_2tuple = _ntuple(2)
+to_3tuple = _ntuple(3)
+to_4tuple = _ntuple(4)
+to_ntuple = lambda n, x: _ntuple(n)(x)
diff --git a/PusaV1/diffsynth/extensions/ImageQualityMetric/open_clip/version.py b/PusaV1/diffsynth/extensions/ImageQualityMetric/open_clip/version.py
new file mode 100644
index 0000000000000000000000000000000000000000..48aa744fb053599044caf0253b889b5cfe5b78e7
--- /dev/null
+++ b/PusaV1/diffsynth/extensions/ImageQualityMetric/open_clip/version.py
@@ -0,0 +1 @@
+__version__ = '2.16.0'
diff --git a/PusaV1/diffsynth/extensions/ImageQualityMetric/pickscore.py b/PusaV1/diffsynth/extensions/ImageQualityMetric/pickscore.py
new file mode 100644
index 0000000000000000000000000000000000000000..7370e099724997d98f1c4ad3fc5f14c861202665
--- /dev/null
+++ b/PusaV1/diffsynth/extensions/ImageQualityMetric/pickscore.py
@@ -0,0 +1,112 @@
+import torch
+from PIL import Image
+from transformers import AutoProcessor, AutoModel
+from typing import List, Union
+import os
+from .config import MODEL_PATHS
+
+class PickScore(torch.nn.Module):
+ def __init__(self, device: Union[str, torch.device], path: str = MODEL_PATHS):
+ super().__init__()
+ """Initialize the Selector with a processor and model.
+
+ Args:
+ device (Union[str, torch.device]): The device to load the model on.
+ """
+ self.device = device if isinstance(device, torch.device) else torch.device(device)
+ processor_name_or_path = path.get("clip")
+ model_pretrained_name_or_path = path.get("pickscore")
+ self.processor = AutoProcessor.from_pretrained(processor_name_or_path)
+ self.model = AutoModel.from_pretrained(model_pretrained_name_or_path).eval().to(self.device)
+
+ def _calculate_score(self, image: torch.Tensor, prompt: str, softmax: bool = False) -> float:
+ """Calculate the score for a single image and prompt.
+
+ Args:
+ image (torch.Tensor): The processed image tensor.
+ prompt (str): The prompt text.
+ softmax (bool): Whether to apply softmax to the scores.
+
+ Returns:
+ float: The score for the image.
+ """
+ with torch.no_grad():
+ # Prepare text inputs
+ text_inputs = self.processor(
+ text=prompt,
+ padding=True,
+ truncation=True,
+ max_length=77,
+ return_tensors="pt",
+ ).to(self.device)
+
+ # Embed images and text
+ image_embs = self.model.get_image_features(pixel_values=image)
+ image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True)
+ text_embs = self.model.get_text_features(**text_inputs)
+ text_embs = text_embs / torch.norm(text_embs, dim=-1, keepdim=True)
+
+ # Compute score
+ score = (text_embs @ image_embs.T)[0]
+ if softmax:
+ # Apply logit scale and softmax
+ score = torch.softmax(self.model.logit_scale.exp() * score, dim=-1)
+
+ return score.cpu().item()
+
+ @torch.no_grad()
+ def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str, softmax: bool = False) -> List[float]:
+ """Score the images based on the prompt.
+
+ Args:
+ images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
+ prompt (str): The prompt text.
+ softmax (bool): Whether to apply softmax to the scores.
+
+ Returns:
+ List[float]: List of scores for the images.
+ """
+ try:
+ if isinstance(images, (str, Image.Image)):
+ # Single image
+ if isinstance(images, str):
+ pil_image = Image.open(images)
+ else:
+ pil_image = images
+
+ # Prepare image inputs
+ image_inputs = self.processor(
+ images=pil_image,
+ padding=True,
+ truncation=True,
+ max_length=77,
+ return_tensors="pt",
+ ).to(self.device)
+
+ return [self._calculate_score(image_inputs["pixel_values"], prompt, softmax)]
+ elif isinstance(images, list):
+ # Multiple images
+ scores = []
+ for one_image in images:
+ if isinstance(one_image, str):
+ pil_image = Image.open(one_image)
+ elif isinstance(one_image, Image.Image):
+ pil_image = one_image
+ else:
+ raise TypeError("The type of parameter images is illegal.")
+
+ # Prepare image inputs
+ image_inputs = self.processor(
+ images=pil_image,
+ padding=True,
+ truncation=True,
+ max_length=77,
+ return_tensors="pt",
+ ).to(self.device)
+
+ scores.append(self._calculate_score(image_inputs["pixel_values"], prompt, softmax))
+ return scores
+ else:
+ raise TypeError("The type of parameter images is illegal.")
+ except Exception as e:
+ raise RuntimeError(f"Error in scoring images: {e}")
diff --git a/PusaV1/diffsynth/extensions/ImageQualityMetric/trainer/__init__.py b/PusaV1/diffsynth/extensions/ImageQualityMetric/trainer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf4f59d6c0977e578ab67ec92c916c7e38842715
--- /dev/null
+++ b/PusaV1/diffsynth/extensions/ImageQualityMetric/trainer/__init__.py
@@ -0,0 +1 @@
+from .models import *
\ No newline at end of file
diff --git a/PusaV1/diffsynth/extensions/ImageQualityMetric/trainer/models/__init__.py b/PusaV1/diffsynth/extensions/ImageQualityMetric/trainer/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4e2b69d17f6f4603d115e79a6122318f059b385
--- /dev/null
+++ b/PusaV1/diffsynth/extensions/ImageQualityMetric/trainer/models/__init__.py
@@ -0,0 +1,3 @@
+from .base_model import *
+from .clip_model import *
+from .cross_modeling import *
\ No newline at end of file
diff --git a/PusaV1/diffsynth/extensions/ImageQualityMetric/trainer/models/base_model.py b/PusaV1/diffsynth/extensions/ImageQualityMetric/trainer/models/base_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..8f28caf67460a517bd9cb7cbdbd806d7b072541f
--- /dev/null
+++ b/PusaV1/diffsynth/extensions/ImageQualityMetric/trainer/models/base_model.py
@@ -0,0 +1,7 @@
+from dataclasses import dataclass
+
+
+
+@dataclass
+class BaseModelConfig:
+ pass
diff --git a/PusaV1/diffsynth/extensions/ImageQualityMetric/trainer/models/clip_model.py b/PusaV1/diffsynth/extensions/ImageQualityMetric/trainer/models/clip_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a1b37095e6b70e4722856a65fbfb30277eab03a
--- /dev/null
+++ b/PusaV1/diffsynth/extensions/ImageQualityMetric/trainer/models/clip_model.py
@@ -0,0 +1,146 @@
+from dataclasses import dataclass
+from transformers import CLIPModel as HFCLIPModel
+from transformers import AutoTokenizer
+
+from torch import nn, einsum
+
+from .base_model import BaseModelConfig
+
+from transformers import CLIPConfig
+from typing import Any, Optional, Tuple, Union
+import torch
+
+from .cross_modeling import Cross_model
+
+import json, os
+
+class XCLIPModel(HFCLIPModel):
+ def __init__(self, config: CLIPConfig):
+ super().__init__(config)
+
+ def get_text_features(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> torch.FloatTensor:
+
+ # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ text_outputs = self.text_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ # pooled_output = text_outputs[1]
+ # text_features = self.text_projection(pooled_output)
+ last_hidden_state = text_outputs[0]
+ text_features = self.text_projection(last_hidden_state)
+
+ pooled_output = text_outputs[1]
+ text_features_EOS = self.text_projection(pooled_output)
+
+
+ # del last_hidden_state, text_outputs
+ # gc.collect()
+
+ return text_features, text_features_EOS
+
+ def get_image_features(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> torch.FloatTensor:
+
+ # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ vision_outputs = self.vision_model(
+ pixel_values=pixel_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ # pooled_output = vision_outputs[1] # pooled_output
+ # image_features = self.visual_projection(pooled_output)
+ last_hidden_state = vision_outputs[0]
+ image_features = self.visual_projection(last_hidden_state)
+
+ return image_features
+
+
+
+@dataclass
+class ClipModelConfig(BaseModelConfig):
+ _target_: str = "diffsynth.extensions.QualityMetric.trainer.models.clip_model.CLIPModel"
+ pretrained_model_name_or_path: str ="checkpoints/clip-vit-base-patch32"
+
+
+class CLIPModel(nn.Module):
+ def __init__(self, ckpt, config_file=False):
+ super().__init__()
+ if config_file is None:
+ self.model = XCLIPModel.from_pretrained(ckpt)
+ else:
+ with open(os.path.join(ckpt, "config.json"), "r", encoding="utf-8") as f:
+ config = json.load(f)
+ config = CLIPConfig(**config)
+ self.model = XCLIPModel._from_config(config)
+ self.cross_model = Cross_model(dim=1024, layer_num=4, heads=16)
+
+ def get_text_features(self, *args, **kwargs):
+ return self.model.get_text_features(*args, **kwargs)
+
+ def get_image_features(self, *args, **kwargs):
+ return self.model.get_image_features(*args, **kwargs)
+
+ def forward(self, text_inputs=None, image_inputs=None, condition_inputs=None):
+ outputs = ()
+
+ text_f, text_EOS = self.model.get_text_features(text_inputs) # B*77*1024
+ outputs += text_EOS,
+
+ image_f = self.model.get_image_features(image_inputs.half()) # 2B*257*1024
+ condition_f, _ = self.model.get_text_features(condition_inputs) # B*5*1024
+
+ sim_text_condition = einsum('b i d, b j d -> b j i', text_f, condition_f)
+ sim_text_condition = torch.max(sim_text_condition, dim=1, keepdim=True)[0]
+ sim_text_condition = sim_text_condition / sim_text_condition.max()
+ mask = torch.where(sim_text_condition > 0.01, 0, float('-inf')) # B*1*77
+
+ mask = mask.repeat(1,image_f.shape[1],1) # B*257*77
+ bc = int(image_f.shape[0]/2)
+
+ sim0 = self.cross_model(image_f[:bc,:,:], text_f,mask.half())
+ sim1 = self.cross_model(image_f[bc:,:,:], text_f,mask.half())
+ outputs += sim0[:,0,:],
+ outputs += sim1[:,0,:],
+
+ return outputs
+
+ @property
+ def logit_scale(self):
+ return self.model.logit_scale
+
+ def save(self, path):
+ self.model.save_pretrained(path)
+
diff --git a/PusaV1/diffsynth/extensions/ImageQualityMetric/trainer/models/cross_modeling.py b/PusaV1/diffsynth/extensions/ImageQualityMetric/trainer/models/cross_modeling.py
new file mode 100644
index 0000000000000000000000000000000000000000..938f1b706e16aa0666210e91fb215304653df4cb
--- /dev/null
+++ b/PusaV1/diffsynth/extensions/ImageQualityMetric/trainer/models/cross_modeling.py
@@ -0,0 +1,292 @@
+import torch
+from torch import einsum, nn
+import torch.nn.functional as F
+from einops import rearrange, repeat
+
+# helper functions
+
+def exists(val):
+ return val is not None
+
+def default(val, d):
+ return val if exists(val) else d
+
+# normalization
+# they use layernorm without bias, something that pytorch does not offer
+
+
+class LayerNorm(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(dim))
+ self.register_buffer("bias", torch.zeros(dim))
+
+ def forward(self, x):
+ return F.layer_norm(x, x.shape[-1:], self.weight, self.bias)
+
+# residual
+
+
+class Residual(nn.Module):
+ def __init__(self, fn):
+ super().__init__()
+ self.fn = fn
+
+ def forward(self, x, *args, **kwargs):
+ return self.fn(x, *args, **kwargs) + x
+
+
+# rotary positional embedding
+# https://arxiv.org/abs/2104.09864
+
+
+class RotaryEmbedding(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
+ self.register_buffer("inv_freq", inv_freq)
+
+ def forward(self, max_seq_len, *, device):
+ seq = torch.arange(max_seq_len, device=device, dtype=self.inv_freq.dtype)
+ freqs = einsum("i , j -> i j", seq, self.inv_freq)
+ return torch.cat((freqs, freqs), dim=-1)
+
+
+def rotate_half(x):
+ x = rearrange(x, "... (j d) -> ... j d", j=2)
+ x1, x2 = x.unbind(dim=-2)
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(pos, t):
+ return (t * pos.cos()) + (rotate_half(t) * pos.sin())
+
+
+# classic Noam Shazeer paper, except here they use SwiGLU instead of the more popular GEGLU for gating the feedforward
+# https://arxiv.org/abs/2002.05202
+
+
+class SwiGLU(nn.Module):
+ def forward(self, x):
+ x, gate = x.chunk(2, dim=-1)
+ return F.silu(gate) * x
+
+
+# parallel attention and feedforward with residual
+# discovered by Wang et al + EleutherAI from GPT-J fame
+
+class ParallelTransformerBlock(nn.Module):
+ def __init__(self, dim, dim_head=64, heads=8, ff_mult=4):
+ super().__init__()
+ self.norm = LayerNorm(dim)
+
+ attn_inner_dim = dim_head * heads
+ ff_inner_dim = dim * ff_mult
+ self.fused_dims = (attn_inner_dim, dim_head, dim_head, (ff_inner_dim * 2))
+
+ self.heads = heads
+ self.scale = dim_head**-0.5
+ self.rotary_emb = RotaryEmbedding(dim_head)
+
+ self.fused_attn_ff_proj = nn.Linear(dim, sum(self.fused_dims), bias=False)
+ self.attn_out = nn.Linear(attn_inner_dim, dim, bias=False)
+
+ self.ff_out = nn.Sequential(
+ SwiGLU(),
+ nn.Linear(ff_inner_dim, dim, bias=False)
+ )
+
+ self.register_buffer("pos_emb", None, persistent=False)
+
+
+ def get_rotary_embedding(self, n, device):
+ if self.pos_emb is not None and self.pos_emb.shape[-2] >= n:
+ return self.pos_emb[:n]
+
+ pos_emb = self.rotary_emb(n, device=device)
+ self.register_buffer("pos_emb", pos_emb, persistent=False)
+ return pos_emb
+
+ def forward(self, x, attn_mask=None):
+ """
+ einstein notation
+ b - batch
+ h - heads
+ n, i, j - sequence length (base sequence length, source, target)
+ d - feature dimension
+ """
+
+ n, device, h = x.shape[1], x.device, self.heads
+
+ # pre layernorm
+
+ x = self.norm(x)
+
+ # attention queries, keys, values, and feedforward inner
+
+ q, k, v, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1)
+
+ # split heads
+ # they use multi-query single-key-value attention, yet another Noam Shazeer paper
+ # they found no performance loss past a certain scale, and more efficient decoding obviously
+ # https://arxiv.org/abs/1911.02150
+
+ q = rearrange(q, "b n (h d) -> b h n d", h=h)
+
+ # rotary embeddings
+
+ positions = self.get_rotary_embedding(n, device)
+ q, k = map(lambda t: apply_rotary_pos_emb(positions, t), (q, k))
+
+ # scale
+
+ q = q * self.scale
+
+ # similarity
+
+ sim = einsum("b h i d, b j d -> b h i j", q, k)
+
+
+ # extra attention mask - for masking out attention from text CLS token to padding
+
+ if exists(attn_mask):
+ attn_mask = rearrange(attn_mask, 'b i j -> b 1 i j')
+ sim = sim.masked_fill(~attn_mask, -torch.finfo(sim.dtype).max)
+
+ # attention
+
+ sim = sim - sim.amax(dim=-1, keepdim=True).detach()
+ attn = sim.softmax(dim=-1)
+
+ # aggregate values
+
+ out = einsum("b h i j, b j d -> b h i d", attn, v)
+
+ # merge heads
+
+ out = rearrange(out, "b h n d -> b n (h d)")
+ return self.attn_out(out) + self.ff_out(ff)
+
+# cross attention - using multi-query + one-headed key / values as in PaLM w/ optional parallel feedforward
+
+class CrossAttention(nn.Module):
+ def __init__(
+ self,
+ dim,
+ *,
+ context_dim=None,
+ dim_head=64,
+ heads=12,
+ parallel_ff=False,
+ ff_mult=4,
+ norm_context=False
+ ):
+ super().__init__()
+ self.heads = heads
+ self.scale = dim_head ** -0.5
+ inner_dim = heads * dim_head
+ context_dim = default(context_dim, dim)
+
+ self.norm = LayerNorm(dim)
+ self.context_norm = LayerNorm(context_dim) if norm_context else nn.Identity()
+
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
+ self.to_kv = nn.Linear(context_dim, dim_head * 2, bias=False)
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
+
+ # whether to have parallel feedforward
+
+ ff_inner_dim = ff_mult * dim
+
+ self.ff = nn.Sequential(
+ nn.Linear(dim, ff_inner_dim * 2, bias=False),
+ SwiGLU(),
+ nn.Linear(ff_inner_dim, dim, bias=False)
+ ) if parallel_ff else None
+
+ def forward(self, x, context, mask):
+ """
+ einstein notation
+ b - batch
+ h - heads
+ n, i, j - sequence length (base sequence length, source, target)
+ d - feature dimension
+ """
+
+ # pre-layernorm, for queries and context
+
+ x = self.norm(x)
+ context = self.context_norm(context)
+
+ # get queries
+
+ q = self.to_q(x)
+ q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)
+
+ # scale
+
+ q = q * self.scale
+
+ # get key / values
+
+ k, v = self.to_kv(context).chunk(2, dim=-1)
+
+ # query / key similarity
+
+ sim = einsum('b h i d, b j d -> b h i j', q, k)
+
+ # attention
+ mask = mask.unsqueeze(1).repeat(1,self.heads,1,1)
+ sim = sim + mask # context mask
+ sim = sim - sim.amax(dim=-1, keepdim=True)
+ attn = sim.softmax(dim=-1)
+
+ # aggregate
+
+ out = einsum('b h i j, b j d -> b h i d', attn, v)
+
+ # merge and combine heads
+
+ out = rearrange(out, 'b h n d -> b n (h d)')
+ out = self.to_out(out)
+
+ # add parallel feedforward (for multimodal layers)
+
+ if exists(self.ff):
+ out = out + self.ff(x)
+
+ return out
+
+
+class Cross_model(nn.Module):
+ def __init__(
+ self,
+ dim=512,
+ layer_num=4,
+ dim_head=64,
+ heads=8,
+ ff_mult=4
+ ):
+ super().__init__()
+
+ self.layers = nn.ModuleList([])
+
+
+ for ind in range(layer_num):
+ self.layers.append(nn.ModuleList([
+ Residual(CrossAttention(dim=dim, dim_head=dim_head, heads=heads, parallel_ff=True, ff_mult=ff_mult)),
+ Residual(ParallelTransformerBlock(dim=dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult))
+ ]))
+
+ def forward(
+ self,
+ query_tokens,
+ context_tokens,
+ mask
+ ):
+
+ for cross_attn, self_attn_ff in self.layers:
+ query_tokens = cross_attn(query_tokens, context_tokens,mask)
+ query_tokens = self_attn_ff(query_tokens)
+
+ return query_tokens
diff --git a/PusaV1/diffsynth/extensions/RIFE/__init__.py b/PusaV1/diffsynth/extensions/RIFE/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e76c391f0b085b3628592990a868ac09f37cced7
--- /dev/null
+++ b/PusaV1/diffsynth/extensions/RIFE/__init__.py
@@ -0,0 +1,242 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+from PIL import Image
+
+
+def warp(tenInput, tenFlow, device):
+ backwarp_tenGrid = {}
+ k = (str(tenFlow.device), str(tenFlow.size()))
+ if k not in backwarp_tenGrid:
+ tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device).view(
+ 1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
+ tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device).view(
+ 1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
+ backwarp_tenGrid[k] = torch.cat(
+ [tenHorizontal, tenVertical], 1).to(device)
+
+ tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0),
+ tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1)
+
+ g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
+ return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True)
+
+
+def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
+ return nn.Sequential(
+ nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
+ padding=padding, dilation=dilation, bias=True),
+ nn.PReLU(out_planes)
+ )
+
+
+class IFBlock(nn.Module):
+ def __init__(self, in_planes, c=64):
+ super(IFBlock, self).__init__()
+ self.conv0 = nn.Sequential(conv(in_planes, c//2, 3, 2, 1), conv(c//2, c, 3, 2, 1),)
+ self.convblock0 = nn.Sequential(conv(c, c), conv(c, c))
+ self.convblock1 = nn.Sequential(conv(c, c), conv(c, c))
+ self.convblock2 = nn.Sequential(conv(c, c), conv(c, c))
+ self.convblock3 = nn.Sequential(conv(c, c), conv(c, c))
+ self.conv1 = nn.Sequential(nn.ConvTranspose2d(c, c//2, 4, 2, 1), nn.PReLU(c//2), nn.ConvTranspose2d(c//2, 4, 4, 2, 1))
+ self.conv2 = nn.Sequential(nn.ConvTranspose2d(c, c//2, 4, 2, 1), nn.PReLU(c//2), nn.ConvTranspose2d(c//2, 1, 4, 2, 1))
+
+ def forward(self, x, flow, scale=1):
+ x = F.interpolate(x, scale_factor= 1. / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False)
+ flow = F.interpolate(flow, scale_factor= 1. / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 1. / scale
+ feat = self.conv0(torch.cat((x, flow), 1))
+ feat = self.convblock0(feat) + feat
+ feat = self.convblock1(feat) + feat
+ feat = self.convblock2(feat) + feat
+ feat = self.convblock3(feat) + feat
+ flow = self.conv1(feat)
+ mask = self.conv2(feat)
+ flow = F.interpolate(flow, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) * scale
+ mask = F.interpolate(mask, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False)
+ return flow, mask
+
+
+class IFNet(nn.Module):
+ def __init__(self, **kwargs):
+ super(IFNet, self).__init__()
+ self.block0 = IFBlock(7+4, c=90)
+ self.block1 = IFBlock(7+4, c=90)
+ self.block2 = IFBlock(7+4, c=90)
+ self.block_tea = IFBlock(10+4, c=90)
+
+ def forward(self, x, scale_list=[4, 2, 1], training=False):
+ if training == False:
+ channel = x.shape[1] // 2
+ img0 = x[:, :channel]
+ img1 = x[:, channel:]
+ flow_list = []
+ merged = []
+ mask_list = []
+ warped_img0 = img0
+ warped_img1 = img1
+ flow = (x[:, :4]).detach() * 0
+ mask = (x[:, :1]).detach() * 0
+ block = [self.block0, self.block1, self.block2]
+ for i in range(3):
+ f0, m0 = block[i](torch.cat((warped_img0[:, :3], warped_img1[:, :3], mask), 1), flow, scale=scale_list[i])
+ f1, m1 = block[i](torch.cat((warped_img1[:, :3], warped_img0[:, :3], -mask), 1), torch.cat((flow[:, 2:4], flow[:, :2]), 1), scale=scale_list[i])
+ flow = flow + (f0 + torch.cat((f1[:, 2:4], f1[:, :2]), 1)) / 2
+ mask = mask + (m0 + (-m1)) / 2
+ mask_list.append(mask)
+ flow_list.append(flow)
+ warped_img0 = warp(img0, flow[:, :2], device=x.device)
+ warped_img1 = warp(img1, flow[:, 2:4], device=x.device)
+ merged.append((warped_img0, warped_img1))
+ '''
+ c0 = self.contextnet(img0, flow[:, :2])
+ c1 = self.contextnet(img1, flow[:, 2:4])
+ tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
+ res = tmp[:, 1:4] * 2 - 1
+ '''
+ for i in range(3):
+ mask_list[i] = torch.sigmoid(mask_list[i])
+ merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i])
+ return flow_list, mask_list[2], merged
+
+ @staticmethod
+ def state_dict_converter():
+ return IFNetStateDictConverter()
+
+
+class IFNetStateDictConverter:
+ def __init__(self):
+ pass
+
+ def from_diffusers(self, state_dict):
+ state_dict_ = {k.replace("module.", ""): v for k, v in state_dict.items()}
+ return state_dict_
+
+ def from_civitai(self, state_dict):
+ return self.from_diffusers(state_dict), {"upcast_to_float32": True}
+
+
+class RIFEInterpolater:
+ def __init__(self, model, device="cuda"):
+ self.model = model
+ self.device = device
+ # IFNet only does not support float16
+ self.torch_dtype = torch.float32
+
+ @staticmethod
+ def from_model_manager(model_manager):
+ return RIFEInterpolater(model_manager.fetch_model("rife"), device=model_manager.device)
+
+ def process_image(self, image):
+ width, height = image.size
+ if width % 32 != 0 or height % 32 != 0:
+ width = (width + 31) // 32
+ height = (height + 31) // 32
+ image = image.resize((width, height))
+ image = torch.Tensor(np.array(image, dtype=np.float32)[:, :, [2,1,0]] / 255).permute(2, 0, 1)
+ return image
+
+ def process_images(self, images):
+ images = [self.process_image(image) for image in images]
+ images = torch.stack(images)
+ return images
+
+ def decode_images(self, images):
+ images = (images[:, [2,1,0]].permute(0, 2, 3, 1) * 255).clip(0, 255).numpy().astype(np.uint8)
+ images = [Image.fromarray(image) for image in images]
+ return images
+
+ def add_interpolated_images(self, images, interpolated_images):
+ output_images = []
+ for image, interpolated_image in zip(images, interpolated_images):
+ output_images.append(image)
+ output_images.append(interpolated_image)
+ output_images.append(images[-1])
+ return output_images
+
+
+ @torch.no_grad()
+ def interpolate_(self, images, scale=1.0):
+ input_tensor = self.process_images(images)
+ input_tensor = torch.cat((input_tensor[:-1], input_tensor[1:]), dim=1)
+ input_tensor = input_tensor.to(device=self.device, dtype=self.torch_dtype)
+ flow, mask, merged = self.model(input_tensor, [4/scale, 2/scale, 1/scale])
+ output_images = self.decode_images(merged[2].cpu())
+ if output_images[0].size != images[0].size:
+ output_images = [image.resize(images[0].size) for image in output_images]
+ return output_images
+
+
+ @torch.no_grad()
+ def interpolate(self, images, scale=1.0, batch_size=4, num_iter=1, progress_bar=lambda x:x):
+ # Preprocess
+ processed_images = self.process_images(images)
+
+ for iter in range(num_iter):
+ # Input
+ input_tensor = torch.cat((processed_images[:-1], processed_images[1:]), dim=1)
+
+ # Interpolate
+ output_tensor = []
+ for batch_id in progress_bar(range(0, input_tensor.shape[0], batch_size)):
+ batch_id_ = min(batch_id + batch_size, input_tensor.shape[0])
+ batch_input_tensor = input_tensor[batch_id: batch_id_]
+ batch_input_tensor = batch_input_tensor.to(device=self.device, dtype=self.torch_dtype)
+ flow, mask, merged = self.model(batch_input_tensor, [4/scale, 2/scale, 1/scale])
+ output_tensor.append(merged[2].cpu())
+
+ # Output
+ output_tensor = torch.concat(output_tensor, dim=0).clip(0, 1)
+ processed_images = self.add_interpolated_images(processed_images, output_tensor)
+ processed_images = torch.stack(processed_images)
+
+ # To images
+ output_images = self.decode_images(processed_images)
+ if output_images[0].size != images[0].size:
+ output_images = [image.resize(images[0].size) for image in output_images]
+ return output_images
+
+
+class RIFESmoother(RIFEInterpolater):
+ def __init__(self, model, device="cuda"):
+ super(RIFESmoother, self).__init__(model, device=device)
+
+ @staticmethod
+ def from_model_manager(model_manager):
+ return RIFEInterpolater(model_manager.fetch_model("rife"), device=model_manager.device)
+
+ def process_tensors(self, input_tensor, scale=1.0, batch_size=4):
+ output_tensor = []
+ for batch_id in range(0, input_tensor.shape[0], batch_size):
+ batch_id_ = min(batch_id + batch_size, input_tensor.shape[0])
+ batch_input_tensor = input_tensor[batch_id: batch_id_]
+ batch_input_tensor = batch_input_tensor.to(device=self.device, dtype=self.torch_dtype)
+ flow, mask, merged = self.model(batch_input_tensor, [4/scale, 2/scale, 1/scale])
+ output_tensor.append(merged[2].cpu())
+ output_tensor = torch.concat(output_tensor, dim=0)
+ return output_tensor
+
+ @torch.no_grad()
+ def __call__(self, rendered_frames, scale=1.0, batch_size=4, num_iter=1, **kwargs):
+ # Preprocess
+ processed_images = self.process_images(rendered_frames)
+
+ for iter in range(num_iter):
+ # Input
+ input_tensor = torch.cat((processed_images[:-2], processed_images[2:]), dim=1)
+
+ # Interpolate
+ output_tensor = self.process_tensors(input_tensor, scale=scale, batch_size=batch_size)
+
+ # Blend
+ input_tensor = torch.cat((processed_images[1:-1], output_tensor), dim=1)
+ output_tensor = self.process_tensors(input_tensor, scale=scale, batch_size=batch_size)
+
+ # Add to frames
+ processed_images[1:-1] = output_tensor
+
+ # To images
+ output_images = self.decode_images(processed_images)
+ if output_images[0].size != rendered_frames[0].size:
+ output_images = [image.resize(rendered_frames[0].size) for image in output_images]
+ return output_images
diff --git a/PusaV1/diffsynth/extensions/RIFE/__pycache__/__init__.cpython-310.pyc b/PusaV1/diffsynth/extensions/RIFE/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6266cd9c47bd4c6e38730379ac9470b63826bf34
Binary files /dev/null and b/PusaV1/diffsynth/extensions/RIFE/__pycache__/__init__.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/extensions/RIFE/__pycache__/__init__.cpython-312.pyc b/PusaV1/diffsynth/extensions/RIFE/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e6d2a964ad0523d9376acd5aecc913b51d020e7c
Binary files /dev/null and b/PusaV1/diffsynth/extensions/RIFE/__pycache__/__init__.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/extensions/__init__.py b/PusaV1/diffsynth/extensions/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/PusaV1/diffsynth/extensions/__pycache__/__init__.cpython-310.pyc b/PusaV1/diffsynth/extensions/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..65aeef098e6eaa8bdd61145c89e431959b7ffe85
Binary files /dev/null and b/PusaV1/diffsynth/extensions/__pycache__/__init__.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/extensions/__pycache__/__init__.cpython-312.pyc b/PusaV1/diffsynth/extensions/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a09fbba812eb8c953f86e276247e4c0d0f327952
Binary files /dev/null and b/PusaV1/diffsynth/extensions/__pycache__/__init__.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/models/__init__.py b/PusaV1/diffsynth/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ac121281da6236373bfa40a1ec8d8511d8aee6b
--- /dev/null
+++ b/PusaV1/diffsynth/models/__init__.py
@@ -0,0 +1,2 @@
+from .model_manager import *
+
diff --git a/PusaV1/diffsynth/models/__pycache__/__init__.cpython-310.pyc b/PusaV1/diffsynth/models/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8858b7409f2a2195b829fce2b6b90c09b94fbdcf
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/__init__.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/__init__.cpython-312.pyc b/PusaV1/diffsynth/models/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..13f2e1205110bbbde211c5f24335dfce1c885518
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/__init__.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/attention.cpython-310.pyc b/PusaV1/diffsynth/models/__pycache__/attention.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f98cf212eff3b2a1f3a54daca32b7b94d75eef89
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/attention.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/attention.cpython-312.pyc b/PusaV1/diffsynth/models/__pycache__/attention.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..61403fb5e94b3bd71a45773bfbc2b1d1dfaa17fe
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/attention.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/cog_dit.cpython-310.pyc b/PusaV1/diffsynth/models/__pycache__/cog_dit.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..10f0f36cd0449b4c9eb7374a3e8b966fb4406da0
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/cog_dit.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/cog_dit.cpython-312.pyc b/PusaV1/diffsynth/models/__pycache__/cog_dit.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..714669a50fefdac6c63a60c259666f78bb789640
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/cog_dit.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/cog_vae.cpython-310.pyc b/PusaV1/diffsynth/models/__pycache__/cog_vae.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..896d17caeb69625e3b4919c8c4d04225452f8ad8
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/cog_vae.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/cog_vae.cpython-312.pyc b/PusaV1/diffsynth/models/__pycache__/cog_vae.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..877be5a0081e3a382cf261ede940d78cfb5f2393
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/cog_vae.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/downloader.cpython-310.pyc b/PusaV1/diffsynth/models/__pycache__/downloader.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a30588691862757ded855f47395847d3a0d459d2
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/downloader.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/downloader.cpython-312.pyc b/PusaV1/diffsynth/models/__pycache__/downloader.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4883991733118db8a4f5d3c9ca3b743da86ce4d7
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/downloader.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/flux_controlnet.cpython-310.pyc b/PusaV1/diffsynth/models/__pycache__/flux_controlnet.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..84fbd104dd9a8307a968f9168adbb19dda446b11
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/flux_controlnet.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/flux_controlnet.cpython-312.pyc b/PusaV1/diffsynth/models/__pycache__/flux_controlnet.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ea48b946754c49d939c85905ded881008f7eab46
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/flux_controlnet.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/flux_dit.cpython-310.pyc b/PusaV1/diffsynth/models/__pycache__/flux_dit.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9f94832c0a3c85e5ef135c3da4bdba9ba4b880d0
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/flux_dit.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/flux_dit.cpython-312.pyc b/PusaV1/diffsynth/models/__pycache__/flux_dit.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d7dc5e9e6bf0fd64326bc3d1774a88c4d0a49dd2
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/flux_dit.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/flux_infiniteyou.cpython-310.pyc b/PusaV1/diffsynth/models/__pycache__/flux_infiniteyou.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..190f353e820cb805778966618fdae86f009e791d
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/flux_infiniteyou.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/flux_infiniteyou.cpython-312.pyc b/PusaV1/diffsynth/models/__pycache__/flux_infiniteyou.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..83b945bb1a9c306f6833fe92abbd64c00654428e
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/flux_infiniteyou.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/flux_ipadapter.cpython-310.pyc b/PusaV1/diffsynth/models/__pycache__/flux_ipadapter.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0904fdc2d0e9ee1e5d101bf0009d399c9e718388
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/flux_ipadapter.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/flux_ipadapter.cpython-312.pyc b/PusaV1/diffsynth/models/__pycache__/flux_ipadapter.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..55ae9e9078fa8df9a44670a41f10921b9b0b7018
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/flux_ipadapter.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/flux_text_encoder.cpython-310.pyc b/PusaV1/diffsynth/models/__pycache__/flux_text_encoder.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..32203c0557e1ccec6ff621f99696b38b01b52a67
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/flux_text_encoder.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/flux_text_encoder.cpython-312.pyc b/PusaV1/diffsynth/models/__pycache__/flux_text_encoder.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c3dacdfb0058fc6a44f948618a9b42ba74e08897
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/flux_text_encoder.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/flux_vae.cpython-310.pyc b/PusaV1/diffsynth/models/__pycache__/flux_vae.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6c0f839f0761691753db12f18c971eaa04efc305
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/flux_vae.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/flux_vae.cpython-312.pyc b/PusaV1/diffsynth/models/__pycache__/flux_vae.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a5b56f9797377de53e408b4be9b8756c2271dc8f
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/flux_vae.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/hunyuan_dit.cpython-310.pyc b/PusaV1/diffsynth/models/__pycache__/hunyuan_dit.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..51a6d8e3b25eb113846fce91b6de56272201184b
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/hunyuan_dit.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/hunyuan_dit.cpython-312.pyc b/PusaV1/diffsynth/models/__pycache__/hunyuan_dit.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1a9c55580e2c57e58fef7920ac3ed2757c5dc927
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/hunyuan_dit.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/hunyuan_dit_text_encoder.cpython-310.pyc b/PusaV1/diffsynth/models/__pycache__/hunyuan_dit_text_encoder.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2b4684def92b6dde071bf75c45a1edcde266184e
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/hunyuan_dit_text_encoder.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/hunyuan_dit_text_encoder.cpython-312.pyc b/PusaV1/diffsynth/models/__pycache__/hunyuan_dit_text_encoder.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e229df5cd39239f91afa27ef786961b5b2007c17
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/hunyuan_dit_text_encoder.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/hunyuan_video_dit.cpython-310.pyc b/PusaV1/diffsynth/models/__pycache__/hunyuan_video_dit.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..eb104a2ae8260991a3a5a72bb62dc582159c9a7a
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/hunyuan_video_dit.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/hunyuan_video_dit.cpython-312.pyc b/PusaV1/diffsynth/models/__pycache__/hunyuan_video_dit.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b6a347bacde0f5c73bbf833a667c713651f71646
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/hunyuan_video_dit.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/hunyuan_video_text_encoder.cpython-310.pyc b/PusaV1/diffsynth/models/__pycache__/hunyuan_video_text_encoder.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b89c19d0ef4a4da6bdc110dd2169925ed2b4a23f
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/hunyuan_video_text_encoder.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/hunyuan_video_text_encoder.cpython-312.pyc b/PusaV1/diffsynth/models/__pycache__/hunyuan_video_text_encoder.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f17d48255bdfabe888d5d7dc757110e0911137ee
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/hunyuan_video_text_encoder.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/hunyuan_video_vae_decoder.cpython-310.pyc b/PusaV1/diffsynth/models/__pycache__/hunyuan_video_vae_decoder.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c48b70d20e821ee84934ea9035d04ff16990584f
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/hunyuan_video_vae_decoder.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/hunyuan_video_vae_decoder.cpython-312.pyc b/PusaV1/diffsynth/models/__pycache__/hunyuan_video_vae_decoder.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6f0edb5ef701c7457bde72d90698369a7b7f689b
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/hunyuan_video_vae_decoder.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/hunyuan_video_vae_encoder.cpython-310.pyc b/PusaV1/diffsynth/models/__pycache__/hunyuan_video_vae_encoder.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b137dacd400c180f433736b2f6e3caf5ab0383b8
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/hunyuan_video_vae_encoder.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/hunyuan_video_vae_encoder.cpython-312.pyc b/PusaV1/diffsynth/models/__pycache__/hunyuan_video_vae_encoder.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..34c3cfa0eed78d759092a5003c69053784d2dc3a
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/hunyuan_video_vae_encoder.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/kolors_text_encoder.cpython-310.pyc b/PusaV1/diffsynth/models/__pycache__/kolors_text_encoder.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f7a7c6f8284c6137d6990ed9f9b875a9b6f2dda8
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/kolors_text_encoder.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/kolors_text_encoder.cpython-312.pyc b/PusaV1/diffsynth/models/__pycache__/kolors_text_encoder.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cdade0149f6693f4205d24a143ad7d362217ce0d
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/kolors_text_encoder.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/lora.cpython-310.pyc b/PusaV1/diffsynth/models/__pycache__/lora.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0b024dc389d604e55b2dbc723dcb5bf9064ac8ef
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/lora.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/lora.cpython-312.pyc b/PusaV1/diffsynth/models/__pycache__/lora.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..76d786f7809cd27550a22a8de8af3361a630a449
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/lora.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/model_manager.cpython-310.pyc b/PusaV1/diffsynth/models/__pycache__/model_manager.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c6916ca13c03d27bc7ffaacfbd2a74748fadb570
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/model_manager.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/model_manager.cpython-312.pyc b/PusaV1/diffsynth/models/__pycache__/model_manager.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..971201625e602fe1fdde7b47271b3522613685d9
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/model_manager.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/model_manager_pusa.cpython-312.pyc b/PusaV1/diffsynth/models/__pycache__/model_manager_pusa.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..edd33b00f6a392cbfab1442524d8cd555be3523c
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/model_manager_pusa.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/omnigen.cpython-310.pyc b/PusaV1/diffsynth/models/__pycache__/omnigen.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..440526a0aaff5e96b4459df080833dc4fa9de165
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/omnigen.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/omnigen.cpython-312.pyc b/PusaV1/diffsynth/models/__pycache__/omnigen.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..33108a325ab5ceb0c315bf2fb84be6ec21c0609f
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/omnigen.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/sd3_dit.cpython-310.pyc b/PusaV1/diffsynth/models/__pycache__/sd3_dit.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..48c49cf0c8396d1ccdd5243594facea6446fac2b
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/sd3_dit.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/sd3_dit.cpython-312.pyc b/PusaV1/diffsynth/models/__pycache__/sd3_dit.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9b254e8f876b241504ff0ebb6d1e1b8afc1982bc
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/sd3_dit.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/sd3_text_encoder.cpython-310.pyc b/PusaV1/diffsynth/models/__pycache__/sd3_text_encoder.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..420926582fc4ce4620ab86d9ba4bee26315253c9
--- /dev/null
+++ b/PusaV1/diffsynth/models/__pycache__/sd3_text_encoder.cpython-310.pyc
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:39ab75e482e3da05d306cf13f5e1d86dd8c45831bc90228bd5293a379b41e1af
+size 104184
diff --git a/PusaV1/diffsynth/models/__pycache__/sd3_text_encoder.cpython-312.pyc b/PusaV1/diffsynth/models/__pycache__/sd3_text_encoder.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9d78b9f82bea0c300e7d06084bcb82a56605dcf0
--- /dev/null
+++ b/PusaV1/diffsynth/models/__pycache__/sd3_text_encoder.cpython-312.pyc
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:40550740e7ee3d35cdc63bd1043ebbe3c63c2dda1e0d07ad9bcecf694290ef84
+size 113549
diff --git a/PusaV1/diffsynth/models/__pycache__/sd3_vae_decoder.cpython-310.pyc b/PusaV1/diffsynth/models/__pycache__/sd3_vae_decoder.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c9669dba1c29788102ff0f3636fdfebf43c23e66
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/sd3_vae_decoder.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/sd3_vae_decoder.cpython-312.pyc b/PusaV1/diffsynth/models/__pycache__/sd3_vae_decoder.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2be4fc707d41e6b4e86ac6be99f8a935b725492c
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/sd3_vae_decoder.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/sd3_vae_encoder.cpython-310.pyc b/PusaV1/diffsynth/models/__pycache__/sd3_vae_encoder.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1e9655fcc0c2f736d85ce3f3298b554da363beb9
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/sd3_vae_encoder.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/sd3_vae_encoder.cpython-312.pyc b/PusaV1/diffsynth/models/__pycache__/sd3_vae_encoder.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7e1fa91d3b3155b70dcaa6f85911bd7048388222
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/sd3_vae_encoder.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/sd_controlnet.cpython-310.pyc b/PusaV1/diffsynth/models/__pycache__/sd_controlnet.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..893271ac57cb3aa3422b8714023c8e06edb2340c
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/sd_controlnet.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/sd_controlnet.cpython-312.pyc b/PusaV1/diffsynth/models/__pycache__/sd_controlnet.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b4548e6dfa6b17c2a0b64d4c64ad133760aa89d3
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/sd_controlnet.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/sd_ipadapter.cpython-310.pyc b/PusaV1/diffsynth/models/__pycache__/sd_ipadapter.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..22ff62d235a73f164d605c69bc4347f4319c5931
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/sd_ipadapter.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/sd_ipadapter.cpython-312.pyc b/PusaV1/diffsynth/models/__pycache__/sd_ipadapter.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e8820fb26ab8c700c0be9160519153a2b592910a
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/sd_ipadapter.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/sd_motion.cpython-310.pyc b/PusaV1/diffsynth/models/__pycache__/sd_motion.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..71453b5e80c5de2720d5f3948109aef06268f8a1
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/sd_motion.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/sd_motion.cpython-312.pyc b/PusaV1/diffsynth/models/__pycache__/sd_motion.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..738e5e4e1eedee045ba18ab554e847f42700292c
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/sd_motion.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/sd_text_encoder.cpython-310.pyc b/PusaV1/diffsynth/models/__pycache__/sd_text_encoder.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b6270f321289f48d4a74991be00695e1928e7e67
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/sd_text_encoder.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/sd_text_encoder.cpython-312.pyc b/PusaV1/diffsynth/models/__pycache__/sd_text_encoder.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3dae50c07b8fa9955ca31ccafe2baad80be88ff6
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/sd_text_encoder.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/sd_unet.cpython-310.pyc b/PusaV1/diffsynth/models/__pycache__/sd_unet.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b1790739c1379d23fd8ffadc916492125fa90b36
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/sd_unet.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/sd_unet.cpython-312.pyc b/PusaV1/diffsynth/models/__pycache__/sd_unet.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..10ed19bf38610123a9cdf3576836611c0ff43dbc
--- /dev/null
+++ b/PusaV1/diffsynth/models/__pycache__/sd_unet.cpython-312.pyc
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:61b7ffb444331874aa8e68fbc4fc2359bb6b5be500a646dab2e81cb7de132e28
+size 112672
diff --git a/PusaV1/diffsynth/models/__pycache__/sd_vae_decoder.cpython-310.pyc b/PusaV1/diffsynth/models/__pycache__/sd_vae_decoder.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3ab766bb57af5d4132ff137831a574b64e343f96
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/sd_vae_decoder.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/sd_vae_decoder.cpython-312.pyc b/PusaV1/diffsynth/models/__pycache__/sd_vae_decoder.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8f74f6d878860fb22d2f6c2063e16ecf50b8aed8
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/sd_vae_decoder.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/sd_vae_encoder.cpython-310.pyc b/PusaV1/diffsynth/models/__pycache__/sd_vae_encoder.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..34bb4001069bccacf6c152650251faaaf1e2d2fa
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/sd_vae_encoder.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/sd_vae_encoder.cpython-312.pyc b/PusaV1/diffsynth/models/__pycache__/sd_vae_encoder.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3854470cc1e85a14d6965452d4e261ab660c4f6c
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/sd_vae_encoder.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/sdxl_controlnet.cpython-310.pyc b/PusaV1/diffsynth/models/__pycache__/sdxl_controlnet.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d1bd5982a09865cd43e4cc0edc0299dcd49fb56b
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/sdxl_controlnet.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/sdxl_controlnet.cpython-312.pyc b/PusaV1/diffsynth/models/__pycache__/sdxl_controlnet.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..84007bc92f2b868b58c271f9a73d42fa1ff96a07
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/sdxl_controlnet.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/sdxl_ipadapter.cpython-310.pyc b/PusaV1/diffsynth/models/__pycache__/sdxl_ipadapter.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a89c57ebce1f2aa4d8ac6ac1d4894afdc63c66ef
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/sdxl_ipadapter.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/sdxl_ipadapter.cpython-312.pyc b/PusaV1/diffsynth/models/__pycache__/sdxl_ipadapter.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bb016b8e2f77cf9c021b34f10f2db9ba1adbfadc
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/sdxl_ipadapter.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/sdxl_motion.cpython-310.pyc b/PusaV1/diffsynth/models/__pycache__/sdxl_motion.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3bbca706c3cab95a047fbfbd9f2a7f4a4df3105e
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/sdxl_motion.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/sdxl_motion.cpython-312.pyc b/PusaV1/diffsynth/models/__pycache__/sdxl_motion.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ba62bb35b117a8a33ded99ce594c2d14bc08cf54
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/sdxl_motion.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/sdxl_text_encoder.cpython-310.pyc b/PusaV1/diffsynth/models/__pycache__/sdxl_text_encoder.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7c3fce6fbd5a9891b132031eacd2673b7ade490e
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/sdxl_text_encoder.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/sdxl_text_encoder.cpython-312.pyc b/PusaV1/diffsynth/models/__pycache__/sdxl_text_encoder.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8272f1d1024134ff1fa61b5a2f0ac0e509e54cfd
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/sdxl_text_encoder.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/sdxl_unet.cpython-310.pyc b/PusaV1/diffsynth/models/__pycache__/sdxl_unet.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7af7946f67e7b8fff4c47ce918f591ff56272c00
--- /dev/null
+++ b/PusaV1/diffsynth/models/__pycache__/sdxl_unet.cpython-310.pyc
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a0369fb6fa563ce7937f6d426e836531a8011d0dc8aa3a267df6de9fe45e65ff
+size 268335
diff --git a/PusaV1/diffsynth/models/__pycache__/sdxl_unet.cpython-312.pyc b/PusaV1/diffsynth/models/__pycache__/sdxl_unet.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a0abeb2c7c41b18e388062dddcc9ec69a0e4bca9
--- /dev/null
+++ b/PusaV1/diffsynth/models/__pycache__/sdxl_unet.cpython-312.pyc
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:66459950a997bc47ae45bd24bd960659b2d4c1791c80b796096bd9888659bd62
+size 259748
diff --git a/PusaV1/diffsynth/models/__pycache__/sdxl_vae_decoder.cpython-310.pyc b/PusaV1/diffsynth/models/__pycache__/sdxl_vae_decoder.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a4756c2254157754394bc4caa035acf0d409401a
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/sdxl_vae_decoder.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/sdxl_vae_decoder.cpython-312.pyc b/PusaV1/diffsynth/models/__pycache__/sdxl_vae_decoder.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cc840c5885b29a5879614340bb69bc482030f765
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/sdxl_vae_decoder.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/sdxl_vae_encoder.cpython-310.pyc b/PusaV1/diffsynth/models/__pycache__/sdxl_vae_encoder.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..10a25ea0525951442ce834f00c9a419784a4d9d0
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/sdxl_vae_encoder.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/sdxl_vae_encoder.cpython-312.pyc b/PusaV1/diffsynth/models/__pycache__/sdxl_vae_encoder.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c222ba6cba4c410c63d99b09de0ce056a945fd2c
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/sdxl_vae_encoder.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/stepvideo_dit.cpython-310.pyc b/PusaV1/diffsynth/models/__pycache__/stepvideo_dit.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..de5f66f568846e42ba829bf3d75609d894ebdc53
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/stepvideo_dit.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/stepvideo_dit.cpython-312.pyc b/PusaV1/diffsynth/models/__pycache__/stepvideo_dit.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5d5c5add288a4a4174242baec0690cf0c9efb1f0
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/stepvideo_dit.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/stepvideo_text_encoder.cpython-310.pyc b/PusaV1/diffsynth/models/__pycache__/stepvideo_text_encoder.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..578905a0ad24205da5a7363d0a2695e0ac768028
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/stepvideo_text_encoder.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/stepvideo_text_encoder.cpython-312.pyc b/PusaV1/diffsynth/models/__pycache__/stepvideo_text_encoder.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dd559a5ceaaeafd94061f71a496583de7471391d
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/stepvideo_text_encoder.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/stepvideo_vae.cpython-310.pyc b/PusaV1/diffsynth/models/__pycache__/stepvideo_vae.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..62367b4967a49f5b221b5f588c3242c673794ac0
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/stepvideo_vae.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/stepvideo_vae.cpython-312.pyc b/PusaV1/diffsynth/models/__pycache__/stepvideo_vae.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b32065c3c518190e2e5fadb91367ee6fc8f68464
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/stepvideo_vae.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/svd_image_encoder.cpython-310.pyc b/PusaV1/diffsynth/models/__pycache__/svd_image_encoder.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3a30a96d31622d128ff632323f749f399e681ca0
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/svd_image_encoder.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/svd_image_encoder.cpython-312.pyc b/PusaV1/diffsynth/models/__pycache__/svd_image_encoder.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3547fa7d4ddaba6c75c22068249bd36e665342a4
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/svd_image_encoder.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/svd_unet.cpython-310.pyc b/PusaV1/diffsynth/models/__pycache__/svd_unet.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..df479c0109302b9a0c871665873371e9ae383d83
--- /dev/null
+++ b/PusaV1/diffsynth/models/__pycache__/svd_unet.cpython-310.pyc
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f776d68576d6c28e69252b7adea681a0ec7cbf7772d57ce78b99ca6e31b222ee
+size 216144
diff --git a/PusaV1/diffsynth/models/__pycache__/svd_unet.cpython-312.pyc b/PusaV1/diffsynth/models/__pycache__/svd_unet.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5d8bd03daa54936b0870a192146cf5012ca52625
--- /dev/null
+++ b/PusaV1/diffsynth/models/__pycache__/svd_unet.cpython-312.pyc
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8cfce243c941d3757b77facc8d0561e954d56590d0f13fbdaa305fb4d14f8243
+size 224997
diff --git a/PusaV1/diffsynth/models/__pycache__/svd_vae_decoder.cpython-310.pyc b/PusaV1/diffsynth/models/__pycache__/svd_vae_decoder.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b0ef497da65302986e7394521ca9a46f1217d6d0
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/svd_vae_decoder.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/svd_vae_decoder.cpython-312.pyc b/PusaV1/diffsynth/models/__pycache__/svd_vae_decoder.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a228ad599fa3af8fd2fa398dd56d05b8fd832d4b
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/svd_vae_decoder.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/svd_vae_encoder.cpython-310.pyc b/PusaV1/diffsynth/models/__pycache__/svd_vae_encoder.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..93aae7f09a1019a8bb7bc3a8f10afdc38bfdab2b
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/svd_vae_encoder.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/svd_vae_encoder.cpython-312.pyc b/PusaV1/diffsynth/models/__pycache__/svd_vae_encoder.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ba53092d0409d1c55f824de831cd75e4a5d2a280
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/svd_vae_encoder.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/tiler.cpython-310.pyc b/PusaV1/diffsynth/models/__pycache__/tiler.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ef7757cd46062b8dac46ab024b614e8e958214bb
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/tiler.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/tiler.cpython-312.pyc b/PusaV1/diffsynth/models/__pycache__/tiler.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..219d4daedd99d705d281e05937d5d4e7d3f59581
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/tiler.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/utils.cpython-310.pyc b/PusaV1/diffsynth/models/__pycache__/utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f5fdb843a5db8c38119154fb105b196fb0532275
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/utils.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/utils.cpython-312.pyc b/PusaV1/diffsynth/models/__pycache__/utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..aadbb453b9f2800904d5ba2b314f964d6f3d4892
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/utils.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/wan_video_dit.cpython-310.pyc b/PusaV1/diffsynth/models/__pycache__/wan_video_dit.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..aceee66ee3538a23dc0fc393c255ee6a712d8043
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/wan_video_dit.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/wan_video_dit.cpython-312.pyc b/PusaV1/diffsynth/models/__pycache__/wan_video_dit.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9bfa19f81496186d14e74ead5cccef664402385e
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/wan_video_dit.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/wan_video_image_encoder.cpython-310.pyc b/PusaV1/diffsynth/models/__pycache__/wan_video_image_encoder.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7af71cc837fca71953036d076ed576cb5ea5c5de
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/wan_video_image_encoder.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/wan_video_image_encoder.cpython-312.pyc b/PusaV1/diffsynth/models/__pycache__/wan_video_image_encoder.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..69ef188d63348a188c3bd04d5b61af9a8cacdbb3
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/wan_video_image_encoder.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/wan_video_motion_controller.cpython-310.pyc b/PusaV1/diffsynth/models/__pycache__/wan_video_motion_controller.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1e30501169a458789ae33a3a7c8aa9565040d999
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/wan_video_motion_controller.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/wan_video_motion_controller.cpython-312.pyc b/PusaV1/diffsynth/models/__pycache__/wan_video_motion_controller.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..db836849985aec6c14ca39b9e01a665ca61e60a9
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/wan_video_motion_controller.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/wan_video_pusa.cpython-310.pyc b/PusaV1/diffsynth/models/__pycache__/wan_video_pusa.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8c4ff76aaac1ad7b889044efcab884c0ec2f497b
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/wan_video_pusa.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/wan_video_pusa.cpython-312.pyc b/PusaV1/diffsynth/models/__pycache__/wan_video_pusa.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3356e6cb688afb6af8130c52a27d4ffd90cda2a1
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/wan_video_pusa.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/wan_video_text_encoder.cpython-310.pyc b/PusaV1/diffsynth/models/__pycache__/wan_video_text_encoder.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4b249ce47dd75d37621050688d03692597c395a3
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/wan_video_text_encoder.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/wan_video_text_encoder.cpython-312.pyc b/PusaV1/diffsynth/models/__pycache__/wan_video_text_encoder.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0b9e09c0e75a202485641e393b4cbbb25f78e618
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/wan_video_text_encoder.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/wan_video_vace.cpython-310.pyc b/PusaV1/diffsynth/models/__pycache__/wan_video_vace.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9d9b632687b1aa9c746d429c70eb24b4bce0d5a7
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/wan_video_vace.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/wan_video_vace.cpython-312.pyc b/PusaV1/diffsynth/models/__pycache__/wan_video_vace.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..00f50fc1f6e545852a303083f5e3458799f81e1a
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/wan_video_vace.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/wan_video_vae.cpython-310.pyc b/PusaV1/diffsynth/models/__pycache__/wan_video_vae.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..aeb8326bc010ec1ca7c5ed274f4ccdafed96b525
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/wan_video_vae.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/models/__pycache__/wan_video_vae.cpython-312.pyc b/PusaV1/diffsynth/models/__pycache__/wan_video_vae.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ae8f5ce565b823b4ce58544d2a78b0aeffee5151
Binary files /dev/null and b/PusaV1/diffsynth/models/__pycache__/wan_video_vae.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/models/attention.py b/PusaV1/diffsynth/models/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb90e1ed1a28a0541a8d9df8313997a7d3f14da7
--- /dev/null
+++ b/PusaV1/diffsynth/models/attention.py
@@ -0,0 +1,89 @@
+import torch
+from einops import rearrange
+
+
+def low_version_attention(query, key, value, attn_bias=None):
+ scale = 1 / query.shape[-1] ** 0.5
+ query = query * scale
+ attn = torch.matmul(query, key.transpose(-2, -1))
+ if attn_bias is not None:
+ attn = attn + attn_bias
+ attn = attn.softmax(-1)
+ return attn @ value
+
+
+class Attention(torch.nn.Module):
+
+ def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
+ super().__init__()
+ dim_inner = head_dim * num_heads
+ kv_dim = kv_dim if kv_dim is not None else q_dim
+ self.num_heads = num_heads
+ self.head_dim = head_dim
+
+ self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)
+ self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
+ self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
+ self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)
+
+ def interact_with_ipadapter(self, hidden_states, q, ip_k, ip_v, scale=1.0):
+ batch_size = q.shape[0]
+ ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
+ ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
+ ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v)
+ hidden_states = hidden_states + scale * ip_hidden_states
+ return hidden_states
+
+ def torch_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+
+ batch_size = encoder_hidden_states.shape[0]
+
+ q = self.to_q(hidden_states)
+ k = self.to_k(encoder_hidden_states)
+ v = self.to_v(encoder_hidden_states)
+
+ q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
+ k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
+ v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
+
+ if qkv_preprocessor is not None:
+ q, k, v = qkv_preprocessor(q, k, v)
+
+ hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
+ if ipadapter_kwargs is not None:
+ hidden_states = self.interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs)
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
+ hidden_states = hidden_states.to(q.dtype)
+
+ hidden_states = self.to_out(hidden_states)
+
+ return hidden_states
+
+ def xformers_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+
+ q = self.to_q(hidden_states)
+ k = self.to_k(encoder_hidden_states)
+ v = self.to_v(encoder_hidden_states)
+
+ q = rearrange(q, "b f (n d) -> (b n) f d", n=self.num_heads)
+ k = rearrange(k, "b f (n d) -> (b n) f d", n=self.num_heads)
+ v = rearrange(v, "b f (n d) -> (b n) f d", n=self.num_heads)
+
+ if attn_mask is not None:
+ hidden_states = low_version_attention(q, k, v, attn_bias=attn_mask)
+ else:
+ import xformers.ops as xops
+ hidden_states = xops.memory_efficient_attention(q, k, v)
+ hidden_states = rearrange(hidden_states, "(b n) f d -> b f (n d)", n=self.num_heads)
+
+ hidden_states = hidden_states.to(q.dtype)
+ hidden_states = self.to_out(hidden_states)
+
+ return hidden_states
+
+ def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):
+ return self.torch_forward(hidden_states, encoder_hidden_states=encoder_hidden_states, attn_mask=attn_mask, ipadapter_kwargs=ipadapter_kwargs, qkv_preprocessor=qkv_preprocessor)
\ No newline at end of file
diff --git a/PusaV1/diffsynth/models/cog_dit.py b/PusaV1/diffsynth/models/cog_dit.py
new file mode 100644
index 0000000000000000000000000000000000000000..e93c4c38684c6815c099774dd4e3c8291462cd78
--- /dev/null
+++ b/PusaV1/diffsynth/models/cog_dit.py
@@ -0,0 +1,408 @@
+import torch
+from einops import rearrange, repeat
+from .sd3_dit import TimestepEmbeddings
+from .attention import Attention
+from .utils import load_state_dict_from_folder
+from .tiler import TileWorker2Dto3D
+import numpy as np
+
+
+
+class CogPatchify(torch.nn.Module):
+ def __init__(self, dim_in, dim_out, patch_size) -> None:
+ super().__init__()
+ self.proj = torch.nn.Conv3d(dim_in, dim_out, kernel_size=(1, patch_size, patch_size), stride=(1, patch_size, patch_size))
+
+ def forward(self, hidden_states):
+ hidden_states = self.proj(hidden_states)
+ hidden_states = rearrange(hidden_states, "B C T H W -> B (T H W) C")
+ return hidden_states
+
+
+
+class CogAdaLayerNorm(torch.nn.Module):
+ def __init__(self, dim, dim_cond, single=False):
+ super().__init__()
+ self.single = single
+ self.linear = torch.nn.Linear(dim_cond, dim * (2 if single else 6))
+ self.norm = torch.nn.LayerNorm(dim, elementwise_affine=True, eps=1e-5)
+
+
+ def forward(self, hidden_states, prompt_emb, emb):
+ emb = self.linear(torch.nn.functional.silu(emb))
+ if self.single:
+ shift, scale = emb.unsqueeze(1).chunk(2, dim=2)
+ hidden_states = self.norm(hidden_states) * (1 + scale) + shift
+ return hidden_states
+ else:
+ shift_a, scale_a, gate_a, shift_b, scale_b, gate_b = emb.unsqueeze(1).chunk(6, dim=2)
+ hidden_states = self.norm(hidden_states) * (1 + scale_a) + shift_a
+ prompt_emb = self.norm(prompt_emb) * (1 + scale_b) + shift_b
+ return hidden_states, prompt_emb, gate_a, gate_b
+
+
+
+class CogDiTBlock(torch.nn.Module):
+ def __init__(self, dim, dim_cond, num_heads):
+ super().__init__()
+ self.norm1 = CogAdaLayerNorm(dim, dim_cond)
+ self.attn1 = Attention(q_dim=dim, num_heads=48, head_dim=dim//num_heads, bias_q=True, bias_kv=True, bias_out=True)
+ self.norm_q = torch.nn.LayerNorm((dim//num_heads,), eps=1e-06, elementwise_affine=True)
+ self.norm_k = torch.nn.LayerNorm((dim//num_heads,), eps=1e-06, elementwise_affine=True)
+
+ self.norm2 = CogAdaLayerNorm(dim, dim_cond)
+ self.ff = torch.nn.Sequential(
+ torch.nn.Linear(dim, dim*4),
+ torch.nn.GELU(approximate="tanh"),
+ torch.nn.Linear(dim*4, dim)
+ )
+
+
+ def apply_rotary_emb(self, x, freqs_cis):
+ cos, sin = freqs_cis # [S, D]
+ cos = cos[None, None]
+ sin = sin[None, None]
+ cos, sin = cos.to(x.device), sin.to(x.device)
+ x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
+ out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
+ return out
+
+
+ def process_qkv(self, q, k, v, image_rotary_emb, text_seq_length):
+ q = self.norm_q(q)
+ k = self.norm_k(k)
+ q[:, :, text_seq_length:] = self.apply_rotary_emb(q[:, :, text_seq_length:], image_rotary_emb)
+ k[:, :, text_seq_length:] = self.apply_rotary_emb(k[:, :, text_seq_length:], image_rotary_emb)
+ return q, k, v
+
+
+ def forward(self, hidden_states, prompt_emb, time_emb, image_rotary_emb):
+ # Attention
+ norm_hidden_states, norm_encoder_hidden_states, gate_a, gate_b = self.norm1(
+ hidden_states, prompt_emb, time_emb
+ )
+ attention_io = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
+ attention_io = self.attn1(
+ attention_io,
+ qkv_preprocessor=lambda q, k, v: self.process_qkv(q, k, v, image_rotary_emb, prompt_emb.shape[1])
+ )
+
+ hidden_states = hidden_states + gate_a * attention_io[:, prompt_emb.shape[1]:]
+ prompt_emb = prompt_emb + gate_b * attention_io[:, :prompt_emb.shape[1]]
+
+ # Feed forward
+ norm_hidden_states, norm_encoder_hidden_states, gate_a, gate_b = self.norm2(
+ hidden_states, prompt_emb, time_emb
+ )
+ ff_io = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
+ ff_io = self.ff(ff_io)
+
+ hidden_states = hidden_states + gate_a * ff_io[:, prompt_emb.shape[1]:]
+ prompt_emb = prompt_emb + gate_b * ff_io[:, :prompt_emb.shape[1]]
+
+ return hidden_states, prompt_emb
+
+
+
+class CogDiT(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.patchify = CogPatchify(16, 3072, 2)
+ self.time_embedder = TimestepEmbeddings(3072, 512)
+ self.context_embedder = torch.nn.Linear(4096, 3072)
+ self.blocks = torch.nn.ModuleList([CogDiTBlock(3072, 512, 48) for _ in range(42)])
+ self.norm_final = torch.nn.LayerNorm((3072,), eps=1e-05, elementwise_affine=True)
+ self.norm_out = CogAdaLayerNorm(3072, 512, single=True)
+ self.proj_out = torch.nn.Linear(3072, 64, bias=True)
+
+
+ def get_resize_crop_region_for_grid(self, src, tgt_width, tgt_height):
+ tw = tgt_width
+ th = tgt_height
+ h, w = src
+ r = h / w
+ if r > (th / tw):
+ resize_height = th
+ resize_width = int(round(th / h * w))
+ else:
+ resize_width = tw
+ resize_height = int(round(tw / w * h))
+
+ crop_top = int(round((th - resize_height) / 2.0))
+ crop_left = int(round((tw - resize_width) / 2.0))
+
+ return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
+
+
+ def get_3d_rotary_pos_embed(
+ self, embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True
+ ):
+ start, stop = crops_coords
+ grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
+ grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32)
+ grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
+
+ # Compute dimensions for each axis
+ dim_t = embed_dim // 4
+ dim_h = embed_dim // 8 * 3
+ dim_w = embed_dim // 8 * 3
+
+ # Temporal frequencies
+ freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2).float() / dim_t))
+ grid_t = torch.from_numpy(grid_t).float()
+ freqs_t = torch.einsum("n , f -> n f", grid_t, freqs_t)
+ freqs_t = freqs_t.repeat_interleave(2, dim=-1)
+
+ # Spatial frequencies for height and width
+ freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2).float() / dim_h))
+ freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2).float() / dim_w))
+ grid_h = torch.from_numpy(grid_h).float()
+ grid_w = torch.from_numpy(grid_w).float()
+ freqs_h = torch.einsum("n , f -> n f", grid_h, freqs_h)
+ freqs_w = torch.einsum("n , f -> n f", grid_w, freqs_w)
+ freqs_h = freqs_h.repeat_interleave(2, dim=-1)
+ freqs_w = freqs_w.repeat_interleave(2, dim=-1)
+
+ # Broadcast and concatenate tensors along specified dimension
+ def broadcast(tensors, dim=-1):
+ num_tensors = len(tensors)
+ shape_lens = {len(t.shape) for t in tensors}
+ assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
+ shape_len = list(shape_lens)[0]
+ dim = (dim + shape_len) if dim < 0 else dim
+ dims = list(zip(*(list(t.shape) for t in tensors)))
+ expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
+ assert all(
+ [*(len(set(t[1])) <= 2 for t in expandable_dims)]
+ ), "invalid dimensions for broadcastable concatenation"
+ max_dims = [(t[0], max(t[1])) for t in expandable_dims]
+ expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims]
+ expanded_dims.insert(dim, (dim, dims[dim]))
+ expandable_shapes = list(zip(*(t[1] for t in expanded_dims)))
+ tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)]
+ return torch.cat(tensors, dim=dim)
+
+ freqs = broadcast((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
+
+ t, h, w, d = freqs.shape
+ freqs = freqs.view(t * h * w, d)
+
+ # Generate sine and cosine components
+ sin = freqs.sin()
+ cos = freqs.cos()
+
+ if use_real:
+ return cos, sin
+ else:
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
+ return freqs_cis
+
+
+ def prepare_rotary_positional_embeddings(
+ self,
+ height: int,
+ width: int,
+ num_frames: int,
+ device: torch.device,
+ ):
+ grid_height = height // 2
+ grid_width = width // 2
+ base_size_width = 720 // (8 * 2)
+ base_size_height = 480 // (8 * 2)
+
+ grid_crops_coords = self.get_resize_crop_region_for_grid(
+ (grid_height, grid_width), base_size_width, base_size_height
+ )
+ freqs_cos, freqs_sin = self.get_3d_rotary_pos_embed(
+ embed_dim=64,
+ crops_coords=grid_crops_coords,
+ grid_size=(grid_height, grid_width),
+ temporal_size=num_frames,
+ use_real=True,
+ )
+
+ freqs_cos = freqs_cos.to(device=device)
+ freqs_sin = freqs_sin.to(device=device)
+ return freqs_cos, freqs_sin
+
+
+ def unpatchify(self, hidden_states, height, width):
+ hidden_states = rearrange(hidden_states, "B (T H W) (C P Q) -> B C T (H P) (W Q)", P=2, Q=2, H=height//2, W=width//2)
+ return hidden_states
+
+
+ def build_mask(self, T, H, W, dtype, device, is_bound):
+ t = repeat(torch.arange(T), "T -> T H W", T=T, H=H, W=W)
+ h = repeat(torch.arange(H), "H -> T H W", T=T, H=H, W=W)
+ w = repeat(torch.arange(W), "W -> T H W", T=T, H=H, W=W)
+ border_width = (H + W) // 4
+ pad = torch.ones_like(h) * border_width
+ mask = torch.stack([
+ pad if is_bound[0] else t + 1,
+ pad if is_bound[1] else T - t,
+ pad if is_bound[2] else h + 1,
+ pad if is_bound[3] else H - h,
+ pad if is_bound[4] else w + 1,
+ pad if is_bound[5] else W - w
+ ]).min(dim=0).values
+ mask = mask.clip(1, border_width)
+ mask = (mask / border_width).to(dtype=dtype, device=device)
+ mask = rearrange(mask, "T H W -> 1 1 T H W")
+ return mask
+
+
+ def tiled_forward(self, hidden_states, timestep, prompt_emb, tile_size=(60, 90), tile_stride=(30, 45)):
+ B, C, T, H, W = hidden_states.shape
+ value = torch.zeros((B, C, T, H, W), dtype=hidden_states.dtype, device=hidden_states.device)
+ weight = torch.zeros((B, C, T, H, W), dtype=hidden_states.dtype, device=hidden_states.device)
+
+ # Split tasks
+ tasks = []
+ for h in range(0, H, tile_stride):
+ for w in range(0, W, tile_stride):
+ if (h-tile_stride >= 0 and h-tile_stride+tile_size >= H) or (w-tile_stride >= 0 and w-tile_stride+tile_size >= W):
+ continue
+ h_, w_ = h + tile_size, w + tile_size
+ if h_ > H: h, h_ = max(H - tile_size, 0), H
+ if w_ > W: w, w_ = max(W - tile_size, 0), W
+ tasks.append((h, h_, w, w_))
+
+ # Run
+ for hl, hr, wl, wr in tasks:
+ mask = self.build_mask(
+ value.shape[2], (hr-hl), (wr-wl),
+ hidden_states.dtype, hidden_states.device,
+ is_bound=(True, True, hl==0, hr>=H, wl==0, wr>=W)
+ )
+ model_output = self.forward(hidden_states[:, :, :, hl:hr, wl:wr], timestep, prompt_emb)
+ value[:, :, :, hl:hr, wl:wr] += model_output * mask
+ weight[:, :, :, hl:hr, wl:wr] += mask
+ value = value / weight
+
+ return value
+
+
+ def forward(self, hidden_states, timestep, prompt_emb, image_rotary_emb=None, tiled=False, tile_size=90, tile_stride=30, use_gradient_checkpointing=False):
+ if tiled:
+ return TileWorker2Dto3D().tiled_forward(
+ forward_fn=lambda x: self.forward(x, timestep, prompt_emb),
+ model_input=hidden_states,
+ tile_size=tile_size, tile_stride=tile_stride,
+ tile_device=hidden_states.device, tile_dtype=hidden_states.dtype,
+ computation_device=self.context_embedder.weight.device, computation_dtype=self.context_embedder.weight.dtype
+ )
+ num_frames, height, width = hidden_states.shape[-3:]
+ if image_rotary_emb is None:
+ image_rotary_emb = self.prepare_rotary_positional_embeddings(height, width, num_frames, device=self.context_embedder.weight.device)
+ hidden_states = self.patchify(hidden_states)
+ time_emb = self.time_embedder(timestep, dtype=hidden_states.dtype)
+ prompt_emb = self.context_embedder(prompt_emb)
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+ return custom_forward
+
+ for block in self.blocks:
+ if self.training and use_gradient_checkpointing:
+ hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ hidden_states, prompt_emb, time_emb, image_rotary_emb,
+ use_reentrant=False,
+ )
+ else:
+ hidden_states, prompt_emb = block(hidden_states, prompt_emb, time_emb, image_rotary_emb)
+
+ hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
+ hidden_states = self.norm_final(hidden_states)
+ hidden_states = hidden_states[:, prompt_emb.shape[1]:]
+ hidden_states = self.norm_out(hidden_states, prompt_emb, time_emb)
+ hidden_states = self.proj_out(hidden_states)
+ hidden_states = self.unpatchify(hidden_states, height, width)
+
+ return hidden_states
+
+
+ @staticmethod
+ def state_dict_converter():
+ return CogDiTStateDictConverter()
+
+
+ @staticmethod
+ def from_pretrained(file_path, torch_dtype=torch.bfloat16):
+ model = CogDiT().to(torch_dtype)
+ state_dict = load_state_dict_from_folder(file_path, torch_dtype=torch_dtype)
+ state_dict = CogDiT.state_dict_converter().from_diffusers(state_dict)
+ model.load_state_dict(state_dict)
+ return model
+
+
+
+class CogDiTStateDictConverter:
+ def __init__(self):
+ pass
+
+
+ def from_diffusers(self, state_dict):
+ rename_dict = {
+ "patch_embed.proj.weight": "patchify.proj.weight",
+ "patch_embed.proj.bias": "patchify.proj.bias",
+ "patch_embed.text_proj.weight": "context_embedder.weight",
+ "patch_embed.text_proj.bias": "context_embedder.bias",
+ "time_embedding.linear_1.weight": "time_embedder.timestep_embedder.0.weight",
+ "time_embedding.linear_1.bias": "time_embedder.timestep_embedder.0.bias",
+ "time_embedding.linear_2.weight": "time_embedder.timestep_embedder.2.weight",
+ "time_embedding.linear_2.bias": "time_embedder.timestep_embedder.2.bias",
+
+ "norm_final.weight": "norm_final.weight",
+ "norm_final.bias": "norm_final.bias",
+ "norm_out.linear.weight": "norm_out.linear.weight",
+ "norm_out.linear.bias": "norm_out.linear.bias",
+ "norm_out.norm.weight": "norm_out.norm.weight",
+ "norm_out.norm.bias": "norm_out.norm.bias",
+ "proj_out.weight": "proj_out.weight",
+ "proj_out.bias": "proj_out.bias",
+ }
+ suffix_dict = {
+ "norm1.linear.weight": "norm1.linear.weight",
+ "norm1.linear.bias": "norm1.linear.bias",
+ "norm1.norm.weight": "norm1.norm.weight",
+ "norm1.norm.bias": "norm1.norm.bias",
+ "attn1.norm_q.weight": "norm_q.weight",
+ "attn1.norm_q.bias": "norm_q.bias",
+ "attn1.norm_k.weight": "norm_k.weight",
+ "attn1.norm_k.bias": "norm_k.bias",
+ "attn1.to_q.weight": "attn1.to_q.weight",
+ "attn1.to_q.bias": "attn1.to_q.bias",
+ "attn1.to_k.weight": "attn1.to_k.weight",
+ "attn1.to_k.bias": "attn1.to_k.bias",
+ "attn1.to_v.weight": "attn1.to_v.weight",
+ "attn1.to_v.bias": "attn1.to_v.bias",
+ "attn1.to_out.0.weight": "attn1.to_out.weight",
+ "attn1.to_out.0.bias": "attn1.to_out.bias",
+ "norm2.linear.weight": "norm2.linear.weight",
+ "norm2.linear.bias": "norm2.linear.bias",
+ "norm2.norm.weight": "norm2.norm.weight",
+ "norm2.norm.bias": "norm2.norm.bias",
+ "ff.net.0.proj.weight": "ff.0.weight",
+ "ff.net.0.proj.bias": "ff.0.bias",
+ "ff.net.2.weight": "ff.2.weight",
+ "ff.net.2.bias": "ff.2.bias",
+ }
+ state_dict_ = {}
+ for name, param in state_dict.items():
+ if name in rename_dict:
+ if name == "patch_embed.proj.weight":
+ param = param.unsqueeze(2)
+ state_dict_[rename_dict[name]] = param
+ else:
+ names = name.split(".")
+ if names[0] == "transformer_blocks":
+ suffix = ".".join(names[2:])
+ state_dict_[f"blocks.{names[1]}." + suffix_dict[suffix]] = param
+ return state_dict_
+
+
+ def from_civitai(self, state_dict):
+ return self.from_diffusers(state_dict)
diff --git a/PusaV1/diffsynth/models/cog_vae.py b/PusaV1/diffsynth/models/cog_vae.py
new file mode 100644
index 0000000000000000000000000000000000000000..24ab3b3f37e111ffabf7d56c582637c6dc1c80b8
--- /dev/null
+++ b/PusaV1/diffsynth/models/cog_vae.py
@@ -0,0 +1,518 @@
+import torch
+from einops import rearrange, repeat
+from .tiler import TileWorker2Dto3D
+
+
+
+class Downsample3D(torch.nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int = 3,
+ stride: int = 2,
+ padding: int = 0,
+ compress_time: bool = False,
+ ):
+ super().__init__()
+
+ self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
+ self.compress_time = compress_time
+
+ def forward(self, x: torch.Tensor, xq: torch.Tensor) -> torch.Tensor:
+ if self.compress_time:
+ batch_size, channels, frames, height, width = x.shape
+
+ # (batch_size, channels, frames, height, width) -> (batch_size, height, width, channels, frames) -> (batch_size * height * width, channels, frames)
+ x = x.permute(0, 3, 4, 1, 2).reshape(batch_size * height * width, channels, frames)
+
+ if x.shape[-1] % 2 == 1:
+ x_first, x_rest = x[..., 0], x[..., 1:]
+ if x_rest.shape[-1] > 0:
+ # (batch_size * height * width, channels, frames - 1) -> (batch_size * height * width, channels, (frames - 1) // 2)
+ x_rest = torch.nn.functional.avg_pool1d(x_rest, kernel_size=2, stride=2)
+
+ x = torch.cat([x_first[..., None], x_rest], dim=-1)
+ # (batch_size * height * width, channels, (frames // 2) + 1) -> (batch_size, height, width, channels, (frames // 2) + 1) -> (batch_size, channels, (frames // 2) + 1, height, width)
+ x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2)
+ else:
+ # (batch_size * height * width, channels, frames) -> (batch_size * height * width, channels, frames // 2)
+ x = torch.nn.functional.avg_pool1d(x, kernel_size=2, stride=2)
+ # (batch_size * height * width, channels, frames // 2) -> (batch_size, height, width, channels, frames // 2) -> (batch_size, channels, frames // 2, height, width)
+ x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2)
+
+ # Pad the tensor
+ pad = (0, 1, 0, 1)
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
+ batch_size, channels, frames, height, width = x.shape
+ # (batch_size, channels, frames, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size * frames, channels, height, width)
+ x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * frames, channels, height, width)
+ x = self.conv(x)
+ # (batch_size * frames, channels, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size, channels, frames, height, width)
+ x = x.reshape(batch_size, frames, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4)
+ return x
+
+
+
+class Upsample3D(torch.nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int = 3,
+ stride: int = 1,
+ padding: int = 1,
+ compress_time: bool = False,
+ ) -> None:
+ super().__init__()
+ self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
+ self.compress_time = compress_time
+
+ def forward(self, inputs: torch.Tensor, xq: torch.Tensor) -> torch.Tensor:
+ if self.compress_time:
+ if inputs.shape[2] > 1 and inputs.shape[2] % 2 == 1:
+ # split first frame
+ x_first, x_rest = inputs[:, :, 0], inputs[:, :, 1:]
+
+ x_first = torch.nn.functional.interpolate(x_first, scale_factor=2.0)
+ x_rest = torch.nn.functional.interpolate(x_rest, scale_factor=2.0)
+ x_first = x_first[:, :, None, :, :]
+ inputs = torch.cat([x_first, x_rest], dim=2)
+ elif inputs.shape[2] > 1:
+ inputs = torch.nn.functional.interpolate(inputs, scale_factor=2.0)
+ else:
+ inputs = inputs.squeeze(2)
+ inputs = torch.nn.functional.interpolate(inputs, scale_factor=2.0)
+ inputs = inputs[:, :, None, :, :]
+ else:
+ # only interpolate 2D
+ b, c, t, h, w = inputs.shape
+ inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
+ inputs = torch.nn.functional.interpolate(inputs, scale_factor=2.0)
+ inputs = inputs.reshape(b, t, c, *inputs.shape[2:]).permute(0, 2, 1, 3, 4)
+
+ b, c, t, h, w = inputs.shape
+ inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
+ inputs = self.conv(inputs)
+ inputs = inputs.reshape(b, t, *inputs.shape[1:]).permute(0, 2, 1, 3, 4)
+
+ return inputs
+
+
+
+class CogVideoXSpatialNorm3D(torch.nn.Module):
+ def __init__(self, f_channels, zq_channels, groups):
+ super().__init__()
+ self.norm_layer = torch.nn.GroupNorm(num_channels=f_channels, num_groups=groups, eps=1e-6, affine=True)
+ self.conv_y = torch.nn.Conv3d(zq_channels, f_channels, kernel_size=1, stride=1)
+ self.conv_b = torch.nn.Conv3d(zq_channels, f_channels, kernel_size=1, stride=1)
+
+
+ def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor:
+ if f.shape[2] > 1 and f.shape[2] % 2 == 1:
+ f_first, f_rest = f[:, :, :1], f[:, :, 1:]
+ f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
+ z_first, z_rest = zq[:, :, :1], zq[:, :, 1:]
+ z_first = torch.nn.functional.interpolate(z_first, size=f_first_size)
+ z_rest = torch.nn.functional.interpolate(z_rest, size=f_rest_size)
+ zq = torch.cat([z_first, z_rest], dim=2)
+ else:
+ zq = torch.nn.functional.interpolate(zq, size=f.shape[-3:])
+
+ norm_f = self.norm_layer(f)
+ new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
+ return new_f
+
+
+
+class Resnet3DBlock(torch.nn.Module):
+ def __init__(self, in_channels, out_channels, spatial_norm_dim, groups, eps=1e-6, use_conv_shortcut=False):
+ super().__init__()
+ self.nonlinearity = torch.nn.SiLU()
+ if spatial_norm_dim is None:
+ self.norm1 = torch.nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps)
+ self.norm2 = torch.nn.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps)
+ else:
+ self.norm1 = CogVideoXSpatialNorm3D(in_channels, spatial_norm_dim, groups)
+ self.norm2 = CogVideoXSpatialNorm3D(out_channels, spatial_norm_dim, groups)
+
+ self.conv1 = CachedConv3d(in_channels, out_channels, kernel_size=3, padding=(0, 1, 1))
+
+ self.conv2 = CachedConv3d(out_channels, out_channels, kernel_size=3, padding=(0, 1, 1))
+
+ if in_channels != out_channels:
+ if use_conv_shortcut:
+ self.conv_shortcut = CachedConv3d(in_channels, out_channels, kernel_size=3, padding=(0, 1, 1))
+ else:
+ self.conv_shortcut = torch.nn.Conv3d(in_channels, out_channels, kernel_size=1)
+ else:
+ self.conv_shortcut = lambda x: x
+
+
+ def forward(self, hidden_states, zq):
+ residual = hidden_states
+
+ hidden_states = self.norm1(hidden_states, zq) if isinstance(self.norm1, CogVideoXSpatialNorm3D) else self.norm1(hidden_states)
+ hidden_states = self.nonlinearity(hidden_states)
+ hidden_states = self.conv1(hidden_states)
+
+ hidden_states = self.norm2(hidden_states, zq) if isinstance(self.norm2, CogVideoXSpatialNorm3D) else self.norm2(hidden_states)
+ hidden_states = self.nonlinearity(hidden_states)
+ hidden_states = self.conv2(hidden_states)
+
+ hidden_states = hidden_states + self.conv_shortcut(residual)
+
+ return hidden_states
+
+
+
+class CachedConv3d(torch.nn.Conv3d):
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
+ super().__init__(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
+ self.cached_tensor = None
+
+
+ def clear_cache(self):
+ self.cached_tensor = None
+
+
+ def forward(self, input: torch.Tensor, use_cache = True) -> torch.Tensor:
+ if use_cache:
+ if self.cached_tensor is None:
+ self.cached_tensor = torch.concat([input[:, :, :1]] * 2, dim=2)
+ input = torch.concat([self.cached_tensor, input], dim=2)
+ self.cached_tensor = input[:, :, -2:]
+ return super().forward(input)
+
+
+
+class CogVAEDecoder(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.scaling_factor = 0.7
+ self.conv_in = CachedConv3d(16, 512, kernel_size=3, stride=1, padding=(0, 1, 1))
+
+ self.blocks = torch.nn.ModuleList([
+ Resnet3DBlock(512, 512, 16, 32),
+ Resnet3DBlock(512, 512, 16, 32),
+ Resnet3DBlock(512, 512, 16, 32),
+ Resnet3DBlock(512, 512, 16, 32),
+ Resnet3DBlock(512, 512, 16, 32),
+ Resnet3DBlock(512, 512, 16, 32),
+ Upsample3D(512, 512, compress_time=True),
+ Resnet3DBlock(512, 256, 16, 32),
+ Resnet3DBlock(256, 256, 16, 32),
+ Resnet3DBlock(256, 256, 16, 32),
+ Resnet3DBlock(256, 256, 16, 32),
+ Upsample3D(256, 256, compress_time=True),
+ Resnet3DBlock(256, 256, 16, 32),
+ Resnet3DBlock(256, 256, 16, 32),
+ Resnet3DBlock(256, 256, 16, 32),
+ Resnet3DBlock(256, 256, 16, 32),
+ Upsample3D(256, 256, compress_time=False),
+ Resnet3DBlock(256, 128, 16, 32),
+ Resnet3DBlock(128, 128, 16, 32),
+ Resnet3DBlock(128, 128, 16, 32),
+ Resnet3DBlock(128, 128, 16, 32),
+ ])
+
+ self.norm_out = CogVideoXSpatialNorm3D(128, 16, 32)
+ self.conv_act = torch.nn.SiLU()
+ self.conv_out = CachedConv3d(128, 3, kernel_size=3, stride=1, padding=(0, 1, 1))
+
+
+ def forward(self, sample):
+ sample = sample / self.scaling_factor
+ hidden_states = self.conv_in(sample)
+
+ for block in self.blocks:
+ hidden_states = block(hidden_states, sample)
+
+ hidden_states = self.norm_out(hidden_states, sample)
+ hidden_states = self.conv_act(hidden_states)
+ hidden_states = self.conv_out(hidden_states)
+
+ return hidden_states
+
+
+ def decode_video(self, sample, tiled=True, tile_size=(60, 90), tile_stride=(30, 45), progress_bar=lambda x:x):
+ if tiled:
+ B, C, T, H, W = sample.shape
+ return TileWorker2Dto3D().tiled_forward(
+ forward_fn=lambda x: self.decode_small_video(x),
+ model_input=sample,
+ tile_size=tile_size, tile_stride=tile_stride,
+ tile_device=sample.device, tile_dtype=sample.dtype,
+ computation_device=sample.device, computation_dtype=sample.dtype,
+ scales=(3/16, (T//2*8+T%2)/T, 8, 8),
+ progress_bar=progress_bar
+ )
+ else:
+ return self.decode_small_video(sample)
+
+
+ def decode_small_video(self, sample):
+ B, C, T, H, W = sample.shape
+ computation_device = self.conv_in.weight.device
+ computation_dtype = self.conv_in.weight.dtype
+ value = []
+ for i in range(T//2):
+ tl = i*2 + T%2 - (T%2 and i==0)
+ tr = i*2 + 2 + T%2
+ model_input = sample[:, :, tl: tr, :, :].to(dtype=computation_dtype, device=computation_device)
+ model_output = self.forward(model_input).to(dtype=sample.dtype, device=sample.device)
+ value.append(model_output)
+ value = torch.concat(value, dim=2)
+ for name, module in self.named_modules():
+ if isinstance(module, CachedConv3d):
+ module.clear_cache()
+ return value
+
+
+ @staticmethod
+ def state_dict_converter():
+ return CogVAEDecoderStateDictConverter()
+
+
+
+class CogVAEEncoder(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.scaling_factor = 0.7
+ self.conv_in = CachedConv3d(3, 128, kernel_size=3, stride=1, padding=(0, 1, 1))
+
+ self.blocks = torch.nn.ModuleList([
+ Resnet3DBlock(128, 128, None, 32),
+ Resnet3DBlock(128, 128, None, 32),
+ Resnet3DBlock(128, 128, None, 32),
+ Downsample3D(128, 128, compress_time=True),
+ Resnet3DBlock(128, 256, None, 32),
+ Resnet3DBlock(256, 256, None, 32),
+ Resnet3DBlock(256, 256, None, 32),
+ Downsample3D(256, 256, compress_time=True),
+ Resnet3DBlock(256, 256, None, 32),
+ Resnet3DBlock(256, 256, None, 32),
+ Resnet3DBlock(256, 256, None, 32),
+ Downsample3D(256, 256, compress_time=False),
+ Resnet3DBlock(256, 512, None, 32),
+ Resnet3DBlock(512, 512, None, 32),
+ Resnet3DBlock(512, 512, None, 32),
+ Resnet3DBlock(512, 512, None, 32),
+ Resnet3DBlock(512, 512, None, 32),
+ ])
+
+ self.norm_out = torch.nn.GroupNorm(32, 512, eps=1e-06, affine=True)
+ self.conv_act = torch.nn.SiLU()
+ self.conv_out = CachedConv3d(512, 32, kernel_size=3, stride=1, padding=(0, 1, 1))
+
+
+ def forward(self, sample):
+ hidden_states = self.conv_in(sample)
+
+ for block in self.blocks:
+ hidden_states = block(hidden_states, sample)
+
+ hidden_states = self.norm_out(hidden_states)
+ hidden_states = self.conv_act(hidden_states)
+ hidden_states = self.conv_out(hidden_states)[:, :16]
+ hidden_states = hidden_states * self.scaling_factor
+
+ return hidden_states
+
+
+ def encode_video(self, sample, tiled=True, tile_size=(60, 90), tile_stride=(30, 45), progress_bar=lambda x:x):
+ if tiled:
+ B, C, T, H, W = sample.shape
+ return TileWorker2Dto3D().tiled_forward(
+ forward_fn=lambda x: self.encode_small_video(x),
+ model_input=sample,
+ tile_size=(i * 8 for i in tile_size), tile_stride=(i * 8 for i in tile_stride),
+ tile_device=sample.device, tile_dtype=sample.dtype,
+ computation_device=sample.device, computation_dtype=sample.dtype,
+ scales=(16/3, (T//4+T%2)/T, 1/8, 1/8),
+ progress_bar=progress_bar
+ )
+ else:
+ return self.encode_small_video(sample)
+
+
+ def encode_small_video(self, sample):
+ B, C, T, H, W = sample.shape
+ computation_device = self.conv_in.weight.device
+ computation_dtype = self.conv_in.weight.dtype
+ value = []
+ for i in range(T//8):
+ t = i*8 + T%2 - (T%2 and i==0)
+ t_ = i*8 + 8 + T%2
+ model_input = sample[:, :, t: t_, :, :].to(dtype=computation_dtype, device=computation_device)
+ model_output = self.forward(model_input).to(dtype=sample.dtype, device=sample.device)
+ value.append(model_output)
+ value = torch.concat(value, dim=2)
+ for name, module in self.named_modules():
+ if isinstance(module, CachedConv3d):
+ module.clear_cache()
+ return value
+
+
+ @staticmethod
+ def state_dict_converter():
+ return CogVAEEncoderStateDictConverter()
+
+
+
+class CogVAEEncoderStateDictConverter:
+ def __init__(self):
+ pass
+
+
+ def from_diffusers(self, state_dict):
+ rename_dict = {
+ "encoder.conv_in.conv.weight": "conv_in.weight",
+ "encoder.conv_in.conv.bias": "conv_in.bias",
+ "encoder.down_blocks.0.downsamplers.0.conv.weight": "blocks.3.conv.weight",
+ "encoder.down_blocks.0.downsamplers.0.conv.bias": "blocks.3.conv.bias",
+ "encoder.down_blocks.1.downsamplers.0.conv.weight": "blocks.7.conv.weight",
+ "encoder.down_blocks.1.downsamplers.0.conv.bias": "blocks.7.conv.bias",
+ "encoder.down_blocks.2.downsamplers.0.conv.weight": "blocks.11.conv.weight",
+ "encoder.down_blocks.2.downsamplers.0.conv.bias": "blocks.11.conv.bias",
+ "encoder.norm_out.weight": "norm_out.weight",
+ "encoder.norm_out.bias": "norm_out.bias",
+ "encoder.conv_out.conv.weight": "conv_out.weight",
+ "encoder.conv_out.conv.bias": "conv_out.bias",
+ }
+ prefix_dict = {
+ "encoder.down_blocks.0.resnets.0.": "blocks.0.",
+ "encoder.down_blocks.0.resnets.1.": "blocks.1.",
+ "encoder.down_blocks.0.resnets.2.": "blocks.2.",
+ "encoder.down_blocks.1.resnets.0.": "blocks.4.",
+ "encoder.down_blocks.1.resnets.1.": "blocks.5.",
+ "encoder.down_blocks.1.resnets.2.": "blocks.6.",
+ "encoder.down_blocks.2.resnets.0.": "blocks.8.",
+ "encoder.down_blocks.2.resnets.1.": "blocks.9.",
+ "encoder.down_blocks.2.resnets.2.": "blocks.10.",
+ "encoder.down_blocks.3.resnets.0.": "blocks.12.",
+ "encoder.down_blocks.3.resnets.1.": "blocks.13.",
+ "encoder.down_blocks.3.resnets.2.": "blocks.14.",
+ "encoder.mid_block.resnets.0.": "blocks.15.",
+ "encoder.mid_block.resnets.1.": "blocks.16.",
+ }
+ suffix_dict = {
+ "norm1.norm_layer.weight": "norm1.norm_layer.weight",
+ "norm1.norm_layer.bias": "norm1.norm_layer.bias",
+ "norm1.conv_y.conv.weight": "norm1.conv_y.weight",
+ "norm1.conv_y.conv.bias": "norm1.conv_y.bias",
+ "norm1.conv_b.conv.weight": "norm1.conv_b.weight",
+ "norm1.conv_b.conv.bias": "norm1.conv_b.bias",
+ "norm2.norm_layer.weight": "norm2.norm_layer.weight",
+ "norm2.norm_layer.bias": "norm2.norm_layer.bias",
+ "norm2.conv_y.conv.weight": "norm2.conv_y.weight",
+ "norm2.conv_y.conv.bias": "norm2.conv_y.bias",
+ "norm2.conv_b.conv.weight": "norm2.conv_b.weight",
+ "norm2.conv_b.conv.bias": "norm2.conv_b.bias",
+ "conv1.conv.weight": "conv1.weight",
+ "conv1.conv.bias": "conv1.bias",
+ "conv2.conv.weight": "conv2.weight",
+ "conv2.conv.bias": "conv2.bias",
+ "conv_shortcut.weight": "conv_shortcut.weight",
+ "conv_shortcut.bias": "conv_shortcut.bias",
+ "norm1.weight": "norm1.weight",
+ "norm1.bias": "norm1.bias",
+ "norm2.weight": "norm2.weight",
+ "norm2.bias": "norm2.bias",
+ }
+ state_dict_ = {}
+ for name, param in state_dict.items():
+ if name in rename_dict:
+ state_dict_[rename_dict[name]] = param
+ else:
+ for prefix in prefix_dict:
+ if name.startswith(prefix):
+ suffix = name[len(prefix):]
+ state_dict_[prefix_dict[prefix] + suffix_dict[suffix]] = param
+ return state_dict_
+
+
+ def from_civitai(self, state_dict):
+ return self.from_diffusers(state_dict)
+
+
+
+class CogVAEDecoderStateDictConverter:
+ def __init__(self):
+ pass
+
+
+ def from_diffusers(self, state_dict):
+ rename_dict = {
+ "decoder.conv_in.conv.weight": "conv_in.weight",
+ "decoder.conv_in.conv.bias": "conv_in.bias",
+ "decoder.up_blocks.0.upsamplers.0.conv.weight": "blocks.6.conv.weight",
+ "decoder.up_blocks.0.upsamplers.0.conv.bias": "blocks.6.conv.bias",
+ "decoder.up_blocks.1.upsamplers.0.conv.weight": "blocks.11.conv.weight",
+ "decoder.up_blocks.1.upsamplers.0.conv.bias": "blocks.11.conv.bias",
+ "decoder.up_blocks.2.upsamplers.0.conv.weight": "blocks.16.conv.weight",
+ "decoder.up_blocks.2.upsamplers.0.conv.bias": "blocks.16.conv.bias",
+ "decoder.norm_out.norm_layer.weight": "norm_out.norm_layer.weight",
+ "decoder.norm_out.norm_layer.bias": "norm_out.norm_layer.bias",
+ "decoder.norm_out.conv_y.conv.weight": "norm_out.conv_y.weight",
+ "decoder.norm_out.conv_y.conv.bias": "norm_out.conv_y.bias",
+ "decoder.norm_out.conv_b.conv.weight": "norm_out.conv_b.weight",
+ "decoder.norm_out.conv_b.conv.bias": "norm_out.conv_b.bias",
+ "decoder.conv_out.conv.weight": "conv_out.weight",
+ "decoder.conv_out.conv.bias": "conv_out.bias"
+ }
+ prefix_dict = {
+ "decoder.mid_block.resnets.0.": "blocks.0.",
+ "decoder.mid_block.resnets.1.": "blocks.1.",
+ "decoder.up_blocks.0.resnets.0.": "blocks.2.",
+ "decoder.up_blocks.0.resnets.1.": "blocks.3.",
+ "decoder.up_blocks.0.resnets.2.": "blocks.4.",
+ "decoder.up_blocks.0.resnets.3.": "blocks.5.",
+ "decoder.up_blocks.1.resnets.0.": "blocks.7.",
+ "decoder.up_blocks.1.resnets.1.": "blocks.8.",
+ "decoder.up_blocks.1.resnets.2.": "blocks.9.",
+ "decoder.up_blocks.1.resnets.3.": "blocks.10.",
+ "decoder.up_blocks.2.resnets.0.": "blocks.12.",
+ "decoder.up_blocks.2.resnets.1.": "blocks.13.",
+ "decoder.up_blocks.2.resnets.2.": "blocks.14.",
+ "decoder.up_blocks.2.resnets.3.": "blocks.15.",
+ "decoder.up_blocks.3.resnets.0.": "blocks.17.",
+ "decoder.up_blocks.3.resnets.1.": "blocks.18.",
+ "decoder.up_blocks.3.resnets.2.": "blocks.19.",
+ "decoder.up_blocks.3.resnets.3.": "blocks.20.",
+ }
+ suffix_dict = {
+ "norm1.norm_layer.weight": "norm1.norm_layer.weight",
+ "norm1.norm_layer.bias": "norm1.norm_layer.bias",
+ "norm1.conv_y.conv.weight": "norm1.conv_y.weight",
+ "norm1.conv_y.conv.bias": "norm1.conv_y.bias",
+ "norm1.conv_b.conv.weight": "norm1.conv_b.weight",
+ "norm1.conv_b.conv.bias": "norm1.conv_b.bias",
+ "norm2.norm_layer.weight": "norm2.norm_layer.weight",
+ "norm2.norm_layer.bias": "norm2.norm_layer.bias",
+ "norm2.conv_y.conv.weight": "norm2.conv_y.weight",
+ "norm2.conv_y.conv.bias": "norm2.conv_y.bias",
+ "norm2.conv_b.conv.weight": "norm2.conv_b.weight",
+ "norm2.conv_b.conv.bias": "norm2.conv_b.bias",
+ "conv1.conv.weight": "conv1.weight",
+ "conv1.conv.bias": "conv1.bias",
+ "conv2.conv.weight": "conv2.weight",
+ "conv2.conv.bias": "conv2.bias",
+ "conv_shortcut.weight": "conv_shortcut.weight",
+ "conv_shortcut.bias": "conv_shortcut.bias",
+ }
+ state_dict_ = {}
+ for name, param in state_dict.items():
+ if name in rename_dict:
+ state_dict_[rename_dict[name]] = param
+ else:
+ for prefix in prefix_dict:
+ if name.startswith(prefix):
+ suffix = name[len(prefix):]
+ state_dict_[prefix_dict[prefix] + suffix_dict[suffix]] = param
+ return state_dict_
+
+
+ def from_civitai(self, state_dict):
+ return self.from_diffusers(state_dict)
+
diff --git a/PusaV1/diffsynth/models/downloader.py b/PusaV1/diffsynth/models/downloader.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c726f628fdbdac4cba79cb4c62475506df76b20
--- /dev/null
+++ b/PusaV1/diffsynth/models/downloader.py
@@ -0,0 +1,111 @@
+from huggingface_hub import hf_hub_download
+from modelscope import snapshot_download
+import os, shutil
+from typing_extensions import Literal, TypeAlias
+from typing import List
+from ..configs.model_config import preset_models_on_huggingface, preset_models_on_modelscope, Preset_model_id
+
+
+def download_from_modelscope(model_id, origin_file_path, local_dir):
+ os.makedirs(local_dir, exist_ok=True)
+ file_name = os.path.basename(origin_file_path)
+ if file_name in os.listdir(local_dir):
+ print(f" {file_name} has been already in {local_dir}.")
+ else:
+ print(f" Start downloading {os.path.join(local_dir, file_name)}")
+ snapshot_download(model_id, allow_file_pattern=origin_file_path, local_dir=local_dir)
+ downloaded_file_path = os.path.join(local_dir, origin_file_path)
+ target_file_path = os.path.join(local_dir, os.path.split(origin_file_path)[-1])
+ if downloaded_file_path != target_file_path:
+ shutil.move(downloaded_file_path, target_file_path)
+ shutil.rmtree(os.path.join(local_dir, origin_file_path.split("/")[0]))
+
+
+def download_from_huggingface(model_id, origin_file_path, local_dir):
+ os.makedirs(local_dir, exist_ok=True)
+ file_name = os.path.basename(origin_file_path)
+ if file_name in os.listdir(local_dir):
+ print(f" {file_name} has been already in {local_dir}.")
+ else:
+ print(f" Start downloading {os.path.join(local_dir, file_name)}")
+ hf_hub_download(model_id, origin_file_path, local_dir=local_dir)
+ downloaded_file_path = os.path.join(local_dir, origin_file_path)
+ target_file_path = os.path.join(local_dir, file_name)
+ if downloaded_file_path != target_file_path:
+ shutil.move(downloaded_file_path, target_file_path)
+ shutil.rmtree(os.path.join(local_dir, origin_file_path.split("/")[0]))
+
+
+Preset_model_website: TypeAlias = Literal[
+ "HuggingFace",
+ "ModelScope",
+]
+website_to_preset_models = {
+ "HuggingFace": preset_models_on_huggingface,
+ "ModelScope": preset_models_on_modelscope,
+}
+website_to_download_fn = {
+ "HuggingFace": download_from_huggingface,
+ "ModelScope": download_from_modelscope,
+}
+
+
+def download_customized_models(
+ model_id,
+ origin_file_path,
+ local_dir,
+ downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
+):
+ downloaded_files = []
+ for website in downloading_priority:
+ # Check if the file is downloaded.
+ file_to_download = os.path.join(local_dir, os.path.basename(origin_file_path))
+ if file_to_download in downloaded_files:
+ continue
+ # Download
+ website_to_download_fn[website](model_id, origin_file_path, local_dir)
+ if os.path.basename(origin_file_path) in os.listdir(local_dir):
+ downloaded_files.append(file_to_download)
+ return downloaded_files
+
+
+def download_models(
+ model_id_list: List[Preset_model_id] = [],
+ downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
+):
+ print(f"Downloading models: {model_id_list}")
+ downloaded_files = []
+ load_files = []
+
+ for model_id in model_id_list:
+ for website in downloading_priority:
+ if model_id in website_to_preset_models[website]:
+
+ # Parse model metadata
+ model_metadata = website_to_preset_models[website][model_id]
+ if isinstance(model_metadata, list):
+ file_data = model_metadata
+ else:
+ file_data = model_metadata.get("file_list", [])
+
+ # Try downloading the model from this website.
+ model_files = []
+ for model_id, origin_file_path, local_dir in file_data:
+ # Check if the file is downloaded.
+ file_to_download = os.path.join(local_dir, os.path.basename(origin_file_path))
+ if file_to_download in downloaded_files:
+ continue
+ # Download
+ website_to_download_fn[website](model_id, origin_file_path, local_dir)
+ if os.path.basename(origin_file_path) in os.listdir(local_dir):
+ downloaded_files.append(file_to_download)
+ model_files.append(file_to_download)
+
+ # If the model is successfully downloaded, break.
+ if len(model_files) > 0:
+ if isinstance(model_metadata, dict) and "load_path" in model_metadata:
+ model_files = model_metadata["load_path"]
+ load_files.extend(model_files)
+ break
+
+ return load_files
diff --git a/PusaV1/diffsynth/models/flux_controlnet.py b/PusaV1/diffsynth/models/flux_controlnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..7bc3dc000066af351da775131f07214411c87b1a
--- /dev/null
+++ b/PusaV1/diffsynth/models/flux_controlnet.py
@@ -0,0 +1,329 @@
+import torch
+from einops import rearrange, repeat
+from .flux_dit import RoPEEmbedding, TimestepEmbeddings, FluxJointTransformerBlock, FluxSingleTransformerBlock, RMSNorm
+from .utils import hash_state_dict_keys, init_weights_on_device
+
+
+
+class FluxControlNet(torch.nn.Module):
+ def __init__(self, disable_guidance_embedder=False, num_joint_blocks=5, num_single_blocks=10, num_mode=0, mode_dict={}, additional_input_dim=0):
+ super().__init__()
+ self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
+ self.time_embedder = TimestepEmbeddings(256, 3072)
+ self.guidance_embedder = None if disable_guidance_embedder else TimestepEmbeddings(256, 3072)
+ self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072))
+ self.context_embedder = torch.nn.Linear(4096, 3072)
+ self.x_embedder = torch.nn.Linear(64, 3072)
+
+ self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(num_joint_blocks)])
+ self.single_blocks = torch.nn.ModuleList([FluxSingleTransformerBlock(3072, 24) for _ in range(num_single_blocks)])
+
+ self.controlnet_blocks = torch.nn.ModuleList([torch.nn.Linear(3072, 3072) for _ in range(num_joint_blocks)])
+ self.controlnet_single_blocks = torch.nn.ModuleList([torch.nn.Linear(3072, 3072) for _ in range(num_single_blocks)])
+
+ self.mode_dict = mode_dict
+ self.controlnet_mode_embedder = torch.nn.Embedding(num_mode, 3072) if len(mode_dict) > 0 else None
+ self.controlnet_x_embedder = torch.nn.Linear(64 + additional_input_dim, 3072)
+
+
+ def prepare_image_ids(self, latents):
+ batch_size, _, height, width = latents.shape
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3)
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
+
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
+
+ latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
+ latent_image_ids = latent_image_ids.reshape(
+ batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
+ )
+ latent_image_ids = latent_image_ids.to(device=latents.device, dtype=latents.dtype)
+
+ return latent_image_ids
+
+
+ def patchify(self, hidden_states):
+ hidden_states = rearrange(hidden_states, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2)
+ return hidden_states
+
+
+ def align_res_stack_to_original_blocks(self, res_stack, num_blocks, hidden_states):
+ if len(res_stack) == 0:
+ return [torch.zeros_like(hidden_states)] * num_blocks
+ interval = (num_blocks + len(res_stack) - 1) // len(res_stack)
+ aligned_res_stack = [res_stack[block_id // interval] for block_id in range(num_blocks)]
+ return aligned_res_stack
+
+
+ def forward(
+ self,
+ hidden_states,
+ controlnet_conditioning,
+ timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None,
+ processor_id=None,
+ tiled=False, tile_size=128, tile_stride=64,
+ **kwargs
+ ):
+ if image_ids is None:
+ image_ids = self.prepare_image_ids(hidden_states)
+
+ conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb)
+ if self.guidance_embedder is not None:
+ guidance = guidance * 1000
+ conditioning = conditioning + self.guidance_embedder(guidance, hidden_states.dtype)
+ prompt_emb = self.context_embedder(prompt_emb)
+ if self.controlnet_mode_embedder is not None: # Different from FluxDiT
+ processor_id = torch.tensor([self.mode_dict[processor_id]], dtype=torch.int)
+ processor_id = repeat(processor_id, "D -> B D", B=1).to(text_ids.device)
+ prompt_emb = torch.concat([self.controlnet_mode_embedder(processor_id), prompt_emb], dim=1)
+ text_ids = torch.cat([text_ids[:, :1], text_ids], dim=1)
+ image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
+
+ hidden_states = self.patchify(hidden_states)
+ hidden_states = self.x_embedder(hidden_states)
+ controlnet_conditioning = self.patchify(controlnet_conditioning) # Different from FluxDiT
+ hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_conditioning) # Different from FluxDiT
+
+ controlnet_res_stack = []
+ for block, controlnet_block in zip(self.blocks, self.controlnet_blocks):
+ hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
+ controlnet_res_stack.append(controlnet_block(hidden_states))
+
+ controlnet_single_res_stack = []
+ hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
+ for block, controlnet_block in zip(self.single_blocks, self.controlnet_single_blocks):
+ hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
+ controlnet_single_res_stack.append(controlnet_block(hidden_states[:, prompt_emb.shape[1]:]))
+
+ controlnet_res_stack = self.align_res_stack_to_original_blocks(controlnet_res_stack, 19, hidden_states[:, prompt_emb.shape[1]:])
+ controlnet_single_res_stack = self.align_res_stack_to_original_blocks(controlnet_single_res_stack, 38, hidden_states[:, prompt_emb.shape[1]:])
+
+ return controlnet_res_stack, controlnet_single_res_stack
+
+
+ @staticmethod
+ def state_dict_converter():
+ return FluxControlNetStateDictConverter()
+
+ def quantize(self):
+ def cast_to(weight, dtype=None, device=None, copy=False):
+ if device is None or weight.device == device:
+ if not copy:
+ if dtype is None or weight.dtype == dtype:
+ return weight
+ return weight.to(dtype=dtype, copy=copy)
+
+ r = torch.empty_like(weight, dtype=dtype, device=device)
+ r.copy_(weight)
+ return r
+
+ def cast_weight(s, input=None, dtype=None, device=None):
+ if input is not None:
+ if dtype is None:
+ dtype = input.dtype
+ if device is None:
+ device = input.device
+ weight = cast_to(s.weight, dtype, device)
+ return weight
+
+ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
+ if input is not None:
+ if dtype is None:
+ dtype = input.dtype
+ if bias_dtype is None:
+ bias_dtype = dtype
+ if device is None:
+ device = input.device
+ bias = None
+ weight = cast_to(s.weight, dtype, device)
+ bias = cast_to(s.bias, bias_dtype, device)
+ return weight, bias
+
+ class quantized_layer:
+ class QLinear(torch.nn.Linear):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(self,input,**kwargs):
+ weight,bias= cast_bias_weight(self,input)
+ return torch.nn.functional.linear(input,weight,bias)
+
+ class QRMSNorm(torch.nn.Module):
+ def __init__(self, module):
+ super().__init__()
+ self.module = module
+
+ def forward(self,hidden_states,**kwargs):
+ weight= cast_weight(self.module,hidden_states)
+ input_dtype = hidden_states.dtype
+ variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.module.eps)
+ hidden_states = hidden_states.to(input_dtype) * weight
+ return hidden_states
+
+ class QEmbedding(torch.nn.Embedding):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(self,input,**kwargs):
+ weight= cast_weight(self,input)
+ return torch.nn.functional.embedding(
+ input, weight, self.padding_idx, self.max_norm,
+ self.norm_type, self.scale_grad_by_freq, self.sparse)
+
+ def replace_layer(model):
+ for name, module in model.named_children():
+ if isinstance(module,quantized_layer.QRMSNorm):
+ continue
+ if isinstance(module, torch.nn.Linear):
+ with init_weights_on_device():
+ new_layer = quantized_layer.QLinear(module.in_features,module.out_features)
+ new_layer.weight = module.weight
+ if module.bias is not None:
+ new_layer.bias = module.bias
+ setattr(model, name, new_layer)
+ elif isinstance(module, RMSNorm):
+ if hasattr(module,"quantized"):
+ continue
+ module.quantized= True
+ new_layer = quantized_layer.QRMSNorm(module)
+ setattr(model, name, new_layer)
+ elif isinstance(module,torch.nn.Embedding):
+ rows, cols = module.weight.shape
+ new_layer = quantized_layer.QEmbedding(
+ num_embeddings=rows,
+ embedding_dim=cols,
+ _weight=module.weight,
+ # _freeze=module.freeze,
+ padding_idx=module.padding_idx,
+ max_norm=module.max_norm,
+ norm_type=module.norm_type,
+ scale_grad_by_freq=module.scale_grad_by_freq,
+ sparse=module.sparse)
+ setattr(model, name, new_layer)
+ else:
+ replace_layer(module)
+
+ replace_layer(self)
+
+
+
+class FluxControlNetStateDictConverter:
+ def __init__(self):
+ pass
+
+ def from_diffusers(self, state_dict):
+ hash_value = hash_state_dict_keys(state_dict)
+ global_rename_dict = {
+ "context_embedder": "context_embedder",
+ "x_embedder": "x_embedder",
+ "time_text_embed.timestep_embedder.linear_1": "time_embedder.timestep_embedder.0",
+ "time_text_embed.timestep_embedder.linear_2": "time_embedder.timestep_embedder.2",
+ "time_text_embed.guidance_embedder.linear_1": "guidance_embedder.timestep_embedder.0",
+ "time_text_embed.guidance_embedder.linear_2": "guidance_embedder.timestep_embedder.2",
+ "time_text_embed.text_embedder.linear_1": "pooled_text_embedder.0",
+ "time_text_embed.text_embedder.linear_2": "pooled_text_embedder.2",
+ "norm_out.linear": "final_norm_out.linear",
+ "proj_out": "final_proj_out",
+ }
+ rename_dict = {
+ "proj_out": "proj_out",
+ "norm1.linear": "norm1_a.linear",
+ "norm1_context.linear": "norm1_b.linear",
+ "attn.to_q": "attn.a_to_q",
+ "attn.to_k": "attn.a_to_k",
+ "attn.to_v": "attn.a_to_v",
+ "attn.to_out.0": "attn.a_to_out",
+ "attn.add_q_proj": "attn.b_to_q",
+ "attn.add_k_proj": "attn.b_to_k",
+ "attn.add_v_proj": "attn.b_to_v",
+ "attn.to_add_out": "attn.b_to_out",
+ "ff.net.0.proj": "ff_a.0",
+ "ff.net.2": "ff_a.2",
+ "ff_context.net.0.proj": "ff_b.0",
+ "ff_context.net.2": "ff_b.2",
+ "attn.norm_q": "attn.norm_q_a",
+ "attn.norm_k": "attn.norm_k_a",
+ "attn.norm_added_q": "attn.norm_q_b",
+ "attn.norm_added_k": "attn.norm_k_b",
+ }
+ rename_dict_single = {
+ "attn.to_q": "a_to_q",
+ "attn.to_k": "a_to_k",
+ "attn.to_v": "a_to_v",
+ "attn.norm_q": "norm_q_a",
+ "attn.norm_k": "norm_k_a",
+ "norm.linear": "norm.linear",
+ "proj_mlp": "proj_in_besides_attn",
+ "proj_out": "proj_out",
+ }
+ state_dict_ = {}
+ for name, param in state_dict.items():
+ if name.endswith(".weight") or name.endswith(".bias"):
+ suffix = ".weight" if name.endswith(".weight") else ".bias"
+ prefix = name[:-len(suffix)]
+ if prefix in global_rename_dict:
+ state_dict_[global_rename_dict[prefix] + suffix] = param
+ elif prefix.startswith("transformer_blocks."):
+ names = prefix.split(".")
+ names[0] = "blocks"
+ middle = ".".join(names[2:])
+ if middle in rename_dict:
+ name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]])
+ state_dict_[name_] = param
+ elif prefix.startswith("single_transformer_blocks."):
+ names = prefix.split(".")
+ names[0] = "single_blocks"
+ middle = ".".join(names[2:])
+ if middle in rename_dict_single:
+ name_ = ".".join(names[:2] + [rename_dict_single[middle]] + [suffix[1:]])
+ state_dict_[name_] = param
+ else:
+ state_dict_[name] = param
+ else:
+ state_dict_[name] = param
+ for name in list(state_dict_.keys()):
+ if ".proj_in_besides_attn." in name:
+ name_ = name.replace(".proj_in_besides_attn.", ".to_qkv_mlp.")
+ param = torch.concat([
+ state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_q.")],
+ state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_k.")],
+ state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_v.")],
+ state_dict_[name],
+ ], dim=0)
+ state_dict_[name_] = param
+ state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_q."))
+ state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_k."))
+ state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_v."))
+ state_dict_.pop(name)
+ for name in list(state_dict_.keys()):
+ for component in ["a", "b"]:
+ if f".{component}_to_q." in name:
+ name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.")
+ param = torch.concat([
+ state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
+ state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
+ state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
+ ], dim=0)
+ state_dict_[name_] = param
+ state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q."))
+ state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k."))
+ state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v."))
+ if hash_value == "78d18b9101345ff695f312e7e62538c0":
+ extra_kwargs = {"num_mode": 10, "mode_dict": {"canny": 0, "tile": 1, "depth": 2, "blur": 3, "pose": 4, "gray": 5, "lq": 6}}
+ elif hash_value == "b001c89139b5f053c715fe772362dd2a":
+ extra_kwargs = {"num_single_blocks": 0}
+ elif hash_value == "52357cb26250681367488a8954c271e8":
+ extra_kwargs = {"num_joint_blocks": 6, "num_single_blocks": 0, "additional_input_dim": 4}
+ elif hash_value == "0cfd1740758423a2a854d67c136d1e8c":
+ extra_kwargs = {"num_joint_blocks": 4, "num_single_blocks": 1}
+ elif hash_value == "7f9583eb8ba86642abb9a21a4b2c9e16":
+ extra_kwargs = {"num_joint_blocks": 4, "num_single_blocks": 10}
+ else:
+ extra_kwargs = {}
+ return state_dict_, extra_kwargs
+
+
+ def from_civitai(self, state_dict):
+ return self.from_diffusers(state_dict)
diff --git a/PusaV1/diffsynth/models/flux_dit.py b/PusaV1/diffsynth/models/flux_dit.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d3100d672da6b5e71b73409480a38c36600f0a9
--- /dev/null
+++ b/PusaV1/diffsynth/models/flux_dit.py
@@ -0,0 +1,742 @@
+import torch
+from .sd3_dit import TimestepEmbeddings, AdaLayerNorm, RMSNorm
+from einops import rearrange
+from .tiler import TileWorker
+from .utils import init_weights_on_device
+
+def interact_with_ipadapter(hidden_states, q, ip_k, ip_v, scale=1.0):
+ batch_size, num_tokens = hidden_states.shape[0:2]
+ ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v)
+ ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, num_tokens, -1)
+ hidden_states = hidden_states + scale * ip_hidden_states
+ return hidden_states
+
+
+class RoPEEmbedding(torch.nn.Module):
+ def __init__(self, dim, theta, axes_dim):
+ super().__init__()
+ self.dim = dim
+ self.theta = theta
+ self.axes_dim = axes_dim
+
+
+ def rope(self, pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
+ assert dim % 2 == 0, "The dimension must be even."
+
+ scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
+ omega = 1.0 / (theta**scale)
+
+ batch_size, seq_length = pos.shape
+ out = torch.einsum("...n,d->...nd", pos, omega)
+ cos_out = torch.cos(out)
+ sin_out = torch.sin(out)
+
+ stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
+ out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
+ return out.float()
+
+
+ def forward(self, ids):
+ n_axes = ids.shape[-1]
+ emb = torch.cat([self.rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3)
+ return emb.unsqueeze(1)
+
+
+
+class FluxJointAttention(torch.nn.Module):
+ def __init__(self, dim_a, dim_b, num_heads, head_dim, only_out_a=False):
+ super().__init__()
+ self.num_heads = num_heads
+ self.head_dim = head_dim
+ self.only_out_a = only_out_a
+
+ self.a_to_qkv = torch.nn.Linear(dim_a, dim_a * 3)
+ self.b_to_qkv = torch.nn.Linear(dim_b, dim_b * 3)
+
+ self.norm_q_a = RMSNorm(head_dim, eps=1e-6)
+ self.norm_k_a = RMSNorm(head_dim, eps=1e-6)
+ self.norm_q_b = RMSNorm(head_dim, eps=1e-6)
+ self.norm_k_b = RMSNorm(head_dim, eps=1e-6)
+
+ self.a_to_out = torch.nn.Linear(dim_a, dim_a)
+ if not only_out_a:
+ self.b_to_out = torch.nn.Linear(dim_b, dim_b)
+
+
+ def apply_rope(self, xq, xk, freqs_cis):
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
+
+ def forward(self, hidden_states_a, hidden_states_b, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
+ batch_size = hidden_states_a.shape[0]
+
+ # Part A
+ qkv_a = self.a_to_qkv(hidden_states_a)
+ qkv_a = qkv_a.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
+ q_a, k_a, v_a = qkv_a.chunk(3, dim=1)
+ q_a, k_a = self.norm_q_a(q_a), self.norm_k_a(k_a)
+
+ # Part B
+ qkv_b = self.b_to_qkv(hidden_states_b)
+ qkv_b = qkv_b.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
+ q_b, k_b, v_b = qkv_b.chunk(3, dim=1)
+ q_b, k_b = self.norm_q_b(q_b), self.norm_k_b(k_b)
+
+ q = torch.concat([q_b, q_a], dim=2)
+ k = torch.concat([k_b, k_a], dim=2)
+ v = torch.concat([v_b, v_a], dim=2)
+
+ q, k = self.apply_rope(q, k, image_rotary_emb)
+
+ hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
+ hidden_states = hidden_states.to(q.dtype)
+ hidden_states_b, hidden_states_a = hidden_states[:, :hidden_states_b.shape[1]], hidden_states[:, hidden_states_b.shape[1]:]
+ if ipadapter_kwargs_list is not None:
+ hidden_states_a = interact_with_ipadapter(hidden_states_a, q_a, **ipadapter_kwargs_list)
+ hidden_states_a = self.a_to_out(hidden_states_a)
+ if self.only_out_a:
+ return hidden_states_a
+ else:
+ hidden_states_b = self.b_to_out(hidden_states_b)
+ return hidden_states_a, hidden_states_b
+
+
+
+class FluxJointTransformerBlock(torch.nn.Module):
+ def __init__(self, dim, num_attention_heads):
+ super().__init__()
+ self.norm1_a = AdaLayerNorm(dim)
+ self.norm1_b = AdaLayerNorm(dim)
+
+ self.attn = FluxJointAttention(dim, dim, num_attention_heads, dim // num_attention_heads)
+
+ self.norm2_a = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
+ self.ff_a = torch.nn.Sequential(
+ torch.nn.Linear(dim, dim*4),
+ torch.nn.GELU(approximate="tanh"),
+ torch.nn.Linear(dim*4, dim)
+ )
+
+ self.norm2_b = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
+ self.ff_b = torch.nn.Sequential(
+ torch.nn.Linear(dim, dim*4),
+ torch.nn.GELU(approximate="tanh"),
+ torch.nn.Linear(dim*4, dim)
+ )
+
+
+ def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
+ norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb)
+ norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb)
+
+ # Attention
+ attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b, image_rotary_emb, attn_mask, ipadapter_kwargs_list)
+
+ # Part A
+ hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
+ norm_hidden_states_a = self.norm2_a(hidden_states_a) * (1 + scale_mlp_a) + shift_mlp_a
+ hidden_states_a = hidden_states_a + gate_mlp_a * self.ff_a(norm_hidden_states_a)
+
+ # Part B
+ hidden_states_b = hidden_states_b + gate_msa_b * attn_output_b
+ norm_hidden_states_b = self.norm2_b(hidden_states_b) * (1 + scale_mlp_b) + shift_mlp_b
+ hidden_states_b = hidden_states_b + gate_mlp_b * self.ff_b(norm_hidden_states_b)
+
+ return hidden_states_a, hidden_states_b
+
+
+
+class FluxSingleAttention(torch.nn.Module):
+ def __init__(self, dim_a, dim_b, num_heads, head_dim):
+ super().__init__()
+ self.num_heads = num_heads
+ self.head_dim = head_dim
+
+ self.a_to_qkv = torch.nn.Linear(dim_a, dim_a * 3)
+
+ self.norm_q_a = RMSNorm(head_dim, eps=1e-6)
+ self.norm_k_a = RMSNorm(head_dim, eps=1e-6)
+
+
+ def apply_rope(self, xq, xk, freqs_cis):
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
+
+
+ def forward(self, hidden_states, image_rotary_emb):
+ batch_size = hidden_states.shape[0]
+
+ qkv_a = self.a_to_qkv(hidden_states)
+ qkv_a = qkv_a.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
+ q_a, k_a, v = qkv_a.chunk(3, dim=1)
+ q_a, k_a = self.norm_q_a(q_a), self.norm_k_a(k_a)
+
+ q, k = self.apply_rope(q_a, k_a, image_rotary_emb)
+
+ hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v)
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
+ hidden_states = hidden_states.to(q.dtype)
+ return hidden_states
+
+
+
+class AdaLayerNormSingle(torch.nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.silu = torch.nn.SiLU()
+ self.linear = torch.nn.Linear(dim, 3 * dim, bias=True)
+ self.norm = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
+
+
+ def forward(self, x, emb):
+ emb = self.linear(self.silu(emb))
+ shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1)
+ x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
+ return x, gate_msa
+
+
+
+class FluxSingleTransformerBlock(torch.nn.Module):
+ def __init__(self, dim, num_attention_heads):
+ super().__init__()
+ self.num_heads = num_attention_heads
+ self.head_dim = dim // num_attention_heads
+ self.dim = dim
+
+ self.norm = AdaLayerNormSingle(dim)
+ self.to_qkv_mlp = torch.nn.Linear(dim, dim * (3 + 4))
+ self.norm_q_a = RMSNorm(self.head_dim, eps=1e-6)
+ self.norm_k_a = RMSNorm(self.head_dim, eps=1e-6)
+
+ self.proj_out = torch.nn.Linear(dim * 5, dim)
+
+
+ def apply_rope(self, xq, xk, freqs_cis):
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
+
+
+ def process_attention(self, hidden_states, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
+ batch_size = hidden_states.shape[0]
+
+ qkv = hidden_states.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
+ q, k, v = qkv.chunk(3, dim=1)
+ q, k = self.norm_q_a(q), self.norm_k_a(k)
+
+ q, k = self.apply_rope(q, k, image_rotary_emb)
+
+ hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
+ hidden_states = hidden_states.to(q.dtype)
+ if ipadapter_kwargs_list is not None:
+ hidden_states = interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs_list)
+ return hidden_states
+
+
+ def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
+ residual = hidden_states_a
+ norm_hidden_states, gate = self.norm(hidden_states_a, emb=temb)
+ hidden_states_a = self.to_qkv_mlp(norm_hidden_states)
+ attn_output, mlp_hidden_states = hidden_states_a[:, :, :self.dim * 3], hidden_states_a[:, :, self.dim * 3:]
+
+ attn_output = self.process_attention(attn_output, image_rotary_emb, attn_mask, ipadapter_kwargs_list)
+ mlp_hidden_states = torch.nn.functional.gelu(mlp_hidden_states, approximate="tanh")
+
+ hidden_states_a = torch.cat([attn_output, mlp_hidden_states], dim=2)
+ hidden_states_a = gate.unsqueeze(1) * self.proj_out(hidden_states_a)
+ hidden_states_a = residual + hidden_states_a
+
+ return hidden_states_a, hidden_states_b
+
+
+
+class AdaLayerNormContinuous(torch.nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.silu = torch.nn.SiLU()
+ self.linear = torch.nn.Linear(dim, dim * 2, bias=True)
+ self.norm = torch.nn.LayerNorm(dim, eps=1e-6, elementwise_affine=False)
+
+ def forward(self, x, conditioning):
+ emb = self.linear(self.silu(conditioning))
+ scale, shift = torch.chunk(emb, 2, dim=1)
+ x = self.norm(x) * (1 + scale)[:, None] + shift[:, None]
+ return x
+
+
+
+class FluxDiT(torch.nn.Module):
+ def __init__(self, disable_guidance_embedder=False):
+ super().__init__()
+ self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
+ self.time_embedder = TimestepEmbeddings(256, 3072)
+ self.guidance_embedder = None if disable_guidance_embedder else TimestepEmbeddings(256, 3072)
+ self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072))
+ self.context_embedder = torch.nn.Linear(4096, 3072)
+ self.x_embedder = torch.nn.Linear(64, 3072)
+
+ self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(19)])
+ self.single_blocks = torch.nn.ModuleList([FluxSingleTransformerBlock(3072, 24) for _ in range(38)])
+
+ self.final_norm_out = AdaLayerNormContinuous(3072)
+ self.final_proj_out = torch.nn.Linear(3072, 64)
+
+
+ def patchify(self, hidden_states):
+ hidden_states = rearrange(hidden_states, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2)
+ return hidden_states
+
+
+ def unpatchify(self, hidden_states, height, width):
+ hidden_states = rearrange(hidden_states, "B (H W) (C P Q) -> B C (H P) (W Q)", P=2, Q=2, H=height//2, W=width//2)
+ return hidden_states
+
+
+ def prepare_image_ids(self, latents):
+ batch_size, _, height, width = latents.shape
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3)
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
+
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
+
+ latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
+ latent_image_ids = latent_image_ids.reshape(
+ batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
+ )
+ latent_image_ids = latent_image_ids.to(device=latents.device, dtype=latents.dtype)
+
+ return latent_image_ids
+
+
+ def tiled_forward(
+ self,
+ hidden_states,
+ timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids,
+ tile_size=128, tile_stride=64,
+ **kwargs
+ ):
+ # Due to the global positional embedding, we cannot implement layer-wise tiled forward.
+ hidden_states = TileWorker().tiled_forward(
+ lambda x: self.forward(x, timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None),
+ hidden_states,
+ tile_size,
+ tile_stride,
+ tile_device=hidden_states.device,
+ tile_dtype=hidden_states.dtype
+ )
+ return hidden_states
+
+
+ def construct_mask(self, entity_masks, prompt_seq_len, image_seq_len):
+ N = len(entity_masks)
+ batch_size = entity_masks[0].shape[0]
+ total_seq_len = N * prompt_seq_len + image_seq_len
+ patched_masks = [self.patchify(entity_masks[i]) for i in range(N)]
+ attention_mask = torch.ones((batch_size, total_seq_len, total_seq_len), dtype=torch.bool).to(device=entity_masks[0].device)
+
+ image_start = N * prompt_seq_len
+ image_end = N * prompt_seq_len + image_seq_len
+ # prompt-image mask
+ for i in range(N):
+ prompt_start = i * prompt_seq_len
+ prompt_end = (i + 1) * prompt_seq_len
+ image_mask = torch.sum(patched_masks[i], dim=-1) > 0
+ image_mask = image_mask.unsqueeze(1).repeat(1, prompt_seq_len, 1)
+ # prompt update with image
+ attention_mask[:, prompt_start:prompt_end, image_start:image_end] = image_mask
+ # image update with prompt
+ attention_mask[:, image_start:image_end, prompt_start:prompt_end] = image_mask.transpose(1, 2)
+ # prompt-prompt mask
+ for i in range(N):
+ for j in range(N):
+ if i != j:
+ prompt_start_i = i * prompt_seq_len
+ prompt_end_i = (i + 1) * prompt_seq_len
+ prompt_start_j = j * prompt_seq_len
+ prompt_end_j = (j + 1) * prompt_seq_len
+ attention_mask[:, prompt_start_i:prompt_end_i, prompt_start_j:prompt_end_j] = False
+
+ attention_mask = attention_mask.float()
+ attention_mask[attention_mask == 0] = float('-inf')
+ attention_mask[attention_mask == 1] = 0
+ return attention_mask
+
+
+ def process_entity_masks(self, hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids):
+ repeat_dim = hidden_states.shape[1]
+ max_masks = 0
+ attention_mask = None
+ prompt_embs = [prompt_emb]
+ if entity_masks is not None:
+ # entity_masks
+ batch_size, max_masks = entity_masks.shape[0], entity_masks.shape[1]
+ entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1)
+ entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)]
+ # global mask
+ global_mask = torch.ones_like(entity_masks[0]).to(device=hidden_states.device, dtype=hidden_states.dtype)
+ entity_masks = entity_masks + [global_mask] # append global to last
+ # attention mask
+ attention_mask = self.construct_mask(entity_masks, prompt_emb.shape[1], hidden_states.shape[1])
+ attention_mask = attention_mask.to(device=hidden_states.device, dtype=hidden_states.dtype)
+ attention_mask = attention_mask.unsqueeze(1)
+ # embds: n_masks * b * seq * d
+ local_embs = [entity_prompt_emb[:, i, None].squeeze(1) for i in range(max_masks)]
+ prompt_embs = local_embs + prompt_embs # append global to last
+ prompt_embs = [self.context_embedder(prompt_emb) for prompt_emb in prompt_embs]
+ prompt_emb = torch.cat(prompt_embs, dim=1)
+
+ # positional embedding
+ text_ids = torch.cat([text_ids] * (max_masks + 1), dim=1)
+ image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
+ return prompt_emb, image_rotary_emb, attention_mask
+
+
+ def forward(
+ self,
+ hidden_states,
+ timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None,
+ tiled=False, tile_size=128, tile_stride=64, entity_prompt_emb=None, entity_masks=None,
+ use_gradient_checkpointing=False,
+ **kwargs
+ ):
+ if tiled:
+ return self.tiled_forward(
+ hidden_states,
+ timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids,
+ tile_size=tile_size, tile_stride=tile_stride,
+ **kwargs
+ )
+
+ if image_ids is None:
+ image_ids = self.prepare_image_ids(hidden_states)
+
+ conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb)
+ if self.guidance_embedder is not None:
+ guidance = guidance * 1000
+ conditioning = conditioning + self.guidance_embedder(guidance, hidden_states.dtype)
+
+ height, width = hidden_states.shape[-2:]
+ hidden_states = self.patchify(hidden_states)
+ hidden_states = self.x_embedder(hidden_states)
+
+ if entity_prompt_emb is not None and entity_masks is not None:
+ prompt_emb, image_rotary_emb, attention_mask = self.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids)
+ else:
+ prompt_emb = self.context_embedder(prompt_emb)
+ image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
+ attention_mask = None
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+ return custom_forward
+
+ for block in self.blocks:
+ if self.training and use_gradient_checkpointing:
+ hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask,
+ use_reentrant=False,
+ )
+ else:
+ hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask)
+
+ hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
+ for block in self.single_blocks:
+ if self.training and use_gradient_checkpointing:
+ hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask,
+ use_reentrant=False,
+ )
+ else:
+ hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask)
+ hidden_states = hidden_states[:, prompt_emb.shape[1]:]
+
+ hidden_states = self.final_norm_out(hidden_states, conditioning)
+ hidden_states = self.final_proj_out(hidden_states)
+ hidden_states = self.unpatchify(hidden_states, height, width)
+
+ return hidden_states
+
+
+ def quantize(self):
+ def cast_to(weight, dtype=None, device=None, copy=False):
+ if device is None or weight.device == device:
+ if not copy:
+ if dtype is None or weight.dtype == dtype:
+ return weight
+ return weight.to(dtype=dtype, copy=copy)
+
+ r = torch.empty_like(weight, dtype=dtype, device=device)
+ r.copy_(weight)
+ return r
+
+ def cast_weight(s, input=None, dtype=None, device=None):
+ if input is not None:
+ if dtype is None:
+ dtype = input.dtype
+ if device is None:
+ device = input.device
+ weight = cast_to(s.weight, dtype, device)
+ return weight
+
+ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
+ if input is not None:
+ if dtype is None:
+ dtype = input.dtype
+ if bias_dtype is None:
+ bias_dtype = dtype
+ if device is None:
+ device = input.device
+ bias = None
+ weight = cast_to(s.weight, dtype, device)
+ bias = cast_to(s.bias, bias_dtype, device)
+ return weight, bias
+
+ class quantized_layer:
+ class Linear(torch.nn.Linear):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(self,input,**kwargs):
+ weight,bias= cast_bias_weight(self,input)
+ return torch.nn.functional.linear(input,weight,bias)
+
+ class RMSNorm(torch.nn.Module):
+ def __init__(self, module):
+ super().__init__()
+ self.module = module
+
+ def forward(self,hidden_states,**kwargs):
+ weight= cast_weight(self.module,hidden_states)
+ input_dtype = hidden_states.dtype
+ variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.module.eps)
+ hidden_states = hidden_states.to(input_dtype) * weight
+ return hidden_states
+
+ def replace_layer(model):
+ for name, module in model.named_children():
+ if isinstance(module, torch.nn.Linear):
+ with init_weights_on_device():
+ new_layer = quantized_layer.Linear(module.in_features,module.out_features)
+ new_layer.weight = module.weight
+ if module.bias is not None:
+ new_layer.bias = module.bias
+ # del module
+ setattr(model, name, new_layer)
+ elif isinstance(module, RMSNorm):
+ if hasattr(module,"quantized"):
+ continue
+ module.quantized= True
+ new_layer = quantized_layer.RMSNorm(module)
+ setattr(model, name, new_layer)
+ else:
+ replace_layer(module)
+
+ replace_layer(self)
+
+
+ @staticmethod
+ def state_dict_converter():
+ return FluxDiTStateDictConverter()
+
+
+class FluxDiTStateDictConverter:
+ def __init__(self):
+ pass
+
+ def from_diffusers(self, state_dict):
+ global_rename_dict = {
+ "context_embedder": "context_embedder",
+ "x_embedder": "x_embedder",
+ "time_text_embed.timestep_embedder.linear_1": "time_embedder.timestep_embedder.0",
+ "time_text_embed.timestep_embedder.linear_2": "time_embedder.timestep_embedder.2",
+ "time_text_embed.guidance_embedder.linear_1": "guidance_embedder.timestep_embedder.0",
+ "time_text_embed.guidance_embedder.linear_2": "guidance_embedder.timestep_embedder.2",
+ "time_text_embed.text_embedder.linear_1": "pooled_text_embedder.0",
+ "time_text_embed.text_embedder.linear_2": "pooled_text_embedder.2",
+ "norm_out.linear": "final_norm_out.linear",
+ "proj_out": "final_proj_out",
+ }
+ rename_dict = {
+ "proj_out": "proj_out",
+ "norm1.linear": "norm1_a.linear",
+ "norm1_context.linear": "norm1_b.linear",
+ "attn.to_q": "attn.a_to_q",
+ "attn.to_k": "attn.a_to_k",
+ "attn.to_v": "attn.a_to_v",
+ "attn.to_out.0": "attn.a_to_out",
+ "attn.add_q_proj": "attn.b_to_q",
+ "attn.add_k_proj": "attn.b_to_k",
+ "attn.add_v_proj": "attn.b_to_v",
+ "attn.to_add_out": "attn.b_to_out",
+ "ff.net.0.proj": "ff_a.0",
+ "ff.net.2": "ff_a.2",
+ "ff_context.net.0.proj": "ff_b.0",
+ "ff_context.net.2": "ff_b.2",
+ "attn.norm_q": "attn.norm_q_a",
+ "attn.norm_k": "attn.norm_k_a",
+ "attn.norm_added_q": "attn.norm_q_b",
+ "attn.norm_added_k": "attn.norm_k_b",
+ }
+ rename_dict_single = {
+ "attn.to_q": "a_to_q",
+ "attn.to_k": "a_to_k",
+ "attn.to_v": "a_to_v",
+ "attn.norm_q": "norm_q_a",
+ "attn.norm_k": "norm_k_a",
+ "norm.linear": "norm.linear",
+ "proj_mlp": "proj_in_besides_attn",
+ "proj_out": "proj_out",
+ }
+ state_dict_ = {}
+ for name, param in state_dict.items():
+ if name.endswith(".weight") or name.endswith(".bias"):
+ suffix = ".weight" if name.endswith(".weight") else ".bias"
+ prefix = name[:-len(suffix)]
+ if prefix in global_rename_dict:
+ state_dict_[global_rename_dict[prefix] + suffix] = param
+ elif prefix.startswith("transformer_blocks."):
+ names = prefix.split(".")
+ names[0] = "blocks"
+ middle = ".".join(names[2:])
+ if middle in rename_dict:
+ name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]])
+ state_dict_[name_] = param
+ elif prefix.startswith("single_transformer_blocks."):
+ names = prefix.split(".")
+ names[0] = "single_blocks"
+ middle = ".".join(names[2:])
+ if middle in rename_dict_single:
+ name_ = ".".join(names[:2] + [rename_dict_single[middle]] + [suffix[1:]])
+ state_dict_[name_] = param
+ else:
+ pass
+ else:
+ pass
+ for name in list(state_dict_.keys()):
+ if "single_blocks." in name and ".a_to_q." in name:
+ mlp = state_dict_.get(name.replace(".a_to_q.", ".proj_in_besides_attn."), None)
+ if mlp is None:
+ mlp = torch.zeros(4 * state_dict_[name].shape[0],
+ *state_dict_[name].shape[1:],
+ dtype=state_dict_[name].dtype)
+ else:
+ state_dict_.pop(name.replace(".a_to_q.", ".proj_in_besides_attn."))
+ param = torch.concat([
+ state_dict_.pop(name),
+ state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")),
+ state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")),
+ mlp,
+ ], dim=0)
+ name_ = name.replace(".a_to_q.", ".to_qkv_mlp.")
+ state_dict_[name_] = param
+ for name in list(state_dict_.keys()):
+ for component in ["a", "b"]:
+ if f".{component}_to_q." in name:
+ name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.")
+ param = torch.concat([
+ state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
+ state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
+ state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
+ ], dim=0)
+ state_dict_[name_] = param
+ state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q."))
+ state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k."))
+ state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v."))
+ return state_dict_
+
+ def from_civitai(self, state_dict):
+ rename_dict = {
+ "time_in.in_layer.bias": "time_embedder.timestep_embedder.0.bias",
+ "time_in.in_layer.weight": "time_embedder.timestep_embedder.0.weight",
+ "time_in.out_layer.bias": "time_embedder.timestep_embedder.2.bias",
+ "time_in.out_layer.weight": "time_embedder.timestep_embedder.2.weight",
+ "txt_in.bias": "context_embedder.bias",
+ "txt_in.weight": "context_embedder.weight",
+ "vector_in.in_layer.bias": "pooled_text_embedder.0.bias",
+ "vector_in.in_layer.weight": "pooled_text_embedder.0.weight",
+ "vector_in.out_layer.bias": "pooled_text_embedder.2.bias",
+ "vector_in.out_layer.weight": "pooled_text_embedder.2.weight",
+ "final_layer.linear.bias": "final_proj_out.bias",
+ "final_layer.linear.weight": "final_proj_out.weight",
+ "guidance_in.in_layer.bias": "guidance_embedder.timestep_embedder.0.bias",
+ "guidance_in.in_layer.weight": "guidance_embedder.timestep_embedder.0.weight",
+ "guidance_in.out_layer.bias": "guidance_embedder.timestep_embedder.2.bias",
+ "guidance_in.out_layer.weight": "guidance_embedder.timestep_embedder.2.weight",
+ "img_in.bias": "x_embedder.bias",
+ "img_in.weight": "x_embedder.weight",
+ "final_layer.adaLN_modulation.1.weight": "final_norm_out.linear.weight",
+ "final_layer.adaLN_modulation.1.bias": "final_norm_out.linear.bias",
+ }
+ suffix_rename_dict = {
+ "img_attn.norm.key_norm.scale": "attn.norm_k_a.weight",
+ "img_attn.norm.query_norm.scale": "attn.norm_q_a.weight",
+ "img_attn.proj.bias": "attn.a_to_out.bias",
+ "img_attn.proj.weight": "attn.a_to_out.weight",
+ "img_attn.qkv.bias": "attn.a_to_qkv.bias",
+ "img_attn.qkv.weight": "attn.a_to_qkv.weight",
+ "img_mlp.0.bias": "ff_a.0.bias",
+ "img_mlp.0.weight": "ff_a.0.weight",
+ "img_mlp.2.bias": "ff_a.2.bias",
+ "img_mlp.2.weight": "ff_a.2.weight",
+ "img_mod.lin.bias": "norm1_a.linear.bias",
+ "img_mod.lin.weight": "norm1_a.linear.weight",
+ "txt_attn.norm.key_norm.scale": "attn.norm_k_b.weight",
+ "txt_attn.norm.query_norm.scale": "attn.norm_q_b.weight",
+ "txt_attn.proj.bias": "attn.b_to_out.bias",
+ "txt_attn.proj.weight": "attn.b_to_out.weight",
+ "txt_attn.qkv.bias": "attn.b_to_qkv.bias",
+ "txt_attn.qkv.weight": "attn.b_to_qkv.weight",
+ "txt_mlp.0.bias": "ff_b.0.bias",
+ "txt_mlp.0.weight": "ff_b.0.weight",
+ "txt_mlp.2.bias": "ff_b.2.bias",
+ "txt_mlp.2.weight": "ff_b.2.weight",
+ "txt_mod.lin.bias": "norm1_b.linear.bias",
+ "txt_mod.lin.weight": "norm1_b.linear.weight",
+
+ "linear1.bias": "to_qkv_mlp.bias",
+ "linear1.weight": "to_qkv_mlp.weight",
+ "linear2.bias": "proj_out.bias",
+ "linear2.weight": "proj_out.weight",
+ "modulation.lin.bias": "norm.linear.bias",
+ "modulation.lin.weight": "norm.linear.weight",
+ "norm.key_norm.scale": "norm_k_a.weight",
+ "norm.query_norm.scale": "norm_q_a.weight",
+ }
+ state_dict_ = {}
+ for name, param in state_dict.items():
+ if name.startswith("model.diffusion_model."):
+ name = name[len("model.diffusion_model."):]
+ names = name.split(".")
+ if name in rename_dict:
+ rename = rename_dict[name]
+ if name.startswith("final_layer.adaLN_modulation.1."):
+ param = torch.concat([param[3072:], param[:3072]], dim=0)
+ state_dict_[rename] = param
+ elif names[0] == "double_blocks":
+ rename = f"blocks.{names[1]}." + suffix_rename_dict[".".join(names[2:])]
+ state_dict_[rename] = param
+ elif names[0] == "single_blocks":
+ if ".".join(names[2:]) in suffix_rename_dict:
+ rename = f"single_blocks.{names[1]}." + suffix_rename_dict[".".join(names[2:])]
+ state_dict_[rename] = param
+ else:
+ pass
+ if "guidance_embedder.timestep_embedder.0.weight" not in state_dict_:
+ return state_dict_, {"disable_guidance_embedder": True}
+ else:
+ return state_dict_
diff --git a/PusaV1/diffsynth/models/flux_infiniteyou.py b/PusaV1/diffsynth/models/flux_infiniteyou.py
new file mode 100644
index 0000000000000000000000000000000000000000..2015de4a6c6ccae0922136622a973e3cc0e39652
--- /dev/null
+++ b/PusaV1/diffsynth/models/flux_infiniteyou.py
@@ -0,0 +1,128 @@
+import math
+import torch
+import torch.nn as nn
+
+
+# FFN
+def FeedForward(dim, mult=4):
+ inner_dim = int(dim * mult)
+ return nn.Sequential(
+ nn.LayerNorm(dim),
+ nn.Linear(dim, inner_dim, bias=False),
+ nn.GELU(),
+ nn.Linear(inner_dim, dim, bias=False),
+ )
+
+
+def reshape_tensor(x, heads):
+ bs, length, width = x.shape
+ #(bs, length, width) --> (bs, length, n_heads, dim_per_head)
+ x = x.view(bs, length, heads, -1)
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
+ x = x.transpose(1, 2)
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
+ x = x.reshape(bs, heads, length, -1)
+ return x
+
+
+class PerceiverAttention(nn.Module):
+
+ def __init__(self, *, dim, dim_head=64, heads=8):
+ super().__init__()
+ self.scale = dim_head**-0.5
+ self.dim_head = dim_head
+ self.heads = heads
+ inner_dim = dim_head * heads
+
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
+
+ def forward(self, x, latents):
+ """
+ Args:
+ x (torch.Tensor): image features
+ shape (b, n1, D)
+ latent (torch.Tensor): latent features
+ shape (b, n2, D)
+ """
+ x = self.norm1(x)
+ latents = self.norm2(latents)
+
+ b, l, _ = latents.shape
+
+ q = self.to_q(latents)
+ kv_input = torch.cat((x, latents), dim=-2)
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
+
+ q = reshape_tensor(q, self.heads)
+ k = reshape_tensor(k, self.heads)
+ v = reshape_tensor(v, self.heads)
+
+ # attention
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
+ out = weight @ v
+
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
+
+ return self.to_out(out)
+
+
+class InfiniteYouImageProjector(nn.Module):
+
+ def __init__(
+ self,
+ dim=1280,
+ depth=4,
+ dim_head=64,
+ heads=20,
+ num_queries=8,
+ embedding_dim=512,
+ output_dim=4096,
+ ff_mult=4,
+ ):
+ super().__init__()
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
+ self.proj_in = nn.Linear(embedding_dim, dim)
+
+ self.proj_out = nn.Linear(dim, output_dim)
+ self.norm_out = nn.LayerNorm(output_dim)
+
+ self.layers = nn.ModuleList([])
+ for _ in range(depth):
+ self.layers.append(
+ nn.ModuleList([
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
+ FeedForward(dim=dim, mult=ff_mult),
+ ]))
+
+ def forward(self, x):
+
+ latents = self.latents.repeat(x.size(0), 1, 1)
+
+ x = self.proj_in(x)
+
+ for attn, ff in self.layers:
+ latents = attn(x, latents) + latents
+ latents = ff(latents) + latents
+
+ latents = self.proj_out(latents)
+ return self.norm_out(latents)
+
+ @staticmethod
+ def state_dict_converter():
+ return FluxInfiniteYouImageProjectorStateDictConverter()
+
+
+class FluxInfiniteYouImageProjectorStateDictConverter:
+
+ def __init__(self):
+ pass
+
+ def from_diffusers(self, state_dict):
+ return state_dict['image_proj']
diff --git a/PusaV1/diffsynth/models/flux_ipadapter.py b/PusaV1/diffsynth/models/flux_ipadapter.py
new file mode 100644
index 0000000000000000000000000000000000000000..575c75268c30f9c1d6e6b35d11b93bc57d80cb3f
--- /dev/null
+++ b/PusaV1/diffsynth/models/flux_ipadapter.py
@@ -0,0 +1,94 @@
+from .svd_image_encoder import SVDImageEncoder
+from .sd3_dit import RMSNorm
+from transformers import CLIPImageProcessor
+import torch
+
+
+class MLPProjModel(torch.nn.Module):
+ def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4):
+ super().__init__()
+
+ self.cross_attention_dim = cross_attention_dim
+ self.num_tokens = num_tokens
+
+ self.proj = torch.nn.Sequential(
+ torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2),
+ torch.nn.GELU(),
+ torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens),
+ )
+ self.norm = torch.nn.LayerNorm(cross_attention_dim)
+
+ def forward(self, id_embeds):
+ x = self.proj(id_embeds)
+ x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
+ x = self.norm(x)
+ return x
+
+class IpAdapterModule(torch.nn.Module):
+ def __init__(self, num_attention_heads, attention_head_dim, input_dim):
+ super().__init__()
+ self.num_heads = num_attention_heads
+ self.head_dim = attention_head_dim
+ output_dim = num_attention_heads * attention_head_dim
+ self.to_k_ip = torch.nn.Linear(input_dim, output_dim, bias=False)
+ self.to_v_ip = torch.nn.Linear(input_dim, output_dim, bias=False)
+ self.norm_added_k = RMSNorm(attention_head_dim, eps=1e-5, elementwise_affine=False)
+
+
+ def forward(self, hidden_states):
+ batch_size = hidden_states.shape[0]
+ # ip_k
+ ip_k = self.to_k_ip(hidden_states)
+ ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
+ ip_k = self.norm_added_k(ip_k)
+ # ip_v
+ ip_v = self.to_v_ip(hidden_states)
+ ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
+ return ip_k, ip_v
+
+
+class FluxIpAdapter(torch.nn.Module):
+ def __init__(self, num_attention_heads=24, attention_head_dim=128, cross_attention_dim=4096, num_tokens=128, num_blocks=57):
+ super().__init__()
+ self.ipadapter_modules = torch.nn.ModuleList([IpAdapterModule(num_attention_heads, attention_head_dim, cross_attention_dim) for _ in range(num_blocks)])
+ self.image_proj = MLPProjModel(cross_attention_dim=cross_attention_dim, id_embeddings_dim=1152, num_tokens=num_tokens)
+ self.set_adapter()
+
+ def set_adapter(self):
+ self.call_block_id = {i:i for i in range(len(self.ipadapter_modules))}
+
+ def forward(self, hidden_states, scale=1.0):
+ hidden_states = self.image_proj(hidden_states)
+ hidden_states = hidden_states.view(1, -1, hidden_states.shape[-1])
+ ip_kv_dict = {}
+ for block_id in self.call_block_id:
+ ipadapter_id = self.call_block_id[block_id]
+ ip_k, ip_v = self.ipadapter_modules[ipadapter_id](hidden_states)
+ ip_kv_dict[block_id] = {
+ "ip_k": ip_k,
+ "ip_v": ip_v,
+ "scale": scale
+ }
+ return ip_kv_dict
+
+ @staticmethod
+ def state_dict_converter():
+ return FluxIpAdapterStateDictConverter()
+
+
+class FluxIpAdapterStateDictConverter:
+ def __init__(self):
+ pass
+
+ def from_diffusers(self, state_dict):
+ state_dict_ = {}
+ for name in state_dict["ip_adapter"]:
+ name_ = 'ipadapter_modules.' + name
+ state_dict_[name_] = state_dict["ip_adapter"][name]
+ for name in state_dict["image_proj"]:
+ name_ = "image_proj." + name
+ state_dict_[name_] = state_dict["image_proj"][name]
+ return state_dict_
+
+ def from_civitai(self, state_dict):
+ return self.from_diffusers(state_dict)
diff --git a/PusaV1/diffsynth/models/flux_text_encoder.py b/PusaV1/diffsynth/models/flux_text_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..bff9d2944e50e42ef9136269b3e0b1c1ea508a79
--- /dev/null
+++ b/PusaV1/diffsynth/models/flux_text_encoder.py
@@ -0,0 +1,32 @@
+import torch
+from transformers import T5EncoderModel, T5Config
+from .sd_text_encoder import SDTextEncoder
+
+
+
+class FluxTextEncoder2(T5EncoderModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.eval()
+
+ def forward(self, input_ids):
+ outputs = super().forward(input_ids=input_ids)
+ prompt_emb = outputs.last_hidden_state
+ return prompt_emb
+
+ @staticmethod
+ def state_dict_converter():
+ return FluxTextEncoder2StateDictConverter()
+
+
+
+class FluxTextEncoder2StateDictConverter():
+ def __init__(self):
+ pass
+
+ def from_diffusers(self, state_dict):
+ state_dict_ = state_dict
+ return state_dict_
+
+ def from_civitai(self, state_dict):
+ return self.from_diffusers(state_dict)
diff --git a/PusaV1/diffsynth/models/flux_vae.py b/PusaV1/diffsynth/models/flux_vae.py
new file mode 100644
index 0000000000000000000000000000000000000000..c99c65522c4d0339adcc9aa734445377b9fdd5b5
--- /dev/null
+++ b/PusaV1/diffsynth/models/flux_vae.py
@@ -0,0 +1,303 @@
+from .sd3_vae_encoder import SD3VAEEncoder, SDVAEEncoderStateDictConverter
+from .sd3_vae_decoder import SD3VAEDecoder, SDVAEDecoderStateDictConverter
+
+
+class FluxVAEEncoder(SD3VAEEncoder):
+ def __init__(self):
+ super().__init__()
+ self.scaling_factor = 0.3611
+ self.shift_factor = 0.1159
+
+ @staticmethod
+ def state_dict_converter():
+ return FluxVAEEncoderStateDictConverter()
+
+
+class FluxVAEDecoder(SD3VAEDecoder):
+ def __init__(self):
+ super().__init__()
+ self.scaling_factor = 0.3611
+ self.shift_factor = 0.1159
+
+ @staticmethod
+ def state_dict_converter():
+ return FluxVAEDecoderStateDictConverter()
+
+
+class FluxVAEEncoderStateDictConverter(SDVAEEncoderStateDictConverter):
+ def __init__(self):
+ pass
+
+ def from_civitai(self, state_dict):
+ rename_dict = {
+ "encoder.conv_in.bias": "conv_in.bias",
+ "encoder.conv_in.weight": "conv_in.weight",
+ "encoder.conv_out.bias": "conv_out.bias",
+ "encoder.conv_out.weight": "conv_out.weight",
+ "encoder.down.0.block.0.conv1.bias": "blocks.0.conv1.bias",
+ "encoder.down.0.block.0.conv1.weight": "blocks.0.conv1.weight",
+ "encoder.down.0.block.0.conv2.bias": "blocks.0.conv2.bias",
+ "encoder.down.0.block.0.conv2.weight": "blocks.0.conv2.weight",
+ "encoder.down.0.block.0.norm1.bias": "blocks.0.norm1.bias",
+ "encoder.down.0.block.0.norm1.weight": "blocks.0.norm1.weight",
+ "encoder.down.0.block.0.norm2.bias": "blocks.0.norm2.bias",
+ "encoder.down.0.block.0.norm2.weight": "blocks.0.norm2.weight",
+ "encoder.down.0.block.1.conv1.bias": "blocks.1.conv1.bias",
+ "encoder.down.0.block.1.conv1.weight": "blocks.1.conv1.weight",
+ "encoder.down.0.block.1.conv2.bias": "blocks.1.conv2.bias",
+ "encoder.down.0.block.1.conv2.weight": "blocks.1.conv2.weight",
+ "encoder.down.0.block.1.norm1.bias": "blocks.1.norm1.bias",
+ "encoder.down.0.block.1.norm1.weight": "blocks.1.norm1.weight",
+ "encoder.down.0.block.1.norm2.bias": "blocks.1.norm2.bias",
+ "encoder.down.0.block.1.norm2.weight": "blocks.1.norm2.weight",
+ "encoder.down.0.downsample.conv.bias": "blocks.2.conv.bias",
+ "encoder.down.0.downsample.conv.weight": "blocks.2.conv.weight",
+ "encoder.down.1.block.0.conv1.bias": "blocks.3.conv1.bias",
+ "encoder.down.1.block.0.conv1.weight": "blocks.3.conv1.weight",
+ "encoder.down.1.block.0.conv2.bias": "blocks.3.conv2.bias",
+ "encoder.down.1.block.0.conv2.weight": "blocks.3.conv2.weight",
+ "encoder.down.1.block.0.nin_shortcut.bias": "blocks.3.conv_shortcut.bias",
+ "encoder.down.1.block.0.nin_shortcut.weight": "blocks.3.conv_shortcut.weight",
+ "encoder.down.1.block.0.norm1.bias": "blocks.3.norm1.bias",
+ "encoder.down.1.block.0.norm1.weight": "blocks.3.norm1.weight",
+ "encoder.down.1.block.0.norm2.bias": "blocks.3.norm2.bias",
+ "encoder.down.1.block.0.norm2.weight": "blocks.3.norm2.weight",
+ "encoder.down.1.block.1.conv1.bias": "blocks.4.conv1.bias",
+ "encoder.down.1.block.1.conv1.weight": "blocks.4.conv1.weight",
+ "encoder.down.1.block.1.conv2.bias": "blocks.4.conv2.bias",
+ "encoder.down.1.block.1.conv2.weight": "blocks.4.conv2.weight",
+ "encoder.down.1.block.1.norm1.bias": "blocks.4.norm1.bias",
+ "encoder.down.1.block.1.norm1.weight": "blocks.4.norm1.weight",
+ "encoder.down.1.block.1.norm2.bias": "blocks.4.norm2.bias",
+ "encoder.down.1.block.1.norm2.weight": "blocks.4.norm2.weight",
+ "encoder.down.1.downsample.conv.bias": "blocks.5.conv.bias",
+ "encoder.down.1.downsample.conv.weight": "blocks.5.conv.weight",
+ "encoder.down.2.block.0.conv1.bias": "blocks.6.conv1.bias",
+ "encoder.down.2.block.0.conv1.weight": "blocks.6.conv1.weight",
+ "encoder.down.2.block.0.conv2.bias": "blocks.6.conv2.bias",
+ "encoder.down.2.block.0.conv2.weight": "blocks.6.conv2.weight",
+ "encoder.down.2.block.0.nin_shortcut.bias": "blocks.6.conv_shortcut.bias",
+ "encoder.down.2.block.0.nin_shortcut.weight": "blocks.6.conv_shortcut.weight",
+ "encoder.down.2.block.0.norm1.bias": "blocks.6.norm1.bias",
+ "encoder.down.2.block.0.norm1.weight": "blocks.6.norm1.weight",
+ "encoder.down.2.block.0.norm2.bias": "blocks.6.norm2.bias",
+ "encoder.down.2.block.0.norm2.weight": "blocks.6.norm2.weight",
+ "encoder.down.2.block.1.conv1.bias": "blocks.7.conv1.bias",
+ "encoder.down.2.block.1.conv1.weight": "blocks.7.conv1.weight",
+ "encoder.down.2.block.1.conv2.bias": "blocks.7.conv2.bias",
+ "encoder.down.2.block.1.conv2.weight": "blocks.7.conv2.weight",
+ "encoder.down.2.block.1.norm1.bias": "blocks.7.norm1.bias",
+ "encoder.down.2.block.1.norm1.weight": "blocks.7.norm1.weight",
+ "encoder.down.2.block.1.norm2.bias": "blocks.7.norm2.bias",
+ "encoder.down.2.block.1.norm2.weight": "blocks.7.norm2.weight",
+ "encoder.down.2.downsample.conv.bias": "blocks.8.conv.bias",
+ "encoder.down.2.downsample.conv.weight": "blocks.8.conv.weight",
+ "encoder.down.3.block.0.conv1.bias": "blocks.9.conv1.bias",
+ "encoder.down.3.block.0.conv1.weight": "blocks.9.conv1.weight",
+ "encoder.down.3.block.0.conv2.bias": "blocks.9.conv2.bias",
+ "encoder.down.3.block.0.conv2.weight": "blocks.9.conv2.weight",
+ "encoder.down.3.block.0.norm1.bias": "blocks.9.norm1.bias",
+ "encoder.down.3.block.0.norm1.weight": "blocks.9.norm1.weight",
+ "encoder.down.3.block.0.norm2.bias": "blocks.9.norm2.bias",
+ "encoder.down.3.block.0.norm2.weight": "blocks.9.norm2.weight",
+ "encoder.down.3.block.1.conv1.bias": "blocks.10.conv1.bias",
+ "encoder.down.3.block.1.conv1.weight": "blocks.10.conv1.weight",
+ "encoder.down.3.block.1.conv2.bias": "blocks.10.conv2.bias",
+ "encoder.down.3.block.1.conv2.weight": "blocks.10.conv2.weight",
+ "encoder.down.3.block.1.norm1.bias": "blocks.10.norm1.bias",
+ "encoder.down.3.block.1.norm1.weight": "blocks.10.norm1.weight",
+ "encoder.down.3.block.1.norm2.bias": "blocks.10.norm2.bias",
+ "encoder.down.3.block.1.norm2.weight": "blocks.10.norm2.weight",
+ "encoder.mid.attn_1.k.bias": "blocks.12.transformer_blocks.0.to_k.bias",
+ "encoder.mid.attn_1.k.weight": "blocks.12.transformer_blocks.0.to_k.weight",
+ "encoder.mid.attn_1.norm.bias": "blocks.12.norm.bias",
+ "encoder.mid.attn_1.norm.weight": "blocks.12.norm.weight",
+ "encoder.mid.attn_1.proj_out.bias": "blocks.12.transformer_blocks.0.to_out.bias",
+ "encoder.mid.attn_1.proj_out.weight": "blocks.12.transformer_blocks.0.to_out.weight",
+ "encoder.mid.attn_1.q.bias": "blocks.12.transformer_blocks.0.to_q.bias",
+ "encoder.mid.attn_1.q.weight": "blocks.12.transformer_blocks.0.to_q.weight",
+ "encoder.mid.attn_1.v.bias": "blocks.12.transformer_blocks.0.to_v.bias",
+ "encoder.mid.attn_1.v.weight": "blocks.12.transformer_blocks.0.to_v.weight",
+ "encoder.mid.block_1.conv1.bias": "blocks.11.conv1.bias",
+ "encoder.mid.block_1.conv1.weight": "blocks.11.conv1.weight",
+ "encoder.mid.block_1.conv2.bias": "blocks.11.conv2.bias",
+ "encoder.mid.block_1.conv2.weight": "blocks.11.conv2.weight",
+ "encoder.mid.block_1.norm1.bias": "blocks.11.norm1.bias",
+ "encoder.mid.block_1.norm1.weight": "blocks.11.norm1.weight",
+ "encoder.mid.block_1.norm2.bias": "blocks.11.norm2.bias",
+ "encoder.mid.block_1.norm2.weight": "blocks.11.norm2.weight",
+ "encoder.mid.block_2.conv1.bias": "blocks.13.conv1.bias",
+ "encoder.mid.block_2.conv1.weight": "blocks.13.conv1.weight",
+ "encoder.mid.block_2.conv2.bias": "blocks.13.conv2.bias",
+ "encoder.mid.block_2.conv2.weight": "blocks.13.conv2.weight",
+ "encoder.mid.block_2.norm1.bias": "blocks.13.norm1.bias",
+ "encoder.mid.block_2.norm1.weight": "blocks.13.norm1.weight",
+ "encoder.mid.block_2.norm2.bias": "blocks.13.norm2.bias",
+ "encoder.mid.block_2.norm2.weight": "blocks.13.norm2.weight",
+ "encoder.norm_out.bias": "conv_norm_out.bias",
+ "encoder.norm_out.weight": "conv_norm_out.weight",
+ }
+ state_dict_ = {}
+ for name in state_dict:
+ if name in rename_dict:
+ param = state_dict[name]
+ if "transformer_blocks" in rename_dict[name]:
+ param = param.squeeze()
+ state_dict_[rename_dict[name]] = param
+ return state_dict_
+
+
+
+class FluxVAEDecoderStateDictConverter(SDVAEDecoderStateDictConverter):
+ def __init__(self):
+ pass
+
+ def from_civitai(self, state_dict):
+ rename_dict = {
+ "decoder.conv_in.bias": "conv_in.bias",
+ "decoder.conv_in.weight": "conv_in.weight",
+ "decoder.conv_out.bias": "conv_out.bias",
+ "decoder.conv_out.weight": "conv_out.weight",
+ "decoder.mid.attn_1.k.bias": "blocks.1.transformer_blocks.0.to_k.bias",
+ "decoder.mid.attn_1.k.weight": "blocks.1.transformer_blocks.0.to_k.weight",
+ "decoder.mid.attn_1.norm.bias": "blocks.1.norm.bias",
+ "decoder.mid.attn_1.norm.weight": "blocks.1.norm.weight",
+ "decoder.mid.attn_1.proj_out.bias": "blocks.1.transformer_blocks.0.to_out.bias",
+ "decoder.mid.attn_1.proj_out.weight": "blocks.1.transformer_blocks.0.to_out.weight",
+ "decoder.mid.attn_1.q.bias": "blocks.1.transformer_blocks.0.to_q.bias",
+ "decoder.mid.attn_1.q.weight": "blocks.1.transformer_blocks.0.to_q.weight",
+ "decoder.mid.attn_1.v.bias": "blocks.1.transformer_blocks.0.to_v.bias",
+ "decoder.mid.attn_1.v.weight": "blocks.1.transformer_blocks.0.to_v.weight",
+ "decoder.mid.block_1.conv1.bias": "blocks.0.conv1.bias",
+ "decoder.mid.block_1.conv1.weight": "blocks.0.conv1.weight",
+ "decoder.mid.block_1.conv2.bias": "blocks.0.conv2.bias",
+ "decoder.mid.block_1.conv2.weight": "blocks.0.conv2.weight",
+ "decoder.mid.block_1.norm1.bias": "blocks.0.norm1.bias",
+ "decoder.mid.block_1.norm1.weight": "blocks.0.norm1.weight",
+ "decoder.mid.block_1.norm2.bias": "blocks.0.norm2.bias",
+ "decoder.mid.block_1.norm2.weight": "blocks.0.norm2.weight",
+ "decoder.mid.block_2.conv1.bias": "blocks.2.conv1.bias",
+ "decoder.mid.block_2.conv1.weight": "blocks.2.conv1.weight",
+ "decoder.mid.block_2.conv2.bias": "blocks.2.conv2.bias",
+ "decoder.mid.block_2.conv2.weight": "blocks.2.conv2.weight",
+ "decoder.mid.block_2.norm1.bias": "blocks.2.norm1.bias",
+ "decoder.mid.block_2.norm1.weight": "blocks.2.norm1.weight",
+ "decoder.mid.block_2.norm2.bias": "blocks.2.norm2.bias",
+ "decoder.mid.block_2.norm2.weight": "blocks.2.norm2.weight",
+ "decoder.norm_out.bias": "conv_norm_out.bias",
+ "decoder.norm_out.weight": "conv_norm_out.weight",
+ "decoder.up.0.block.0.conv1.bias": "blocks.15.conv1.bias",
+ "decoder.up.0.block.0.conv1.weight": "blocks.15.conv1.weight",
+ "decoder.up.0.block.0.conv2.bias": "blocks.15.conv2.bias",
+ "decoder.up.0.block.0.conv2.weight": "blocks.15.conv2.weight",
+ "decoder.up.0.block.0.nin_shortcut.bias": "blocks.15.conv_shortcut.bias",
+ "decoder.up.0.block.0.nin_shortcut.weight": "blocks.15.conv_shortcut.weight",
+ "decoder.up.0.block.0.norm1.bias": "blocks.15.norm1.bias",
+ "decoder.up.0.block.0.norm1.weight": "blocks.15.norm1.weight",
+ "decoder.up.0.block.0.norm2.bias": "blocks.15.norm2.bias",
+ "decoder.up.0.block.0.norm2.weight": "blocks.15.norm2.weight",
+ "decoder.up.0.block.1.conv1.bias": "blocks.16.conv1.bias",
+ "decoder.up.0.block.1.conv1.weight": "blocks.16.conv1.weight",
+ "decoder.up.0.block.1.conv2.bias": "blocks.16.conv2.bias",
+ "decoder.up.0.block.1.conv2.weight": "blocks.16.conv2.weight",
+ "decoder.up.0.block.1.norm1.bias": "blocks.16.norm1.bias",
+ "decoder.up.0.block.1.norm1.weight": "blocks.16.norm1.weight",
+ "decoder.up.0.block.1.norm2.bias": "blocks.16.norm2.bias",
+ "decoder.up.0.block.1.norm2.weight": "blocks.16.norm2.weight",
+ "decoder.up.0.block.2.conv1.bias": "blocks.17.conv1.bias",
+ "decoder.up.0.block.2.conv1.weight": "blocks.17.conv1.weight",
+ "decoder.up.0.block.2.conv2.bias": "blocks.17.conv2.bias",
+ "decoder.up.0.block.2.conv2.weight": "blocks.17.conv2.weight",
+ "decoder.up.0.block.2.norm1.bias": "blocks.17.norm1.bias",
+ "decoder.up.0.block.2.norm1.weight": "blocks.17.norm1.weight",
+ "decoder.up.0.block.2.norm2.bias": "blocks.17.norm2.bias",
+ "decoder.up.0.block.2.norm2.weight": "blocks.17.norm2.weight",
+ "decoder.up.1.block.0.conv1.bias": "blocks.11.conv1.bias",
+ "decoder.up.1.block.0.conv1.weight": "blocks.11.conv1.weight",
+ "decoder.up.1.block.0.conv2.bias": "blocks.11.conv2.bias",
+ "decoder.up.1.block.0.conv2.weight": "blocks.11.conv2.weight",
+ "decoder.up.1.block.0.nin_shortcut.bias": "blocks.11.conv_shortcut.bias",
+ "decoder.up.1.block.0.nin_shortcut.weight": "blocks.11.conv_shortcut.weight",
+ "decoder.up.1.block.0.norm1.bias": "blocks.11.norm1.bias",
+ "decoder.up.1.block.0.norm1.weight": "blocks.11.norm1.weight",
+ "decoder.up.1.block.0.norm2.bias": "blocks.11.norm2.bias",
+ "decoder.up.1.block.0.norm2.weight": "blocks.11.norm2.weight",
+ "decoder.up.1.block.1.conv1.bias": "blocks.12.conv1.bias",
+ "decoder.up.1.block.1.conv1.weight": "blocks.12.conv1.weight",
+ "decoder.up.1.block.1.conv2.bias": "blocks.12.conv2.bias",
+ "decoder.up.1.block.1.conv2.weight": "blocks.12.conv2.weight",
+ "decoder.up.1.block.1.norm1.bias": "blocks.12.norm1.bias",
+ "decoder.up.1.block.1.norm1.weight": "blocks.12.norm1.weight",
+ "decoder.up.1.block.1.norm2.bias": "blocks.12.norm2.bias",
+ "decoder.up.1.block.1.norm2.weight": "blocks.12.norm2.weight",
+ "decoder.up.1.block.2.conv1.bias": "blocks.13.conv1.bias",
+ "decoder.up.1.block.2.conv1.weight": "blocks.13.conv1.weight",
+ "decoder.up.1.block.2.conv2.bias": "blocks.13.conv2.bias",
+ "decoder.up.1.block.2.conv2.weight": "blocks.13.conv2.weight",
+ "decoder.up.1.block.2.norm1.bias": "blocks.13.norm1.bias",
+ "decoder.up.1.block.2.norm1.weight": "blocks.13.norm1.weight",
+ "decoder.up.1.block.2.norm2.bias": "blocks.13.norm2.bias",
+ "decoder.up.1.block.2.norm2.weight": "blocks.13.norm2.weight",
+ "decoder.up.1.upsample.conv.bias": "blocks.14.conv.bias",
+ "decoder.up.1.upsample.conv.weight": "blocks.14.conv.weight",
+ "decoder.up.2.block.0.conv1.bias": "blocks.7.conv1.bias",
+ "decoder.up.2.block.0.conv1.weight": "blocks.7.conv1.weight",
+ "decoder.up.2.block.0.conv2.bias": "blocks.7.conv2.bias",
+ "decoder.up.2.block.0.conv2.weight": "blocks.7.conv2.weight",
+ "decoder.up.2.block.0.norm1.bias": "blocks.7.norm1.bias",
+ "decoder.up.2.block.0.norm1.weight": "blocks.7.norm1.weight",
+ "decoder.up.2.block.0.norm2.bias": "blocks.7.norm2.bias",
+ "decoder.up.2.block.0.norm2.weight": "blocks.7.norm2.weight",
+ "decoder.up.2.block.1.conv1.bias": "blocks.8.conv1.bias",
+ "decoder.up.2.block.1.conv1.weight": "blocks.8.conv1.weight",
+ "decoder.up.2.block.1.conv2.bias": "blocks.8.conv2.bias",
+ "decoder.up.2.block.1.conv2.weight": "blocks.8.conv2.weight",
+ "decoder.up.2.block.1.norm1.bias": "blocks.8.norm1.bias",
+ "decoder.up.2.block.1.norm1.weight": "blocks.8.norm1.weight",
+ "decoder.up.2.block.1.norm2.bias": "blocks.8.norm2.bias",
+ "decoder.up.2.block.1.norm2.weight": "blocks.8.norm2.weight",
+ "decoder.up.2.block.2.conv1.bias": "blocks.9.conv1.bias",
+ "decoder.up.2.block.2.conv1.weight": "blocks.9.conv1.weight",
+ "decoder.up.2.block.2.conv2.bias": "blocks.9.conv2.bias",
+ "decoder.up.2.block.2.conv2.weight": "blocks.9.conv2.weight",
+ "decoder.up.2.block.2.norm1.bias": "blocks.9.norm1.bias",
+ "decoder.up.2.block.2.norm1.weight": "blocks.9.norm1.weight",
+ "decoder.up.2.block.2.norm2.bias": "blocks.9.norm2.bias",
+ "decoder.up.2.block.2.norm2.weight": "blocks.9.norm2.weight",
+ "decoder.up.2.upsample.conv.bias": "blocks.10.conv.bias",
+ "decoder.up.2.upsample.conv.weight": "blocks.10.conv.weight",
+ "decoder.up.3.block.0.conv1.bias": "blocks.3.conv1.bias",
+ "decoder.up.3.block.0.conv1.weight": "blocks.3.conv1.weight",
+ "decoder.up.3.block.0.conv2.bias": "blocks.3.conv2.bias",
+ "decoder.up.3.block.0.conv2.weight": "blocks.3.conv2.weight",
+ "decoder.up.3.block.0.norm1.bias": "blocks.3.norm1.bias",
+ "decoder.up.3.block.0.norm1.weight": "blocks.3.norm1.weight",
+ "decoder.up.3.block.0.norm2.bias": "blocks.3.norm2.bias",
+ "decoder.up.3.block.0.norm2.weight": "blocks.3.norm2.weight",
+ "decoder.up.3.block.1.conv1.bias": "blocks.4.conv1.bias",
+ "decoder.up.3.block.1.conv1.weight": "blocks.4.conv1.weight",
+ "decoder.up.3.block.1.conv2.bias": "blocks.4.conv2.bias",
+ "decoder.up.3.block.1.conv2.weight": "blocks.4.conv2.weight",
+ "decoder.up.3.block.1.norm1.bias": "blocks.4.norm1.bias",
+ "decoder.up.3.block.1.norm1.weight": "blocks.4.norm1.weight",
+ "decoder.up.3.block.1.norm2.bias": "blocks.4.norm2.bias",
+ "decoder.up.3.block.1.norm2.weight": "blocks.4.norm2.weight",
+ "decoder.up.3.block.2.conv1.bias": "blocks.5.conv1.bias",
+ "decoder.up.3.block.2.conv1.weight": "blocks.5.conv1.weight",
+ "decoder.up.3.block.2.conv2.bias": "blocks.5.conv2.bias",
+ "decoder.up.3.block.2.conv2.weight": "blocks.5.conv2.weight",
+ "decoder.up.3.block.2.norm1.bias": "blocks.5.norm1.bias",
+ "decoder.up.3.block.2.norm1.weight": "blocks.5.norm1.weight",
+ "decoder.up.3.block.2.norm2.bias": "blocks.5.norm2.bias",
+ "decoder.up.3.block.2.norm2.weight": "blocks.5.norm2.weight",
+ "decoder.up.3.upsample.conv.bias": "blocks.6.conv.bias",
+ "decoder.up.3.upsample.conv.weight": "blocks.6.conv.weight",
+ }
+ state_dict_ = {}
+ for name in state_dict:
+ if name in rename_dict:
+ param = state_dict[name]
+ if "transformer_blocks" in rename_dict[name]:
+ param = param.squeeze()
+ state_dict_[rename_dict[name]] = param
+ return state_dict_
\ No newline at end of file
diff --git a/PusaV1/diffsynth/models/hunyuan_dit.py b/PusaV1/diffsynth/models/hunyuan_dit.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e27183d6993e7f64eee6b9b231dcb4d8d1a6bc0
--- /dev/null
+++ b/PusaV1/diffsynth/models/hunyuan_dit.py
@@ -0,0 +1,451 @@
+from .attention import Attention
+from einops import repeat, rearrange
+import math
+import torch
+
+
+class HunyuanDiTRotaryEmbedding(torch.nn.Module):
+
+ def __init__(self, q_norm_shape=88, k_norm_shape=88, rotary_emb_on_k=True):
+ super().__init__()
+ self.q_norm = torch.nn.LayerNorm((q_norm_shape,), elementwise_affine=True, eps=1e-06)
+ self.k_norm = torch.nn.LayerNorm((k_norm_shape,), elementwise_affine=True, eps=1e-06)
+ self.rotary_emb_on_k = rotary_emb_on_k
+ self.k_cache, self.v_cache = [], []
+
+ def reshape_for_broadcast(self, freqs_cis, x):
+ ndim = x.ndim
+ shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
+ return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
+
+ def rotate_half(self, x):
+ x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
+ return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
+
+ def apply_rotary_emb(self, xq, xk, freqs_cis):
+ xk_out = None
+ cos, sin = self.reshape_for_broadcast(freqs_cis, xq)
+ cos, sin = cos.to(xq.device), sin.to(xq.device)
+ xq_out = (xq.float() * cos + self.rotate_half(xq.float()) * sin).type_as(xq)
+ if xk is not None:
+ xk_out = (xk.float() * cos + self.rotate_half(xk.float()) * sin).type_as(xk)
+ return xq_out, xk_out
+
+ def forward(self, q, k, v, freqs_cis_img, to_cache=False):
+ # norm
+ q = self.q_norm(q)
+ k = self.k_norm(k)
+
+ # RoPE
+ if self.rotary_emb_on_k:
+ q, k = self.apply_rotary_emb(q, k, freqs_cis_img)
+ else:
+ q, _ = self.apply_rotary_emb(q, None, freqs_cis_img)
+
+ if to_cache:
+ self.k_cache.append(k)
+ self.v_cache.append(v)
+ elif len(self.k_cache) > 0 and len(self.v_cache) > 0:
+ k = torch.concat([k] + self.k_cache, dim=2)
+ v = torch.concat([v] + self.v_cache, dim=2)
+ self.k_cache, self.v_cache = [], []
+ return q, k, v
+
+
+class FP32_Layernorm(torch.nn.LayerNorm):
+ def forward(self, inputs):
+ origin_dtype = inputs.dtype
+ return torch.nn.functional.layer_norm(inputs.float(), self.normalized_shape, self.weight.float(), self.bias.float(), self.eps).to(origin_dtype)
+
+
+class FP32_SiLU(torch.nn.SiLU):
+ def forward(self, inputs):
+ origin_dtype = inputs.dtype
+ return torch.nn.functional.silu(inputs.float(), inplace=False).to(origin_dtype)
+
+
+class HunyuanDiTFinalLayer(torch.nn.Module):
+ def __init__(self, final_hidden_size=1408, condition_dim=1408, patch_size=2, out_channels=8):
+ super().__init__()
+ self.norm_final = torch.nn.LayerNorm(final_hidden_size, elementwise_affine=False, eps=1e-6)
+ self.linear = torch.nn.Linear(final_hidden_size, patch_size * patch_size * out_channels, bias=True)
+ self.adaLN_modulation = torch.nn.Sequential(
+ FP32_SiLU(),
+ torch.nn.Linear(condition_dim, 2 * final_hidden_size, bias=True)
+ )
+
+ def modulate(self, x, shift, scale):
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
+
+ def forward(self, hidden_states, condition_emb):
+ shift, scale = self.adaLN_modulation(condition_emb).chunk(2, dim=1)
+ hidden_states = self.modulate(self.norm_final(hidden_states), shift, scale)
+ hidden_states = self.linear(hidden_states)
+ return hidden_states
+
+
+class HunyuanDiTBlock(torch.nn.Module):
+
+ def __init__(
+ self,
+ hidden_dim=1408,
+ condition_dim=1408,
+ num_heads=16,
+ mlp_ratio=4.3637,
+ text_dim=1024,
+ skip_connection=False
+ ):
+ super().__init__()
+ self.norm1 = FP32_Layernorm((hidden_dim,), eps=1e-6, elementwise_affine=True)
+ self.rota1 = HunyuanDiTRotaryEmbedding(hidden_dim//num_heads, hidden_dim//num_heads)
+ self.attn1 = Attention(hidden_dim, num_heads, hidden_dim//num_heads, bias_q=True, bias_kv=True, bias_out=True)
+ self.norm2 = FP32_Layernorm((hidden_dim,), eps=1e-6, elementwise_affine=True)
+ self.rota2 = HunyuanDiTRotaryEmbedding(hidden_dim//num_heads, hidden_dim//num_heads, rotary_emb_on_k=False)
+ self.attn2 = Attention(hidden_dim, num_heads, hidden_dim//num_heads, kv_dim=text_dim, bias_q=True, bias_kv=True, bias_out=True)
+ self.norm3 = FP32_Layernorm((hidden_dim,), eps=1e-6, elementwise_affine=True)
+ self.modulation = torch.nn.Sequential(FP32_SiLU(), torch.nn.Linear(condition_dim, hidden_dim, bias=True))
+ self.mlp = torch.nn.Sequential(
+ torch.nn.Linear(hidden_dim, int(hidden_dim*mlp_ratio), bias=True),
+ torch.nn.GELU(approximate="tanh"),
+ torch.nn.Linear(int(hidden_dim*mlp_ratio), hidden_dim, bias=True)
+ )
+ if skip_connection:
+ self.skip_norm = FP32_Layernorm((hidden_dim * 2,), eps=1e-6, elementwise_affine=True)
+ self.skip_linear = torch.nn.Linear(hidden_dim * 2, hidden_dim, bias=True)
+ else:
+ self.skip_norm, self.skip_linear = None, None
+
+ def forward(self, hidden_states, condition_emb, text_emb, freq_cis_img, residual=None, to_cache=False):
+ # Long Skip Connection
+ if self.skip_norm is not None and self.skip_linear is not None:
+ hidden_states = torch.cat([hidden_states, residual], dim=-1)
+ hidden_states = self.skip_norm(hidden_states)
+ hidden_states = self.skip_linear(hidden_states)
+
+ # Self-Attention
+ shift_msa = self.modulation(condition_emb).unsqueeze(dim=1)
+ attn_input = self.norm1(hidden_states) + shift_msa
+ hidden_states = hidden_states + self.attn1(attn_input, qkv_preprocessor=lambda q, k, v: self.rota1(q, k, v, freq_cis_img, to_cache=to_cache))
+
+ # Cross-Attention
+ attn_input = self.norm3(hidden_states)
+ hidden_states = hidden_states + self.attn2(attn_input, text_emb, qkv_preprocessor=lambda q, k, v: self.rota2(q, k, v, freq_cis_img))
+
+ # FFN Layer
+ mlp_input = self.norm2(hidden_states)
+ hidden_states = hidden_states + self.mlp(mlp_input)
+ return hidden_states
+
+
+class AttentionPool(torch.nn.Module):
+ def __init__(self, spacial_dim, embed_dim, num_heads, output_dim = None):
+ super().__init__()
+ self.positional_embedding = torch.nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim ** 0.5)
+ self.k_proj = torch.nn.Linear(embed_dim, embed_dim)
+ self.q_proj = torch.nn.Linear(embed_dim, embed_dim)
+ self.v_proj = torch.nn.Linear(embed_dim, embed_dim)
+ self.c_proj = torch.nn.Linear(embed_dim, output_dim or embed_dim)
+ self.num_heads = num_heads
+
+ def forward(self, x):
+ x = x.permute(1, 0, 2) # NLC -> LNC
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC
+ x, _ = torch.nn.functional.multi_head_attention_forward(
+ query=x[:1], key=x, value=x,
+ embed_dim_to_check=x.shape[-1],
+ num_heads=self.num_heads,
+ q_proj_weight=self.q_proj.weight,
+ k_proj_weight=self.k_proj.weight,
+ v_proj_weight=self.v_proj.weight,
+ in_proj_weight=None,
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
+ bias_k=None,
+ bias_v=None,
+ add_zero_attn=False,
+ dropout_p=0,
+ out_proj_weight=self.c_proj.weight,
+ out_proj_bias=self.c_proj.bias,
+ use_separate_proj_weight=True,
+ training=self.training,
+ need_weights=False
+ )
+ return x.squeeze(0)
+
+
+class PatchEmbed(torch.nn.Module):
+ def __init__(
+ self,
+ patch_size=(2, 2),
+ in_chans=4,
+ embed_dim=1408,
+ bias=True,
+ ):
+ super().__init__()
+ self.proj = torch.nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
+
+ def forward(self, x):
+ x = self.proj(x)
+ x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
+ return x
+
+
+def timestep_embedding(t, dim, max_period=10000, repeat_only=False):
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
+ if not repeat_only:
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period)
+ * torch.arange(start=0, end=half, dtype=torch.float32)
+ / half
+ ).to(device=t.device) # size: [dim/2], 一个指数衰减的曲线
+ args = t[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat(
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
+ )
+ else:
+ embedding = repeat(t, "b -> b d", d=dim)
+ return embedding
+
+
+class TimestepEmbedder(torch.nn.Module):
+ def __init__(self, hidden_size=1408, frequency_embedding_size=256):
+ super().__init__()
+ self.mlp = torch.nn.Sequential(
+ torch.nn.Linear(frequency_embedding_size, hidden_size, bias=True),
+ torch.nn.SiLU(),
+ torch.nn.Linear(hidden_size, hidden_size, bias=True),
+ )
+ self.frequency_embedding_size = frequency_embedding_size
+
+ def forward(self, t):
+ t_freq = timestep_embedding(t, self.frequency_embedding_size).type(self.mlp[0].weight.dtype)
+ t_emb = self.mlp(t_freq)
+ return t_emb
+
+
+class HunyuanDiT(torch.nn.Module):
+ def __init__(self, num_layers_down=21, num_layers_up=19, in_channels=4, out_channels=8, hidden_dim=1408, text_dim=1024, t5_dim=2048, text_length=77, t5_length=256):
+ super().__init__()
+
+ # Embedders
+ self.text_emb_padding = torch.nn.Parameter(torch.randn(text_length + t5_length, text_dim, dtype=torch.float32))
+ self.t5_embedder = torch.nn.Sequential(
+ torch.nn.Linear(t5_dim, t5_dim * 4, bias=True),
+ FP32_SiLU(),
+ torch.nn.Linear(t5_dim * 4, text_dim, bias=True),
+ )
+ self.t5_pooler = AttentionPool(t5_length, t5_dim, num_heads=8, output_dim=1024)
+ self.style_embedder = torch.nn.Parameter(torch.randn(hidden_dim))
+ self.patch_embedder = PatchEmbed(in_chans=in_channels)
+ self.timestep_embedder = TimestepEmbedder()
+ self.extra_embedder = torch.nn.Sequential(
+ torch.nn.Linear(256 * 6 + 1024 + hidden_dim, hidden_dim * 4),
+ FP32_SiLU(),
+ torch.nn.Linear(hidden_dim * 4, hidden_dim),
+ )
+
+ # Transformer blocks
+ self.num_layers_down = num_layers_down
+ self.num_layers_up = num_layers_up
+ self.blocks = torch.nn.ModuleList(
+ [HunyuanDiTBlock(skip_connection=False) for _ in range(num_layers_down)] + \
+ [HunyuanDiTBlock(skip_connection=True) for _ in range(num_layers_up)]
+ )
+
+ # Output layers
+ self.final_layer = HunyuanDiTFinalLayer()
+ self.out_channels = out_channels
+
+ def prepare_text_emb(self, text_emb, text_emb_t5, text_emb_mask, text_emb_mask_t5):
+ text_emb_mask = text_emb_mask.bool()
+ text_emb_mask_t5 = text_emb_mask_t5.bool()
+ text_emb_t5 = self.t5_embedder(text_emb_t5)
+ text_emb = torch.cat([text_emb, text_emb_t5], dim=1)
+ text_emb_mask = torch.cat([text_emb_mask, text_emb_mask_t5], dim=-1)
+ text_emb = torch.where(text_emb_mask.unsqueeze(2), text_emb, self.text_emb_padding.to(text_emb))
+ return text_emb
+
+ def prepare_extra_emb(self, text_emb_t5, timestep, size_emb, dtype, batch_size):
+ # Text embedding
+ pooled_text_emb_t5 = self.t5_pooler(text_emb_t5)
+
+ # Timestep embedding
+ timestep_emb = self.timestep_embedder(timestep)
+
+ # Size embedding
+ size_emb = timestep_embedding(size_emb.view(-1), 256).to(dtype)
+ size_emb = size_emb.view(-1, 6 * 256)
+
+ # Style embedding
+ style_emb = repeat(self.style_embedder, "D -> B D", B=batch_size)
+
+ # Concatenate all extra vectors
+ extra_emb = torch.cat([pooled_text_emb_t5, size_emb, style_emb], dim=1)
+ condition_emb = timestep_emb + self.extra_embedder(extra_emb)
+
+ return condition_emb
+
+ def unpatchify(self, x, h, w):
+ return rearrange(x, "B (H W) (P Q C) -> B C (H P) (W Q)", H=h, W=w, P=2, Q=2)
+
+ def build_mask(self, data, is_bound):
+ _, _, H, W = data.shape
+ h = repeat(torch.arange(H), "H -> H W", H=H, W=W)
+ w = repeat(torch.arange(W), "W -> H W", H=H, W=W)
+ border_width = (H + W) // 4
+ pad = torch.ones_like(h) * border_width
+ mask = torch.stack([
+ pad if is_bound[0] else h + 1,
+ pad if is_bound[1] else H - h,
+ pad if is_bound[2] else w + 1,
+ pad if is_bound[3] else W - w
+ ]).min(dim=0).values
+ mask = mask.clip(1, border_width)
+ mask = (mask / border_width).to(dtype=data.dtype, device=data.device)
+ mask = rearrange(mask, "H W -> 1 H W")
+ return mask
+
+ def tiled_block_forward(self, block, hidden_states, condition_emb, text_emb, freq_cis_img, residual, torch_dtype, data_device, computation_device, tile_size, tile_stride):
+ B, C, H, W = hidden_states.shape
+
+ weight = torch.zeros((1, 1, H, W), dtype=torch_dtype, device=data_device)
+ values = torch.zeros((B, C, H, W), dtype=torch_dtype, device=data_device)
+
+ # Split tasks
+ tasks = []
+ for h in range(0, H, tile_stride):
+ for w in range(0, W, tile_stride):
+ if (h-tile_stride >= 0 and h-tile_stride+tile_size >= H) or (w-tile_stride >= 0 and w-tile_stride+tile_size >= W):
+ continue
+ h_, w_ = h + tile_size, w + tile_size
+ if h_ > H: h, h_ = H - tile_size, H
+ if w_ > W: w, w_ = W - tile_size, W
+ tasks.append((h, h_, w, w_))
+
+ # Run
+ for hl, hr, wl, wr in tasks:
+ hidden_states_batch = hidden_states[:, :, hl:hr, wl:wr].to(computation_device)
+ hidden_states_batch = rearrange(hidden_states_batch, "B C H W -> B (H W) C")
+ if residual is not None:
+ residual_batch = residual[:, :, hl:hr, wl:wr].to(computation_device)
+ residual_batch = rearrange(residual_batch, "B C H W -> B (H W) C")
+ else:
+ residual_batch = None
+
+ # Forward
+ hidden_states_batch = block(hidden_states_batch, condition_emb, text_emb, freq_cis_img, residual_batch).to(data_device)
+ hidden_states_batch = rearrange(hidden_states_batch, "B (H W) C -> B C H W", H=hr-hl)
+
+ mask = self.build_mask(hidden_states_batch, is_bound=(hl==0, hr>=H, wl==0, wr>=W))
+ values[:, :, hl:hr, wl:wr] += hidden_states_batch * mask
+ weight[:, :, hl:hr, wl:wr] += mask
+ values /= weight
+ return values
+
+ def forward(
+ self, hidden_states, text_emb, text_emb_t5, text_emb_mask, text_emb_mask_t5, timestep, size_emb, freq_cis_img,
+ tiled=False, tile_size=64, tile_stride=32,
+ to_cache=False,
+ use_gradient_checkpointing=False,
+ ):
+ # Embeddings
+ text_emb = self.prepare_text_emb(text_emb, text_emb_t5, text_emb_mask, text_emb_mask_t5)
+ condition_emb = self.prepare_extra_emb(text_emb_t5, timestep, size_emb, hidden_states.dtype, hidden_states.shape[0])
+
+ # Input
+ height, width = hidden_states.shape[-2], hidden_states.shape[-1]
+ hidden_states = self.patch_embedder(hidden_states)
+
+ # Blocks
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+ return custom_forward
+ if tiled:
+ hidden_states = rearrange(hidden_states, "B (H W) C -> B C H W", H=height//2)
+ residuals = []
+ for block_id, block in enumerate(self.blocks):
+ residual = residuals.pop() if block_id >= self.num_layers_down else None
+ hidden_states = self.tiled_block_forward(
+ block, hidden_states, condition_emb, text_emb, freq_cis_img, residual,
+ torch_dtype=hidden_states.dtype, data_device=hidden_states.device, computation_device=hidden_states.device,
+ tile_size=tile_size, tile_stride=tile_stride
+ )
+ if block_id < self.num_layers_down - 2:
+ residuals.append(hidden_states)
+ hidden_states = rearrange(hidden_states, "B C H W -> B (H W) C")
+ else:
+ residuals = []
+ for block_id, block in enumerate(self.blocks):
+ residual = residuals.pop() if block_id >= self.num_layers_down else None
+ if self.training and use_gradient_checkpointing:
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ hidden_states, condition_emb, text_emb, freq_cis_img, residual,
+ use_reentrant=False,
+ )
+ else:
+ hidden_states = block(hidden_states, condition_emb, text_emb, freq_cis_img, residual, to_cache=to_cache)
+ if block_id < self.num_layers_down - 2:
+ residuals.append(hidden_states)
+
+ # Output
+ hidden_states = self.final_layer(hidden_states, condition_emb)
+ hidden_states = self.unpatchify(hidden_states, height//2, width//2)
+ hidden_states, _ = hidden_states.chunk(2, dim=1)
+ return hidden_states
+
+ @staticmethod
+ def state_dict_converter():
+ return HunyuanDiTStateDictConverter()
+
+
+
+class HunyuanDiTStateDictConverter():
+ def __init__(self):
+ pass
+
+ def from_diffusers(self, state_dict):
+ state_dict_ = {}
+ for name, param in state_dict.items():
+ name_ = name
+ name_ = name_.replace(".default_modulation.", ".modulation.")
+ name_ = name_.replace(".mlp.fc1.", ".mlp.0.")
+ name_ = name_.replace(".mlp.fc2.", ".mlp.2.")
+ name_ = name_.replace(".attn1.q_norm.", ".rota1.q_norm.")
+ name_ = name_.replace(".attn2.q_norm.", ".rota2.q_norm.")
+ name_ = name_.replace(".attn1.k_norm.", ".rota1.k_norm.")
+ name_ = name_.replace(".attn2.k_norm.", ".rota2.k_norm.")
+ name_ = name_.replace(".q_proj.", ".to_q.")
+ name_ = name_.replace(".out_proj.", ".to_out.")
+ name_ = name_.replace("text_embedding_padding", "text_emb_padding")
+ name_ = name_.replace("mlp_t5.0.", "t5_embedder.0.")
+ name_ = name_.replace("mlp_t5.2.", "t5_embedder.2.")
+ name_ = name_.replace("pooler.", "t5_pooler.")
+ name_ = name_.replace("x_embedder.", "patch_embedder.")
+ name_ = name_.replace("t_embedder.", "timestep_embedder.")
+ name_ = name_.replace("t5_pooler.to_q.", "t5_pooler.q_proj.")
+ name_ = name_.replace("style_embedder.weight", "style_embedder")
+ if ".kv_proj." in name_:
+ param_k = param[:param.shape[0]//2]
+ param_v = param[param.shape[0]//2:]
+ state_dict_[name_.replace(".kv_proj.", ".to_k.")] = param_k
+ state_dict_[name_.replace(".kv_proj.", ".to_v.")] = param_v
+ elif ".Wqkv." in name_:
+ param_q = param[:param.shape[0]//3]
+ param_k = param[param.shape[0]//3:param.shape[0]//3*2]
+ param_v = param[param.shape[0]//3*2:]
+ state_dict_[name_.replace(".Wqkv.", ".to_q.")] = param_q
+ state_dict_[name_.replace(".Wqkv.", ".to_k.")] = param_k
+ state_dict_[name_.replace(".Wqkv.", ".to_v.")] = param_v
+ elif "style_embedder" in name_:
+ state_dict_[name_] = param.squeeze()
+ else:
+ state_dict_[name_] = param
+ return state_dict_
+
+ def from_civitai(self, state_dict):
+ return self.from_diffusers(state_dict)
diff --git a/PusaV1/diffsynth/models/hunyuan_dit_text_encoder.py b/PusaV1/diffsynth/models/hunyuan_dit_text_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..33999a8b10a319b736981dd8f3a911bbe9568e8d
--- /dev/null
+++ b/PusaV1/diffsynth/models/hunyuan_dit_text_encoder.py
@@ -0,0 +1,163 @@
+from transformers import BertModel, BertConfig, T5EncoderModel, T5Config
+import torch
+
+
+
+class HunyuanDiTCLIPTextEncoder(BertModel):
+ def __init__(self):
+ config = BertConfig(
+ _name_or_path = "",
+ architectures = ["BertModel"],
+ attention_probs_dropout_prob = 0.1,
+ bos_token_id = 0,
+ classifier_dropout = None,
+ directionality = "bidi",
+ eos_token_id = 2,
+ hidden_act = "gelu",
+ hidden_dropout_prob = 0.1,
+ hidden_size = 1024,
+ initializer_range = 0.02,
+ intermediate_size = 4096,
+ layer_norm_eps = 1e-12,
+ max_position_embeddings = 512,
+ model_type = "bert",
+ num_attention_heads = 16,
+ num_hidden_layers = 24,
+ output_past = True,
+ pad_token_id = 0,
+ pooler_fc_size = 768,
+ pooler_num_attention_heads = 12,
+ pooler_num_fc_layers = 3,
+ pooler_size_per_head = 128,
+ pooler_type = "first_token_transform",
+ position_embedding_type = "absolute",
+ torch_dtype = "float32",
+ transformers_version = "4.37.2",
+ type_vocab_size = 2,
+ use_cache = True,
+ vocab_size = 47020
+ )
+ super().__init__(config, add_pooling_layer=False)
+ self.eval()
+
+ def forward(self, input_ids, attention_mask, clip_skip=1):
+ input_shape = input_ids.size()
+
+ batch_size, seq_length = input_shape
+ device = input_ids.device
+
+ past_key_values_length = 0
+
+ if attention_mask is None:
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
+
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
+
+ embedding_output = self.embeddings(
+ input_ids=input_ids,
+ position_ids=None,
+ token_type_ids=None,
+ inputs_embeds=None,
+ past_key_values_length=0,
+ )
+ encoder_outputs = self.encoder(
+ embedding_output,
+ attention_mask=extended_attention_mask,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_values=None,
+ use_cache=False,
+ output_attentions=False,
+ output_hidden_states=True,
+ return_dict=True,
+ )
+ all_hidden_states = encoder_outputs.hidden_states
+ prompt_emb = all_hidden_states[-clip_skip]
+ if clip_skip > 1:
+ mean, std = all_hidden_states[-1].mean(), all_hidden_states[-1].std()
+ prompt_emb = (prompt_emb - prompt_emb.mean()) / prompt_emb.std() * std + mean
+ return prompt_emb
+
+ @staticmethod
+ def state_dict_converter():
+ return HunyuanDiTCLIPTextEncoderStateDictConverter()
+
+
+
+class HunyuanDiTT5TextEncoder(T5EncoderModel):
+ def __init__(self):
+ config = T5Config(
+ _name_or_path = "../HunyuanDiT/t2i/mt5",
+ architectures = ["MT5ForConditionalGeneration"],
+ classifier_dropout = 0.0,
+ d_ff = 5120,
+ d_kv = 64,
+ d_model = 2048,
+ decoder_start_token_id = 0,
+ dense_act_fn = "gelu_new",
+ dropout_rate = 0.1,
+ eos_token_id = 1,
+ feed_forward_proj = "gated-gelu",
+ initializer_factor = 1.0,
+ is_encoder_decoder = True,
+ is_gated_act = True,
+ layer_norm_epsilon = 1e-06,
+ model_type = "t5",
+ num_decoder_layers = 24,
+ num_heads = 32,
+ num_layers = 24,
+ output_past = True,
+ pad_token_id = 0,
+ relative_attention_max_distance = 128,
+ relative_attention_num_buckets = 32,
+ tie_word_embeddings = False,
+ tokenizer_class = "T5Tokenizer",
+ transformers_version = "4.37.2",
+ use_cache = True,
+ vocab_size = 250112
+ )
+ super().__init__(config)
+ self.eval()
+
+ def forward(self, input_ids, attention_mask, clip_skip=1):
+ outputs = super().forward(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ output_hidden_states=True,
+ )
+ prompt_emb = outputs.hidden_states[-clip_skip]
+ if clip_skip > 1:
+ mean, std = outputs.hidden_states[-1].mean(), outputs.hidden_states[-1].std()
+ prompt_emb = (prompt_emb - prompt_emb.mean()) / prompt_emb.std() * std + mean
+ return prompt_emb
+
+ @staticmethod
+ def state_dict_converter():
+ return HunyuanDiTT5TextEncoderStateDictConverter()
+
+
+
+class HunyuanDiTCLIPTextEncoderStateDictConverter():
+ def __init__(self):
+ pass
+
+ def from_diffusers(self, state_dict):
+ state_dict_ = {name[5:]: param for name, param in state_dict.items() if name.startswith("bert.")}
+ return state_dict_
+
+ def from_civitai(self, state_dict):
+ return self.from_diffusers(state_dict)
+
+
+class HunyuanDiTT5TextEncoderStateDictConverter():
+ def __init__(self):
+ pass
+
+ def from_diffusers(self, state_dict):
+ state_dict_ = {name: param for name, param in state_dict.items() if name.startswith("encoder.")}
+ state_dict_["shared.weight"] = state_dict["shared.weight"]
+ return state_dict_
+
+ def from_civitai(self, state_dict):
+ return self.from_diffusers(state_dict)
diff --git a/PusaV1/diffsynth/models/hunyuan_video_dit.py b/PusaV1/diffsynth/models/hunyuan_video_dit.py
new file mode 100644
index 0000000000000000000000000000000000000000..13155361734f8ec4cb7947941a1824f5064f1642
--- /dev/null
+++ b/PusaV1/diffsynth/models/hunyuan_video_dit.py
@@ -0,0 +1,920 @@
+import torch
+from .sd3_dit import TimestepEmbeddings, RMSNorm
+from .utils import init_weights_on_device
+from einops import rearrange, repeat
+from tqdm import tqdm
+from typing import Union, Tuple, List
+from .utils import hash_state_dict_keys
+
+
+def HunyuanVideoRope(latents):
+ def _to_tuple(x, dim=2):
+ if isinstance(x, int):
+ return (x,) * dim
+ elif len(x) == dim:
+ return x
+ else:
+ raise ValueError(f"Expected length {dim} or int, but got {x}")
+
+
+ def get_meshgrid_nd(start, *args, dim=2):
+ """
+ Get n-D meshgrid with start, stop and num.
+
+ Args:
+ start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop,
+ step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num
+ should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in
+ n-tuples.
+ *args: See above.
+ dim (int): Dimension of the meshgrid. Defaults to 2.
+
+ Returns:
+ grid (np.ndarray): [dim, ...]
+ """
+ if len(args) == 0:
+ # start is grid_size
+ num = _to_tuple(start, dim=dim)
+ start = (0,) * dim
+ stop = num
+ elif len(args) == 1:
+ # start is start, args[0] is stop, step is 1
+ start = _to_tuple(start, dim=dim)
+ stop = _to_tuple(args[0], dim=dim)
+ num = [stop[i] - start[i] for i in range(dim)]
+ elif len(args) == 2:
+ # start is start, args[0] is stop, args[1] is num
+ start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0
+ stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32
+ num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124
+ else:
+ raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
+
+ # PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False)
+ axis_grid = []
+ for i in range(dim):
+ a, b, n = start[i], stop[i], num[i]
+ g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
+ axis_grid.append(g)
+ grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D]
+ grid = torch.stack(grid, dim=0) # [dim, W, H, D]
+
+ return grid
+
+
+ def get_1d_rotary_pos_embed(
+ dim: int,
+ pos: Union[torch.FloatTensor, int],
+ theta: float = 10000.0,
+ use_real: bool = False,
+ theta_rescale_factor: float = 1.0,
+ interpolation_factor: float = 1.0,
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+ """
+ Precompute the frequency tensor for complex exponential (cis) with given dimensions.
+ (Note: `cis` means `cos + i * sin`, where i is the imaginary unit.)
+
+ This function calculates a frequency tensor with complex exponential using the given dimension 'dim'
+ and the end index 'end'. The 'theta' parameter scales the frequencies.
+ The returned tensor contains complex values in complex64 data type.
+
+ Args:
+ dim (int): Dimension of the frequency tensor.
+ pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar
+ theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
+ use_real (bool, optional): If True, return real part and imaginary part separately.
+ Otherwise, return complex numbers.
+ theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0.
+
+ Returns:
+ freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2]
+ freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D]
+ """
+ if isinstance(pos, int):
+ pos = torch.arange(pos).float()
+
+ # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
+ # has some connection to NTK literature
+ if theta_rescale_factor != 1.0:
+ theta *= theta_rescale_factor ** (dim / (dim - 2))
+
+ freqs = 1.0 / (
+ theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
+ ) # [D/2]
+ # assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}"
+ freqs = torch.outer(pos * interpolation_factor, freqs) # [S, D/2]
+ if use_real:
+ freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
+ freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
+ return freqs_cos, freqs_sin
+ else:
+ freqs_cis = torch.polar(
+ torch.ones_like(freqs), freqs
+ ) # complex64 # [S, D/2]
+ return freqs_cis
+
+
+ def get_nd_rotary_pos_embed(
+ rope_dim_list,
+ start,
+ *args,
+ theta=10000.0,
+ use_real=False,
+ theta_rescale_factor: Union[float, List[float]] = 1.0,
+ interpolation_factor: Union[float, List[float]] = 1.0,
+ ):
+ """
+ This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure.
+
+ Args:
+ rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n.
+ sum(rope_dim_list) should equal to head_dim of attention layer.
+ start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start,
+ args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num.
+ *args: See above.
+ theta (float): Scaling factor for frequency computation. Defaults to 10000.0.
+ use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers.
+ Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real
+ part and an imaginary part separately.
+ theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0.
+
+ Returns:
+ pos_embed (torch.Tensor): [HW, D/2]
+ """
+
+ grid = get_meshgrid_nd(
+ start, *args, dim=len(rope_dim_list)
+ ) # [3, W, H, D] / [2, W, H]
+
+ if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float):
+ theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list)
+ elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1:
+ theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list)
+ assert len(theta_rescale_factor) == len(
+ rope_dim_list
+ ), "len(theta_rescale_factor) should equal to len(rope_dim_list)"
+
+ if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float):
+ interpolation_factor = [interpolation_factor] * len(rope_dim_list)
+ elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1:
+ interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list)
+ assert len(interpolation_factor) == len(
+ rope_dim_list
+ ), "len(interpolation_factor) should equal to len(rope_dim_list)"
+
+ # use 1/ndim of dimensions to encode grid_axis
+ embs = []
+ for i in range(len(rope_dim_list)):
+ emb = get_1d_rotary_pos_embed(
+ rope_dim_list[i],
+ grid[i].reshape(-1),
+ theta,
+ use_real=use_real,
+ theta_rescale_factor=theta_rescale_factor[i],
+ interpolation_factor=interpolation_factor[i],
+ ) # 2 x [WHD, rope_dim_list[i]]
+ embs.append(emb)
+
+ if use_real:
+ cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2)
+ sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2)
+ return cos, sin
+ else:
+ emb = torch.cat(embs, dim=1) # (WHD, D/2)
+ return emb
+
+ freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
+ [16, 56, 56],
+ [latents.shape[2], latents.shape[3] // 2, latents.shape[4] // 2],
+ theta=256,
+ use_real=True,
+ theta_rescale_factor=1,
+ )
+ return freqs_cos, freqs_sin
+
+
+class PatchEmbed(torch.nn.Module):
+ def __init__(self, patch_size=(1, 2, 2), in_channels=16, embed_dim=3072):
+ super().__init__()
+ self.proj = torch.nn.Conv3d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
+
+ def forward(self, x):
+ x = self.proj(x)
+ x = x.flatten(2).transpose(1, 2)
+ return x
+
+
+class IndividualTokenRefinerBlock(torch.nn.Module):
+ def __init__(self, hidden_size=3072, num_heads=24):
+ super().__init__()
+ self.num_heads = num_heads
+ self.norm1 = torch.nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
+ self.self_attn_qkv = torch.nn.Linear(hidden_size, hidden_size * 3)
+ self.self_attn_proj = torch.nn.Linear(hidden_size, hidden_size)
+
+ self.norm2 = torch.nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
+ self.mlp = torch.nn.Sequential(
+ torch.nn.Linear(hidden_size, hidden_size * 4),
+ torch.nn.SiLU(),
+ torch.nn.Linear(hidden_size * 4, hidden_size)
+ )
+ self.adaLN_modulation = torch.nn.Sequential(
+ torch.nn.SiLU(),
+ torch.nn.Linear(hidden_size, hidden_size * 2, device="cuda", dtype=torch.bfloat16),
+ )
+
+ def forward(self, x, c, attn_mask=None):
+ gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
+
+ norm_x = self.norm1(x)
+ qkv = self.self_attn_qkv(norm_x)
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
+
+ attn = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
+ attn = rearrange(attn, "B H L D -> B L (H D)")
+
+ x = x + self.self_attn_proj(attn) * gate_msa.unsqueeze(1)
+ x = x + self.mlp(self.norm2(x)) * gate_mlp.unsqueeze(1)
+
+ return x
+
+
+class SingleTokenRefiner(torch.nn.Module):
+ def __init__(self, in_channels=4096, hidden_size=3072, depth=2):
+ super().__init__()
+ self.input_embedder = torch.nn.Linear(in_channels, hidden_size, bias=True)
+ self.t_embedder = TimestepEmbeddings(256, hidden_size, computation_device="cpu")
+ self.c_embedder = torch.nn.Sequential(
+ torch.nn.Linear(in_channels, hidden_size),
+ torch.nn.SiLU(),
+ torch.nn.Linear(hidden_size, hidden_size)
+ )
+ self.blocks = torch.nn.ModuleList([IndividualTokenRefinerBlock(hidden_size=hidden_size) for _ in range(depth)])
+
+ def forward(self, x, t, mask=None):
+ timestep_aware_representations = self.t_embedder(t, dtype=torch.float32)
+
+ mask_float = mask.float().unsqueeze(-1)
+ context_aware_representations = (x * mask_float).sum(dim=1) / mask_float.sum(dim=1)
+ context_aware_representations = self.c_embedder(context_aware_representations)
+ c = timestep_aware_representations + context_aware_representations
+
+ x = self.input_embedder(x)
+
+ mask = mask.to(device=x.device, dtype=torch.bool)
+ mask = repeat(mask, "B L -> B 1 D L", D=mask.shape[-1])
+ mask = mask & mask.transpose(2, 3)
+ mask[:, :, :, 0] = True
+
+ for block in self.blocks:
+ x = block(x, c, mask)
+
+ return x
+
+
+class ModulateDiT(torch.nn.Module):
+ def __init__(self, hidden_size, factor=6):
+ super().__init__()
+ self.act = torch.nn.SiLU()
+ self.linear = torch.nn.Linear(hidden_size, factor * hidden_size)
+
+ def forward(self, x):
+ return self.linear(self.act(x))
+
+
+def modulate(x, shift=None, scale=None, tr_shift=None, tr_scale=None, tr_token=None):
+ if tr_shift is not None:
+ x_zero = x[:, :tr_token] * (1 + tr_scale.unsqueeze(1)) + tr_shift.unsqueeze(1)
+ x_orig = x[:, tr_token:] * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
+ x = torch.concat((x_zero, x_orig), dim=1)
+ return x
+ if scale is None and shift is None:
+ return x
+ elif shift is None:
+ return x * (1 + scale.unsqueeze(1))
+ elif scale is None:
+ return x + shift.unsqueeze(1)
+ else:
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
+
+
+def reshape_for_broadcast(
+ freqs_cis,
+ x: torch.Tensor,
+ head_first=False,
+):
+ ndim = x.ndim
+ assert 0 <= 1 < ndim
+
+ if isinstance(freqs_cis, tuple):
+ # freqs_cis: (cos, sin) in real space
+ if head_first:
+ assert freqs_cis[0].shape == (
+ x.shape[-2],
+ x.shape[-1],
+ ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
+ shape = [
+ d if i == ndim - 2 or i == ndim - 1 else 1
+ for i, d in enumerate(x.shape)
+ ]
+ else:
+ assert freqs_cis[0].shape == (
+ x.shape[1],
+ x.shape[-1],
+ ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
+ return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
+ else:
+ # freqs_cis: values in complex space
+ if head_first:
+ assert freqs_cis.shape == (
+ x.shape[-2],
+ x.shape[-1],
+ ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
+ shape = [
+ d if i == ndim - 2 or i == ndim - 1 else 1
+ for i, d in enumerate(x.shape)
+ ]
+ else:
+ assert freqs_cis.shape == (
+ x.shape[1],
+ x.shape[-1],
+ ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
+ return freqs_cis.view(*shape)
+
+
+def rotate_half(x):
+ x_real, x_imag = (
+ x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
+ ) # [B, S, H, D//2]
+ return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
+
+
+def apply_rotary_emb(
+ xq: torch.Tensor,
+ xk: torch.Tensor,
+ freqs_cis,
+ head_first: bool = False,
+):
+ xk_out = None
+ if isinstance(freqs_cis, tuple):
+ cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
+ cos, sin = cos.to(xq.device), sin.to(xq.device)
+ # real * cos - imag * sin
+ # imag * cos + real * sin
+ xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
+ xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
+ else:
+ # view_as_complex will pack [..., D/2, 2](real) to [..., D/2](complex)
+ xq_ = torch.view_as_complex(
+ xq.float().reshape(*xq.shape[:-1], -1, 2)
+ ) # [B, S, H, D//2]
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(
+ xq.device
+ ) # [S, D//2] --> [1, S, 1, D//2]
+ # (real, imag) * (cos, sin) = (real * cos - imag * sin, imag * cos + real * sin)
+ # view_as_real will expand [..., D/2](complex) to [..., D/2, 2](real)
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
+ xk_ = torch.view_as_complex(
+ xk.float().reshape(*xk.shape[:-1], -1, 2)
+ ) # [B, S, H, D//2]
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)
+
+ return xq_out, xk_out
+
+
+def attention(q, k, v):
+ q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
+ x = x.transpose(1, 2).flatten(2, 3)
+ return x
+
+
+def apply_gate(x, gate, tr_gate=None, tr_token=None):
+ if tr_gate is not None:
+ x_zero = x[:, :tr_token] * tr_gate.unsqueeze(1)
+ x_orig = x[:, tr_token:] * gate.unsqueeze(1)
+ return torch.concat((x_zero, x_orig), dim=1)
+ else:
+ return x * gate.unsqueeze(1)
+
+
+class MMDoubleStreamBlockComponent(torch.nn.Module):
+ def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
+ super().__init__()
+ self.heads_num = heads_num
+
+ self.mod = ModulateDiT(hidden_size)
+ self.norm1 = torch.nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+
+ self.to_qkv = torch.nn.Linear(hidden_size, hidden_size * 3)
+ self.norm_q = RMSNorm(dim=hidden_size // heads_num, eps=1e-6)
+ self.norm_k = RMSNorm(dim=hidden_size // heads_num, eps=1e-6)
+ self.to_out = torch.nn.Linear(hidden_size, hidden_size)
+
+ self.norm2 = torch.nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.ff = torch.nn.Sequential(
+ torch.nn.Linear(hidden_size, hidden_size * mlp_width_ratio),
+ torch.nn.GELU(approximate="tanh"),
+ torch.nn.Linear(hidden_size * mlp_width_ratio, hidden_size)
+ )
+
+ def forward(self, hidden_states, conditioning, freqs_cis=None, token_replace_vec=None, tr_token=None):
+ mod1_shift, mod1_scale, mod1_gate, mod2_shift, mod2_scale, mod2_gate = self.mod(conditioning).chunk(6, dim=-1)
+ if token_replace_vec is not None:
+ assert tr_token is not None
+ tr_mod1_shift, tr_mod1_scale, tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate = self.mod(token_replace_vec).chunk(6, dim=-1)
+ else:
+ tr_mod1_shift, tr_mod1_scale, tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate = None, None, None, None, None, None
+
+ norm_hidden_states = self.norm1(hidden_states)
+ norm_hidden_states = modulate(norm_hidden_states, shift=mod1_shift, scale=mod1_scale,
+ tr_shift=tr_mod1_shift, tr_scale=tr_mod1_scale, tr_token=tr_token)
+ qkv = self.to_qkv(norm_hidden_states)
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
+
+ q = self.norm_q(q)
+ k = self.norm_k(k)
+
+ if freqs_cis is not None:
+ q, k = apply_rotary_emb(q, k, freqs_cis, head_first=False)
+ return (q, k, v), (mod1_gate, mod2_shift, mod2_scale, mod2_gate), (tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate)
+
+ def process_ff(self, hidden_states, attn_output, mod, mod_tr=None, tr_token=None):
+ mod1_gate, mod2_shift, mod2_scale, mod2_gate = mod
+ if mod_tr is not None:
+ tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate = mod_tr
+ else:
+ tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate = None, None, None, None
+ hidden_states = hidden_states + apply_gate(self.to_out(attn_output), mod1_gate, tr_mod1_gate, tr_token)
+ x = self.ff(modulate(self.norm2(hidden_states), shift=mod2_shift, scale=mod2_scale, tr_shift=tr_mod2_shift, tr_scale=tr_mod2_scale, tr_token=tr_token))
+ hidden_states = hidden_states + apply_gate(x, mod2_gate, tr_mod2_gate, tr_token)
+ return hidden_states
+
+
+class MMDoubleStreamBlock(torch.nn.Module):
+ def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
+ super().__init__()
+ self.component_a = MMDoubleStreamBlockComponent(hidden_size, heads_num, mlp_width_ratio)
+ self.component_b = MMDoubleStreamBlockComponent(hidden_size, heads_num, mlp_width_ratio)
+
+ def forward(self, hidden_states_a, hidden_states_b, conditioning, freqs_cis, token_replace_vec=None, tr_token=None, split_token=71):
+ (q_a, k_a, v_a), mod_a, mod_tr = self.component_a(hidden_states_a, conditioning, freqs_cis, token_replace_vec, tr_token)
+ (q_b, k_b, v_b), mod_b, _ = self.component_b(hidden_states_b, conditioning, freqs_cis=None)
+
+ q_a, q_b = torch.concat([q_a, q_b[:, :split_token]], dim=1), q_b[:, split_token:].contiguous()
+ k_a, k_b = torch.concat([k_a, k_b[:, :split_token]], dim=1), k_b[:, split_token:].contiguous()
+ v_a, v_b = torch.concat([v_a, v_b[:, :split_token]], dim=1), v_b[:, split_token:].contiguous()
+ attn_output_a = attention(q_a, k_a, v_a)
+ attn_output_b = attention(q_b, k_b, v_b)
+ attn_output_a, attn_output_b = attn_output_a[:, :-split_token].contiguous(), torch.concat([attn_output_a[:, -split_token:], attn_output_b], dim=1)
+
+ hidden_states_a = self.component_a.process_ff(hidden_states_a, attn_output_a, mod_a, mod_tr, tr_token)
+ hidden_states_b = self.component_b.process_ff(hidden_states_b, attn_output_b, mod_b)
+ return hidden_states_a, hidden_states_b
+
+
+class MMSingleStreamBlockOriginal(torch.nn.Module):
+ def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.heads_num = heads_num
+ self.mlp_hidden_dim = hidden_size * mlp_width_ratio
+
+ self.linear1 = torch.nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
+ self.linear2 = torch.nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
+
+ self.q_norm = RMSNorm(dim=hidden_size // heads_num, eps=1e-6)
+ self.k_norm = RMSNorm(dim=hidden_size // heads_num, eps=1e-6)
+
+ self.pre_norm = torch.nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+
+ self.mlp_act = torch.nn.GELU(approximate="tanh")
+ self.modulation = ModulateDiT(hidden_size, factor=3)
+
+ def forward(self, x, vec, freqs_cis=None, txt_len=256):
+ mod_shift, mod_scale, mod_gate = self.modulation(vec).chunk(3, dim=-1)
+ x_mod = modulate(self.pre_norm(x), shift=mod_shift, scale=mod_scale)
+ qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
+ q = self.q_norm(q)
+ k = self.k_norm(k)
+
+ q_a, q_b = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :]
+ k_a, k_b = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
+ q_a, k_a = apply_rotary_emb(q_a, k_a, freqs_cis, head_first=False)
+ q = torch.cat((q_a, q_b), dim=1)
+ k = torch.cat((k_a, k_b), dim=1)
+
+ attn_output_a = attention(q[:, :-185].contiguous(), k[:, :-185].contiguous(), v[:, :-185].contiguous())
+ attn_output_b = attention(q[:, -185:].contiguous(), k[:, -185:].contiguous(), v[:, -185:].contiguous())
+ attn_output = torch.concat([attn_output_a, attn_output_b], dim=1)
+
+ output = self.linear2(torch.cat((attn_output, self.mlp_act(mlp)), 2))
+ return x + output * mod_gate.unsqueeze(1)
+
+
+class MMSingleStreamBlock(torch.nn.Module):
+ def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
+ super().__init__()
+ self.heads_num = heads_num
+
+ self.mod = ModulateDiT(hidden_size, factor=3)
+ self.norm = torch.nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+
+ self.to_qkv = torch.nn.Linear(hidden_size, hidden_size * 3)
+ self.norm_q = RMSNorm(dim=hidden_size // heads_num, eps=1e-6)
+ self.norm_k = RMSNorm(dim=hidden_size // heads_num, eps=1e-6)
+ self.to_out = torch.nn.Linear(hidden_size, hidden_size)
+
+ self.ff = torch.nn.Sequential(
+ torch.nn.Linear(hidden_size, hidden_size * mlp_width_ratio),
+ torch.nn.GELU(approximate="tanh"),
+ torch.nn.Linear(hidden_size * mlp_width_ratio, hidden_size, bias=False)
+ )
+
+ def forward(self, hidden_states, conditioning, freqs_cis=None, txt_len=256, token_replace_vec=None, tr_token=None, split_token=71):
+ mod_shift, mod_scale, mod_gate = self.mod(conditioning).chunk(3, dim=-1)
+ if token_replace_vec is not None:
+ assert tr_token is not None
+ tr_mod_shift, tr_mod_scale, tr_mod_gate = self.mod(token_replace_vec).chunk(3, dim=-1)
+ else:
+ tr_mod_shift, tr_mod_scale, tr_mod_gate = None, None, None
+
+ norm_hidden_states = self.norm(hidden_states)
+ norm_hidden_states = modulate(norm_hidden_states, shift=mod_shift, scale=mod_scale,
+ tr_shift=tr_mod_shift, tr_scale=tr_mod_scale, tr_token=tr_token)
+ qkv = self.to_qkv(norm_hidden_states)
+
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
+
+ q = self.norm_q(q)
+ k = self.norm_k(k)
+
+ q_a, q_b = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :]
+ k_a, k_b = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
+ q_a, k_a = apply_rotary_emb(q_a, k_a, freqs_cis, head_first=False)
+
+ v_len = txt_len - split_token
+ q_a, q_b = torch.concat([q_a, q_b[:, :split_token]], dim=1), q_b[:, split_token:].contiguous()
+ k_a, k_b = torch.concat([k_a, k_b[:, :split_token]], dim=1), k_b[:, split_token:].contiguous()
+ v_a, v_b = v[:, :-v_len].contiguous(), v[:, -v_len:].contiguous()
+
+ attn_output_a = attention(q_a, k_a, v_a)
+ attn_output_b = attention(q_b, k_b, v_b)
+ attn_output = torch.concat([attn_output_a, attn_output_b], dim=1)
+
+ hidden_states = hidden_states + apply_gate(self.to_out(attn_output), mod_gate, tr_mod_gate, tr_token)
+ hidden_states = hidden_states + apply_gate(self.ff(norm_hidden_states), mod_gate, tr_mod_gate, tr_token)
+ return hidden_states
+
+
+class FinalLayer(torch.nn.Module):
+ def __init__(self, hidden_size=3072, patch_size=(1, 2, 2), out_channels=16):
+ super().__init__()
+
+ self.norm_final = torch.nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.linear = torch.nn.Linear(hidden_size, patch_size[0] * patch_size[1] * patch_size[2] * out_channels)
+
+ self.adaLN_modulation = torch.nn.Sequential(torch.nn.SiLU(), torch.nn.Linear(hidden_size, 2 * hidden_size))
+
+ def forward(self, x, c):
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
+ x = modulate(self.norm_final(x), shift=shift, scale=scale)
+ x = self.linear(x)
+ return x
+
+
+class HunyuanVideoDiT(torch.nn.Module):
+ def __init__(self, in_channels=16, hidden_size=3072, text_dim=4096, num_double_blocks=20, num_single_blocks=40, guidance_embed=True):
+ super().__init__()
+ self.img_in = PatchEmbed(in_channels=in_channels, embed_dim=hidden_size)
+ self.txt_in = SingleTokenRefiner(in_channels=text_dim, hidden_size=hidden_size)
+ self.time_in = TimestepEmbeddings(256, hidden_size, computation_device="cpu")
+ self.vector_in = torch.nn.Sequential(
+ torch.nn.Linear(768, hidden_size),
+ torch.nn.SiLU(),
+ torch.nn.Linear(hidden_size, hidden_size)
+ )
+ self.guidance_in = TimestepEmbeddings(256, hidden_size, computation_device="cpu") if guidance_embed else None
+ self.double_blocks = torch.nn.ModuleList([MMDoubleStreamBlock(hidden_size) for _ in range(num_double_blocks)])
+ self.single_blocks = torch.nn.ModuleList([MMSingleStreamBlock(hidden_size) for _ in range(num_single_blocks)])
+ self.final_layer = FinalLayer(hidden_size)
+
+ # TODO: remove these parameters
+ self.dtype = torch.bfloat16
+ self.patch_size = [1, 2, 2]
+ self.hidden_size = 3072
+ self.heads_num = 24
+ self.rope_dim_list = [16, 56, 56]
+
+ def unpatchify(self, x, T, H, W):
+ x = rearrange(x, "B (T H W) (C pT pH pW) -> B C (T pT) (H pH) (W pW)", H=H, W=W, pT=1, pH=2, pW=2)
+ return x
+
+ def enable_block_wise_offload(self, warm_device="cuda", cold_device="cpu"):
+ self.warm_device = warm_device
+ self.cold_device = cold_device
+ self.to(self.cold_device)
+
+ def load_models_to_device(self, loadmodel_names=[], device="cpu"):
+ for model_name in loadmodel_names:
+ model = getattr(self, model_name)
+ if model is not None:
+ model.to(device)
+ torch.cuda.empty_cache()
+
+ def prepare_freqs(self, latents):
+ return HunyuanVideoRope(latents)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ t: torch.Tensor,
+ prompt_emb: torch.Tensor = None,
+ text_mask: torch.Tensor = None,
+ pooled_prompt_emb: torch.Tensor = None,
+ freqs_cos: torch.Tensor = None,
+ freqs_sin: torch.Tensor = None,
+ guidance: torch.Tensor = None,
+ **kwargs
+ ):
+ B, C, T, H, W = x.shape
+
+ vec = self.time_in(t, dtype=torch.float32) + self.vector_in(pooled_prompt_emb)
+ if self.guidance_in is not None:
+ vec += self.guidance_in(guidance * 1000, dtype=torch.float32)
+ img = self.img_in(x)
+ txt = self.txt_in(prompt_emb, t, text_mask)
+
+ for block in tqdm(self.double_blocks, desc="Double stream blocks"):
+ img, txt = block(img, txt, vec, (freqs_cos, freqs_sin))
+
+ x = torch.concat([img, txt], dim=1)
+ for block in tqdm(self.single_blocks, desc="Single stream blocks"):
+ x = block(x, vec, (freqs_cos, freqs_sin))
+
+ img = x[:, :-256]
+ img = self.final_layer(img, vec)
+ img = self.unpatchify(img, T=T//1, H=H//2, W=W//2)
+ return img
+
+
+ def enable_auto_offload(self, dtype=torch.bfloat16, device="cuda"):
+ def cast_to(weight, dtype=None, device=None, copy=False):
+ if device is None or weight.device == device:
+ if not copy:
+ if dtype is None or weight.dtype == dtype:
+ return weight
+ return weight.to(dtype=dtype, copy=copy)
+
+ r = torch.empty_like(weight, dtype=dtype, device=device)
+ r.copy_(weight)
+ return r
+
+ def cast_weight(s, input=None, dtype=None, device=None):
+ if input is not None:
+ if dtype is None:
+ dtype = input.dtype
+ if device is None:
+ device = input.device
+ weight = cast_to(s.weight, dtype, device)
+ return weight
+
+ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
+ if input is not None:
+ if dtype is None:
+ dtype = input.dtype
+ if bias_dtype is None:
+ bias_dtype = dtype
+ if device is None:
+ device = input.device
+ weight = cast_to(s.weight, dtype, device)
+ bias = cast_to(s.bias, bias_dtype, device) if s.bias is not None else None
+ return weight, bias
+
+ class quantized_layer:
+ class Linear(torch.nn.Linear):
+ def __init__(self, *args, dtype=torch.bfloat16, device="cuda", **kwargs):
+ super().__init__(*args, **kwargs)
+ self.dtype = dtype
+ self.device = device
+
+ def block_forward_(self, x, i, j, dtype, device):
+ weight_ = cast_to(
+ self.weight[j * self.block_size: (j + 1) * self.block_size, i * self.block_size: (i + 1) * self.block_size],
+ dtype=dtype, device=device
+ )
+ if self.bias is None or i > 0:
+ bias_ = None
+ else:
+ bias_ = cast_to(self.bias[j * self.block_size: (j + 1) * self.block_size], dtype=dtype, device=device)
+ x_ = x[..., i * self.block_size: (i + 1) * self.block_size]
+ y_ = torch.nn.functional.linear(x_, weight_, bias_)
+ del x_, weight_, bias_
+ torch.cuda.empty_cache()
+ return y_
+
+ def block_forward(self, x, **kwargs):
+ # This feature can only reduce 2GB VRAM, so we disable it.
+ y = torch.zeros(x.shape[:-1] + (self.out_features,), dtype=x.dtype, device=x.device)
+ for i in range((self.in_features + self.block_size - 1) // self.block_size):
+ for j in range((self.out_features + self.block_size - 1) // self.block_size):
+ y[..., j * self.block_size: (j + 1) * self.block_size] += self.block_forward_(x, i, j, dtype=x.dtype, device=x.device)
+ return y
+
+ def forward(self, x, **kwargs):
+ weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device)
+ return torch.nn.functional.linear(x, weight, bias)
+
+
+ class RMSNorm(torch.nn.Module):
+ def __init__(self, module, dtype=torch.bfloat16, device="cuda"):
+ super().__init__()
+ self.module = module
+ self.dtype = dtype
+ self.device = device
+
+ def forward(self, hidden_states, **kwargs):
+ input_dtype = hidden_states.dtype
+ variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.module.eps)
+ hidden_states = hidden_states.to(input_dtype)
+ if self.module.weight is not None:
+ weight = cast_weight(self.module, hidden_states, dtype=torch.bfloat16, device="cuda")
+ hidden_states = hidden_states * weight
+ return hidden_states
+
+ class Conv3d(torch.nn.Conv3d):
+ def __init__(self, *args, dtype=torch.bfloat16, device="cuda", **kwargs):
+ super().__init__(*args, **kwargs)
+ self.dtype = dtype
+ self.device = device
+
+ def forward(self, x):
+ weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device)
+ return torch.nn.functional.conv3d(x, weight, bias, self.stride, self.padding, self.dilation, self.groups)
+
+ class LayerNorm(torch.nn.LayerNorm):
+ def __init__(self, *args, dtype=torch.bfloat16, device="cuda", **kwargs):
+ super().__init__(*args, **kwargs)
+ self.dtype = dtype
+ self.device = device
+
+ def forward(self, x):
+ if self.weight is not None and self.bias is not None:
+ weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device)
+ return torch.nn.functional.layer_norm(x, self.normalized_shape, weight, bias, self.eps)
+ else:
+ return torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
+
+ def replace_layer(model, dtype=torch.bfloat16, device="cuda"):
+ for name, module in model.named_children():
+ if isinstance(module, torch.nn.Linear):
+ with init_weights_on_device():
+ new_layer = quantized_layer.Linear(
+ module.in_features, module.out_features, bias=module.bias is not None,
+ dtype=dtype, device=device
+ )
+ new_layer.load_state_dict(module.state_dict(), assign=True)
+ setattr(model, name, new_layer)
+ elif isinstance(module, torch.nn.Conv3d):
+ with init_weights_on_device():
+ new_layer = quantized_layer.Conv3d(
+ module.in_channels, module.out_channels, kernel_size=module.kernel_size, stride=module.stride,
+ dtype=dtype, device=device
+ )
+ new_layer.load_state_dict(module.state_dict(), assign=True)
+ setattr(model, name, new_layer)
+ elif isinstance(module, RMSNorm):
+ new_layer = quantized_layer.RMSNorm(
+ module,
+ dtype=dtype, device=device
+ )
+ setattr(model, name, new_layer)
+ elif isinstance(module, torch.nn.LayerNorm):
+ with init_weights_on_device():
+ new_layer = quantized_layer.LayerNorm(
+ module.normalized_shape, elementwise_affine=module.elementwise_affine, eps=module.eps,
+ dtype=dtype, device=device
+ )
+ new_layer.load_state_dict(module.state_dict(), assign=True)
+ setattr(model, name, new_layer)
+ else:
+ replace_layer(module, dtype=dtype, device=device)
+
+ replace_layer(self, dtype=dtype, device=device)
+
+ @staticmethod
+ def state_dict_converter():
+ return HunyuanVideoDiTStateDictConverter()
+
+
+class HunyuanVideoDiTStateDictConverter:
+ def __init__(self):
+ pass
+
+ def from_civitai(self, state_dict):
+ origin_hash_key = hash_state_dict_keys(state_dict, with_shape=True)
+ if "module" in state_dict:
+ state_dict = state_dict["module"]
+ direct_dict = {
+ "img_in.proj": "img_in.proj",
+ "time_in.mlp.0": "time_in.timestep_embedder.0",
+ "time_in.mlp.2": "time_in.timestep_embedder.2",
+ "vector_in.in_layer": "vector_in.0",
+ "vector_in.out_layer": "vector_in.2",
+ "guidance_in.mlp.0": "guidance_in.timestep_embedder.0",
+ "guidance_in.mlp.2": "guidance_in.timestep_embedder.2",
+ "txt_in.input_embedder": "txt_in.input_embedder",
+ "txt_in.t_embedder.mlp.0": "txt_in.t_embedder.timestep_embedder.0",
+ "txt_in.t_embedder.mlp.2": "txt_in.t_embedder.timestep_embedder.2",
+ "txt_in.c_embedder.linear_1": "txt_in.c_embedder.0",
+ "txt_in.c_embedder.linear_2": "txt_in.c_embedder.2",
+ "final_layer.linear": "final_layer.linear",
+ "final_layer.adaLN_modulation.1": "final_layer.adaLN_modulation.1",
+ }
+ txt_suffix_dict = {
+ "norm1": "norm1",
+ "self_attn_qkv": "self_attn_qkv",
+ "self_attn_proj": "self_attn_proj",
+ "norm2": "norm2",
+ "mlp.fc1": "mlp.0",
+ "mlp.fc2": "mlp.2",
+ "adaLN_modulation.1": "adaLN_modulation.1",
+ }
+ double_suffix_dict = {
+ "img_mod.linear": "component_a.mod.linear",
+ "img_attn_qkv": "component_a.to_qkv",
+ "img_attn_q_norm": "component_a.norm_q",
+ "img_attn_k_norm": "component_a.norm_k",
+ "img_attn_proj": "component_a.to_out",
+ "img_mlp.fc1": "component_a.ff.0",
+ "img_mlp.fc2": "component_a.ff.2",
+ "txt_mod.linear": "component_b.mod.linear",
+ "txt_attn_qkv": "component_b.to_qkv",
+ "txt_attn_q_norm": "component_b.norm_q",
+ "txt_attn_k_norm": "component_b.norm_k",
+ "txt_attn_proj": "component_b.to_out",
+ "txt_mlp.fc1": "component_b.ff.0",
+ "txt_mlp.fc2": "component_b.ff.2",
+ }
+ single_suffix_dict = {
+ "linear1": ["to_qkv", "ff.0"],
+ "linear2": ["to_out", "ff.2"],
+ "q_norm": "norm_q",
+ "k_norm": "norm_k",
+ "modulation.linear": "mod.linear",
+ }
+ # single_suffix_dict = {
+ # "linear1": "linear1",
+ # "linear2": "linear2",
+ # "q_norm": "q_norm",
+ # "k_norm": "k_norm",
+ # "modulation.linear": "modulation.linear",
+ # }
+ state_dict_ = {}
+ for name, param in state_dict.items():
+ names = name.split(".")
+ direct_name = ".".join(names[:-1])
+ if direct_name in direct_dict:
+ name_ = direct_dict[direct_name] + "." + names[-1]
+ state_dict_[name_] = param
+ elif names[0] == "double_blocks":
+ prefix = ".".join(names[:2])
+ suffix = ".".join(names[2:-1])
+ name_ = prefix + "." + double_suffix_dict[suffix] + "." + names[-1]
+ state_dict_[name_] = param
+ elif names[0] == "single_blocks":
+ prefix = ".".join(names[:2])
+ suffix = ".".join(names[2:-1])
+ if isinstance(single_suffix_dict[suffix], list):
+ if suffix == "linear1":
+ name_a, name_b = single_suffix_dict[suffix]
+ param_a, param_b = torch.split(param, (3072*3, 3072*4), dim=0)
+ state_dict_[prefix + "." + name_a + "." + names[-1]] = param_a
+ state_dict_[prefix + "." + name_b + "." + names[-1]] = param_b
+ elif suffix == "linear2":
+ if names[-1] == "weight":
+ name_a, name_b = single_suffix_dict[suffix]
+ param_a, param_b = torch.split(param, (3072*1, 3072*4), dim=-1)
+ state_dict_[prefix + "." + name_a + "." + names[-1]] = param_a
+ state_dict_[prefix + "." + name_b + "." + names[-1]] = param_b
+ else:
+ name_a, name_b = single_suffix_dict[suffix]
+ state_dict_[prefix + "." + name_a + "." + names[-1]] = param
+ else:
+ pass
+ else:
+ name_ = prefix + "." + single_suffix_dict[suffix] + "." + names[-1]
+ state_dict_[name_] = param
+ elif names[0] == "txt_in":
+ prefix = ".".join(names[:4]).replace(".individual_token_refiner.", ".")
+ suffix = ".".join(names[4:-1])
+ name_ = prefix + "." + txt_suffix_dict[suffix] + "." + names[-1]
+ state_dict_[name_] = param
+ else:
+ pass
+
+ return state_dict_
diff --git a/PusaV1/diffsynth/models/hunyuan_video_text_encoder.py b/PusaV1/diffsynth/models/hunyuan_video_text_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce7a6805a163c709a4f3de82784538b300541119
--- /dev/null
+++ b/PusaV1/diffsynth/models/hunyuan_video_text_encoder.py
@@ -0,0 +1,68 @@
+from transformers import LlamaModel, LlamaConfig, DynamicCache, LlavaForConditionalGeneration
+from copy import deepcopy
+import torch
+
+
+class HunyuanVideoLLMEncoder(LlamaModel):
+
+ def __init__(self, config: LlamaConfig):
+ super().__init__(config)
+ self.auto_offload = False
+
+ def enable_auto_offload(self, **kwargs):
+ self.auto_offload = True
+
+ def forward(self, input_ids, attention_mask, hidden_state_skip_layer=2):
+ embed_tokens = deepcopy(self.embed_tokens).to(input_ids.device) if self.auto_offload else self.embed_tokens
+ inputs_embeds = embed_tokens(input_ids)
+
+ past_key_values = DynamicCache()
+
+ cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device)
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, None, False)
+ hidden_states = inputs_embeds
+
+ # create position embeddings to be shared across the decoder layers
+ rotary_emb = deepcopy(self.rotary_emb).to(input_ids.device) if self.auto_offload else self.rotary_emb
+ position_embeddings = rotary_emb(hidden_states, position_ids)
+
+ # decoder layers
+ for layer_id, decoder_layer in enumerate(self.layers):
+ if self.auto_offload:
+ decoder_layer = deepcopy(decoder_layer).to(hidden_states.device)
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_values,
+ output_attentions=False,
+ use_cache=True,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ )
+ hidden_states = layer_outputs[0]
+ if layer_id + hidden_state_skip_layer + 1 >= len(self.layers):
+ break
+
+ return hidden_states
+
+
+class HunyuanVideoMLLMEncoder(LlavaForConditionalGeneration):
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.auto_offload = False
+
+ def enable_auto_offload(self, **kwargs):
+ self.auto_offload = True
+
+ # TODO: implement the low VRAM inference for MLLM.
+ def forward(self, input_ids, pixel_values, attention_mask, hidden_state_skip_layer=2):
+ outputs = super().forward(input_ids=input_ids,
+ attention_mask=attention_mask,
+ output_hidden_states=True,
+ pixel_values=pixel_values)
+ hidden_state = outputs.hidden_states[-(hidden_state_skip_layer + 1)]
+ return hidden_state
diff --git a/PusaV1/diffsynth/models/hunyuan_video_vae_decoder.py b/PusaV1/diffsynth/models/hunyuan_video_vae_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae09ff85a9149edd46a36e7bfeb29a372a81e12c
--- /dev/null
+++ b/PusaV1/diffsynth/models/hunyuan_video_vae_decoder.py
@@ -0,0 +1,507 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+import numpy as np
+from tqdm import tqdm
+from einops import repeat
+
+
+class CausalConv3d(nn.Module):
+
+ def __init__(self, in_channel, out_channel, kernel_size, stride=1, dilation=1, pad_mode='replicate', **kwargs):
+ super().__init__()
+ self.pad_mode = pad_mode
+ self.time_causal_padding = (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size - 1, 0
+ ) # W, H, T
+ self.conv = nn.Conv3d(in_channel, out_channel, kernel_size, stride=stride, dilation=dilation, **kwargs)
+
+ def forward(self, x):
+ x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
+ return self.conv(x)
+
+
+class UpsampleCausal3D(nn.Module):
+
+ def __init__(self, channels, use_conv=False, out_channels=None, kernel_size=None, bias=True, upsample_factor=(2, 2, 2)):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.upsample_factor = upsample_factor
+ self.conv = None
+ if use_conv:
+ kernel_size = 3 if kernel_size is None else kernel_size
+ self.conv = CausalConv3d(self.channels, self.out_channels, kernel_size=kernel_size, bias=bias)
+
+ def forward(self, hidden_states):
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
+ dtype = hidden_states.dtype
+ if dtype == torch.bfloat16:
+ hidden_states = hidden_states.to(torch.float32)
+
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
+ if hidden_states.shape[0] >= 64:
+ hidden_states = hidden_states.contiguous()
+
+ # interpolate
+ B, C, T, H, W = hidden_states.shape
+ first_h, other_h = hidden_states.split((1, T - 1), dim=2)
+ if T > 1:
+ other_h = F.interpolate(other_h, scale_factor=self.upsample_factor, mode="nearest")
+ first_h = F.interpolate(first_h.squeeze(2), scale_factor=self.upsample_factor[1:], mode="nearest").unsqueeze(2)
+ hidden_states = torch.cat((first_h, other_h), dim=2) if T > 1 else first_h
+
+ # If the input is bfloat16, we cast back to bfloat16
+ if dtype == torch.bfloat16:
+ hidden_states = hidden_states.to(dtype)
+
+ if self.conv:
+ hidden_states = self.conv(hidden_states)
+
+ return hidden_states
+
+
+class ResnetBlockCausal3D(nn.Module):
+
+ def __init__(self, in_channels, out_channels=None, dropout=0.0, groups=32, eps=1e-6, conv_shortcut_bias=True):
+ super().__init__()
+ self.pre_norm = True
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+
+ self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
+ self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, stride=1)
+
+ self.norm2 = nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True)
+ self.conv2 = CausalConv3d(out_channels, out_channels, kernel_size=3, stride=1)
+
+ self.dropout = nn.Dropout(dropout)
+ self.nonlinearity = nn.SiLU()
+
+ self.conv_shortcut = None
+ if in_channels != out_channels:
+ self.conv_shortcut = CausalConv3d(in_channels, out_channels, kernel_size=1, stride=1, bias=conv_shortcut_bias)
+
+ def forward(self, input_tensor):
+ hidden_states = input_tensor
+ # conv1
+ hidden_states = self.norm1(hidden_states)
+ hidden_states = self.nonlinearity(hidden_states)
+ hidden_states = self.conv1(hidden_states)
+
+ # conv2
+ hidden_states = self.norm2(hidden_states)
+ hidden_states = self.nonlinearity(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.conv2(hidden_states)
+ # shortcut
+ if self.conv_shortcut is not None:
+ input_tensor = (self.conv_shortcut(input_tensor))
+ # shortcut and scale
+ output_tensor = input_tensor + hidden_states
+
+ return output_tensor
+
+
+def prepare_causal_attention_mask(n_frame, n_hw, dtype, device, batch_size=None):
+ seq_len = n_frame * n_hw
+ mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device)
+ for i in range(seq_len):
+ i_frame = i // n_hw
+ mask[i, :(i_frame + 1) * n_hw] = 0
+ if batch_size is not None:
+ mask = mask.unsqueeze(0).expand(batch_size, -1, -1)
+ return mask
+
+
+class Attention(nn.Module):
+
+ def __init__(self,
+ in_channels,
+ num_heads,
+ head_dim,
+ num_groups=32,
+ dropout=0.0,
+ eps=1e-6,
+ bias=True,
+ residual_connection=True):
+ super().__init__()
+ self.num_heads = num_heads
+ self.head_dim = head_dim
+ self.residual_connection = residual_connection
+ dim_inner = head_dim * num_heads
+ self.group_norm = nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=eps, affine=True)
+ self.to_q = nn.Linear(in_channels, dim_inner, bias=bias)
+ self.to_k = nn.Linear(in_channels, dim_inner, bias=bias)
+ self.to_v = nn.Linear(in_channels, dim_inner, bias=bias)
+ self.to_out = nn.Sequential(nn.Linear(dim_inner, in_channels, bias=bias), nn.Dropout(dropout))
+
+ def forward(self, input_tensor, attn_mask=None):
+ hidden_states = self.group_norm(input_tensor.transpose(1, 2)).transpose(1, 2)
+ batch_size = hidden_states.shape[0]
+
+ q = self.to_q(hidden_states)
+ k = self.to_k(hidden_states)
+ v = self.to_v(hidden_states)
+
+ q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
+ k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
+ v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
+
+ if attn_mask is not None:
+ attn_mask = attn_mask.view(batch_size, self.num_heads, -1, attn_mask.shape[-1])
+ hidden_states = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
+ hidden_states = self.to_out(hidden_states)
+ if self.residual_connection:
+ output_tensor = input_tensor + hidden_states
+ return output_tensor
+
+
+class UNetMidBlockCausal3D(nn.Module):
+
+ def __init__(self, in_channels, dropout=0.0, num_layers=1, eps=1e-6, num_groups=32, attention_head_dim=None):
+ super().__init__()
+ resnets = [
+ ResnetBlockCausal3D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ dropout=dropout,
+ groups=num_groups,
+ eps=eps,
+ )
+ ]
+ attentions = []
+ attention_head_dim = attention_head_dim or in_channels
+
+ for _ in range(num_layers):
+ attentions.append(
+ Attention(
+ in_channels,
+ num_heads=in_channels // attention_head_dim,
+ head_dim=attention_head_dim,
+ num_groups=num_groups,
+ dropout=dropout,
+ eps=eps,
+ bias=True,
+ residual_connection=True,
+ ))
+
+ resnets.append(
+ ResnetBlockCausal3D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ dropout=dropout,
+ groups=num_groups,
+ eps=eps,
+ ))
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ def forward(self, hidden_states):
+ hidden_states = self.resnets[0](hidden_states)
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
+ B, C, T, H, W = hidden_states.shape
+ hidden_states = rearrange(hidden_states, "b c f h w -> b (f h w) c")
+ attn_mask = prepare_causal_attention_mask(T, H * W, hidden_states.dtype, hidden_states.device, batch_size=B)
+ hidden_states = attn(hidden_states, attn_mask=attn_mask)
+ hidden_states = rearrange(hidden_states, "b (f h w) c -> b c f h w", f=T, h=H, w=W)
+ hidden_states = resnet(hidden_states)
+
+ return hidden_states
+
+
+class UpDecoderBlockCausal3D(nn.Module):
+
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ dropout=0.0,
+ num_layers=1,
+ eps=1e-6,
+ num_groups=32,
+ add_upsample=True,
+ upsample_scale_factor=(2, 2, 2),
+ ):
+ super().__init__()
+ resnets = []
+ for i in range(num_layers):
+ cur_in_channel = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlockCausal3D(
+ in_channels=cur_in_channel,
+ out_channels=out_channels,
+ groups=num_groups,
+ dropout=dropout,
+ eps=eps,
+ ))
+ self.resnets = nn.ModuleList(resnets)
+
+ self.upsamplers = None
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([
+ UpsampleCausal3D(
+ out_channels,
+ use_conv=True,
+ out_channels=out_channels,
+ upsample_factor=upsample_scale_factor,
+ )
+ ])
+
+ def forward(self, hidden_states):
+ for resnet in self.resnets:
+ hidden_states = resnet(hidden_states)
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states)
+ return hidden_states
+
+
+class DecoderCausal3D(nn.Module):
+
+ def __init__(
+ self,
+ in_channels=16,
+ out_channels=3,
+ eps=1e-6,
+ dropout=0.0,
+ block_out_channels=[128, 256, 512, 512],
+ layers_per_block=2,
+ num_groups=32,
+ time_compression_ratio=4,
+ spatial_compression_ratio=8,
+ gradient_checkpointing=False,
+ ):
+ super().__init__()
+ self.layers_per_block = layers_per_block
+
+ self.conv_in = CausalConv3d(in_channels, block_out_channels[-1], kernel_size=3, stride=1)
+ self.up_blocks = nn.ModuleList([])
+
+ # mid
+ self.mid_block = UNetMidBlockCausal3D(
+ in_channels=block_out_channels[-1],
+ dropout=dropout,
+ eps=eps,
+ num_groups=num_groups,
+ attention_head_dim=block_out_channels[-1],
+ )
+
+ # up
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ output_channel = reversed_block_out_channels[0]
+ for i in range(len(block_out_channels)):
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+ num_spatial_upsample_layers = int(np.log2(spatial_compression_ratio))
+ num_time_upsample_layers = int(np.log2(time_compression_ratio))
+
+ add_spatial_upsample = bool(i < num_spatial_upsample_layers)
+ add_time_upsample = bool(i >= len(block_out_channels) - 1 - num_time_upsample_layers and not is_final_block)
+
+ upsample_scale_factor_HW = (2, 2) if add_spatial_upsample else (1, 1)
+ upsample_scale_factor_T = (2,) if add_time_upsample else (1,)
+ upsample_scale_factor = tuple(upsample_scale_factor_T + upsample_scale_factor_HW)
+
+ up_block = UpDecoderBlockCausal3D(
+ in_channels=prev_output_channel,
+ out_channels=output_channel,
+ dropout=dropout,
+ num_layers=layers_per_block + 1,
+ eps=eps,
+ num_groups=num_groups,
+ add_upsample=bool(add_spatial_upsample or add_time_upsample),
+ upsample_scale_factor=upsample_scale_factor,
+ )
+
+ self.up_blocks.append(up_block)
+ prev_output_channel = output_channel
+
+ # out
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups, eps=eps)
+ self.conv_act = nn.SiLU()
+ self.conv_out = CausalConv3d(block_out_channels[0], out_channels, kernel_size=3)
+
+ self.gradient_checkpointing = gradient_checkpointing
+
+ def forward(self, hidden_states):
+ hidden_states = self.conv_in(hidden_states)
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ # middle
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(self.mid_block),
+ hidden_states,
+ use_reentrant=False,
+ )
+ # up
+ for up_block in self.up_blocks:
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(up_block),
+ hidden_states,
+ use_reentrant=False,
+ )
+ else:
+ # middle
+ hidden_states = self.mid_block(hidden_states)
+ # up
+ for up_block in self.up_blocks:
+ hidden_states = up_block(hidden_states)
+ # post-process
+ hidden_states = self.conv_norm_out(hidden_states)
+ hidden_states = self.conv_act(hidden_states)
+ hidden_states = self.conv_out(hidden_states)
+
+ return hidden_states
+
+
+class HunyuanVideoVAEDecoder(nn.Module):
+
+ def __init__(
+ self,
+ in_channels=16,
+ out_channels=3,
+ eps=1e-6,
+ dropout=0.0,
+ block_out_channels=[128, 256, 512, 512],
+ layers_per_block=2,
+ num_groups=32,
+ time_compression_ratio=4,
+ spatial_compression_ratio=8,
+ gradient_checkpointing=False,
+ ):
+ super().__init__()
+ self.decoder = DecoderCausal3D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ eps=eps,
+ dropout=dropout,
+ block_out_channels=block_out_channels,
+ layers_per_block=layers_per_block,
+ num_groups=num_groups,
+ time_compression_ratio=time_compression_ratio,
+ spatial_compression_ratio=spatial_compression_ratio,
+ gradient_checkpointing=gradient_checkpointing,
+ )
+ self.post_quant_conv = nn.Conv3d(in_channels, in_channels, kernel_size=1)
+ self.scaling_factor = 0.476986
+
+
+ def forward(self, latents):
+ latents = latents / self.scaling_factor
+ latents = self.post_quant_conv(latents)
+ dec = self.decoder(latents)
+ return dec
+
+
+ def build_1d_mask(self, length, left_bound, right_bound, border_width):
+ x = torch.ones((length,))
+ if not left_bound:
+ x[:border_width] = (torch.arange(border_width) + 1) / border_width
+ if not right_bound:
+ x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,))
+ return x
+
+
+ def build_mask(self, data, is_bound, border_width):
+ _, _, T, H, W = data.shape
+ t = self.build_1d_mask(T, is_bound[0], is_bound[1], border_width[0])
+ h = self.build_1d_mask(H, is_bound[2], is_bound[3], border_width[1])
+ w = self.build_1d_mask(W, is_bound[4], is_bound[5], border_width[2])
+
+ t = repeat(t, "T -> T H W", T=T, H=H, W=W)
+ h = repeat(h, "H -> T H W", T=T, H=H, W=W)
+ w = repeat(w, "W -> T H W", T=T, H=H, W=W)
+
+ mask = torch.stack([t, h, w]).min(dim=0).values
+ mask = rearrange(mask, "T H W -> 1 1 T H W")
+ return mask
+
+
+ def tile_forward(self, hidden_states, tile_size, tile_stride):
+ B, C, T, H, W = hidden_states.shape
+ size_t, size_h, size_w = tile_size
+ stride_t, stride_h, stride_w = tile_stride
+
+ # Split tasks
+ tasks = []
+ for t in range(0, T, stride_t):
+ if (t-stride_t >= 0 and t-stride_t+size_t >= T): continue
+ for h in range(0, H, stride_h):
+ if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
+ for w in range(0, W, stride_w):
+ if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
+ t_, h_, w_ = t + size_t, h + size_h, w + size_w
+ tasks.append((t, t_, h, h_, w, w_))
+
+ # Run
+ torch_dtype = self.post_quant_conv.weight.dtype
+ data_device = hidden_states.device
+ computation_device = self.post_quant_conv.weight.device
+
+ weight = torch.zeros((1, 1, (T - 1) * 4 + 1, H * 8, W * 8), dtype=torch_dtype, device=data_device)
+ values = torch.zeros((B, 3, (T - 1) * 4 + 1, H * 8, W * 8), dtype=torch_dtype, device=data_device)
+
+ for t, t_, h, h_, w, w_ in tqdm(tasks, desc="VAE decoding"):
+ hidden_states_batch = hidden_states[:, :, t:t_, h:h_, w:w_].to(computation_device)
+ hidden_states_batch = self.forward(hidden_states_batch).to(data_device)
+ if t > 0:
+ hidden_states_batch = hidden_states_batch[:, :, 1:]
+
+ mask = self.build_mask(
+ hidden_states_batch,
+ is_bound=(t==0, t_>=T, h==0, h_>=H, w==0, w_>=W),
+ border_width=((size_t - stride_t) * 4, (size_h - stride_h) * 8, (size_w - stride_w) * 8)
+ ).to(dtype=torch_dtype, device=data_device)
+
+ target_t = 0 if t==0 else t * 4 + 1
+ target_h = h * 8
+ target_w = w * 8
+ values[
+ :,
+ :,
+ target_t: target_t + hidden_states_batch.shape[2],
+ target_h: target_h + hidden_states_batch.shape[3],
+ target_w: target_w + hidden_states_batch.shape[4],
+ ] += hidden_states_batch * mask
+ weight[
+ :,
+ :,
+ target_t: target_t + hidden_states_batch.shape[2],
+ target_h: target_h + hidden_states_batch.shape[3],
+ target_w: target_w + hidden_states_batch.shape[4],
+ ] += mask
+ return values / weight
+
+
+ def decode_video(self, latents, tile_size=(17, 32, 32), tile_stride=(12, 24, 24)):
+ latents = latents.to(self.post_quant_conv.weight.dtype)
+ return self.tile_forward(latents, tile_size=tile_size, tile_stride=tile_stride)
+
+ @staticmethod
+ def state_dict_converter():
+ return HunyuanVideoVAEDecoderStateDictConverter()
+
+
+class HunyuanVideoVAEDecoderStateDictConverter:
+
+ def __init__(self):
+ pass
+
+ def from_diffusers(self, state_dict):
+ state_dict_ = {}
+ for name in state_dict:
+ if name.startswith('decoder.') or name.startswith('post_quant_conv.'):
+ state_dict_[name] = state_dict[name]
+ return state_dict_
diff --git a/PusaV1/diffsynth/models/hunyuan_video_vae_encoder.py b/PusaV1/diffsynth/models/hunyuan_video_vae_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..faaaeb95f7d688d57fb61a707c7658a89bb2c92a
--- /dev/null
+++ b/PusaV1/diffsynth/models/hunyuan_video_vae_encoder.py
@@ -0,0 +1,307 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange, repeat
+import numpy as np
+from tqdm import tqdm
+from .hunyuan_video_vae_decoder import CausalConv3d, ResnetBlockCausal3D, UNetMidBlockCausal3D
+
+
+class DownsampleCausal3D(nn.Module):
+
+ def __init__(self, channels, out_channels, kernel_size=3, bias=True, stride=2):
+ super().__init__()
+ self.conv = CausalConv3d(channels, out_channels, kernel_size, stride=stride, bias=bias)
+
+ def forward(self, hidden_states):
+ hidden_states = self.conv(hidden_states)
+ return hidden_states
+
+
+class DownEncoderBlockCausal3D(nn.Module):
+
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ dropout=0.0,
+ num_layers=1,
+ eps=1e-6,
+ num_groups=32,
+ add_downsample=True,
+ downsample_stride=2,
+ ):
+
+ super().__init__()
+ resnets = []
+ for i in range(num_layers):
+ cur_in_channel = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlockCausal3D(
+ in_channels=cur_in_channel,
+ out_channels=out_channels,
+ groups=num_groups,
+ dropout=dropout,
+ eps=eps,
+ ))
+ self.resnets = nn.ModuleList(resnets)
+
+ self.downsamplers = None
+ if add_downsample:
+ self.downsamplers = nn.ModuleList([DownsampleCausal3D(
+ out_channels,
+ out_channels,
+ stride=downsample_stride,
+ )])
+
+ def forward(self, hidden_states):
+ for resnet in self.resnets:
+ hidden_states = resnet(hidden_states)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ return hidden_states
+
+
+class EncoderCausal3D(nn.Module):
+
+ def __init__(
+ self,
+ in_channels: int = 3,
+ out_channels: int = 16,
+ eps=1e-6,
+ dropout=0.0,
+ block_out_channels=[128, 256, 512, 512],
+ layers_per_block=2,
+ num_groups=32,
+ time_compression_ratio: int = 4,
+ spatial_compression_ratio: int = 8,
+ gradient_checkpointing=False,
+ ):
+ super().__init__()
+ self.conv_in = CausalConv3d(in_channels, block_out_channels[0], kernel_size=3, stride=1)
+ self.down_blocks = nn.ModuleList([])
+
+ # down
+ output_channel = block_out_channels[0]
+ for i in range(len(block_out_channels)):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+ num_spatial_downsample_layers = int(np.log2(spatial_compression_ratio))
+ num_time_downsample_layers = int(np.log2(time_compression_ratio))
+
+ add_spatial_downsample = bool(i < num_spatial_downsample_layers)
+ add_time_downsample = bool(i >= (len(block_out_channels) - 1 - num_time_downsample_layers) and not is_final_block)
+
+ downsample_stride_HW = (2, 2) if add_spatial_downsample else (1, 1)
+ downsample_stride_T = (2,) if add_time_downsample else (1,)
+ downsample_stride = tuple(downsample_stride_T + downsample_stride_HW)
+ down_block = DownEncoderBlockCausal3D(
+ in_channels=input_channel,
+ out_channels=output_channel,
+ dropout=dropout,
+ num_layers=layers_per_block,
+ eps=eps,
+ num_groups=num_groups,
+ add_downsample=bool(add_spatial_downsample or add_time_downsample),
+ downsample_stride=downsample_stride,
+ )
+ self.down_blocks.append(down_block)
+
+ # mid
+ self.mid_block = UNetMidBlockCausal3D(
+ in_channels=block_out_channels[-1],
+ dropout=dropout,
+ eps=eps,
+ num_groups=num_groups,
+ attention_head_dim=block_out_channels[-1],
+ )
+ # out
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=num_groups, eps=eps)
+ self.conv_act = nn.SiLU()
+ self.conv_out = CausalConv3d(block_out_channels[-1], 2 * out_channels, kernel_size=3)
+
+ self.gradient_checkpointing = gradient_checkpointing
+
+ def forward(self, hidden_states):
+ hidden_states = self.conv_in(hidden_states)
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ # down
+ for down_block in self.down_blocks:
+ torch.utils.checkpoint.checkpoint(
+ create_custom_forward(down_block),
+ hidden_states,
+ use_reentrant=False,
+ )
+ # middle
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(self.mid_block),
+ hidden_states,
+ use_reentrant=False,
+ )
+ else:
+ # down
+ for down_block in self.down_blocks:
+ hidden_states = down_block(hidden_states)
+ # middle
+ hidden_states = self.mid_block(hidden_states)
+ # post-process
+ hidden_states = self.conv_norm_out(hidden_states)
+ hidden_states = self.conv_act(hidden_states)
+ hidden_states = self.conv_out(hidden_states)
+
+ return hidden_states
+
+
+class HunyuanVideoVAEEncoder(nn.Module):
+
+ def __init__(
+ self,
+ in_channels=3,
+ out_channels=16,
+ eps=1e-6,
+ dropout=0.0,
+ block_out_channels=[128, 256, 512, 512],
+ layers_per_block=2,
+ num_groups=32,
+ time_compression_ratio=4,
+ spatial_compression_ratio=8,
+ gradient_checkpointing=False,
+ ):
+ super().__init__()
+ self.encoder = EncoderCausal3D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ eps=eps,
+ dropout=dropout,
+ block_out_channels=block_out_channels,
+ layers_per_block=layers_per_block,
+ num_groups=num_groups,
+ time_compression_ratio=time_compression_ratio,
+ spatial_compression_ratio=spatial_compression_ratio,
+ gradient_checkpointing=gradient_checkpointing,
+ )
+ self.quant_conv = nn.Conv3d(2 * out_channels, 2 * out_channels, kernel_size=1)
+ self.scaling_factor = 0.476986
+
+
+ def forward(self, images):
+ latents = self.encoder(images)
+ latents = self.quant_conv(latents)
+ latents = latents[:, :16]
+ latents = latents * self.scaling_factor
+ return latents
+
+
+ def build_1d_mask(self, length, left_bound, right_bound, border_width):
+ x = torch.ones((length,))
+ if not left_bound:
+ x[:border_width] = (torch.arange(border_width) + 1) / border_width
+ if not right_bound:
+ x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,))
+ return x
+
+
+ def build_mask(self, data, is_bound, border_width):
+ _, _, T, H, W = data.shape
+ t = self.build_1d_mask(T, is_bound[0], is_bound[1], border_width[0])
+ h = self.build_1d_mask(H, is_bound[2], is_bound[3], border_width[1])
+ w = self.build_1d_mask(W, is_bound[4], is_bound[5], border_width[2])
+
+ t = repeat(t, "T -> T H W", T=T, H=H, W=W)
+ h = repeat(h, "H -> T H W", T=T, H=H, W=W)
+ w = repeat(w, "W -> T H W", T=T, H=H, W=W)
+
+ mask = torch.stack([t, h, w]).min(dim=0).values
+ mask = rearrange(mask, "T H W -> 1 1 T H W")
+ return mask
+
+
+ def tile_forward(self, hidden_states, tile_size, tile_stride):
+ B, C, T, H, W = hidden_states.shape
+ size_t, size_h, size_w = tile_size
+ stride_t, stride_h, stride_w = tile_stride
+
+ # Split tasks
+ tasks = []
+ for t in range(0, T, stride_t):
+ if (t-stride_t >= 0 and t-stride_t+size_t >= T): continue
+ for h in range(0, H, stride_h):
+ if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
+ for w in range(0, W, stride_w):
+ if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
+ t_, h_, w_ = t + size_t, h + size_h, w + size_w
+ tasks.append((t, t_, h, h_, w, w_))
+
+ # Run
+ torch_dtype = self.quant_conv.weight.dtype
+ data_device = hidden_states.device
+ computation_device = self.quant_conv.weight.device
+
+ weight = torch.zeros((1, 1, (T - 1) // 4 + 1, H // 8, W // 8), dtype=torch_dtype, device=data_device)
+ values = torch.zeros((B, 16, (T - 1) // 4 + 1, H // 8, W // 8), dtype=torch_dtype, device=data_device)
+
+ for t, t_, h, h_, w, w_ in tqdm(tasks, desc="VAE encoding"):
+ hidden_states_batch = hidden_states[:, :, t:t_, h:h_, w:w_].to(computation_device)
+ hidden_states_batch = self.forward(hidden_states_batch).to(data_device)
+ if t > 0:
+ hidden_states_batch = hidden_states_batch[:, :, 1:]
+
+ mask = self.build_mask(
+ hidden_states_batch,
+ is_bound=(t==0, t_>=T, h==0, h_>=H, w==0, w_>=W),
+ border_width=((size_t - stride_t) // 4, (size_h - stride_h) // 8, (size_w - stride_w) // 8)
+ ).to(dtype=torch_dtype, device=data_device)
+
+ target_t = 0 if t==0 else t // 4 + 1
+ target_h = h // 8
+ target_w = w // 8
+ values[
+ :,
+ :,
+ target_t: target_t + hidden_states_batch.shape[2],
+ target_h: target_h + hidden_states_batch.shape[3],
+ target_w: target_w + hidden_states_batch.shape[4],
+ ] += hidden_states_batch * mask
+ weight[
+ :,
+ :,
+ target_t: target_t + hidden_states_batch.shape[2],
+ target_h: target_h + hidden_states_batch.shape[3],
+ target_w: target_w + hidden_states_batch.shape[4],
+ ] += mask
+ return values / weight
+
+
+ def encode_video(self, latents, tile_size=(65, 256, 256), tile_stride=(48, 192, 192)):
+ latents = latents.to(self.quant_conv.weight.dtype)
+ return self.tile_forward(latents, tile_size=tile_size, tile_stride=tile_stride)
+
+
+ @staticmethod
+ def state_dict_converter():
+ return HunyuanVideoVAEEncoderStateDictConverter()
+
+
+class HunyuanVideoVAEEncoderStateDictConverter:
+
+ def __init__(self):
+ pass
+
+ def from_diffusers(self, state_dict):
+ state_dict_ = {}
+ for name in state_dict:
+ if name.startswith('encoder.') or name.startswith('quant_conv.'):
+ state_dict_[name] = state_dict[name]
+ return state_dict_
diff --git a/PusaV1/diffsynth/models/kolors_text_encoder.py b/PusaV1/diffsynth/models/kolors_text_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee785e373567734d0a9ff413c72b67f45c7d6b1d
--- /dev/null
+++ b/PusaV1/diffsynth/models/kolors_text_encoder.py
@@ -0,0 +1,1551 @@
+"""
+This model is copied from https://github.com/Kwai-Kolors/Kolors/tree/master/kolors/models.
+We didn't modify this model.
+The tensor operation is performed in the prompter.
+"""
+
+
+""" PyTorch ChatGLM model. """
+
+import math
+import copy
+import warnings
+import re
+import sys
+
+import torch
+import torch.utils.checkpoint
+import torch.nn.functional as F
+from torch import nn
+from torch.nn import CrossEntropyLoss, LayerNorm
+from torch.nn import CrossEntropyLoss, LayerNorm, MSELoss, BCEWithLogitsLoss
+from torch.nn.utils import skip_init
+from typing import Optional, Tuple, Union, List, Callable, Dict, Any
+from copy import deepcopy
+
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPast,
+ CausalLMOutputWithPast,
+ SequenceClassifierOutputWithPast,
+)
+from transformers.modeling_utils import PreTrainedModel
+from transformers.utils import logging
+from transformers.generation.logits_process import LogitsProcessor
+from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput
+from transformers import PretrainedConfig
+from torch.nn.parameter import Parameter
+import bz2
+import torch
+import base64
+import ctypes
+from transformers.utils import logging
+from typing import List
+
+
+
+logger = logging.get_logger(__name__)
+
+try:
+ from cpm_kernels.kernels.base import LazyKernelCModule, KernelFunction, round_up
+
+
+ class Kernel:
+ def __init__(self, code: bytes, function_names: List[str]):
+ self.code = code
+ self._function_names = function_names
+ self._cmodule = LazyKernelCModule(self.code)
+
+ for name in self._function_names:
+ setattr(self, name, KernelFunction(self._cmodule, name))
+
+
+ quantization_code = "$QlpoOTFBWSZTWU9yuJUAQHN//////////f/n/8/n///n//bt4dTidcVx8X3V9FV/92/v4B7/AD5FBQFAAAChSgKpFCFAFVSigUAAAEKhSgUUqgFBKigqVREQAABQBQIANDTTIGI00BkZBkNGE0A0BkBkGQGRkaNAaAGQNBoGgDIAAYIGTI0DQAQAaGmmQMRpoDIyDIaMJoBoDIDIMgMjI0aA0AMgaDQNAGQAAwQMmRoGgAgA0NNMgYjTQGRkGQ0YTQDQGQGQZAZGRo0BoAZA0GgaAMgABggZMjQNABABoaaZAxGmgMjIMhowmgGgMgMgyAyMjRoDQAyBoNA0AZAADBAyZGgaAAmqU1NEgJqnptU/Sn4jRR6J6epk2pqb1Q/SgAPUGgyNNGjQ2SBpoAZAAGg0NB6mgDIAAAAA2oaApSREBNAARhGiYEaEwU8pvImlP0k2aam1GaGqbFNM1MHpTwmkepmyU9R6nqPKekHqNNPUxNGhp6n6p6QaZ6o9TG1GMqcoV9ly6nRanHlq6zPNbnGZNi6HSug+2nPiZ13XcnFYZW+45W11CumhzYhchOJ2GLLV1OBjBjGf4TptOddTSOcVxhqYZMYwZXZZY00zI1paX5X9J+b+f4e+x43RXSxXPOdquiGpduatGyXneN696M9t4HU2eR5XX/kPhP261NTx3JO1Ow7LyuDmeo9a7d351T1ZxnvnrvYnrXv/hXxPCeuYx2XsNmO003eg9J3Z6U7b23meJ4ri01OdzTk9BNO96brz+qT5nuvvH3ds/G+m/JcG/F2XYuhXlvO+jP7U3XgrzPN/lr8Sf1n6j4j7jZs+s/T0tNaNNYzTs12rxjwztHlnire3Nzc3N1wuBwOBwXBvZfoHpD7rFmR99V5vj3aXza3xdBbXMalubTg/jIv5dfAi54Pdc75j4z412n3Npj3Ld/ENm7a3b/Cod6h/ret1/5vn/C+l+gdslMvgPSLJ8d8q+U66fevYn/tW1chleEtNTGlcHCbLRlq0tHzF5tsbbZZfHjjLgZu42XCuC3NrdjTasZGNzgxPIrGqp7r3p7L2p5XjnpPSmTd5XtzqnB6U87zzg1Ol0zd0zsLszxR6lkxp35u6/teL0L0W922cR7Lu1lpL9CsHirzuM2T+BgsyViT6LHcm0/Vr6U/7LGGyJeqTEjt0PHWhF5mCT7R9mtlDwriYv0Tyr/OxYt6qp5r0mPVT0608TqnqMZaarU2nFwrTzzlrs1ed7z1ux60wyr4ydCaTi3enW8x68x0zU7tXSlcmPSW1mGpWJMg4zmPC2lK96tp0OE80y4MfEvnZj8zGluR6b22ki1Ou9V2nCd9xovcPvcYMZYy0lvN60ScZ45vN6yeCeeXFb1lVjnnCar5fwXwE2bzJ4HI1XVPXfXZMm44GUsMpYsmLB65TuVdm0cl0b+i/wGNN66XjeV7zuPpHcnK/juhhjdfId5jMdE5nN0dGmmm2zZs2cexD5n9p/dY352XsvXHaZNWWsmmS1atjR452nYudzvqv2HMRyvNNnlMcDl3R2+yx2uVrBubTW9icHDVtbNXlZm7jma1rM4VurZZd2y6nUau7ZXZ7bVU+mnoOVxZGMrVmvX60605JwmzGZhhhjTWtaaaMaaGTGmNMZasY0iX8VMUl8eepaIrzGSpemWOQyZORk2bNpjUybMmxqYmknCGCFynutfksaZpjTNMaaatM0xsxcGR0sociNqxNSmhhR1ZJPbsn8qyF0t2qH6iYBclclalbtTTcHTDsPaX6rlnElph2Jyumumtynv2Kk8GI7rsvXbIcJgHJOSaSXnnGaI3m87RtVXJOZ/YtgdTE6Wpha6ZlE8ayXkef1fh602r2WwvfMXtMdLlkfnLFdYYwYso+bWqm7yJqHXZGw2nrS5ZanSYnWlxBxMF1V940K2wdrI7R6OYf7DGGamMmTSbRhlS45xmVOumF1EyPCmHrrN8wwZOOrdNtLeMtzFzDlWnfTBxMk2NaXIZHBYxYLD4w8yju0ao65Vz1OIXoS9dLanwCe1PWrYuWMqf1if1z2k2yYfKJ741PDgno1ZQ8DRqvUny3mNoWTzGO6m1DkrJI8JiR5cSd+vZdGOO8nrMoc5+NDUFsMSXaZJeNlMmGLtJsovOsUp7I9S5VojKxF6bTVEelXqlfJobQr3LozSh2Jk7VcrVMfhXqszGWMzNqGhqZY0OadxkyyMssKugZR0KNFXBHlqwmJgTE/BNVMk6ItJXZMR0H47GpXv/DMOvNkmVuaV1PRfEdxuqc7Hcd+ZV/zTLaRxWk0nl9CdCeM6mn5rstHIBcpiuwmUZXeq81DacHI2rmrZ5SuE5mOZd6LQrZg9mx32TprA8BMo5jKN6yLTCi3WzQaZSuhzTtM1fUTGVpG8Tw+KXI0tjEpiWxtLYynOlktSbVlaI5kxP8TDH8kx50xoxi5KcA4pcja8KWLRlO/Ks6q06ergnvm1ca3Tq8Uw7LTUsmWyctXPWmpitl/uvGcWTGXGuAXDfhqazGmjkxcJW5hMMMMpYsXl2TZYtVOddG3XCarUt6Ptq9CZXSNzyuRzqRZOjsxdBbFVz6OA5HI43r1jityVlVpVkxmOsyaYWE1NTGq1sOVh36mHMcxtSvcy70edG0ZGR3I1Go1GRlV7mWWo1G0ZGRqlvH40l7o4m5xMWLLLYyNjnqc8556mdPqLJ31n/1nWOncxzG1tizrHs/Z+d2vP/B/l8wdJ6rHUn2nbbDq4p6htFtYzMMMTaZis1K5GKzGNmxhmUx2DDlZ/qNnIx41xnaMfCZWYaZWtNLTNW8ND4Fw1MyZOCdM428suKG1ehW8TesOydg7J+YYcD4cYR+8dFK6M4E3HM9ZfRNNL+Sn6rsl4DsrDl2HpPCnfxjGXtbZtYys1ttlyJ4T+BvexjGWRjMszK4Jpc77D3GyuVD7q0+G8m9G+2+rGm7cOR2y7FdtY2XUYx/oNlfRYxhMYyYZkyyg55enna9Kt/FFi6GMMwYwdwxWgxGMLKYmUyGExTKMZkMFhkymKuh0NOBNnBu+23LdwDoZYYzGGMxtORaTU1pjTGWTTGGtMrNWUsyyTTLLG1qy2ZjbK2DBllWqxMtBMaYZQmcE7zvvRcTkclUwdkxTaSdyySt/7fpL+T1v516Ji97fwr5JbLu305zMn5+GMTTZ9F+y7ExwmGVfG44yxn3dLv6l5i+Wth1jCrDq21nW9LqvvDzz3Vf3LLH/O/32TJ/erx3bXftO4eF+G956D952K/An4NfvOpjFjExjevP/UmE0fIoZXx6/w6lX/no3D0bLt+ixjieBM6ksRd0yB4Lt2SwYNE+gd1detlZWUnpiZfGfFaK+4PyCa/v18V8X75pe9fLXzp7l3VjF76vWZmHwGz1IZNWT7b8yddJ4q5kyrVdfru6atWc7bVYztL9Jf4GXvT+Y8m9/YsXP6H018a8D4XVOqvfzqeR+6yZOD8dPv0+U7/q5Pl+2dNb0MjzGVH5p6MNQ7cOWvw62U9aHE8DprDek+McLyvDz+te+9Zhq5+YTruufMcWMabqysTmZVWjKPfnK0wyVcrsuhjZRdLkHNvD72b9abriOSGIxiLixMOoalNPXzy+wT/tf+U6HHONfsz+xe8ufHBdQWWGWLA9if0rsnmrxK5LvRZQeWsTCsrmOYy8VteVfuRfcVTtDLItLIsMYxZLdU/DbtSemxF6Z6Zo5WBXE4tFdCyVMMXMTEMZXVlS6Xec2T4e0tHsRcEuWshcJ2YsNF5rUx1E8ifCq6Z+ZP7qdCeu/aTwFd53l16/o0NOw6O3dLavP4Hbi4RdmuDk6DoYaninC0+o4uZjbJ7Rxeu0/FbuFg+q7DVS6fQe0rZ6NDGUNNU6DEqOaLTicKnYZMnBWruljQxoaS3dZhocDge0bSTyOvdAbG5hxe2xji7E/L55xX13wWNDi6HCekcFxfCPGxY0MXC+s7afWaMdDyjyr+o8Rudm/NabOZvdl274zH4f5XK9z6On1Pe/K5TdPAslg77BjuO6Y3eO7GqvOPG/stknp1leyvLL0Z7bl9I4noMvLkzytLhWYzrOZzLXCORe028rORzOg4N/L0HlMOQ3Pgmnbb6KczlabORpu980q37TBqRu0/p3PO6234Bl03Ynuz+9W7gnsEcmvYaYY3aMYY0wx3pYd+ujsXauWdaY5Xkbtl23fPzFHiDB/QMo0yFjBllYxTQYYyxkrwn7JufwJ/PfgJ+C83X69ni6zvXcnyXabv0ncbLwsceS+RNlyN2mnneJtX0ngYO0+e+0+UnA+Wch3ji8hj5an4h+i6XBySU4n+R0roVcbw5yvHrmr4Yw8Y7x6c+9POPYHI5HI5HI5HI5HGXGww4nE4nrVyOR8XeqPEO7PLOiukYa3Novk5hV4cdtYZLI93e+uxff2jRo0aNGjRo0aNG1bVtW1dy3m83m8+tQ5ZzHw3nObwOu8La9Rc1dtkdS8A3eTk823tnktXWlxN6Oixe06zrN70Isd9jiOgZFq9yfkPqP/SLhN2Myl8jDM43bl1nbcb4cO57jlh8Jow6pzXZdL4dyODTuuhu77FyO27DdwdRxmvO+O+3N2+BdqyTwLHVczDVY4UPE4O66/ZO2cx1LFzVdSXtF7G4HMbrauOHRw6c8FdZ5m9fHZHYZXfTlZquyynSyTTKke6vcffSD9pzPA/G7n7jxPmuhc1DHMynPMrGL6AdewYmwu5ko+UUyTwrMv27rPH1v1nGqd87+p6N6LU8k3NEng53xXyHS97+44OSg/sy/hn+Se6yfYNjW0/uTgP+PvWYzLMmjhcLB/gGpri6H83/84eUXWT6T9Hsv7785z/7z4icpW+zfXypuR7rx/gMdZb1/wC678pcs8/2a3mDitGHxl9mfPlll5MafWWqxk/eYuTDgcNMzDGWLWvsuglNxs53GtN6uWpktlW1tZZYcuinMMWmnNnJydze3b2Y1McBxrBkXw799izLMZZYyy0TkbsGM4p03S2uVu5s/XXUdSdec6smVxZYYGpVmT8A+8ajuEyV5FatkvVru2x6uxGXXbH4A+jvgP4GMYy3iPLXzq/6z65+E005ey+cwMZD3fZcqc6xpjTFjQ0P3U+e++cPYmTIwj0nrK5NPTfl3WvpfLtXDcb2HQMudYOxFXQBor4L4T6vrOauFctYXJQ++NUWmJe5bmx1jDiZS1dTqWxo4GR8jm3fttpmPHppk9PEyv4/y8/sO07XacOmcqc0x2Vi9BvNJvN5oW8x4mOsydpidRxMYJPx06m1bqPzq9KtK8sxXNXFodD/+MYYaJTLwOhc9brCsV18oOR1i4tXChyTkq4lf4y1Ke+9axjDHqs1mfBbMXuP4Hzi+X7t8vzv7bHerrUPgPCxhjre4fXdfLNtNM+Jd+Zdh8xd8wP87uNPoPgv4W7/5P2BuxfsMabNnMnza+54Pdi5U671GPZY8CehX8Voeoo7FHpkeEc6715FwHZrIrUrHaviPUbPZHND+IhczrP6FcYvhOZ0Di/ETt0OI+YwNWR9r7tpf6WDeZKZDB1+z2IthOl1mPyb5FluvEx9h9d0NnM0Y1XPFkWIsk1WotJ0PBMmkvjvQTd0e71tfeV+8r8lQ/tpzpsmxJ+InrI/dj2UajUajVTUajatRqNRtGo1Go1Go4wjeMpZFMVV9CHbofPraLsJ3JpWV2XOoanCuFky4y3PPNxucK2uKC1Lbdb1eo+m5XomN6HfeZsabHLHRX/K+offtNGGmHWctcVcG44MdSqsOLY9VzX+Zxfxn2HPdWTpzWvkrtJ8M5zorrKcquRytJ5N5DZmcaW02l76nWO+BqPXm1A2Ry/0q71dH/mqrqeFjkYxjEXtsX8qubTk67rGycyqsdm4tZx5D6D5hhi0waaWmiaMP81Yjii5qxPlPuU/GfTL1Y5E6Jyfiq63qTa39A4J0sOGDgO9WF9bOXl0XfPRbsY2bPNKPy1YrFYrFYmRhhlTIyMjJWJYZHXuCXI8OoXsvfljGLFicNifpp2XunoPiG1wtx3p1Tah+/DD66OnVtVXP9rKbVxOnL0tR/rHtqB5UDErUVcl11D4qqvjpOcxX7armUNJB3LpW6bxVvD08e8h3odKKvyCFZBdSh2FVcST9xV3n3T8t1j7Kr9qgrqXg+13Pt5U7JCvFXVIV1YG5lRhkVYZJYYDDD4KOIMoHCp26WS8GB7uBh2zIdgq/PKyInjV2STShuoapUdCpX1yTwqq/z1VvET7Kh5nVPkO8YyxjLt2MaaMmWTLQvx3qnzltnXW0p2jxgbEtSny/Osv8Y9pLMXYoHVPAhkVdWVeODhR6q9/Sxe2liwwZWMVvFXfRkeIDxAePUPIrdJ4ey6yquzH+PD/bUOWAu05qVHtFd8rrKHSoeNIOUqrYr3FXyToqfYJgwmJdKpXXOwYYegNNGMzfZPp/t3t/DVs4zjNTN61rRqaWaa4NYbRjTa0tWwy2Y2tGN8ZO8ofNKq4j9SL7I+cSm4/6ovLV5HNXLI0jJidwrtk6ynCaP6Z++GjRlWS3tLeW129Mi9evxU9mtz6s5J3Z7M2ngTgnKvmpomxpaLCzPfmx0JWE+m3NLDDGOX47RctdYYNK5jakdqLkRlI39n590T5zctGSwwZZDJj6kW8XSi6ot2MmWWJ0DUT3nuvebBudScjZ79g8cWJ8av0k+/bE5WKd5MdbFpbDVMxu1DVMmtNZGJvq1mtRbn6M+g/kP0FwDwr7quZs7xosNGpbscyxhhd9TyJyFwbLcxlTasg75vW7TsV5K7ji44XPMMrdoj+Y3rT0Hie62nlYV/pwczzOmdLqLhYkzGMzCZWGMQzGMSsZYY6Di1t4nlJ+Em63mJxrVLxPbYxNEdgc1dU2iOKyoYYWjNrEeHTYybVk0atSa7ehuwsWMWTqn1TrnS6hYsi71d1+s+k+ic70e20fzE/VaTdxT9ZtU4GIXdeNx3X77guYYfpHeTQjaMX6brOu4OY4K7Y2d9mbHarI5ox3p4GpJ2Vd/Tst60f7j999pppjR+Q/Qf8J/VaORs3cji7FfFuN61+ui9s8hix1OCh5KGVV23BPXvZfz3CLyHpix+exi8z/KnCnosY2eunor+cxyPO/xJ0vKey9OvE9VjqaYu0x3Z3jd6o2b1T12D+F8l232lwaaacD5LE8LBxu7WTlbWraWpew8Xexjel3E+wWD4APITdNqR8F3R3T0lunCQ4GaE9R37DxeCYfcHi4xci5ovKfxVs55y2hf+65E/Xdp6jR5nrebTmi5incpkyOjs50JvrZwstbbW6kfuuQw+2mykf/EXNFzxfKTrxew929TR6bWnGL//F3JFOFCQT3K4lQ"
+
+ kernels = Kernel(
+ bz2.decompress(base64.b64decode(quantization_code)),
+ [
+ "int4WeightCompression",
+ "int4WeightExtractionFloat",
+ "int4WeightExtractionHalf",
+ "int8WeightExtractionFloat",
+ "int8WeightExtractionHalf",
+ ],
+ )
+except Exception as exception:
+ kernels = None
+
+
+class W8A16Linear(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width):
+ ctx.inp_shape = inp.size()
+ ctx.weight_bit_width = weight_bit_width
+ out_features = quant_w.size(0)
+ inp = inp.contiguous().view(-1, inp.size(-1))
+ weight = extract_weight_to_half(quant_w, scale_w, weight_bit_width)
+ ctx.weight_shape = weight.size()
+ output = inp.mm(weight.t())
+ ctx.save_for_backward(inp, quant_w, scale_w)
+ return output.view(*(ctx.inp_shape[:-1] + (out_features,)))
+
+ @staticmethod
+ def backward(ctx, grad_output: torch.Tensor):
+ inp, quant_w, scale_w = ctx.saved_tensors
+ weight = extract_weight_to_half(quant_w, scale_w, ctx.weight_bit_width)
+ grad_output = grad_output.contiguous().view(-1, weight.size(0))
+ grad_input = grad_output.mm(weight)
+ grad_weight = grad_output.t().mm(inp)
+ return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None, None
+
+
+def compress_int4_weight(weight: torch.Tensor): # (n, m)
+ with torch.cuda.device(weight.device):
+ n, m = weight.size(0), weight.size(1)
+ assert m % 2 == 0
+ m = m // 2
+ out = torch.empty(n, m, dtype=torch.int8, device="cuda")
+ stream = torch.cuda.current_stream()
+
+ gridDim = (n, 1, 1)
+ blockDim = (min(round_up(m, 32), 1024), 1, 1)
+
+ kernels.int4WeightCompression(
+ gridDim,
+ blockDim,
+ 0,
+ stream,
+ [ctypes.c_void_p(weight.data_ptr()), ctypes.c_void_p(out.data_ptr()), ctypes.c_int32(n), ctypes.c_int32(m)],
+ )
+ return out
+
+
+def extract_weight_to_half(weight: torch.Tensor, scale_list: torch.Tensor, source_bit_width: int):
+ assert scale_list.dtype in [torch.half, torch.bfloat16]
+ assert weight.dtype in [torch.int8]
+ if source_bit_width == 8:
+ return weight.to(scale_list.dtype) * scale_list[:, None]
+ elif source_bit_width == 4:
+ func = (
+ kernels.int4WeightExtractionHalf if scale_list.dtype == torch.half else kernels.int4WeightExtractionBFloat16
+ )
+ else:
+ assert False, "Unsupported bit-width"
+
+ with torch.cuda.device(weight.device):
+ n, m = weight.size(0), weight.size(1)
+ out = torch.empty(n, m * (8 // source_bit_width), dtype=scale_list.dtype, device="cuda")
+ stream = torch.cuda.current_stream()
+
+ gridDim = (n, 1, 1)
+ blockDim = (min(round_up(m, 32), 1024), 1, 1)
+
+ func(
+ gridDim,
+ blockDim,
+ 0,
+ stream,
+ [
+ ctypes.c_void_p(weight.data_ptr()),
+ ctypes.c_void_p(scale_list.data_ptr()),
+ ctypes.c_void_p(out.data_ptr()),
+ ctypes.c_int32(n),
+ ctypes.c_int32(m),
+ ],
+ )
+ return out
+
+
+class QuantizedLinear(torch.nn.Module):
+ def __init__(self, weight_bit_width: int, weight, bias=None, device="cuda", dtype=None, empty_init=False):
+ super().__init__()
+ weight = weight.to(device) # ensure the weight is on the cuda device
+ assert str(weight.device).startswith(
+ 'cuda'), 'The weights that need to be quantified should be on the CUDA device'
+ self.weight_bit_width = weight_bit_width
+ shape = weight.shape
+
+ if weight is None or empty_init:
+ self.weight = torch.empty(shape[0], shape[1] * weight_bit_width // 8, dtype=torch.int8, device=device)
+ self.weight_scale = torch.empty(shape[0], dtype=dtype, device=device)
+ else:
+ self.weight_scale = weight.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)
+ self.weight = torch.round(weight / self.weight_scale[:, None]).to(torch.int8)
+ if weight_bit_width == 4:
+ self.weight = compress_int4_weight(self.weight)
+
+ self.weight = Parameter(self.weight.to(device), requires_grad=False)
+ self.weight_scale = Parameter(self.weight_scale.to(device), requires_grad=False)
+ self.bias = Parameter(bias.to(device), requires_grad=False) if bias is not None else None
+
+ def forward(self, input):
+ output = W8A16Linear.apply(input, self.weight, self.weight_scale, self.weight_bit_width)
+ if self.bias is not None:
+ output = output + self.bias
+ return output
+
+
+def quantize(model, weight_bit_width, empty_init=False, device=None):
+ """Replace fp16 linear with quantized linear"""
+ for layer in model.layers:
+ layer.self_attention.query_key_value = QuantizedLinear(
+ weight_bit_width=weight_bit_width,
+ weight=layer.self_attention.query_key_value.weight,
+ bias=layer.self_attention.query_key_value.bias,
+ dtype=layer.self_attention.query_key_value.weight.dtype,
+ device=layer.self_attention.query_key_value.weight.device if device is None else device,
+ empty_init=empty_init
+ )
+ layer.self_attention.dense = QuantizedLinear(
+ weight_bit_width=weight_bit_width,
+ weight=layer.self_attention.dense.weight,
+ bias=layer.self_attention.dense.bias,
+ dtype=layer.self_attention.dense.weight.dtype,
+ device=layer.self_attention.dense.weight.device if device is None else device,
+ empty_init=empty_init
+ )
+ layer.mlp.dense_h_to_4h = QuantizedLinear(
+ weight_bit_width=weight_bit_width,
+ weight=layer.mlp.dense_h_to_4h.weight,
+ bias=layer.mlp.dense_h_to_4h.bias,
+ dtype=layer.mlp.dense_h_to_4h.weight.dtype,
+ device=layer.mlp.dense_h_to_4h.weight.device if device is None else device,
+ empty_init=empty_init
+ )
+ layer.mlp.dense_4h_to_h = QuantizedLinear(
+ weight_bit_width=weight_bit_width,
+ weight=layer.mlp.dense_4h_to_h.weight,
+ bias=layer.mlp.dense_4h_to_h.bias,
+ dtype=layer.mlp.dense_4h_to_h.weight.dtype,
+ device=layer.mlp.dense_4h_to_h.weight.device if device is None else device,
+ empty_init=empty_init
+ )
+
+ return model
+
+
+
+class ChatGLMConfig(PretrainedConfig):
+ model_type = "chatglm"
+ def __init__(
+ self,
+ num_layers=28,
+ padded_vocab_size=65024,
+ hidden_size=4096,
+ ffn_hidden_size=13696,
+ kv_channels=128,
+ num_attention_heads=32,
+ seq_length=2048,
+ hidden_dropout=0.0,
+ classifier_dropout=None,
+ attention_dropout=0.0,
+ layernorm_epsilon=1e-5,
+ rmsnorm=True,
+ apply_residual_connection_post_layernorm=False,
+ post_layer_norm=True,
+ add_bias_linear=False,
+ add_qkv_bias=False,
+ bias_dropout_fusion=True,
+ multi_query_attention=False,
+ multi_query_group_num=1,
+ apply_query_key_layer_scaling=True,
+ attention_softmax_in_fp32=True,
+ fp32_residual_connection=False,
+ quantization_bit=0,
+ pre_seq_len=None,
+ prefix_projection=False,
+ **kwargs
+ ):
+ self.num_layers = num_layers
+ self.vocab_size = padded_vocab_size
+ self.padded_vocab_size = padded_vocab_size
+ self.hidden_size = hidden_size
+ self.ffn_hidden_size = ffn_hidden_size
+ self.kv_channels = kv_channels
+ self.num_attention_heads = num_attention_heads
+ self.seq_length = seq_length
+ self.hidden_dropout = hidden_dropout
+ self.classifier_dropout = classifier_dropout
+ self.attention_dropout = attention_dropout
+ self.layernorm_epsilon = layernorm_epsilon
+ self.rmsnorm = rmsnorm
+ self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
+ self.post_layer_norm = post_layer_norm
+ self.add_bias_linear = add_bias_linear
+ self.add_qkv_bias = add_qkv_bias
+ self.bias_dropout_fusion = bias_dropout_fusion
+ self.multi_query_attention = multi_query_attention
+ self.multi_query_group_num = multi_query_group_num
+ self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
+ self.attention_softmax_in_fp32 = attention_softmax_in_fp32
+ self.fp32_residual_connection = fp32_residual_connection
+ self.quantization_bit = quantization_bit
+ self.pre_seq_len = pre_seq_len
+ self.prefix_projection = prefix_projection
+ super().__init__(**kwargs)
+
+
+
+# flags required to enable jit fusion kernels
+
+if sys.platform != 'darwin':
+ torch._C._jit_set_profiling_mode(False)
+ torch._C._jit_set_profiling_executor(False)
+ torch._C._jit_override_can_fuse_on_cpu(True)
+ torch._C._jit_override_can_fuse_on_gpu(True)
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "THUDM/ChatGLM"
+_CONFIG_FOR_DOC = "ChatGLM6BConfig"
+
+CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "THUDM/chatglm3-6b-base",
+ # See all ChatGLM models at https://huggingface.co/models?filter=chatglm
+]
+
+
+def default_init(cls, *args, **kwargs):
+ return cls(*args, **kwargs)
+
+
+class InvalidScoreLogitsProcessor(LogitsProcessor):
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
+ if torch.isnan(scores).any() or torch.isinf(scores).any():
+ scores.zero_()
+ scores[..., 5] = 5e4
+ return scores
+
+
+class PrefixEncoder(torch.nn.Module):
+ """
+ The torch.nn model to encode the prefix
+ Input shape: (batch-size, prefix-length)
+ Output shape: (batch-size, prefix-length, 2*layers*hidden)
+ """
+
+ def __init__(self, config: ChatGLMConfig):
+ super().__init__()
+ self.prefix_projection = config.prefix_projection
+ if self.prefix_projection:
+ # Use a two-layer MLP to encode the prefix
+ kv_size = config.num_layers * config.kv_channels * config.multi_query_group_num * 2
+ self.embedding = torch.nn.Embedding(config.pre_seq_len, kv_size)
+ self.trans = torch.nn.Sequential(
+ torch.nn.Linear(kv_size, config.hidden_size),
+ torch.nn.Tanh(),
+ torch.nn.Linear(config.hidden_size, kv_size)
+ )
+ else:
+ self.embedding = torch.nn.Embedding(config.pre_seq_len,
+ config.num_layers * config.kv_channels * config.multi_query_group_num * 2)
+
+ def forward(self, prefix: torch.Tensor):
+ if self.prefix_projection:
+ prefix_tokens = self.embedding(prefix)
+ past_key_values = self.trans(prefix_tokens)
+ else:
+ past_key_values = self.embedding(prefix)
+ return past_key_values
+
+
+def split_tensor_along_last_dim(
+ tensor: torch.Tensor,
+ num_partitions: int,
+ contiguous_split_chunks: bool = False,
+) -> List[torch.Tensor]:
+ """Split a tensor along its last dimension.
+
+ Arguments:
+ tensor: input tensor.
+ num_partitions: number of partitions to split the tensor
+ contiguous_split_chunks: If True, make each chunk contiguous
+ in memory.
+
+ Returns:
+ A list of Tensors
+ """
+ # Get the size and dimension.
+ last_dim = tensor.dim() - 1
+ last_dim_size = tensor.size()[last_dim] // num_partitions
+ # Split.
+ tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
+ # Note: torch.split does not create contiguous tensors by default.
+ if contiguous_split_chunks:
+ return tuple(chunk.contiguous() for chunk in tensor_list)
+
+ return tensor_list
+
+
+class RotaryEmbedding(nn.Module):
+ def __init__(self, dim, original_impl=False, device=None, dtype=None):
+ super().__init__()
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim))
+ self.register_buffer("inv_freq", inv_freq)
+ self.dim = dim
+ self.original_impl = original_impl
+
+ def forward_impl(
+ self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000
+ ):
+ """Enhanced Transformer with Rotary Position Embedding.
+
+ Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
+ transformers/rope/__init__.py. MIT License:
+ https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
+ """
+ # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
+ theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=torch.float, device=device) / n_elem))
+
+ # Create position indexes `[0, 1, ..., seq_len - 1]`
+ seq_idx = torch.arange(seq_len, dtype=torch.float, device=device)
+
+ # Calculate the product of position index and $\theta_i$
+ idx_theta = torch.outer(seq_idx, theta).float()
+
+ cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)
+
+ # this is to mimic the behaviour of complex32, else we will get different results
+ if dtype in (torch.float16, torch.bfloat16, torch.int8):
+ cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half()
+ return cache
+
+ def forward(self, max_seq_len, offset=0):
+ return self.forward_impl(
+ max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device
+ )
+
+
+@torch.jit.script
+def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
+ # x: [sq, b, np, hn]
+ sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
+ rot_dim = rope_cache.shape[-2] * 2
+ x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
+ # truncate to support variable sizes
+ rope_cache = rope_cache[:sq]
+ xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2)
+ rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2)
+ x_out2 = torch.stack(
+ [
+ xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
+ xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
+ ],
+ -1,
+ )
+ x_out2 = x_out2.flatten(3)
+ return torch.cat((x_out2, x_pass), dim=-1)
+
+
+class RMSNorm(torch.nn.Module):
+ def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
+ super().__init__()
+ self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
+ self.eps = eps
+
+ def forward(self, hidden_states: torch.Tensor):
+ input_dtype = hidden_states.dtype
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
+
+ return (self.weight * hidden_states).to(input_dtype)
+
+
+class CoreAttention(torch.nn.Module):
+ def __init__(self, config: ChatGLMConfig, layer_number):
+ super(CoreAttention, self).__init__()
+
+ self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling
+ self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
+ if self.apply_query_key_layer_scaling:
+ self.attention_softmax_in_fp32 = True
+ self.layer_number = max(1, layer_number)
+
+ projection_size = config.kv_channels * config.num_attention_heads
+
+ # Per attention head and per partition values.
+ self.hidden_size_per_partition = projection_size
+ self.hidden_size_per_attention_head = projection_size // config.num_attention_heads
+ self.num_attention_heads_per_partition = config.num_attention_heads
+
+ coeff = None
+ self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
+ if self.apply_query_key_layer_scaling:
+ coeff = self.layer_number
+ self.norm_factor *= coeff
+ self.coeff = coeff
+
+ self.attention_dropout = torch.nn.Dropout(config.attention_dropout)
+
+ def forward(self, query_layer, key_layer, value_layer, attention_mask):
+ pytorch_major_version = int(torch.__version__.split('.')[0])
+ if pytorch_major_version >= 2:
+ query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
+ if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
+ context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
+ is_causal=True)
+ else:
+ if attention_mask is not None:
+ attention_mask = ~attention_mask
+ context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
+ attention_mask)
+ context_layer = context_layer.permute(2, 0, 1, 3)
+ new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
+ context_layer = context_layer.reshape(*new_context_layer_shape)
+ else:
+ # Raw attention scores
+
+ # [b, np, sq, sk]
+ output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0))
+
+ # [sq, b, np, hn] -> [sq, b * np, hn]
+ query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)
+ # [sk, b, np, hn] -> [sk, b * np, hn]
+ key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
+
+ # preallocting input tensor: [b * np, sq, sk]
+ matmul_input_buffer = torch.empty(
+ output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype,
+ device=query_layer.device
+ )
+
+ # Raw attention scores. [b * np, sq, sk]
+ matmul_result = torch.baddbmm(
+ matmul_input_buffer,
+ query_layer.transpose(0, 1), # [b * np, sq, hn]
+ key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
+ beta=0.0,
+ alpha=(1.0 / self.norm_factor),
+ )
+
+ # change view to [b, np, sq, sk]
+ attention_scores = matmul_result.view(*output_size)
+
+ # ===========================
+ # Attention probs and dropout
+ # ===========================
+
+ # attention scores and attention mask [b, np, sq, sk]
+ if self.attention_softmax_in_fp32:
+ attention_scores = attention_scores.float()
+ if self.coeff is not None:
+ attention_scores = attention_scores * self.coeff
+ if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]:
+ attention_mask = torch.ones(output_size[0], 1, output_size[2], output_size[3],
+ device=attention_scores.device, dtype=torch.bool)
+ attention_mask.tril_()
+ attention_mask = ~attention_mask
+ if attention_mask is not None:
+ attention_scores = attention_scores.masked_fill(attention_mask, float("-inf"))
+ attention_probs = F.softmax(attention_scores, dim=-1)
+ attention_probs = attention_probs.type_as(value_layer)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.attention_dropout(attention_probs)
+ # =========================
+ # Context layer. [sq, b, hp]
+ # =========================
+
+ # value_layer -> context layer.
+ # [sk, b, np, hn] --> [b, np, sq, hn]
+
+ # context layer shape: [b, np, sq, hn]
+ output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3))
+ # change view [sk, b * np, hn]
+ value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1)
+ # change view [b * np, sq, sk]
+ attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
+ # matmul: [b * np, sq, hn]
+ context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
+ # change view [b, np, sq, hn]
+ context_layer = context_layer.view(*output_size)
+ # [b, np, sq, hn] --> [sq, b, np, hn]
+ context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
+ # [sq, b, np, hn] --> [sq, b, hp]
+ new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
+ context_layer = context_layer.view(*new_context_layer_shape)
+
+ return context_layer
+
+
+class SelfAttention(torch.nn.Module):
+ """Parallel self-attention layer abstract class.
+
+ Self-attention layer takes input with size [s, b, h]
+ and returns output of the same size.
+ """
+
+ def __init__(self, config: ChatGLMConfig, layer_number, device=None):
+ super(SelfAttention, self).__init__()
+ self.layer_number = max(1, layer_number)
+
+ self.projection_size = config.kv_channels * config.num_attention_heads
+
+ # Per attention head and per partition values.
+ self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads
+ self.num_attention_heads_per_partition = config.num_attention_heads
+
+ self.multi_query_attention = config.multi_query_attention
+ self.qkv_hidden_size = 3 * self.projection_size
+ if self.multi_query_attention:
+ self.num_multi_query_groups_per_partition = config.multi_query_group_num
+ self.qkv_hidden_size = (
+ self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num
+ )
+ self.query_key_value = nn.Linear(config.hidden_size, self.qkv_hidden_size,
+ bias=config.add_bias_linear or config.add_qkv_bias,
+ device=device, **_config_to_kwargs(config)
+ )
+
+ self.core_attention = CoreAttention(config, self.layer_number)
+
+ # Output.
+ self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear,
+ device=device, **_config_to_kwargs(config)
+ )
+
+ def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None):
+ if self.multi_query_attention:
+ num_attention_heads = self.num_multi_query_groups_per_partition
+ else:
+ num_attention_heads = self.num_attention_heads_per_partition
+ return torch.empty(
+ inference_max_sequence_len,
+ batch_size,
+ num_attention_heads,
+ self.hidden_size_per_attention_head,
+ dtype=dtype,
+ device=device,
+ )
+
+ def forward(
+ self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True
+ ):
+ # hidden_states: [sq, b, h]
+
+ # =================================================
+ # Pre-allocate memory for key-values for inference.
+ # =================================================
+ # =====================
+ # Query, Key, and Value
+ # =====================
+
+ # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
+ mixed_x_layer = self.query_key_value(hidden_states)
+
+ if self.multi_query_attention:
+ (query_layer, key_layer, value_layer) = mixed_x_layer.split(
+ [
+ self.num_attention_heads_per_partition * self.hidden_size_per_attention_head,
+ self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
+ self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
+ ],
+ dim=-1,
+ )
+ query_layer = query_layer.view(
+ query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
+ )
+ key_layer = key_layer.view(
+ key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
+ )
+ value_layer = value_layer.view(
+ value_layer.size()[:-1]
+ + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
+ )
+ else:
+ new_tensor_shape = mixed_x_layer.size()[:-1] + \
+ (self.num_attention_heads_per_partition,
+ 3 * self.hidden_size_per_attention_head)
+ mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
+
+ # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
+ (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
+
+ # apply relative positional encoding (rotary embedding)
+ if rotary_pos_emb is not None:
+ query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
+ key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)
+
+ # adjust key and value for inference
+ if kv_cache is not None:
+ cache_k, cache_v = kv_cache
+ key_layer = torch.cat((cache_k, key_layer), dim=0)
+ value_layer = torch.cat((cache_v, value_layer), dim=0)
+ if use_cache:
+ kv_cache = (key_layer, value_layer)
+ else:
+ kv_cache = None
+
+ if self.multi_query_attention:
+ key_layer = key_layer.unsqueeze(-2)
+ key_layer = key_layer.expand(
+ -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1
+ )
+ key_layer = key_layer.contiguous().view(
+ key_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
+ )
+ value_layer = value_layer.unsqueeze(-2)
+ value_layer = value_layer.expand(
+ -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1
+ )
+ value_layer = value_layer.contiguous().view(
+ value_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
+ )
+
+ # ==================================
+ # core attention computation
+ # ==================================
+
+ context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask)
+
+ # =================
+ # Output. [sq, b, h]
+ # =================
+
+ output = self.dense(context_layer)
+
+ return output, kv_cache
+
+
+def _config_to_kwargs(args):
+ common_kwargs = {
+ "dtype": args.torch_dtype,
+ }
+ return common_kwargs
+
+
+class MLP(torch.nn.Module):
+ """MLP.
+
+ MLP will take the input with h hidden state, project it to 4*h
+ hidden dimension, perform nonlinear transformation, and project the
+ state back into h hidden dimension.
+ """
+
+ def __init__(self, config: ChatGLMConfig, device=None):
+ super(MLP, self).__init__()
+
+ self.add_bias = config.add_bias_linear
+
+ # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
+ self.dense_h_to_4h = nn.Linear(
+ config.hidden_size,
+ config.ffn_hidden_size * 2,
+ bias=self.add_bias,
+ device=device,
+ **_config_to_kwargs(config)
+ )
+
+ def swiglu(x):
+ x = torch.chunk(x, 2, dim=-1)
+ return F.silu(x[0]) * x[1]
+
+ self.activation_func = swiglu
+
+ # Project back to h.
+ self.dense_4h_to_h = nn.Linear(
+ config.ffn_hidden_size,
+ config.hidden_size,
+ bias=self.add_bias,
+ device=device,
+ **_config_to_kwargs(config)
+ )
+
+ def forward(self, hidden_states):
+ # [s, b, 4hp]
+ intermediate_parallel = self.dense_h_to_4h(hidden_states)
+ intermediate_parallel = self.activation_func(intermediate_parallel)
+ # [s, b, h]
+ output = self.dense_4h_to_h(intermediate_parallel)
+ return output
+
+
+class GLMBlock(torch.nn.Module):
+ """A single transformer layer.
+
+ Transformer layer takes input with size [s, b, h] and returns an
+ output of the same size.
+ """
+
+ def __init__(self, config: ChatGLMConfig, layer_number, device=None):
+ super(GLMBlock, self).__init__()
+ self.layer_number = layer_number
+
+ self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
+
+ self.fp32_residual_connection = config.fp32_residual_connection
+
+ LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
+ # Layernorm on the input data.
+ self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
+ dtype=config.torch_dtype)
+
+ # Self attention.
+ self.self_attention = SelfAttention(config, layer_number, device=device)
+ self.hidden_dropout = config.hidden_dropout
+
+ # Layernorm on the attention output
+ self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
+ dtype=config.torch_dtype)
+
+ # MLP
+ self.mlp = MLP(config, device=device)
+
+ def forward(
+ self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True,
+ ):
+ # hidden_states: [s, b, h]
+
+ # Layer norm at the beginning of the transformer layer.
+ layernorm_output = self.input_layernorm(hidden_states)
+ # Self attention.
+ attention_output, kv_cache = self.self_attention(
+ layernorm_output,
+ attention_mask,
+ rotary_pos_emb,
+ kv_cache=kv_cache,
+ use_cache=use_cache
+ )
+
+ # Residual connection.
+ if self.apply_residual_connection_post_layernorm:
+ residual = layernorm_output
+ else:
+ residual = hidden_states
+
+ layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training)
+ layernorm_input = residual + layernorm_input
+
+ # Layer norm post the self attention.
+ layernorm_output = self.post_attention_layernorm(layernorm_input)
+
+ # MLP.
+ mlp_output = self.mlp(layernorm_output)
+
+ # Second residual connection.
+ if self.apply_residual_connection_post_layernorm:
+ residual = layernorm_output
+ else:
+ residual = layernorm_input
+
+ output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training)
+ output = residual + output
+
+ return output, kv_cache
+
+
+class GLMTransformer(torch.nn.Module):
+ """Transformer class."""
+
+ def __init__(self, config: ChatGLMConfig, device=None):
+ super(GLMTransformer, self).__init__()
+
+ self.fp32_residual_connection = config.fp32_residual_connection
+ self.post_layer_norm = config.post_layer_norm
+
+ # Number of layers.
+ self.num_layers = config.num_layers
+
+ # Transformer layers.
+ def build_layer(layer_number):
+ return GLMBlock(config, layer_number, device=device)
+
+ self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)])
+
+ if self.post_layer_norm:
+ LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
+ # Final layer norm before output.
+ self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
+ dtype=config.torch_dtype)
+
+ self.gradient_checkpointing = False
+
+ def _get_layer(self, layer_number):
+ return self.layers[layer_number]
+
+ def forward(
+ self, hidden_states, attention_mask, rotary_pos_emb, kv_caches=None,
+ use_cache: Optional[bool] = True,
+ output_hidden_states: Optional[bool] = False,
+ ):
+ if not kv_caches:
+ kv_caches = [None for _ in range(self.num_layers)]
+ presents = () if use_cache else None
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ all_self_attentions = None
+ all_hidden_states = () if output_hidden_states else None
+ for index in range(self.num_layers):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer = self._get_layer(index)
+ if self.gradient_checkpointing and self.training:
+ layer_ret = torch.utils.checkpoint.checkpoint(
+ layer,
+ hidden_states,
+ attention_mask,
+ rotary_pos_emb,
+ kv_caches[index],
+ use_cache
+ )
+ else:
+ layer_ret = layer(
+ hidden_states,
+ attention_mask,
+ rotary_pos_emb,
+ kv_cache=kv_caches[index],
+ use_cache=use_cache
+ )
+ hidden_states, kv_cache = layer_ret
+ if use_cache:
+ presents = presents + (kv_cache,)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ # Final layer norm.
+ if self.post_layer_norm:
+ hidden_states = self.final_layernorm(hidden_states)
+
+ return hidden_states, presents, all_hidden_states, all_self_attentions
+
+
+class ChatGLMPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and
+ a simple interface for downloading and loading pretrained models.
+ """
+
+ is_parallelizable = False
+ supports_gradient_checkpointing = True
+ config_class = ChatGLMConfig
+ base_model_prefix = "transformer"
+ _no_split_modules = ["GLMBlock"]
+
+ def _init_weights(self, module: nn.Module):
+ """Initialize the weights."""
+ return
+
+ def get_masks(self, input_ids, past_key_values, padding_mask=None):
+ batch_size, seq_length = input_ids.shape
+ full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device)
+ full_attention_mask.tril_()
+ past_length = 0
+ if past_key_values:
+ past_length = past_key_values[0][0].shape[0]
+ if past_length:
+ full_attention_mask = torch.cat((torch.ones(batch_size, seq_length, past_length,
+ device=input_ids.device), full_attention_mask), dim=-1)
+ if padding_mask is not None:
+ full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)
+ if not past_length and padding_mask is not None:
+ full_attention_mask -= padding_mask.unsqueeze(-1) - 1
+ full_attention_mask = (full_attention_mask < 0.5).bool()
+ full_attention_mask.unsqueeze_(1)
+ return full_attention_mask
+
+ def get_position_ids(self, input_ids, device):
+ batch_size, seq_length = input_ids.shape
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
+ return position_ids
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, GLMTransformer):
+ module.gradient_checkpointing = value
+
+
+class Embedding(torch.nn.Module):
+ """Language model embeddings."""
+
+ def __init__(self, config: ChatGLMConfig, device=None):
+ super(Embedding, self).__init__()
+
+ self.hidden_size = config.hidden_size
+ # Word embeddings (parallel).
+ self.word_embeddings = nn.Embedding(
+ config.padded_vocab_size,
+ self.hidden_size,
+ dtype=config.torch_dtype,
+ device=device
+ )
+ self.fp32_residual_connection = config.fp32_residual_connection
+
+ def forward(self, input_ids):
+ # Embeddings.
+ words_embeddings = self.word_embeddings(input_ids)
+ embeddings = words_embeddings
+ # Data format change to avoid explicit transposes : [b s h] --> [s b h].
+ embeddings = embeddings.transpose(0, 1).contiguous()
+ # If the input flag for fp32 residual connection is set, convert for float.
+ if self.fp32_residual_connection:
+ embeddings = embeddings.float()
+ return embeddings
+
+
+class ChatGLMModel(ChatGLMPreTrainedModel):
+ def __init__(self, config: ChatGLMConfig, device=None, empty_init=True):
+ super().__init__(config)
+ if empty_init:
+ init_method = skip_init
+ else:
+ init_method = default_init
+ init_kwargs = {}
+ if device is not None:
+ init_kwargs["device"] = device
+ self.embedding = init_method(Embedding, config, **init_kwargs)
+ self.num_layers = config.num_layers
+ self.multi_query_group_num = config.multi_query_group_num
+ self.kv_channels = config.kv_channels
+
+ # Rotary positional embeddings
+ self.seq_length = config.seq_length
+ rotary_dim = (
+ config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
+ )
+
+ self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device,
+ dtype=config.torch_dtype)
+ self.encoder = init_method(GLMTransformer, config, **init_kwargs)
+ self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,
+ dtype=config.torch_dtype, **init_kwargs)
+ self.pre_seq_len = config.pre_seq_len
+ self.prefix_projection = config.prefix_projection
+ if self.pre_seq_len is not None:
+ for param in self.parameters():
+ param.requires_grad = False
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
+ self.prefix_encoder = PrefixEncoder(config)
+ self.dropout = torch.nn.Dropout(0.1)
+
+ def get_input_embeddings(self):
+ return self.embedding.word_embeddings
+
+ def get_prompt(self, batch_size, device, dtype=torch.half):
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)
+ past_key_values = self.prefix_encoder(prefix_tokens).type(dtype)
+ past_key_values = past_key_values.view(
+ batch_size,
+ self.pre_seq_len,
+ self.num_layers * 2,
+ self.multi_query_group_num,
+ self.kv_channels
+ )
+ # seq_len, b, nh, hidden_size
+ past_key_values = self.dropout(past_key_values)
+ past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2)
+ return past_key_values
+
+ def forward(
+ self,
+ input_ids,
+ position_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.BoolTensor] = None,
+ full_attention_mask: Optional[torch.BoolTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ):
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ batch_size, seq_length = input_ids.shape
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embedding(input_ids)
+
+ if self.pre_seq_len is not None:
+ if past_key_values is None:
+ past_key_values = self.get_prompt(batch_size=batch_size, device=input_ids.device,
+ dtype=inputs_embeds.dtype)
+ if attention_mask is not None:
+ attention_mask = torch.cat([attention_mask.new_ones((batch_size, self.pre_seq_len)),
+ attention_mask], dim=-1)
+
+ if full_attention_mask is None:
+ if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
+ full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
+
+ # Rotary positional embeddings
+ rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
+ if position_ids is not None:
+ rotary_pos_emb = rotary_pos_emb[position_ids]
+ else:
+ rotary_pos_emb = rotary_pos_emb[None, :seq_length]
+ rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
+
+ # Run encoder.
+ hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
+ inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb,
+ kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states
+ )
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=presents,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+ def quantize(self, weight_bit_width: int):
+ # from .quantization import quantize
+ quantize(self.encoder, weight_bit_width)
+ return self
+
+
+class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
+ def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):
+ super().__init__(config)
+
+ self.max_sequence_length = config.max_length
+ self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)
+ self.config = config
+ self.quantized = False
+
+ if self.config.quantization_bit:
+ self.quantize(self.config.quantization_bit, empty_init=True)
+
+ def _update_model_kwargs_for_generation(
+ self,
+ outputs: ModelOutput,
+ model_kwargs: Dict[str, Any],
+ is_encoder_decoder: bool = False,
+ standardize_cache_format: bool = False,
+ ) -> Dict[str, Any]:
+ # update past_key_values
+ model_kwargs["past_key_values"] = self._extract_past_from_model_output(
+ outputs, standardize_cache_format=standardize_cache_format
+ )
+
+ # update attention mask
+ if "attention_mask" in model_kwargs:
+ attention_mask = model_kwargs["attention_mask"]
+ model_kwargs["attention_mask"] = torch.cat(
+ [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
+ )
+
+ # update position ids
+ if "position_ids" in model_kwargs:
+ position_ids = model_kwargs["position_ids"]
+ new_position_id = position_ids[..., -1:].clone()
+ new_position_id += 1
+ model_kwargs["position_ids"] = torch.cat(
+ [position_ids, new_position_id], dim=-1
+ )
+
+ model_kwargs["is_first_forward"] = False
+ return model_kwargs
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids: torch.LongTensor,
+ past_key_values: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ is_first_forward: bool = True,
+ **kwargs
+ ) -> dict:
+ # only last token for input_ids if past is not None
+ if position_ids is None:
+ position_ids = self.get_position_ids(input_ids, device=input_ids.device)
+ if not is_first_forward:
+ if past_key_values is not None:
+ position_ids = position_ids[..., -1:]
+ input_ids = input_ids[:, -1:]
+ return {
+ "input_ids": input_ids,
+ "past_key_values": past_key_values,
+ "position_ids": position_ids,
+ "attention_mask": attention_mask,
+ "return_last_logit": True,
+ "use_cache": use_cache
+ }
+
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ return_last_logit: Optional[bool] = False,
+ ):
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ transformer_outputs = self.transformer(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = transformer_outputs[0]
+ if return_last_logit:
+ hidden_states = hidden_states[-1:]
+ lm_logits = self.transformer.output_layer(hidden_states)
+ lm_logits = lm_logits.transpose(0, 1).contiguous()
+
+ loss = None
+ if labels is not None:
+ lm_logits = lm_logits.to(torch.float32)
+
+ # Shift so that tokens < n predict n
+ shift_logits = lm_logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss(ignore_index=-100)
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
+
+ lm_logits = lm_logits.to(hidden_states.dtype)
+ loss = loss.to(hidden_states.dtype)
+
+ if not return_dict:
+ output = (lm_logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=lm_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+ @staticmethod
+ def _reorder_cache(
+ past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
+ ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
+ """
+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
+ [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
+ beam_idx at every generation step.
+
+ Output shares the same memory storage as `past`.
+ """
+ return tuple(
+ (
+ layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)),
+ layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)),
+ )
+ for layer_past in past
+ )
+
+ def process_response(self, output, history):
+ content = ""
+ history = deepcopy(history)
+ for response in output.split("<|assistant|>"):
+ metadata, content = response.split("\n", maxsplit=1)
+ if not metadata.strip():
+ content = content.strip()
+ history.append({"role": "assistant", "metadata": metadata, "content": content})
+ content = content.replace("[[训练时间]]", "2023年")
+ else:
+ history.append({"role": "assistant", "metadata": metadata, "content": content})
+ if history[0]["role"] == "system" and "tools" in history[0]:
+ content = "\n".join(content.split("\n")[1:-1])
+ def tool_call(**kwargs):
+ return kwargs
+ parameters = eval(content)
+ content = {"name": metadata.strip(), "parameters": parameters}
+ else:
+ content = {"name": metadata.strip(), "content": content}
+ return content, history
+
+ @torch.inference_mode()
+ def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, role: str = "user",
+ max_length: int = 8192, num_beams=1, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None,
+ **kwargs):
+ if history is None:
+ history = []
+ if logits_processor is None:
+ logits_processor = LogitsProcessorList()
+ logits_processor.append(InvalidScoreLogitsProcessor())
+ gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
+ "temperature": temperature, "logits_processor": logits_processor, **kwargs}
+ inputs = tokenizer.build_chat_input(query, history=history, role=role)
+ inputs = inputs.to(self.device)
+ eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
+ tokenizer.get_command("<|observation|>")]
+ outputs = self.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id)
+ outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
+ response = tokenizer.decode(outputs)
+ history.append({"role": role, "content": query})
+ response, history = self.process_response(response, history)
+ return response, history
+
+ @torch.inference_mode()
+ def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, role: str = "user",
+ past_key_values=None,max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8,
+ logits_processor=None, return_past_key_values=False, **kwargs):
+ if history is None:
+ history = []
+ if logits_processor is None:
+ logits_processor = LogitsProcessorList()
+ logits_processor.append(InvalidScoreLogitsProcessor())
+ eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
+ tokenizer.get_command("<|observation|>")]
+ gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
+ "temperature": temperature, "logits_processor": logits_processor, **kwargs}
+ if past_key_values is None:
+ inputs = tokenizer.build_chat_input(query, history=history, role=role)
+ else:
+ inputs = tokenizer.build_chat_input(query, role=role)
+ inputs = inputs.to(self.device)
+ if past_key_values is not None:
+ past_length = past_key_values[0][0].shape[0]
+ if self.transformer.pre_seq_len is not None:
+ past_length -= self.transformer.pre_seq_len
+ inputs.position_ids += past_length
+ attention_mask = inputs.attention_mask
+ attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1)
+ inputs['attention_mask'] = attention_mask
+ history.append({"role": role, "content": query})
+ for outputs in self.stream_generate(**inputs, past_key_values=past_key_values,
+ eos_token_id=eos_token_id, return_past_key_values=return_past_key_values,
+ **gen_kwargs):
+ if return_past_key_values:
+ outputs, past_key_values = outputs
+ outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
+ response = tokenizer.decode(outputs)
+ if response and response[-1] != "�":
+ response, new_history = self.process_response(response, history)
+ if return_past_key_values:
+ yield response, new_history, past_key_values
+ else:
+ yield response, new_history
+
+ @torch.inference_mode()
+ def stream_generate(
+ self,
+ input_ids,
+ generation_config: Optional[GenerationConfig] = None,
+ logits_processor: Optional[LogitsProcessorList] = None,
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
+ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
+ return_past_key_values=False,
+ **kwargs,
+ ):
+ batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
+
+ if generation_config is None:
+ generation_config = self.generation_config
+ generation_config = copy.deepcopy(generation_config)
+ model_kwargs = generation_config.update(**kwargs)
+ model_kwargs["use_cache"] = generation_config.use_cache
+ bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id
+
+ if isinstance(eos_token_id, int):
+ eos_token_id = [eos_token_id]
+ eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
+
+ has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
+ if has_default_max_length and generation_config.max_new_tokens is None:
+ warnings.warn(
+ f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
+ "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
+ " recommend using `max_new_tokens` to control the maximum length of the generation.",
+ UserWarning,
+ )
+ elif generation_config.max_new_tokens is not None:
+ generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
+ if not has_default_max_length:
+ logger.warn(
+ f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
+ f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
+ "Please refer to the documentation for more information. "
+ "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
+ UserWarning,
+ )
+
+ if input_ids_seq_length >= generation_config.max_length:
+ input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
+ logger.warning(
+ f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
+ f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
+ " increasing `max_new_tokens`."
+ )
+
+ # 2. Set generation parameters if not already defined
+ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
+ stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
+
+ logits_processor = self._get_logits_processor(
+ generation_config=generation_config,
+ input_ids_seq_length=input_ids_seq_length,
+ encoder_input_ids=input_ids,
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
+ logits_processor=logits_processor,
+ )
+
+ stopping_criteria = self._get_stopping_criteria(
+ generation_config=generation_config, stopping_criteria=stopping_criteria
+ )
+ logits_warper = self._get_logits_warper(generation_config)
+
+ unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
+ scores = None
+ while True:
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
+ # forward pass to get next token
+ outputs = self(
+ **model_inputs,
+ return_dict=True,
+ output_attentions=False,
+ output_hidden_states=False,
+ )
+
+ next_token_logits = outputs.logits[:, -1, :]
+
+ # pre-process distribution
+ next_token_scores = logits_processor(input_ids, next_token_logits)
+ next_token_scores = logits_warper(input_ids, next_token_scores)
+
+ # sample
+ probs = nn.functional.softmax(next_token_scores, dim=-1)
+ if generation_config.do_sample:
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
+ else:
+ next_tokens = torch.argmax(probs, dim=-1)
+ # update generated ids, model inputs, and length for next step
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
+ model_kwargs = self._update_model_kwargs_for_generation(
+ outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
+ )
+ unfinished_sequences = unfinished_sequences.mul(
+ next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
+ )
+ if return_past_key_values:
+ yield input_ids, outputs.past_key_values
+ else:
+ yield input_ids
+ # stop when each sentence is finished, or if we exceed the maximum length
+ if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
+ break
+
+ def quantize(self, bits: int, empty_init=False, device=None, **kwargs):
+ if bits == 0:
+ return
+
+ # from .quantization import quantize
+
+ if self.quantized:
+ logger.info("Already quantized.")
+ return self
+
+ self.quantized = True
+
+ self.config.quantization_bit = bits
+
+ self.transformer.encoder = quantize(self.transformer.encoder, bits, empty_init=empty_init, device=device,
+ **kwargs)
+ return self
+
+
+class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):
+ def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):
+ super().__init__(config)
+
+ self.num_labels = config.num_labels
+ self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)
+
+ self.classifier_head = nn.Linear(config.hidden_size, config.num_labels, bias=True, dtype=torch.half)
+ if config.classifier_dropout is not None:
+ self.dropout = nn.Dropout(config.classifier_dropout)
+ else:
+ self.dropout = None
+ self.config = config
+
+ if self.config.quantization_bit:
+ self.quantize(self.config.quantization_bit, empty_init=True)
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ full_attention_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
+ inputs_embeds: Optional[torch.LongTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor, ...], SequenceClassifierOutputWithPast]:
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ transformer_outputs = self.transformer(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ full_attention_mask=full_attention_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = transformer_outputs[0]
+ pooled_hidden_states = hidden_states[-1]
+ if self.dropout is not None:
+ pooled_hidden_states = self.dropout(pooled_hidden_states)
+ logits = self.classifier_head(pooled_hidden_states)
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(logits.squeeze().float(), labels.squeeze())
+ else:
+ loss = loss_fct(logits.float(), labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels).float(), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits.float(), labels.view(-1, self.num_labels))
+
+ if not return_dict:
+ output = (logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
diff --git a/PusaV1/diffsynth/models/lora.py b/PusaV1/diffsynth/models/lora.py
new file mode 100644
index 0000000000000000000000000000000000000000..96d7ca9f1bd79f5748b91b5a5236f1b03f378ee3
--- /dev/null
+++ b/PusaV1/diffsynth/models/lora.py
@@ -0,0 +1,386 @@
+import torch
+from .sd_unet import SDUNet
+from .sdxl_unet import SDXLUNet
+from .sd_text_encoder import SDTextEncoder
+from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
+from .sd3_dit import SD3DiT
+from .flux_dit import FluxDiT
+from .hunyuan_dit import HunyuanDiT
+from .cog_dit import CogDiT
+from .hunyuan_video_dit import HunyuanVideoDiT
+from .wan_video_dit import WanModel
+from .wan_video_pusa import WanModelPusa
+
+
+class LoRAFromCivitai:
+ def __init__(self):
+ self.supported_model_classes = []
+ self.lora_prefix = []
+ self.renamed_lora_prefix = {}
+ self.special_keys = {}
+
+
+ def convert_state_dict(self, state_dict, lora_prefix="lora_unet_", alpha=1.0):
+ for key in state_dict:
+ if ".lora_up" in key:
+ return self.convert_state_dict_up_down(state_dict, lora_prefix, alpha)
+ return self.convert_state_dict_AB(state_dict, lora_prefix, alpha)
+
+
+ def convert_state_dict_up_down(self, state_dict, lora_prefix="lora_unet_", alpha=1.0):
+ renamed_lora_prefix = self.renamed_lora_prefix.get(lora_prefix, "")
+ state_dict_ = {}
+ for key in state_dict:
+ if ".lora_up" not in key:
+ continue
+ if not key.startswith(lora_prefix):
+ continue
+ weight_up = state_dict[key].to(device="cuda", dtype=torch.float16)
+ weight_down = state_dict[key.replace(".lora_up", ".lora_down")].to(device="cuda", dtype=torch.float16)
+ if len(weight_up.shape) == 4:
+ weight_up = weight_up.squeeze(3).squeeze(2).to(torch.float32)
+ weight_down = weight_down.squeeze(3).squeeze(2).to(torch.float32)
+ lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
+ else:
+ lora_weight = alpha * torch.mm(weight_up, weight_down)
+ target_name = key.split(".")[0].replace(lora_prefix, renamed_lora_prefix).replace("_", ".") + ".weight"
+ for special_key in self.special_keys:
+ target_name = target_name.replace(special_key, self.special_keys[special_key])
+ state_dict_[target_name] = lora_weight.cpu()
+ return state_dict_
+
+
+ def convert_state_dict_AB(self, state_dict, lora_prefix="", alpha=1.0, device="cuda", torch_dtype=torch.float16):
+ state_dict_ = {}
+ for key in state_dict:
+ if ".lora_B." not in key:
+ continue
+ if not key.startswith(lora_prefix):
+ continue
+ weight_up = state_dict[key].to(device=device, dtype=torch_dtype)
+ weight_down = state_dict[key.replace(".lora_B.", ".lora_A.")].to(device=device, dtype=torch_dtype)
+ if len(weight_up.shape) == 4:
+ weight_up = weight_up.squeeze(3).squeeze(2)
+ weight_down = weight_down.squeeze(3).squeeze(2)
+ lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
+ else:
+ lora_weight = alpha * torch.mm(weight_up, weight_down)
+ keys = key.split(".")
+ keys.pop(keys.index("lora_B"))
+ target_name = ".".join(keys)
+ target_name = target_name[len(lora_prefix):]
+ state_dict_[target_name] = lora_weight.cpu()
+ return state_dict_
+
+
+ def load(self, model, state_dict_lora, lora_prefix, alpha=1.0, model_resource=None):
+ state_dict_model = model.state_dict()
+ state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix=lora_prefix, alpha=alpha)
+ if model_resource == "diffusers":
+ state_dict_lora = model.__class__.state_dict_converter().from_diffusers(state_dict_lora)
+ elif model_resource == "civitai":
+ state_dict_lora = model.__class__.state_dict_converter().from_civitai(state_dict_lora)
+ if isinstance(state_dict_lora, tuple):
+ state_dict_lora = state_dict_lora[0]
+ if len(state_dict_lora) > 0:
+ print(f" {len(state_dict_lora)} tensors are updated.")
+ for name in state_dict_lora:
+ fp8=False
+ if state_dict_model[name].dtype == torch.float8_e4m3fn:
+ state_dict_model[name]= state_dict_model[name].to(state_dict_lora[name].dtype)
+ fp8=True
+ state_dict_model[name] += state_dict_lora[name].to(
+ dtype=state_dict_model[name].dtype, device=state_dict_model[name].device)
+ if fp8:
+ state_dict_model[name] = state_dict_model[name].to(torch.float8_e4m3fn)
+ model.load_state_dict(state_dict_model)
+
+
+ def match(self, model, state_dict_lora):
+ for lora_prefix, model_class in zip(self.lora_prefix, self.supported_model_classes):
+ if not isinstance(model, model_class):
+ continue
+ state_dict_model = model.state_dict()
+ for model_resource in ["diffusers", "civitai"]:
+ try:
+ state_dict_lora_ = self.convert_state_dict(state_dict_lora, lora_prefix=lora_prefix, alpha=1.0)
+ converter_fn = model.__class__.state_dict_converter().from_diffusers if model_resource == "diffusers" \
+ else model.__class__.state_dict_converter().from_civitai
+ state_dict_lora_ = converter_fn(state_dict_lora_)
+ if isinstance(state_dict_lora_, tuple):
+ state_dict_lora_ = state_dict_lora_[0]
+ if len(state_dict_lora_) == 0:
+ continue
+ for name in state_dict_lora_:
+ if name not in state_dict_model:
+ break
+ else:
+ return lora_prefix, model_resource
+ except:
+ pass
+ return None
+
+
+
+class SDLoRAFromCivitai(LoRAFromCivitai):
+ def __init__(self):
+ super().__init__()
+ self.supported_model_classes = [SDUNet, SDTextEncoder]
+ self.lora_prefix = ["lora_unet_", "lora_te_"]
+ self.special_keys = {
+ "down.blocks": "down_blocks",
+ "up.blocks": "up_blocks",
+ "mid.block": "mid_block",
+ "proj.in": "proj_in",
+ "proj.out": "proj_out",
+ "transformer.blocks": "transformer_blocks",
+ "to.q": "to_q",
+ "to.k": "to_k",
+ "to.v": "to_v",
+ "to.out": "to_out",
+ "text.model": "text_model",
+ "self.attn.q.proj": "self_attn.q_proj",
+ "self.attn.k.proj": "self_attn.k_proj",
+ "self.attn.v.proj": "self_attn.v_proj",
+ "self.attn.out.proj": "self_attn.out_proj",
+ "input.blocks": "model.diffusion_model.input_blocks",
+ "middle.block": "model.diffusion_model.middle_block",
+ "output.blocks": "model.diffusion_model.output_blocks",
+ }
+
+
+class SDXLLoRAFromCivitai(LoRAFromCivitai):
+ def __init__(self):
+ super().__init__()
+ self.supported_model_classes = [SDXLUNet, SDXLTextEncoder, SDXLTextEncoder2]
+ self.lora_prefix = ["lora_unet_", "lora_te1_", "lora_te2_"]
+ self.renamed_lora_prefix = {"lora_te2_": "2"}
+ self.special_keys = {
+ "down.blocks": "down_blocks",
+ "up.blocks": "up_blocks",
+ "mid.block": "mid_block",
+ "proj.in": "proj_in",
+ "proj.out": "proj_out",
+ "transformer.blocks": "transformer_blocks",
+ "to.q": "to_q",
+ "to.k": "to_k",
+ "to.v": "to_v",
+ "to.out": "to_out",
+ "text.model": "conditioner.embedders.0.transformer.text_model",
+ "self.attn.q.proj": "self_attn.q_proj",
+ "self.attn.k.proj": "self_attn.k_proj",
+ "self.attn.v.proj": "self_attn.v_proj",
+ "self.attn.out.proj": "self_attn.out_proj",
+ "input.blocks": "model.diffusion_model.input_blocks",
+ "middle.block": "model.diffusion_model.middle_block",
+ "output.blocks": "model.diffusion_model.output_blocks",
+ "2conditioner.embedders.0.transformer.text_model.encoder.layers": "text_model.encoder.layers"
+ }
+
+
+class FluxLoRAFromCivitai(LoRAFromCivitai):
+ def __init__(self):
+ super().__init__()
+ self.supported_model_classes = [FluxDiT, FluxDiT]
+ self.lora_prefix = ["lora_unet_", "transformer."]
+ self.renamed_lora_prefix = {}
+ self.special_keys = {
+ "single.blocks": "single_blocks",
+ "double.blocks": "double_blocks",
+ "img.attn": "img_attn",
+ "img.mlp": "img_mlp",
+ "img.mod": "img_mod",
+ "txt.attn": "txt_attn",
+ "txt.mlp": "txt_mlp",
+ "txt.mod": "txt_mod",
+ }
+
+
+
+class GeneralLoRAFromPeft:
+ def __init__(self):
+ self.supported_model_classes = [SDUNet, SDXLUNet, SD3DiT, HunyuanDiT, FluxDiT, CogDiT, WanModel, WanModelPusa]
+
+
+ def get_name_dict(self, lora_state_dict):
+ lora_name_dict = {}
+ for key in lora_state_dict:
+ if ".lora_B." not in key:
+ continue
+ keys = key.split(".")
+ if len(keys) > keys.index("lora_B") + 2:
+ keys.pop(keys.index("lora_B") + 1)
+ keys.pop(keys.index("lora_B"))
+ if keys[0] == "diffusion_model":
+ keys.pop(0)
+ target_name = ".".join(keys)
+ lora_name_dict[target_name] = (key, key.replace(".lora_B.", ".lora_A."))
+ return lora_name_dict
+
+
+ def match(self, model: torch.nn.Module, state_dict_lora):
+ lora_name_dict = self.get_name_dict(state_dict_lora)
+ model_name_dict = {name: None for name, _ in model.named_parameters()}
+ matched_num = sum([i in model_name_dict for i in lora_name_dict])
+ if matched_num == len(lora_name_dict):
+ return "", ""
+ else:
+ return None
+
+
+ def fetch_device_and_dtype(self, state_dict):
+ device, dtype = None, None
+ for name, param in state_dict.items():
+ device, dtype = param.device, param.dtype
+ break
+ computation_device = device
+ computation_dtype = dtype
+ if computation_device == torch.device("cpu"):
+ if torch.cuda.is_available():
+ computation_device = torch.device("cuda")
+ if computation_dtype == torch.float8_e4m3fn:
+ computation_dtype = torch.float32
+ return device, dtype, computation_device, computation_dtype
+
+
+ def load(self, model, state_dict_lora, lora_prefix="", alpha=1.0, model_resource=""):
+ state_dict_model = model.state_dict()
+ device, dtype, computation_device, computation_dtype = self.fetch_device_and_dtype(state_dict_model)
+ lora_name_dict = self.get_name_dict(state_dict_lora)
+ for name in lora_name_dict:
+ weight_up = state_dict_lora[lora_name_dict[name][0]].to(device=computation_device, dtype=computation_dtype)
+ weight_down = state_dict_lora[lora_name_dict[name][1]].to(device=computation_device, dtype=computation_dtype)
+ if len(weight_up.shape) == 4:
+ weight_up = weight_up.squeeze(3).squeeze(2)
+ weight_down = weight_down.squeeze(3).squeeze(2)
+ weight_lora = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
+ else:
+ weight_lora = alpha * torch.mm(weight_up, weight_down)
+ weight_model = state_dict_model[name].to(device=computation_device, dtype=computation_dtype)
+ weight_patched = weight_model + weight_lora
+ state_dict_model[name] = weight_patched.to(device=device, dtype=dtype)
+ print(f" {len(lora_name_dict)} tensors are updated.")
+ model.load_state_dict(state_dict_model)
+
+
+
+class HunyuanVideoLoRAFromCivitai(LoRAFromCivitai):
+ def __init__(self):
+ super().__init__()
+ self.supported_model_classes = [HunyuanVideoDiT, HunyuanVideoDiT]
+ self.lora_prefix = ["diffusion_model.", "transformer."]
+ self.special_keys = {}
+
+
+class FluxLoRAConverter:
+ def __init__(self):
+ pass
+
+ @staticmethod
+ def align_to_opensource_format(state_dict, alpha=1.0):
+ prefix_rename_dict = {
+ "single_blocks": "lora_unet_single_blocks",
+ "blocks": "lora_unet_double_blocks",
+ }
+ middle_rename_dict = {
+ "norm.linear": "modulation_lin",
+ "to_qkv_mlp": "linear1",
+ "proj_out": "linear2",
+
+ "norm1_a.linear": "img_mod_lin",
+ "norm1_b.linear": "txt_mod_lin",
+ "attn.a_to_qkv": "img_attn_qkv",
+ "attn.b_to_qkv": "txt_attn_qkv",
+ "attn.a_to_out": "img_attn_proj",
+ "attn.b_to_out": "txt_attn_proj",
+ "ff_a.0": "img_mlp_0",
+ "ff_a.2": "img_mlp_2",
+ "ff_b.0": "txt_mlp_0",
+ "ff_b.2": "txt_mlp_2",
+ }
+ suffix_rename_dict = {
+ "lora_B.weight": "lora_up.weight",
+ "lora_A.weight": "lora_down.weight",
+ }
+ state_dict_ = {}
+ for name, param in state_dict.items():
+ names = name.split(".")
+ if names[-2] != "lora_A" and names[-2] != "lora_B":
+ names.pop(-2)
+ prefix = names[0]
+ middle = ".".join(names[2:-2])
+ suffix = ".".join(names[-2:])
+ block_id = names[1]
+ if middle not in middle_rename_dict:
+ continue
+ rename = prefix_rename_dict[prefix] + "_" + block_id + "_" + middle_rename_dict[middle] + "." + suffix_rename_dict[suffix]
+ state_dict_[rename] = param
+ if rename.endswith("lora_up.weight"):
+ state_dict_[rename.replace("lora_up.weight", "alpha")] = torch.tensor((alpha,))[0]
+ return state_dict_
+
+ @staticmethod
+ def align_to_diffsynth_format(state_dict):
+ rename_dict = {
+ "lora_unet_double_blocks_blockid_img_mod_lin.lora_down.weight": "blocks.blockid.norm1_a.linear.lora_A.default.weight",
+ "lora_unet_double_blocks_blockid_img_mod_lin.lora_up.weight": "blocks.blockid.norm1_a.linear.lora_B.default.weight",
+ "lora_unet_double_blocks_blockid_txt_mod_lin.lora_down.weight": "blocks.blockid.norm1_b.linear.lora_A.default.weight",
+ "lora_unet_double_blocks_blockid_txt_mod_lin.lora_up.weight": "blocks.blockid.norm1_b.linear.lora_B.default.weight",
+ "lora_unet_double_blocks_blockid_img_attn_qkv.lora_down.weight": "blocks.blockid.attn.a_to_qkv.lora_A.default.weight",
+ "lora_unet_double_blocks_blockid_img_attn_qkv.lora_up.weight": "blocks.blockid.attn.a_to_qkv.lora_B.default.weight",
+ "lora_unet_double_blocks_blockid_txt_attn_qkv.lora_down.weight": "blocks.blockid.attn.b_to_qkv.lora_A.default.weight",
+ "lora_unet_double_blocks_blockid_txt_attn_qkv.lora_up.weight": "blocks.blockid.attn.b_to_qkv.lora_B.default.weight",
+ "lora_unet_double_blocks_blockid_img_attn_proj.lora_down.weight": "blocks.blockid.attn.a_to_out.lora_A.default.weight",
+ "lora_unet_double_blocks_blockid_img_attn_proj.lora_up.weight": "blocks.blockid.attn.a_to_out.lora_B.default.weight",
+ "lora_unet_double_blocks_blockid_txt_attn_proj.lora_down.weight": "blocks.blockid.attn.b_to_out.lora_A.default.weight",
+ "lora_unet_double_blocks_blockid_txt_attn_proj.lora_up.weight": "blocks.blockid.attn.b_to_out.lora_B.default.weight",
+ "lora_unet_double_blocks_blockid_img_mlp_0.lora_down.weight": "blocks.blockid.ff_a.0.lora_A.default.weight",
+ "lora_unet_double_blocks_blockid_img_mlp_0.lora_up.weight": "blocks.blockid.ff_a.0.lora_B.default.weight",
+ "lora_unet_double_blocks_blockid_img_mlp_2.lora_down.weight": "blocks.blockid.ff_a.2.lora_A.default.weight",
+ "lora_unet_double_blocks_blockid_img_mlp_2.lora_up.weight": "blocks.blockid.ff_a.2.lora_B.default.weight",
+ "lora_unet_double_blocks_blockid_txt_mlp_0.lora_down.weight": "blocks.blockid.ff_b.0.lora_A.default.weight",
+ "lora_unet_double_blocks_blockid_txt_mlp_0.lora_up.weight": "blocks.blockid.ff_b.0.lora_B.default.weight",
+ "lora_unet_double_blocks_blockid_txt_mlp_2.lora_down.weight": "blocks.blockid.ff_b.2.lora_A.default.weight",
+ "lora_unet_double_blocks_blockid_txt_mlp_2.lora_up.weight": "blocks.blockid.ff_b.2.lora_B.default.weight",
+ "lora_unet_single_blocks_blockid_modulation_lin.lora_down.weight": "single_blocks.blockid.norm.linear.lora_A.default.weight",
+ "lora_unet_single_blocks_blockid_modulation_lin.lora_up.weight": "single_blocks.blockid.norm.linear.lora_B.default.weight",
+ "lora_unet_single_blocks_blockid_linear1.lora_down.weight": "single_blocks.blockid.to_qkv_mlp.lora_A.default.weight",
+ "lora_unet_single_blocks_blockid_linear1.lora_up.weight": "single_blocks.blockid.to_qkv_mlp.lora_B.default.weight",
+ "lora_unet_single_blocks_blockid_linear2.lora_down.weight": "single_blocks.blockid.proj_out.lora_A.default.weight",
+ "lora_unet_single_blocks_blockid_linear2.lora_up.weight": "single_blocks.blockid.proj_out.lora_B.default.weight",
+ }
+ def guess_block_id(name):
+ names = name.split("_")
+ for i in names:
+ if i.isdigit():
+ return i, name.replace(f"_{i}_", "_blockid_")
+ return None, None
+ state_dict_ = {}
+ for name, param in state_dict.items():
+ block_id, source_name = guess_block_id(name)
+ if source_name in rename_dict:
+ target_name = rename_dict[source_name]
+ target_name = target_name.replace(".blockid.", f".{block_id}.")
+ state_dict_[target_name] = param
+ else:
+ state_dict_[name] = param
+ return state_dict_
+
+
+class WanLoRAConverter:
+ def __init__(self):
+ pass
+
+ @staticmethod
+ def align_to_opensource_format(state_dict, **kwargs):
+ state_dict = {"diffusion_model." + name.replace(".default.", "."): param for name, param in state_dict.items()}
+ return state_dict
+
+ @staticmethod
+ def align_to_diffsynth_format(state_dict, **kwargs):
+ state_dict = {name.replace("diffusion_model.", "").replace(".lora_A.weight", ".lora_A.default.weight").replace(".lora_B.weight", ".lora_B.default.weight"): param for name, param in state_dict.items()}
+ return state_dict
+
+
+def get_lora_loaders():
+ return [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), FluxLoRAFromCivitai(), HunyuanVideoLoRAFromCivitai(), GeneralLoRAFromPeft()]
diff --git a/PusaV1/diffsynth/models/model_manager.py b/PusaV1/diffsynth/models/model_manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ae3c50ad7c6fe89515b789ef02b0bcb12d1714e
--- /dev/null
+++ b/PusaV1/diffsynth/models/model_manager.py
@@ -0,0 +1,454 @@
+import os, torch, json, importlib
+from typing import List
+
+from .downloader import download_models, download_customized_models, Preset_model_id, Preset_model_website
+
+from .sd_text_encoder import SDTextEncoder
+from .sd_unet import SDUNet
+from .sd_vae_encoder import SDVAEEncoder
+from .sd_vae_decoder import SDVAEDecoder
+from .lora import get_lora_loaders
+
+from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
+from .sdxl_unet import SDXLUNet
+from .sdxl_vae_decoder import SDXLVAEDecoder
+from .sdxl_vae_encoder import SDXLVAEEncoder
+
+from .sd3_text_encoder import SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3
+from .sd3_dit import SD3DiT
+from .sd3_vae_decoder import SD3VAEDecoder
+from .sd3_vae_encoder import SD3VAEEncoder
+
+from .sd_controlnet import SDControlNet
+from .sdxl_controlnet import SDXLControlNetUnion
+
+from .sd_motion import SDMotionModel
+from .sdxl_motion import SDXLMotionModel
+
+from .svd_image_encoder import SVDImageEncoder
+from .svd_unet import SVDUNet
+from .svd_vae_decoder import SVDVAEDecoder
+from .svd_vae_encoder import SVDVAEEncoder
+
+from .sd_ipadapter import SDIpAdapter, IpAdapterCLIPImageEmbedder
+from .sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
+
+from .hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
+from .hunyuan_dit import HunyuanDiT
+from .hunyuan_video_vae_decoder import HunyuanVideoVAEDecoder
+from .hunyuan_video_vae_encoder import HunyuanVideoVAEEncoder
+
+from .flux_dit import FluxDiT
+from .flux_text_encoder import FluxTextEncoder2
+from .flux_vae import FluxVAEEncoder, FluxVAEDecoder
+from .flux_ipadapter import FluxIpAdapter
+
+from .cog_vae import CogVAEEncoder, CogVAEDecoder
+from .cog_dit import CogDiT
+
+from ..extensions.RIFE import IFNet
+from ..extensions.ESRGAN import RRDBNet
+
+from ..configs.model_config import model_loader_configs, huggingface_model_loader_configs, patch_model_loader_configs
+from .utils import load_state_dict, init_weights_on_device, hash_state_dict_keys, split_state_dict_with_prefix
+
+
+def load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device):
+ loaded_model_names, loaded_models = [], []
+ for model_name, model_class in zip(model_names, model_classes):
+ print(f" model_name: {model_name} model_class: {model_class.__name__}")
+ state_dict_converter = model_class.state_dict_converter()
+ if model_resource == "civitai":
+ state_dict_results = state_dict_converter.from_civitai(state_dict)
+ elif model_resource == "diffusers":
+ state_dict_results = state_dict_converter.from_diffusers(state_dict)
+ if isinstance(state_dict_results, tuple):
+ model_state_dict, extra_kwargs = state_dict_results
+ print(f" This model is initialized with extra kwargs: {extra_kwargs}")
+ else:
+ model_state_dict, extra_kwargs = state_dict_results, {}
+ torch_dtype = torch.float32 if extra_kwargs.get("upcast_to_float32", False) else torch_dtype
+ with init_weights_on_device():
+ model = model_class(**extra_kwargs)
+ if hasattr(model, "eval"):
+ model = model.eval()
+ model.load_state_dict(model_state_dict, assign=True)
+ model = model.to(dtype=torch_dtype, device=device)
+ loaded_model_names.append(model_name)
+ loaded_models.append(model)
+ return loaded_model_names, loaded_models
+
+
+def load_model_from_huggingface_folder(file_path, model_names, model_classes, torch_dtype, device):
+ loaded_model_names, loaded_models = [], []
+ for model_name, model_class in zip(model_names, model_classes):
+ if torch_dtype in [torch.float32, torch.float16, torch.bfloat16]:
+ model = model_class.from_pretrained(file_path, torch_dtype=torch_dtype).eval()
+ else:
+ model = model_class.from_pretrained(file_path).eval().to(dtype=torch_dtype)
+ if torch_dtype == torch.float16 and hasattr(model, "half"):
+ model = model.half()
+ try:
+ model = model.to(device=device)
+ except:
+ pass
+ loaded_model_names.append(model_name)
+ loaded_models.append(model)
+ return loaded_model_names, loaded_models
+
+
+def load_single_patch_model_from_single_file(state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device):
+ print(f" model_name: {model_name} model_class: {model_class.__name__} extra_kwargs: {extra_kwargs}")
+ base_state_dict = base_model.state_dict()
+ base_model.to("cpu")
+ del base_model
+ model = model_class(**extra_kwargs)
+ model.load_state_dict(base_state_dict, strict=False)
+ model.load_state_dict(state_dict, strict=False)
+ model.to(dtype=torch_dtype, device=device)
+ return model
+
+
+def load_patch_model_from_single_file(state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device):
+ loaded_model_names, loaded_models = [], []
+ for model_name, model_class in zip(model_names, model_classes):
+ while True:
+ for model_id in range(len(model_manager.model)):
+ base_model_name = model_manager.model_name[model_id]
+ if base_model_name == model_name:
+ base_model_path = model_manager.model_path[model_id]
+ base_model = model_manager.model[model_id]
+ print(f" Adding patch model to {base_model_name} ({base_model_path})")
+ patched_model = load_single_patch_model_from_single_file(
+ state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device)
+ loaded_model_names.append(base_model_name)
+ loaded_models.append(patched_model)
+ model_manager.model.pop(model_id)
+ model_manager.model_path.pop(model_id)
+ model_manager.model_name.pop(model_id)
+ break
+ else:
+ break
+ return loaded_model_names, loaded_models
+
+
+
+class ModelDetectorTemplate:
+ def __init__(self):
+ pass
+
+ def match(self, file_path="", state_dict={}):
+ return False
+
+ def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
+ return [], []
+
+
+
+class ModelDetectorFromSingleFile:
+ def __init__(self, model_loader_configs=[]):
+ self.keys_hash_with_shape_dict = {}
+ self.keys_hash_dict = {}
+ for metadata in model_loader_configs:
+ self.add_model_metadata(*metadata)
+
+
+ def add_model_metadata(self, keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource):
+ self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_names, model_classes, model_resource)
+ if keys_hash is not None:
+ self.keys_hash_dict[keys_hash] = (model_names, model_classes, model_resource)
+
+
+ def match(self, file_path="", state_dict={}):
+ if isinstance(file_path, str) and os.path.isdir(file_path):
+ return False
+ if len(state_dict) == 0:
+ state_dict = load_state_dict(file_path)
+ keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
+ if keys_hash_with_shape in self.keys_hash_with_shape_dict:
+ return True
+ keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
+ if keys_hash in self.keys_hash_dict:
+ return True
+ return False
+
+
+ def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
+ if len(state_dict) == 0:
+ state_dict = load_state_dict(file_path)
+
+ # Load models with strict matching
+ keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
+ if keys_hash_with_shape in self.keys_hash_with_shape_dict:
+ model_names, model_classes, model_resource = self.keys_hash_with_shape_dict[keys_hash_with_shape]
+ loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device)
+ return loaded_model_names, loaded_models
+
+ # Load models without strict matching
+ # (the shape of parameters may be inconsistent, and the state_dict_converter will modify the model architecture)
+ keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
+ if keys_hash in self.keys_hash_dict:
+ model_names, model_classes, model_resource = self.keys_hash_dict[keys_hash]
+ loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device)
+ return loaded_model_names, loaded_models
+
+ return loaded_model_names, loaded_models
+
+
+
+class ModelDetectorFromSplitedSingleFile(ModelDetectorFromSingleFile):
+ def __init__(self, model_loader_configs=[]):
+ super().__init__(model_loader_configs)
+
+
+ def match(self, file_path="", state_dict={}):
+ if isinstance(file_path, str) and os.path.isdir(file_path):
+ return False
+ if len(state_dict) == 0:
+ state_dict = load_state_dict(file_path)
+ splited_state_dict = split_state_dict_with_prefix(state_dict)
+ for sub_state_dict in splited_state_dict:
+ if super().match(file_path, sub_state_dict):
+ return True
+ return False
+
+
+ def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
+ # Split the state_dict and load from each component
+ splited_state_dict = split_state_dict_with_prefix(state_dict)
+ valid_state_dict = {}
+ for sub_state_dict in splited_state_dict:
+ if super().match(file_path, sub_state_dict):
+ valid_state_dict.update(sub_state_dict)
+ if super().match(file_path, valid_state_dict):
+ loaded_model_names, loaded_models = super().load(file_path, valid_state_dict, device, torch_dtype)
+ else:
+ loaded_model_names, loaded_models = [], []
+ for sub_state_dict in splited_state_dict:
+ if super().match(file_path, sub_state_dict):
+ loaded_model_names_, loaded_models_ = super().load(file_path, valid_state_dict, device, torch_dtype)
+ loaded_model_names += loaded_model_names_
+ loaded_models += loaded_models_
+ return loaded_model_names, loaded_models
+
+
+
+class ModelDetectorFromHuggingfaceFolder:
+ def __init__(self, model_loader_configs=[]):
+ self.architecture_dict = {}
+ for metadata in model_loader_configs:
+ self.add_model_metadata(*metadata)
+
+
+ def add_model_metadata(self, architecture, huggingface_lib, model_name, redirected_architecture):
+ self.architecture_dict[architecture] = (huggingface_lib, model_name, redirected_architecture)
+
+
+ def match(self, file_path="", state_dict={}):
+ if not isinstance(file_path, str) or os.path.isfile(file_path):
+ return False
+ file_list = os.listdir(file_path)
+ if "config.json" not in file_list:
+ return False
+ with open(os.path.join(file_path, "config.json"), "r") as f:
+ config = json.load(f)
+ if "architectures" not in config and "_class_name" not in config:
+ return False
+ return True
+
+
+ def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
+ with open(os.path.join(file_path, "config.json"), "r") as f:
+ config = json.load(f)
+ loaded_model_names, loaded_models = [], []
+ architectures = config["architectures"] if "architectures" in config else [config["_class_name"]]
+ for architecture in architectures:
+ huggingface_lib, model_name, redirected_architecture = self.architecture_dict[architecture]
+ if redirected_architecture is not None:
+ architecture = redirected_architecture
+ model_class = importlib.import_module(huggingface_lib).__getattribute__(architecture)
+ loaded_model_names_, loaded_models_ = load_model_from_huggingface_folder(file_path, [model_name], [model_class], torch_dtype, device)
+ loaded_model_names += loaded_model_names_
+ loaded_models += loaded_models_
+ return loaded_model_names, loaded_models
+
+
+
+class ModelDetectorFromPatchedSingleFile:
+ def __init__(self, model_loader_configs=[]):
+ self.keys_hash_with_shape_dict = {}
+ for metadata in model_loader_configs:
+ self.add_model_metadata(*metadata)
+
+
+ def add_model_metadata(self, keys_hash_with_shape, model_name, model_class, extra_kwargs):
+ self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_name, model_class, extra_kwargs)
+
+
+ def match(self, file_path="", state_dict={}):
+ if not isinstance(file_path, str) or os.path.isdir(file_path):
+ return False
+ if len(state_dict) == 0:
+ state_dict = load_state_dict(file_path)
+ keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
+ if keys_hash_with_shape in self.keys_hash_with_shape_dict:
+ return True
+ return False
+
+
+ def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, model_manager=None, **kwargs):
+ if len(state_dict) == 0:
+ state_dict = load_state_dict(file_path)
+
+ # Load models with strict matching
+ loaded_model_names, loaded_models = [], []
+ keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
+ if keys_hash_with_shape in self.keys_hash_with_shape_dict:
+ model_names, model_classes, extra_kwargs = self.keys_hash_with_shape_dict[keys_hash_with_shape]
+ loaded_model_names_, loaded_models_ = load_patch_model_from_single_file(
+ state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device)
+ loaded_model_names += loaded_model_names_
+ loaded_models += loaded_models_
+ return loaded_model_names, loaded_models
+
+
+
+class ModelManager:
+ def __init__(
+ self,
+ torch_dtype=torch.float16,
+ device="cuda",
+ model_id_list: List[Preset_model_id] = [],
+ downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
+ file_path_list: List[str] = [],
+ ):
+ self.torch_dtype = torch_dtype
+ self.device = device
+ self.model = []
+ self.model_path = []
+ self.model_name = []
+ downloaded_files = download_models(model_id_list, downloading_priority) if len(model_id_list) > 0 else []
+ self.model_detector = [
+ ModelDetectorFromSingleFile(model_loader_configs),
+ ModelDetectorFromSplitedSingleFile(model_loader_configs),
+ ModelDetectorFromHuggingfaceFolder(huggingface_model_loader_configs),
+ ModelDetectorFromPatchedSingleFile(patch_model_loader_configs),
+ ]
+ self.load_models(downloaded_files + file_path_list)
+
+
+ def load_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], model_resource=None):
+ print(f"Loading models from file: {file_path}")
+ if len(state_dict) == 0:
+ state_dict = load_state_dict(file_path)
+ model_names, models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, self.torch_dtype, self.device)
+ for model_name, model in zip(model_names, models):
+ self.model.append(model)
+ self.model_path.append(file_path)
+ self.model_name.append(model_name)
+ print(f" The following models are loaded: {model_names}.")
+
+
+ def load_model_from_huggingface_folder(self, file_path="", model_names=[], model_classes=[]):
+ print(f"Loading models from folder: {file_path}")
+ model_names, models = load_model_from_huggingface_folder(file_path, model_names, model_classes, self.torch_dtype, self.device)
+ for model_name, model in zip(model_names, models):
+ self.model.append(model)
+ self.model_path.append(file_path)
+ self.model_name.append(model_name)
+ print(f" The following models are loaded: {model_names}.")
+
+
+ def load_patch_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], extra_kwargs={}):
+ print(f"Loading patch models from file: {file_path}")
+ model_names, models = load_patch_model_from_single_file(
+ state_dict, model_names, model_classes, extra_kwargs, self, self.torch_dtype, self.device)
+ for model_name, model in zip(model_names, models):
+ self.model.append(model)
+ self.model_path.append(file_path)
+ self.model_name.append(model_name)
+ print(f" The following patched models are loaded: {model_names}.")
+
+
+ def load_lora(self, file_path="", state_dict={}, lora_alpha=1.0):
+ if isinstance(file_path, list):
+ for file_path_ in file_path:
+ self.load_lora(file_path_, state_dict=state_dict, lora_alpha=lora_alpha)
+ else:
+ print(f"Loading LoRA models from file: {file_path}")
+ is_loaded = False
+ if len(state_dict) == 0:
+ state_dict = load_state_dict(file_path)
+ for model_name, model, model_path in zip(self.model_name, self.model, self.model_path):
+ for lora in get_lora_loaders():
+ match_results = lora.match(model, state_dict)
+ if match_results is not None:
+ print(f" Adding LoRA to {model_name} ({model_path}).")
+ lora_prefix, model_resource = match_results
+ lora.load(model, state_dict, lora_prefix, alpha=lora_alpha, model_resource=model_resource)
+ is_loaded = True
+ break
+ if not is_loaded:
+ print(f" Cannot load LoRA: {file_path}")
+
+
+ def load_model(self, file_path, model_names=None, device=None, torch_dtype=None):
+ print(f"Loading models from: {file_path}")
+ if device is None: device = self.device
+ if torch_dtype is None: torch_dtype = self.torch_dtype
+ if isinstance(file_path, list):
+ state_dict = {}
+ for path in file_path:
+ state_dict.update(load_state_dict(path))
+ elif os.path.isfile(file_path):
+ state_dict = load_state_dict(file_path)
+ else:
+ state_dict = None
+ for model_detector in self.model_detector:
+ if model_detector.match(file_path, state_dict):
+ model_names, models = model_detector.load(
+ file_path, state_dict,
+ device=device, torch_dtype=torch_dtype,
+ allowed_model_names=model_names, model_manager=self
+ )
+ for model_name, model in zip(model_names, models):
+ self.model.append(model)
+ self.model_path.append(file_path)
+ self.model_name.append(model_name)
+ print(f" The following models are loaded: {model_names}.")
+ break
+ else:
+ print(f" We cannot detect the model type. No models are loaded.")
+
+
+ def load_models(self, file_path_list, model_names=None, device=None, torch_dtype=None):
+ for file_path in file_path_list:
+ self.load_model(file_path, model_names, device=device, torch_dtype=torch_dtype)
+
+
+ def fetch_model(self, model_name, file_path=None, require_model_path=False):
+ fetched_models = []
+ fetched_model_paths = []
+ for model, model_path, model_name_ in zip(self.model, self.model_path, self.model_name):
+ if file_path is not None and file_path != model_path:
+ continue
+ if model_name == model_name_:
+ fetched_models.append(model)
+ fetched_model_paths.append(model_path)
+ if len(fetched_models) == 0:
+ print(f"No {model_name} models available.")
+ return None
+ if len(fetched_models) == 1:
+ print(f"Using {model_name} from {fetched_model_paths[0]}.")
+ else:
+ print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[0]}.")
+ if require_model_path:
+ return fetched_models[0], fetched_model_paths[0]
+ else:
+ return fetched_models[0]
+
+
+ def to(self, device):
+ for model in self.model:
+ model.to(device)
+
diff --git a/PusaV1/diffsynth/models/model_manager_I2V.py b/PusaV1/diffsynth/models/model_manager_I2V.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ff01216c10249a665ba2e5b2c864c92dde827f3
--- /dev/null
+++ b/PusaV1/diffsynth/models/model_manager_I2V.py
@@ -0,0 +1,521 @@
+import os, torch, json, importlib
+from typing import List
+import torch.nn as nn
+
+from .downloader import download_models, download_customized_models, Preset_model_id, Preset_model_website
+
+from .sd_text_encoder import SDTextEncoder
+from .sd_unet import SDUNet
+from .sd_vae_encoder import SDVAEEncoder
+from .sd_vae_decoder import SDVAEDecoder
+from .lora import get_lora_loaders
+
+from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
+from .sdxl_unet import SDXLUNet
+from .sdxl_vae_decoder import SDXLVAEDecoder
+from .sdxl_vae_encoder import SDXLVAEEncoder
+
+from .sd3_text_encoder import SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3
+from .sd3_dit import SD3DiT
+from .sd3_vae_decoder import SD3VAEDecoder
+from .sd3_vae_encoder import SD3VAEEncoder
+
+from .sd_controlnet import SDControlNet
+from .sdxl_controlnet import SDXLControlNetUnion
+
+from .sd_motion import SDMotionModel
+from .sdxl_motion import SDXLMotionModel
+
+from .svd_image_encoder import SVDImageEncoder
+from .svd_unet import SVDUNet
+from .svd_vae_decoder import SVDVAEDecoder
+from .svd_vae_encoder import SVDVAEEncoder
+
+from .sd_ipadapter import SDIpAdapter, IpAdapterCLIPImageEmbedder
+from .sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
+
+from .hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
+from .hunyuan_dit import HunyuanDiT
+from .hunyuan_video_vae_decoder import HunyuanVideoVAEDecoder
+from .hunyuan_video_vae_encoder import HunyuanVideoVAEEncoder
+
+from .flux_dit import FluxDiT
+from .flux_text_encoder import FluxTextEncoder2
+from .flux_vae import FluxVAEEncoder, FluxVAEDecoder
+from .flux_ipadapter import FluxIpAdapter
+
+from .cog_vae import CogVAEEncoder, CogVAEDecoder
+from .cog_dit import CogDiT
+
+from ..extensions.RIFE import IFNet
+from ..extensions.ESRGAN import RRDBNet
+
+from ..configs.model_config import model_loader_configs, huggingface_model_loader_configs, patch_model_loader_configs
+from .utils import load_state_dict, init_weights_on_device, hash_state_dict_keys, split_state_dict_with_prefix
+
+
+# def load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device):
+# loaded_model_names, loaded_models = [], []
+# for model_name, model_class in zip(model_names, model_classes):
+# print(f" model_name: {model_name} model_class: {model_class.__name__}")
+# state_dict_converter = model_class.state_dict_converter()
+# if model_resource == "civitai":
+# state_dict_results = state_dict_converter.from_civitai(state_dict)
+# elif model_resource == "diffusers":
+# state_dict_results = state_dict_converter.from_diffusers(state_dict)
+# if isinstance(state_dict_results, tuple):
+# model_state_dict, extra_kwargs = state_dict_results
+# print(f" This model is initialized with extra kwargs: {extra_kwargs}")
+# else:
+# model_state_dict, extra_kwargs = state_dict_results, {}
+# torch_dtype = torch.float32 if extra_kwargs.get("upcast_to_float32", False) else torch_dtype
+# with init_weights_on_device():
+# model = model_class(**extra_kwargs)
+# if hasattr(model, "eval"):
+# model = model.eval()
+# model.load_state_dict(model_state_dict, assign=True)
+# model = model.to(dtype=torch_dtype, device=device)
+# loaded_model_names.append(model_name)
+# loaded_models.append(model)
+# return loaded_model_names, loaded_models
+
+# load I2V from t2v ckpt
+def load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device):
+ loaded_model_names, loaded_models = [], []
+ for model_name, model_class in zip(model_names, model_classes):
+ print(f" model_name: {model_name} model_class: {model_class.__name__}")
+ state_dict_converter = model_class.state_dict_converter()
+ if model_resource == "civitai":
+ state_dict_results = state_dict_converter.from_civitai(state_dict)
+ elif model_resource == "diffusers":
+ state_dict_results = state_dict_converter.from_diffusers(state_dict)
+ if isinstance(state_dict_results, tuple):
+ model_state_dict, extra_kwargs = state_dict_results
+
+ print(f" This model is initialized with extra kwargs: {extra_kwargs}")
+ else:
+ model_state_dict, extra_kwargs = state_dict_results, {}
+ torch_dtype = torch.float32 if extra_kwargs.get("upcast_to_float32", False) else torch_dtype
+ extra_kwargs = {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06} #TODO I2V
+ print(f" 2 This model is initialized with extra kwargs: {extra_kwargs}")
+ with init_weights_on_device():
+ model = model_class(**extra_kwargs)
+ if hasattr(model, "eval"):
+ model = model.eval()
+
+ # Handle size mismatches for specific parameters
+ if 'patch_embedding.weight' in model_state_dict and hasattr(model, 'patch_embedding'):
+ checkpoint_shape = model_state_dict['patch_embedding.weight'].shape
+ model_shape = model.patch_embedding.weight.shape
+ if checkpoint_shape != model_shape:
+ print(f"Warning: Size mismatch for patch_embedding.weight. Checkpoint: {checkpoint_shape}, Model: {model_shape}")
+ # Option 1: Skip loading this parameter
+ # del model_state_dict['patch_embedding.weight']
+ # Option 2: Resize the parameter (example for the specific case)
+ if checkpoint_shape[1] < model_shape[1]:
+ # Expand the tensor to match the model's shape
+ old_weight = model_state_dict['patch_embedding.weight']
+ new_weight = torch.zeros(model_shape, device=old_weight.device, dtype=old_weight.dtype)
+ new_weight[:, :checkpoint_shape[1], ...] = old_weight
+ model_state_dict['patch_embedding.weight'] = new_weight
+ print(f"Resized patch_embedding.weight to match model dimensions")
+
+ model.load_state_dict(model_state_dict, assign=True, strict=False) #TODO I2V
+
+
+ # After loading the state dict
+ for name, param in model.named_parameters():
+ if param.is_meta:
+ print(f"Warning: Parameter {name} is still a meta tensor after loading")
+ # Create a new tensor with the same shape and dtype
+ new_param = torch.zeros(param.shape, dtype=param.dtype, device=device)
+ # Use nn.Parameter to maintain requires_grad property
+ new_param = nn.Parameter(new_param, requires_grad=param.requires_grad)
+ # Replace the parameter in the model
+ module_path, param_name = name.rsplit('.', 1)
+ module = model
+ for part in module_path.split('.'):
+ if part.isdigit():
+ module = module[int(part)]
+ else:
+ module = getattr(module, part)
+ setattr(module, param_name, new_param)
+
+ model = model.to(dtype=torch_dtype, device=device)
+ loaded_model_names.append(model_name)
+ loaded_models.append(model)
+ return loaded_model_names, loaded_models
+
+def load_model_from_huggingface_folder(file_path, model_names, model_classes, torch_dtype, device):
+ loaded_model_names, loaded_models = [], []
+ for model_name, model_class in zip(model_names, model_classes):
+ if torch_dtype in [torch.float32, torch.float16, torch.bfloat16]:
+ model = model_class.from_pretrained(file_path, torch_dtype=torch_dtype).eval()
+ else:
+ model = model_class.from_pretrained(file_path).eval().to(dtype=torch_dtype)
+ if torch_dtype == torch.float16 and hasattr(model, "half"):
+ model = model.half()
+ try:
+ model = model.to(device=device)
+ except:
+ pass
+ loaded_model_names.append(model_name)
+ loaded_models.append(model)
+ return loaded_model_names, loaded_models
+
+
+def load_single_patch_model_from_single_file(state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device):
+ print(f" model_name: {model_name} model_class: {model_class.__name__} extra_kwargs: {extra_kwargs}")
+ base_state_dict = base_model.state_dict()
+ base_model.to("cpu")
+ del base_model
+ model = model_class(**extra_kwargs)
+ model.load_state_dict(base_state_dict, strict=False)
+ model.load_state_dict(state_dict, strict=False)
+ model.to(dtype=torch_dtype, device=device)
+ return model
+
+
+def load_patch_model_from_single_file(state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device):
+ loaded_model_names, loaded_models = [], []
+ for model_name, model_class in zip(model_names, model_classes):
+ while True:
+ for model_id in range(len(model_manager.model)):
+ base_model_name = model_manager.model_name[model_id]
+ if base_model_name == model_name:
+ base_model_path = model_manager.model_path[model_id]
+ base_model = model_manager.model[model_id]
+ print(f" Adding patch model to {base_model_name} ({base_model_path})")
+ patched_model = load_single_patch_model_from_single_file(
+ state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device)
+ loaded_model_names.append(base_model_name)
+ loaded_models.append(patched_model)
+ model_manager.model.pop(model_id)
+ model_manager.model_path.pop(model_id)
+ model_manager.model_name.pop(model_id)
+ break
+ else:
+ break
+ return loaded_model_names, loaded_models
+
+
+
+class ModelDetectorTemplate:
+ def __init__(self):
+ pass
+
+ def match(self, file_path="", state_dict={}):
+ return False
+
+ def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
+ return [], []
+
+
+
+class ModelDetectorFromSingleFile:
+ def __init__(self, model_loader_configs=[]):
+ self.keys_hash_with_shape_dict = {}
+ self.keys_hash_dict = {}
+ for metadata in model_loader_configs:
+ self.add_model_metadata(*metadata)
+
+
+ def add_model_metadata(self, keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource):
+ self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_names, model_classes, model_resource)
+ if keys_hash is not None:
+ self.keys_hash_dict[keys_hash] = (model_names, model_classes, model_resource)
+
+
+ def match(self, file_path="", state_dict={}):
+ if isinstance(file_path, str) and os.path.isdir(file_path):
+ return False
+ if len(state_dict) == 0:
+ state_dict = load_state_dict(file_path)
+ keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
+ if keys_hash_with_shape in self.keys_hash_with_shape_dict:
+ return True
+ keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
+ if keys_hash in self.keys_hash_dict:
+ return True
+ return False
+
+
+ def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
+ if len(state_dict) == 0:
+ state_dict = load_state_dict(file_path)
+
+ # Load models with strict matching
+ keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
+ if keys_hash_with_shape in self.keys_hash_with_shape_dict:
+ model_names, model_classes, model_resource = self.keys_hash_with_shape_dict[keys_hash_with_shape]
+ loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device)
+ return loaded_model_names, loaded_models
+
+ # Load models without strict matching
+ # (the shape of parameters may be inconsistent, and the state_dict_converter will modify the model architecture)
+ keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
+ if keys_hash in self.keys_hash_dict:
+ model_names, model_classes, model_resource = self.keys_hash_dict[keys_hash]
+ loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device)
+ return loaded_model_names, loaded_models
+
+ return loaded_model_names, loaded_models
+
+
+
+class ModelDetectorFromSplitedSingleFile(ModelDetectorFromSingleFile):
+ def __init__(self, model_loader_configs=[]):
+ super().__init__(model_loader_configs)
+
+
+ def match(self, file_path="", state_dict={}):
+ if isinstance(file_path, str) and os.path.isdir(file_path):
+ return False
+ if len(state_dict) == 0:
+ state_dict = load_state_dict(file_path)
+ splited_state_dict = split_state_dict_with_prefix(state_dict)
+ for sub_state_dict in splited_state_dict:
+ if super().match(file_path, sub_state_dict):
+ return True
+ return False
+
+
+ def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
+ # Split the state_dict and load from each component
+ splited_state_dict = split_state_dict_with_prefix(state_dict)
+ valid_state_dict = {}
+ for sub_state_dict in splited_state_dict:
+ if super().match(file_path, sub_state_dict):
+ valid_state_dict.update(sub_state_dict)
+ if super().match(file_path, valid_state_dict):
+ loaded_model_names, loaded_models = super().load(file_path, valid_state_dict, device, torch_dtype)
+ else:
+ loaded_model_names, loaded_models = [], []
+ for sub_state_dict in splited_state_dict:
+ if super().match(file_path, sub_state_dict):
+ loaded_model_names_, loaded_models_ = super().load(file_path, valid_state_dict, device, torch_dtype)
+ loaded_model_names += loaded_model_names_
+ loaded_models += loaded_models_
+ return loaded_model_names, loaded_models
+
+
+
+class ModelDetectorFromHuggingfaceFolder:
+ def __init__(self, model_loader_configs=[]):
+ self.architecture_dict = {}
+ for metadata in model_loader_configs:
+ self.add_model_metadata(*metadata)
+
+
+ def add_model_metadata(self, architecture, huggingface_lib, model_name, redirected_architecture):
+ self.architecture_dict[architecture] = (huggingface_lib, model_name, redirected_architecture)
+
+
+ def match(self, file_path="", state_dict={}):
+ if not isinstance(file_path, str) or os.path.isfile(file_path):
+ return False
+ file_list = os.listdir(file_path)
+ if "config.json" not in file_list:
+ return False
+ with open(os.path.join(file_path, "config.json"), "r") as f:
+ config = json.load(f)
+ if "architectures" not in config and "_class_name" not in config:
+ return False
+ return True
+
+
+ def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
+ with open(os.path.join(file_path, "config.json"), "r") as f:
+ config = json.load(f)
+ loaded_model_names, loaded_models = [], []
+ architectures = config["architectures"] if "architectures" in config else [config["_class_name"]]
+ for architecture in architectures:
+ huggingface_lib, model_name, redirected_architecture = self.architecture_dict[architecture]
+ if redirected_architecture is not None:
+ architecture = redirected_architecture
+ model_class = importlib.import_module(huggingface_lib).__getattribute__(architecture)
+ loaded_model_names_, loaded_models_ = load_model_from_huggingface_folder(file_path, [model_name], [model_class], torch_dtype, device)
+ loaded_model_names += loaded_model_names_
+ loaded_models += loaded_models_
+ return loaded_model_names, loaded_models
+
+
+
+class ModelDetectorFromPatchedSingleFile:
+ def __init__(self, model_loader_configs=[]):
+ self.keys_hash_with_shape_dict = {}
+ for metadata in model_loader_configs:
+ self.add_model_metadata(*metadata)
+
+
+ def add_model_metadata(self, keys_hash_with_shape, model_name, model_class, extra_kwargs):
+ self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_name, model_class, extra_kwargs)
+
+
+ def match(self, file_path="", state_dict={}):
+ if not isinstance(file_path, str) or os.path.isdir(file_path):
+ return False
+ if len(state_dict) == 0:
+ state_dict = load_state_dict(file_path)
+ keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
+ if keys_hash_with_shape in self.keys_hash_with_shape_dict:
+ return True
+ return False
+
+
+ def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, model_manager=None, **kwargs):
+ if len(state_dict) == 0:
+ state_dict = load_state_dict(file_path)
+
+ # Load models with strict matching
+ loaded_model_names, loaded_models = [], []
+ keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
+ if keys_hash_with_shape in self.keys_hash_with_shape_dict:
+ model_names, model_classes, extra_kwargs = self.keys_hash_with_shape_dict[keys_hash_with_shape]
+ loaded_model_names_, loaded_models_ = load_patch_model_from_single_file(
+ state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device)
+ loaded_model_names += loaded_model_names_
+ loaded_models += loaded_models_
+ return loaded_model_names, loaded_models
+
+
+
+class ModelManager:
+ def __init__(
+ self,
+ torch_dtype=torch.float16,
+ device="cuda",
+ model_id_list: List[Preset_model_id] = [],
+ downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
+ file_path_list: List[str] = [],
+ ):
+ self.torch_dtype = torch_dtype
+ self.device = device
+ self.model = []
+ self.model_path = []
+ self.model_name = []
+ downloaded_files = download_models(model_id_list, downloading_priority) if len(model_id_list) > 0 else []
+ self.model_detector = [
+ ModelDetectorFromSingleFile(model_loader_configs),
+ ModelDetectorFromSplitedSingleFile(model_loader_configs),
+ ModelDetectorFromHuggingfaceFolder(huggingface_model_loader_configs),
+ ModelDetectorFromPatchedSingleFile(patch_model_loader_configs),
+ ]
+ self.load_models(downloaded_files + file_path_list)
+
+
+ def load_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], model_resource=None):
+ print(f"Loading models from file: {file_path}")
+ if len(state_dict) == 0:
+ state_dict = load_state_dict(file_path)
+ model_names, models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, self.torch_dtype, self.device)
+ for model_name, model in zip(model_names, models):
+ self.model.append(model)
+ self.model_path.append(file_path)
+ self.model_name.append(model_name)
+ print(f" The following models are loaded: {model_names}.")
+
+
+ def load_model_from_huggingface_folder(self, file_path="", model_names=[], model_classes=[]):
+ print(f"Loading models from folder: {file_path}")
+ model_names, models = load_model_from_huggingface_folder(file_path, model_names, model_classes, self.torch_dtype, self.device)
+ for model_name, model in zip(model_names, models):
+ self.model.append(model)
+ self.model_path.append(file_path)
+ self.model_name.append(model_name)
+ print(f" The following models are loaded: {model_names}.")
+
+
+ def load_patch_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], extra_kwargs={}):
+ print(f"Loading patch models from file: {file_path}")
+ model_names, models = load_patch_model_from_single_file(
+ state_dict, model_names, model_classes, extra_kwargs, self, self.torch_dtype, self.device)
+ for model_name, model in zip(model_names, models):
+ self.model.append(model)
+ self.model_path.append(file_path)
+ self.model_name.append(model_name)
+ print(f" The following patched models are loaded: {model_names}.")
+
+
+ def load_lora(self, file_path="", state_dict={}, lora_alpha=1.0):
+ if isinstance(file_path, list):
+ for file_path_ in file_path:
+ self.load_lora(file_path_, state_dict=state_dict, lora_alpha=lora_alpha)
+ else:
+ print(f"Loading LoRA models from file: {file_path}")
+ is_loaded = False
+ if len(state_dict) == 0:
+ state_dict = load_state_dict(file_path)
+ for model_name, model, model_path in zip(self.model_name, self.model, self.model_path):
+ for lora in get_lora_loaders():
+ match_results = lora.match(model, state_dict)
+ if match_results is not None:
+ print(f" Adding LoRA to {model_name} ({model_path}).")
+ lora_prefix, model_resource = match_results
+ lora.load(model, state_dict, lora_prefix, alpha=lora_alpha, model_resource=model_resource)
+ is_loaded = True
+ break
+ if not is_loaded:
+ print(f" Cannot load LoRA: {file_path}")
+
+
+ def load_model(self, file_path, model_names=None, device=None, torch_dtype=None):
+ print(f"Loading models from: {file_path}")
+ if device is None: device = self.device
+ if torch_dtype is None: torch_dtype = self.torch_dtype
+ if isinstance(file_path, list):
+ state_dict = {}
+ for path in file_path:
+ state_dict.update(load_state_dict(path))
+ elif os.path.isfile(file_path):
+ state_dict = load_state_dict(file_path)
+ else:
+ state_dict = None
+ for model_detector in self.model_detector:
+ if model_detector.match(file_path, state_dict):
+ model_names, models = model_detector.load(
+ file_path, state_dict,
+ device=device, torch_dtype=torch_dtype,
+ allowed_model_names=model_names, model_manager=self
+ )
+ for model_name, model in zip(model_names, models):
+ self.model.append(model)
+ self.model_path.append(file_path)
+ self.model_name.append(model_name)
+ print(f" The following models are loaded: {model_names}.")
+ break
+ else:
+ print(f" We cannot detect the model type. No models are loaded.")
+
+
+ def load_models(self, file_path_list, model_names=None, device=None, torch_dtype=None):
+ for file_path in file_path_list:
+ self.load_model(file_path, model_names, device=device, torch_dtype=torch_dtype)
+
+
+ def fetch_model(self, model_name, file_path=None, require_model_path=False):
+ fetched_models = []
+ fetched_model_paths = []
+ for model, model_path, model_name_ in zip(self.model, self.model_path, self.model_name):
+ if file_path is not None and file_path != model_path:
+ continue
+ if model_name == model_name_:
+ fetched_models.append(model)
+ fetched_model_paths.append(model_path)
+ if len(fetched_models) == 0:
+ print(f"No {model_name} models available.")
+ return None
+ if len(fetched_models) == 1:
+ print(f"Using {model_name} from {fetched_model_paths[0]}.")
+ else:
+ print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[0]}.")
+ if require_model_path:
+ return fetched_models[0], fetched_model_paths[0]
+ else:
+ return fetched_models[0]
+
+
+ def to(self, device):
+ for model in self.model:
+ model.to(device)
+
diff --git a/PusaV1/diffsynth/models/model_manager_ori.py b/PusaV1/diffsynth/models/model_manager_ori.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ae3c50ad7c6fe89515b789ef02b0bcb12d1714e
--- /dev/null
+++ b/PusaV1/diffsynth/models/model_manager_ori.py
@@ -0,0 +1,454 @@
+import os, torch, json, importlib
+from typing import List
+
+from .downloader import download_models, download_customized_models, Preset_model_id, Preset_model_website
+
+from .sd_text_encoder import SDTextEncoder
+from .sd_unet import SDUNet
+from .sd_vae_encoder import SDVAEEncoder
+from .sd_vae_decoder import SDVAEDecoder
+from .lora import get_lora_loaders
+
+from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
+from .sdxl_unet import SDXLUNet
+from .sdxl_vae_decoder import SDXLVAEDecoder
+from .sdxl_vae_encoder import SDXLVAEEncoder
+
+from .sd3_text_encoder import SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3
+from .sd3_dit import SD3DiT
+from .sd3_vae_decoder import SD3VAEDecoder
+from .sd3_vae_encoder import SD3VAEEncoder
+
+from .sd_controlnet import SDControlNet
+from .sdxl_controlnet import SDXLControlNetUnion
+
+from .sd_motion import SDMotionModel
+from .sdxl_motion import SDXLMotionModel
+
+from .svd_image_encoder import SVDImageEncoder
+from .svd_unet import SVDUNet
+from .svd_vae_decoder import SVDVAEDecoder
+from .svd_vae_encoder import SVDVAEEncoder
+
+from .sd_ipadapter import SDIpAdapter, IpAdapterCLIPImageEmbedder
+from .sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
+
+from .hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
+from .hunyuan_dit import HunyuanDiT
+from .hunyuan_video_vae_decoder import HunyuanVideoVAEDecoder
+from .hunyuan_video_vae_encoder import HunyuanVideoVAEEncoder
+
+from .flux_dit import FluxDiT
+from .flux_text_encoder import FluxTextEncoder2
+from .flux_vae import FluxVAEEncoder, FluxVAEDecoder
+from .flux_ipadapter import FluxIpAdapter
+
+from .cog_vae import CogVAEEncoder, CogVAEDecoder
+from .cog_dit import CogDiT
+
+from ..extensions.RIFE import IFNet
+from ..extensions.ESRGAN import RRDBNet
+
+from ..configs.model_config import model_loader_configs, huggingface_model_loader_configs, patch_model_loader_configs
+from .utils import load_state_dict, init_weights_on_device, hash_state_dict_keys, split_state_dict_with_prefix
+
+
+def load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device):
+ loaded_model_names, loaded_models = [], []
+ for model_name, model_class in zip(model_names, model_classes):
+ print(f" model_name: {model_name} model_class: {model_class.__name__}")
+ state_dict_converter = model_class.state_dict_converter()
+ if model_resource == "civitai":
+ state_dict_results = state_dict_converter.from_civitai(state_dict)
+ elif model_resource == "diffusers":
+ state_dict_results = state_dict_converter.from_diffusers(state_dict)
+ if isinstance(state_dict_results, tuple):
+ model_state_dict, extra_kwargs = state_dict_results
+ print(f" This model is initialized with extra kwargs: {extra_kwargs}")
+ else:
+ model_state_dict, extra_kwargs = state_dict_results, {}
+ torch_dtype = torch.float32 if extra_kwargs.get("upcast_to_float32", False) else torch_dtype
+ with init_weights_on_device():
+ model = model_class(**extra_kwargs)
+ if hasattr(model, "eval"):
+ model = model.eval()
+ model.load_state_dict(model_state_dict, assign=True)
+ model = model.to(dtype=torch_dtype, device=device)
+ loaded_model_names.append(model_name)
+ loaded_models.append(model)
+ return loaded_model_names, loaded_models
+
+
+def load_model_from_huggingface_folder(file_path, model_names, model_classes, torch_dtype, device):
+ loaded_model_names, loaded_models = [], []
+ for model_name, model_class in zip(model_names, model_classes):
+ if torch_dtype in [torch.float32, torch.float16, torch.bfloat16]:
+ model = model_class.from_pretrained(file_path, torch_dtype=torch_dtype).eval()
+ else:
+ model = model_class.from_pretrained(file_path).eval().to(dtype=torch_dtype)
+ if torch_dtype == torch.float16 and hasattr(model, "half"):
+ model = model.half()
+ try:
+ model = model.to(device=device)
+ except:
+ pass
+ loaded_model_names.append(model_name)
+ loaded_models.append(model)
+ return loaded_model_names, loaded_models
+
+
+def load_single_patch_model_from_single_file(state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device):
+ print(f" model_name: {model_name} model_class: {model_class.__name__} extra_kwargs: {extra_kwargs}")
+ base_state_dict = base_model.state_dict()
+ base_model.to("cpu")
+ del base_model
+ model = model_class(**extra_kwargs)
+ model.load_state_dict(base_state_dict, strict=False)
+ model.load_state_dict(state_dict, strict=False)
+ model.to(dtype=torch_dtype, device=device)
+ return model
+
+
+def load_patch_model_from_single_file(state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device):
+ loaded_model_names, loaded_models = [], []
+ for model_name, model_class in zip(model_names, model_classes):
+ while True:
+ for model_id in range(len(model_manager.model)):
+ base_model_name = model_manager.model_name[model_id]
+ if base_model_name == model_name:
+ base_model_path = model_manager.model_path[model_id]
+ base_model = model_manager.model[model_id]
+ print(f" Adding patch model to {base_model_name} ({base_model_path})")
+ patched_model = load_single_patch_model_from_single_file(
+ state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device)
+ loaded_model_names.append(base_model_name)
+ loaded_models.append(patched_model)
+ model_manager.model.pop(model_id)
+ model_manager.model_path.pop(model_id)
+ model_manager.model_name.pop(model_id)
+ break
+ else:
+ break
+ return loaded_model_names, loaded_models
+
+
+
+class ModelDetectorTemplate:
+ def __init__(self):
+ pass
+
+ def match(self, file_path="", state_dict={}):
+ return False
+
+ def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
+ return [], []
+
+
+
+class ModelDetectorFromSingleFile:
+ def __init__(self, model_loader_configs=[]):
+ self.keys_hash_with_shape_dict = {}
+ self.keys_hash_dict = {}
+ for metadata in model_loader_configs:
+ self.add_model_metadata(*metadata)
+
+
+ def add_model_metadata(self, keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource):
+ self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_names, model_classes, model_resource)
+ if keys_hash is not None:
+ self.keys_hash_dict[keys_hash] = (model_names, model_classes, model_resource)
+
+
+ def match(self, file_path="", state_dict={}):
+ if isinstance(file_path, str) and os.path.isdir(file_path):
+ return False
+ if len(state_dict) == 0:
+ state_dict = load_state_dict(file_path)
+ keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
+ if keys_hash_with_shape in self.keys_hash_with_shape_dict:
+ return True
+ keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
+ if keys_hash in self.keys_hash_dict:
+ return True
+ return False
+
+
+ def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
+ if len(state_dict) == 0:
+ state_dict = load_state_dict(file_path)
+
+ # Load models with strict matching
+ keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
+ if keys_hash_with_shape in self.keys_hash_with_shape_dict:
+ model_names, model_classes, model_resource = self.keys_hash_with_shape_dict[keys_hash_with_shape]
+ loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device)
+ return loaded_model_names, loaded_models
+
+ # Load models without strict matching
+ # (the shape of parameters may be inconsistent, and the state_dict_converter will modify the model architecture)
+ keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
+ if keys_hash in self.keys_hash_dict:
+ model_names, model_classes, model_resource = self.keys_hash_dict[keys_hash]
+ loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device)
+ return loaded_model_names, loaded_models
+
+ return loaded_model_names, loaded_models
+
+
+
+class ModelDetectorFromSplitedSingleFile(ModelDetectorFromSingleFile):
+ def __init__(self, model_loader_configs=[]):
+ super().__init__(model_loader_configs)
+
+
+ def match(self, file_path="", state_dict={}):
+ if isinstance(file_path, str) and os.path.isdir(file_path):
+ return False
+ if len(state_dict) == 0:
+ state_dict = load_state_dict(file_path)
+ splited_state_dict = split_state_dict_with_prefix(state_dict)
+ for sub_state_dict in splited_state_dict:
+ if super().match(file_path, sub_state_dict):
+ return True
+ return False
+
+
+ def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
+ # Split the state_dict and load from each component
+ splited_state_dict = split_state_dict_with_prefix(state_dict)
+ valid_state_dict = {}
+ for sub_state_dict in splited_state_dict:
+ if super().match(file_path, sub_state_dict):
+ valid_state_dict.update(sub_state_dict)
+ if super().match(file_path, valid_state_dict):
+ loaded_model_names, loaded_models = super().load(file_path, valid_state_dict, device, torch_dtype)
+ else:
+ loaded_model_names, loaded_models = [], []
+ for sub_state_dict in splited_state_dict:
+ if super().match(file_path, sub_state_dict):
+ loaded_model_names_, loaded_models_ = super().load(file_path, valid_state_dict, device, torch_dtype)
+ loaded_model_names += loaded_model_names_
+ loaded_models += loaded_models_
+ return loaded_model_names, loaded_models
+
+
+
+class ModelDetectorFromHuggingfaceFolder:
+ def __init__(self, model_loader_configs=[]):
+ self.architecture_dict = {}
+ for metadata in model_loader_configs:
+ self.add_model_metadata(*metadata)
+
+
+ def add_model_metadata(self, architecture, huggingface_lib, model_name, redirected_architecture):
+ self.architecture_dict[architecture] = (huggingface_lib, model_name, redirected_architecture)
+
+
+ def match(self, file_path="", state_dict={}):
+ if not isinstance(file_path, str) or os.path.isfile(file_path):
+ return False
+ file_list = os.listdir(file_path)
+ if "config.json" not in file_list:
+ return False
+ with open(os.path.join(file_path, "config.json"), "r") as f:
+ config = json.load(f)
+ if "architectures" not in config and "_class_name" not in config:
+ return False
+ return True
+
+
+ def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
+ with open(os.path.join(file_path, "config.json"), "r") as f:
+ config = json.load(f)
+ loaded_model_names, loaded_models = [], []
+ architectures = config["architectures"] if "architectures" in config else [config["_class_name"]]
+ for architecture in architectures:
+ huggingface_lib, model_name, redirected_architecture = self.architecture_dict[architecture]
+ if redirected_architecture is not None:
+ architecture = redirected_architecture
+ model_class = importlib.import_module(huggingface_lib).__getattribute__(architecture)
+ loaded_model_names_, loaded_models_ = load_model_from_huggingface_folder(file_path, [model_name], [model_class], torch_dtype, device)
+ loaded_model_names += loaded_model_names_
+ loaded_models += loaded_models_
+ return loaded_model_names, loaded_models
+
+
+
+class ModelDetectorFromPatchedSingleFile:
+ def __init__(self, model_loader_configs=[]):
+ self.keys_hash_with_shape_dict = {}
+ for metadata in model_loader_configs:
+ self.add_model_metadata(*metadata)
+
+
+ def add_model_metadata(self, keys_hash_with_shape, model_name, model_class, extra_kwargs):
+ self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_name, model_class, extra_kwargs)
+
+
+ def match(self, file_path="", state_dict={}):
+ if not isinstance(file_path, str) or os.path.isdir(file_path):
+ return False
+ if len(state_dict) == 0:
+ state_dict = load_state_dict(file_path)
+ keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
+ if keys_hash_with_shape in self.keys_hash_with_shape_dict:
+ return True
+ return False
+
+
+ def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, model_manager=None, **kwargs):
+ if len(state_dict) == 0:
+ state_dict = load_state_dict(file_path)
+
+ # Load models with strict matching
+ loaded_model_names, loaded_models = [], []
+ keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
+ if keys_hash_with_shape in self.keys_hash_with_shape_dict:
+ model_names, model_classes, extra_kwargs = self.keys_hash_with_shape_dict[keys_hash_with_shape]
+ loaded_model_names_, loaded_models_ = load_patch_model_from_single_file(
+ state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device)
+ loaded_model_names += loaded_model_names_
+ loaded_models += loaded_models_
+ return loaded_model_names, loaded_models
+
+
+
+class ModelManager:
+ def __init__(
+ self,
+ torch_dtype=torch.float16,
+ device="cuda",
+ model_id_list: List[Preset_model_id] = [],
+ downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
+ file_path_list: List[str] = [],
+ ):
+ self.torch_dtype = torch_dtype
+ self.device = device
+ self.model = []
+ self.model_path = []
+ self.model_name = []
+ downloaded_files = download_models(model_id_list, downloading_priority) if len(model_id_list) > 0 else []
+ self.model_detector = [
+ ModelDetectorFromSingleFile(model_loader_configs),
+ ModelDetectorFromSplitedSingleFile(model_loader_configs),
+ ModelDetectorFromHuggingfaceFolder(huggingface_model_loader_configs),
+ ModelDetectorFromPatchedSingleFile(patch_model_loader_configs),
+ ]
+ self.load_models(downloaded_files + file_path_list)
+
+
+ def load_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], model_resource=None):
+ print(f"Loading models from file: {file_path}")
+ if len(state_dict) == 0:
+ state_dict = load_state_dict(file_path)
+ model_names, models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, self.torch_dtype, self.device)
+ for model_name, model in zip(model_names, models):
+ self.model.append(model)
+ self.model_path.append(file_path)
+ self.model_name.append(model_name)
+ print(f" The following models are loaded: {model_names}.")
+
+
+ def load_model_from_huggingface_folder(self, file_path="", model_names=[], model_classes=[]):
+ print(f"Loading models from folder: {file_path}")
+ model_names, models = load_model_from_huggingface_folder(file_path, model_names, model_classes, self.torch_dtype, self.device)
+ for model_name, model in zip(model_names, models):
+ self.model.append(model)
+ self.model_path.append(file_path)
+ self.model_name.append(model_name)
+ print(f" The following models are loaded: {model_names}.")
+
+
+ def load_patch_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], extra_kwargs={}):
+ print(f"Loading patch models from file: {file_path}")
+ model_names, models = load_patch_model_from_single_file(
+ state_dict, model_names, model_classes, extra_kwargs, self, self.torch_dtype, self.device)
+ for model_name, model in zip(model_names, models):
+ self.model.append(model)
+ self.model_path.append(file_path)
+ self.model_name.append(model_name)
+ print(f" The following patched models are loaded: {model_names}.")
+
+
+ def load_lora(self, file_path="", state_dict={}, lora_alpha=1.0):
+ if isinstance(file_path, list):
+ for file_path_ in file_path:
+ self.load_lora(file_path_, state_dict=state_dict, lora_alpha=lora_alpha)
+ else:
+ print(f"Loading LoRA models from file: {file_path}")
+ is_loaded = False
+ if len(state_dict) == 0:
+ state_dict = load_state_dict(file_path)
+ for model_name, model, model_path in zip(self.model_name, self.model, self.model_path):
+ for lora in get_lora_loaders():
+ match_results = lora.match(model, state_dict)
+ if match_results is not None:
+ print(f" Adding LoRA to {model_name} ({model_path}).")
+ lora_prefix, model_resource = match_results
+ lora.load(model, state_dict, lora_prefix, alpha=lora_alpha, model_resource=model_resource)
+ is_loaded = True
+ break
+ if not is_loaded:
+ print(f" Cannot load LoRA: {file_path}")
+
+
+ def load_model(self, file_path, model_names=None, device=None, torch_dtype=None):
+ print(f"Loading models from: {file_path}")
+ if device is None: device = self.device
+ if torch_dtype is None: torch_dtype = self.torch_dtype
+ if isinstance(file_path, list):
+ state_dict = {}
+ for path in file_path:
+ state_dict.update(load_state_dict(path))
+ elif os.path.isfile(file_path):
+ state_dict = load_state_dict(file_path)
+ else:
+ state_dict = None
+ for model_detector in self.model_detector:
+ if model_detector.match(file_path, state_dict):
+ model_names, models = model_detector.load(
+ file_path, state_dict,
+ device=device, torch_dtype=torch_dtype,
+ allowed_model_names=model_names, model_manager=self
+ )
+ for model_name, model in zip(model_names, models):
+ self.model.append(model)
+ self.model_path.append(file_path)
+ self.model_name.append(model_name)
+ print(f" The following models are loaded: {model_names}.")
+ break
+ else:
+ print(f" We cannot detect the model type. No models are loaded.")
+
+
+ def load_models(self, file_path_list, model_names=None, device=None, torch_dtype=None):
+ for file_path in file_path_list:
+ self.load_model(file_path, model_names, device=device, torch_dtype=torch_dtype)
+
+
+ def fetch_model(self, model_name, file_path=None, require_model_path=False):
+ fetched_models = []
+ fetched_model_paths = []
+ for model, model_path, model_name_ in zip(self.model, self.model_path, self.model_name):
+ if file_path is not None and file_path != model_path:
+ continue
+ if model_name == model_name_:
+ fetched_models.append(model)
+ fetched_model_paths.append(model_path)
+ if len(fetched_models) == 0:
+ print(f"No {model_name} models available.")
+ return None
+ if len(fetched_models) == 1:
+ print(f"Using {model_name} from {fetched_model_paths[0]}.")
+ else:
+ print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[0]}.")
+ if require_model_path:
+ return fetched_models[0], fetched_model_paths[0]
+ else:
+ return fetched_models[0]
+
+
+ def to(self, device):
+ for model in self.model:
+ model.to(device)
+
diff --git a/PusaV1/diffsynth/models/omnigen.py b/PusaV1/diffsynth/models/omnigen.py
new file mode 100644
index 0000000000000000000000000000000000000000..571d6c0e71e5bb28cd0c4e56a2d0437dd82be4c0
--- /dev/null
+++ b/PusaV1/diffsynth/models/omnigen.py
@@ -0,0 +1,803 @@
+# The code is revised from DiT
+import os
+import torch
+import torch.nn as nn
+import numpy as np
+import math
+from safetensors.torch import load_file
+from typing import List, Optional, Tuple, Union
+import torch.utils.checkpoint
+from huggingface_hub import snapshot_download
+from transformers.modeling_outputs import BaseModelOutputWithPast
+from transformers import Phi3Config, Phi3Model
+from transformers.cache_utils import Cache, DynamicCache
+from transformers.utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class Phi3Transformer(Phi3Model):
+ """
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Phi3DecoderLayer`]
+ We only modified the attention mask
+ Args:
+ config: Phi3Config
+ """
+ def prefetch_layer(self, layer_idx: int, device: torch.device):
+ "Starts prefetching the next layer cache"
+ with torch.cuda.stream(self.prefetch_stream):
+ # Prefetch next layer tensors to GPU
+ for name, param in self.layers[layer_idx].named_parameters():
+ param.data = param.data.to(device, non_blocking=True)
+
+ def evict_previous_layer(self, layer_idx: int):
+ "Moves the previous layer cache to the CPU"
+ prev_layer_idx = layer_idx - 1
+ for name, param in self.layers[prev_layer_idx].named_parameters():
+ param.data = param.data.to("cpu", non_blocking=True)
+
+ def get_offlaod_layer(self, layer_idx: int, device: torch.device):
+ # init stream
+ if not hasattr(self, "prefetch_stream"):
+ self.prefetch_stream = torch.cuda.Stream()
+
+ # delete previous layer
+ torch.cuda.current_stream().synchronize()
+ self.evict_previous_layer(layer_idx)
+
+ # make sure the current layer is ready
+ torch.cuda.synchronize(self.prefetch_stream)
+
+ # load next layer
+ self.prefetch_layer((layer_idx + 1) % len(self.layers), device)
+
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ offload_model: Optional[bool] = False,
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ # kept for BC (non `Cache` `past_key_values` inputs)
+ return_legacy_cache = False
+ if use_cache and not isinstance(past_key_values, Cache):
+ return_legacy_cache = True
+ if past_key_values is None:
+ past_key_values = DynamicCache()
+ else:
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+ logger.warning_once(
+ "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
+ "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
+ "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
+ )
+
+ # if inputs_embeds is None:
+ # inputs_embeds = self.embed_tokens(input_ids)
+
+ # if cache_position is None:
+ # past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ # cache_position = torch.arange(
+ # past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ # )
+ # if position_ids is None:
+ # position_ids = cache_position.unsqueeze(0)
+
+ if attention_mask is not None and attention_mask.dim() == 3:
+ dtype = inputs_embeds.dtype
+ min_dtype = torch.finfo(dtype).min
+ attention_mask = (1 - attention_mask) * min_dtype
+ attention_mask = attention_mask.unsqueeze(1).to(inputs_embeds.dtype)
+ else:
+ raise Exception("attention_mask parameter was unavailable or invalid")
+ # causal_mask = self._update_causal_mask(
+ # attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
+ # )
+
+ hidden_states = inputs_embeds
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = None
+
+ layer_idx = -1
+ for decoder_layer in self.layers:
+ layer_idx += 1
+
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ decoder_layer.__call__,
+ hidden_states,
+ attention_mask,
+ position_ids,
+ past_key_values,
+ output_attentions,
+ use_cache,
+ cache_position,
+ )
+ else:
+ if offload_model and not self.training:
+ self.get_offlaod_layer(layer_idx, device=inputs_embeds.device)
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ print('************')
+ all_hidden_states += (hidden_states,)
+
+ next_cache = next_decoder_cache if use_cache else None
+ if return_legacy_cache:
+ next_cache = next_cache.to_legacy_cache()
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+
+def modulate(x, shift, scale):
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
+
+
+class TimestepEmbedder(nn.Module):
+ """
+ Embeds scalar timesteps into vector representations.
+ """
+ def __init__(self, hidden_size, frequency_embedding_size=256):
+ super().__init__()
+ self.mlp = nn.Sequential(
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
+ nn.SiLU(),
+ nn.Linear(hidden_size, hidden_size, bias=True),
+ )
+ self.frequency_embedding_size = frequency_embedding_size
+
+ @staticmethod
+ def timestep_embedding(t, dim, max_period=10000):
+ """
+ Create sinusoidal timestep embeddings.
+ :param t: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an (N, D) Tensor of positional embeddings.
+ """
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
+ ).to(device=t.device)
+ args = t[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ return embedding
+
+ def forward(self, t, dtype=torch.float32):
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype)
+ t_emb = self.mlp(t_freq)
+ return t_emb
+
+
+class FinalLayer(nn.Module):
+ """
+ The final layer of DiT.
+ """
+ def __init__(self, hidden_size, patch_size, out_channels):
+ super().__init__()
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
+ self.adaLN_modulation = nn.Sequential(
+ nn.SiLU(),
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True)
+ )
+
+ def forward(self, x, c):
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
+ x = modulate(self.norm_final(x), shift, scale)
+ x = self.linear(x)
+ return x
+
+
+def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=1):
+ """
+ grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
+ [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
+ """
+ if isinstance(grid_size, int):
+ grid_size = (grid_size, grid_size)
+
+ grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale
+ grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
+ grid = np.stack(grid, axis=0)
+
+ grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
+ if cls_token and extra_tokens > 0:
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
+ return pos_embed
+
+
+def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
+ assert embed_dim % 2 == 0
+
+ # use half of dimensions to encode grid_h
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
+
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
+ return emb
+
+
+def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
+ """
+ embed_dim: output dimension for each position
+ pos: a list of positions to be encoded: size (M,)
+ out: (M, D)
+ """
+ assert embed_dim % 2 == 0
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
+ omega /= embed_dim / 2.
+ omega = 1. / 10000**omega # (D/2,)
+
+ pos = pos.reshape(-1) # (M,)
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
+
+ emb_sin = np.sin(out) # (M, D/2)
+ emb_cos = np.cos(out) # (M, D/2)
+
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
+ return emb
+
+
+class PatchEmbedMR(nn.Module):
+ """ 2D Image to Patch Embedding
+ """
+ def __init__(
+ self,
+ patch_size: int = 2,
+ in_chans: int = 4,
+ embed_dim: int = 768,
+ bias: bool = True,
+ ):
+ super().__init__()
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
+
+ def forward(self, x):
+ x = self.proj(x)
+ x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
+ return x
+
+
+class OmniGenOriginalModel(nn.Module):
+ """
+ Diffusion model with a Transformer backbone.
+ """
+ def __init__(
+ self,
+ transformer_config: Phi3Config,
+ patch_size=2,
+ in_channels=4,
+ pe_interpolation: float = 1.0,
+ pos_embed_max_size: int = 192,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = in_channels
+ self.patch_size = patch_size
+ self.pos_embed_max_size = pos_embed_max_size
+
+ hidden_size = transformer_config.hidden_size
+
+ self.x_embedder = PatchEmbedMR(patch_size, in_channels, hidden_size, bias=True)
+ self.input_x_embedder = PatchEmbedMR(patch_size, in_channels, hidden_size, bias=True)
+
+ self.time_token = TimestepEmbedder(hidden_size)
+ self.t_embedder = TimestepEmbedder(hidden_size)
+
+ self.pe_interpolation = pe_interpolation
+ pos_embed = get_2d_sincos_pos_embed(hidden_size, pos_embed_max_size, interpolation_scale=self.pe_interpolation, base_size=64)
+ self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=True)
+
+ self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
+
+ self.initialize_weights()
+
+ self.llm = Phi3Transformer(config=transformer_config)
+ self.llm.config.use_cache = False
+
+ @classmethod
+ def from_pretrained(cls, model_name):
+ if not os.path.exists(model_name):
+ cache_folder = os.getenv('HF_HUB_CACHE')
+ model_name = snapshot_download(repo_id=model_name,
+ cache_dir=cache_folder,
+ ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5'])
+ config = Phi3Config.from_pretrained(model_name)
+ model = cls(config)
+ if os.path.exists(os.path.join(model_name, 'model.safetensors')):
+ print("Loading safetensors")
+ ckpt = load_file(os.path.join(model_name, 'model.safetensors'))
+ else:
+ ckpt = torch.load(os.path.join(model_name, 'model.pt'), map_location='cpu')
+ model.load_state_dict(ckpt)
+ return model
+
+ def initialize_weights(self):
+ assert not hasattr(self, "llama")
+
+ # Initialize transformer layers:
+ def _basic_init(module):
+ if isinstance(module, nn.Linear):
+ torch.nn.init.xavier_uniform_(module.weight)
+ if module.bias is not None:
+ nn.init.constant_(module.bias, 0)
+ self.apply(_basic_init)
+
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
+ w = self.x_embedder.proj.weight.data
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
+
+ w = self.input_x_embedder.proj.weight.data
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
+
+
+ # Initialize timestep embedding MLP:
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
+ nn.init.normal_(self.time_token.mlp[0].weight, std=0.02)
+ nn.init.normal_(self.time_token.mlp[2].weight, std=0.02)
+
+ # Zero-out output layers:
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
+ nn.init.constant_(self.final_layer.linear.weight, 0)
+ nn.init.constant_(self.final_layer.linear.bias, 0)
+
+ def unpatchify(self, x, h, w):
+ """
+ x: (N, T, patch_size**2 * C)
+ imgs: (N, H, W, C)
+ """
+ c = self.out_channels
+
+ x = x.reshape(shape=(x.shape[0], h//self.patch_size, w//self.patch_size, self.patch_size, self.patch_size, c))
+ x = torch.einsum('nhwpqc->nchpwq', x)
+ imgs = x.reshape(shape=(x.shape[0], c, h, w))
+ return imgs
+
+
+ def cropped_pos_embed(self, height, width):
+ """Crops positional embeddings for SD3 compatibility."""
+ if self.pos_embed_max_size is None:
+ raise ValueError("`pos_embed_max_size` must be set for cropping.")
+
+ height = height // self.patch_size
+ width = width // self.patch_size
+ if height > self.pos_embed_max_size:
+ raise ValueError(
+ f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
+ )
+ if width > self.pos_embed_max_size:
+ raise ValueError(
+ f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
+ )
+
+ top = (self.pos_embed_max_size - height) // 2
+ left = (self.pos_embed_max_size - width) // 2
+ spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1)
+ spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :]
+ # print(top, top + height, left, left + width, spatial_pos_embed.size())
+ spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1])
+ return spatial_pos_embed
+
+
+ def patch_multiple_resolutions(self, latents, padding_latent=None, is_input_images:bool=False):
+ if isinstance(latents, list):
+ return_list = False
+ if padding_latent is None:
+ padding_latent = [None] * len(latents)
+ return_list = True
+ patched_latents, num_tokens, shapes = [], [], []
+ for latent, padding in zip(latents, padding_latent):
+ height, width = latent.shape[-2:]
+ if is_input_images:
+ latent = self.input_x_embedder(latent)
+ else:
+ latent = self.x_embedder(latent)
+ pos_embed = self.cropped_pos_embed(height, width)
+ latent = latent + pos_embed
+ if padding is not None:
+ latent = torch.cat([latent, padding], dim=-2)
+ patched_latents.append(latent)
+
+ num_tokens.append(pos_embed.size(1))
+ shapes.append([height, width])
+ if not return_list:
+ latents = torch.cat(patched_latents, dim=0)
+ else:
+ latents = patched_latents
+ else:
+ height, width = latents.shape[-2:]
+ if is_input_images:
+ latents = self.input_x_embedder(latents)
+ else:
+ latents = self.x_embedder(latents)
+ pos_embed = self.cropped_pos_embed(height, width)
+ latents = latents + pos_embed
+ num_tokens = latents.size(1)
+ shapes = [height, width]
+ return latents, num_tokens, shapes
+
+
+ def forward(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, padding_latent=None, past_key_values=None, return_past_key_values=True, offload_model:bool=False):
+ """
+
+ """
+ input_is_list = isinstance(x, list)
+ x, num_tokens, shapes = self.patch_multiple_resolutions(x, padding_latent)
+ time_token = self.time_token(timestep, dtype=x[0].dtype).unsqueeze(1)
+
+ if input_img_latents is not None:
+ input_latents, _, _ = self.patch_multiple_resolutions(input_img_latents, is_input_images=True)
+ if input_ids is not None:
+ condition_embeds = self.llm.embed_tokens(input_ids).clone()
+ input_img_inx = 0
+ for b_inx in input_image_sizes.keys():
+ for start_inx, end_inx in input_image_sizes[b_inx]:
+ condition_embeds[b_inx, start_inx: end_inx] = input_latents[input_img_inx]
+ input_img_inx += 1
+ if input_img_latents is not None:
+ assert input_img_inx == len(input_latents)
+
+ input_emb = torch.cat([condition_embeds, time_token, x], dim=1)
+ else:
+ input_emb = torch.cat([time_token, x], dim=1)
+ output = self.llm(inputs_embeds=input_emb, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, offload_model=offload_model)
+ output, past_key_values = output.last_hidden_state, output.past_key_values
+ if input_is_list:
+ image_embedding = output[:, -max(num_tokens):]
+ time_emb = self.t_embedder(timestep, dtype=x.dtype)
+ x = self.final_layer(image_embedding, time_emb)
+ latents = []
+ for i in range(x.size(0)):
+ latent = x[i:i+1, :num_tokens[i]]
+ latent = self.unpatchify(latent, shapes[i][0], shapes[i][1])
+ latents.append(latent)
+ else:
+ image_embedding = output[:, -num_tokens:]
+ time_emb = self.t_embedder(timestep, dtype=x.dtype)
+ x = self.final_layer(image_embedding, time_emb)
+ latents = self.unpatchify(x, shapes[0], shapes[1])
+
+ if return_past_key_values:
+ return latents, past_key_values
+ return latents
+
+ @torch.no_grad()
+ def forward_with_cfg(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, cfg_scale, use_img_cfg, img_cfg_scale, past_key_values, use_kv_cache, offload_model):
+ self.llm.config.use_cache = use_kv_cache
+ model_out, past_key_values = self.forward(x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, past_key_values=past_key_values, return_past_key_values=True, offload_model=offload_model)
+ if use_img_cfg:
+ cond, uncond, img_cond = torch.split(model_out, len(model_out) // 3, dim=0)
+ cond = uncond + img_cfg_scale * (img_cond - uncond) + cfg_scale * (cond - img_cond)
+ model_out = [cond, cond, cond]
+ else:
+ cond, uncond = torch.split(model_out, len(model_out) // 2, dim=0)
+ cond = uncond + cfg_scale * (cond - uncond)
+ model_out = [cond, cond]
+
+ return torch.cat(model_out, dim=0), past_key_values
+
+
+ @torch.no_grad()
+ def forward_with_separate_cfg(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, cfg_scale, use_img_cfg, img_cfg_scale, past_key_values, use_kv_cache, offload_model):
+ self.llm.config.use_cache = use_kv_cache
+ if past_key_values is None:
+ past_key_values = [None] * len(attention_mask)
+
+ x = torch.split(x, len(x) // len(attention_mask), dim=0)
+ timestep = timestep.to(x[0].dtype)
+ timestep = torch.split(timestep, len(timestep) // len(input_ids), dim=0)
+
+ model_out, pask_key_values = [], []
+ for i in range(len(input_ids)):
+ temp_out, temp_pask_key_values = self.forward(x[i], timestep[i], input_ids[i], input_img_latents[i], input_image_sizes[i], attention_mask[i], position_ids[i], past_key_values=past_key_values[i], return_past_key_values=True, offload_model=offload_model)
+ model_out.append(temp_out)
+ pask_key_values.append(temp_pask_key_values)
+
+ if len(model_out) == 3:
+ cond, uncond, img_cond = model_out
+ cond = uncond + img_cfg_scale * (img_cond - uncond) + cfg_scale * (cond - img_cond)
+ model_out = [cond, cond, cond]
+ elif len(model_out) == 2:
+ cond, uncond = model_out
+ cond = uncond + cfg_scale * (cond - uncond)
+ model_out = [cond, cond]
+ else:
+ return model_out[0]
+
+ return torch.cat(model_out, dim=0), pask_key_values
+
+
+
+class OmniGenTransformer(OmniGenOriginalModel):
+ def __init__(self):
+ config = {
+ "_name_or_path": "Phi-3-vision-128k-instruct",
+ "architectures": [
+ "Phi3ForCausalLM"
+ ],
+ "attention_dropout": 0.0,
+ "bos_token_id": 1,
+ "eos_token_id": 2,
+ "hidden_act": "silu",
+ "hidden_size": 3072,
+ "initializer_range": 0.02,
+ "intermediate_size": 8192,
+ "max_position_embeddings": 131072,
+ "model_type": "phi3",
+ "num_attention_heads": 32,
+ "num_hidden_layers": 32,
+ "num_key_value_heads": 32,
+ "original_max_position_embeddings": 4096,
+ "rms_norm_eps": 1e-05,
+ "rope_scaling": {
+ "long_factor": [
+ 1.0299999713897705,
+ 1.0499999523162842,
+ 1.0499999523162842,
+ 1.0799999237060547,
+ 1.2299998998641968,
+ 1.2299998998641968,
+ 1.2999999523162842,
+ 1.4499999284744263,
+ 1.5999999046325684,
+ 1.6499998569488525,
+ 1.8999998569488525,
+ 2.859999895095825,
+ 3.68999981880188,
+ 5.419999599456787,
+ 5.489999771118164,
+ 5.489999771118164,
+ 9.09000015258789,
+ 11.579999923706055,
+ 15.65999984741211,
+ 15.769999504089355,
+ 15.789999961853027,
+ 18.360000610351562,
+ 21.989999771118164,
+ 23.079999923706055,
+ 30.009998321533203,
+ 32.35000228881836,
+ 32.590003967285156,
+ 35.56000518798828,
+ 39.95000457763672,
+ 53.840003967285156,
+ 56.20000457763672,
+ 57.95000457763672,
+ 59.29000473022461,
+ 59.77000427246094,
+ 59.920005798339844,
+ 61.190006256103516,
+ 61.96000671386719,
+ 62.50000762939453,
+ 63.3700065612793,
+ 63.48000717163086,
+ 63.48000717163086,
+ 63.66000747680664,
+ 63.850006103515625,
+ 64.08000946044922,
+ 64.760009765625,
+ 64.80001068115234,
+ 64.81001281738281,
+ 64.81001281738281
+ ],
+ "short_factor": [
+ 1.05,
+ 1.05,
+ 1.05,
+ 1.1,
+ 1.1,
+ 1.1,
+ 1.2500000000000002,
+ 1.2500000000000002,
+ 1.4000000000000004,
+ 1.4500000000000004,
+ 1.5500000000000005,
+ 1.8500000000000008,
+ 1.9000000000000008,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.1000000000000005,
+ 2.1000000000000005,
+ 2.2,
+ 2.3499999999999996,
+ 2.3499999999999996,
+ 2.3499999999999996,
+ 2.3499999999999996,
+ 2.3999999999999995,
+ 2.3999999999999995,
+ 2.6499999999999986,
+ 2.6999999999999984,
+ 2.8999999999999977,
+ 2.9499999999999975,
+ 3.049999999999997,
+ 3.049999999999997,
+ 3.049999999999997
+ ],
+ "type": "su"
+ },
+ "rope_theta": 10000.0,
+ "sliding_window": 131072,
+ "tie_word_embeddings": False,
+ "torch_dtype": "bfloat16",
+ "transformers_version": "4.38.1",
+ "use_cache": True,
+ "vocab_size": 32064,
+ "_attn_implementation": "sdpa"
+ }
+ config = Phi3Config(**config)
+ super().__init__(config)
+
+
+ def forward(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, padding_latent=None, past_key_values=None, return_past_key_values=True, offload_model:bool=False):
+ input_is_list = isinstance(x, list)
+ x, num_tokens, shapes = self.patch_multiple_resolutions(x, padding_latent)
+ time_token = self.time_token(timestep, dtype=x[0].dtype).unsqueeze(1)
+
+ if input_img_latents is not None:
+ input_latents, _, _ = self.patch_multiple_resolutions(input_img_latents, is_input_images=True)
+ if input_ids is not None:
+ condition_embeds = self.llm.embed_tokens(input_ids).clone()
+ input_img_inx = 0
+ for b_inx in input_image_sizes.keys():
+ for start_inx, end_inx in input_image_sizes[b_inx]:
+ condition_embeds[b_inx, start_inx: end_inx] = input_latents[input_img_inx]
+ input_img_inx += 1
+ if input_img_latents is not None:
+ assert input_img_inx == len(input_latents)
+
+ input_emb = torch.cat([condition_embeds, time_token, x], dim=1)
+ else:
+ input_emb = torch.cat([time_token, x], dim=1)
+ output = self.llm(inputs_embeds=input_emb, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, offload_model=offload_model)
+ output, past_key_values = output.last_hidden_state, output.past_key_values
+ if input_is_list:
+ image_embedding = output[:, -max(num_tokens):]
+ time_emb = self.t_embedder(timestep, dtype=x.dtype)
+ x = self.final_layer(image_embedding, time_emb)
+ latents = []
+ for i in range(x.size(0)):
+ latent = x[i:i+1, :num_tokens[i]]
+ latent = self.unpatchify(latent, shapes[i][0], shapes[i][1])
+ latents.append(latent)
+ else:
+ image_embedding = output[:, -num_tokens:]
+ time_emb = self.t_embedder(timestep, dtype=x.dtype)
+ x = self.final_layer(image_embedding, time_emb)
+ latents = self.unpatchify(x, shapes[0], shapes[1])
+
+ if return_past_key_values:
+ return latents, past_key_values
+ return latents
+
+
+ @torch.no_grad()
+ def forward_with_separate_cfg(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, cfg_scale, use_img_cfg, img_cfg_scale, past_key_values, use_kv_cache, offload_model):
+ self.llm.config.use_cache = use_kv_cache
+ if past_key_values is None:
+ past_key_values = [None] * len(attention_mask)
+
+ x = torch.split(x, len(x) // len(attention_mask), dim=0)
+ timestep = timestep.to(x[0].dtype)
+ timestep = torch.split(timestep, len(timestep) // len(input_ids), dim=0)
+
+ model_out, pask_key_values = [], []
+ for i in range(len(input_ids)):
+ temp_out, temp_pask_key_values = self.forward(x[i], timestep[i], input_ids[i], input_img_latents[i], input_image_sizes[i], attention_mask[i], position_ids[i], past_key_values=past_key_values[i], return_past_key_values=True, offload_model=offload_model)
+ model_out.append(temp_out)
+ pask_key_values.append(temp_pask_key_values)
+
+ if len(model_out) == 3:
+ cond, uncond, img_cond = model_out
+ cond = uncond + img_cfg_scale * (img_cond - uncond) + cfg_scale * (cond - img_cond)
+ model_out = [cond, cond, cond]
+ elif len(model_out) == 2:
+ cond, uncond = model_out
+ cond = uncond + cfg_scale * (cond - uncond)
+ model_out = [cond, cond]
+ else:
+ return model_out[0]
+
+ return torch.cat(model_out, dim=0), pask_key_values
+
+
+ @staticmethod
+ def state_dict_converter():
+ return OmniGenTransformerStateDictConverter()
+
+
+
+class OmniGenTransformerStateDictConverter:
+ def __init__(self):
+ pass
+
+ def from_diffusers(self, state_dict):
+ return state_dict
+
+ def from_civitai(self, state_dict):
+ return state_dict
diff --git a/PusaV1/diffsynth/models/sd3_dit.py b/PusaV1/diffsynth/models/sd3_dit.py
new file mode 100644
index 0000000000000000000000000000000000000000..60e6be4a805f29c9d501fdece10a805c75e4662d
--- /dev/null
+++ b/PusaV1/diffsynth/models/sd3_dit.py
@@ -0,0 +1,551 @@
+import torch
+from einops import rearrange
+from .svd_unet import TemporalTimesteps
+from .tiler import TileWorker
+
+
+
+class RMSNorm(torch.nn.Module):
+ def __init__(self, dim, eps, elementwise_affine=True):
+ super().__init__()
+ self.eps = eps
+ if elementwise_affine:
+ self.weight = torch.nn.Parameter(torch.ones((dim,)))
+ else:
+ self.weight = None
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
+ hidden_states = hidden_states.to(input_dtype)
+ if self.weight is not None:
+ hidden_states = hidden_states * self.weight
+ return hidden_states
+
+
+
+class PatchEmbed(torch.nn.Module):
+ def __init__(self, patch_size=2, in_channels=16, embed_dim=1536, pos_embed_max_size=192):
+ super().__init__()
+ self.pos_embed_max_size = pos_embed_max_size
+ self.patch_size = patch_size
+
+ self.proj = torch.nn.Conv2d(in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size)
+ self.pos_embed = torch.nn.Parameter(torch.zeros(1, self.pos_embed_max_size, self.pos_embed_max_size, embed_dim))
+
+ def cropped_pos_embed(self, height, width):
+ height = height // self.patch_size
+ width = width // self.patch_size
+ top = (self.pos_embed_max_size - height) // 2
+ left = (self.pos_embed_max_size - width) // 2
+ spatial_pos_embed = self.pos_embed[:, top : top + height, left : left + width, :].flatten(1, 2)
+ return spatial_pos_embed
+
+ def forward(self, latent):
+ height, width = latent.shape[-2:]
+ latent = self.proj(latent)
+ latent = latent.flatten(2).transpose(1, 2)
+ pos_embed = self.cropped_pos_embed(height, width)
+ return latent + pos_embed
+
+
+
+class TimestepEmbeddings(torch.nn.Module):
+ def __init__(self, dim_in, dim_out, computation_device=None):
+ super().__init__()
+ self.time_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0, computation_device=computation_device)
+ self.timestep_embedder = torch.nn.Sequential(
+ torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out)
+ )
+
+ def forward(self, timestep, dtype):
+ time_emb = self.time_proj(timestep).to(dtype)
+ time_emb = self.timestep_embedder(time_emb)
+ return time_emb
+
+
+
+class AdaLayerNorm(torch.nn.Module):
+ def __init__(self, dim, single=False, dual=False):
+ super().__init__()
+ self.single = single
+ self.dual = dual
+ self.linear = torch.nn.Linear(dim, dim * [[6, 2][single], 9][dual])
+ self.norm = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
+
+ def forward(self, x, emb):
+ emb = self.linear(torch.nn.functional.silu(emb))
+ if self.single:
+ scale, shift = emb.unsqueeze(1).chunk(2, dim=2)
+ x = self.norm(x) * (1 + scale) + shift
+ return x
+ elif self.dual:
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_msa2, scale_msa2, gate_msa2 = emb.unsqueeze(1).chunk(9, dim=2)
+ norm_x = self.norm(x)
+ x = norm_x * (1 + scale_msa) + shift_msa
+ norm_x2 = norm_x * (1 + scale_msa2) + shift_msa2
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_x2, gate_msa2
+ else:
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.unsqueeze(1).chunk(6, dim=2)
+ x = self.norm(x) * (1 + scale_msa) + shift_msa
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
+
+
+
+class JointAttention(torch.nn.Module):
+ def __init__(self, dim_a, dim_b, num_heads, head_dim, only_out_a=False, use_rms_norm=False):
+ super().__init__()
+ self.num_heads = num_heads
+ self.head_dim = head_dim
+ self.only_out_a = only_out_a
+
+ self.a_to_qkv = torch.nn.Linear(dim_a, dim_a * 3)
+ self.b_to_qkv = torch.nn.Linear(dim_b, dim_b * 3)
+
+ self.a_to_out = torch.nn.Linear(dim_a, dim_a)
+ if not only_out_a:
+ self.b_to_out = torch.nn.Linear(dim_b, dim_b)
+
+ if use_rms_norm:
+ self.norm_q_a = RMSNorm(head_dim, eps=1e-6)
+ self.norm_k_a = RMSNorm(head_dim, eps=1e-6)
+ self.norm_q_b = RMSNorm(head_dim, eps=1e-6)
+ self.norm_k_b = RMSNorm(head_dim, eps=1e-6)
+ else:
+ self.norm_q_a = None
+ self.norm_k_a = None
+ self.norm_q_b = None
+ self.norm_k_b = None
+
+
+ def process_qkv(self, hidden_states, to_qkv, norm_q, norm_k):
+ batch_size = hidden_states.shape[0]
+ qkv = to_qkv(hidden_states)
+ qkv = qkv.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
+ q, k, v = qkv.chunk(3, dim=1)
+ if norm_q is not None:
+ q = norm_q(q)
+ if norm_k is not None:
+ k = norm_k(k)
+ return q, k, v
+
+
+ def forward(self, hidden_states_a, hidden_states_b):
+ batch_size = hidden_states_a.shape[0]
+
+ qa, ka, va = self.process_qkv(hidden_states_a, self.a_to_qkv, self.norm_q_a, self.norm_k_a)
+ qb, kb, vb = self.process_qkv(hidden_states_b, self.b_to_qkv, self.norm_q_b, self.norm_k_b)
+ q = torch.concat([qa, qb], dim=2)
+ k = torch.concat([ka, kb], dim=2)
+ v = torch.concat([va, vb], dim=2)
+
+ hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v)
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
+ hidden_states = hidden_states.to(q.dtype)
+ hidden_states_a, hidden_states_b = hidden_states[:, :hidden_states_a.shape[1]], hidden_states[:, hidden_states_a.shape[1]:]
+ hidden_states_a = self.a_to_out(hidden_states_a)
+ if self.only_out_a:
+ return hidden_states_a
+ else:
+ hidden_states_b = self.b_to_out(hidden_states_b)
+ return hidden_states_a, hidden_states_b
+
+
+
+class SingleAttention(torch.nn.Module):
+ def __init__(self, dim_a, num_heads, head_dim, use_rms_norm=False):
+ super().__init__()
+ self.num_heads = num_heads
+ self.head_dim = head_dim
+
+ self.a_to_qkv = torch.nn.Linear(dim_a, dim_a * 3)
+ self.a_to_out = torch.nn.Linear(dim_a, dim_a)
+
+ if use_rms_norm:
+ self.norm_q_a = RMSNorm(head_dim, eps=1e-6)
+ self.norm_k_a = RMSNorm(head_dim, eps=1e-6)
+ else:
+ self.norm_q_a = None
+ self.norm_k_a = None
+
+
+ def process_qkv(self, hidden_states, to_qkv, norm_q, norm_k):
+ batch_size = hidden_states.shape[0]
+ qkv = to_qkv(hidden_states)
+ qkv = qkv.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
+ q, k, v = qkv.chunk(3, dim=1)
+ if norm_q is not None:
+ q = norm_q(q)
+ if norm_k is not None:
+ k = norm_k(k)
+ return q, k, v
+
+
+ def forward(self, hidden_states_a):
+ batch_size = hidden_states_a.shape[0]
+ q, k, v = self.process_qkv(hidden_states_a, self.a_to_qkv, self.norm_q_a, self.norm_k_a)
+
+ hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v)
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
+ hidden_states = hidden_states.to(q.dtype)
+ hidden_states = self.a_to_out(hidden_states)
+ return hidden_states
+
+
+
+class DualTransformerBlock(torch.nn.Module):
+ def __init__(self, dim, num_attention_heads, use_rms_norm=False):
+ super().__init__()
+ self.norm1_a = AdaLayerNorm(dim, dual=True)
+ self.norm1_b = AdaLayerNorm(dim)
+
+ self.attn = JointAttention(dim, dim, num_attention_heads, dim // num_attention_heads, use_rms_norm=use_rms_norm)
+ self.attn2 = JointAttention(dim, dim, num_attention_heads, dim // num_attention_heads, use_rms_norm=use_rms_norm)
+
+ self.norm2_a = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
+ self.ff_a = torch.nn.Sequential(
+ torch.nn.Linear(dim, dim*4),
+ torch.nn.GELU(approximate="tanh"),
+ torch.nn.Linear(dim*4, dim)
+ )
+
+ self.norm2_b = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
+ self.ff_b = torch.nn.Sequential(
+ torch.nn.Linear(dim, dim*4),
+ torch.nn.GELU(approximate="tanh"),
+ torch.nn.Linear(dim*4, dim)
+ )
+
+
+ def forward(self, hidden_states_a, hidden_states_b, temb):
+ norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a, norm_hidden_states_a_2, gate_msa_a_2 = self.norm1_a(hidden_states_a, emb=temb)
+ norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb)
+
+ # Attention
+ attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b)
+
+ # Part A
+ hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
+ hidden_states_a = hidden_states_a + gate_msa_a_2 * self.attn2(norm_hidden_states_a_2)
+ norm_hidden_states_a = self.norm2_a(hidden_states_a) * (1 + scale_mlp_a) + shift_mlp_a
+ hidden_states_a = hidden_states_a + gate_mlp_a * self.ff_a(norm_hidden_states_a)
+
+ # Part B
+ hidden_states_b = hidden_states_b + gate_msa_b * attn_output_b
+ norm_hidden_states_b = self.norm2_b(hidden_states_b) * (1 + scale_mlp_b) + shift_mlp_b
+ hidden_states_b = hidden_states_b + gate_mlp_b * self.ff_b(norm_hidden_states_b)
+
+ return hidden_states_a, hidden_states_b
+
+
+
+class JointTransformerBlock(torch.nn.Module):
+ def __init__(self, dim, num_attention_heads, use_rms_norm=False, dual=False):
+ super().__init__()
+ self.norm1_a = AdaLayerNorm(dim, dual=dual)
+ self.norm1_b = AdaLayerNorm(dim)
+
+ self.attn = JointAttention(dim, dim, num_attention_heads, dim // num_attention_heads, use_rms_norm=use_rms_norm)
+ if dual:
+ self.attn2 = SingleAttention(dim, num_attention_heads, dim // num_attention_heads, use_rms_norm=use_rms_norm)
+
+ self.norm2_a = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
+ self.ff_a = torch.nn.Sequential(
+ torch.nn.Linear(dim, dim*4),
+ torch.nn.GELU(approximate="tanh"),
+ torch.nn.Linear(dim*4, dim)
+ )
+
+ self.norm2_b = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
+ self.ff_b = torch.nn.Sequential(
+ torch.nn.Linear(dim, dim*4),
+ torch.nn.GELU(approximate="tanh"),
+ torch.nn.Linear(dim*4, dim)
+ )
+
+
+ def forward(self, hidden_states_a, hidden_states_b, temb):
+ if self.norm1_a.dual:
+ norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a, norm_hidden_states_a_2, gate_msa_a_2 = self.norm1_a(hidden_states_a, emb=temb)
+ else:
+ norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb)
+ norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb)
+
+ # Attention
+ attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b)
+
+ # Part A
+ hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
+ if self.norm1_a.dual:
+ hidden_states_a = hidden_states_a + gate_msa_a_2 * self.attn2(norm_hidden_states_a_2)
+ norm_hidden_states_a = self.norm2_a(hidden_states_a) * (1 + scale_mlp_a) + shift_mlp_a
+ hidden_states_a = hidden_states_a + gate_mlp_a * self.ff_a(norm_hidden_states_a)
+
+ # Part B
+ hidden_states_b = hidden_states_b + gate_msa_b * attn_output_b
+ norm_hidden_states_b = self.norm2_b(hidden_states_b) * (1 + scale_mlp_b) + shift_mlp_b
+ hidden_states_b = hidden_states_b + gate_mlp_b * self.ff_b(norm_hidden_states_b)
+
+ return hidden_states_a, hidden_states_b
+
+
+
+class JointTransformerFinalBlock(torch.nn.Module):
+ def __init__(self, dim, num_attention_heads, use_rms_norm=False):
+ super().__init__()
+ self.norm1_a = AdaLayerNorm(dim)
+ self.norm1_b = AdaLayerNorm(dim, single=True)
+
+ self.attn = JointAttention(dim, dim, num_attention_heads, dim // num_attention_heads, only_out_a=True, use_rms_norm=use_rms_norm)
+
+ self.norm2_a = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
+ self.ff_a = torch.nn.Sequential(
+ torch.nn.Linear(dim, dim*4),
+ torch.nn.GELU(approximate="tanh"),
+ torch.nn.Linear(dim*4, dim)
+ )
+
+
+ def forward(self, hidden_states_a, hidden_states_b, temb):
+ norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb)
+ norm_hidden_states_b = self.norm1_b(hidden_states_b, emb=temb)
+
+ # Attention
+ attn_output_a = self.attn(norm_hidden_states_a, norm_hidden_states_b)
+
+ # Part A
+ hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
+ norm_hidden_states_a = self.norm2_a(hidden_states_a) * (1 + scale_mlp_a) + shift_mlp_a
+ hidden_states_a = hidden_states_a + gate_mlp_a * self.ff_a(norm_hidden_states_a)
+
+ return hidden_states_a, hidden_states_b
+
+
+
+class SD3DiT(torch.nn.Module):
+ def __init__(self, embed_dim=1536, num_layers=24, use_rms_norm=False, num_dual_blocks=0, pos_embed_max_size=192):
+ super().__init__()
+ self.pos_embedder = PatchEmbed(patch_size=2, in_channels=16, embed_dim=embed_dim, pos_embed_max_size=pos_embed_max_size)
+ self.time_embedder = TimestepEmbeddings(256, embed_dim)
+ self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(2048, embed_dim), torch.nn.SiLU(), torch.nn.Linear(embed_dim, embed_dim))
+ self.context_embedder = torch.nn.Linear(4096, embed_dim)
+ self.blocks = torch.nn.ModuleList([JointTransformerBlock(embed_dim, embed_dim//64, use_rms_norm=use_rms_norm, dual=True) for _ in range(num_dual_blocks)]
+ + [JointTransformerBlock(embed_dim, embed_dim//64, use_rms_norm=use_rms_norm) for _ in range(num_layers-1-num_dual_blocks)]
+ + [JointTransformerFinalBlock(embed_dim, embed_dim//64, use_rms_norm=use_rms_norm)])
+ self.norm_out = AdaLayerNorm(embed_dim, single=True)
+ self.proj_out = torch.nn.Linear(embed_dim, 64)
+
+ def tiled_forward(self, hidden_states, timestep, prompt_emb, pooled_prompt_emb, tile_size=128, tile_stride=64):
+ # Due to the global positional embedding, we cannot implement layer-wise tiled forward.
+ hidden_states = TileWorker().tiled_forward(
+ lambda x: self.forward(x, timestep, prompt_emb, pooled_prompt_emb),
+ hidden_states,
+ tile_size,
+ tile_stride,
+ tile_device=hidden_states.device,
+ tile_dtype=hidden_states.dtype
+ )
+ return hidden_states
+
+ def forward(self, hidden_states, timestep, prompt_emb, pooled_prompt_emb, tiled=False, tile_size=128, tile_stride=64, use_gradient_checkpointing=False):
+ if tiled:
+ return self.tiled_forward(hidden_states, timestep, prompt_emb, pooled_prompt_emb, tile_size, tile_stride)
+ conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb)
+ prompt_emb = self.context_embedder(prompt_emb)
+
+ height, width = hidden_states.shape[-2:]
+ hidden_states = self.pos_embedder(hidden_states)
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+ return custom_forward
+
+ for block in self.blocks:
+ if self.training and use_gradient_checkpointing:
+ hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ hidden_states, prompt_emb, conditioning,
+ use_reentrant=False,
+ )
+ else:
+ hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning)
+
+ hidden_states = self.norm_out(hidden_states, conditioning)
+ hidden_states = self.proj_out(hidden_states)
+ hidden_states = rearrange(hidden_states, "B (H W) (P Q C) -> B C (H P) (W Q)", P=2, Q=2, H=height//2, W=width//2)
+ return hidden_states
+
+ @staticmethod
+ def state_dict_converter():
+ return SD3DiTStateDictConverter()
+
+
+
+class SD3DiTStateDictConverter:
+ def __init__(self):
+ pass
+
+ def infer_architecture(self, state_dict):
+ embed_dim = state_dict["blocks.0.ff_a.0.weight"].shape[1]
+ num_layers = 100
+ while num_layers > 0 and f"blocks.{num_layers-1}.ff_a.0.bias" not in state_dict:
+ num_layers -= 1
+ use_rms_norm = "blocks.0.attn.norm_q_a.weight" in state_dict
+ num_dual_blocks = 0
+ while f"blocks.{num_dual_blocks}.attn2.a_to_out.bias" in state_dict:
+ num_dual_blocks += 1
+ pos_embed_max_size = state_dict["pos_embedder.pos_embed"].shape[1]
+ return {
+ "embed_dim": embed_dim,
+ "num_layers": num_layers,
+ "use_rms_norm": use_rms_norm,
+ "num_dual_blocks": num_dual_blocks,
+ "pos_embed_max_size": pos_embed_max_size
+ }
+
+ def from_diffusers(self, state_dict):
+ rename_dict = {
+ "context_embedder": "context_embedder",
+ "pos_embed.pos_embed": "pos_embedder.pos_embed",
+ "pos_embed.proj": "pos_embedder.proj",
+ "time_text_embed.timestep_embedder.linear_1": "time_embedder.timestep_embedder.0",
+ "time_text_embed.timestep_embedder.linear_2": "time_embedder.timestep_embedder.2",
+ "time_text_embed.text_embedder.linear_1": "pooled_text_embedder.0",
+ "time_text_embed.text_embedder.linear_2": "pooled_text_embedder.2",
+ "norm_out.linear": "norm_out.linear",
+ "proj_out": "proj_out",
+
+ "norm1.linear": "norm1_a.linear",
+ "norm1_context.linear": "norm1_b.linear",
+ "attn.to_q": "attn.a_to_q",
+ "attn.to_k": "attn.a_to_k",
+ "attn.to_v": "attn.a_to_v",
+ "attn.to_out.0": "attn.a_to_out",
+ "attn.add_q_proj": "attn.b_to_q",
+ "attn.add_k_proj": "attn.b_to_k",
+ "attn.add_v_proj": "attn.b_to_v",
+ "attn.to_add_out": "attn.b_to_out",
+ "ff.net.0.proj": "ff_a.0",
+ "ff.net.2": "ff_a.2",
+ "ff_context.net.0.proj": "ff_b.0",
+ "ff_context.net.2": "ff_b.2",
+
+ "attn.norm_q": "attn.norm_q_a",
+ "attn.norm_k": "attn.norm_k_a",
+ "attn.norm_added_q": "attn.norm_q_b",
+ "attn.norm_added_k": "attn.norm_k_b",
+ }
+ state_dict_ = {}
+ for name, param in state_dict.items():
+ if name in rename_dict:
+ if name == "pos_embed.pos_embed":
+ param = param.reshape((1, 192, 192, param.shape[-1]))
+ state_dict_[rename_dict[name]] = param
+ elif name.endswith(".weight") or name.endswith(".bias"):
+ suffix = ".weight" if name.endswith(".weight") else ".bias"
+ prefix = name[:-len(suffix)]
+ if prefix in rename_dict:
+ state_dict_[rename_dict[prefix] + suffix] = param
+ elif prefix.startswith("transformer_blocks."):
+ names = prefix.split(".")
+ names[0] = "blocks"
+ middle = ".".join(names[2:])
+ if middle in rename_dict:
+ name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]])
+ state_dict_[name_] = param
+ merged_keys = [name for name in state_dict_ if ".a_to_q." in name or ".b_to_q." in name]
+ for key in merged_keys:
+ param = torch.concat([
+ state_dict_[key.replace("to_q", "to_q")],
+ state_dict_[key.replace("to_q", "to_k")],
+ state_dict_[key.replace("to_q", "to_v")],
+ ], dim=0)
+ name = key.replace("to_q", "to_qkv")
+ state_dict_.pop(key.replace("to_q", "to_q"))
+ state_dict_.pop(key.replace("to_q", "to_k"))
+ state_dict_.pop(key.replace("to_q", "to_v"))
+ state_dict_[name] = param
+ return state_dict_, self.infer_architecture(state_dict_)
+
+ def from_civitai(self, state_dict):
+ rename_dict = {
+ "model.diffusion_model.context_embedder.bias": "context_embedder.bias",
+ "model.diffusion_model.context_embedder.weight": "context_embedder.weight",
+ "model.diffusion_model.final_layer.linear.bias": "proj_out.bias",
+ "model.diffusion_model.final_layer.linear.weight": "proj_out.weight",
+
+ "model.diffusion_model.pos_embed": "pos_embedder.pos_embed",
+ "model.diffusion_model.t_embedder.mlp.0.bias": "time_embedder.timestep_embedder.0.bias",
+ "model.diffusion_model.t_embedder.mlp.0.weight": "time_embedder.timestep_embedder.0.weight",
+ "model.diffusion_model.t_embedder.mlp.2.bias": "time_embedder.timestep_embedder.2.bias",
+ "model.diffusion_model.t_embedder.mlp.2.weight": "time_embedder.timestep_embedder.2.weight",
+ "model.diffusion_model.x_embedder.proj.bias": "pos_embedder.proj.bias",
+ "model.diffusion_model.x_embedder.proj.weight": "pos_embedder.proj.weight",
+ "model.diffusion_model.y_embedder.mlp.0.bias": "pooled_text_embedder.0.bias",
+ "model.diffusion_model.y_embedder.mlp.0.weight": "pooled_text_embedder.0.weight",
+ "model.diffusion_model.y_embedder.mlp.2.bias": "pooled_text_embedder.2.bias",
+ "model.diffusion_model.y_embedder.mlp.2.weight": "pooled_text_embedder.2.weight",
+
+ "model.diffusion_model.joint_blocks.23.context_block.adaLN_modulation.1.weight": "blocks.23.norm1_b.linear.weight",
+ "model.diffusion_model.joint_blocks.23.context_block.adaLN_modulation.1.bias": "blocks.23.norm1_b.linear.bias",
+ "model.diffusion_model.final_layer.adaLN_modulation.1.weight": "norm_out.linear.weight",
+ "model.diffusion_model.final_layer.adaLN_modulation.1.bias": "norm_out.linear.bias",
+ }
+ for i in range(40):
+ rename_dict.update({
+ f"model.diffusion_model.joint_blocks.{i}.context_block.adaLN_modulation.1.bias": f"blocks.{i}.norm1_b.linear.bias",
+ f"model.diffusion_model.joint_blocks.{i}.context_block.adaLN_modulation.1.weight": f"blocks.{i}.norm1_b.linear.weight",
+ f"model.diffusion_model.joint_blocks.{i}.context_block.attn.proj.bias": f"blocks.{i}.attn.b_to_out.bias",
+ f"model.diffusion_model.joint_blocks.{i}.context_block.attn.proj.weight": f"blocks.{i}.attn.b_to_out.weight",
+ f"model.diffusion_model.joint_blocks.{i}.context_block.attn.qkv.bias": [f'blocks.{i}.attn.b_to_q.bias', f'blocks.{i}.attn.b_to_k.bias', f'blocks.{i}.attn.b_to_v.bias'],
+ f"model.diffusion_model.joint_blocks.{i}.context_block.attn.qkv.weight": [f'blocks.{i}.attn.b_to_q.weight', f'blocks.{i}.attn.b_to_k.weight', f'blocks.{i}.attn.b_to_v.weight'],
+ f"model.diffusion_model.joint_blocks.{i}.context_block.mlp.fc1.bias": f"blocks.{i}.ff_b.0.bias",
+ f"model.diffusion_model.joint_blocks.{i}.context_block.mlp.fc1.weight": f"blocks.{i}.ff_b.0.weight",
+ f"model.diffusion_model.joint_blocks.{i}.context_block.mlp.fc2.bias": f"blocks.{i}.ff_b.2.bias",
+ f"model.diffusion_model.joint_blocks.{i}.context_block.mlp.fc2.weight": f"blocks.{i}.ff_b.2.weight",
+ f"model.diffusion_model.joint_blocks.{i}.x_block.adaLN_modulation.1.bias": f"blocks.{i}.norm1_a.linear.bias",
+ f"model.diffusion_model.joint_blocks.{i}.x_block.adaLN_modulation.1.weight": f"blocks.{i}.norm1_a.linear.weight",
+ f"model.diffusion_model.joint_blocks.{i}.x_block.attn.proj.bias": f"blocks.{i}.attn.a_to_out.bias",
+ f"model.diffusion_model.joint_blocks.{i}.x_block.attn.proj.weight": f"blocks.{i}.attn.a_to_out.weight",
+ f"model.diffusion_model.joint_blocks.{i}.x_block.attn.qkv.bias": [f'blocks.{i}.attn.a_to_q.bias', f'blocks.{i}.attn.a_to_k.bias', f'blocks.{i}.attn.a_to_v.bias'],
+ f"model.diffusion_model.joint_blocks.{i}.x_block.attn.qkv.weight": [f'blocks.{i}.attn.a_to_q.weight', f'blocks.{i}.attn.a_to_k.weight', f'blocks.{i}.attn.a_to_v.weight'],
+ f"model.diffusion_model.joint_blocks.{i}.x_block.mlp.fc1.bias": f"blocks.{i}.ff_a.0.bias",
+ f"model.diffusion_model.joint_blocks.{i}.x_block.mlp.fc1.weight": f"blocks.{i}.ff_a.0.weight",
+ f"model.diffusion_model.joint_blocks.{i}.x_block.mlp.fc2.bias": f"blocks.{i}.ff_a.2.bias",
+ f"model.diffusion_model.joint_blocks.{i}.x_block.mlp.fc2.weight": f"blocks.{i}.ff_a.2.weight",
+ f"model.diffusion_model.joint_blocks.{i}.x_block.attn.ln_q.weight": f"blocks.{i}.attn.norm_q_a.weight",
+ f"model.diffusion_model.joint_blocks.{i}.x_block.attn.ln_k.weight": f"blocks.{i}.attn.norm_k_a.weight",
+ f"model.diffusion_model.joint_blocks.{i}.context_block.attn.ln_q.weight": f"blocks.{i}.attn.norm_q_b.weight",
+ f"model.diffusion_model.joint_blocks.{i}.context_block.attn.ln_k.weight": f"blocks.{i}.attn.norm_k_b.weight",
+
+ f"model.diffusion_model.joint_blocks.{i}.x_block.attn2.ln_q.weight": f"blocks.{i}.attn2.norm_q_a.weight",
+ f"model.diffusion_model.joint_blocks.{i}.x_block.attn2.ln_k.weight": f"blocks.{i}.attn2.norm_k_a.weight",
+ f"model.diffusion_model.joint_blocks.{i}.x_block.attn2.qkv.weight": f"blocks.{i}.attn2.a_to_qkv.weight",
+ f"model.diffusion_model.joint_blocks.{i}.x_block.attn2.qkv.bias": f"blocks.{i}.attn2.a_to_qkv.bias",
+ f"model.diffusion_model.joint_blocks.{i}.x_block.attn2.proj.weight": f"blocks.{i}.attn2.a_to_out.weight",
+ f"model.diffusion_model.joint_blocks.{i}.x_block.attn2.proj.bias": f"blocks.{i}.attn2.a_to_out.bias",
+ })
+ state_dict_ = {}
+ for name in state_dict:
+ if name in rename_dict:
+ param = state_dict[name]
+ if name == "model.diffusion_model.pos_embed":
+ pos_embed_max_size = int(param.shape[1] ** 0.5 + 0.4)
+ param = param.reshape((1, pos_embed_max_size, pos_embed_max_size, param.shape[-1]))
+ if isinstance(rename_dict[name], str):
+ state_dict_[rename_dict[name]] = param
+ else:
+ name_ = rename_dict[name][0].replace(".a_to_q.", ".a_to_qkv.").replace(".b_to_q.", ".b_to_qkv.")
+ state_dict_[name_] = param
+ extra_kwargs = self.infer_architecture(state_dict_)
+ num_layers = extra_kwargs["num_layers"]
+ for name in [
+ f"blocks.{num_layers-1}.norm1_b.linear.weight", f"blocks.{num_layers-1}.norm1_b.linear.bias", "norm_out.linear.weight", "norm_out.linear.bias",
+ ]:
+ param = state_dict_[name]
+ dim = param.shape[0] // 2
+ param = torch.concat([param[dim:], param[:dim]], axis=0)
+ state_dict_[name] = param
+ return state_dict_, self.infer_architecture(state_dict_)
diff --git a/PusaV1/diffsynth/models/sd3_text_encoder.py b/PusaV1/diffsynth/models/sd3_text_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..efe29ca8ae99586ae197ee633a9c3d7f2c074f77
--- /dev/null
+++ b/PusaV1/diffsynth/models/sd3_text_encoder.py
@@ -0,0 +1,1120 @@
+import torch
+from transformers import T5EncoderModel, T5Config
+from .sd_text_encoder import SDTextEncoder
+from .sdxl_text_encoder import SDXLTextEncoder2, SDXLTextEncoder2StateDictConverter
+
+
+class SD3TextEncoder1(SDTextEncoder):
+ def __init__(self, vocab_size=49408):
+ super().__init__(vocab_size=vocab_size)
+
+ def forward(self, input_ids, clip_skip=2, extra_mask=None):
+ embeds = self.token_embedding(input_ids)
+ embeds = embeds + self.position_embeds.to(dtype=embeds.dtype, device=input_ids.device)
+ attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
+ if extra_mask is not None:
+ attn_mask[:, extra_mask[0]==0] = float("-inf")
+ for encoder_id, encoder in enumerate(self.encoders):
+ embeds = encoder(embeds, attn_mask=attn_mask)
+ if encoder_id + clip_skip == len(self.encoders):
+ hidden_states = embeds
+ embeds = self.final_layer_norm(embeds)
+ pooled_embeds = embeds[torch.arange(embeds.shape[0]), input_ids.to(dtype=torch.int).argmax(dim=-1)]
+ return pooled_embeds, hidden_states
+
+ @staticmethod
+ def state_dict_converter():
+ return SD3TextEncoder1StateDictConverter()
+
+
+
+class SD3TextEncoder2(SDXLTextEncoder2):
+ def __init__(self):
+ super().__init__()
+
+ @staticmethod
+ def state_dict_converter():
+ return SD3TextEncoder2StateDictConverter()
+
+
+class SD3TextEncoder3(T5EncoderModel):
+ def __init__(self):
+ config = T5Config(
+ _name_or_path = ".",
+ architectures = ["T5EncoderModel"],
+ classifier_dropout = 0.0,
+ d_ff = 10240,
+ d_kv = 64,
+ d_model = 4096,
+ decoder_start_token_id = 0,
+ dense_act_fn = "gelu_new",
+ dropout_rate = 0.1,
+ eos_token_id = 1,
+ feed_forward_proj = "gated-gelu",
+ initializer_factor = 1.0,
+ is_encoder_decoder = True,
+ is_gated_act = True,
+ layer_norm_epsilon = 1e-06,
+ model_type = "t5",
+ num_decoder_layers = 24,
+ num_heads = 64,
+ num_layers = 24,
+ output_past = True,
+ pad_token_id = 0,
+ relative_attention_max_distance = 128,
+ relative_attention_num_buckets = 32,
+ tie_word_embeddings = False,
+ torch_dtype = torch.float16,
+ transformers_version = "4.41.2",
+ use_cache = True,
+ vocab_size = 32128
+ )
+ super().__init__(config)
+ self.eval()
+
+ def forward(self, input_ids):
+ outputs = super().forward(input_ids=input_ids)
+ prompt_emb = outputs.last_hidden_state
+ return prompt_emb
+
+ @staticmethod
+ def state_dict_converter():
+ return SD3TextEncoder3StateDictConverter()
+
+
+
+class SD3TextEncoder1StateDictConverter:
+ def __init__(self):
+ pass
+
+ def from_diffusers(self, state_dict):
+ rename_dict = {
+ "text_model.embeddings.token_embedding.weight": "token_embedding.weight",
+ "text_model.embeddings.position_embedding.weight": "position_embeds",
+ "text_model.final_layer_norm.weight": "final_layer_norm.weight",
+ "text_model.final_layer_norm.bias": "final_layer_norm.bias",
+ }
+ attn_rename_dict = {
+ "self_attn.q_proj": "attn.to_q",
+ "self_attn.k_proj": "attn.to_k",
+ "self_attn.v_proj": "attn.to_v",
+ "self_attn.out_proj": "attn.to_out",
+ "layer_norm1": "layer_norm1",
+ "layer_norm2": "layer_norm2",
+ "mlp.fc1": "fc1",
+ "mlp.fc2": "fc2",
+ }
+ state_dict_ = {}
+ for name in state_dict:
+ if name in rename_dict:
+ param = state_dict[name]
+ if name == "text_model.embeddings.position_embedding.weight":
+ param = param.reshape((1, param.shape[0], param.shape[1]))
+ state_dict_[rename_dict[name]] = param
+ elif name.startswith("text_model.encoder.layers."):
+ param = state_dict[name]
+ names = name.split(".")
+ layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
+ name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail])
+ state_dict_[name_] = param
+ return state_dict_
+
+ def from_civitai(self, state_dict):
+ rename_dict = {
+ "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight": "position_embeds",
+ "text_encoders.clip_l.transformer.text_model.embeddings.token_embedding.weight": "token_embedding.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.0.layer_norm1.bias": "encoders.0.layer_norm1.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.0.layer_norm1.weight": "encoders.0.layer_norm1.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.0.layer_norm2.bias": "encoders.0.layer_norm2.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.0.layer_norm2.weight": "encoders.0.layer_norm2.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.0.mlp.fc1.bias": "encoders.0.fc1.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.0.mlp.fc1.weight": "encoders.0.fc1.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.0.mlp.fc2.bias": "encoders.0.fc2.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.0.mlp.fc2.weight": "encoders.0.fc2.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.0.self_attn.k_proj.bias": "encoders.0.attn.to_k.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.0.self_attn.k_proj.weight": "encoders.0.attn.to_k.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.0.self_attn.out_proj.bias": "encoders.0.attn.to_out.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.0.self_attn.out_proj.weight": "encoders.0.attn.to_out.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.0.self_attn.q_proj.bias": "encoders.0.attn.to_q.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight": "encoders.0.attn.to_q.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.0.self_attn.v_proj.bias": "encoders.0.attn.to_v.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.0.self_attn.v_proj.weight": "encoders.0.attn.to_v.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.1.layer_norm1.bias": "encoders.1.layer_norm1.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.1.layer_norm1.weight": "encoders.1.layer_norm1.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.1.layer_norm2.bias": "encoders.1.layer_norm2.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.1.layer_norm2.weight": "encoders.1.layer_norm2.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.1.mlp.fc1.bias": "encoders.1.fc1.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.1.mlp.fc1.weight": "encoders.1.fc1.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.1.mlp.fc2.bias": "encoders.1.fc2.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.1.mlp.fc2.weight": "encoders.1.fc2.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.1.self_attn.k_proj.bias": "encoders.1.attn.to_k.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.1.self_attn.k_proj.weight": "encoders.1.attn.to_k.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.1.self_attn.out_proj.bias": "encoders.1.attn.to_out.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.1.self_attn.out_proj.weight": "encoders.1.attn.to_out.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.1.self_attn.q_proj.bias": "encoders.1.attn.to_q.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.1.self_attn.q_proj.weight": "encoders.1.attn.to_q.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.1.self_attn.v_proj.bias": "encoders.1.attn.to_v.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.1.self_attn.v_proj.weight": "encoders.1.attn.to_v.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.10.layer_norm1.bias": "encoders.10.layer_norm1.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.10.layer_norm1.weight": "encoders.10.layer_norm1.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.10.layer_norm2.bias": "encoders.10.layer_norm2.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.10.layer_norm2.weight": "encoders.10.layer_norm2.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.10.mlp.fc1.bias": "encoders.10.fc1.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.10.mlp.fc1.weight": "encoders.10.fc1.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.10.mlp.fc2.bias": "encoders.10.fc2.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.10.mlp.fc2.weight": "encoders.10.fc2.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.10.self_attn.k_proj.bias": "encoders.10.attn.to_k.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.10.self_attn.k_proj.weight": "encoders.10.attn.to_k.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.10.self_attn.out_proj.bias": "encoders.10.attn.to_out.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.10.self_attn.out_proj.weight": "encoders.10.attn.to_out.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.10.self_attn.q_proj.bias": "encoders.10.attn.to_q.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.10.self_attn.q_proj.weight": "encoders.10.attn.to_q.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.10.self_attn.v_proj.bias": "encoders.10.attn.to_v.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.10.self_attn.v_proj.weight": "encoders.10.attn.to_v.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.11.layer_norm1.bias": "encoders.11.layer_norm1.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.11.layer_norm1.weight": "encoders.11.layer_norm1.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.11.layer_norm2.bias": "encoders.11.layer_norm2.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.11.layer_norm2.weight": "encoders.11.layer_norm2.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.11.mlp.fc1.bias": "encoders.11.fc1.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.11.mlp.fc1.weight": "encoders.11.fc1.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.11.mlp.fc2.bias": "encoders.11.fc2.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.11.mlp.fc2.weight": "encoders.11.fc2.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.11.self_attn.k_proj.bias": "encoders.11.attn.to_k.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.11.self_attn.k_proj.weight": "encoders.11.attn.to_k.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.11.self_attn.out_proj.bias": "encoders.11.attn.to_out.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.11.self_attn.out_proj.weight": "encoders.11.attn.to_out.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.11.self_attn.q_proj.bias": "encoders.11.attn.to_q.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.11.self_attn.q_proj.weight": "encoders.11.attn.to_q.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.11.self_attn.v_proj.bias": "encoders.11.attn.to_v.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.11.self_attn.v_proj.weight": "encoders.11.attn.to_v.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.2.layer_norm1.bias": "encoders.2.layer_norm1.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.2.layer_norm1.weight": "encoders.2.layer_norm1.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.2.layer_norm2.bias": "encoders.2.layer_norm2.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.2.layer_norm2.weight": "encoders.2.layer_norm2.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.2.mlp.fc1.bias": "encoders.2.fc1.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.2.mlp.fc1.weight": "encoders.2.fc1.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.2.mlp.fc2.bias": "encoders.2.fc2.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.2.mlp.fc2.weight": "encoders.2.fc2.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.2.self_attn.k_proj.bias": "encoders.2.attn.to_k.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.2.self_attn.k_proj.weight": "encoders.2.attn.to_k.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.2.self_attn.out_proj.bias": "encoders.2.attn.to_out.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.2.self_attn.out_proj.weight": "encoders.2.attn.to_out.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.2.self_attn.q_proj.bias": "encoders.2.attn.to_q.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.2.self_attn.q_proj.weight": "encoders.2.attn.to_q.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.2.self_attn.v_proj.bias": "encoders.2.attn.to_v.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.2.self_attn.v_proj.weight": "encoders.2.attn.to_v.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.3.layer_norm1.bias": "encoders.3.layer_norm1.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.3.layer_norm1.weight": "encoders.3.layer_norm1.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.3.layer_norm2.bias": "encoders.3.layer_norm2.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.3.layer_norm2.weight": "encoders.3.layer_norm2.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.3.mlp.fc1.bias": "encoders.3.fc1.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.3.mlp.fc1.weight": "encoders.3.fc1.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.3.mlp.fc2.bias": "encoders.3.fc2.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.3.mlp.fc2.weight": "encoders.3.fc2.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.3.self_attn.k_proj.bias": "encoders.3.attn.to_k.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.3.self_attn.k_proj.weight": "encoders.3.attn.to_k.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.3.self_attn.out_proj.bias": "encoders.3.attn.to_out.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.3.self_attn.out_proj.weight": "encoders.3.attn.to_out.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.3.self_attn.q_proj.bias": "encoders.3.attn.to_q.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.3.self_attn.q_proj.weight": "encoders.3.attn.to_q.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.3.self_attn.v_proj.bias": "encoders.3.attn.to_v.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.3.self_attn.v_proj.weight": "encoders.3.attn.to_v.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.4.layer_norm1.bias": "encoders.4.layer_norm1.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.4.layer_norm1.weight": "encoders.4.layer_norm1.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.4.layer_norm2.bias": "encoders.4.layer_norm2.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.4.layer_norm2.weight": "encoders.4.layer_norm2.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.4.mlp.fc1.bias": "encoders.4.fc1.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.4.mlp.fc1.weight": "encoders.4.fc1.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.4.mlp.fc2.bias": "encoders.4.fc2.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.4.mlp.fc2.weight": "encoders.4.fc2.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.4.self_attn.k_proj.bias": "encoders.4.attn.to_k.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.4.self_attn.k_proj.weight": "encoders.4.attn.to_k.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.4.self_attn.out_proj.bias": "encoders.4.attn.to_out.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.4.self_attn.out_proj.weight": "encoders.4.attn.to_out.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.4.self_attn.q_proj.bias": "encoders.4.attn.to_q.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.4.self_attn.q_proj.weight": "encoders.4.attn.to_q.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.4.self_attn.v_proj.bias": "encoders.4.attn.to_v.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.4.self_attn.v_proj.weight": "encoders.4.attn.to_v.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.5.layer_norm1.bias": "encoders.5.layer_norm1.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.5.layer_norm1.weight": "encoders.5.layer_norm1.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.5.layer_norm2.bias": "encoders.5.layer_norm2.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.5.layer_norm2.weight": "encoders.5.layer_norm2.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.5.mlp.fc1.bias": "encoders.5.fc1.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.5.mlp.fc1.weight": "encoders.5.fc1.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.5.mlp.fc2.bias": "encoders.5.fc2.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.5.mlp.fc2.weight": "encoders.5.fc2.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.5.self_attn.k_proj.bias": "encoders.5.attn.to_k.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.5.self_attn.k_proj.weight": "encoders.5.attn.to_k.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.5.self_attn.out_proj.bias": "encoders.5.attn.to_out.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.5.self_attn.out_proj.weight": "encoders.5.attn.to_out.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.5.self_attn.q_proj.bias": "encoders.5.attn.to_q.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.5.self_attn.q_proj.weight": "encoders.5.attn.to_q.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.5.self_attn.v_proj.bias": "encoders.5.attn.to_v.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.5.self_attn.v_proj.weight": "encoders.5.attn.to_v.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.6.layer_norm1.bias": "encoders.6.layer_norm1.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.6.layer_norm1.weight": "encoders.6.layer_norm1.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.6.layer_norm2.bias": "encoders.6.layer_norm2.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.6.layer_norm2.weight": "encoders.6.layer_norm2.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.6.mlp.fc1.bias": "encoders.6.fc1.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.6.mlp.fc1.weight": "encoders.6.fc1.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.6.mlp.fc2.bias": "encoders.6.fc2.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.6.mlp.fc2.weight": "encoders.6.fc2.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.6.self_attn.k_proj.bias": "encoders.6.attn.to_k.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.6.self_attn.k_proj.weight": "encoders.6.attn.to_k.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.6.self_attn.out_proj.bias": "encoders.6.attn.to_out.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.6.self_attn.out_proj.weight": "encoders.6.attn.to_out.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.6.self_attn.q_proj.bias": "encoders.6.attn.to_q.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.6.self_attn.q_proj.weight": "encoders.6.attn.to_q.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.6.self_attn.v_proj.bias": "encoders.6.attn.to_v.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.6.self_attn.v_proj.weight": "encoders.6.attn.to_v.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.7.layer_norm1.bias": "encoders.7.layer_norm1.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.7.layer_norm1.weight": "encoders.7.layer_norm1.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.7.layer_norm2.bias": "encoders.7.layer_norm2.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.7.layer_norm2.weight": "encoders.7.layer_norm2.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.7.mlp.fc1.bias": "encoders.7.fc1.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.7.mlp.fc1.weight": "encoders.7.fc1.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.7.mlp.fc2.bias": "encoders.7.fc2.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.7.mlp.fc2.weight": "encoders.7.fc2.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.7.self_attn.k_proj.bias": "encoders.7.attn.to_k.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.7.self_attn.k_proj.weight": "encoders.7.attn.to_k.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.7.self_attn.out_proj.bias": "encoders.7.attn.to_out.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.7.self_attn.out_proj.weight": "encoders.7.attn.to_out.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.7.self_attn.q_proj.bias": "encoders.7.attn.to_q.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.7.self_attn.q_proj.weight": "encoders.7.attn.to_q.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.7.self_attn.v_proj.bias": "encoders.7.attn.to_v.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.7.self_attn.v_proj.weight": "encoders.7.attn.to_v.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.8.layer_norm1.bias": "encoders.8.layer_norm1.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.8.layer_norm1.weight": "encoders.8.layer_norm1.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.8.layer_norm2.bias": "encoders.8.layer_norm2.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.8.layer_norm2.weight": "encoders.8.layer_norm2.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.8.mlp.fc1.bias": "encoders.8.fc1.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.8.mlp.fc1.weight": "encoders.8.fc1.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.8.mlp.fc2.bias": "encoders.8.fc2.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.8.mlp.fc2.weight": "encoders.8.fc2.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.8.self_attn.k_proj.bias": "encoders.8.attn.to_k.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.8.self_attn.k_proj.weight": "encoders.8.attn.to_k.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.8.self_attn.out_proj.bias": "encoders.8.attn.to_out.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.8.self_attn.out_proj.weight": "encoders.8.attn.to_out.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.8.self_attn.q_proj.bias": "encoders.8.attn.to_q.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.8.self_attn.q_proj.weight": "encoders.8.attn.to_q.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.8.self_attn.v_proj.bias": "encoders.8.attn.to_v.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.8.self_attn.v_proj.weight": "encoders.8.attn.to_v.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.9.layer_norm1.bias": "encoders.9.layer_norm1.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.9.layer_norm1.weight": "encoders.9.layer_norm1.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.9.layer_norm2.bias": "encoders.9.layer_norm2.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.9.layer_norm2.weight": "encoders.9.layer_norm2.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.9.mlp.fc1.bias": "encoders.9.fc1.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.9.mlp.fc1.weight": "encoders.9.fc1.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.9.mlp.fc2.bias": "encoders.9.fc2.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.9.mlp.fc2.weight": "encoders.9.fc2.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.9.self_attn.k_proj.bias": "encoders.9.attn.to_k.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.9.self_attn.k_proj.weight": "encoders.9.attn.to_k.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.9.self_attn.out_proj.bias": "encoders.9.attn.to_out.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.9.self_attn.out_proj.weight": "encoders.9.attn.to_out.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.9.self_attn.q_proj.bias": "encoders.9.attn.to_q.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.9.self_attn.q_proj.weight": "encoders.9.attn.to_q.weight",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.9.self_attn.v_proj.bias": "encoders.9.attn.to_v.bias",
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight": "encoders.9.attn.to_v.weight",
+ "text_encoders.clip_l.transformer.text_model.final_layer_norm.bias": "final_layer_norm.bias",
+ "text_encoders.clip_l.transformer.text_model.final_layer_norm.weight": "final_layer_norm.weight",
+ }
+ state_dict_ = {}
+ for name in state_dict:
+ if name in rename_dict:
+ param = state_dict[name]
+ if name == "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight":
+ param = param.reshape((1, param.shape[0], param.shape[1]))
+ state_dict_[rename_dict[name]] = param
+ elif ("text_encoders.clip_l.transformer." + name) in rename_dict:
+ param = state_dict[name]
+ if name == "text_model.embeddings.position_embedding.weight":
+ param = param.reshape((1, param.shape[0], param.shape[1]))
+ state_dict_[rename_dict["text_encoders.clip_l.transformer." + name]] = param
+ return state_dict_
+
+
+
+class SD3TextEncoder2StateDictConverter(SDXLTextEncoder2StateDictConverter):
+ def __init__(self):
+ pass
+
+ def from_diffusers(self, state_dict):
+ return super().from_diffusers(state_dict)
+
+ def from_civitai(self, state_dict):
+ rename_dict = {
+ "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight": "position_embeds",
+ "text_encoders.clip_g.transformer.text_model.embeddings.token_embedding.weight": "token_embedding.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.0.layer_norm1.bias": "encoders.0.layer_norm1.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.0.layer_norm1.weight": "encoders.0.layer_norm1.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.0.layer_norm2.bias": "encoders.0.layer_norm2.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.0.layer_norm2.weight": "encoders.0.layer_norm2.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.0.mlp.fc1.bias": "encoders.0.fc1.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.0.mlp.fc1.weight": "encoders.0.fc1.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.0.mlp.fc2.bias": "encoders.0.fc2.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.0.mlp.fc2.weight": "encoders.0.fc2.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.0.self_attn.k_proj.bias": "encoders.0.attn.to_k.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.0.self_attn.k_proj.weight": "encoders.0.attn.to_k.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.0.self_attn.out_proj.bias": "encoders.0.attn.to_out.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.0.self_attn.out_proj.weight": "encoders.0.attn.to_out.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.0.self_attn.q_proj.bias": "encoders.0.attn.to_q.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight": "encoders.0.attn.to_q.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.0.self_attn.v_proj.bias": "encoders.0.attn.to_v.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.0.self_attn.v_proj.weight": "encoders.0.attn.to_v.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.1.layer_norm1.bias": "encoders.1.layer_norm1.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.1.layer_norm1.weight": "encoders.1.layer_norm1.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.1.layer_norm2.bias": "encoders.1.layer_norm2.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.1.layer_norm2.weight": "encoders.1.layer_norm2.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.1.mlp.fc1.bias": "encoders.1.fc1.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.1.mlp.fc1.weight": "encoders.1.fc1.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.1.mlp.fc2.bias": "encoders.1.fc2.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.1.mlp.fc2.weight": "encoders.1.fc2.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.1.self_attn.k_proj.bias": "encoders.1.attn.to_k.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.1.self_attn.k_proj.weight": "encoders.1.attn.to_k.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.1.self_attn.out_proj.bias": "encoders.1.attn.to_out.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.1.self_attn.out_proj.weight": "encoders.1.attn.to_out.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.1.self_attn.q_proj.bias": "encoders.1.attn.to_q.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.1.self_attn.q_proj.weight": "encoders.1.attn.to_q.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.1.self_attn.v_proj.bias": "encoders.1.attn.to_v.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.1.self_attn.v_proj.weight": "encoders.1.attn.to_v.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.10.layer_norm1.bias": "encoders.10.layer_norm1.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.10.layer_norm1.weight": "encoders.10.layer_norm1.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.10.layer_norm2.bias": "encoders.10.layer_norm2.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.10.layer_norm2.weight": "encoders.10.layer_norm2.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.10.mlp.fc1.bias": "encoders.10.fc1.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.10.mlp.fc1.weight": "encoders.10.fc1.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.10.mlp.fc2.bias": "encoders.10.fc2.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.10.mlp.fc2.weight": "encoders.10.fc2.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.10.self_attn.k_proj.bias": "encoders.10.attn.to_k.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.10.self_attn.k_proj.weight": "encoders.10.attn.to_k.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.10.self_attn.out_proj.bias": "encoders.10.attn.to_out.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.10.self_attn.out_proj.weight": "encoders.10.attn.to_out.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.10.self_attn.q_proj.bias": "encoders.10.attn.to_q.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.10.self_attn.q_proj.weight": "encoders.10.attn.to_q.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.10.self_attn.v_proj.bias": "encoders.10.attn.to_v.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.10.self_attn.v_proj.weight": "encoders.10.attn.to_v.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.11.layer_norm1.bias": "encoders.11.layer_norm1.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.11.layer_norm1.weight": "encoders.11.layer_norm1.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.11.layer_norm2.bias": "encoders.11.layer_norm2.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.11.layer_norm2.weight": "encoders.11.layer_norm2.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.11.mlp.fc1.bias": "encoders.11.fc1.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.11.mlp.fc1.weight": "encoders.11.fc1.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.11.mlp.fc2.bias": "encoders.11.fc2.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.11.mlp.fc2.weight": "encoders.11.fc2.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.11.self_attn.k_proj.bias": "encoders.11.attn.to_k.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.11.self_attn.k_proj.weight": "encoders.11.attn.to_k.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.11.self_attn.out_proj.bias": "encoders.11.attn.to_out.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.11.self_attn.out_proj.weight": "encoders.11.attn.to_out.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.11.self_attn.q_proj.bias": "encoders.11.attn.to_q.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.11.self_attn.q_proj.weight": "encoders.11.attn.to_q.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.11.self_attn.v_proj.bias": "encoders.11.attn.to_v.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.11.self_attn.v_proj.weight": "encoders.11.attn.to_v.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.12.layer_norm1.bias": "encoders.12.layer_norm1.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.12.layer_norm1.weight": "encoders.12.layer_norm1.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.12.layer_norm2.bias": "encoders.12.layer_norm2.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.12.layer_norm2.weight": "encoders.12.layer_norm2.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.12.mlp.fc1.bias": "encoders.12.fc1.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.12.mlp.fc1.weight": "encoders.12.fc1.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.12.mlp.fc2.bias": "encoders.12.fc2.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.12.mlp.fc2.weight": "encoders.12.fc2.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.12.self_attn.k_proj.bias": "encoders.12.attn.to_k.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.12.self_attn.k_proj.weight": "encoders.12.attn.to_k.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.12.self_attn.out_proj.bias": "encoders.12.attn.to_out.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.12.self_attn.out_proj.weight": "encoders.12.attn.to_out.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.12.self_attn.q_proj.bias": "encoders.12.attn.to_q.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.12.self_attn.q_proj.weight": "encoders.12.attn.to_q.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.12.self_attn.v_proj.bias": "encoders.12.attn.to_v.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.12.self_attn.v_proj.weight": "encoders.12.attn.to_v.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.13.layer_norm1.bias": "encoders.13.layer_norm1.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.13.layer_norm1.weight": "encoders.13.layer_norm1.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.13.layer_norm2.bias": "encoders.13.layer_norm2.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.13.layer_norm2.weight": "encoders.13.layer_norm2.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.13.mlp.fc1.bias": "encoders.13.fc1.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.13.mlp.fc1.weight": "encoders.13.fc1.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.13.mlp.fc2.bias": "encoders.13.fc2.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.13.mlp.fc2.weight": "encoders.13.fc2.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.13.self_attn.k_proj.bias": "encoders.13.attn.to_k.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.13.self_attn.k_proj.weight": "encoders.13.attn.to_k.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.13.self_attn.out_proj.bias": "encoders.13.attn.to_out.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.13.self_attn.out_proj.weight": "encoders.13.attn.to_out.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.13.self_attn.q_proj.bias": "encoders.13.attn.to_q.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.13.self_attn.q_proj.weight": "encoders.13.attn.to_q.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.13.self_attn.v_proj.bias": "encoders.13.attn.to_v.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.13.self_attn.v_proj.weight": "encoders.13.attn.to_v.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.14.layer_norm1.bias": "encoders.14.layer_norm1.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.14.layer_norm1.weight": "encoders.14.layer_norm1.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.14.layer_norm2.bias": "encoders.14.layer_norm2.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.14.layer_norm2.weight": "encoders.14.layer_norm2.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.14.mlp.fc1.bias": "encoders.14.fc1.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.14.mlp.fc1.weight": "encoders.14.fc1.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.14.mlp.fc2.bias": "encoders.14.fc2.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.14.mlp.fc2.weight": "encoders.14.fc2.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.14.self_attn.k_proj.bias": "encoders.14.attn.to_k.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.14.self_attn.k_proj.weight": "encoders.14.attn.to_k.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.14.self_attn.out_proj.bias": "encoders.14.attn.to_out.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.14.self_attn.out_proj.weight": "encoders.14.attn.to_out.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.14.self_attn.q_proj.bias": "encoders.14.attn.to_q.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.14.self_attn.q_proj.weight": "encoders.14.attn.to_q.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.14.self_attn.v_proj.bias": "encoders.14.attn.to_v.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.14.self_attn.v_proj.weight": "encoders.14.attn.to_v.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.15.layer_norm1.bias": "encoders.15.layer_norm1.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.15.layer_norm1.weight": "encoders.15.layer_norm1.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.15.layer_norm2.bias": "encoders.15.layer_norm2.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.15.layer_norm2.weight": "encoders.15.layer_norm2.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.15.mlp.fc1.bias": "encoders.15.fc1.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.15.mlp.fc1.weight": "encoders.15.fc1.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.15.mlp.fc2.bias": "encoders.15.fc2.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.15.mlp.fc2.weight": "encoders.15.fc2.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.15.self_attn.k_proj.bias": "encoders.15.attn.to_k.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.15.self_attn.k_proj.weight": "encoders.15.attn.to_k.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.15.self_attn.out_proj.bias": "encoders.15.attn.to_out.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.15.self_attn.out_proj.weight": "encoders.15.attn.to_out.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.15.self_attn.q_proj.bias": "encoders.15.attn.to_q.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.15.self_attn.q_proj.weight": "encoders.15.attn.to_q.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.15.self_attn.v_proj.bias": "encoders.15.attn.to_v.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.15.self_attn.v_proj.weight": "encoders.15.attn.to_v.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.16.layer_norm1.bias": "encoders.16.layer_norm1.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.16.layer_norm1.weight": "encoders.16.layer_norm1.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.16.layer_norm2.bias": "encoders.16.layer_norm2.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.16.layer_norm2.weight": "encoders.16.layer_norm2.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.16.mlp.fc1.bias": "encoders.16.fc1.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.16.mlp.fc1.weight": "encoders.16.fc1.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.16.mlp.fc2.bias": "encoders.16.fc2.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.16.mlp.fc2.weight": "encoders.16.fc2.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.16.self_attn.k_proj.bias": "encoders.16.attn.to_k.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.16.self_attn.k_proj.weight": "encoders.16.attn.to_k.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.16.self_attn.out_proj.bias": "encoders.16.attn.to_out.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.16.self_attn.out_proj.weight": "encoders.16.attn.to_out.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.16.self_attn.q_proj.bias": "encoders.16.attn.to_q.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.16.self_attn.q_proj.weight": "encoders.16.attn.to_q.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.16.self_attn.v_proj.bias": "encoders.16.attn.to_v.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.16.self_attn.v_proj.weight": "encoders.16.attn.to_v.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.17.layer_norm1.bias": "encoders.17.layer_norm1.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.17.layer_norm1.weight": "encoders.17.layer_norm1.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.17.layer_norm2.bias": "encoders.17.layer_norm2.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.17.layer_norm2.weight": "encoders.17.layer_norm2.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.17.mlp.fc1.bias": "encoders.17.fc1.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.17.mlp.fc1.weight": "encoders.17.fc1.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.17.mlp.fc2.bias": "encoders.17.fc2.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.17.mlp.fc2.weight": "encoders.17.fc2.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.17.self_attn.k_proj.bias": "encoders.17.attn.to_k.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.17.self_attn.k_proj.weight": "encoders.17.attn.to_k.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.17.self_attn.out_proj.bias": "encoders.17.attn.to_out.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.17.self_attn.out_proj.weight": "encoders.17.attn.to_out.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.17.self_attn.q_proj.bias": "encoders.17.attn.to_q.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.17.self_attn.q_proj.weight": "encoders.17.attn.to_q.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.17.self_attn.v_proj.bias": "encoders.17.attn.to_v.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.17.self_attn.v_proj.weight": "encoders.17.attn.to_v.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.18.layer_norm1.bias": "encoders.18.layer_norm1.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.18.layer_norm1.weight": "encoders.18.layer_norm1.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.18.layer_norm2.bias": "encoders.18.layer_norm2.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.18.layer_norm2.weight": "encoders.18.layer_norm2.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.18.mlp.fc1.bias": "encoders.18.fc1.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.18.mlp.fc1.weight": "encoders.18.fc1.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.18.mlp.fc2.bias": "encoders.18.fc2.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.18.mlp.fc2.weight": "encoders.18.fc2.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.18.self_attn.k_proj.bias": "encoders.18.attn.to_k.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.18.self_attn.k_proj.weight": "encoders.18.attn.to_k.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.18.self_attn.out_proj.bias": "encoders.18.attn.to_out.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.18.self_attn.out_proj.weight": "encoders.18.attn.to_out.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.18.self_attn.q_proj.bias": "encoders.18.attn.to_q.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.18.self_attn.q_proj.weight": "encoders.18.attn.to_q.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.18.self_attn.v_proj.bias": "encoders.18.attn.to_v.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.18.self_attn.v_proj.weight": "encoders.18.attn.to_v.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.19.layer_norm1.bias": "encoders.19.layer_norm1.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.19.layer_norm1.weight": "encoders.19.layer_norm1.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.19.layer_norm2.bias": "encoders.19.layer_norm2.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.19.layer_norm2.weight": "encoders.19.layer_norm2.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.19.mlp.fc1.bias": "encoders.19.fc1.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.19.mlp.fc1.weight": "encoders.19.fc1.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.19.mlp.fc2.bias": "encoders.19.fc2.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.19.mlp.fc2.weight": "encoders.19.fc2.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.19.self_attn.k_proj.bias": "encoders.19.attn.to_k.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.19.self_attn.k_proj.weight": "encoders.19.attn.to_k.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.19.self_attn.out_proj.bias": "encoders.19.attn.to_out.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.19.self_attn.out_proj.weight": "encoders.19.attn.to_out.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.19.self_attn.q_proj.bias": "encoders.19.attn.to_q.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.19.self_attn.q_proj.weight": "encoders.19.attn.to_q.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.19.self_attn.v_proj.bias": "encoders.19.attn.to_v.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.19.self_attn.v_proj.weight": "encoders.19.attn.to_v.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.2.layer_norm1.bias": "encoders.2.layer_norm1.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.2.layer_norm1.weight": "encoders.2.layer_norm1.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.2.layer_norm2.bias": "encoders.2.layer_norm2.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.2.layer_norm2.weight": "encoders.2.layer_norm2.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.2.mlp.fc1.bias": "encoders.2.fc1.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.2.mlp.fc1.weight": "encoders.2.fc1.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.2.mlp.fc2.bias": "encoders.2.fc2.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.2.mlp.fc2.weight": "encoders.2.fc2.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.2.self_attn.k_proj.bias": "encoders.2.attn.to_k.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.2.self_attn.k_proj.weight": "encoders.2.attn.to_k.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.2.self_attn.out_proj.bias": "encoders.2.attn.to_out.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.2.self_attn.out_proj.weight": "encoders.2.attn.to_out.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.2.self_attn.q_proj.bias": "encoders.2.attn.to_q.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.2.self_attn.q_proj.weight": "encoders.2.attn.to_q.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.2.self_attn.v_proj.bias": "encoders.2.attn.to_v.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.2.self_attn.v_proj.weight": "encoders.2.attn.to_v.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.20.layer_norm1.bias": "encoders.20.layer_norm1.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.20.layer_norm1.weight": "encoders.20.layer_norm1.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.20.layer_norm2.bias": "encoders.20.layer_norm2.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.20.layer_norm2.weight": "encoders.20.layer_norm2.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.20.mlp.fc1.bias": "encoders.20.fc1.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.20.mlp.fc1.weight": "encoders.20.fc1.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.20.mlp.fc2.bias": "encoders.20.fc2.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.20.mlp.fc2.weight": "encoders.20.fc2.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.20.self_attn.k_proj.bias": "encoders.20.attn.to_k.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.20.self_attn.k_proj.weight": "encoders.20.attn.to_k.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.20.self_attn.out_proj.bias": "encoders.20.attn.to_out.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.20.self_attn.out_proj.weight": "encoders.20.attn.to_out.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.20.self_attn.q_proj.bias": "encoders.20.attn.to_q.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.20.self_attn.q_proj.weight": "encoders.20.attn.to_q.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.20.self_attn.v_proj.bias": "encoders.20.attn.to_v.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.20.self_attn.v_proj.weight": "encoders.20.attn.to_v.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.21.layer_norm1.bias": "encoders.21.layer_norm1.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.21.layer_norm1.weight": "encoders.21.layer_norm1.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.21.layer_norm2.bias": "encoders.21.layer_norm2.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.21.layer_norm2.weight": "encoders.21.layer_norm2.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.21.mlp.fc1.bias": "encoders.21.fc1.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.21.mlp.fc1.weight": "encoders.21.fc1.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.21.mlp.fc2.bias": "encoders.21.fc2.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.21.mlp.fc2.weight": "encoders.21.fc2.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.21.self_attn.k_proj.bias": "encoders.21.attn.to_k.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.21.self_attn.k_proj.weight": "encoders.21.attn.to_k.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.21.self_attn.out_proj.bias": "encoders.21.attn.to_out.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.21.self_attn.out_proj.weight": "encoders.21.attn.to_out.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.21.self_attn.q_proj.bias": "encoders.21.attn.to_q.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.21.self_attn.q_proj.weight": "encoders.21.attn.to_q.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.21.self_attn.v_proj.bias": "encoders.21.attn.to_v.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.21.self_attn.v_proj.weight": "encoders.21.attn.to_v.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.22.layer_norm1.bias": "encoders.22.layer_norm1.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.22.layer_norm1.weight": "encoders.22.layer_norm1.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.22.layer_norm2.bias": "encoders.22.layer_norm2.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.22.layer_norm2.weight": "encoders.22.layer_norm2.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.22.mlp.fc1.bias": "encoders.22.fc1.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.22.mlp.fc1.weight": "encoders.22.fc1.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.22.mlp.fc2.bias": "encoders.22.fc2.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.22.mlp.fc2.weight": "encoders.22.fc2.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.22.self_attn.k_proj.bias": "encoders.22.attn.to_k.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.22.self_attn.k_proj.weight": "encoders.22.attn.to_k.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.22.self_attn.out_proj.bias": "encoders.22.attn.to_out.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.22.self_attn.out_proj.weight": "encoders.22.attn.to_out.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.22.self_attn.q_proj.bias": "encoders.22.attn.to_q.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.22.self_attn.q_proj.weight": "encoders.22.attn.to_q.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.22.self_attn.v_proj.bias": "encoders.22.attn.to_v.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.22.self_attn.v_proj.weight": "encoders.22.attn.to_v.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.23.layer_norm1.bias": "encoders.23.layer_norm1.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.23.layer_norm1.weight": "encoders.23.layer_norm1.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.23.layer_norm2.bias": "encoders.23.layer_norm2.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.23.layer_norm2.weight": "encoders.23.layer_norm2.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.23.mlp.fc1.bias": "encoders.23.fc1.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.23.mlp.fc1.weight": "encoders.23.fc1.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.23.mlp.fc2.bias": "encoders.23.fc2.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.23.mlp.fc2.weight": "encoders.23.fc2.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.23.self_attn.k_proj.bias": "encoders.23.attn.to_k.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.23.self_attn.k_proj.weight": "encoders.23.attn.to_k.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.23.self_attn.out_proj.bias": "encoders.23.attn.to_out.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.23.self_attn.out_proj.weight": "encoders.23.attn.to_out.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.23.self_attn.q_proj.bias": "encoders.23.attn.to_q.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.23.self_attn.q_proj.weight": "encoders.23.attn.to_q.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.23.self_attn.v_proj.bias": "encoders.23.attn.to_v.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.23.self_attn.v_proj.weight": "encoders.23.attn.to_v.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.24.layer_norm1.bias": "encoders.24.layer_norm1.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.24.layer_norm1.weight": "encoders.24.layer_norm1.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.24.layer_norm2.bias": "encoders.24.layer_norm2.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.24.layer_norm2.weight": "encoders.24.layer_norm2.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.24.mlp.fc1.bias": "encoders.24.fc1.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.24.mlp.fc1.weight": "encoders.24.fc1.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.24.mlp.fc2.bias": "encoders.24.fc2.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.24.mlp.fc2.weight": "encoders.24.fc2.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.24.self_attn.k_proj.bias": "encoders.24.attn.to_k.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.24.self_attn.k_proj.weight": "encoders.24.attn.to_k.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.24.self_attn.out_proj.bias": "encoders.24.attn.to_out.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.24.self_attn.out_proj.weight": "encoders.24.attn.to_out.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.24.self_attn.q_proj.bias": "encoders.24.attn.to_q.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.24.self_attn.q_proj.weight": "encoders.24.attn.to_q.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.24.self_attn.v_proj.bias": "encoders.24.attn.to_v.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.24.self_attn.v_proj.weight": "encoders.24.attn.to_v.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.25.layer_norm1.bias": "encoders.25.layer_norm1.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.25.layer_norm1.weight": "encoders.25.layer_norm1.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.25.layer_norm2.bias": "encoders.25.layer_norm2.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.25.layer_norm2.weight": "encoders.25.layer_norm2.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.25.mlp.fc1.bias": "encoders.25.fc1.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.25.mlp.fc1.weight": "encoders.25.fc1.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.25.mlp.fc2.bias": "encoders.25.fc2.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.25.mlp.fc2.weight": "encoders.25.fc2.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.25.self_attn.k_proj.bias": "encoders.25.attn.to_k.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.25.self_attn.k_proj.weight": "encoders.25.attn.to_k.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.25.self_attn.out_proj.bias": "encoders.25.attn.to_out.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.25.self_attn.out_proj.weight": "encoders.25.attn.to_out.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.25.self_attn.q_proj.bias": "encoders.25.attn.to_q.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.25.self_attn.q_proj.weight": "encoders.25.attn.to_q.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.25.self_attn.v_proj.bias": "encoders.25.attn.to_v.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.25.self_attn.v_proj.weight": "encoders.25.attn.to_v.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.26.layer_norm1.bias": "encoders.26.layer_norm1.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.26.layer_norm1.weight": "encoders.26.layer_norm1.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.26.layer_norm2.bias": "encoders.26.layer_norm2.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.26.layer_norm2.weight": "encoders.26.layer_norm2.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.26.mlp.fc1.bias": "encoders.26.fc1.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.26.mlp.fc1.weight": "encoders.26.fc1.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.26.mlp.fc2.bias": "encoders.26.fc2.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.26.mlp.fc2.weight": "encoders.26.fc2.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.26.self_attn.k_proj.bias": "encoders.26.attn.to_k.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.26.self_attn.k_proj.weight": "encoders.26.attn.to_k.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.26.self_attn.out_proj.bias": "encoders.26.attn.to_out.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.26.self_attn.out_proj.weight": "encoders.26.attn.to_out.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.26.self_attn.q_proj.bias": "encoders.26.attn.to_q.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.26.self_attn.q_proj.weight": "encoders.26.attn.to_q.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.26.self_attn.v_proj.bias": "encoders.26.attn.to_v.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.26.self_attn.v_proj.weight": "encoders.26.attn.to_v.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.27.layer_norm1.bias": "encoders.27.layer_norm1.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.27.layer_norm1.weight": "encoders.27.layer_norm1.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.27.layer_norm2.bias": "encoders.27.layer_norm2.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.27.layer_norm2.weight": "encoders.27.layer_norm2.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.27.mlp.fc1.bias": "encoders.27.fc1.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.27.mlp.fc1.weight": "encoders.27.fc1.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.27.mlp.fc2.bias": "encoders.27.fc2.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.27.mlp.fc2.weight": "encoders.27.fc2.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.27.self_attn.k_proj.bias": "encoders.27.attn.to_k.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.27.self_attn.k_proj.weight": "encoders.27.attn.to_k.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.27.self_attn.out_proj.bias": "encoders.27.attn.to_out.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.27.self_attn.out_proj.weight": "encoders.27.attn.to_out.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.27.self_attn.q_proj.bias": "encoders.27.attn.to_q.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.27.self_attn.q_proj.weight": "encoders.27.attn.to_q.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.27.self_attn.v_proj.bias": "encoders.27.attn.to_v.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.27.self_attn.v_proj.weight": "encoders.27.attn.to_v.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.28.layer_norm1.bias": "encoders.28.layer_norm1.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.28.layer_norm1.weight": "encoders.28.layer_norm1.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.28.layer_norm2.bias": "encoders.28.layer_norm2.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.28.layer_norm2.weight": "encoders.28.layer_norm2.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.28.mlp.fc1.bias": "encoders.28.fc1.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.28.mlp.fc1.weight": "encoders.28.fc1.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.28.mlp.fc2.bias": "encoders.28.fc2.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.28.mlp.fc2.weight": "encoders.28.fc2.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.28.self_attn.k_proj.bias": "encoders.28.attn.to_k.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.28.self_attn.k_proj.weight": "encoders.28.attn.to_k.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.28.self_attn.out_proj.bias": "encoders.28.attn.to_out.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.28.self_attn.out_proj.weight": "encoders.28.attn.to_out.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.28.self_attn.q_proj.bias": "encoders.28.attn.to_q.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.28.self_attn.q_proj.weight": "encoders.28.attn.to_q.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.28.self_attn.v_proj.bias": "encoders.28.attn.to_v.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.28.self_attn.v_proj.weight": "encoders.28.attn.to_v.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.29.layer_norm1.bias": "encoders.29.layer_norm1.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.29.layer_norm1.weight": "encoders.29.layer_norm1.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.29.layer_norm2.bias": "encoders.29.layer_norm2.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.29.layer_norm2.weight": "encoders.29.layer_norm2.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.29.mlp.fc1.bias": "encoders.29.fc1.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.29.mlp.fc1.weight": "encoders.29.fc1.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.29.mlp.fc2.bias": "encoders.29.fc2.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.29.mlp.fc2.weight": "encoders.29.fc2.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.29.self_attn.k_proj.bias": "encoders.29.attn.to_k.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.29.self_attn.k_proj.weight": "encoders.29.attn.to_k.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.29.self_attn.out_proj.bias": "encoders.29.attn.to_out.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.29.self_attn.out_proj.weight": "encoders.29.attn.to_out.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.29.self_attn.q_proj.bias": "encoders.29.attn.to_q.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.29.self_attn.q_proj.weight": "encoders.29.attn.to_q.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.29.self_attn.v_proj.bias": "encoders.29.attn.to_v.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.29.self_attn.v_proj.weight": "encoders.29.attn.to_v.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.3.layer_norm1.bias": "encoders.3.layer_norm1.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.3.layer_norm1.weight": "encoders.3.layer_norm1.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.3.layer_norm2.bias": "encoders.3.layer_norm2.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.3.layer_norm2.weight": "encoders.3.layer_norm2.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.3.mlp.fc1.bias": "encoders.3.fc1.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.3.mlp.fc1.weight": "encoders.3.fc1.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.3.mlp.fc2.bias": "encoders.3.fc2.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.3.mlp.fc2.weight": "encoders.3.fc2.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.3.self_attn.k_proj.bias": "encoders.3.attn.to_k.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.3.self_attn.k_proj.weight": "encoders.3.attn.to_k.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.3.self_attn.out_proj.bias": "encoders.3.attn.to_out.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.3.self_attn.out_proj.weight": "encoders.3.attn.to_out.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.3.self_attn.q_proj.bias": "encoders.3.attn.to_q.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.3.self_attn.q_proj.weight": "encoders.3.attn.to_q.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.3.self_attn.v_proj.bias": "encoders.3.attn.to_v.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.3.self_attn.v_proj.weight": "encoders.3.attn.to_v.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.30.layer_norm1.bias": "encoders.30.layer_norm1.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.30.layer_norm1.weight": "encoders.30.layer_norm1.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.30.layer_norm2.bias": "encoders.30.layer_norm2.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.30.layer_norm2.weight": "encoders.30.layer_norm2.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.30.mlp.fc1.bias": "encoders.30.fc1.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.30.mlp.fc1.weight": "encoders.30.fc1.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.30.mlp.fc2.bias": "encoders.30.fc2.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.30.mlp.fc2.weight": "encoders.30.fc2.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.30.self_attn.k_proj.bias": "encoders.30.attn.to_k.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.30.self_attn.k_proj.weight": "encoders.30.attn.to_k.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.30.self_attn.out_proj.bias": "encoders.30.attn.to_out.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.30.self_attn.out_proj.weight": "encoders.30.attn.to_out.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.30.self_attn.q_proj.bias": "encoders.30.attn.to_q.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.30.self_attn.q_proj.weight": "encoders.30.attn.to_q.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.30.self_attn.v_proj.bias": "encoders.30.attn.to_v.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.30.self_attn.v_proj.weight": "encoders.30.attn.to_v.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.31.layer_norm1.bias": "encoders.31.layer_norm1.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.31.layer_norm1.weight": "encoders.31.layer_norm1.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.31.layer_norm2.bias": "encoders.31.layer_norm2.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.31.layer_norm2.weight": "encoders.31.layer_norm2.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.31.mlp.fc1.bias": "encoders.31.fc1.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.31.mlp.fc1.weight": "encoders.31.fc1.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.31.mlp.fc2.bias": "encoders.31.fc2.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.31.mlp.fc2.weight": "encoders.31.fc2.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.31.self_attn.k_proj.bias": "encoders.31.attn.to_k.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.31.self_attn.k_proj.weight": "encoders.31.attn.to_k.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.31.self_attn.out_proj.bias": "encoders.31.attn.to_out.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.31.self_attn.out_proj.weight": "encoders.31.attn.to_out.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.31.self_attn.q_proj.bias": "encoders.31.attn.to_q.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.31.self_attn.q_proj.weight": "encoders.31.attn.to_q.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.31.self_attn.v_proj.bias": "encoders.31.attn.to_v.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.31.self_attn.v_proj.weight": "encoders.31.attn.to_v.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.4.layer_norm1.bias": "encoders.4.layer_norm1.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.4.layer_norm1.weight": "encoders.4.layer_norm1.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.4.layer_norm2.bias": "encoders.4.layer_norm2.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.4.layer_norm2.weight": "encoders.4.layer_norm2.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.4.mlp.fc1.bias": "encoders.4.fc1.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.4.mlp.fc1.weight": "encoders.4.fc1.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.4.mlp.fc2.bias": "encoders.4.fc2.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.4.mlp.fc2.weight": "encoders.4.fc2.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.4.self_attn.k_proj.bias": "encoders.4.attn.to_k.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.4.self_attn.k_proj.weight": "encoders.4.attn.to_k.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.4.self_attn.out_proj.bias": "encoders.4.attn.to_out.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.4.self_attn.out_proj.weight": "encoders.4.attn.to_out.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.4.self_attn.q_proj.bias": "encoders.4.attn.to_q.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.4.self_attn.q_proj.weight": "encoders.4.attn.to_q.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.4.self_attn.v_proj.bias": "encoders.4.attn.to_v.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.4.self_attn.v_proj.weight": "encoders.4.attn.to_v.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.5.layer_norm1.bias": "encoders.5.layer_norm1.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.5.layer_norm1.weight": "encoders.5.layer_norm1.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.5.layer_norm2.bias": "encoders.5.layer_norm2.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.5.layer_norm2.weight": "encoders.5.layer_norm2.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.5.mlp.fc1.bias": "encoders.5.fc1.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.5.mlp.fc1.weight": "encoders.5.fc1.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.5.mlp.fc2.bias": "encoders.5.fc2.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.5.mlp.fc2.weight": "encoders.5.fc2.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.5.self_attn.k_proj.bias": "encoders.5.attn.to_k.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.5.self_attn.k_proj.weight": "encoders.5.attn.to_k.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.5.self_attn.out_proj.bias": "encoders.5.attn.to_out.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.5.self_attn.out_proj.weight": "encoders.5.attn.to_out.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.5.self_attn.q_proj.bias": "encoders.5.attn.to_q.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.5.self_attn.q_proj.weight": "encoders.5.attn.to_q.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.5.self_attn.v_proj.bias": "encoders.5.attn.to_v.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.5.self_attn.v_proj.weight": "encoders.5.attn.to_v.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.6.layer_norm1.bias": "encoders.6.layer_norm1.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.6.layer_norm1.weight": "encoders.6.layer_norm1.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.6.layer_norm2.bias": "encoders.6.layer_norm2.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.6.layer_norm2.weight": "encoders.6.layer_norm2.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.6.mlp.fc1.bias": "encoders.6.fc1.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.6.mlp.fc1.weight": "encoders.6.fc1.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.6.mlp.fc2.bias": "encoders.6.fc2.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.6.mlp.fc2.weight": "encoders.6.fc2.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.6.self_attn.k_proj.bias": "encoders.6.attn.to_k.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.6.self_attn.k_proj.weight": "encoders.6.attn.to_k.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.6.self_attn.out_proj.bias": "encoders.6.attn.to_out.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.6.self_attn.out_proj.weight": "encoders.6.attn.to_out.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.6.self_attn.q_proj.bias": "encoders.6.attn.to_q.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.6.self_attn.q_proj.weight": "encoders.6.attn.to_q.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.6.self_attn.v_proj.bias": "encoders.6.attn.to_v.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.6.self_attn.v_proj.weight": "encoders.6.attn.to_v.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.7.layer_norm1.bias": "encoders.7.layer_norm1.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.7.layer_norm1.weight": "encoders.7.layer_norm1.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.7.layer_norm2.bias": "encoders.7.layer_norm2.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.7.layer_norm2.weight": "encoders.7.layer_norm2.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.7.mlp.fc1.bias": "encoders.7.fc1.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.7.mlp.fc1.weight": "encoders.7.fc1.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.7.mlp.fc2.bias": "encoders.7.fc2.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.7.mlp.fc2.weight": "encoders.7.fc2.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.7.self_attn.k_proj.bias": "encoders.7.attn.to_k.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.7.self_attn.k_proj.weight": "encoders.7.attn.to_k.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.7.self_attn.out_proj.bias": "encoders.7.attn.to_out.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.7.self_attn.out_proj.weight": "encoders.7.attn.to_out.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.7.self_attn.q_proj.bias": "encoders.7.attn.to_q.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.7.self_attn.q_proj.weight": "encoders.7.attn.to_q.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.7.self_attn.v_proj.bias": "encoders.7.attn.to_v.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.7.self_attn.v_proj.weight": "encoders.7.attn.to_v.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.8.layer_norm1.bias": "encoders.8.layer_norm1.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.8.layer_norm1.weight": "encoders.8.layer_norm1.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.8.layer_norm2.bias": "encoders.8.layer_norm2.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.8.layer_norm2.weight": "encoders.8.layer_norm2.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.8.mlp.fc1.bias": "encoders.8.fc1.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.8.mlp.fc1.weight": "encoders.8.fc1.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.8.mlp.fc2.bias": "encoders.8.fc2.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.8.mlp.fc2.weight": "encoders.8.fc2.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.8.self_attn.k_proj.bias": "encoders.8.attn.to_k.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.8.self_attn.k_proj.weight": "encoders.8.attn.to_k.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.8.self_attn.out_proj.bias": "encoders.8.attn.to_out.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.8.self_attn.out_proj.weight": "encoders.8.attn.to_out.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.8.self_attn.q_proj.bias": "encoders.8.attn.to_q.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.8.self_attn.q_proj.weight": "encoders.8.attn.to_q.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.8.self_attn.v_proj.bias": "encoders.8.attn.to_v.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.8.self_attn.v_proj.weight": "encoders.8.attn.to_v.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.9.layer_norm1.bias": "encoders.9.layer_norm1.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.9.layer_norm1.weight": "encoders.9.layer_norm1.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.9.layer_norm2.bias": "encoders.9.layer_norm2.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.9.layer_norm2.weight": "encoders.9.layer_norm2.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.9.mlp.fc1.bias": "encoders.9.fc1.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.9.mlp.fc1.weight": "encoders.9.fc1.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.9.mlp.fc2.bias": "encoders.9.fc2.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.9.mlp.fc2.weight": "encoders.9.fc2.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.9.self_attn.k_proj.bias": "encoders.9.attn.to_k.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.9.self_attn.k_proj.weight": "encoders.9.attn.to_k.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.9.self_attn.out_proj.bias": "encoders.9.attn.to_out.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.9.self_attn.out_proj.weight": "encoders.9.attn.to_out.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.9.self_attn.q_proj.bias": "encoders.9.attn.to_q.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.9.self_attn.q_proj.weight": "encoders.9.attn.to_q.weight",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.9.self_attn.v_proj.bias": "encoders.9.attn.to_v.bias",
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight": "encoders.9.attn.to_v.weight",
+ "text_encoders.clip_g.transformer.text_model.final_layer_norm.bias": "final_layer_norm.bias",
+ "text_encoders.clip_g.transformer.text_model.final_layer_norm.weight": "final_layer_norm.weight",
+ "text_encoders.clip_g.transformer.text_projection.weight": "text_projection.weight",
+ }
+ state_dict_ = {}
+ for name in state_dict:
+ if name in rename_dict:
+ param = state_dict[name]
+ if name == "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight":
+ param = param.reshape((1, param.shape[0], param.shape[1]))
+ state_dict_[rename_dict[name]] = param
+ elif ("text_encoders.clip_g.transformer." + name) in rename_dict:
+ param = state_dict[name]
+ if name == "text_model.embeddings.position_embedding.weight":
+ param = param.reshape((1, param.shape[0], param.shape[1]))
+ state_dict_[rename_dict["text_encoders.clip_g.transformer." + name]] = param
+ return state_dict_
+
+
+
+class SD3TextEncoder3StateDictConverter():
+ def __init__(self):
+ pass
+
+ def from_diffusers(self, state_dict):
+ state_dict_ = state_dict
+ return state_dict_
+
+ def from_civitai(self, state_dict):
+ prefix = "text_encoders.t5xxl.transformer."
+ state_dict_ = {name[len(prefix):]: param for name, param in state_dict.items() if name.startswith(prefix)}
+ if len(state_dict_) > 0:
+ return self.from_diffusers(state_dict_)
+ name_list = [
+ "encoder.block.0.layer.0.SelfAttention.k.weight",
+ "encoder.block.0.layer.0.SelfAttention.o.weight",
+ "encoder.block.0.layer.0.SelfAttention.q.weight",
+ "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight",
+ "encoder.block.0.layer.0.SelfAttention.v.weight",
+ "encoder.block.0.layer.0.layer_norm.weight",
+ "encoder.block.0.layer.1.DenseReluDense.wi_0.weight",
+ "encoder.block.0.layer.1.DenseReluDense.wi_1.weight",
+ "encoder.block.0.layer.1.DenseReluDense.wo.weight",
+ "encoder.block.0.layer.1.layer_norm.weight",
+ "encoder.block.1.layer.0.SelfAttention.k.weight",
+ "encoder.block.1.layer.0.SelfAttention.o.weight",
+ "encoder.block.1.layer.0.SelfAttention.q.weight",
+ "encoder.block.1.layer.0.SelfAttention.v.weight",
+ "encoder.block.1.layer.0.layer_norm.weight",
+ "encoder.block.1.layer.1.DenseReluDense.wi_0.weight",
+ "encoder.block.1.layer.1.DenseReluDense.wi_1.weight",
+ "encoder.block.1.layer.1.DenseReluDense.wo.weight",
+ "encoder.block.1.layer.1.layer_norm.weight",
+ "encoder.block.10.layer.0.SelfAttention.k.weight",
+ "encoder.block.10.layer.0.SelfAttention.o.weight",
+ "encoder.block.10.layer.0.SelfAttention.q.weight",
+ "encoder.block.10.layer.0.SelfAttention.v.weight",
+ "encoder.block.10.layer.0.layer_norm.weight",
+ "encoder.block.10.layer.1.DenseReluDense.wi_0.weight",
+ "encoder.block.10.layer.1.DenseReluDense.wi_1.weight",
+ "encoder.block.10.layer.1.DenseReluDense.wo.weight",
+ "encoder.block.10.layer.1.layer_norm.weight",
+ "encoder.block.11.layer.0.SelfAttention.k.weight",
+ "encoder.block.11.layer.0.SelfAttention.o.weight",
+ "encoder.block.11.layer.0.SelfAttention.q.weight",
+ "encoder.block.11.layer.0.SelfAttention.v.weight",
+ "encoder.block.11.layer.0.layer_norm.weight",
+ "encoder.block.11.layer.1.DenseReluDense.wi_0.weight",
+ "encoder.block.11.layer.1.DenseReluDense.wi_1.weight",
+ "encoder.block.11.layer.1.DenseReluDense.wo.weight",
+ "encoder.block.11.layer.1.layer_norm.weight",
+ "encoder.block.12.layer.0.SelfAttention.k.weight",
+ "encoder.block.12.layer.0.SelfAttention.o.weight",
+ "encoder.block.12.layer.0.SelfAttention.q.weight",
+ "encoder.block.12.layer.0.SelfAttention.v.weight",
+ "encoder.block.12.layer.0.layer_norm.weight",
+ "encoder.block.12.layer.1.DenseReluDense.wi_0.weight",
+ "encoder.block.12.layer.1.DenseReluDense.wi_1.weight",
+ "encoder.block.12.layer.1.DenseReluDense.wo.weight",
+ "encoder.block.12.layer.1.layer_norm.weight",
+ "encoder.block.13.layer.0.SelfAttention.k.weight",
+ "encoder.block.13.layer.0.SelfAttention.o.weight",
+ "encoder.block.13.layer.0.SelfAttention.q.weight",
+ "encoder.block.13.layer.0.SelfAttention.v.weight",
+ "encoder.block.13.layer.0.layer_norm.weight",
+ "encoder.block.13.layer.1.DenseReluDense.wi_0.weight",
+ "encoder.block.13.layer.1.DenseReluDense.wi_1.weight",
+ "encoder.block.13.layer.1.DenseReluDense.wo.weight",
+ "encoder.block.13.layer.1.layer_norm.weight",
+ "encoder.block.14.layer.0.SelfAttention.k.weight",
+ "encoder.block.14.layer.0.SelfAttention.o.weight",
+ "encoder.block.14.layer.0.SelfAttention.q.weight",
+ "encoder.block.14.layer.0.SelfAttention.v.weight",
+ "encoder.block.14.layer.0.layer_norm.weight",
+ "encoder.block.14.layer.1.DenseReluDense.wi_0.weight",
+ "encoder.block.14.layer.1.DenseReluDense.wi_1.weight",
+ "encoder.block.14.layer.1.DenseReluDense.wo.weight",
+ "encoder.block.14.layer.1.layer_norm.weight",
+ "encoder.block.15.layer.0.SelfAttention.k.weight",
+ "encoder.block.15.layer.0.SelfAttention.o.weight",
+ "encoder.block.15.layer.0.SelfAttention.q.weight",
+ "encoder.block.15.layer.0.SelfAttention.v.weight",
+ "encoder.block.15.layer.0.layer_norm.weight",
+ "encoder.block.15.layer.1.DenseReluDense.wi_0.weight",
+ "encoder.block.15.layer.1.DenseReluDense.wi_1.weight",
+ "encoder.block.15.layer.1.DenseReluDense.wo.weight",
+ "encoder.block.15.layer.1.layer_norm.weight",
+ "encoder.block.16.layer.0.SelfAttention.k.weight",
+ "encoder.block.16.layer.0.SelfAttention.o.weight",
+ "encoder.block.16.layer.0.SelfAttention.q.weight",
+ "encoder.block.16.layer.0.SelfAttention.v.weight",
+ "encoder.block.16.layer.0.layer_norm.weight",
+ "encoder.block.16.layer.1.DenseReluDense.wi_0.weight",
+ "encoder.block.16.layer.1.DenseReluDense.wi_1.weight",
+ "encoder.block.16.layer.1.DenseReluDense.wo.weight",
+ "encoder.block.16.layer.1.layer_norm.weight",
+ "encoder.block.17.layer.0.SelfAttention.k.weight",
+ "encoder.block.17.layer.0.SelfAttention.o.weight",
+ "encoder.block.17.layer.0.SelfAttention.q.weight",
+ "encoder.block.17.layer.0.SelfAttention.v.weight",
+ "encoder.block.17.layer.0.layer_norm.weight",
+ "encoder.block.17.layer.1.DenseReluDense.wi_0.weight",
+ "encoder.block.17.layer.1.DenseReluDense.wi_1.weight",
+ "encoder.block.17.layer.1.DenseReluDense.wo.weight",
+ "encoder.block.17.layer.1.layer_norm.weight",
+ "encoder.block.18.layer.0.SelfAttention.k.weight",
+ "encoder.block.18.layer.0.SelfAttention.o.weight",
+ "encoder.block.18.layer.0.SelfAttention.q.weight",
+ "encoder.block.18.layer.0.SelfAttention.v.weight",
+ "encoder.block.18.layer.0.layer_norm.weight",
+ "encoder.block.18.layer.1.DenseReluDense.wi_0.weight",
+ "encoder.block.18.layer.1.DenseReluDense.wi_1.weight",
+ "encoder.block.18.layer.1.DenseReluDense.wo.weight",
+ "encoder.block.18.layer.1.layer_norm.weight",
+ "encoder.block.19.layer.0.SelfAttention.k.weight",
+ "encoder.block.19.layer.0.SelfAttention.o.weight",
+ "encoder.block.19.layer.0.SelfAttention.q.weight",
+ "encoder.block.19.layer.0.SelfAttention.v.weight",
+ "encoder.block.19.layer.0.layer_norm.weight",
+ "encoder.block.19.layer.1.DenseReluDense.wi_0.weight",
+ "encoder.block.19.layer.1.DenseReluDense.wi_1.weight",
+ "encoder.block.19.layer.1.DenseReluDense.wo.weight",
+ "encoder.block.19.layer.1.layer_norm.weight",
+ "encoder.block.2.layer.0.SelfAttention.k.weight",
+ "encoder.block.2.layer.0.SelfAttention.o.weight",
+ "encoder.block.2.layer.0.SelfAttention.q.weight",
+ "encoder.block.2.layer.0.SelfAttention.v.weight",
+ "encoder.block.2.layer.0.layer_norm.weight",
+ "encoder.block.2.layer.1.DenseReluDense.wi_0.weight",
+ "encoder.block.2.layer.1.DenseReluDense.wi_1.weight",
+ "encoder.block.2.layer.1.DenseReluDense.wo.weight",
+ "encoder.block.2.layer.1.layer_norm.weight",
+ "encoder.block.20.layer.0.SelfAttention.k.weight",
+ "encoder.block.20.layer.0.SelfAttention.o.weight",
+ "encoder.block.20.layer.0.SelfAttention.q.weight",
+ "encoder.block.20.layer.0.SelfAttention.v.weight",
+ "encoder.block.20.layer.0.layer_norm.weight",
+ "encoder.block.20.layer.1.DenseReluDense.wi_0.weight",
+ "encoder.block.20.layer.1.DenseReluDense.wi_1.weight",
+ "encoder.block.20.layer.1.DenseReluDense.wo.weight",
+ "encoder.block.20.layer.1.layer_norm.weight",
+ "encoder.block.21.layer.0.SelfAttention.k.weight",
+ "encoder.block.21.layer.0.SelfAttention.o.weight",
+ "encoder.block.21.layer.0.SelfAttention.q.weight",
+ "encoder.block.21.layer.0.SelfAttention.v.weight",
+ "encoder.block.21.layer.0.layer_norm.weight",
+ "encoder.block.21.layer.1.DenseReluDense.wi_0.weight",
+ "encoder.block.21.layer.1.DenseReluDense.wi_1.weight",
+ "encoder.block.21.layer.1.DenseReluDense.wo.weight",
+ "encoder.block.21.layer.1.layer_norm.weight",
+ "encoder.block.22.layer.0.SelfAttention.k.weight",
+ "encoder.block.22.layer.0.SelfAttention.o.weight",
+ "encoder.block.22.layer.0.SelfAttention.q.weight",
+ "encoder.block.22.layer.0.SelfAttention.v.weight",
+ "encoder.block.22.layer.0.layer_norm.weight",
+ "encoder.block.22.layer.1.DenseReluDense.wi_0.weight",
+ "encoder.block.22.layer.1.DenseReluDense.wi_1.weight",
+ "encoder.block.22.layer.1.DenseReluDense.wo.weight",
+ "encoder.block.22.layer.1.layer_norm.weight",
+ "encoder.block.23.layer.0.SelfAttention.k.weight",
+ "encoder.block.23.layer.0.SelfAttention.o.weight",
+ "encoder.block.23.layer.0.SelfAttention.q.weight",
+ "encoder.block.23.layer.0.SelfAttention.v.weight",
+ "encoder.block.23.layer.0.layer_norm.weight",
+ "encoder.block.23.layer.1.DenseReluDense.wi_0.weight",
+ "encoder.block.23.layer.1.DenseReluDense.wi_1.weight",
+ "encoder.block.23.layer.1.DenseReluDense.wo.weight",
+ "encoder.block.23.layer.1.layer_norm.weight",
+ "encoder.block.3.layer.0.SelfAttention.k.weight",
+ "encoder.block.3.layer.0.SelfAttention.o.weight",
+ "encoder.block.3.layer.0.SelfAttention.q.weight",
+ "encoder.block.3.layer.0.SelfAttention.v.weight",
+ "encoder.block.3.layer.0.layer_norm.weight",
+ "encoder.block.3.layer.1.DenseReluDense.wi_0.weight",
+ "encoder.block.3.layer.1.DenseReluDense.wi_1.weight",
+ "encoder.block.3.layer.1.DenseReluDense.wo.weight",
+ "encoder.block.3.layer.1.layer_norm.weight",
+ "encoder.block.4.layer.0.SelfAttention.k.weight",
+ "encoder.block.4.layer.0.SelfAttention.o.weight",
+ "encoder.block.4.layer.0.SelfAttention.q.weight",
+ "encoder.block.4.layer.0.SelfAttention.v.weight",
+ "encoder.block.4.layer.0.layer_norm.weight",
+ "encoder.block.4.layer.1.DenseReluDense.wi_0.weight",
+ "encoder.block.4.layer.1.DenseReluDense.wi_1.weight",
+ "encoder.block.4.layer.1.DenseReluDense.wo.weight",
+ "encoder.block.4.layer.1.layer_norm.weight",
+ "encoder.block.5.layer.0.SelfAttention.k.weight",
+ "encoder.block.5.layer.0.SelfAttention.o.weight",
+ "encoder.block.5.layer.0.SelfAttention.q.weight",
+ "encoder.block.5.layer.0.SelfAttention.v.weight",
+ "encoder.block.5.layer.0.layer_norm.weight",
+ "encoder.block.5.layer.1.DenseReluDense.wi_0.weight",
+ "encoder.block.5.layer.1.DenseReluDense.wi_1.weight",
+ "encoder.block.5.layer.1.DenseReluDense.wo.weight",
+ "encoder.block.5.layer.1.layer_norm.weight",
+ "encoder.block.6.layer.0.SelfAttention.k.weight",
+ "encoder.block.6.layer.0.SelfAttention.o.weight",
+ "encoder.block.6.layer.0.SelfAttention.q.weight",
+ "encoder.block.6.layer.0.SelfAttention.v.weight",
+ "encoder.block.6.layer.0.layer_norm.weight",
+ "encoder.block.6.layer.1.DenseReluDense.wi_0.weight",
+ "encoder.block.6.layer.1.DenseReluDense.wi_1.weight",
+ "encoder.block.6.layer.1.DenseReluDense.wo.weight",
+ "encoder.block.6.layer.1.layer_norm.weight",
+ "encoder.block.7.layer.0.SelfAttention.k.weight",
+ "encoder.block.7.layer.0.SelfAttention.o.weight",
+ "encoder.block.7.layer.0.SelfAttention.q.weight",
+ "encoder.block.7.layer.0.SelfAttention.v.weight",
+ "encoder.block.7.layer.0.layer_norm.weight",
+ "encoder.block.7.layer.1.DenseReluDense.wi_0.weight",
+ "encoder.block.7.layer.1.DenseReluDense.wi_1.weight",
+ "encoder.block.7.layer.1.DenseReluDense.wo.weight",
+ "encoder.block.7.layer.1.layer_norm.weight",
+ "encoder.block.8.layer.0.SelfAttention.k.weight",
+ "encoder.block.8.layer.0.SelfAttention.o.weight",
+ "encoder.block.8.layer.0.SelfAttention.q.weight",
+ "encoder.block.8.layer.0.SelfAttention.v.weight",
+ "encoder.block.8.layer.0.layer_norm.weight",
+ "encoder.block.8.layer.1.DenseReluDense.wi_0.weight",
+ "encoder.block.8.layer.1.DenseReluDense.wi_1.weight",
+ "encoder.block.8.layer.1.DenseReluDense.wo.weight",
+ "encoder.block.8.layer.1.layer_norm.weight",
+ "encoder.block.9.layer.0.SelfAttention.k.weight",
+ "encoder.block.9.layer.0.SelfAttention.o.weight",
+ "encoder.block.9.layer.0.SelfAttention.q.weight",
+ "encoder.block.9.layer.0.SelfAttention.v.weight",
+ "encoder.block.9.layer.0.layer_norm.weight",
+ "encoder.block.9.layer.1.DenseReluDense.wi_0.weight",
+ "encoder.block.9.layer.1.DenseReluDense.wi_1.weight",
+ "encoder.block.9.layer.1.DenseReluDense.wo.weight",
+ "encoder.block.9.layer.1.layer_norm.weight",
+ "encoder.embed_tokens.weight",
+ "encoder.final_layer_norm.weight",
+ "shared.weight",
+ ]
+ state_dict_ = {}
+ for name, param in state_dict.items():
+ if name in name_list:
+ state_dict_[name] = param
+ return state_dict_
+
diff --git a/PusaV1/diffsynth/models/sd3_vae_decoder.py b/PusaV1/diffsynth/models/sd3_vae_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..55fd9c05bdb3b6efe4da417f77b7c7b6e1ef949b
--- /dev/null
+++ b/PusaV1/diffsynth/models/sd3_vae_decoder.py
@@ -0,0 +1,81 @@
+import torch
+from .sd_vae_decoder import VAEAttentionBlock, SDVAEDecoderStateDictConverter
+from .sd_unet import ResnetBlock, UpSampler
+from .tiler import TileWorker
+
+
+
+class SD3VAEDecoder(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.scaling_factor = 1.5305 # Different from SD 1.x
+ self.shift_factor = 0.0609 # Different from SD 1.x
+ self.conv_in = torch.nn.Conv2d(16, 512, kernel_size=3, padding=1) # Different from SD 1.x
+
+ self.blocks = torch.nn.ModuleList([
+ # UNetMidBlock2D
+ ResnetBlock(512, 512, eps=1e-6),
+ VAEAttentionBlock(1, 512, 512, 1, eps=1e-6),
+ ResnetBlock(512, 512, eps=1e-6),
+ # UpDecoderBlock2D
+ ResnetBlock(512, 512, eps=1e-6),
+ ResnetBlock(512, 512, eps=1e-6),
+ ResnetBlock(512, 512, eps=1e-6),
+ UpSampler(512),
+ # UpDecoderBlock2D
+ ResnetBlock(512, 512, eps=1e-6),
+ ResnetBlock(512, 512, eps=1e-6),
+ ResnetBlock(512, 512, eps=1e-6),
+ UpSampler(512),
+ # UpDecoderBlock2D
+ ResnetBlock(512, 256, eps=1e-6),
+ ResnetBlock(256, 256, eps=1e-6),
+ ResnetBlock(256, 256, eps=1e-6),
+ UpSampler(256),
+ # UpDecoderBlock2D
+ ResnetBlock(256, 128, eps=1e-6),
+ ResnetBlock(128, 128, eps=1e-6),
+ ResnetBlock(128, 128, eps=1e-6),
+ ])
+
+ self.conv_norm_out = torch.nn.GroupNorm(num_channels=128, num_groups=32, eps=1e-6)
+ self.conv_act = torch.nn.SiLU()
+ self.conv_out = torch.nn.Conv2d(128, 3, kernel_size=3, padding=1)
+
+ def tiled_forward(self, sample, tile_size=64, tile_stride=32):
+ hidden_states = TileWorker().tiled_forward(
+ lambda x: self.forward(x),
+ sample,
+ tile_size,
+ tile_stride,
+ tile_device=sample.device,
+ tile_dtype=sample.dtype
+ )
+ return hidden_states
+
+ def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
+ # For VAE Decoder, we do not need to apply the tiler on each layer.
+ if tiled:
+ return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
+
+ # 1. pre-process
+ hidden_states = sample / self.scaling_factor + self.shift_factor
+ hidden_states = self.conv_in(hidden_states)
+ time_emb = None
+ text_emb = None
+ res_stack = None
+
+ # 2. blocks
+ for i, block in enumerate(self.blocks):
+ hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
+
+ # 3. output
+ hidden_states = self.conv_norm_out(hidden_states)
+ hidden_states = self.conv_act(hidden_states)
+ hidden_states = self.conv_out(hidden_states)
+
+ return hidden_states
+
+ @staticmethod
+ def state_dict_converter():
+ return SDVAEDecoderStateDictConverter()
\ No newline at end of file
diff --git a/PusaV1/diffsynth/models/sd3_vae_encoder.py b/PusaV1/diffsynth/models/sd3_vae_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..c486866b889093ac501c54e224f67c5428cf81c8
--- /dev/null
+++ b/PusaV1/diffsynth/models/sd3_vae_encoder.py
@@ -0,0 +1,95 @@
+import torch
+from .sd_unet import ResnetBlock, DownSampler
+from .sd_vae_encoder import VAEAttentionBlock, SDVAEEncoderStateDictConverter
+from .tiler import TileWorker
+from einops import rearrange
+
+
+class SD3VAEEncoder(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.scaling_factor = 1.5305 # Different from SD 1.x
+ self.shift_factor = 0.0609 # Different from SD 1.x
+ self.conv_in = torch.nn.Conv2d(3, 128, kernel_size=3, padding=1)
+
+ self.blocks = torch.nn.ModuleList([
+ # DownEncoderBlock2D
+ ResnetBlock(128, 128, eps=1e-6),
+ ResnetBlock(128, 128, eps=1e-6),
+ DownSampler(128, padding=0, extra_padding=True),
+ # DownEncoderBlock2D
+ ResnetBlock(128, 256, eps=1e-6),
+ ResnetBlock(256, 256, eps=1e-6),
+ DownSampler(256, padding=0, extra_padding=True),
+ # DownEncoderBlock2D
+ ResnetBlock(256, 512, eps=1e-6),
+ ResnetBlock(512, 512, eps=1e-6),
+ DownSampler(512, padding=0, extra_padding=True),
+ # DownEncoderBlock2D
+ ResnetBlock(512, 512, eps=1e-6),
+ ResnetBlock(512, 512, eps=1e-6),
+ # UNetMidBlock2D
+ ResnetBlock(512, 512, eps=1e-6),
+ VAEAttentionBlock(1, 512, 512, 1, eps=1e-6),
+ ResnetBlock(512, 512, eps=1e-6),
+ ])
+
+ self.conv_norm_out = torch.nn.GroupNorm(num_channels=512, num_groups=32, eps=1e-6)
+ self.conv_act = torch.nn.SiLU()
+ self.conv_out = torch.nn.Conv2d(512, 32, kernel_size=3, padding=1)
+
+ def tiled_forward(self, sample, tile_size=64, tile_stride=32):
+ hidden_states = TileWorker().tiled_forward(
+ lambda x: self.forward(x),
+ sample,
+ tile_size,
+ tile_stride,
+ tile_device=sample.device,
+ tile_dtype=sample.dtype
+ )
+ return hidden_states
+
+ def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
+ # For VAE Decoder, we do not need to apply the tiler on each layer.
+ if tiled:
+ return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
+
+ # 1. pre-process
+ hidden_states = self.conv_in(sample)
+ time_emb = None
+ text_emb = None
+ res_stack = None
+
+ # 2. blocks
+ for i, block in enumerate(self.blocks):
+ hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
+
+ # 3. output
+ hidden_states = self.conv_norm_out(hidden_states)
+ hidden_states = self.conv_act(hidden_states)
+ hidden_states = self.conv_out(hidden_states)
+ hidden_states = hidden_states[:, :16]
+ hidden_states = (hidden_states - self.shift_factor) * self.scaling_factor
+
+ return hidden_states
+
+ def encode_video(self, sample, batch_size=8):
+ B = sample.shape[0]
+ hidden_states = []
+
+ for i in range(0, sample.shape[2], batch_size):
+
+ j = min(i + batch_size, sample.shape[2])
+ sample_batch = rearrange(sample[:,:,i:j], "B C T H W -> (B T) C H W")
+
+ hidden_states_batch = self(sample_batch)
+ hidden_states_batch = rearrange(hidden_states_batch, "(B T) C H W -> B C T H W", B=B)
+
+ hidden_states.append(hidden_states_batch)
+
+ hidden_states = torch.concat(hidden_states, dim=2)
+ return hidden_states
+
+ @staticmethod
+ def state_dict_converter():
+ return SDVAEEncoderStateDictConverter()
diff --git a/PusaV1/diffsynth/models/sd_controlnet.py b/PusaV1/diffsynth/models/sd_controlnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..910e0dbae8dc6647e6e478f79c239450cefb2027
--- /dev/null
+++ b/PusaV1/diffsynth/models/sd_controlnet.py
@@ -0,0 +1,589 @@
+import torch
+from .sd_unet import Timesteps, ResnetBlock, AttentionBlock, PushBlock, DownSampler
+from .tiler import TileWorker
+
+
+class ControlNetConditioningLayer(torch.nn.Module):
+ def __init__(self, channels = (3, 16, 32, 96, 256, 320)):
+ super().__init__()
+ self.blocks = torch.nn.ModuleList([])
+ self.blocks.append(torch.nn.Conv2d(channels[0], channels[1], kernel_size=3, padding=1))
+ self.blocks.append(torch.nn.SiLU())
+ for i in range(1, len(channels) - 2):
+ self.blocks.append(torch.nn.Conv2d(channels[i], channels[i], kernel_size=3, padding=1))
+ self.blocks.append(torch.nn.SiLU())
+ self.blocks.append(torch.nn.Conv2d(channels[i], channels[i+1], kernel_size=3, padding=1, stride=2))
+ self.blocks.append(torch.nn.SiLU())
+ self.blocks.append(torch.nn.Conv2d(channels[-2], channels[-1], kernel_size=3, padding=1))
+
+ def forward(self, conditioning):
+ for block in self.blocks:
+ conditioning = block(conditioning)
+ return conditioning
+
+
+class SDControlNet(torch.nn.Module):
+ def __init__(self, global_pool=False):
+ super().__init__()
+ self.time_proj = Timesteps(320)
+ self.time_embedding = torch.nn.Sequential(
+ torch.nn.Linear(320, 1280),
+ torch.nn.SiLU(),
+ torch.nn.Linear(1280, 1280)
+ )
+ self.conv_in = torch.nn.Conv2d(4, 320, kernel_size=3, padding=1)
+
+ self.controlnet_conv_in = ControlNetConditioningLayer(channels=(3, 16, 32, 96, 256, 320))
+
+ self.blocks = torch.nn.ModuleList([
+ # CrossAttnDownBlock2D
+ ResnetBlock(320, 320, 1280),
+ AttentionBlock(8, 40, 320, 1, 768),
+ PushBlock(),
+ ResnetBlock(320, 320, 1280),
+ AttentionBlock(8, 40, 320, 1, 768),
+ PushBlock(),
+ DownSampler(320),
+ PushBlock(),
+ # CrossAttnDownBlock2D
+ ResnetBlock(320, 640, 1280),
+ AttentionBlock(8, 80, 640, 1, 768),
+ PushBlock(),
+ ResnetBlock(640, 640, 1280),
+ AttentionBlock(8, 80, 640, 1, 768),
+ PushBlock(),
+ DownSampler(640),
+ PushBlock(),
+ # CrossAttnDownBlock2D
+ ResnetBlock(640, 1280, 1280),
+ AttentionBlock(8, 160, 1280, 1, 768),
+ PushBlock(),
+ ResnetBlock(1280, 1280, 1280),
+ AttentionBlock(8, 160, 1280, 1, 768),
+ PushBlock(),
+ DownSampler(1280),
+ PushBlock(),
+ # DownBlock2D
+ ResnetBlock(1280, 1280, 1280),
+ PushBlock(),
+ ResnetBlock(1280, 1280, 1280),
+ PushBlock(),
+ # UNetMidBlock2DCrossAttn
+ ResnetBlock(1280, 1280, 1280),
+ AttentionBlock(8, 160, 1280, 1, 768),
+ ResnetBlock(1280, 1280, 1280),
+ PushBlock()
+ ])
+
+ self.controlnet_blocks = torch.nn.ModuleList([
+ torch.nn.Conv2d(320, 320, kernel_size=(1, 1)),
+ torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False),
+ torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False),
+ torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False),
+ torch.nn.Conv2d(640, 640, kernel_size=(1, 1)),
+ torch.nn.Conv2d(640, 640, kernel_size=(1, 1), bias=False),
+ torch.nn.Conv2d(640, 640, kernel_size=(1, 1), bias=False),
+ torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1)),
+ torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
+ torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
+ torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
+ torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
+ torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
+ ])
+
+ self.global_pool = global_pool
+
+ def forward(
+ self,
+ sample, timestep, encoder_hidden_states, conditioning,
+ tiled=False, tile_size=64, tile_stride=32,
+ **kwargs
+ ):
+ # 1. time
+ time_emb = self.time_proj(timestep).to(sample.dtype)
+ time_emb = self.time_embedding(time_emb)
+ time_emb = time_emb.repeat(sample.shape[0], 1)
+
+ # 2. pre-process
+ height, width = sample.shape[2], sample.shape[3]
+ hidden_states = self.conv_in(sample) + self.controlnet_conv_in(conditioning)
+ text_emb = encoder_hidden_states
+ res_stack = [hidden_states]
+
+ # 3. blocks
+ for i, block in enumerate(self.blocks):
+ if tiled and not isinstance(block, PushBlock):
+ _, _, inter_height, _ = hidden_states.shape
+ resize_scale = inter_height / height
+ hidden_states = TileWorker().tiled_forward(
+ lambda x: block(x, time_emb, text_emb, res_stack)[0],
+ hidden_states,
+ int(tile_size * resize_scale),
+ int(tile_stride * resize_scale),
+ tile_device=hidden_states.device,
+ tile_dtype=hidden_states.dtype
+ )
+ else:
+ hidden_states, _, _, _ = block(hidden_states, time_emb, text_emb, res_stack)
+
+ # 4. ControlNet blocks
+ controlnet_res_stack = [block(res) for block, res in zip(self.controlnet_blocks, res_stack)]
+
+ # pool
+ if self.global_pool:
+ controlnet_res_stack = [res.mean(dim=(2, 3), keepdim=True) for res in controlnet_res_stack]
+
+ return controlnet_res_stack
+
+ @staticmethod
+ def state_dict_converter():
+ return SDControlNetStateDictConverter()
+
+
+class SDControlNetStateDictConverter:
+ def __init__(self):
+ pass
+
+ def from_diffusers(self, state_dict):
+ # architecture
+ block_types = [
+ 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'DownSampler', 'PushBlock',
+ 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'DownSampler', 'PushBlock',
+ 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'DownSampler', 'PushBlock',
+ 'ResnetBlock', 'PushBlock', 'ResnetBlock', 'PushBlock',
+ 'ResnetBlock', 'AttentionBlock', 'ResnetBlock',
+ 'PopBlock', 'ResnetBlock', 'PopBlock', 'ResnetBlock', 'PopBlock', 'ResnetBlock', 'UpSampler',
+ 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'UpSampler',
+ 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'UpSampler',
+ 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock'
+ ]
+
+ # controlnet_rename_dict
+ controlnet_rename_dict = {
+ "controlnet_cond_embedding.conv_in.weight": "controlnet_conv_in.blocks.0.weight",
+ "controlnet_cond_embedding.conv_in.bias": "controlnet_conv_in.blocks.0.bias",
+ "controlnet_cond_embedding.blocks.0.weight": "controlnet_conv_in.blocks.2.weight",
+ "controlnet_cond_embedding.blocks.0.bias": "controlnet_conv_in.blocks.2.bias",
+ "controlnet_cond_embedding.blocks.1.weight": "controlnet_conv_in.blocks.4.weight",
+ "controlnet_cond_embedding.blocks.1.bias": "controlnet_conv_in.blocks.4.bias",
+ "controlnet_cond_embedding.blocks.2.weight": "controlnet_conv_in.blocks.6.weight",
+ "controlnet_cond_embedding.blocks.2.bias": "controlnet_conv_in.blocks.6.bias",
+ "controlnet_cond_embedding.blocks.3.weight": "controlnet_conv_in.blocks.8.weight",
+ "controlnet_cond_embedding.blocks.3.bias": "controlnet_conv_in.blocks.8.bias",
+ "controlnet_cond_embedding.blocks.4.weight": "controlnet_conv_in.blocks.10.weight",
+ "controlnet_cond_embedding.blocks.4.bias": "controlnet_conv_in.blocks.10.bias",
+ "controlnet_cond_embedding.blocks.5.weight": "controlnet_conv_in.blocks.12.weight",
+ "controlnet_cond_embedding.blocks.5.bias": "controlnet_conv_in.blocks.12.bias",
+ "controlnet_cond_embedding.conv_out.weight": "controlnet_conv_in.blocks.14.weight",
+ "controlnet_cond_embedding.conv_out.bias": "controlnet_conv_in.blocks.14.bias",
+ }
+
+ # Rename each parameter
+ name_list = sorted([name for name in state_dict])
+ rename_dict = {}
+ block_id = {"ResnetBlock": -1, "AttentionBlock": -1, "DownSampler": -1, "UpSampler": -1}
+ last_block_type_with_id = {"ResnetBlock": "", "AttentionBlock": "", "DownSampler": "", "UpSampler": ""}
+ for name in name_list:
+ names = name.split(".")
+ if names[0] in ["conv_in", "conv_norm_out", "conv_out"]:
+ pass
+ elif name in controlnet_rename_dict:
+ names = controlnet_rename_dict[name].split(".")
+ elif names[0] == "controlnet_down_blocks":
+ names[0] = "controlnet_blocks"
+ elif names[0] == "controlnet_mid_block":
+ names = ["controlnet_blocks", "12", names[-1]]
+ elif names[0] in ["time_embedding", "add_embedding"]:
+ if names[0] == "add_embedding":
+ names[0] = "add_time_embedding"
+ names[1] = {"linear_1": "0", "linear_2": "2"}[names[1]]
+ elif names[0] in ["down_blocks", "mid_block", "up_blocks"]:
+ if names[0] == "mid_block":
+ names.insert(1, "0")
+ block_type = {"resnets": "ResnetBlock", "attentions": "AttentionBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[2]]
+ block_type_with_id = ".".join(names[:4])
+ if block_type_with_id != last_block_type_with_id[block_type]:
+ block_id[block_type] += 1
+ last_block_type_with_id[block_type] = block_type_with_id
+ while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type:
+ block_id[block_type] += 1
+ block_type_with_id = ".".join(names[:4])
+ names = ["blocks", str(block_id[block_type])] + names[4:]
+ if "ff" in names:
+ ff_index = names.index("ff")
+ component = ".".join(names[ff_index:ff_index+3])
+ component = {"ff.net.0": "act_fn", "ff.net.2": "ff"}[component]
+ names = names[:ff_index] + [component] + names[ff_index+3:]
+ if "to_out" in names:
+ names.pop(names.index("to_out") + 1)
+ else:
+ raise ValueError(f"Unknown parameters: {name}")
+ rename_dict[name] = ".".join(names)
+
+ # Convert state_dict
+ state_dict_ = {}
+ for name, param in state_dict.items():
+ if ".proj_in." in name or ".proj_out." in name:
+ param = param.squeeze()
+ if rename_dict[name] in [
+ "controlnet_blocks.1.bias", "controlnet_blocks.2.bias", "controlnet_blocks.3.bias", "controlnet_blocks.5.bias", "controlnet_blocks.6.bias",
+ "controlnet_blocks.8.bias", "controlnet_blocks.9.bias", "controlnet_blocks.10.bias", "controlnet_blocks.11.bias", "controlnet_blocks.12.bias"
+ ]:
+ continue
+ state_dict_[rename_dict[name]] = param
+ return state_dict_
+
+ def from_civitai(self, state_dict):
+ if "mid_block.resnets.1.time_emb_proj.weight" in state_dict:
+ # For controlnets in diffusers format
+ return self.from_diffusers(state_dict)
+ rename_dict = {
+ "control_model.time_embed.0.weight": "time_embedding.0.weight",
+ "control_model.time_embed.0.bias": "time_embedding.0.bias",
+ "control_model.time_embed.2.weight": "time_embedding.2.weight",
+ "control_model.time_embed.2.bias": "time_embedding.2.bias",
+ "control_model.input_blocks.0.0.weight": "conv_in.weight",
+ "control_model.input_blocks.0.0.bias": "conv_in.bias",
+ "control_model.input_blocks.1.0.in_layers.0.weight": "blocks.0.norm1.weight",
+ "control_model.input_blocks.1.0.in_layers.0.bias": "blocks.0.norm1.bias",
+ "control_model.input_blocks.1.0.in_layers.2.weight": "blocks.0.conv1.weight",
+ "control_model.input_blocks.1.0.in_layers.2.bias": "blocks.0.conv1.bias",
+ "control_model.input_blocks.1.0.emb_layers.1.weight": "blocks.0.time_emb_proj.weight",
+ "control_model.input_blocks.1.0.emb_layers.1.bias": "blocks.0.time_emb_proj.bias",
+ "control_model.input_blocks.1.0.out_layers.0.weight": "blocks.0.norm2.weight",
+ "control_model.input_blocks.1.0.out_layers.0.bias": "blocks.0.norm2.bias",
+ "control_model.input_blocks.1.0.out_layers.3.weight": "blocks.0.conv2.weight",
+ "control_model.input_blocks.1.0.out_layers.3.bias": "blocks.0.conv2.bias",
+ "control_model.input_blocks.1.1.norm.weight": "blocks.1.norm.weight",
+ "control_model.input_blocks.1.1.norm.bias": "blocks.1.norm.bias",
+ "control_model.input_blocks.1.1.proj_in.weight": "blocks.1.proj_in.weight",
+ "control_model.input_blocks.1.1.proj_in.bias": "blocks.1.proj_in.bias",
+ "control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_q.weight": "blocks.1.transformer_blocks.0.attn1.to_q.weight",
+ "control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_k.weight": "blocks.1.transformer_blocks.0.attn1.to_k.weight",
+ "control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_v.weight": "blocks.1.transformer_blocks.0.attn1.to_v.weight",
+ "control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.1.transformer_blocks.0.attn1.to_out.weight",
+ "control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.1.transformer_blocks.0.attn1.to_out.bias",
+ "control_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.1.transformer_blocks.0.act_fn.proj.weight",
+ "control_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.1.transformer_blocks.0.act_fn.proj.bias",
+ "control_model.input_blocks.1.1.transformer_blocks.0.ff.net.2.weight": "blocks.1.transformer_blocks.0.ff.weight",
+ "control_model.input_blocks.1.1.transformer_blocks.0.ff.net.2.bias": "blocks.1.transformer_blocks.0.ff.bias",
+ "control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_q.weight": "blocks.1.transformer_blocks.0.attn2.to_q.weight",
+ "control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight": "blocks.1.transformer_blocks.0.attn2.to_k.weight",
+ "control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_v.weight": "blocks.1.transformer_blocks.0.attn2.to_v.weight",
+ "control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.1.transformer_blocks.0.attn2.to_out.weight",
+ "control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.1.transformer_blocks.0.attn2.to_out.bias",
+ "control_model.input_blocks.1.1.transformer_blocks.0.norm1.weight": "blocks.1.transformer_blocks.0.norm1.weight",
+ "control_model.input_blocks.1.1.transformer_blocks.0.norm1.bias": "blocks.1.transformer_blocks.0.norm1.bias",
+ "control_model.input_blocks.1.1.transformer_blocks.0.norm2.weight": "blocks.1.transformer_blocks.0.norm2.weight",
+ "control_model.input_blocks.1.1.transformer_blocks.0.norm2.bias": "blocks.1.transformer_blocks.0.norm2.bias",
+ "control_model.input_blocks.1.1.transformer_blocks.0.norm3.weight": "blocks.1.transformer_blocks.0.norm3.weight",
+ "control_model.input_blocks.1.1.transformer_blocks.0.norm3.bias": "blocks.1.transformer_blocks.0.norm3.bias",
+ "control_model.input_blocks.1.1.proj_out.weight": "blocks.1.proj_out.weight",
+ "control_model.input_blocks.1.1.proj_out.bias": "blocks.1.proj_out.bias",
+ "control_model.input_blocks.2.0.in_layers.0.weight": "blocks.3.norm1.weight",
+ "control_model.input_blocks.2.0.in_layers.0.bias": "blocks.3.norm1.bias",
+ "control_model.input_blocks.2.0.in_layers.2.weight": "blocks.3.conv1.weight",
+ "control_model.input_blocks.2.0.in_layers.2.bias": "blocks.3.conv1.bias",
+ "control_model.input_blocks.2.0.emb_layers.1.weight": "blocks.3.time_emb_proj.weight",
+ "control_model.input_blocks.2.0.emb_layers.1.bias": "blocks.3.time_emb_proj.bias",
+ "control_model.input_blocks.2.0.out_layers.0.weight": "blocks.3.norm2.weight",
+ "control_model.input_blocks.2.0.out_layers.0.bias": "blocks.3.norm2.bias",
+ "control_model.input_blocks.2.0.out_layers.3.weight": "blocks.3.conv2.weight",
+ "control_model.input_blocks.2.0.out_layers.3.bias": "blocks.3.conv2.bias",
+ "control_model.input_blocks.2.1.norm.weight": "blocks.4.norm.weight",
+ "control_model.input_blocks.2.1.norm.bias": "blocks.4.norm.bias",
+ "control_model.input_blocks.2.1.proj_in.weight": "blocks.4.proj_in.weight",
+ "control_model.input_blocks.2.1.proj_in.bias": "blocks.4.proj_in.bias",
+ "control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_q.weight": "blocks.4.transformer_blocks.0.attn1.to_q.weight",
+ "control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_k.weight": "blocks.4.transformer_blocks.0.attn1.to_k.weight",
+ "control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_v.weight": "blocks.4.transformer_blocks.0.attn1.to_v.weight",
+ "control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.4.transformer_blocks.0.attn1.to_out.weight",
+ "control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.4.transformer_blocks.0.attn1.to_out.bias",
+ "control_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.4.transformer_blocks.0.act_fn.proj.weight",
+ "control_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.4.transformer_blocks.0.act_fn.proj.bias",
+ "control_model.input_blocks.2.1.transformer_blocks.0.ff.net.2.weight": "blocks.4.transformer_blocks.0.ff.weight",
+ "control_model.input_blocks.2.1.transformer_blocks.0.ff.net.2.bias": "blocks.4.transformer_blocks.0.ff.bias",
+ "control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_q.weight": "blocks.4.transformer_blocks.0.attn2.to_q.weight",
+ "control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight": "blocks.4.transformer_blocks.0.attn2.to_k.weight",
+ "control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_v.weight": "blocks.4.transformer_blocks.0.attn2.to_v.weight",
+ "control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.4.transformer_blocks.0.attn2.to_out.weight",
+ "control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.4.transformer_blocks.0.attn2.to_out.bias",
+ "control_model.input_blocks.2.1.transformer_blocks.0.norm1.weight": "blocks.4.transformer_blocks.0.norm1.weight",
+ "control_model.input_blocks.2.1.transformer_blocks.0.norm1.bias": "blocks.4.transformer_blocks.0.norm1.bias",
+ "control_model.input_blocks.2.1.transformer_blocks.0.norm2.weight": "blocks.4.transformer_blocks.0.norm2.weight",
+ "control_model.input_blocks.2.1.transformer_blocks.0.norm2.bias": "blocks.4.transformer_blocks.0.norm2.bias",
+ "control_model.input_blocks.2.1.transformer_blocks.0.norm3.weight": "blocks.4.transformer_blocks.0.norm3.weight",
+ "control_model.input_blocks.2.1.transformer_blocks.0.norm3.bias": "blocks.4.transformer_blocks.0.norm3.bias",
+ "control_model.input_blocks.2.1.proj_out.weight": "blocks.4.proj_out.weight",
+ "control_model.input_blocks.2.1.proj_out.bias": "blocks.4.proj_out.bias",
+ "control_model.input_blocks.3.0.op.weight": "blocks.6.conv.weight",
+ "control_model.input_blocks.3.0.op.bias": "blocks.6.conv.bias",
+ "control_model.input_blocks.4.0.in_layers.0.weight": "blocks.8.norm1.weight",
+ "control_model.input_blocks.4.0.in_layers.0.bias": "blocks.8.norm1.bias",
+ "control_model.input_blocks.4.0.in_layers.2.weight": "blocks.8.conv1.weight",
+ "control_model.input_blocks.4.0.in_layers.2.bias": "blocks.8.conv1.bias",
+ "control_model.input_blocks.4.0.emb_layers.1.weight": "blocks.8.time_emb_proj.weight",
+ "control_model.input_blocks.4.0.emb_layers.1.bias": "blocks.8.time_emb_proj.bias",
+ "control_model.input_blocks.4.0.out_layers.0.weight": "blocks.8.norm2.weight",
+ "control_model.input_blocks.4.0.out_layers.0.bias": "blocks.8.norm2.bias",
+ "control_model.input_blocks.4.0.out_layers.3.weight": "blocks.8.conv2.weight",
+ "control_model.input_blocks.4.0.out_layers.3.bias": "blocks.8.conv2.bias",
+ "control_model.input_blocks.4.0.skip_connection.weight": "blocks.8.conv_shortcut.weight",
+ "control_model.input_blocks.4.0.skip_connection.bias": "blocks.8.conv_shortcut.bias",
+ "control_model.input_blocks.4.1.norm.weight": "blocks.9.norm.weight",
+ "control_model.input_blocks.4.1.norm.bias": "blocks.9.norm.bias",
+ "control_model.input_blocks.4.1.proj_in.weight": "blocks.9.proj_in.weight",
+ "control_model.input_blocks.4.1.proj_in.bias": "blocks.9.proj_in.bias",
+ "control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_q.weight": "blocks.9.transformer_blocks.0.attn1.to_q.weight",
+ "control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_k.weight": "blocks.9.transformer_blocks.0.attn1.to_k.weight",
+ "control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_v.weight": "blocks.9.transformer_blocks.0.attn1.to_v.weight",
+ "control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.9.transformer_blocks.0.attn1.to_out.weight",
+ "control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.9.transformer_blocks.0.attn1.to_out.bias",
+ "control_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.9.transformer_blocks.0.act_fn.proj.weight",
+ "control_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.9.transformer_blocks.0.act_fn.proj.bias",
+ "control_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.weight": "blocks.9.transformer_blocks.0.ff.weight",
+ "control_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.bias": "blocks.9.transformer_blocks.0.ff.bias",
+ "control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_q.weight": "blocks.9.transformer_blocks.0.attn2.to_q.weight",
+ "control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight": "blocks.9.transformer_blocks.0.attn2.to_k.weight",
+ "control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_v.weight": "blocks.9.transformer_blocks.0.attn2.to_v.weight",
+ "control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.9.transformer_blocks.0.attn2.to_out.weight",
+ "control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.9.transformer_blocks.0.attn2.to_out.bias",
+ "control_model.input_blocks.4.1.transformer_blocks.0.norm1.weight": "blocks.9.transformer_blocks.0.norm1.weight",
+ "control_model.input_blocks.4.1.transformer_blocks.0.norm1.bias": "blocks.9.transformer_blocks.0.norm1.bias",
+ "control_model.input_blocks.4.1.transformer_blocks.0.norm2.weight": "blocks.9.transformer_blocks.0.norm2.weight",
+ "control_model.input_blocks.4.1.transformer_blocks.0.norm2.bias": "blocks.9.transformer_blocks.0.norm2.bias",
+ "control_model.input_blocks.4.1.transformer_blocks.0.norm3.weight": "blocks.9.transformer_blocks.0.norm3.weight",
+ "control_model.input_blocks.4.1.transformer_blocks.0.norm3.bias": "blocks.9.transformer_blocks.0.norm3.bias",
+ "control_model.input_blocks.4.1.proj_out.weight": "blocks.9.proj_out.weight",
+ "control_model.input_blocks.4.1.proj_out.bias": "blocks.9.proj_out.bias",
+ "control_model.input_blocks.5.0.in_layers.0.weight": "blocks.11.norm1.weight",
+ "control_model.input_blocks.5.0.in_layers.0.bias": "blocks.11.norm1.bias",
+ "control_model.input_blocks.5.0.in_layers.2.weight": "blocks.11.conv1.weight",
+ "control_model.input_blocks.5.0.in_layers.2.bias": "blocks.11.conv1.bias",
+ "control_model.input_blocks.5.0.emb_layers.1.weight": "blocks.11.time_emb_proj.weight",
+ "control_model.input_blocks.5.0.emb_layers.1.bias": "blocks.11.time_emb_proj.bias",
+ "control_model.input_blocks.5.0.out_layers.0.weight": "blocks.11.norm2.weight",
+ "control_model.input_blocks.5.0.out_layers.0.bias": "blocks.11.norm2.bias",
+ "control_model.input_blocks.5.0.out_layers.3.weight": "blocks.11.conv2.weight",
+ "control_model.input_blocks.5.0.out_layers.3.bias": "blocks.11.conv2.bias",
+ "control_model.input_blocks.5.1.norm.weight": "blocks.12.norm.weight",
+ "control_model.input_blocks.5.1.norm.bias": "blocks.12.norm.bias",
+ "control_model.input_blocks.5.1.proj_in.weight": "blocks.12.proj_in.weight",
+ "control_model.input_blocks.5.1.proj_in.bias": "blocks.12.proj_in.bias",
+ "control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_q.weight": "blocks.12.transformer_blocks.0.attn1.to_q.weight",
+ "control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_k.weight": "blocks.12.transformer_blocks.0.attn1.to_k.weight",
+ "control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_v.weight": "blocks.12.transformer_blocks.0.attn1.to_v.weight",
+ "control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.12.transformer_blocks.0.attn1.to_out.weight",
+ "control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.12.transformer_blocks.0.attn1.to_out.bias",
+ "control_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.12.transformer_blocks.0.act_fn.proj.weight",
+ "control_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.12.transformer_blocks.0.act_fn.proj.bias",
+ "control_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.weight": "blocks.12.transformer_blocks.0.ff.weight",
+ "control_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.bias": "blocks.12.transformer_blocks.0.ff.bias",
+ "control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_q.weight": "blocks.12.transformer_blocks.0.attn2.to_q.weight",
+ "control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_k.weight": "blocks.12.transformer_blocks.0.attn2.to_k.weight",
+ "control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_v.weight": "blocks.12.transformer_blocks.0.attn2.to_v.weight",
+ "control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.12.transformer_blocks.0.attn2.to_out.weight",
+ "control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.12.transformer_blocks.0.attn2.to_out.bias",
+ "control_model.input_blocks.5.1.transformer_blocks.0.norm1.weight": "blocks.12.transformer_blocks.0.norm1.weight",
+ "control_model.input_blocks.5.1.transformer_blocks.0.norm1.bias": "blocks.12.transformer_blocks.0.norm1.bias",
+ "control_model.input_blocks.5.1.transformer_blocks.0.norm2.weight": "blocks.12.transformer_blocks.0.norm2.weight",
+ "control_model.input_blocks.5.1.transformer_blocks.0.norm2.bias": "blocks.12.transformer_blocks.0.norm2.bias",
+ "control_model.input_blocks.5.1.transformer_blocks.0.norm3.weight": "blocks.12.transformer_blocks.0.norm3.weight",
+ "control_model.input_blocks.5.1.transformer_blocks.0.norm3.bias": "blocks.12.transformer_blocks.0.norm3.bias",
+ "control_model.input_blocks.5.1.proj_out.weight": "blocks.12.proj_out.weight",
+ "control_model.input_blocks.5.1.proj_out.bias": "blocks.12.proj_out.bias",
+ "control_model.input_blocks.6.0.op.weight": "blocks.14.conv.weight",
+ "control_model.input_blocks.6.0.op.bias": "blocks.14.conv.bias",
+ "control_model.input_blocks.7.0.in_layers.0.weight": "blocks.16.norm1.weight",
+ "control_model.input_blocks.7.0.in_layers.0.bias": "blocks.16.norm1.bias",
+ "control_model.input_blocks.7.0.in_layers.2.weight": "blocks.16.conv1.weight",
+ "control_model.input_blocks.7.0.in_layers.2.bias": "blocks.16.conv1.bias",
+ "control_model.input_blocks.7.0.emb_layers.1.weight": "blocks.16.time_emb_proj.weight",
+ "control_model.input_blocks.7.0.emb_layers.1.bias": "blocks.16.time_emb_proj.bias",
+ "control_model.input_blocks.7.0.out_layers.0.weight": "blocks.16.norm2.weight",
+ "control_model.input_blocks.7.0.out_layers.0.bias": "blocks.16.norm2.bias",
+ "control_model.input_blocks.7.0.out_layers.3.weight": "blocks.16.conv2.weight",
+ "control_model.input_blocks.7.0.out_layers.3.bias": "blocks.16.conv2.bias",
+ "control_model.input_blocks.7.0.skip_connection.weight": "blocks.16.conv_shortcut.weight",
+ "control_model.input_blocks.7.0.skip_connection.bias": "blocks.16.conv_shortcut.bias",
+ "control_model.input_blocks.7.1.norm.weight": "blocks.17.norm.weight",
+ "control_model.input_blocks.7.1.norm.bias": "blocks.17.norm.bias",
+ "control_model.input_blocks.7.1.proj_in.weight": "blocks.17.proj_in.weight",
+ "control_model.input_blocks.7.1.proj_in.bias": "blocks.17.proj_in.bias",
+ "control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_q.weight": "blocks.17.transformer_blocks.0.attn1.to_q.weight",
+ "control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_k.weight": "blocks.17.transformer_blocks.0.attn1.to_k.weight",
+ "control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_v.weight": "blocks.17.transformer_blocks.0.attn1.to_v.weight",
+ "control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.17.transformer_blocks.0.attn1.to_out.weight",
+ "control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.17.transformer_blocks.0.attn1.to_out.bias",
+ "control_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.17.transformer_blocks.0.act_fn.proj.weight",
+ "control_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.17.transformer_blocks.0.act_fn.proj.bias",
+ "control_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.weight": "blocks.17.transformer_blocks.0.ff.weight",
+ "control_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.bias": "blocks.17.transformer_blocks.0.ff.bias",
+ "control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_q.weight": "blocks.17.transformer_blocks.0.attn2.to_q.weight",
+ "control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_k.weight": "blocks.17.transformer_blocks.0.attn2.to_k.weight",
+ "control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_v.weight": "blocks.17.transformer_blocks.0.attn2.to_v.weight",
+ "control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.17.transformer_blocks.0.attn2.to_out.weight",
+ "control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.17.transformer_blocks.0.attn2.to_out.bias",
+ "control_model.input_blocks.7.1.transformer_blocks.0.norm1.weight": "blocks.17.transformer_blocks.0.norm1.weight",
+ "control_model.input_blocks.7.1.transformer_blocks.0.norm1.bias": "blocks.17.transformer_blocks.0.norm1.bias",
+ "control_model.input_blocks.7.1.transformer_blocks.0.norm2.weight": "blocks.17.transformer_blocks.0.norm2.weight",
+ "control_model.input_blocks.7.1.transformer_blocks.0.norm2.bias": "blocks.17.transformer_blocks.0.norm2.bias",
+ "control_model.input_blocks.7.1.transformer_blocks.0.norm3.weight": "blocks.17.transformer_blocks.0.norm3.weight",
+ "control_model.input_blocks.7.1.transformer_blocks.0.norm3.bias": "blocks.17.transformer_blocks.0.norm3.bias",
+ "control_model.input_blocks.7.1.proj_out.weight": "blocks.17.proj_out.weight",
+ "control_model.input_blocks.7.1.proj_out.bias": "blocks.17.proj_out.bias",
+ "control_model.input_blocks.8.0.in_layers.0.weight": "blocks.19.norm1.weight",
+ "control_model.input_blocks.8.0.in_layers.0.bias": "blocks.19.norm1.bias",
+ "control_model.input_blocks.8.0.in_layers.2.weight": "blocks.19.conv1.weight",
+ "control_model.input_blocks.8.0.in_layers.2.bias": "blocks.19.conv1.bias",
+ "control_model.input_blocks.8.0.emb_layers.1.weight": "blocks.19.time_emb_proj.weight",
+ "control_model.input_blocks.8.0.emb_layers.1.bias": "blocks.19.time_emb_proj.bias",
+ "control_model.input_blocks.8.0.out_layers.0.weight": "blocks.19.norm2.weight",
+ "control_model.input_blocks.8.0.out_layers.0.bias": "blocks.19.norm2.bias",
+ "control_model.input_blocks.8.0.out_layers.3.weight": "blocks.19.conv2.weight",
+ "control_model.input_blocks.8.0.out_layers.3.bias": "blocks.19.conv2.bias",
+ "control_model.input_blocks.8.1.norm.weight": "blocks.20.norm.weight",
+ "control_model.input_blocks.8.1.norm.bias": "blocks.20.norm.bias",
+ "control_model.input_blocks.8.1.proj_in.weight": "blocks.20.proj_in.weight",
+ "control_model.input_blocks.8.1.proj_in.bias": "blocks.20.proj_in.bias",
+ "control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_q.weight": "blocks.20.transformer_blocks.0.attn1.to_q.weight",
+ "control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_k.weight": "blocks.20.transformer_blocks.0.attn1.to_k.weight",
+ "control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_v.weight": "blocks.20.transformer_blocks.0.attn1.to_v.weight",
+ "control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.20.transformer_blocks.0.attn1.to_out.weight",
+ "control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.20.transformer_blocks.0.attn1.to_out.bias",
+ "control_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.20.transformer_blocks.0.act_fn.proj.weight",
+ "control_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.20.transformer_blocks.0.act_fn.proj.bias",
+ "control_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.weight": "blocks.20.transformer_blocks.0.ff.weight",
+ "control_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.bias": "blocks.20.transformer_blocks.0.ff.bias",
+ "control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_q.weight": "blocks.20.transformer_blocks.0.attn2.to_q.weight",
+ "control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_k.weight": "blocks.20.transformer_blocks.0.attn2.to_k.weight",
+ "control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_v.weight": "blocks.20.transformer_blocks.0.attn2.to_v.weight",
+ "control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.20.transformer_blocks.0.attn2.to_out.weight",
+ "control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.20.transformer_blocks.0.attn2.to_out.bias",
+ "control_model.input_blocks.8.1.transformer_blocks.0.norm1.weight": "blocks.20.transformer_blocks.0.norm1.weight",
+ "control_model.input_blocks.8.1.transformer_blocks.0.norm1.bias": "blocks.20.transformer_blocks.0.norm1.bias",
+ "control_model.input_blocks.8.1.transformer_blocks.0.norm2.weight": "blocks.20.transformer_blocks.0.norm2.weight",
+ "control_model.input_blocks.8.1.transformer_blocks.0.norm2.bias": "blocks.20.transformer_blocks.0.norm2.bias",
+ "control_model.input_blocks.8.1.transformer_blocks.0.norm3.weight": "blocks.20.transformer_blocks.0.norm3.weight",
+ "control_model.input_blocks.8.1.transformer_blocks.0.norm3.bias": "blocks.20.transformer_blocks.0.norm3.bias",
+ "control_model.input_blocks.8.1.proj_out.weight": "blocks.20.proj_out.weight",
+ "control_model.input_blocks.8.1.proj_out.bias": "blocks.20.proj_out.bias",
+ "control_model.input_blocks.9.0.op.weight": "blocks.22.conv.weight",
+ "control_model.input_blocks.9.0.op.bias": "blocks.22.conv.bias",
+ "control_model.input_blocks.10.0.in_layers.0.weight": "blocks.24.norm1.weight",
+ "control_model.input_blocks.10.0.in_layers.0.bias": "blocks.24.norm1.bias",
+ "control_model.input_blocks.10.0.in_layers.2.weight": "blocks.24.conv1.weight",
+ "control_model.input_blocks.10.0.in_layers.2.bias": "blocks.24.conv1.bias",
+ "control_model.input_blocks.10.0.emb_layers.1.weight": "blocks.24.time_emb_proj.weight",
+ "control_model.input_blocks.10.0.emb_layers.1.bias": "blocks.24.time_emb_proj.bias",
+ "control_model.input_blocks.10.0.out_layers.0.weight": "blocks.24.norm2.weight",
+ "control_model.input_blocks.10.0.out_layers.0.bias": "blocks.24.norm2.bias",
+ "control_model.input_blocks.10.0.out_layers.3.weight": "blocks.24.conv2.weight",
+ "control_model.input_blocks.10.0.out_layers.3.bias": "blocks.24.conv2.bias",
+ "control_model.input_blocks.11.0.in_layers.0.weight": "blocks.26.norm1.weight",
+ "control_model.input_blocks.11.0.in_layers.0.bias": "blocks.26.norm1.bias",
+ "control_model.input_blocks.11.0.in_layers.2.weight": "blocks.26.conv1.weight",
+ "control_model.input_blocks.11.0.in_layers.2.bias": "blocks.26.conv1.bias",
+ "control_model.input_blocks.11.0.emb_layers.1.weight": "blocks.26.time_emb_proj.weight",
+ "control_model.input_blocks.11.0.emb_layers.1.bias": "blocks.26.time_emb_proj.bias",
+ "control_model.input_blocks.11.0.out_layers.0.weight": "blocks.26.norm2.weight",
+ "control_model.input_blocks.11.0.out_layers.0.bias": "blocks.26.norm2.bias",
+ "control_model.input_blocks.11.0.out_layers.3.weight": "blocks.26.conv2.weight",
+ "control_model.input_blocks.11.0.out_layers.3.bias": "blocks.26.conv2.bias",
+ "control_model.zero_convs.0.0.weight": "controlnet_blocks.0.weight",
+ "control_model.zero_convs.0.0.bias": "controlnet_blocks.0.bias",
+ "control_model.zero_convs.1.0.weight": "controlnet_blocks.1.weight",
+ "control_model.zero_convs.1.0.bias": "controlnet_blocks.0.bias",
+ "control_model.zero_convs.2.0.weight": "controlnet_blocks.2.weight",
+ "control_model.zero_convs.2.0.bias": "controlnet_blocks.0.bias",
+ "control_model.zero_convs.3.0.weight": "controlnet_blocks.3.weight",
+ "control_model.zero_convs.3.0.bias": "controlnet_blocks.0.bias",
+ "control_model.zero_convs.4.0.weight": "controlnet_blocks.4.weight",
+ "control_model.zero_convs.4.0.bias": "controlnet_blocks.4.bias",
+ "control_model.zero_convs.5.0.weight": "controlnet_blocks.5.weight",
+ "control_model.zero_convs.5.0.bias": "controlnet_blocks.4.bias",
+ "control_model.zero_convs.6.0.weight": "controlnet_blocks.6.weight",
+ "control_model.zero_convs.6.0.bias": "controlnet_blocks.4.bias",
+ "control_model.zero_convs.7.0.weight": "controlnet_blocks.7.weight",
+ "control_model.zero_convs.7.0.bias": "controlnet_blocks.7.bias",
+ "control_model.zero_convs.8.0.weight": "controlnet_blocks.8.weight",
+ "control_model.zero_convs.8.0.bias": "controlnet_blocks.7.bias",
+ "control_model.zero_convs.9.0.weight": "controlnet_blocks.9.weight",
+ "control_model.zero_convs.9.0.bias": "controlnet_blocks.7.bias",
+ "control_model.zero_convs.10.0.weight": "controlnet_blocks.10.weight",
+ "control_model.zero_convs.10.0.bias": "controlnet_blocks.7.bias",
+ "control_model.zero_convs.11.0.weight": "controlnet_blocks.11.weight",
+ "control_model.zero_convs.11.0.bias": "controlnet_blocks.7.bias",
+ "control_model.input_hint_block.0.weight": "controlnet_conv_in.blocks.0.weight",
+ "control_model.input_hint_block.0.bias": "controlnet_conv_in.blocks.0.bias",
+ "control_model.input_hint_block.2.weight": "controlnet_conv_in.blocks.2.weight",
+ "control_model.input_hint_block.2.bias": "controlnet_conv_in.blocks.2.bias",
+ "control_model.input_hint_block.4.weight": "controlnet_conv_in.blocks.4.weight",
+ "control_model.input_hint_block.4.bias": "controlnet_conv_in.blocks.4.bias",
+ "control_model.input_hint_block.6.weight": "controlnet_conv_in.blocks.6.weight",
+ "control_model.input_hint_block.6.bias": "controlnet_conv_in.blocks.6.bias",
+ "control_model.input_hint_block.8.weight": "controlnet_conv_in.blocks.8.weight",
+ "control_model.input_hint_block.8.bias": "controlnet_conv_in.blocks.8.bias",
+ "control_model.input_hint_block.10.weight": "controlnet_conv_in.blocks.10.weight",
+ "control_model.input_hint_block.10.bias": "controlnet_conv_in.blocks.10.bias",
+ "control_model.input_hint_block.12.weight": "controlnet_conv_in.blocks.12.weight",
+ "control_model.input_hint_block.12.bias": "controlnet_conv_in.blocks.12.bias",
+ "control_model.input_hint_block.14.weight": "controlnet_conv_in.blocks.14.weight",
+ "control_model.input_hint_block.14.bias": "controlnet_conv_in.blocks.14.bias",
+ "control_model.middle_block.0.in_layers.0.weight": "blocks.28.norm1.weight",
+ "control_model.middle_block.0.in_layers.0.bias": "blocks.28.norm1.bias",
+ "control_model.middle_block.0.in_layers.2.weight": "blocks.28.conv1.weight",
+ "control_model.middle_block.0.in_layers.2.bias": "blocks.28.conv1.bias",
+ "control_model.middle_block.0.emb_layers.1.weight": "blocks.28.time_emb_proj.weight",
+ "control_model.middle_block.0.emb_layers.1.bias": "blocks.28.time_emb_proj.bias",
+ "control_model.middle_block.0.out_layers.0.weight": "blocks.28.norm2.weight",
+ "control_model.middle_block.0.out_layers.0.bias": "blocks.28.norm2.bias",
+ "control_model.middle_block.0.out_layers.3.weight": "blocks.28.conv2.weight",
+ "control_model.middle_block.0.out_layers.3.bias": "blocks.28.conv2.bias",
+ "control_model.middle_block.1.norm.weight": "blocks.29.norm.weight",
+ "control_model.middle_block.1.norm.bias": "blocks.29.norm.bias",
+ "control_model.middle_block.1.proj_in.weight": "blocks.29.proj_in.weight",
+ "control_model.middle_block.1.proj_in.bias": "blocks.29.proj_in.bias",
+ "control_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight": "blocks.29.transformer_blocks.0.attn1.to_q.weight",
+ "control_model.middle_block.1.transformer_blocks.0.attn1.to_k.weight": "blocks.29.transformer_blocks.0.attn1.to_k.weight",
+ "control_model.middle_block.1.transformer_blocks.0.attn1.to_v.weight": "blocks.29.transformer_blocks.0.attn1.to_v.weight",
+ "control_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.29.transformer_blocks.0.attn1.to_out.weight",
+ "control_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.29.transformer_blocks.0.attn1.to_out.bias",
+ "control_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.29.transformer_blocks.0.act_fn.proj.weight",
+ "control_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.29.transformer_blocks.0.act_fn.proj.bias",
+ "control_model.middle_block.1.transformer_blocks.0.ff.net.2.weight": "blocks.29.transformer_blocks.0.ff.weight",
+ "control_model.middle_block.1.transformer_blocks.0.ff.net.2.bias": "blocks.29.transformer_blocks.0.ff.bias",
+ "control_model.middle_block.1.transformer_blocks.0.attn2.to_q.weight": "blocks.29.transformer_blocks.0.attn2.to_q.weight",
+ "control_model.middle_block.1.transformer_blocks.0.attn2.to_k.weight": "blocks.29.transformer_blocks.0.attn2.to_k.weight",
+ "control_model.middle_block.1.transformer_blocks.0.attn2.to_v.weight": "blocks.29.transformer_blocks.0.attn2.to_v.weight",
+ "control_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.29.transformer_blocks.0.attn2.to_out.weight",
+ "control_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.29.transformer_blocks.0.attn2.to_out.bias",
+ "control_model.middle_block.1.transformer_blocks.0.norm1.weight": "blocks.29.transformer_blocks.0.norm1.weight",
+ "control_model.middle_block.1.transformer_blocks.0.norm1.bias": "blocks.29.transformer_blocks.0.norm1.bias",
+ "control_model.middle_block.1.transformer_blocks.0.norm2.weight": "blocks.29.transformer_blocks.0.norm2.weight",
+ "control_model.middle_block.1.transformer_blocks.0.norm2.bias": "blocks.29.transformer_blocks.0.norm2.bias",
+ "control_model.middle_block.1.transformer_blocks.0.norm3.weight": "blocks.29.transformer_blocks.0.norm3.weight",
+ "control_model.middle_block.1.transformer_blocks.0.norm3.bias": "blocks.29.transformer_blocks.0.norm3.bias",
+ "control_model.middle_block.1.proj_out.weight": "blocks.29.proj_out.weight",
+ "control_model.middle_block.1.proj_out.bias": "blocks.29.proj_out.bias",
+ "control_model.middle_block.2.in_layers.0.weight": "blocks.30.norm1.weight",
+ "control_model.middle_block.2.in_layers.0.bias": "blocks.30.norm1.bias",
+ "control_model.middle_block.2.in_layers.2.weight": "blocks.30.conv1.weight",
+ "control_model.middle_block.2.in_layers.2.bias": "blocks.30.conv1.bias",
+ "control_model.middle_block.2.emb_layers.1.weight": "blocks.30.time_emb_proj.weight",
+ "control_model.middle_block.2.emb_layers.1.bias": "blocks.30.time_emb_proj.bias",
+ "control_model.middle_block.2.out_layers.0.weight": "blocks.30.norm2.weight",
+ "control_model.middle_block.2.out_layers.0.bias": "blocks.30.norm2.bias",
+ "control_model.middle_block.2.out_layers.3.weight": "blocks.30.conv2.weight",
+ "control_model.middle_block.2.out_layers.3.bias": "blocks.30.conv2.bias",
+ "control_model.middle_block_out.0.weight": "controlnet_blocks.12.weight",
+ "control_model.middle_block_out.0.bias": "controlnet_blocks.7.bias",
+ }
+ state_dict_ = {}
+ for name in state_dict:
+ if name in rename_dict:
+ param = state_dict[name]
+ if ".proj_in." in name or ".proj_out." in name:
+ param = param.squeeze()
+ state_dict_[rename_dict[name]] = param
+ return state_dict_
diff --git a/PusaV1/diffsynth/models/sd_ipadapter.py b/PusaV1/diffsynth/models/sd_ipadapter.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d6ebd7d5e79ccd534aab11d22e046111562ccde
--- /dev/null
+++ b/PusaV1/diffsynth/models/sd_ipadapter.py
@@ -0,0 +1,57 @@
+from .svd_image_encoder import SVDImageEncoder
+from .sdxl_ipadapter import IpAdapterImageProjModel, IpAdapterModule, SDXLIpAdapterStateDictConverter
+from transformers import CLIPImageProcessor
+import torch
+
+
+class IpAdapterCLIPImageEmbedder(SVDImageEncoder):
+ def __init__(self):
+ super().__init__()
+ self.image_processor = CLIPImageProcessor()
+
+ def forward(self, image):
+ pixel_values = self.image_processor(images=image, return_tensors="pt").pixel_values
+ pixel_values = pixel_values.to(device=self.embeddings.class_embedding.device, dtype=self.embeddings.class_embedding.dtype)
+ return super().forward(pixel_values)
+
+
+class SDIpAdapter(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ shape_list = [(768, 320)] * 2 + [(768, 640)] * 2 + [(768, 1280)] * 5 + [(768, 640)] * 3 + [(768, 320)] * 3 + [(768, 1280)] * 1
+ self.ipadapter_modules = torch.nn.ModuleList([IpAdapterModule(*shape) for shape in shape_list])
+ self.image_proj = IpAdapterImageProjModel(cross_attention_dim=768, clip_embeddings_dim=1024, clip_extra_context_tokens=4)
+ self.set_full_adapter()
+
+ def set_full_adapter(self):
+ block_ids = [1, 4, 9, 12, 17, 20, 40, 43, 46, 50, 53, 56, 60, 63, 66, 29]
+ self.call_block_id = {(i, 0): j for j, i in enumerate(block_ids)}
+
+ def set_less_adapter(self):
+ # IP-Adapter for SD v1.5 doesn't support this feature.
+ self.set_full_adapter()
+
+ def forward(self, hidden_states, scale=1.0):
+ hidden_states = self.image_proj(hidden_states)
+ hidden_states = hidden_states.view(1, -1, hidden_states.shape[-1])
+ ip_kv_dict = {}
+ for (block_id, transformer_id) in self.call_block_id:
+ ipadapter_id = self.call_block_id[(block_id, transformer_id)]
+ ip_k, ip_v = self.ipadapter_modules[ipadapter_id](hidden_states)
+ if block_id not in ip_kv_dict:
+ ip_kv_dict[block_id] = {}
+ ip_kv_dict[block_id][transformer_id] = {
+ "ip_k": ip_k,
+ "ip_v": ip_v,
+ "scale": scale
+ }
+ return ip_kv_dict
+
+ @staticmethod
+ def state_dict_converter():
+ return SDIpAdapterStateDictConverter()
+
+
+class SDIpAdapterStateDictConverter(SDXLIpAdapterStateDictConverter):
+ def __init__(self):
+ pass
diff --git a/PusaV1/diffsynth/models/sd_motion.py b/PusaV1/diffsynth/models/sd_motion.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb49138e147538537a60fb4a3e2d12a175da4a50
--- /dev/null
+++ b/PusaV1/diffsynth/models/sd_motion.py
@@ -0,0 +1,199 @@
+from .sd_unet import SDUNet, Attention, GEGLU
+import torch
+from einops import rearrange, repeat
+
+
+class TemporalTransformerBlock(torch.nn.Module):
+
+ def __init__(self, dim, num_attention_heads, attention_head_dim, max_position_embeddings=32):
+ super().__init__()
+
+ # 1. Self-Attn
+ self.pe1 = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, dim))
+ self.norm1 = torch.nn.LayerNorm(dim, elementwise_affine=True)
+ self.attn1 = Attention(q_dim=dim, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True)
+
+ # 2. Cross-Attn
+ self.pe2 = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, dim))
+ self.norm2 = torch.nn.LayerNorm(dim, elementwise_affine=True)
+ self.attn2 = Attention(q_dim=dim, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True)
+
+ # 3. Feed-forward
+ self.norm3 = torch.nn.LayerNorm(dim, elementwise_affine=True)
+ self.act_fn = GEGLU(dim, dim * 4)
+ self.ff = torch.nn.Linear(dim * 4, dim)
+
+
+ def forward(self, hidden_states, batch_size=1):
+
+ # 1. Self-Attention
+ norm_hidden_states = self.norm1(hidden_states)
+ norm_hidden_states = rearrange(norm_hidden_states, "(b f) h c -> (b h) f c", b=batch_size)
+ attn_output = self.attn1(norm_hidden_states + self.pe1[:, :norm_hidden_states.shape[1]])
+ attn_output = rearrange(attn_output, "(b h) f c -> (b f) h c", b=batch_size)
+ hidden_states = attn_output + hidden_states
+
+ # 2. Cross-Attention
+ norm_hidden_states = self.norm2(hidden_states)
+ norm_hidden_states = rearrange(norm_hidden_states, "(b f) h c -> (b h) f c", b=batch_size)
+ attn_output = self.attn2(norm_hidden_states + self.pe2[:, :norm_hidden_states.shape[1]])
+ attn_output = rearrange(attn_output, "(b h) f c -> (b f) h c", b=batch_size)
+ hidden_states = attn_output + hidden_states
+
+ # 3. Feed-forward
+ norm_hidden_states = self.norm3(hidden_states)
+ ff_output = self.act_fn(norm_hidden_states)
+ ff_output = self.ff(ff_output)
+ hidden_states = ff_output + hidden_states
+
+ return hidden_states
+
+
+class TemporalBlock(torch.nn.Module):
+
+ def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5):
+ super().__init__()
+ inner_dim = num_attention_heads * attention_head_dim
+
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True)
+ self.proj_in = torch.nn.Linear(in_channels, inner_dim)
+
+ self.transformer_blocks = torch.nn.ModuleList([
+ TemporalTransformerBlock(
+ inner_dim,
+ num_attention_heads,
+ attention_head_dim
+ )
+ for d in range(num_layers)
+ ])
+
+ self.proj_out = torch.nn.Linear(inner_dim, in_channels)
+
+ def forward(self, hidden_states, time_emb, text_emb, res_stack, batch_size=1):
+ batch, _, height, width = hidden_states.shape
+ residual = hidden_states
+
+ hidden_states = self.norm(hidden_states)
+ inner_dim = hidden_states.shape[1]
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
+ hidden_states = self.proj_in(hidden_states)
+
+ for block in self.transformer_blocks:
+ hidden_states = block(
+ hidden_states,
+ batch_size=batch_size
+ )
+
+ hidden_states = self.proj_out(hidden_states)
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
+ hidden_states = hidden_states + residual
+
+ return hidden_states, time_emb, text_emb, res_stack
+
+
+class SDMotionModel(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.motion_modules = torch.nn.ModuleList([
+ TemporalBlock(8, 40, 320, eps=1e-6),
+ TemporalBlock(8, 40, 320, eps=1e-6),
+ TemporalBlock(8, 80, 640, eps=1e-6),
+ TemporalBlock(8, 80, 640, eps=1e-6),
+ TemporalBlock(8, 160, 1280, eps=1e-6),
+ TemporalBlock(8, 160, 1280, eps=1e-6),
+ TemporalBlock(8, 160, 1280, eps=1e-6),
+ TemporalBlock(8, 160, 1280, eps=1e-6),
+ TemporalBlock(8, 160, 1280, eps=1e-6),
+ TemporalBlock(8, 160, 1280, eps=1e-6),
+ TemporalBlock(8, 160, 1280, eps=1e-6),
+ TemporalBlock(8, 160, 1280, eps=1e-6),
+ TemporalBlock(8, 160, 1280, eps=1e-6),
+ TemporalBlock(8, 160, 1280, eps=1e-6),
+ TemporalBlock(8, 160, 1280, eps=1e-6),
+ TemporalBlock(8, 80, 640, eps=1e-6),
+ TemporalBlock(8, 80, 640, eps=1e-6),
+ TemporalBlock(8, 80, 640, eps=1e-6),
+ TemporalBlock(8, 40, 320, eps=1e-6),
+ TemporalBlock(8, 40, 320, eps=1e-6),
+ TemporalBlock(8, 40, 320, eps=1e-6),
+ ])
+ self.call_block_id = {
+ 1: 0,
+ 4: 1,
+ 9: 2,
+ 12: 3,
+ 17: 4,
+ 20: 5,
+ 24: 6,
+ 26: 7,
+ 29: 8,
+ 32: 9,
+ 34: 10,
+ 36: 11,
+ 40: 12,
+ 43: 13,
+ 46: 14,
+ 50: 15,
+ 53: 16,
+ 56: 17,
+ 60: 18,
+ 63: 19,
+ 66: 20
+ }
+
+ def forward(self):
+ pass
+
+ @staticmethod
+ def state_dict_converter():
+ return SDMotionModelStateDictConverter()
+
+
+class SDMotionModelStateDictConverter:
+ def __init__(self):
+ pass
+
+ def from_diffusers(self, state_dict):
+ rename_dict = {
+ "norm": "norm",
+ "proj_in": "proj_in",
+ "transformer_blocks.0.attention_blocks.0.to_q": "transformer_blocks.0.attn1.to_q",
+ "transformer_blocks.0.attention_blocks.0.to_k": "transformer_blocks.0.attn1.to_k",
+ "transformer_blocks.0.attention_blocks.0.to_v": "transformer_blocks.0.attn1.to_v",
+ "transformer_blocks.0.attention_blocks.0.to_out.0": "transformer_blocks.0.attn1.to_out",
+ "transformer_blocks.0.attention_blocks.0.pos_encoder": "transformer_blocks.0.pe1",
+ "transformer_blocks.0.attention_blocks.1.to_q": "transformer_blocks.0.attn2.to_q",
+ "transformer_blocks.0.attention_blocks.1.to_k": "transformer_blocks.0.attn2.to_k",
+ "transformer_blocks.0.attention_blocks.1.to_v": "transformer_blocks.0.attn2.to_v",
+ "transformer_blocks.0.attention_blocks.1.to_out.0": "transformer_blocks.0.attn2.to_out",
+ "transformer_blocks.0.attention_blocks.1.pos_encoder": "transformer_blocks.0.pe2",
+ "transformer_blocks.0.norms.0": "transformer_blocks.0.norm1",
+ "transformer_blocks.0.norms.1": "transformer_blocks.0.norm2",
+ "transformer_blocks.0.ff.net.0.proj": "transformer_blocks.0.act_fn.proj",
+ "transformer_blocks.0.ff.net.2": "transformer_blocks.0.ff",
+ "transformer_blocks.0.ff_norm": "transformer_blocks.0.norm3",
+ "proj_out": "proj_out",
+ }
+ name_list = sorted([i for i in state_dict if i.startswith("down_blocks.")])
+ name_list += sorted([i for i in state_dict if i.startswith("mid_block.")])
+ name_list += sorted([i for i in state_dict if i.startswith("up_blocks.")])
+ state_dict_ = {}
+ last_prefix, module_id = "", -1
+ for name in name_list:
+ names = name.split(".")
+ prefix_index = names.index("temporal_transformer") + 1
+ prefix = ".".join(names[:prefix_index])
+ if prefix != last_prefix:
+ last_prefix = prefix
+ module_id += 1
+ middle_name = ".".join(names[prefix_index:-1])
+ suffix = names[-1]
+ if "pos_encoder" in names:
+ rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name]])
+ else:
+ rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name], suffix])
+ state_dict_[rename] = state_dict[name]
+ return state_dict_
+
+ def from_civitai(self, state_dict):
+ return self.from_diffusers(state_dict)
diff --git a/PusaV1/diffsynth/models/sd_text_encoder.py b/PusaV1/diffsynth/models/sd_text_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..8fe8994a10bc998b8778cae2cbd57b95545166ba
--- /dev/null
+++ b/PusaV1/diffsynth/models/sd_text_encoder.py
@@ -0,0 +1,321 @@
+import torch
+from .attention import Attention
+
+
+class CLIPEncoderLayer(torch.nn.Module):
+ def __init__(self, embed_dim, intermediate_size, num_heads=12, head_dim=64, use_quick_gelu=True):
+ super().__init__()
+ self.attn = Attention(q_dim=embed_dim, num_heads=num_heads, head_dim=head_dim, bias_q=True, bias_kv=True, bias_out=True)
+ self.layer_norm1 = torch.nn.LayerNorm(embed_dim)
+ self.layer_norm2 = torch.nn.LayerNorm(embed_dim)
+ self.fc1 = torch.nn.Linear(embed_dim, intermediate_size)
+ self.fc2 = torch.nn.Linear(intermediate_size, embed_dim)
+
+ self.use_quick_gelu = use_quick_gelu
+
+ def quickGELU(self, x):
+ return x * torch.sigmoid(1.702 * x)
+
+ def forward(self, hidden_states, attn_mask=None):
+ residual = hidden_states
+
+ hidden_states = self.layer_norm1(hidden_states)
+ hidden_states = self.attn(hidden_states, attn_mask=attn_mask)
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.layer_norm2(hidden_states)
+ hidden_states = self.fc1(hidden_states)
+ if self.use_quick_gelu:
+ hidden_states = self.quickGELU(hidden_states)
+ else:
+ hidden_states = torch.nn.functional.gelu(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ hidden_states = residual + hidden_states
+
+ return hidden_states
+
+
+class SDTextEncoder(torch.nn.Module):
+ def __init__(self, embed_dim=768, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=12, encoder_intermediate_size=3072):
+ super().__init__()
+
+ # token_embedding
+ self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim)
+
+ # position_embeds (This is a fixed tensor)
+ self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim))
+
+ # encoders
+ self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size) for _ in range(num_encoder_layers)])
+
+ # attn_mask
+ self.attn_mask = self.attention_mask(max_position_embeddings)
+
+ # final_layer_norm
+ self.final_layer_norm = torch.nn.LayerNorm(embed_dim)
+
+ def attention_mask(self, length):
+ mask = torch.empty(length, length)
+ mask.fill_(float("-inf"))
+ mask.triu_(1)
+ return mask
+
+ def forward(self, input_ids, clip_skip=1):
+ embeds = self.token_embedding(input_ids) + self.position_embeds
+ attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
+ for encoder_id, encoder in enumerate(self.encoders):
+ embeds = encoder(embeds, attn_mask=attn_mask)
+ if encoder_id + clip_skip == len(self.encoders):
+ break
+ embeds = self.final_layer_norm(embeds)
+ return embeds
+
+ @staticmethod
+ def state_dict_converter():
+ return SDTextEncoderStateDictConverter()
+
+
+class SDTextEncoderStateDictConverter:
+ def __init__(self):
+ pass
+
+ def from_diffusers(self, state_dict):
+ rename_dict = {
+ "text_model.embeddings.token_embedding.weight": "token_embedding.weight",
+ "text_model.embeddings.position_embedding.weight": "position_embeds",
+ "text_model.final_layer_norm.weight": "final_layer_norm.weight",
+ "text_model.final_layer_norm.bias": "final_layer_norm.bias"
+ }
+ attn_rename_dict = {
+ "self_attn.q_proj": "attn.to_q",
+ "self_attn.k_proj": "attn.to_k",
+ "self_attn.v_proj": "attn.to_v",
+ "self_attn.out_proj": "attn.to_out",
+ "layer_norm1": "layer_norm1",
+ "layer_norm2": "layer_norm2",
+ "mlp.fc1": "fc1",
+ "mlp.fc2": "fc2",
+ }
+ state_dict_ = {}
+ for name in state_dict:
+ if name in rename_dict:
+ param = state_dict[name]
+ if name == "text_model.embeddings.position_embedding.weight":
+ param = param.reshape((1, param.shape[0], param.shape[1]))
+ state_dict_[rename_dict[name]] = param
+ elif name.startswith("text_model.encoder.layers."):
+ param = state_dict[name]
+ names = name.split(".")
+ layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
+ name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail])
+ state_dict_[name_] = param
+ return state_dict_
+
+ def from_civitai(self, state_dict):
+ rename_dict = {
+ "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight": "token_embedding.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.bias": "encoders.0.layer_norm1.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.weight": "encoders.0.layer_norm1.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.bias": "encoders.0.layer_norm2.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.weight": "encoders.0.layer_norm2.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.bias": "encoders.0.fc1.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.weight": "encoders.0.fc1.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.bias": "encoders.0.fc2.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.weight": "encoders.0.fc2.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.bias": "encoders.0.attn.to_k.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.weight": "encoders.0.attn.to_k.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.bias": "encoders.0.attn.to_out.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.weight": "encoders.0.attn.to_out.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.bias": "encoders.0.attn.to_q.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight": "encoders.0.attn.to_q.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.bias": "encoders.0.attn.to_v.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.weight": "encoders.0.attn.to_v.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.bias": "encoders.1.layer_norm1.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.weight": "encoders.1.layer_norm1.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.bias": "encoders.1.layer_norm2.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.weight": "encoders.1.layer_norm2.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.bias": "encoders.1.fc1.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.weight": "encoders.1.fc1.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.bias": "encoders.1.fc2.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.weight": "encoders.1.fc2.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.bias": "encoders.1.attn.to_k.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.weight": "encoders.1.attn.to_k.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.bias": "encoders.1.attn.to_out.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.weight": "encoders.1.attn.to_out.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.bias": "encoders.1.attn.to_q.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.weight": "encoders.1.attn.to_q.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.bias": "encoders.1.attn.to_v.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.weight": "encoders.1.attn.to_v.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.bias": "encoders.10.layer_norm1.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.weight": "encoders.10.layer_norm1.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.bias": "encoders.10.layer_norm2.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.weight": "encoders.10.layer_norm2.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.bias": "encoders.10.fc1.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.weight": "encoders.10.fc1.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.bias": "encoders.10.fc2.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.weight": "encoders.10.fc2.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.bias": "encoders.10.attn.to_k.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.weight": "encoders.10.attn.to_k.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.bias": "encoders.10.attn.to_out.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.weight": "encoders.10.attn.to_out.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.bias": "encoders.10.attn.to_q.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.weight": "encoders.10.attn.to_q.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.bias": "encoders.10.attn.to_v.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.weight": "encoders.10.attn.to_v.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.bias": "encoders.11.layer_norm1.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.weight": "encoders.11.layer_norm1.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.bias": "encoders.11.layer_norm2.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.weight": "encoders.11.layer_norm2.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.bias": "encoders.11.fc1.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.weight": "encoders.11.fc1.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.bias": "encoders.11.fc2.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.weight": "encoders.11.fc2.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.bias": "encoders.11.attn.to_k.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.weight": "encoders.11.attn.to_k.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.bias": "encoders.11.attn.to_out.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.weight": "encoders.11.attn.to_out.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.bias": "encoders.11.attn.to_q.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.weight": "encoders.11.attn.to_q.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.bias": "encoders.11.attn.to_v.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.weight": "encoders.11.attn.to_v.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.bias": "encoders.2.layer_norm1.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.weight": "encoders.2.layer_norm1.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.bias": "encoders.2.layer_norm2.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.weight": "encoders.2.layer_norm2.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.bias": "encoders.2.fc1.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.weight": "encoders.2.fc1.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.bias": "encoders.2.fc2.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.weight": "encoders.2.fc2.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.bias": "encoders.2.attn.to_k.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.weight": "encoders.2.attn.to_k.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.bias": "encoders.2.attn.to_out.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.weight": "encoders.2.attn.to_out.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.bias": "encoders.2.attn.to_q.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.weight": "encoders.2.attn.to_q.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.bias": "encoders.2.attn.to_v.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.weight": "encoders.2.attn.to_v.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.bias": "encoders.3.layer_norm1.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.weight": "encoders.3.layer_norm1.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.bias": "encoders.3.layer_norm2.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.weight": "encoders.3.layer_norm2.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.bias": "encoders.3.fc1.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.weight": "encoders.3.fc1.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.bias": "encoders.3.fc2.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.weight": "encoders.3.fc2.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.bias": "encoders.3.attn.to_k.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.weight": "encoders.3.attn.to_k.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.bias": "encoders.3.attn.to_out.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.weight": "encoders.3.attn.to_out.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.bias": "encoders.3.attn.to_q.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.weight": "encoders.3.attn.to_q.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.bias": "encoders.3.attn.to_v.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.weight": "encoders.3.attn.to_v.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.bias": "encoders.4.layer_norm1.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.weight": "encoders.4.layer_norm1.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.bias": "encoders.4.layer_norm2.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.weight": "encoders.4.layer_norm2.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.bias": "encoders.4.fc1.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.weight": "encoders.4.fc1.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.bias": "encoders.4.fc2.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.weight": "encoders.4.fc2.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.bias": "encoders.4.attn.to_k.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.weight": "encoders.4.attn.to_k.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.bias": "encoders.4.attn.to_out.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.weight": "encoders.4.attn.to_out.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.bias": "encoders.4.attn.to_q.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.weight": "encoders.4.attn.to_q.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.bias": "encoders.4.attn.to_v.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.weight": "encoders.4.attn.to_v.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.bias": "encoders.5.layer_norm1.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.weight": "encoders.5.layer_norm1.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.bias": "encoders.5.layer_norm2.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.weight": "encoders.5.layer_norm2.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.bias": "encoders.5.fc1.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.weight": "encoders.5.fc1.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.bias": "encoders.5.fc2.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.weight": "encoders.5.fc2.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.bias": "encoders.5.attn.to_k.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.weight": "encoders.5.attn.to_k.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.bias": "encoders.5.attn.to_out.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.weight": "encoders.5.attn.to_out.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.bias": "encoders.5.attn.to_q.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.weight": "encoders.5.attn.to_q.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.bias": "encoders.5.attn.to_v.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.weight": "encoders.5.attn.to_v.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.bias": "encoders.6.layer_norm1.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.weight": "encoders.6.layer_norm1.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.bias": "encoders.6.layer_norm2.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.weight": "encoders.6.layer_norm2.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.bias": "encoders.6.fc1.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.weight": "encoders.6.fc1.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.bias": "encoders.6.fc2.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.weight": "encoders.6.fc2.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.bias": "encoders.6.attn.to_k.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.weight": "encoders.6.attn.to_k.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.bias": "encoders.6.attn.to_out.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.weight": "encoders.6.attn.to_out.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.bias": "encoders.6.attn.to_q.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.weight": "encoders.6.attn.to_q.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.bias": "encoders.6.attn.to_v.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.weight": "encoders.6.attn.to_v.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.bias": "encoders.7.layer_norm1.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.weight": "encoders.7.layer_norm1.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.bias": "encoders.7.layer_norm2.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.weight": "encoders.7.layer_norm2.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.bias": "encoders.7.fc1.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.weight": "encoders.7.fc1.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.bias": "encoders.7.fc2.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.weight": "encoders.7.fc2.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.bias": "encoders.7.attn.to_k.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.weight": "encoders.7.attn.to_k.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.bias": "encoders.7.attn.to_out.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.weight": "encoders.7.attn.to_out.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.bias": "encoders.7.attn.to_q.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.weight": "encoders.7.attn.to_q.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.bias": "encoders.7.attn.to_v.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.weight": "encoders.7.attn.to_v.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.bias": "encoders.8.layer_norm1.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.weight": "encoders.8.layer_norm1.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.bias": "encoders.8.layer_norm2.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.weight": "encoders.8.layer_norm2.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.bias": "encoders.8.fc1.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.weight": "encoders.8.fc1.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.bias": "encoders.8.fc2.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.weight": "encoders.8.fc2.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.bias": "encoders.8.attn.to_k.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.weight": "encoders.8.attn.to_k.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.bias": "encoders.8.attn.to_out.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.weight": "encoders.8.attn.to_out.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.bias": "encoders.8.attn.to_q.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.weight": "encoders.8.attn.to_q.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.bias": "encoders.8.attn.to_v.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.weight": "encoders.8.attn.to_v.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.bias": "encoders.9.layer_norm1.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.weight": "encoders.9.layer_norm1.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.bias": "encoders.9.layer_norm2.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.weight": "encoders.9.layer_norm2.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.bias": "encoders.9.fc1.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.weight": "encoders.9.fc1.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.bias": "encoders.9.fc2.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.weight": "encoders.9.fc2.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.bias": "encoders.9.attn.to_k.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.weight": "encoders.9.attn.to_k.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.bias": "encoders.9.attn.to_out.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.weight": "encoders.9.attn.to_out.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.bias": "encoders.9.attn.to_q.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.weight": "encoders.9.attn.to_q.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.bias": "encoders.9.attn.to_v.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight": "encoders.9.attn.to_v.weight",
+ "cond_stage_model.transformer.text_model.final_layer_norm.bias": "final_layer_norm.bias",
+ "cond_stage_model.transformer.text_model.final_layer_norm.weight": "final_layer_norm.weight",
+ "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight": "position_embeds"
+ }
+ state_dict_ = {}
+ for name in state_dict:
+ if name in rename_dict:
+ param = state_dict[name]
+ if name == "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight":
+ param = param.reshape((1, param.shape[0], param.shape[1]))
+ state_dict_[rename_dict[name]] = param
+ return state_dict_
diff --git a/PusaV1/diffsynth/models/sd_unet.py b/PusaV1/diffsynth/models/sd_unet.py
new file mode 100644
index 0000000000000000000000000000000000000000..33363909e4968292fe22e1953dc0c4a12e41d921
--- /dev/null
+++ b/PusaV1/diffsynth/models/sd_unet.py
@@ -0,0 +1,1108 @@
+import torch, math
+from .attention import Attention
+from .tiler import TileWorker
+
+
+class Timesteps(torch.nn.Module):
+ def __init__(self, num_channels):
+ super().__init__()
+ self.num_channels = num_channels
+
+ def forward(self, timesteps):
+ half_dim = self.num_channels // 2
+ exponent = -math.log(10000) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device) / half_dim
+ timesteps = timesteps.unsqueeze(-1)
+ emb = timesteps.float() * torch.exp(exponent)
+ emb = torch.cat([torch.cos(emb), torch.sin(emb)], dim=-1)
+ return emb
+
+
+class GEGLU(torch.nn.Module):
+
+ def __init__(self, dim_in, dim_out):
+ super().__init__()
+ self.proj = torch.nn.Linear(dim_in, dim_out * 2)
+
+ def forward(self, hidden_states):
+ hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
+ return hidden_states * torch.nn.functional.gelu(gate)
+
+
+class BasicTransformerBlock(torch.nn.Module):
+
+ def __init__(self, dim, num_attention_heads, attention_head_dim, cross_attention_dim):
+ super().__init__()
+
+ # 1. Self-Attn
+ self.norm1 = torch.nn.LayerNorm(dim, elementwise_affine=True)
+ self.attn1 = Attention(q_dim=dim, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True)
+
+ # 2. Cross-Attn
+ self.norm2 = torch.nn.LayerNorm(dim, elementwise_affine=True)
+ self.attn2 = Attention(q_dim=dim, kv_dim=cross_attention_dim, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True)
+
+ # 3. Feed-forward
+ self.norm3 = torch.nn.LayerNorm(dim, elementwise_affine=True)
+ self.act_fn = GEGLU(dim, dim * 4)
+ self.ff = torch.nn.Linear(dim * 4, dim)
+
+
+ def forward(self, hidden_states, encoder_hidden_states, ipadapter_kwargs=None):
+ # 1. Self-Attention
+ norm_hidden_states = self.norm1(hidden_states)
+ attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None)
+ hidden_states = attn_output + hidden_states
+
+ # 2. Cross-Attention
+ norm_hidden_states = self.norm2(hidden_states)
+ attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states, ipadapter_kwargs=ipadapter_kwargs)
+ hidden_states = attn_output + hidden_states
+
+ # 3. Feed-forward
+ norm_hidden_states = self.norm3(hidden_states)
+ ff_output = self.act_fn(norm_hidden_states)
+ ff_output = self.ff(ff_output)
+ hidden_states = ff_output + hidden_states
+
+ return hidden_states
+
+
+class DownSampler(torch.nn.Module):
+ def __init__(self, channels, padding=1, extra_padding=False):
+ super().__init__()
+ self.conv = torch.nn.Conv2d(channels, channels, 3, stride=2, padding=padding)
+ self.extra_padding = extra_padding
+
+ def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs):
+ if self.extra_padding:
+ hidden_states = torch.nn.functional.pad(hidden_states, (0, 1, 0, 1), mode="constant", value=0)
+ hidden_states = self.conv(hidden_states)
+ return hidden_states, time_emb, text_emb, res_stack
+
+
+class UpSampler(torch.nn.Module):
+ def __init__(self, channels):
+ super().__init__()
+ self.conv = torch.nn.Conv2d(channels, channels, 3, padding=1)
+
+ def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs):
+ hidden_states = torch.nn.functional.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
+ hidden_states = self.conv(hidden_states)
+ return hidden_states, time_emb, text_emb, res_stack
+
+
+class ResnetBlock(torch.nn.Module):
+ def __init__(self, in_channels, out_channels, temb_channels=None, groups=32, eps=1e-5):
+ super().__init__()
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
+ self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ if temb_channels is not None:
+ self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True)
+ self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ self.nonlinearity = torch.nn.SiLU()
+ self.conv_shortcut = None
+ if in_channels != out_channels:
+ self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=True)
+
+ def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs):
+ x = hidden_states
+ x = self.norm1(x)
+ x = self.nonlinearity(x)
+ x = self.conv1(x)
+ if time_emb is not None:
+ emb = self.nonlinearity(time_emb)
+ emb = self.time_emb_proj(emb)[:, :, None, None]
+ x = x + emb
+ x = self.norm2(x)
+ x = self.nonlinearity(x)
+ x = self.conv2(x)
+ if self.conv_shortcut is not None:
+ hidden_states = self.conv_shortcut(hidden_states)
+ hidden_states = hidden_states + x
+ return hidden_states, time_emb, text_emb, res_stack
+
+
+class AttentionBlock(torch.nn.Module):
+
+ def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, cross_attention_dim=None, norm_num_groups=32, eps=1e-5, need_proj_out=True):
+ super().__init__()
+ inner_dim = num_attention_heads * attention_head_dim
+
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True)
+ self.proj_in = torch.nn.Linear(in_channels, inner_dim)
+
+ self.transformer_blocks = torch.nn.ModuleList([
+ BasicTransformerBlock(
+ inner_dim,
+ num_attention_heads,
+ attention_head_dim,
+ cross_attention_dim=cross_attention_dim
+ )
+ for d in range(num_layers)
+ ])
+ self.need_proj_out = need_proj_out
+ if need_proj_out:
+ self.proj_out = torch.nn.Linear(inner_dim, in_channels)
+
+ def forward(
+ self,
+ hidden_states, time_emb, text_emb, res_stack,
+ cross_frame_attention=False,
+ tiled=False, tile_size=64, tile_stride=32,
+ ipadapter_kwargs_list={},
+ **kwargs
+ ):
+ batch, _, height, width = hidden_states.shape
+ residual = hidden_states
+
+ hidden_states = self.norm(hidden_states)
+ inner_dim = hidden_states.shape[1]
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
+ hidden_states = self.proj_in(hidden_states)
+
+ if cross_frame_attention:
+ hidden_states = hidden_states.reshape(1, batch * height * width, inner_dim)
+ encoder_hidden_states = text_emb.mean(dim=0, keepdim=True)
+ else:
+ encoder_hidden_states = text_emb
+ if encoder_hidden_states.shape[0] != hidden_states.shape[0]:
+ encoder_hidden_states = encoder_hidden_states.repeat(hidden_states.shape[0], 1, 1)
+
+ if tiled:
+ tile_size = min(tile_size, min(height, width))
+ hidden_states = hidden_states.permute(0, 2, 1).reshape(batch, inner_dim, height, width)
+ def block_tile_forward(x):
+ b, c, h, w = x.shape
+ x = x.permute(0, 2, 3, 1).reshape(b, h*w, c)
+ x = block(x, encoder_hidden_states)
+ x = x.reshape(b, h, w, c).permute(0, 3, 1, 2)
+ return x
+ for block in self.transformer_blocks:
+ hidden_states = TileWorker().tiled_forward(
+ block_tile_forward,
+ hidden_states,
+ tile_size,
+ tile_stride,
+ tile_device=hidden_states.device,
+ tile_dtype=hidden_states.dtype
+ )
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
+ else:
+ for block_id, block in enumerate(self.transformer_blocks):
+ hidden_states = block(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ ipadapter_kwargs=ipadapter_kwargs_list.get(block_id, None)
+ )
+ if cross_frame_attention:
+ hidden_states = hidden_states.reshape(batch, height * width, inner_dim)
+
+ if self.need_proj_out:
+ hidden_states = self.proj_out(hidden_states)
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
+ hidden_states = hidden_states + residual
+ else:
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
+
+ return hidden_states, time_emb, text_emb, res_stack
+
+
+class PushBlock(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs):
+ res_stack.append(hidden_states)
+ return hidden_states, time_emb, text_emb, res_stack
+
+
+class PopBlock(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs):
+ res_hidden_states = res_stack.pop()
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+ return hidden_states, time_emb, text_emb, res_stack
+
+
+class SDUNet(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.time_proj = Timesteps(320)
+ self.time_embedding = torch.nn.Sequential(
+ torch.nn.Linear(320, 1280),
+ torch.nn.SiLU(),
+ torch.nn.Linear(1280, 1280)
+ )
+ self.conv_in = torch.nn.Conv2d(4, 320, kernel_size=3, padding=1)
+
+ self.blocks = torch.nn.ModuleList([
+ # CrossAttnDownBlock2D
+ ResnetBlock(320, 320, 1280),
+ AttentionBlock(8, 40, 320, 1, 768, eps=1e-6),
+ PushBlock(),
+ ResnetBlock(320, 320, 1280),
+ AttentionBlock(8, 40, 320, 1, 768, eps=1e-6),
+ PushBlock(),
+ DownSampler(320),
+ PushBlock(),
+ # CrossAttnDownBlock2D
+ ResnetBlock(320, 640, 1280),
+ AttentionBlock(8, 80, 640, 1, 768, eps=1e-6),
+ PushBlock(),
+ ResnetBlock(640, 640, 1280),
+ AttentionBlock(8, 80, 640, 1, 768, eps=1e-6),
+ PushBlock(),
+ DownSampler(640),
+ PushBlock(),
+ # CrossAttnDownBlock2D
+ ResnetBlock(640, 1280, 1280),
+ AttentionBlock(8, 160, 1280, 1, 768, eps=1e-6),
+ PushBlock(),
+ ResnetBlock(1280, 1280, 1280),
+ AttentionBlock(8, 160, 1280, 1, 768, eps=1e-6),
+ PushBlock(),
+ DownSampler(1280),
+ PushBlock(),
+ # DownBlock2D
+ ResnetBlock(1280, 1280, 1280),
+ PushBlock(),
+ ResnetBlock(1280, 1280, 1280),
+ PushBlock(),
+ # UNetMidBlock2DCrossAttn
+ ResnetBlock(1280, 1280, 1280),
+ AttentionBlock(8, 160, 1280, 1, 768, eps=1e-6),
+ ResnetBlock(1280, 1280, 1280),
+ # UpBlock2D
+ PopBlock(),
+ ResnetBlock(2560, 1280, 1280),
+ PopBlock(),
+ ResnetBlock(2560, 1280, 1280),
+ PopBlock(),
+ ResnetBlock(2560, 1280, 1280),
+ UpSampler(1280),
+ # CrossAttnUpBlock2D
+ PopBlock(),
+ ResnetBlock(2560, 1280, 1280),
+ AttentionBlock(8, 160, 1280, 1, 768, eps=1e-6),
+ PopBlock(),
+ ResnetBlock(2560, 1280, 1280),
+ AttentionBlock(8, 160, 1280, 1, 768, eps=1e-6),
+ PopBlock(),
+ ResnetBlock(1920, 1280, 1280),
+ AttentionBlock(8, 160, 1280, 1, 768, eps=1e-6),
+ UpSampler(1280),
+ # CrossAttnUpBlock2D
+ PopBlock(),
+ ResnetBlock(1920, 640, 1280),
+ AttentionBlock(8, 80, 640, 1, 768, eps=1e-6),
+ PopBlock(),
+ ResnetBlock(1280, 640, 1280),
+ AttentionBlock(8, 80, 640, 1, 768, eps=1e-6),
+ PopBlock(),
+ ResnetBlock(960, 640, 1280),
+ AttentionBlock(8, 80, 640, 1, 768, eps=1e-6),
+ UpSampler(640),
+ # CrossAttnUpBlock2D
+ PopBlock(),
+ ResnetBlock(960, 320, 1280),
+ AttentionBlock(8, 40, 320, 1, 768, eps=1e-6),
+ PopBlock(),
+ ResnetBlock(640, 320, 1280),
+ AttentionBlock(8, 40, 320, 1, 768, eps=1e-6),
+ PopBlock(),
+ ResnetBlock(640, 320, 1280),
+ AttentionBlock(8, 40, 320, 1, 768, eps=1e-6),
+ ])
+
+ self.conv_norm_out = torch.nn.GroupNorm(num_channels=320, num_groups=32, eps=1e-5)
+ self.conv_act = torch.nn.SiLU()
+ self.conv_out = torch.nn.Conv2d(320, 4, kernel_size=3, padding=1)
+
+ def forward(self, sample, timestep, encoder_hidden_states, **kwargs):
+ # 1. time
+ time_emb = self.time_proj(timestep).to(sample.dtype)
+ time_emb = self.time_embedding(time_emb)
+
+ # 2. pre-process
+ hidden_states = self.conv_in(sample)
+ text_emb = encoder_hidden_states
+ res_stack = [hidden_states]
+
+ # 3. blocks
+ for i, block in enumerate(self.blocks):
+ hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
+
+ # 4. output
+ hidden_states = self.conv_norm_out(hidden_states)
+ hidden_states = self.conv_act(hidden_states)
+ hidden_states = self.conv_out(hidden_states)
+
+ return hidden_states
+
+ @staticmethod
+ def state_dict_converter():
+ return SDUNetStateDictConverter()
+
+
+class SDUNetStateDictConverter:
+ def __init__(self):
+ pass
+
+ def from_diffusers(self, state_dict):
+ # architecture
+ block_types = [
+ 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'DownSampler', 'PushBlock',
+ 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'DownSampler', 'PushBlock',
+ 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'DownSampler', 'PushBlock',
+ 'ResnetBlock', 'PushBlock', 'ResnetBlock', 'PushBlock',
+ 'ResnetBlock', 'AttentionBlock', 'ResnetBlock',
+ 'PopBlock', 'ResnetBlock', 'PopBlock', 'ResnetBlock', 'PopBlock', 'ResnetBlock', 'UpSampler',
+ 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'UpSampler',
+ 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'UpSampler',
+ 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock'
+ ]
+
+ # Rename each parameter
+ name_list = sorted([name for name in state_dict])
+ rename_dict = {}
+ block_id = {"ResnetBlock": -1, "AttentionBlock": -1, "DownSampler": -1, "UpSampler": -1}
+ last_block_type_with_id = {"ResnetBlock": "", "AttentionBlock": "", "DownSampler": "", "UpSampler": ""}
+ for name in name_list:
+ names = name.split(".")
+ if names[0] in ["conv_in", "conv_norm_out", "conv_out"]:
+ pass
+ elif names[0] in ["time_embedding", "add_embedding"]:
+ if names[0] == "add_embedding":
+ names[0] = "add_time_embedding"
+ names[1] = {"linear_1": "0", "linear_2": "2"}[names[1]]
+ elif names[0] in ["down_blocks", "mid_block", "up_blocks"]:
+ if names[0] == "mid_block":
+ names.insert(1, "0")
+ block_type = {"resnets": "ResnetBlock", "attentions": "AttentionBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[2]]
+ block_type_with_id = ".".join(names[:4])
+ if block_type_with_id != last_block_type_with_id[block_type]:
+ block_id[block_type] += 1
+ last_block_type_with_id[block_type] = block_type_with_id
+ while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type:
+ block_id[block_type] += 1
+ block_type_with_id = ".".join(names[:4])
+ names = ["blocks", str(block_id[block_type])] + names[4:]
+ if "ff" in names:
+ ff_index = names.index("ff")
+ component = ".".join(names[ff_index:ff_index+3])
+ component = {"ff.net.0": "act_fn", "ff.net.2": "ff"}[component]
+ names = names[:ff_index] + [component] + names[ff_index+3:]
+ if "to_out" in names:
+ names.pop(names.index("to_out") + 1)
+ else:
+ raise ValueError(f"Unknown parameters: {name}")
+ rename_dict[name] = ".".join(names)
+
+ # Convert state_dict
+ state_dict_ = {}
+ for name, param in state_dict.items():
+ if ".proj_in." in name or ".proj_out." in name:
+ param = param.squeeze()
+ state_dict_[rename_dict[name]] = param
+ return state_dict_
+
+ def from_civitai(self, state_dict):
+ rename_dict = {
+ "model.diffusion_model.input_blocks.0.0.bias": "conv_in.bias",
+ "model.diffusion_model.input_blocks.0.0.weight": "conv_in.weight",
+ "model.diffusion_model.input_blocks.1.0.emb_layers.1.bias": "blocks.0.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.1.0.emb_layers.1.weight": "blocks.0.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.1.0.in_layers.0.bias": "blocks.0.norm1.bias",
+ "model.diffusion_model.input_blocks.1.0.in_layers.0.weight": "blocks.0.norm1.weight",
+ "model.diffusion_model.input_blocks.1.0.in_layers.2.bias": "blocks.0.conv1.bias",
+ "model.diffusion_model.input_blocks.1.0.in_layers.2.weight": "blocks.0.conv1.weight",
+ "model.diffusion_model.input_blocks.1.0.out_layers.0.bias": "blocks.0.norm2.bias",
+ "model.diffusion_model.input_blocks.1.0.out_layers.0.weight": "blocks.0.norm2.weight",
+ "model.diffusion_model.input_blocks.1.0.out_layers.3.bias": "blocks.0.conv2.bias",
+ "model.diffusion_model.input_blocks.1.0.out_layers.3.weight": "blocks.0.conv2.weight",
+ "model.diffusion_model.input_blocks.1.1.norm.bias": "blocks.1.norm.bias",
+ "model.diffusion_model.input_blocks.1.1.norm.weight": "blocks.1.norm.weight",
+ "model.diffusion_model.input_blocks.1.1.proj_in.bias": "blocks.1.proj_in.bias",
+ "model.diffusion_model.input_blocks.1.1.proj_in.weight": "blocks.1.proj_in.weight",
+ "model.diffusion_model.input_blocks.1.1.proj_out.bias": "blocks.1.proj_out.bias",
+ "model.diffusion_model.input_blocks.1.1.proj_out.weight": "blocks.1.proj_out.weight",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_k.weight": "blocks.1.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.1.transformer_blocks.0.attn1.to_out.bias",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.1.transformer_blocks.0.attn1.to_out.weight",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_q.weight": "blocks.1.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_v.weight": "blocks.1.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight": "blocks.1.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.1.transformer_blocks.0.attn2.to_out.bias",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.1.transformer_blocks.0.attn2.to_out.weight",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_q.weight": "blocks.1.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_v.weight": "blocks.1.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.1.transformer_blocks.0.act_fn.proj.bias",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.1.transformer_blocks.0.act_fn.proj.weight",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.ff.net.2.bias": "blocks.1.transformer_blocks.0.ff.bias",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.ff.net.2.weight": "blocks.1.transformer_blocks.0.ff.weight",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm1.bias": "blocks.1.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm1.weight": "blocks.1.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm2.bias": "blocks.1.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm2.weight": "blocks.1.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm3.bias": "blocks.1.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm3.weight": "blocks.1.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.input_blocks.10.0.emb_layers.1.bias": "blocks.24.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.10.0.emb_layers.1.weight": "blocks.24.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.10.0.in_layers.0.bias": "blocks.24.norm1.bias",
+ "model.diffusion_model.input_blocks.10.0.in_layers.0.weight": "blocks.24.norm1.weight",
+ "model.diffusion_model.input_blocks.10.0.in_layers.2.bias": "blocks.24.conv1.bias",
+ "model.diffusion_model.input_blocks.10.0.in_layers.2.weight": "blocks.24.conv1.weight",
+ "model.diffusion_model.input_blocks.10.0.out_layers.0.bias": "blocks.24.norm2.bias",
+ "model.diffusion_model.input_blocks.10.0.out_layers.0.weight": "blocks.24.norm2.weight",
+ "model.diffusion_model.input_blocks.10.0.out_layers.3.bias": "blocks.24.conv2.bias",
+ "model.diffusion_model.input_blocks.10.0.out_layers.3.weight": "blocks.24.conv2.weight",
+ "model.diffusion_model.input_blocks.11.0.emb_layers.1.bias": "blocks.26.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.11.0.emb_layers.1.weight": "blocks.26.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.11.0.in_layers.0.bias": "blocks.26.norm1.bias",
+ "model.diffusion_model.input_blocks.11.0.in_layers.0.weight": "blocks.26.norm1.weight",
+ "model.diffusion_model.input_blocks.11.0.in_layers.2.bias": "blocks.26.conv1.bias",
+ "model.diffusion_model.input_blocks.11.0.in_layers.2.weight": "blocks.26.conv1.weight",
+ "model.diffusion_model.input_blocks.11.0.out_layers.0.bias": "blocks.26.norm2.bias",
+ "model.diffusion_model.input_blocks.11.0.out_layers.0.weight": "blocks.26.norm2.weight",
+ "model.diffusion_model.input_blocks.11.0.out_layers.3.bias": "blocks.26.conv2.bias",
+ "model.diffusion_model.input_blocks.11.0.out_layers.3.weight": "blocks.26.conv2.weight",
+ "model.diffusion_model.input_blocks.2.0.emb_layers.1.bias": "blocks.3.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.2.0.emb_layers.1.weight": "blocks.3.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.2.0.in_layers.0.bias": "blocks.3.norm1.bias",
+ "model.diffusion_model.input_blocks.2.0.in_layers.0.weight": "blocks.3.norm1.weight",
+ "model.diffusion_model.input_blocks.2.0.in_layers.2.bias": "blocks.3.conv1.bias",
+ "model.diffusion_model.input_blocks.2.0.in_layers.2.weight": "blocks.3.conv1.weight",
+ "model.diffusion_model.input_blocks.2.0.out_layers.0.bias": "blocks.3.norm2.bias",
+ "model.diffusion_model.input_blocks.2.0.out_layers.0.weight": "blocks.3.norm2.weight",
+ "model.diffusion_model.input_blocks.2.0.out_layers.3.bias": "blocks.3.conv2.bias",
+ "model.diffusion_model.input_blocks.2.0.out_layers.3.weight": "blocks.3.conv2.weight",
+ "model.diffusion_model.input_blocks.2.1.norm.bias": "blocks.4.norm.bias",
+ "model.diffusion_model.input_blocks.2.1.norm.weight": "blocks.4.norm.weight",
+ "model.diffusion_model.input_blocks.2.1.proj_in.bias": "blocks.4.proj_in.bias",
+ "model.diffusion_model.input_blocks.2.1.proj_in.weight": "blocks.4.proj_in.weight",
+ "model.diffusion_model.input_blocks.2.1.proj_out.bias": "blocks.4.proj_out.bias",
+ "model.diffusion_model.input_blocks.2.1.proj_out.weight": "blocks.4.proj_out.weight",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_k.weight": "blocks.4.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.4.transformer_blocks.0.attn1.to_out.bias",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.4.transformer_blocks.0.attn1.to_out.weight",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_q.weight": "blocks.4.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_v.weight": "blocks.4.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight": "blocks.4.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.4.transformer_blocks.0.attn2.to_out.bias",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.4.transformer_blocks.0.attn2.to_out.weight",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_q.weight": "blocks.4.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_v.weight": "blocks.4.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.4.transformer_blocks.0.act_fn.proj.bias",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.4.transformer_blocks.0.act_fn.proj.weight",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.ff.net.2.bias": "blocks.4.transformer_blocks.0.ff.bias",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.ff.net.2.weight": "blocks.4.transformer_blocks.0.ff.weight",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm1.bias": "blocks.4.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm1.weight": "blocks.4.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm2.bias": "blocks.4.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm2.weight": "blocks.4.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm3.bias": "blocks.4.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm3.weight": "blocks.4.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.input_blocks.3.0.op.bias": "blocks.6.conv.bias",
+ "model.diffusion_model.input_blocks.3.0.op.weight": "blocks.6.conv.weight",
+ "model.diffusion_model.input_blocks.4.0.emb_layers.1.bias": "blocks.8.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.4.0.emb_layers.1.weight": "blocks.8.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.4.0.in_layers.0.bias": "blocks.8.norm1.bias",
+ "model.diffusion_model.input_blocks.4.0.in_layers.0.weight": "blocks.8.norm1.weight",
+ "model.diffusion_model.input_blocks.4.0.in_layers.2.bias": "blocks.8.conv1.bias",
+ "model.diffusion_model.input_blocks.4.0.in_layers.2.weight": "blocks.8.conv1.weight",
+ "model.diffusion_model.input_blocks.4.0.out_layers.0.bias": "blocks.8.norm2.bias",
+ "model.diffusion_model.input_blocks.4.0.out_layers.0.weight": "blocks.8.norm2.weight",
+ "model.diffusion_model.input_blocks.4.0.out_layers.3.bias": "blocks.8.conv2.bias",
+ "model.diffusion_model.input_blocks.4.0.out_layers.3.weight": "blocks.8.conv2.weight",
+ "model.diffusion_model.input_blocks.4.0.skip_connection.bias": "blocks.8.conv_shortcut.bias",
+ "model.diffusion_model.input_blocks.4.0.skip_connection.weight": "blocks.8.conv_shortcut.weight",
+ "model.diffusion_model.input_blocks.4.1.norm.bias": "blocks.9.norm.bias",
+ "model.diffusion_model.input_blocks.4.1.norm.weight": "blocks.9.norm.weight",
+ "model.diffusion_model.input_blocks.4.1.proj_in.bias": "blocks.9.proj_in.bias",
+ "model.diffusion_model.input_blocks.4.1.proj_in.weight": "blocks.9.proj_in.weight",
+ "model.diffusion_model.input_blocks.4.1.proj_out.bias": "blocks.9.proj_out.bias",
+ "model.diffusion_model.input_blocks.4.1.proj_out.weight": "blocks.9.proj_out.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_k.weight": "blocks.9.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.9.transformer_blocks.0.attn1.to_out.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.9.transformer_blocks.0.attn1.to_out.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_q.weight": "blocks.9.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_v.weight": "blocks.9.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight": "blocks.9.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.9.transformer_blocks.0.attn2.to_out.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.9.transformer_blocks.0.attn2.to_out.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_q.weight": "blocks.9.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_v.weight": "blocks.9.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.9.transformer_blocks.0.act_fn.proj.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.9.transformer_blocks.0.act_fn.proj.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.bias": "blocks.9.transformer_blocks.0.ff.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.weight": "blocks.9.transformer_blocks.0.ff.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm1.bias": "blocks.9.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm1.weight": "blocks.9.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm2.bias": "blocks.9.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm2.weight": "blocks.9.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm3.bias": "blocks.9.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm3.weight": "blocks.9.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.input_blocks.5.0.emb_layers.1.bias": "blocks.11.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.5.0.emb_layers.1.weight": "blocks.11.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.5.0.in_layers.0.bias": "blocks.11.norm1.bias",
+ "model.diffusion_model.input_blocks.5.0.in_layers.0.weight": "blocks.11.norm1.weight",
+ "model.diffusion_model.input_blocks.5.0.in_layers.2.bias": "blocks.11.conv1.bias",
+ "model.diffusion_model.input_blocks.5.0.in_layers.2.weight": "blocks.11.conv1.weight",
+ "model.diffusion_model.input_blocks.5.0.out_layers.0.bias": "blocks.11.norm2.bias",
+ "model.diffusion_model.input_blocks.5.0.out_layers.0.weight": "blocks.11.norm2.weight",
+ "model.diffusion_model.input_blocks.5.0.out_layers.3.bias": "blocks.11.conv2.bias",
+ "model.diffusion_model.input_blocks.5.0.out_layers.3.weight": "blocks.11.conv2.weight",
+ "model.diffusion_model.input_blocks.5.1.norm.bias": "blocks.12.norm.bias",
+ "model.diffusion_model.input_blocks.5.1.norm.weight": "blocks.12.norm.weight",
+ "model.diffusion_model.input_blocks.5.1.proj_in.bias": "blocks.12.proj_in.bias",
+ "model.diffusion_model.input_blocks.5.1.proj_in.weight": "blocks.12.proj_in.weight",
+ "model.diffusion_model.input_blocks.5.1.proj_out.bias": "blocks.12.proj_out.bias",
+ "model.diffusion_model.input_blocks.5.1.proj_out.weight": "blocks.12.proj_out.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_k.weight": "blocks.12.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.12.transformer_blocks.0.attn1.to_out.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.12.transformer_blocks.0.attn1.to_out.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_q.weight": "blocks.12.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_v.weight": "blocks.12.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_k.weight": "blocks.12.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.12.transformer_blocks.0.attn2.to_out.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.12.transformer_blocks.0.attn2.to_out.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_q.weight": "blocks.12.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_v.weight": "blocks.12.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.12.transformer_blocks.0.act_fn.proj.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.12.transformer_blocks.0.act_fn.proj.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.bias": "blocks.12.transformer_blocks.0.ff.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.weight": "blocks.12.transformer_blocks.0.ff.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm1.bias": "blocks.12.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm1.weight": "blocks.12.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm2.bias": "blocks.12.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm2.weight": "blocks.12.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm3.bias": "blocks.12.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm3.weight": "blocks.12.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.input_blocks.6.0.op.bias": "blocks.14.conv.bias",
+ "model.diffusion_model.input_blocks.6.0.op.weight": "blocks.14.conv.weight",
+ "model.diffusion_model.input_blocks.7.0.emb_layers.1.bias": "blocks.16.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.7.0.emb_layers.1.weight": "blocks.16.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.7.0.in_layers.0.bias": "blocks.16.norm1.bias",
+ "model.diffusion_model.input_blocks.7.0.in_layers.0.weight": "blocks.16.norm1.weight",
+ "model.diffusion_model.input_blocks.7.0.in_layers.2.bias": "blocks.16.conv1.bias",
+ "model.diffusion_model.input_blocks.7.0.in_layers.2.weight": "blocks.16.conv1.weight",
+ "model.diffusion_model.input_blocks.7.0.out_layers.0.bias": "blocks.16.norm2.bias",
+ "model.diffusion_model.input_blocks.7.0.out_layers.0.weight": "blocks.16.norm2.weight",
+ "model.diffusion_model.input_blocks.7.0.out_layers.3.bias": "blocks.16.conv2.bias",
+ "model.diffusion_model.input_blocks.7.0.out_layers.3.weight": "blocks.16.conv2.weight",
+ "model.diffusion_model.input_blocks.7.0.skip_connection.bias": "blocks.16.conv_shortcut.bias",
+ "model.diffusion_model.input_blocks.7.0.skip_connection.weight": "blocks.16.conv_shortcut.weight",
+ "model.diffusion_model.input_blocks.7.1.norm.bias": "blocks.17.norm.bias",
+ "model.diffusion_model.input_blocks.7.1.norm.weight": "blocks.17.norm.weight",
+ "model.diffusion_model.input_blocks.7.1.proj_in.bias": "blocks.17.proj_in.bias",
+ "model.diffusion_model.input_blocks.7.1.proj_in.weight": "blocks.17.proj_in.weight",
+ "model.diffusion_model.input_blocks.7.1.proj_out.bias": "blocks.17.proj_out.bias",
+ "model.diffusion_model.input_blocks.7.1.proj_out.weight": "blocks.17.proj_out.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_k.weight": "blocks.17.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.17.transformer_blocks.0.attn1.to_out.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.17.transformer_blocks.0.attn1.to_out.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_q.weight": "blocks.17.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_v.weight": "blocks.17.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_k.weight": "blocks.17.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.17.transformer_blocks.0.attn2.to_out.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.17.transformer_blocks.0.attn2.to_out.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_q.weight": "blocks.17.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_v.weight": "blocks.17.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.17.transformer_blocks.0.act_fn.proj.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.17.transformer_blocks.0.act_fn.proj.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.bias": "blocks.17.transformer_blocks.0.ff.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.weight": "blocks.17.transformer_blocks.0.ff.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm1.bias": "blocks.17.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm1.weight": "blocks.17.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm2.bias": "blocks.17.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm2.weight": "blocks.17.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm3.bias": "blocks.17.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm3.weight": "blocks.17.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.input_blocks.8.0.emb_layers.1.bias": "blocks.19.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.8.0.emb_layers.1.weight": "blocks.19.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.8.0.in_layers.0.bias": "blocks.19.norm1.bias",
+ "model.diffusion_model.input_blocks.8.0.in_layers.0.weight": "blocks.19.norm1.weight",
+ "model.diffusion_model.input_blocks.8.0.in_layers.2.bias": "blocks.19.conv1.bias",
+ "model.diffusion_model.input_blocks.8.0.in_layers.2.weight": "blocks.19.conv1.weight",
+ "model.diffusion_model.input_blocks.8.0.out_layers.0.bias": "blocks.19.norm2.bias",
+ "model.diffusion_model.input_blocks.8.0.out_layers.0.weight": "blocks.19.norm2.weight",
+ "model.diffusion_model.input_blocks.8.0.out_layers.3.bias": "blocks.19.conv2.bias",
+ "model.diffusion_model.input_blocks.8.0.out_layers.3.weight": "blocks.19.conv2.weight",
+ "model.diffusion_model.input_blocks.8.1.norm.bias": "blocks.20.norm.bias",
+ "model.diffusion_model.input_blocks.8.1.norm.weight": "blocks.20.norm.weight",
+ "model.diffusion_model.input_blocks.8.1.proj_in.bias": "blocks.20.proj_in.bias",
+ "model.diffusion_model.input_blocks.8.1.proj_in.weight": "blocks.20.proj_in.weight",
+ "model.diffusion_model.input_blocks.8.1.proj_out.bias": "blocks.20.proj_out.bias",
+ "model.diffusion_model.input_blocks.8.1.proj_out.weight": "blocks.20.proj_out.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_k.weight": "blocks.20.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.20.transformer_blocks.0.attn1.to_out.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.20.transformer_blocks.0.attn1.to_out.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_q.weight": "blocks.20.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_v.weight": "blocks.20.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_k.weight": "blocks.20.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.20.transformer_blocks.0.attn2.to_out.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.20.transformer_blocks.0.attn2.to_out.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_q.weight": "blocks.20.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_v.weight": "blocks.20.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.20.transformer_blocks.0.act_fn.proj.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.20.transformer_blocks.0.act_fn.proj.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.bias": "blocks.20.transformer_blocks.0.ff.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.weight": "blocks.20.transformer_blocks.0.ff.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm1.bias": "blocks.20.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm1.weight": "blocks.20.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm2.bias": "blocks.20.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm2.weight": "blocks.20.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm3.bias": "blocks.20.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm3.weight": "blocks.20.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.input_blocks.9.0.op.bias": "blocks.22.conv.bias",
+ "model.diffusion_model.input_blocks.9.0.op.weight": "blocks.22.conv.weight",
+ "model.diffusion_model.middle_block.0.emb_layers.1.bias": "blocks.28.time_emb_proj.bias",
+ "model.diffusion_model.middle_block.0.emb_layers.1.weight": "blocks.28.time_emb_proj.weight",
+ "model.diffusion_model.middle_block.0.in_layers.0.bias": "blocks.28.norm1.bias",
+ "model.diffusion_model.middle_block.0.in_layers.0.weight": "blocks.28.norm1.weight",
+ "model.diffusion_model.middle_block.0.in_layers.2.bias": "blocks.28.conv1.bias",
+ "model.diffusion_model.middle_block.0.in_layers.2.weight": "blocks.28.conv1.weight",
+ "model.diffusion_model.middle_block.0.out_layers.0.bias": "blocks.28.norm2.bias",
+ "model.diffusion_model.middle_block.0.out_layers.0.weight": "blocks.28.norm2.weight",
+ "model.diffusion_model.middle_block.0.out_layers.3.bias": "blocks.28.conv2.bias",
+ "model.diffusion_model.middle_block.0.out_layers.3.weight": "blocks.28.conv2.weight",
+ "model.diffusion_model.middle_block.1.norm.bias": "blocks.29.norm.bias",
+ "model.diffusion_model.middle_block.1.norm.weight": "blocks.29.norm.weight",
+ "model.diffusion_model.middle_block.1.proj_in.bias": "blocks.29.proj_in.bias",
+ "model.diffusion_model.middle_block.1.proj_in.weight": "blocks.29.proj_in.weight",
+ "model.diffusion_model.middle_block.1.proj_out.bias": "blocks.29.proj_out.bias",
+ "model.diffusion_model.middle_block.1.proj_out.weight": "blocks.29.proj_out.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_k.weight": "blocks.29.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.29.transformer_blocks.0.attn1.to_out.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.29.transformer_blocks.0.attn1.to_out.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight": "blocks.29.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_v.weight": "blocks.29.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_k.weight": "blocks.29.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.29.transformer_blocks.0.attn2.to_out.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.29.transformer_blocks.0.attn2.to_out.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_q.weight": "blocks.29.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_v.weight": "blocks.29.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.29.transformer_blocks.0.act_fn.proj.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.29.transformer_blocks.0.act_fn.proj.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.2.bias": "blocks.29.transformer_blocks.0.ff.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.2.weight": "blocks.29.transformer_blocks.0.ff.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.norm1.bias": "blocks.29.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.norm1.weight": "blocks.29.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.norm2.bias": "blocks.29.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.norm2.weight": "blocks.29.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.norm3.bias": "blocks.29.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.norm3.weight": "blocks.29.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.middle_block.2.emb_layers.1.bias": "blocks.30.time_emb_proj.bias",
+ "model.diffusion_model.middle_block.2.emb_layers.1.weight": "blocks.30.time_emb_proj.weight",
+ "model.diffusion_model.middle_block.2.in_layers.0.bias": "blocks.30.norm1.bias",
+ "model.diffusion_model.middle_block.2.in_layers.0.weight": "blocks.30.norm1.weight",
+ "model.diffusion_model.middle_block.2.in_layers.2.bias": "blocks.30.conv1.bias",
+ "model.diffusion_model.middle_block.2.in_layers.2.weight": "blocks.30.conv1.weight",
+ "model.diffusion_model.middle_block.2.out_layers.0.bias": "blocks.30.norm2.bias",
+ "model.diffusion_model.middle_block.2.out_layers.0.weight": "blocks.30.norm2.weight",
+ "model.diffusion_model.middle_block.2.out_layers.3.bias": "blocks.30.conv2.bias",
+ "model.diffusion_model.middle_block.2.out_layers.3.weight": "blocks.30.conv2.weight",
+ "model.diffusion_model.out.0.bias": "conv_norm_out.bias",
+ "model.diffusion_model.out.0.weight": "conv_norm_out.weight",
+ "model.diffusion_model.out.2.bias": "conv_out.bias",
+ "model.diffusion_model.out.2.weight": "conv_out.weight",
+ "model.diffusion_model.output_blocks.0.0.emb_layers.1.bias": "blocks.32.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.0.0.emb_layers.1.weight": "blocks.32.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.0.0.in_layers.0.bias": "blocks.32.norm1.bias",
+ "model.diffusion_model.output_blocks.0.0.in_layers.0.weight": "blocks.32.norm1.weight",
+ "model.diffusion_model.output_blocks.0.0.in_layers.2.bias": "blocks.32.conv1.bias",
+ "model.diffusion_model.output_blocks.0.0.in_layers.2.weight": "blocks.32.conv1.weight",
+ "model.diffusion_model.output_blocks.0.0.out_layers.0.bias": "blocks.32.norm2.bias",
+ "model.diffusion_model.output_blocks.0.0.out_layers.0.weight": "blocks.32.norm2.weight",
+ "model.diffusion_model.output_blocks.0.0.out_layers.3.bias": "blocks.32.conv2.bias",
+ "model.diffusion_model.output_blocks.0.0.out_layers.3.weight": "blocks.32.conv2.weight",
+ "model.diffusion_model.output_blocks.0.0.skip_connection.bias": "blocks.32.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.0.0.skip_connection.weight": "blocks.32.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.1.0.emb_layers.1.bias": "blocks.34.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.1.0.emb_layers.1.weight": "blocks.34.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.1.0.in_layers.0.bias": "blocks.34.norm1.bias",
+ "model.diffusion_model.output_blocks.1.0.in_layers.0.weight": "blocks.34.norm1.weight",
+ "model.diffusion_model.output_blocks.1.0.in_layers.2.bias": "blocks.34.conv1.bias",
+ "model.diffusion_model.output_blocks.1.0.in_layers.2.weight": "blocks.34.conv1.weight",
+ "model.diffusion_model.output_blocks.1.0.out_layers.0.bias": "blocks.34.norm2.bias",
+ "model.diffusion_model.output_blocks.1.0.out_layers.0.weight": "blocks.34.norm2.weight",
+ "model.diffusion_model.output_blocks.1.0.out_layers.3.bias": "blocks.34.conv2.bias",
+ "model.diffusion_model.output_blocks.1.0.out_layers.3.weight": "blocks.34.conv2.weight",
+ "model.diffusion_model.output_blocks.1.0.skip_connection.bias": "blocks.34.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.1.0.skip_connection.weight": "blocks.34.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.10.0.emb_layers.1.bias": "blocks.62.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.10.0.emb_layers.1.weight": "blocks.62.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.10.0.in_layers.0.bias": "blocks.62.norm1.bias",
+ "model.diffusion_model.output_blocks.10.0.in_layers.0.weight": "blocks.62.norm1.weight",
+ "model.diffusion_model.output_blocks.10.0.in_layers.2.bias": "blocks.62.conv1.bias",
+ "model.diffusion_model.output_blocks.10.0.in_layers.2.weight": "blocks.62.conv1.weight",
+ "model.diffusion_model.output_blocks.10.0.out_layers.0.bias": "blocks.62.norm2.bias",
+ "model.diffusion_model.output_blocks.10.0.out_layers.0.weight": "blocks.62.norm2.weight",
+ "model.diffusion_model.output_blocks.10.0.out_layers.3.bias": "blocks.62.conv2.bias",
+ "model.diffusion_model.output_blocks.10.0.out_layers.3.weight": "blocks.62.conv2.weight",
+ "model.diffusion_model.output_blocks.10.0.skip_connection.bias": "blocks.62.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.10.0.skip_connection.weight": "blocks.62.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.10.1.norm.bias": "blocks.63.norm.bias",
+ "model.diffusion_model.output_blocks.10.1.norm.weight": "blocks.63.norm.weight",
+ "model.diffusion_model.output_blocks.10.1.proj_in.bias": "blocks.63.proj_in.bias",
+ "model.diffusion_model.output_blocks.10.1.proj_in.weight": "blocks.63.proj_in.weight",
+ "model.diffusion_model.output_blocks.10.1.proj_out.bias": "blocks.63.proj_out.bias",
+ "model.diffusion_model.output_blocks.10.1.proj_out.weight": "blocks.63.proj_out.weight",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_k.weight": "blocks.63.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.63.transformer_blocks.0.attn1.to_out.bias",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.63.transformer_blocks.0.attn1.to_out.weight",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_q.weight": "blocks.63.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_v.weight": "blocks.63.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_k.weight": "blocks.63.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.63.transformer_blocks.0.attn2.to_out.bias",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.63.transformer_blocks.0.attn2.to_out.weight",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_q.weight": "blocks.63.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_v.weight": "blocks.63.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.63.transformer_blocks.0.act_fn.proj.bias",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.63.transformer_blocks.0.act_fn.proj.weight",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.ff.net.2.bias": "blocks.63.transformer_blocks.0.ff.bias",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.ff.net.2.weight": "blocks.63.transformer_blocks.0.ff.weight",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm1.bias": "blocks.63.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm1.weight": "blocks.63.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm2.bias": "blocks.63.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm2.weight": "blocks.63.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm3.bias": "blocks.63.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm3.weight": "blocks.63.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.output_blocks.11.0.emb_layers.1.bias": "blocks.65.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.11.0.emb_layers.1.weight": "blocks.65.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.11.0.in_layers.0.bias": "blocks.65.norm1.bias",
+ "model.diffusion_model.output_blocks.11.0.in_layers.0.weight": "blocks.65.norm1.weight",
+ "model.diffusion_model.output_blocks.11.0.in_layers.2.bias": "blocks.65.conv1.bias",
+ "model.diffusion_model.output_blocks.11.0.in_layers.2.weight": "blocks.65.conv1.weight",
+ "model.diffusion_model.output_blocks.11.0.out_layers.0.bias": "blocks.65.norm2.bias",
+ "model.diffusion_model.output_blocks.11.0.out_layers.0.weight": "blocks.65.norm2.weight",
+ "model.diffusion_model.output_blocks.11.0.out_layers.3.bias": "blocks.65.conv2.bias",
+ "model.diffusion_model.output_blocks.11.0.out_layers.3.weight": "blocks.65.conv2.weight",
+ "model.diffusion_model.output_blocks.11.0.skip_connection.bias": "blocks.65.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.11.0.skip_connection.weight": "blocks.65.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.11.1.norm.bias": "blocks.66.norm.bias",
+ "model.diffusion_model.output_blocks.11.1.norm.weight": "blocks.66.norm.weight",
+ "model.diffusion_model.output_blocks.11.1.proj_in.bias": "blocks.66.proj_in.bias",
+ "model.diffusion_model.output_blocks.11.1.proj_in.weight": "blocks.66.proj_in.weight",
+ "model.diffusion_model.output_blocks.11.1.proj_out.bias": "blocks.66.proj_out.bias",
+ "model.diffusion_model.output_blocks.11.1.proj_out.weight": "blocks.66.proj_out.weight",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_k.weight": "blocks.66.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.66.transformer_blocks.0.attn1.to_out.bias",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.66.transformer_blocks.0.attn1.to_out.weight",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_q.weight": "blocks.66.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_v.weight": "blocks.66.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_k.weight": "blocks.66.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.66.transformer_blocks.0.attn2.to_out.bias",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.66.transformer_blocks.0.attn2.to_out.weight",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_q.weight": "blocks.66.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_v.weight": "blocks.66.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.66.transformer_blocks.0.act_fn.proj.bias",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.66.transformer_blocks.0.act_fn.proj.weight",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.ff.net.2.bias": "blocks.66.transformer_blocks.0.ff.bias",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.ff.net.2.weight": "blocks.66.transformer_blocks.0.ff.weight",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.bias": "blocks.66.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.weight": "blocks.66.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm2.bias": "blocks.66.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm2.weight": "blocks.66.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm3.bias": "blocks.66.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm3.weight": "blocks.66.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.output_blocks.2.0.emb_layers.1.bias": "blocks.36.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.2.0.emb_layers.1.weight": "blocks.36.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.2.0.in_layers.0.bias": "blocks.36.norm1.bias",
+ "model.diffusion_model.output_blocks.2.0.in_layers.0.weight": "blocks.36.norm1.weight",
+ "model.diffusion_model.output_blocks.2.0.in_layers.2.bias": "blocks.36.conv1.bias",
+ "model.diffusion_model.output_blocks.2.0.in_layers.2.weight": "blocks.36.conv1.weight",
+ "model.diffusion_model.output_blocks.2.0.out_layers.0.bias": "blocks.36.norm2.bias",
+ "model.diffusion_model.output_blocks.2.0.out_layers.0.weight": "blocks.36.norm2.weight",
+ "model.diffusion_model.output_blocks.2.0.out_layers.3.bias": "blocks.36.conv2.bias",
+ "model.diffusion_model.output_blocks.2.0.out_layers.3.weight": "blocks.36.conv2.weight",
+ "model.diffusion_model.output_blocks.2.0.skip_connection.bias": "blocks.36.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.2.0.skip_connection.weight": "blocks.36.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.2.1.conv.bias": "blocks.37.conv.bias",
+ "model.diffusion_model.output_blocks.2.1.conv.weight": "blocks.37.conv.weight",
+ "model.diffusion_model.output_blocks.3.0.emb_layers.1.bias": "blocks.39.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.3.0.emb_layers.1.weight": "blocks.39.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.3.0.in_layers.0.bias": "blocks.39.norm1.bias",
+ "model.diffusion_model.output_blocks.3.0.in_layers.0.weight": "blocks.39.norm1.weight",
+ "model.diffusion_model.output_blocks.3.0.in_layers.2.bias": "blocks.39.conv1.bias",
+ "model.diffusion_model.output_blocks.3.0.in_layers.2.weight": "blocks.39.conv1.weight",
+ "model.diffusion_model.output_blocks.3.0.out_layers.0.bias": "blocks.39.norm2.bias",
+ "model.diffusion_model.output_blocks.3.0.out_layers.0.weight": "blocks.39.norm2.weight",
+ "model.diffusion_model.output_blocks.3.0.out_layers.3.bias": "blocks.39.conv2.bias",
+ "model.diffusion_model.output_blocks.3.0.out_layers.3.weight": "blocks.39.conv2.weight",
+ "model.diffusion_model.output_blocks.3.0.skip_connection.bias": "blocks.39.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.3.0.skip_connection.weight": "blocks.39.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.3.1.norm.bias": "blocks.40.norm.bias",
+ "model.diffusion_model.output_blocks.3.1.norm.weight": "blocks.40.norm.weight",
+ "model.diffusion_model.output_blocks.3.1.proj_in.bias": "blocks.40.proj_in.bias",
+ "model.diffusion_model.output_blocks.3.1.proj_in.weight": "blocks.40.proj_in.weight",
+ "model.diffusion_model.output_blocks.3.1.proj_out.bias": "blocks.40.proj_out.bias",
+ "model.diffusion_model.output_blocks.3.1.proj_out.weight": "blocks.40.proj_out.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_k.weight": "blocks.40.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.40.transformer_blocks.0.attn1.to_out.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.40.transformer_blocks.0.attn1.to_out.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_q.weight": "blocks.40.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_v.weight": "blocks.40.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_k.weight": "blocks.40.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.40.transformer_blocks.0.attn2.to_out.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.40.transformer_blocks.0.attn2.to_out.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_q.weight": "blocks.40.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_v.weight": "blocks.40.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.40.transformer_blocks.0.act_fn.proj.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.40.transformer_blocks.0.act_fn.proj.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.2.bias": "blocks.40.transformer_blocks.0.ff.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.2.weight": "blocks.40.transformer_blocks.0.ff.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm1.bias": "blocks.40.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm1.weight": "blocks.40.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm2.bias": "blocks.40.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm2.weight": "blocks.40.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm3.bias": "blocks.40.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm3.weight": "blocks.40.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.output_blocks.4.0.emb_layers.1.bias": "blocks.42.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.4.0.emb_layers.1.weight": "blocks.42.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.4.0.in_layers.0.bias": "blocks.42.norm1.bias",
+ "model.diffusion_model.output_blocks.4.0.in_layers.0.weight": "blocks.42.norm1.weight",
+ "model.diffusion_model.output_blocks.4.0.in_layers.2.bias": "blocks.42.conv1.bias",
+ "model.diffusion_model.output_blocks.4.0.in_layers.2.weight": "blocks.42.conv1.weight",
+ "model.diffusion_model.output_blocks.4.0.out_layers.0.bias": "blocks.42.norm2.bias",
+ "model.diffusion_model.output_blocks.4.0.out_layers.0.weight": "blocks.42.norm2.weight",
+ "model.diffusion_model.output_blocks.4.0.out_layers.3.bias": "blocks.42.conv2.bias",
+ "model.diffusion_model.output_blocks.4.0.out_layers.3.weight": "blocks.42.conv2.weight",
+ "model.diffusion_model.output_blocks.4.0.skip_connection.bias": "blocks.42.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.4.0.skip_connection.weight": "blocks.42.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.4.1.norm.bias": "blocks.43.norm.bias",
+ "model.diffusion_model.output_blocks.4.1.norm.weight": "blocks.43.norm.weight",
+ "model.diffusion_model.output_blocks.4.1.proj_in.bias": "blocks.43.proj_in.bias",
+ "model.diffusion_model.output_blocks.4.1.proj_in.weight": "blocks.43.proj_in.weight",
+ "model.diffusion_model.output_blocks.4.1.proj_out.bias": "blocks.43.proj_out.bias",
+ "model.diffusion_model.output_blocks.4.1.proj_out.weight": "blocks.43.proj_out.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_k.weight": "blocks.43.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.43.transformer_blocks.0.attn1.to_out.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.43.transformer_blocks.0.attn1.to_out.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_q.weight": "blocks.43.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_v.weight": "blocks.43.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_k.weight": "blocks.43.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.43.transformer_blocks.0.attn2.to_out.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.43.transformer_blocks.0.attn2.to_out.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_q.weight": "blocks.43.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_v.weight": "blocks.43.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.43.transformer_blocks.0.act_fn.proj.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.43.transformer_blocks.0.act_fn.proj.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.2.bias": "blocks.43.transformer_blocks.0.ff.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.2.weight": "blocks.43.transformer_blocks.0.ff.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm1.bias": "blocks.43.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm1.weight": "blocks.43.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm2.bias": "blocks.43.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm2.weight": "blocks.43.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm3.bias": "blocks.43.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm3.weight": "blocks.43.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.output_blocks.5.0.emb_layers.1.bias": "blocks.45.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.5.0.emb_layers.1.weight": "blocks.45.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.5.0.in_layers.0.bias": "blocks.45.norm1.bias",
+ "model.diffusion_model.output_blocks.5.0.in_layers.0.weight": "blocks.45.norm1.weight",
+ "model.diffusion_model.output_blocks.5.0.in_layers.2.bias": "blocks.45.conv1.bias",
+ "model.diffusion_model.output_blocks.5.0.in_layers.2.weight": "blocks.45.conv1.weight",
+ "model.diffusion_model.output_blocks.5.0.out_layers.0.bias": "blocks.45.norm2.bias",
+ "model.diffusion_model.output_blocks.5.0.out_layers.0.weight": "blocks.45.norm2.weight",
+ "model.diffusion_model.output_blocks.5.0.out_layers.3.bias": "blocks.45.conv2.bias",
+ "model.diffusion_model.output_blocks.5.0.out_layers.3.weight": "blocks.45.conv2.weight",
+ "model.diffusion_model.output_blocks.5.0.skip_connection.bias": "blocks.45.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.5.0.skip_connection.weight": "blocks.45.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.5.1.norm.bias": "blocks.46.norm.bias",
+ "model.diffusion_model.output_blocks.5.1.norm.weight": "blocks.46.norm.weight",
+ "model.diffusion_model.output_blocks.5.1.proj_in.bias": "blocks.46.proj_in.bias",
+ "model.diffusion_model.output_blocks.5.1.proj_in.weight": "blocks.46.proj_in.weight",
+ "model.diffusion_model.output_blocks.5.1.proj_out.bias": "blocks.46.proj_out.bias",
+ "model.diffusion_model.output_blocks.5.1.proj_out.weight": "blocks.46.proj_out.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_k.weight": "blocks.46.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.46.transformer_blocks.0.attn1.to_out.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.46.transformer_blocks.0.attn1.to_out.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_q.weight": "blocks.46.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_v.weight": "blocks.46.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_k.weight": "blocks.46.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.46.transformer_blocks.0.attn2.to_out.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.46.transformer_blocks.0.attn2.to_out.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_q.weight": "blocks.46.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_v.weight": "blocks.46.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.46.transformer_blocks.0.act_fn.proj.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.46.transformer_blocks.0.act_fn.proj.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.2.bias": "blocks.46.transformer_blocks.0.ff.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.2.weight": "blocks.46.transformer_blocks.0.ff.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm1.bias": "blocks.46.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm1.weight": "blocks.46.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm2.bias": "blocks.46.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm2.weight": "blocks.46.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm3.bias": "blocks.46.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm3.weight": "blocks.46.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.output_blocks.5.2.conv.bias": "blocks.47.conv.bias",
+ "model.diffusion_model.output_blocks.5.2.conv.weight": "blocks.47.conv.weight",
+ "model.diffusion_model.output_blocks.6.0.emb_layers.1.bias": "blocks.49.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.6.0.emb_layers.1.weight": "blocks.49.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.6.0.in_layers.0.bias": "blocks.49.norm1.bias",
+ "model.diffusion_model.output_blocks.6.0.in_layers.0.weight": "blocks.49.norm1.weight",
+ "model.diffusion_model.output_blocks.6.0.in_layers.2.bias": "blocks.49.conv1.bias",
+ "model.diffusion_model.output_blocks.6.0.in_layers.2.weight": "blocks.49.conv1.weight",
+ "model.diffusion_model.output_blocks.6.0.out_layers.0.bias": "blocks.49.norm2.bias",
+ "model.diffusion_model.output_blocks.6.0.out_layers.0.weight": "blocks.49.norm2.weight",
+ "model.diffusion_model.output_blocks.6.0.out_layers.3.bias": "blocks.49.conv2.bias",
+ "model.diffusion_model.output_blocks.6.0.out_layers.3.weight": "blocks.49.conv2.weight",
+ "model.diffusion_model.output_blocks.6.0.skip_connection.bias": "blocks.49.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.6.0.skip_connection.weight": "blocks.49.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.6.1.norm.bias": "blocks.50.norm.bias",
+ "model.diffusion_model.output_blocks.6.1.norm.weight": "blocks.50.norm.weight",
+ "model.diffusion_model.output_blocks.6.1.proj_in.bias": "blocks.50.proj_in.bias",
+ "model.diffusion_model.output_blocks.6.1.proj_in.weight": "blocks.50.proj_in.weight",
+ "model.diffusion_model.output_blocks.6.1.proj_out.bias": "blocks.50.proj_out.bias",
+ "model.diffusion_model.output_blocks.6.1.proj_out.weight": "blocks.50.proj_out.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_k.weight": "blocks.50.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.50.transformer_blocks.0.attn1.to_out.bias",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.50.transformer_blocks.0.attn1.to_out.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_q.weight": "blocks.50.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_v.weight": "blocks.50.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_k.weight": "blocks.50.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.50.transformer_blocks.0.attn2.to_out.bias",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.50.transformer_blocks.0.attn2.to_out.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_q.weight": "blocks.50.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_v.weight": "blocks.50.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.50.transformer_blocks.0.act_fn.proj.bias",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.50.transformer_blocks.0.act_fn.proj.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.2.bias": "blocks.50.transformer_blocks.0.ff.bias",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.2.weight": "blocks.50.transformer_blocks.0.ff.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm1.bias": "blocks.50.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm1.weight": "blocks.50.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm2.bias": "blocks.50.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm2.weight": "blocks.50.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm3.bias": "blocks.50.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm3.weight": "blocks.50.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.output_blocks.7.0.emb_layers.1.bias": "blocks.52.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.7.0.emb_layers.1.weight": "blocks.52.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.7.0.in_layers.0.bias": "blocks.52.norm1.bias",
+ "model.diffusion_model.output_blocks.7.0.in_layers.0.weight": "blocks.52.norm1.weight",
+ "model.diffusion_model.output_blocks.7.0.in_layers.2.bias": "blocks.52.conv1.bias",
+ "model.diffusion_model.output_blocks.7.0.in_layers.2.weight": "blocks.52.conv1.weight",
+ "model.diffusion_model.output_blocks.7.0.out_layers.0.bias": "blocks.52.norm2.bias",
+ "model.diffusion_model.output_blocks.7.0.out_layers.0.weight": "blocks.52.norm2.weight",
+ "model.diffusion_model.output_blocks.7.0.out_layers.3.bias": "blocks.52.conv2.bias",
+ "model.diffusion_model.output_blocks.7.0.out_layers.3.weight": "blocks.52.conv2.weight",
+ "model.diffusion_model.output_blocks.7.0.skip_connection.bias": "blocks.52.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.7.0.skip_connection.weight": "blocks.52.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.7.1.norm.bias": "blocks.53.norm.bias",
+ "model.diffusion_model.output_blocks.7.1.norm.weight": "blocks.53.norm.weight",
+ "model.diffusion_model.output_blocks.7.1.proj_in.bias": "blocks.53.proj_in.bias",
+ "model.diffusion_model.output_blocks.7.1.proj_in.weight": "blocks.53.proj_in.weight",
+ "model.diffusion_model.output_blocks.7.1.proj_out.bias": "blocks.53.proj_out.bias",
+ "model.diffusion_model.output_blocks.7.1.proj_out.weight": "blocks.53.proj_out.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_k.weight": "blocks.53.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.53.transformer_blocks.0.attn1.to_out.bias",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.53.transformer_blocks.0.attn1.to_out.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_q.weight": "blocks.53.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_v.weight": "blocks.53.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_k.weight": "blocks.53.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.53.transformer_blocks.0.attn2.to_out.bias",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.53.transformer_blocks.0.attn2.to_out.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_q.weight": "blocks.53.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_v.weight": "blocks.53.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.53.transformer_blocks.0.act_fn.proj.bias",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.53.transformer_blocks.0.act_fn.proj.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.2.bias": "blocks.53.transformer_blocks.0.ff.bias",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.2.weight": "blocks.53.transformer_blocks.0.ff.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm1.bias": "blocks.53.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm1.weight": "blocks.53.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm2.bias": "blocks.53.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm2.weight": "blocks.53.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm3.bias": "blocks.53.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm3.weight": "blocks.53.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.output_blocks.8.0.emb_layers.1.bias": "blocks.55.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.8.0.emb_layers.1.weight": "blocks.55.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.8.0.in_layers.0.bias": "blocks.55.norm1.bias",
+ "model.diffusion_model.output_blocks.8.0.in_layers.0.weight": "blocks.55.norm1.weight",
+ "model.diffusion_model.output_blocks.8.0.in_layers.2.bias": "blocks.55.conv1.bias",
+ "model.diffusion_model.output_blocks.8.0.in_layers.2.weight": "blocks.55.conv1.weight",
+ "model.diffusion_model.output_blocks.8.0.out_layers.0.bias": "blocks.55.norm2.bias",
+ "model.diffusion_model.output_blocks.8.0.out_layers.0.weight": "blocks.55.norm2.weight",
+ "model.diffusion_model.output_blocks.8.0.out_layers.3.bias": "blocks.55.conv2.bias",
+ "model.diffusion_model.output_blocks.8.0.out_layers.3.weight": "blocks.55.conv2.weight",
+ "model.diffusion_model.output_blocks.8.0.skip_connection.bias": "blocks.55.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.8.0.skip_connection.weight": "blocks.55.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.8.1.norm.bias": "blocks.56.norm.bias",
+ "model.diffusion_model.output_blocks.8.1.norm.weight": "blocks.56.norm.weight",
+ "model.diffusion_model.output_blocks.8.1.proj_in.bias": "blocks.56.proj_in.bias",
+ "model.diffusion_model.output_blocks.8.1.proj_in.weight": "blocks.56.proj_in.weight",
+ "model.diffusion_model.output_blocks.8.1.proj_out.bias": "blocks.56.proj_out.bias",
+ "model.diffusion_model.output_blocks.8.1.proj_out.weight": "blocks.56.proj_out.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_k.weight": "blocks.56.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.56.transformer_blocks.0.attn1.to_out.bias",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.56.transformer_blocks.0.attn1.to_out.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_q.weight": "blocks.56.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_v.weight": "blocks.56.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_k.weight": "blocks.56.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.56.transformer_blocks.0.attn2.to_out.bias",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.56.transformer_blocks.0.attn2.to_out.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_q.weight": "blocks.56.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_v.weight": "blocks.56.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.56.transformer_blocks.0.act_fn.proj.bias",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.56.transformer_blocks.0.act_fn.proj.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.2.bias": "blocks.56.transformer_blocks.0.ff.bias",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.2.weight": "blocks.56.transformer_blocks.0.ff.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm1.bias": "blocks.56.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm1.weight": "blocks.56.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm2.bias": "blocks.56.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm2.weight": "blocks.56.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm3.bias": "blocks.56.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm3.weight": "blocks.56.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.output_blocks.8.2.conv.bias": "blocks.57.conv.bias",
+ "model.diffusion_model.output_blocks.8.2.conv.weight": "blocks.57.conv.weight",
+ "model.diffusion_model.output_blocks.9.0.emb_layers.1.bias": "blocks.59.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.9.0.emb_layers.1.weight": "blocks.59.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.9.0.in_layers.0.bias": "blocks.59.norm1.bias",
+ "model.diffusion_model.output_blocks.9.0.in_layers.0.weight": "blocks.59.norm1.weight",
+ "model.diffusion_model.output_blocks.9.0.in_layers.2.bias": "blocks.59.conv1.bias",
+ "model.diffusion_model.output_blocks.9.0.in_layers.2.weight": "blocks.59.conv1.weight",
+ "model.diffusion_model.output_blocks.9.0.out_layers.0.bias": "blocks.59.norm2.bias",
+ "model.diffusion_model.output_blocks.9.0.out_layers.0.weight": "blocks.59.norm2.weight",
+ "model.diffusion_model.output_blocks.9.0.out_layers.3.bias": "blocks.59.conv2.bias",
+ "model.diffusion_model.output_blocks.9.0.out_layers.3.weight": "blocks.59.conv2.weight",
+ "model.diffusion_model.output_blocks.9.0.skip_connection.bias": "blocks.59.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.9.0.skip_connection.weight": "blocks.59.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.9.1.norm.bias": "blocks.60.norm.bias",
+ "model.diffusion_model.output_blocks.9.1.norm.weight": "blocks.60.norm.weight",
+ "model.diffusion_model.output_blocks.9.1.proj_in.bias": "blocks.60.proj_in.bias",
+ "model.diffusion_model.output_blocks.9.1.proj_in.weight": "blocks.60.proj_in.weight",
+ "model.diffusion_model.output_blocks.9.1.proj_out.bias": "blocks.60.proj_out.bias",
+ "model.diffusion_model.output_blocks.9.1.proj_out.weight": "blocks.60.proj_out.weight",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_k.weight": "blocks.60.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.60.transformer_blocks.0.attn1.to_out.bias",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.60.transformer_blocks.0.attn1.to_out.weight",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_q.weight": "blocks.60.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_v.weight": "blocks.60.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_k.weight": "blocks.60.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.60.transformer_blocks.0.attn2.to_out.bias",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.60.transformer_blocks.0.attn2.to_out.weight",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_q.weight": "blocks.60.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_v.weight": "blocks.60.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.60.transformer_blocks.0.act_fn.proj.bias",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.60.transformer_blocks.0.act_fn.proj.weight",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.ff.net.2.bias": "blocks.60.transformer_blocks.0.ff.bias",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.ff.net.2.weight": "blocks.60.transformer_blocks.0.ff.weight",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm1.bias": "blocks.60.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm1.weight": "blocks.60.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm2.bias": "blocks.60.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm2.weight": "blocks.60.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm3.bias": "blocks.60.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm3.weight": "blocks.60.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.time_embed.0.bias": "time_embedding.0.bias",
+ "model.diffusion_model.time_embed.0.weight": "time_embedding.0.weight",
+ "model.diffusion_model.time_embed.2.bias": "time_embedding.2.bias",
+ "model.diffusion_model.time_embed.2.weight": "time_embedding.2.weight",
+ }
+ state_dict_ = {}
+ for name in state_dict:
+ if name in rename_dict:
+ param = state_dict[name]
+ if ".proj_in." in name or ".proj_out." in name:
+ param = param.squeeze()
+ state_dict_[rename_dict[name]] = param
+ return state_dict_
\ No newline at end of file
diff --git a/PusaV1/diffsynth/models/sd_vae_decoder.py b/PusaV1/diffsynth/models/sd_vae_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..93f015a63ededec53507ea73d34e2d904f5bed06
--- /dev/null
+++ b/PusaV1/diffsynth/models/sd_vae_decoder.py
@@ -0,0 +1,336 @@
+import torch
+from .attention import Attention
+from .sd_unet import ResnetBlock, UpSampler
+from .tiler import TileWorker
+
+
+class VAEAttentionBlock(torch.nn.Module):
+
+ def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5):
+ super().__init__()
+ inner_dim = num_attention_heads * attention_head_dim
+
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True)
+
+ self.transformer_blocks = torch.nn.ModuleList([
+ Attention(
+ inner_dim,
+ num_attention_heads,
+ attention_head_dim,
+ bias_q=True,
+ bias_kv=True,
+ bias_out=True
+ )
+ for d in range(num_layers)
+ ])
+
+ def forward(self, hidden_states, time_emb, text_emb, res_stack):
+ batch, _, height, width = hidden_states.shape
+ residual = hidden_states
+
+ hidden_states = self.norm(hidden_states)
+ inner_dim = hidden_states.shape[1]
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
+
+ for block in self.transformer_blocks:
+ hidden_states = block(hidden_states)
+
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
+ hidden_states = hidden_states + residual
+
+ return hidden_states, time_emb, text_emb, res_stack
+
+
+class SDVAEDecoder(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.scaling_factor = 0.18215
+ self.post_quant_conv = torch.nn.Conv2d(4, 4, kernel_size=1)
+ self.conv_in = torch.nn.Conv2d(4, 512, kernel_size=3, padding=1)
+
+ self.blocks = torch.nn.ModuleList([
+ # UNetMidBlock2D
+ ResnetBlock(512, 512, eps=1e-6),
+ VAEAttentionBlock(1, 512, 512, 1, eps=1e-6),
+ ResnetBlock(512, 512, eps=1e-6),
+ # UpDecoderBlock2D
+ ResnetBlock(512, 512, eps=1e-6),
+ ResnetBlock(512, 512, eps=1e-6),
+ ResnetBlock(512, 512, eps=1e-6),
+ UpSampler(512),
+ # UpDecoderBlock2D
+ ResnetBlock(512, 512, eps=1e-6),
+ ResnetBlock(512, 512, eps=1e-6),
+ ResnetBlock(512, 512, eps=1e-6),
+ UpSampler(512),
+ # UpDecoderBlock2D
+ ResnetBlock(512, 256, eps=1e-6),
+ ResnetBlock(256, 256, eps=1e-6),
+ ResnetBlock(256, 256, eps=1e-6),
+ UpSampler(256),
+ # UpDecoderBlock2D
+ ResnetBlock(256, 128, eps=1e-6),
+ ResnetBlock(128, 128, eps=1e-6),
+ ResnetBlock(128, 128, eps=1e-6),
+ ])
+
+ self.conv_norm_out = torch.nn.GroupNorm(num_channels=128, num_groups=32, eps=1e-5)
+ self.conv_act = torch.nn.SiLU()
+ self.conv_out = torch.nn.Conv2d(128, 3, kernel_size=3, padding=1)
+
+ def tiled_forward(self, sample, tile_size=64, tile_stride=32):
+ hidden_states = TileWorker().tiled_forward(
+ lambda x: self.forward(x),
+ sample,
+ tile_size,
+ tile_stride,
+ tile_device=sample.device,
+ tile_dtype=sample.dtype
+ )
+ return hidden_states
+
+ def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
+ original_dtype = sample.dtype
+ sample = sample.to(dtype=next(iter(self.parameters())).dtype)
+ # For VAE Decoder, we do not need to apply the tiler on each layer.
+ if tiled:
+ return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
+
+ # 1. pre-process
+ sample = sample / self.scaling_factor
+ hidden_states = self.post_quant_conv(sample)
+ hidden_states = self.conv_in(hidden_states)
+ time_emb = None
+ text_emb = None
+ res_stack = None
+
+ # 2. blocks
+ for i, block in enumerate(self.blocks):
+ hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
+
+ # 3. output
+ hidden_states = self.conv_norm_out(hidden_states)
+ hidden_states = self.conv_act(hidden_states)
+ hidden_states = self.conv_out(hidden_states)
+ hidden_states = hidden_states.to(original_dtype)
+
+ return hidden_states
+
+ @staticmethod
+ def state_dict_converter():
+ return SDVAEDecoderStateDictConverter()
+
+
+class SDVAEDecoderStateDictConverter:
+ def __init__(self):
+ pass
+
+ def from_diffusers(self, state_dict):
+ # architecture
+ block_types = [
+ 'ResnetBlock', 'VAEAttentionBlock', 'ResnetBlock',
+ 'ResnetBlock', 'ResnetBlock', 'ResnetBlock', 'UpSampler',
+ 'ResnetBlock', 'ResnetBlock', 'ResnetBlock', 'UpSampler',
+ 'ResnetBlock', 'ResnetBlock', 'ResnetBlock', 'UpSampler',
+ 'ResnetBlock', 'ResnetBlock', 'ResnetBlock'
+ ]
+
+ # Rename each parameter
+ local_rename_dict = {
+ "post_quant_conv": "post_quant_conv",
+ "decoder.conv_in": "conv_in",
+ "decoder.mid_block.attentions.0.group_norm": "blocks.1.norm",
+ "decoder.mid_block.attentions.0.to_q": "blocks.1.transformer_blocks.0.to_q",
+ "decoder.mid_block.attentions.0.to_k": "blocks.1.transformer_blocks.0.to_k",
+ "decoder.mid_block.attentions.0.to_v": "blocks.1.transformer_blocks.0.to_v",
+ "decoder.mid_block.attentions.0.to_out.0": "blocks.1.transformer_blocks.0.to_out",
+ "decoder.mid_block.resnets.0.norm1": "blocks.0.norm1",
+ "decoder.mid_block.resnets.0.conv1": "blocks.0.conv1",
+ "decoder.mid_block.resnets.0.norm2": "blocks.0.norm2",
+ "decoder.mid_block.resnets.0.conv2": "blocks.0.conv2",
+ "decoder.mid_block.resnets.1.norm1": "blocks.2.norm1",
+ "decoder.mid_block.resnets.1.conv1": "blocks.2.conv1",
+ "decoder.mid_block.resnets.1.norm2": "blocks.2.norm2",
+ "decoder.mid_block.resnets.1.conv2": "blocks.2.conv2",
+ "decoder.conv_norm_out": "conv_norm_out",
+ "decoder.conv_out": "conv_out",
+ }
+ name_list = sorted([name for name in state_dict])
+ rename_dict = {}
+ block_id = {"ResnetBlock": 2, "DownSampler": 2, "UpSampler": 2}
+ last_block_type_with_id = {"ResnetBlock": "", "DownSampler": "", "UpSampler": ""}
+ for name in name_list:
+ names = name.split(".")
+ name_prefix = ".".join(names[:-1])
+ if name_prefix in local_rename_dict:
+ rename_dict[name] = local_rename_dict[name_prefix] + "." + names[-1]
+ elif name.startswith("decoder.up_blocks"):
+ block_type = {"resnets": "ResnetBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[3]]
+ block_type_with_id = ".".join(names[:5])
+ if block_type_with_id != last_block_type_with_id[block_type]:
+ block_id[block_type] += 1
+ last_block_type_with_id[block_type] = block_type_with_id
+ while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type:
+ block_id[block_type] += 1
+ block_type_with_id = ".".join(names[:5])
+ names = ["blocks", str(block_id[block_type])] + names[5:]
+ rename_dict[name] = ".".join(names)
+
+ # Convert state_dict
+ state_dict_ = {}
+ for name, param in state_dict.items():
+ if name in rename_dict:
+ state_dict_[rename_dict[name]] = param
+ return state_dict_
+
+ def from_civitai(self, state_dict):
+ rename_dict = {
+ "first_stage_model.decoder.conv_in.bias": "conv_in.bias",
+ "first_stage_model.decoder.conv_in.weight": "conv_in.weight",
+ "first_stage_model.decoder.conv_out.bias": "conv_out.bias",
+ "first_stage_model.decoder.conv_out.weight": "conv_out.weight",
+ "first_stage_model.decoder.mid.attn_1.k.bias": "blocks.1.transformer_blocks.0.to_k.bias",
+ "first_stage_model.decoder.mid.attn_1.k.weight": "blocks.1.transformer_blocks.0.to_k.weight",
+ "first_stage_model.decoder.mid.attn_1.norm.bias": "blocks.1.norm.bias",
+ "first_stage_model.decoder.mid.attn_1.norm.weight": "blocks.1.norm.weight",
+ "first_stage_model.decoder.mid.attn_1.proj_out.bias": "blocks.1.transformer_blocks.0.to_out.bias",
+ "first_stage_model.decoder.mid.attn_1.proj_out.weight": "blocks.1.transformer_blocks.0.to_out.weight",
+ "first_stage_model.decoder.mid.attn_1.q.bias": "blocks.1.transformer_blocks.0.to_q.bias",
+ "first_stage_model.decoder.mid.attn_1.q.weight": "blocks.1.transformer_blocks.0.to_q.weight",
+ "first_stage_model.decoder.mid.attn_1.v.bias": "blocks.1.transformer_blocks.0.to_v.bias",
+ "first_stage_model.decoder.mid.attn_1.v.weight": "blocks.1.transformer_blocks.0.to_v.weight",
+ "first_stage_model.decoder.mid.block_1.conv1.bias": "blocks.0.conv1.bias",
+ "first_stage_model.decoder.mid.block_1.conv1.weight": "blocks.0.conv1.weight",
+ "first_stage_model.decoder.mid.block_1.conv2.bias": "blocks.0.conv2.bias",
+ "first_stage_model.decoder.mid.block_1.conv2.weight": "blocks.0.conv2.weight",
+ "first_stage_model.decoder.mid.block_1.norm1.bias": "blocks.0.norm1.bias",
+ "first_stage_model.decoder.mid.block_1.norm1.weight": "blocks.0.norm1.weight",
+ "first_stage_model.decoder.mid.block_1.norm2.bias": "blocks.0.norm2.bias",
+ "first_stage_model.decoder.mid.block_1.norm2.weight": "blocks.0.norm2.weight",
+ "first_stage_model.decoder.mid.block_2.conv1.bias": "blocks.2.conv1.bias",
+ "first_stage_model.decoder.mid.block_2.conv1.weight": "blocks.2.conv1.weight",
+ "first_stage_model.decoder.mid.block_2.conv2.bias": "blocks.2.conv2.bias",
+ "first_stage_model.decoder.mid.block_2.conv2.weight": "blocks.2.conv2.weight",
+ "first_stage_model.decoder.mid.block_2.norm1.bias": "blocks.2.norm1.bias",
+ "first_stage_model.decoder.mid.block_2.norm1.weight": "blocks.2.norm1.weight",
+ "first_stage_model.decoder.mid.block_2.norm2.bias": "blocks.2.norm2.bias",
+ "first_stage_model.decoder.mid.block_2.norm2.weight": "blocks.2.norm2.weight",
+ "first_stage_model.decoder.norm_out.bias": "conv_norm_out.bias",
+ "first_stage_model.decoder.norm_out.weight": "conv_norm_out.weight",
+ "first_stage_model.decoder.up.0.block.0.conv1.bias": "blocks.15.conv1.bias",
+ "first_stage_model.decoder.up.0.block.0.conv1.weight": "blocks.15.conv1.weight",
+ "first_stage_model.decoder.up.0.block.0.conv2.bias": "blocks.15.conv2.bias",
+ "first_stage_model.decoder.up.0.block.0.conv2.weight": "blocks.15.conv2.weight",
+ "first_stage_model.decoder.up.0.block.0.nin_shortcut.bias": "blocks.15.conv_shortcut.bias",
+ "first_stage_model.decoder.up.0.block.0.nin_shortcut.weight": "blocks.15.conv_shortcut.weight",
+ "first_stage_model.decoder.up.0.block.0.norm1.bias": "blocks.15.norm1.bias",
+ "first_stage_model.decoder.up.0.block.0.norm1.weight": "blocks.15.norm1.weight",
+ "first_stage_model.decoder.up.0.block.0.norm2.bias": "blocks.15.norm2.bias",
+ "first_stage_model.decoder.up.0.block.0.norm2.weight": "blocks.15.norm2.weight",
+ "first_stage_model.decoder.up.0.block.1.conv1.bias": "blocks.16.conv1.bias",
+ "first_stage_model.decoder.up.0.block.1.conv1.weight": "blocks.16.conv1.weight",
+ "first_stage_model.decoder.up.0.block.1.conv2.bias": "blocks.16.conv2.bias",
+ "first_stage_model.decoder.up.0.block.1.conv2.weight": "blocks.16.conv2.weight",
+ "first_stage_model.decoder.up.0.block.1.norm1.bias": "blocks.16.norm1.bias",
+ "first_stage_model.decoder.up.0.block.1.norm1.weight": "blocks.16.norm1.weight",
+ "first_stage_model.decoder.up.0.block.1.norm2.bias": "blocks.16.norm2.bias",
+ "first_stage_model.decoder.up.0.block.1.norm2.weight": "blocks.16.norm2.weight",
+ "first_stage_model.decoder.up.0.block.2.conv1.bias": "blocks.17.conv1.bias",
+ "first_stage_model.decoder.up.0.block.2.conv1.weight": "blocks.17.conv1.weight",
+ "first_stage_model.decoder.up.0.block.2.conv2.bias": "blocks.17.conv2.bias",
+ "first_stage_model.decoder.up.0.block.2.conv2.weight": "blocks.17.conv2.weight",
+ "first_stage_model.decoder.up.0.block.2.norm1.bias": "blocks.17.norm1.bias",
+ "first_stage_model.decoder.up.0.block.2.norm1.weight": "blocks.17.norm1.weight",
+ "first_stage_model.decoder.up.0.block.2.norm2.bias": "blocks.17.norm2.bias",
+ "first_stage_model.decoder.up.0.block.2.norm2.weight": "blocks.17.norm2.weight",
+ "first_stage_model.decoder.up.1.block.0.conv1.bias": "blocks.11.conv1.bias",
+ "first_stage_model.decoder.up.1.block.0.conv1.weight": "blocks.11.conv1.weight",
+ "first_stage_model.decoder.up.1.block.0.conv2.bias": "blocks.11.conv2.bias",
+ "first_stage_model.decoder.up.1.block.0.conv2.weight": "blocks.11.conv2.weight",
+ "first_stage_model.decoder.up.1.block.0.nin_shortcut.bias": "blocks.11.conv_shortcut.bias",
+ "first_stage_model.decoder.up.1.block.0.nin_shortcut.weight": "blocks.11.conv_shortcut.weight",
+ "first_stage_model.decoder.up.1.block.0.norm1.bias": "blocks.11.norm1.bias",
+ "first_stage_model.decoder.up.1.block.0.norm1.weight": "blocks.11.norm1.weight",
+ "first_stage_model.decoder.up.1.block.0.norm2.bias": "blocks.11.norm2.bias",
+ "first_stage_model.decoder.up.1.block.0.norm2.weight": "blocks.11.norm2.weight",
+ "first_stage_model.decoder.up.1.block.1.conv1.bias": "blocks.12.conv1.bias",
+ "first_stage_model.decoder.up.1.block.1.conv1.weight": "blocks.12.conv1.weight",
+ "first_stage_model.decoder.up.1.block.1.conv2.bias": "blocks.12.conv2.bias",
+ "first_stage_model.decoder.up.1.block.1.conv2.weight": "blocks.12.conv2.weight",
+ "first_stage_model.decoder.up.1.block.1.norm1.bias": "blocks.12.norm1.bias",
+ "first_stage_model.decoder.up.1.block.1.norm1.weight": "blocks.12.norm1.weight",
+ "first_stage_model.decoder.up.1.block.1.norm2.bias": "blocks.12.norm2.bias",
+ "first_stage_model.decoder.up.1.block.1.norm2.weight": "blocks.12.norm2.weight",
+ "first_stage_model.decoder.up.1.block.2.conv1.bias": "blocks.13.conv1.bias",
+ "first_stage_model.decoder.up.1.block.2.conv1.weight": "blocks.13.conv1.weight",
+ "first_stage_model.decoder.up.1.block.2.conv2.bias": "blocks.13.conv2.bias",
+ "first_stage_model.decoder.up.1.block.2.conv2.weight": "blocks.13.conv2.weight",
+ "first_stage_model.decoder.up.1.block.2.norm1.bias": "blocks.13.norm1.bias",
+ "first_stage_model.decoder.up.1.block.2.norm1.weight": "blocks.13.norm1.weight",
+ "first_stage_model.decoder.up.1.block.2.norm2.bias": "blocks.13.norm2.bias",
+ "first_stage_model.decoder.up.1.block.2.norm2.weight": "blocks.13.norm2.weight",
+ "first_stage_model.decoder.up.1.upsample.conv.bias": "blocks.14.conv.bias",
+ "first_stage_model.decoder.up.1.upsample.conv.weight": "blocks.14.conv.weight",
+ "first_stage_model.decoder.up.2.block.0.conv1.bias": "blocks.7.conv1.bias",
+ "first_stage_model.decoder.up.2.block.0.conv1.weight": "blocks.7.conv1.weight",
+ "first_stage_model.decoder.up.2.block.0.conv2.bias": "blocks.7.conv2.bias",
+ "first_stage_model.decoder.up.2.block.0.conv2.weight": "blocks.7.conv2.weight",
+ "first_stage_model.decoder.up.2.block.0.norm1.bias": "blocks.7.norm1.bias",
+ "first_stage_model.decoder.up.2.block.0.norm1.weight": "blocks.7.norm1.weight",
+ "first_stage_model.decoder.up.2.block.0.norm2.bias": "blocks.7.norm2.bias",
+ "first_stage_model.decoder.up.2.block.0.norm2.weight": "blocks.7.norm2.weight",
+ "first_stage_model.decoder.up.2.block.1.conv1.bias": "blocks.8.conv1.bias",
+ "first_stage_model.decoder.up.2.block.1.conv1.weight": "blocks.8.conv1.weight",
+ "first_stage_model.decoder.up.2.block.1.conv2.bias": "blocks.8.conv2.bias",
+ "first_stage_model.decoder.up.2.block.1.conv2.weight": "blocks.8.conv2.weight",
+ "first_stage_model.decoder.up.2.block.1.norm1.bias": "blocks.8.norm1.bias",
+ "first_stage_model.decoder.up.2.block.1.norm1.weight": "blocks.8.norm1.weight",
+ "first_stage_model.decoder.up.2.block.1.norm2.bias": "blocks.8.norm2.bias",
+ "first_stage_model.decoder.up.2.block.1.norm2.weight": "blocks.8.norm2.weight",
+ "first_stage_model.decoder.up.2.block.2.conv1.bias": "blocks.9.conv1.bias",
+ "first_stage_model.decoder.up.2.block.2.conv1.weight": "blocks.9.conv1.weight",
+ "first_stage_model.decoder.up.2.block.2.conv2.bias": "blocks.9.conv2.bias",
+ "first_stage_model.decoder.up.2.block.2.conv2.weight": "blocks.9.conv2.weight",
+ "first_stage_model.decoder.up.2.block.2.norm1.bias": "blocks.9.norm1.bias",
+ "first_stage_model.decoder.up.2.block.2.norm1.weight": "blocks.9.norm1.weight",
+ "first_stage_model.decoder.up.2.block.2.norm2.bias": "blocks.9.norm2.bias",
+ "first_stage_model.decoder.up.2.block.2.norm2.weight": "blocks.9.norm2.weight",
+ "first_stage_model.decoder.up.2.upsample.conv.bias": "blocks.10.conv.bias",
+ "first_stage_model.decoder.up.2.upsample.conv.weight": "blocks.10.conv.weight",
+ "first_stage_model.decoder.up.3.block.0.conv1.bias": "blocks.3.conv1.bias",
+ "first_stage_model.decoder.up.3.block.0.conv1.weight": "blocks.3.conv1.weight",
+ "first_stage_model.decoder.up.3.block.0.conv2.bias": "blocks.3.conv2.bias",
+ "first_stage_model.decoder.up.3.block.0.conv2.weight": "blocks.3.conv2.weight",
+ "first_stage_model.decoder.up.3.block.0.norm1.bias": "blocks.3.norm1.bias",
+ "first_stage_model.decoder.up.3.block.0.norm1.weight": "blocks.3.norm1.weight",
+ "first_stage_model.decoder.up.3.block.0.norm2.bias": "blocks.3.norm2.bias",
+ "first_stage_model.decoder.up.3.block.0.norm2.weight": "blocks.3.norm2.weight",
+ "first_stage_model.decoder.up.3.block.1.conv1.bias": "blocks.4.conv1.bias",
+ "first_stage_model.decoder.up.3.block.1.conv1.weight": "blocks.4.conv1.weight",
+ "first_stage_model.decoder.up.3.block.1.conv2.bias": "blocks.4.conv2.bias",
+ "first_stage_model.decoder.up.3.block.1.conv2.weight": "blocks.4.conv2.weight",
+ "first_stage_model.decoder.up.3.block.1.norm1.bias": "blocks.4.norm1.bias",
+ "first_stage_model.decoder.up.3.block.1.norm1.weight": "blocks.4.norm1.weight",
+ "first_stage_model.decoder.up.3.block.1.norm2.bias": "blocks.4.norm2.bias",
+ "first_stage_model.decoder.up.3.block.1.norm2.weight": "blocks.4.norm2.weight",
+ "first_stage_model.decoder.up.3.block.2.conv1.bias": "blocks.5.conv1.bias",
+ "first_stage_model.decoder.up.3.block.2.conv1.weight": "blocks.5.conv1.weight",
+ "first_stage_model.decoder.up.3.block.2.conv2.bias": "blocks.5.conv2.bias",
+ "first_stage_model.decoder.up.3.block.2.conv2.weight": "blocks.5.conv2.weight",
+ "first_stage_model.decoder.up.3.block.2.norm1.bias": "blocks.5.norm1.bias",
+ "first_stage_model.decoder.up.3.block.2.norm1.weight": "blocks.5.norm1.weight",
+ "first_stage_model.decoder.up.3.block.2.norm2.bias": "blocks.5.norm2.bias",
+ "first_stage_model.decoder.up.3.block.2.norm2.weight": "blocks.5.norm2.weight",
+ "first_stage_model.decoder.up.3.upsample.conv.bias": "blocks.6.conv.bias",
+ "first_stage_model.decoder.up.3.upsample.conv.weight": "blocks.6.conv.weight",
+ "first_stage_model.post_quant_conv.bias": "post_quant_conv.bias",
+ "first_stage_model.post_quant_conv.weight": "post_quant_conv.weight",
+ }
+ state_dict_ = {}
+ for name in state_dict:
+ if name in rename_dict:
+ param = state_dict[name]
+ if "transformer_blocks" in rename_dict[name]:
+ param = param.squeeze()
+ state_dict_[rename_dict[name]] = param
+ return state_dict_
diff --git a/PusaV1/diffsynth/models/sd_vae_encoder.py b/PusaV1/diffsynth/models/sd_vae_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..60965c591c01bd12dce5f0abdbfd121c033c47c6
--- /dev/null
+++ b/PusaV1/diffsynth/models/sd_vae_encoder.py
@@ -0,0 +1,282 @@
+import torch
+from .sd_unet import ResnetBlock, DownSampler
+from .sd_vae_decoder import VAEAttentionBlock
+from .tiler import TileWorker
+from einops import rearrange
+
+
+class SDVAEEncoder(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.scaling_factor = 0.18215
+ self.quant_conv = torch.nn.Conv2d(8, 8, kernel_size=1)
+ self.conv_in = torch.nn.Conv2d(3, 128, kernel_size=3, padding=1)
+
+ self.blocks = torch.nn.ModuleList([
+ # DownEncoderBlock2D
+ ResnetBlock(128, 128, eps=1e-6),
+ ResnetBlock(128, 128, eps=1e-6),
+ DownSampler(128, padding=0, extra_padding=True),
+ # DownEncoderBlock2D
+ ResnetBlock(128, 256, eps=1e-6),
+ ResnetBlock(256, 256, eps=1e-6),
+ DownSampler(256, padding=0, extra_padding=True),
+ # DownEncoderBlock2D
+ ResnetBlock(256, 512, eps=1e-6),
+ ResnetBlock(512, 512, eps=1e-6),
+ DownSampler(512, padding=0, extra_padding=True),
+ # DownEncoderBlock2D
+ ResnetBlock(512, 512, eps=1e-6),
+ ResnetBlock(512, 512, eps=1e-6),
+ # UNetMidBlock2D
+ ResnetBlock(512, 512, eps=1e-6),
+ VAEAttentionBlock(1, 512, 512, 1, eps=1e-6),
+ ResnetBlock(512, 512, eps=1e-6),
+ ])
+
+ self.conv_norm_out = torch.nn.GroupNorm(num_channels=512, num_groups=32, eps=1e-6)
+ self.conv_act = torch.nn.SiLU()
+ self.conv_out = torch.nn.Conv2d(512, 8, kernel_size=3, padding=1)
+
+ def tiled_forward(self, sample, tile_size=64, tile_stride=32):
+ hidden_states = TileWorker().tiled_forward(
+ lambda x: self.forward(x),
+ sample,
+ tile_size,
+ tile_stride,
+ tile_device=sample.device,
+ tile_dtype=sample.dtype
+ )
+ return hidden_states
+
+ def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
+ original_dtype = sample.dtype
+ sample = sample.to(dtype=next(iter(self.parameters())).dtype)
+ # For VAE Decoder, we do not need to apply the tiler on each layer.
+ if tiled:
+ return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
+
+ # 1. pre-process
+ hidden_states = self.conv_in(sample)
+ time_emb = None
+ text_emb = None
+ res_stack = None
+
+ # 2. blocks
+ for i, block in enumerate(self.blocks):
+ hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
+
+ # 3. output
+ hidden_states = self.conv_norm_out(hidden_states)
+ hidden_states = self.conv_act(hidden_states)
+ hidden_states = self.conv_out(hidden_states)
+ hidden_states = self.quant_conv(hidden_states)
+ hidden_states = hidden_states[:, :4]
+ hidden_states *= self.scaling_factor
+ hidden_states = hidden_states.to(original_dtype)
+
+ return hidden_states
+
+ def encode_video(self, sample, batch_size=8):
+ B = sample.shape[0]
+ hidden_states = []
+
+ for i in range(0, sample.shape[2], batch_size):
+
+ j = min(i + batch_size, sample.shape[2])
+ sample_batch = rearrange(sample[:,:,i:j], "B C T H W -> (B T) C H W")
+
+ hidden_states_batch = self(sample_batch)
+ hidden_states_batch = rearrange(hidden_states_batch, "(B T) C H W -> B C T H W", B=B)
+
+ hidden_states.append(hidden_states_batch)
+
+ hidden_states = torch.concat(hidden_states, dim=2)
+ return hidden_states
+
+ @staticmethod
+ def state_dict_converter():
+ return SDVAEEncoderStateDictConverter()
+
+
+class SDVAEEncoderStateDictConverter:
+ def __init__(self):
+ pass
+
+ def from_diffusers(self, state_dict):
+ # architecture
+ block_types = [
+ 'ResnetBlock', 'ResnetBlock', 'DownSampler',
+ 'ResnetBlock', 'ResnetBlock', 'DownSampler',
+ 'ResnetBlock', 'ResnetBlock', 'DownSampler',
+ 'ResnetBlock', 'ResnetBlock',
+ 'ResnetBlock', 'VAEAttentionBlock', 'ResnetBlock'
+ ]
+
+ # Rename each parameter
+ local_rename_dict = {
+ "quant_conv": "quant_conv",
+ "encoder.conv_in": "conv_in",
+ "encoder.mid_block.attentions.0.group_norm": "blocks.12.norm",
+ "encoder.mid_block.attentions.0.to_q": "blocks.12.transformer_blocks.0.to_q",
+ "encoder.mid_block.attentions.0.to_k": "blocks.12.transformer_blocks.0.to_k",
+ "encoder.mid_block.attentions.0.to_v": "blocks.12.transformer_blocks.0.to_v",
+ "encoder.mid_block.attentions.0.to_out.0": "blocks.12.transformer_blocks.0.to_out",
+ "encoder.mid_block.resnets.0.norm1": "blocks.11.norm1",
+ "encoder.mid_block.resnets.0.conv1": "blocks.11.conv1",
+ "encoder.mid_block.resnets.0.norm2": "blocks.11.norm2",
+ "encoder.mid_block.resnets.0.conv2": "blocks.11.conv2",
+ "encoder.mid_block.resnets.1.norm1": "blocks.13.norm1",
+ "encoder.mid_block.resnets.1.conv1": "blocks.13.conv1",
+ "encoder.mid_block.resnets.1.norm2": "blocks.13.norm2",
+ "encoder.mid_block.resnets.1.conv2": "blocks.13.conv2",
+ "encoder.conv_norm_out": "conv_norm_out",
+ "encoder.conv_out": "conv_out",
+ }
+ name_list = sorted([name for name in state_dict])
+ rename_dict = {}
+ block_id = {"ResnetBlock": -1, "DownSampler": -1, "UpSampler": -1}
+ last_block_type_with_id = {"ResnetBlock": "", "DownSampler": "", "UpSampler": ""}
+ for name in name_list:
+ names = name.split(".")
+ name_prefix = ".".join(names[:-1])
+ if name_prefix in local_rename_dict:
+ rename_dict[name] = local_rename_dict[name_prefix] + "." + names[-1]
+ elif name.startswith("encoder.down_blocks"):
+ block_type = {"resnets": "ResnetBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[3]]
+ block_type_with_id = ".".join(names[:5])
+ if block_type_with_id != last_block_type_with_id[block_type]:
+ block_id[block_type] += 1
+ last_block_type_with_id[block_type] = block_type_with_id
+ while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type:
+ block_id[block_type] += 1
+ block_type_with_id = ".".join(names[:5])
+ names = ["blocks", str(block_id[block_type])] + names[5:]
+ rename_dict[name] = ".".join(names)
+
+ # Convert state_dict
+ state_dict_ = {}
+ for name, param in state_dict.items():
+ if name in rename_dict:
+ state_dict_[rename_dict[name]] = param
+ return state_dict_
+
+ def from_civitai(self, state_dict):
+ rename_dict = {
+ "first_stage_model.encoder.conv_in.bias": "conv_in.bias",
+ "first_stage_model.encoder.conv_in.weight": "conv_in.weight",
+ "first_stage_model.encoder.conv_out.bias": "conv_out.bias",
+ "first_stage_model.encoder.conv_out.weight": "conv_out.weight",
+ "first_stage_model.encoder.down.0.block.0.conv1.bias": "blocks.0.conv1.bias",
+ "first_stage_model.encoder.down.0.block.0.conv1.weight": "blocks.0.conv1.weight",
+ "first_stage_model.encoder.down.0.block.0.conv2.bias": "blocks.0.conv2.bias",
+ "first_stage_model.encoder.down.0.block.0.conv2.weight": "blocks.0.conv2.weight",
+ "first_stage_model.encoder.down.0.block.0.norm1.bias": "blocks.0.norm1.bias",
+ "first_stage_model.encoder.down.0.block.0.norm1.weight": "blocks.0.norm1.weight",
+ "first_stage_model.encoder.down.0.block.0.norm2.bias": "blocks.0.norm2.bias",
+ "first_stage_model.encoder.down.0.block.0.norm2.weight": "blocks.0.norm2.weight",
+ "first_stage_model.encoder.down.0.block.1.conv1.bias": "blocks.1.conv1.bias",
+ "first_stage_model.encoder.down.0.block.1.conv1.weight": "blocks.1.conv1.weight",
+ "first_stage_model.encoder.down.0.block.1.conv2.bias": "blocks.1.conv2.bias",
+ "first_stage_model.encoder.down.0.block.1.conv2.weight": "blocks.1.conv2.weight",
+ "first_stage_model.encoder.down.0.block.1.norm1.bias": "blocks.1.norm1.bias",
+ "first_stage_model.encoder.down.0.block.1.norm1.weight": "blocks.1.norm1.weight",
+ "first_stage_model.encoder.down.0.block.1.norm2.bias": "blocks.1.norm2.bias",
+ "first_stage_model.encoder.down.0.block.1.norm2.weight": "blocks.1.norm2.weight",
+ "first_stage_model.encoder.down.0.downsample.conv.bias": "blocks.2.conv.bias",
+ "first_stage_model.encoder.down.0.downsample.conv.weight": "blocks.2.conv.weight",
+ "first_stage_model.encoder.down.1.block.0.conv1.bias": "blocks.3.conv1.bias",
+ "first_stage_model.encoder.down.1.block.0.conv1.weight": "blocks.3.conv1.weight",
+ "first_stage_model.encoder.down.1.block.0.conv2.bias": "blocks.3.conv2.bias",
+ "first_stage_model.encoder.down.1.block.0.conv2.weight": "blocks.3.conv2.weight",
+ "first_stage_model.encoder.down.1.block.0.nin_shortcut.bias": "blocks.3.conv_shortcut.bias",
+ "first_stage_model.encoder.down.1.block.0.nin_shortcut.weight": "blocks.3.conv_shortcut.weight",
+ "first_stage_model.encoder.down.1.block.0.norm1.bias": "blocks.3.norm1.bias",
+ "first_stage_model.encoder.down.1.block.0.norm1.weight": "blocks.3.norm1.weight",
+ "first_stage_model.encoder.down.1.block.0.norm2.bias": "blocks.3.norm2.bias",
+ "first_stage_model.encoder.down.1.block.0.norm2.weight": "blocks.3.norm2.weight",
+ "first_stage_model.encoder.down.1.block.1.conv1.bias": "blocks.4.conv1.bias",
+ "first_stage_model.encoder.down.1.block.1.conv1.weight": "blocks.4.conv1.weight",
+ "first_stage_model.encoder.down.1.block.1.conv2.bias": "blocks.4.conv2.bias",
+ "first_stage_model.encoder.down.1.block.1.conv2.weight": "blocks.4.conv2.weight",
+ "first_stage_model.encoder.down.1.block.1.norm1.bias": "blocks.4.norm1.bias",
+ "first_stage_model.encoder.down.1.block.1.norm1.weight": "blocks.4.norm1.weight",
+ "first_stage_model.encoder.down.1.block.1.norm2.bias": "blocks.4.norm2.bias",
+ "first_stage_model.encoder.down.1.block.1.norm2.weight": "blocks.4.norm2.weight",
+ "first_stage_model.encoder.down.1.downsample.conv.bias": "blocks.5.conv.bias",
+ "first_stage_model.encoder.down.1.downsample.conv.weight": "blocks.5.conv.weight",
+ "first_stage_model.encoder.down.2.block.0.conv1.bias": "blocks.6.conv1.bias",
+ "first_stage_model.encoder.down.2.block.0.conv1.weight": "blocks.6.conv1.weight",
+ "first_stage_model.encoder.down.2.block.0.conv2.bias": "blocks.6.conv2.bias",
+ "first_stage_model.encoder.down.2.block.0.conv2.weight": "blocks.6.conv2.weight",
+ "first_stage_model.encoder.down.2.block.0.nin_shortcut.bias": "blocks.6.conv_shortcut.bias",
+ "first_stage_model.encoder.down.2.block.0.nin_shortcut.weight": "blocks.6.conv_shortcut.weight",
+ "first_stage_model.encoder.down.2.block.0.norm1.bias": "blocks.6.norm1.bias",
+ "first_stage_model.encoder.down.2.block.0.norm1.weight": "blocks.6.norm1.weight",
+ "first_stage_model.encoder.down.2.block.0.norm2.bias": "blocks.6.norm2.bias",
+ "first_stage_model.encoder.down.2.block.0.norm2.weight": "blocks.6.norm2.weight",
+ "first_stage_model.encoder.down.2.block.1.conv1.bias": "blocks.7.conv1.bias",
+ "first_stage_model.encoder.down.2.block.1.conv1.weight": "blocks.7.conv1.weight",
+ "first_stage_model.encoder.down.2.block.1.conv2.bias": "blocks.7.conv2.bias",
+ "first_stage_model.encoder.down.2.block.1.conv2.weight": "blocks.7.conv2.weight",
+ "first_stage_model.encoder.down.2.block.1.norm1.bias": "blocks.7.norm1.bias",
+ "first_stage_model.encoder.down.2.block.1.norm1.weight": "blocks.7.norm1.weight",
+ "first_stage_model.encoder.down.2.block.1.norm2.bias": "blocks.7.norm2.bias",
+ "first_stage_model.encoder.down.2.block.1.norm2.weight": "blocks.7.norm2.weight",
+ "first_stage_model.encoder.down.2.downsample.conv.bias": "blocks.8.conv.bias",
+ "first_stage_model.encoder.down.2.downsample.conv.weight": "blocks.8.conv.weight",
+ "first_stage_model.encoder.down.3.block.0.conv1.bias": "blocks.9.conv1.bias",
+ "first_stage_model.encoder.down.3.block.0.conv1.weight": "blocks.9.conv1.weight",
+ "first_stage_model.encoder.down.3.block.0.conv2.bias": "blocks.9.conv2.bias",
+ "first_stage_model.encoder.down.3.block.0.conv2.weight": "blocks.9.conv2.weight",
+ "first_stage_model.encoder.down.3.block.0.norm1.bias": "blocks.9.norm1.bias",
+ "first_stage_model.encoder.down.3.block.0.norm1.weight": "blocks.9.norm1.weight",
+ "first_stage_model.encoder.down.3.block.0.norm2.bias": "blocks.9.norm2.bias",
+ "first_stage_model.encoder.down.3.block.0.norm2.weight": "blocks.9.norm2.weight",
+ "first_stage_model.encoder.down.3.block.1.conv1.bias": "blocks.10.conv1.bias",
+ "first_stage_model.encoder.down.3.block.1.conv1.weight": "blocks.10.conv1.weight",
+ "first_stage_model.encoder.down.3.block.1.conv2.bias": "blocks.10.conv2.bias",
+ "first_stage_model.encoder.down.3.block.1.conv2.weight": "blocks.10.conv2.weight",
+ "first_stage_model.encoder.down.3.block.1.norm1.bias": "blocks.10.norm1.bias",
+ "first_stage_model.encoder.down.3.block.1.norm1.weight": "blocks.10.norm1.weight",
+ "first_stage_model.encoder.down.3.block.1.norm2.bias": "blocks.10.norm2.bias",
+ "first_stage_model.encoder.down.3.block.1.norm2.weight": "blocks.10.norm2.weight",
+ "first_stage_model.encoder.mid.attn_1.k.bias": "blocks.12.transformer_blocks.0.to_k.bias",
+ "first_stage_model.encoder.mid.attn_1.k.weight": "blocks.12.transformer_blocks.0.to_k.weight",
+ "first_stage_model.encoder.mid.attn_1.norm.bias": "blocks.12.norm.bias",
+ "first_stage_model.encoder.mid.attn_1.norm.weight": "blocks.12.norm.weight",
+ "first_stage_model.encoder.mid.attn_1.proj_out.bias": "blocks.12.transformer_blocks.0.to_out.bias",
+ "first_stage_model.encoder.mid.attn_1.proj_out.weight": "blocks.12.transformer_blocks.0.to_out.weight",
+ "first_stage_model.encoder.mid.attn_1.q.bias": "blocks.12.transformer_blocks.0.to_q.bias",
+ "first_stage_model.encoder.mid.attn_1.q.weight": "blocks.12.transformer_blocks.0.to_q.weight",
+ "first_stage_model.encoder.mid.attn_1.v.bias": "blocks.12.transformer_blocks.0.to_v.bias",
+ "first_stage_model.encoder.mid.attn_1.v.weight": "blocks.12.transformer_blocks.0.to_v.weight",
+ "first_stage_model.encoder.mid.block_1.conv1.bias": "blocks.11.conv1.bias",
+ "first_stage_model.encoder.mid.block_1.conv1.weight": "blocks.11.conv1.weight",
+ "first_stage_model.encoder.mid.block_1.conv2.bias": "blocks.11.conv2.bias",
+ "first_stage_model.encoder.mid.block_1.conv2.weight": "blocks.11.conv2.weight",
+ "first_stage_model.encoder.mid.block_1.norm1.bias": "blocks.11.norm1.bias",
+ "first_stage_model.encoder.mid.block_1.norm1.weight": "blocks.11.norm1.weight",
+ "first_stage_model.encoder.mid.block_1.norm2.bias": "blocks.11.norm2.bias",
+ "first_stage_model.encoder.mid.block_1.norm2.weight": "blocks.11.norm2.weight",
+ "first_stage_model.encoder.mid.block_2.conv1.bias": "blocks.13.conv1.bias",
+ "first_stage_model.encoder.mid.block_2.conv1.weight": "blocks.13.conv1.weight",
+ "first_stage_model.encoder.mid.block_2.conv2.bias": "blocks.13.conv2.bias",
+ "first_stage_model.encoder.mid.block_2.conv2.weight": "blocks.13.conv2.weight",
+ "first_stage_model.encoder.mid.block_2.norm1.bias": "blocks.13.norm1.bias",
+ "first_stage_model.encoder.mid.block_2.norm1.weight": "blocks.13.norm1.weight",
+ "first_stage_model.encoder.mid.block_2.norm2.bias": "blocks.13.norm2.bias",
+ "first_stage_model.encoder.mid.block_2.norm2.weight": "blocks.13.norm2.weight",
+ "first_stage_model.encoder.norm_out.bias": "conv_norm_out.bias",
+ "first_stage_model.encoder.norm_out.weight": "conv_norm_out.weight",
+ "first_stage_model.quant_conv.bias": "quant_conv.bias",
+ "first_stage_model.quant_conv.weight": "quant_conv.weight",
+ }
+ state_dict_ = {}
+ for name in state_dict:
+ if name in rename_dict:
+ param = state_dict[name]
+ if "transformer_blocks" in rename_dict[name]:
+ param = param.squeeze()
+ state_dict_[rename_dict[name]] = param
+ return state_dict_
diff --git a/PusaV1/diffsynth/models/sdxl_controlnet.py b/PusaV1/diffsynth/models/sdxl_controlnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..acddf1cc4109af01bd9c06121f6cd4d8604ce945
--- /dev/null
+++ b/PusaV1/diffsynth/models/sdxl_controlnet.py
@@ -0,0 +1,318 @@
+import torch
+from .sd_unet import Timesteps, ResnetBlock, AttentionBlock, PushBlock, DownSampler
+from .sdxl_unet import SDXLUNet
+from .tiler import TileWorker
+from .sd_controlnet import ControlNetConditioningLayer
+from collections import OrderedDict
+
+
+
+class QuickGELU(torch.nn.Module):
+
+ def forward(self, x: torch.Tensor):
+ return x * torch.sigmoid(1.702 * x)
+
+
+
+class ResidualAttentionBlock(torch.nn.Module):
+
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
+ super().__init__()
+
+ self.attn = torch.nn.MultiheadAttention(d_model, n_head)
+ self.ln_1 = torch.nn.LayerNorm(d_model)
+ self.mlp = torch.nn.Sequential(OrderedDict([
+ ("c_fc", torch.nn.Linear(d_model, d_model * 4)),
+ ("gelu", QuickGELU()),
+ ("c_proj", torch.nn.Linear(d_model * 4, d_model))
+ ]))
+ self.ln_2 = torch.nn.LayerNorm(d_model)
+ self.attn_mask = attn_mask
+
+ def attention(self, x: torch.Tensor):
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
+
+ def forward(self, x: torch.Tensor):
+ x = x + self.attention(self.ln_1(x))
+ x = x + self.mlp(self.ln_2(x))
+ return x
+
+
+
+class SDXLControlNetUnion(torch.nn.Module):
+ def __init__(self, global_pool=False):
+ super().__init__()
+ self.time_proj = Timesteps(320)
+ self.time_embedding = torch.nn.Sequential(
+ torch.nn.Linear(320, 1280),
+ torch.nn.SiLU(),
+ torch.nn.Linear(1280, 1280)
+ )
+ self.add_time_proj = Timesteps(256)
+ self.add_time_embedding = torch.nn.Sequential(
+ torch.nn.Linear(2816, 1280),
+ torch.nn.SiLU(),
+ torch.nn.Linear(1280, 1280)
+ )
+ self.control_type_proj = Timesteps(256)
+ self.control_type_embedding = torch.nn.Sequential(
+ torch.nn.Linear(256 * 8, 1280),
+ torch.nn.SiLU(),
+ torch.nn.Linear(1280, 1280)
+ )
+ self.conv_in = torch.nn.Conv2d(4, 320, kernel_size=3, padding=1)
+
+ self.controlnet_conv_in = ControlNetConditioningLayer(channels=(3, 16, 32, 96, 256, 320))
+ self.controlnet_transformer = ResidualAttentionBlock(320, 8)
+ self.task_embedding = torch.nn.Parameter(torch.randn(8, 320))
+ self.spatial_ch_projs = torch.nn.Linear(320, 320)
+
+ self.blocks = torch.nn.ModuleList([
+ # DownBlock2D
+ ResnetBlock(320, 320, 1280),
+ PushBlock(),
+ ResnetBlock(320, 320, 1280),
+ PushBlock(),
+ DownSampler(320),
+ PushBlock(),
+ # CrossAttnDownBlock2D
+ ResnetBlock(320, 640, 1280),
+ AttentionBlock(10, 64, 640, 2, 2048),
+ PushBlock(),
+ ResnetBlock(640, 640, 1280),
+ AttentionBlock(10, 64, 640, 2, 2048),
+ PushBlock(),
+ DownSampler(640),
+ PushBlock(),
+ # CrossAttnDownBlock2D
+ ResnetBlock(640, 1280, 1280),
+ AttentionBlock(20, 64, 1280, 10, 2048),
+ PushBlock(),
+ ResnetBlock(1280, 1280, 1280),
+ AttentionBlock(20, 64, 1280, 10, 2048),
+ PushBlock(),
+ # UNetMidBlock2DCrossAttn
+ ResnetBlock(1280, 1280, 1280),
+ AttentionBlock(20, 64, 1280, 10, 2048),
+ ResnetBlock(1280, 1280, 1280),
+ PushBlock()
+ ])
+
+ self.controlnet_blocks = torch.nn.ModuleList([
+ torch.nn.Conv2d(320, 320, kernel_size=(1, 1)),
+ torch.nn.Conv2d(320, 320, kernel_size=(1, 1)),
+ torch.nn.Conv2d(320, 320, kernel_size=(1, 1)),
+ torch.nn.Conv2d(320, 320, kernel_size=(1, 1)),
+ torch.nn.Conv2d(640, 640, kernel_size=(1, 1)),
+ torch.nn.Conv2d(640, 640, kernel_size=(1, 1)),
+ torch.nn.Conv2d(640, 640, kernel_size=(1, 1)),
+ torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1)),
+ torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1)),
+ torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1)),
+ ])
+
+ self.global_pool = global_pool
+
+ # 0 -- openpose
+ # 1 -- depth
+ # 2 -- hed/pidi/scribble/ted
+ # 3 -- canny/lineart/anime_lineart/mlsd
+ # 4 -- normal
+ # 5 -- segment
+ # 6 -- tile
+ # 7 -- repaint
+ self.task_id = {
+ "openpose": 0,
+ "depth": 1,
+ "softedge": 2,
+ "canny": 3,
+ "lineart": 3,
+ "lineart_anime": 3,
+ "tile": 6,
+ "inpaint": 7
+ }
+
+
+ def fuse_condition_to_input(self, hidden_states, task_id, conditioning):
+ controlnet_cond = self.controlnet_conv_in(conditioning)
+ feat_seq = torch.mean(controlnet_cond, dim=(2, 3))
+ feat_seq = feat_seq + self.task_embedding[task_id]
+ x = torch.stack([feat_seq, torch.mean(hidden_states, dim=(2, 3))], dim=1)
+ x = self.controlnet_transformer(x)
+
+ alpha = self.spatial_ch_projs(x[:,0]).unsqueeze(-1).unsqueeze(-1)
+ controlnet_cond_fuser = controlnet_cond + alpha
+
+ hidden_states = hidden_states + controlnet_cond_fuser
+ return hidden_states
+
+
+ def forward(
+ self,
+ sample, timestep, encoder_hidden_states,
+ conditioning, processor_id, add_time_id, add_text_embeds,
+ tiled=False, tile_size=64, tile_stride=32,
+ unet:SDXLUNet=None,
+ **kwargs
+ ):
+ task_id = self.task_id[processor_id]
+
+ # 1. time
+ t_emb = self.time_proj(timestep).to(sample.dtype)
+ t_emb = self.time_embedding(t_emb)
+
+ time_embeds = self.add_time_proj(add_time_id)
+ time_embeds = time_embeds.reshape((add_text_embeds.shape[0], -1))
+ add_embeds = torch.concat([add_text_embeds, time_embeds], dim=-1)
+ add_embeds = add_embeds.to(sample.dtype)
+ if unet is not None and unet.is_kolors:
+ add_embeds = unet.add_time_embedding(add_embeds)
+ else:
+ add_embeds = self.add_time_embedding(add_embeds)
+
+ control_type = torch.zeros((sample.shape[0], 8), dtype=sample.dtype, device=sample.device)
+ control_type[:, task_id] = 1
+ control_embeds = self.control_type_proj(control_type.flatten())
+ control_embeds = control_embeds.reshape((sample.shape[0], -1))
+ control_embeds = control_embeds.to(sample.dtype)
+ control_embeds = self.control_type_embedding(control_embeds)
+ time_emb = t_emb + add_embeds + control_embeds
+
+ # 2. pre-process
+ height, width = sample.shape[2], sample.shape[3]
+ hidden_states = self.conv_in(sample)
+ hidden_states = self.fuse_condition_to_input(hidden_states, task_id, conditioning)
+ text_emb = encoder_hidden_states
+ if unet is not None and unet.is_kolors:
+ text_emb = unet.text_intermediate_proj(text_emb)
+ res_stack = [hidden_states]
+
+ # 3. blocks
+ for i, block in enumerate(self.blocks):
+ if tiled and not isinstance(block, PushBlock):
+ _, _, inter_height, _ = hidden_states.shape
+ resize_scale = inter_height / height
+ hidden_states = TileWorker().tiled_forward(
+ lambda x: block(x, time_emb, text_emb, res_stack)[0],
+ hidden_states,
+ int(tile_size * resize_scale),
+ int(tile_stride * resize_scale),
+ tile_device=hidden_states.device,
+ tile_dtype=hidden_states.dtype
+ )
+ else:
+ hidden_states, _, _, _ = block(hidden_states, time_emb, text_emb, res_stack)
+
+ # 4. ControlNet blocks
+ controlnet_res_stack = [block(res) for block, res in zip(self.controlnet_blocks, res_stack)]
+
+ # pool
+ if self.global_pool:
+ controlnet_res_stack = [res.mean(dim=(2, 3), keepdim=True) for res in controlnet_res_stack]
+
+ return controlnet_res_stack
+
+ @staticmethod
+ def state_dict_converter():
+ return SDXLControlNetUnionStateDictConverter()
+
+
+
+class SDXLControlNetUnionStateDictConverter:
+ def __init__(self):
+ pass
+
+ def from_diffusers(self, state_dict):
+ # architecture
+ block_types = [
+ "ResnetBlock", "PushBlock", "ResnetBlock", "PushBlock", "DownSampler", "PushBlock",
+ "ResnetBlock", "AttentionBlock", "PushBlock", "ResnetBlock", "AttentionBlock", "PushBlock", "DownSampler", "PushBlock",
+ "ResnetBlock", "AttentionBlock", "PushBlock", "ResnetBlock", "AttentionBlock", "PushBlock",
+ "ResnetBlock", "AttentionBlock", "ResnetBlock", "PushBlock"
+ ]
+
+ # controlnet_rename_dict
+ controlnet_rename_dict = {
+ "controlnet_cond_embedding.conv_in.weight": "controlnet_conv_in.blocks.0.weight",
+ "controlnet_cond_embedding.conv_in.bias": "controlnet_conv_in.blocks.0.bias",
+ "controlnet_cond_embedding.blocks.0.weight": "controlnet_conv_in.blocks.2.weight",
+ "controlnet_cond_embedding.blocks.0.bias": "controlnet_conv_in.blocks.2.bias",
+ "controlnet_cond_embedding.blocks.1.weight": "controlnet_conv_in.blocks.4.weight",
+ "controlnet_cond_embedding.blocks.1.bias": "controlnet_conv_in.blocks.4.bias",
+ "controlnet_cond_embedding.blocks.2.weight": "controlnet_conv_in.blocks.6.weight",
+ "controlnet_cond_embedding.blocks.2.bias": "controlnet_conv_in.blocks.6.bias",
+ "controlnet_cond_embedding.blocks.3.weight": "controlnet_conv_in.blocks.8.weight",
+ "controlnet_cond_embedding.blocks.3.bias": "controlnet_conv_in.blocks.8.bias",
+ "controlnet_cond_embedding.blocks.4.weight": "controlnet_conv_in.blocks.10.weight",
+ "controlnet_cond_embedding.blocks.4.bias": "controlnet_conv_in.blocks.10.bias",
+ "controlnet_cond_embedding.blocks.5.weight": "controlnet_conv_in.blocks.12.weight",
+ "controlnet_cond_embedding.blocks.5.bias": "controlnet_conv_in.blocks.12.bias",
+ "controlnet_cond_embedding.conv_out.weight": "controlnet_conv_in.blocks.14.weight",
+ "controlnet_cond_embedding.conv_out.bias": "controlnet_conv_in.blocks.14.bias",
+ "control_add_embedding.linear_1.weight": "control_type_embedding.0.weight",
+ "control_add_embedding.linear_1.bias": "control_type_embedding.0.bias",
+ "control_add_embedding.linear_2.weight": "control_type_embedding.2.weight",
+ "control_add_embedding.linear_2.bias": "control_type_embedding.2.bias",
+ }
+
+ # Rename each parameter
+ name_list = sorted([name for name in state_dict])
+ rename_dict = {}
+ block_id = {"ResnetBlock": -1, "AttentionBlock": -1, "DownSampler": -1, "UpSampler": -1}
+ last_block_type_with_id = {"ResnetBlock": "", "AttentionBlock": "", "DownSampler": "", "UpSampler": ""}
+ for name in name_list:
+ names = name.split(".")
+ if names[0] in ["conv_in", "conv_norm_out", "conv_out", "task_embedding", "spatial_ch_projs"]:
+ pass
+ elif name in controlnet_rename_dict:
+ names = controlnet_rename_dict[name].split(".")
+ elif names[0] == "controlnet_down_blocks":
+ names[0] = "controlnet_blocks"
+ elif names[0] == "controlnet_mid_block":
+ names = ["controlnet_blocks", "9", names[-1]]
+ elif names[0] in ["time_embedding", "add_embedding"]:
+ if names[0] == "add_embedding":
+ names[0] = "add_time_embedding"
+ names[1] = {"linear_1": "0", "linear_2": "2"}[names[1]]
+ elif names[0] == "control_add_embedding":
+ names[0] = "control_type_embedding"
+ elif names[0] == "transformer_layes":
+ names[0] = "controlnet_transformer"
+ names.pop(1)
+ elif names[0] in ["down_blocks", "mid_block", "up_blocks"]:
+ if names[0] == "mid_block":
+ names.insert(1, "0")
+ block_type = {"resnets": "ResnetBlock", "attentions": "AttentionBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[2]]
+ block_type_with_id = ".".join(names[:4])
+ if block_type_with_id != last_block_type_with_id[block_type]:
+ block_id[block_type] += 1
+ last_block_type_with_id[block_type] = block_type_with_id
+ while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type:
+ block_id[block_type] += 1
+ block_type_with_id = ".".join(names[:4])
+ names = ["blocks", str(block_id[block_type])] + names[4:]
+ if "ff" in names:
+ ff_index = names.index("ff")
+ component = ".".join(names[ff_index:ff_index+3])
+ component = {"ff.net.0": "act_fn", "ff.net.2": "ff"}[component]
+ names = names[:ff_index] + [component] + names[ff_index+3:]
+ if "to_out" in names:
+ names.pop(names.index("to_out") + 1)
+ else:
+ print(name, state_dict[name].shape)
+ # raise ValueError(f"Unknown parameters: {name}")
+ rename_dict[name] = ".".join(names)
+
+ # Convert state_dict
+ state_dict_ = {}
+ for name, param in state_dict.items():
+ if name not in rename_dict:
+ continue
+ if ".proj_in." in name or ".proj_out." in name:
+ param = param.squeeze()
+ state_dict_[rename_dict[name]] = param
+ return state_dict_
+
+ def from_civitai(self, state_dict):
+ return self.from_diffusers(state_dict)
\ No newline at end of file
diff --git a/PusaV1/diffsynth/models/sdxl_ipadapter.py b/PusaV1/diffsynth/models/sdxl_ipadapter.py
new file mode 100644
index 0000000000000000000000000000000000000000..d959d3b9249e690c6240670fc191e11e27deaaa2
--- /dev/null
+++ b/PusaV1/diffsynth/models/sdxl_ipadapter.py
@@ -0,0 +1,122 @@
+from .svd_image_encoder import SVDImageEncoder
+from transformers import CLIPImageProcessor
+import torch
+
+
+class IpAdapterXLCLIPImageEmbedder(SVDImageEncoder):
+ def __init__(self):
+ super().__init__(embed_dim=1664, encoder_intermediate_size=8192, projection_dim=1280, num_encoder_layers=48, num_heads=16, head_dim=104)
+ self.image_processor = CLIPImageProcessor()
+
+ def forward(self, image):
+ pixel_values = self.image_processor(images=image, return_tensors="pt").pixel_values
+ pixel_values = pixel_values.to(device=self.embeddings.class_embedding.device, dtype=self.embeddings.class_embedding.dtype)
+ return super().forward(pixel_values)
+
+
+class IpAdapterImageProjModel(torch.nn.Module):
+ def __init__(self, cross_attention_dim=2048, clip_embeddings_dim=1280, clip_extra_context_tokens=4):
+ super().__init__()
+ self.cross_attention_dim = cross_attention_dim
+ self.clip_extra_context_tokens = clip_extra_context_tokens
+ self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
+ self.norm = torch.nn.LayerNorm(cross_attention_dim)
+
+ def forward(self, image_embeds):
+ clip_extra_context_tokens = self.proj(image_embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim)
+ clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
+ return clip_extra_context_tokens
+
+
+class IpAdapterModule(torch.nn.Module):
+ def __init__(self, input_dim, output_dim):
+ super().__init__()
+ self.to_k_ip = torch.nn.Linear(input_dim, output_dim, bias=False)
+ self.to_v_ip = torch.nn.Linear(input_dim, output_dim, bias=False)
+
+ def forward(self, hidden_states):
+ ip_k = self.to_k_ip(hidden_states)
+ ip_v = self.to_v_ip(hidden_states)
+ return ip_k, ip_v
+
+
+class SDXLIpAdapter(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ shape_list = [(2048, 640)] * 4 + [(2048, 1280)] * 50 + [(2048, 640)] * 6 + [(2048, 1280)] * 10
+ self.ipadapter_modules = torch.nn.ModuleList([IpAdapterModule(*shape) for shape in shape_list])
+ self.image_proj = IpAdapterImageProjModel()
+ self.set_full_adapter()
+
+ def set_full_adapter(self):
+ map_list = sum([
+ [(7, i) for i in range(2)],
+ [(10, i) for i in range(2)],
+ [(15, i) for i in range(10)],
+ [(18, i) for i in range(10)],
+ [(25, i) for i in range(10)],
+ [(28, i) for i in range(10)],
+ [(31, i) for i in range(10)],
+ [(35, i) for i in range(2)],
+ [(38, i) for i in range(2)],
+ [(41, i) for i in range(2)],
+ [(21, i) for i in range(10)],
+ ], [])
+ self.call_block_id = {i: j for j, i in enumerate(map_list)}
+
+ def set_less_adapter(self):
+ map_list = sum([
+ [(7, i) for i in range(2)],
+ [(10, i) for i in range(2)],
+ [(15, i) for i in range(10)],
+ [(18, i) for i in range(10)],
+ [(25, i) for i in range(10)],
+ [(28, i) for i in range(10)],
+ [(31, i) for i in range(10)],
+ [(35, i) for i in range(2)],
+ [(38, i) for i in range(2)],
+ [(41, i) for i in range(2)],
+ [(21, i) for i in range(10)],
+ ], [])
+ self.call_block_id = {i: j for j, i in enumerate(map_list) if j>=34 and j<44}
+
+ def forward(self, hidden_states, scale=1.0):
+ hidden_states = self.image_proj(hidden_states)
+ hidden_states = hidden_states.view(1, -1, hidden_states.shape[-1])
+ ip_kv_dict = {}
+ for (block_id, transformer_id) in self.call_block_id:
+ ipadapter_id = self.call_block_id[(block_id, transformer_id)]
+ ip_k, ip_v = self.ipadapter_modules[ipadapter_id](hidden_states)
+ if block_id not in ip_kv_dict:
+ ip_kv_dict[block_id] = {}
+ ip_kv_dict[block_id][transformer_id] = {
+ "ip_k": ip_k,
+ "ip_v": ip_v,
+ "scale": scale
+ }
+ return ip_kv_dict
+
+ @staticmethod
+ def state_dict_converter():
+ return SDXLIpAdapterStateDictConverter()
+
+
+class SDXLIpAdapterStateDictConverter:
+ def __init__(self):
+ pass
+
+ def from_diffusers(self, state_dict):
+ state_dict_ = {}
+ for name in state_dict["ip_adapter"]:
+ names = name.split(".")
+ layer_id = str(int(names[0]) // 2)
+ name_ = ".".join(["ipadapter_modules"] + [layer_id] + names[1:])
+ state_dict_[name_] = state_dict["ip_adapter"][name]
+ for name in state_dict["image_proj"]:
+ name_ = "image_proj." + name
+ state_dict_[name_] = state_dict["image_proj"][name]
+ return state_dict_
+
+ def from_civitai(self, state_dict):
+ return self.from_diffusers(state_dict)
+
diff --git a/PusaV1/diffsynth/models/sdxl_motion.py b/PusaV1/diffsynth/models/sdxl_motion.py
new file mode 100644
index 0000000000000000000000000000000000000000..268c3e96f006e697eed5ac03fad3f5c995cfa319
--- /dev/null
+++ b/PusaV1/diffsynth/models/sdxl_motion.py
@@ -0,0 +1,104 @@
+from .sd_motion import TemporalBlock
+import torch
+
+
+
+class SDXLMotionModel(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.motion_modules = torch.nn.ModuleList([
+ TemporalBlock(8, 320//8, 320, eps=1e-6),
+ TemporalBlock(8, 320//8, 320, eps=1e-6),
+
+ TemporalBlock(8, 640//8, 640, eps=1e-6),
+ TemporalBlock(8, 640//8, 640, eps=1e-6),
+
+ TemporalBlock(8, 1280//8, 1280, eps=1e-6),
+ TemporalBlock(8, 1280//8, 1280, eps=1e-6),
+
+ TemporalBlock(8, 1280//8, 1280, eps=1e-6),
+ TemporalBlock(8, 1280//8, 1280, eps=1e-6),
+ TemporalBlock(8, 1280//8, 1280, eps=1e-6),
+
+ TemporalBlock(8, 640//8, 640, eps=1e-6),
+ TemporalBlock(8, 640//8, 640, eps=1e-6),
+ TemporalBlock(8, 640//8, 640, eps=1e-6),
+
+ TemporalBlock(8, 320//8, 320, eps=1e-6),
+ TemporalBlock(8, 320//8, 320, eps=1e-6),
+ TemporalBlock(8, 320//8, 320, eps=1e-6),
+ ])
+ self.call_block_id = {
+ 0: 0,
+ 2: 1,
+ 7: 2,
+ 10: 3,
+ 15: 4,
+ 18: 5,
+ 25: 6,
+ 28: 7,
+ 31: 8,
+ 35: 9,
+ 38: 10,
+ 41: 11,
+ 44: 12,
+ 46: 13,
+ 48: 14,
+ }
+
+ def forward(self):
+ pass
+
+ @staticmethod
+ def state_dict_converter():
+ return SDMotionModelStateDictConverter()
+
+
+class SDMotionModelStateDictConverter:
+ def __init__(self):
+ pass
+
+ def from_diffusers(self, state_dict):
+ rename_dict = {
+ "norm": "norm",
+ "proj_in": "proj_in",
+ "transformer_blocks.0.attention_blocks.0.to_q": "transformer_blocks.0.attn1.to_q",
+ "transformer_blocks.0.attention_blocks.0.to_k": "transformer_blocks.0.attn1.to_k",
+ "transformer_blocks.0.attention_blocks.0.to_v": "transformer_blocks.0.attn1.to_v",
+ "transformer_blocks.0.attention_blocks.0.to_out.0": "transformer_blocks.0.attn1.to_out",
+ "transformer_blocks.0.attention_blocks.0.pos_encoder": "transformer_blocks.0.pe1",
+ "transformer_blocks.0.attention_blocks.1.to_q": "transformer_blocks.0.attn2.to_q",
+ "transformer_blocks.0.attention_blocks.1.to_k": "transformer_blocks.0.attn2.to_k",
+ "transformer_blocks.0.attention_blocks.1.to_v": "transformer_blocks.0.attn2.to_v",
+ "transformer_blocks.0.attention_blocks.1.to_out.0": "transformer_blocks.0.attn2.to_out",
+ "transformer_blocks.0.attention_blocks.1.pos_encoder": "transformer_blocks.0.pe2",
+ "transformer_blocks.0.norms.0": "transformer_blocks.0.norm1",
+ "transformer_blocks.0.norms.1": "transformer_blocks.0.norm2",
+ "transformer_blocks.0.ff.net.0.proj": "transformer_blocks.0.act_fn.proj",
+ "transformer_blocks.0.ff.net.2": "transformer_blocks.0.ff",
+ "transformer_blocks.0.ff_norm": "transformer_blocks.0.norm3",
+ "proj_out": "proj_out",
+ }
+ name_list = sorted([i for i in state_dict if i.startswith("down_blocks.")])
+ name_list += sorted([i for i in state_dict if i.startswith("mid_block.")])
+ name_list += sorted([i for i in state_dict if i.startswith("up_blocks.")])
+ state_dict_ = {}
+ last_prefix, module_id = "", -1
+ for name in name_list:
+ names = name.split(".")
+ prefix_index = names.index("temporal_transformer") + 1
+ prefix = ".".join(names[:prefix_index])
+ if prefix != last_prefix:
+ last_prefix = prefix
+ module_id += 1
+ middle_name = ".".join(names[prefix_index:-1])
+ suffix = names[-1]
+ if "pos_encoder" in names:
+ rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name]])
+ else:
+ rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name], suffix])
+ state_dict_[rename] = state_dict[name]
+ return state_dict_
+
+ def from_civitai(self, state_dict):
+ return self.from_diffusers(state_dict)
diff --git a/PusaV1/diffsynth/models/sdxl_text_encoder.py b/PusaV1/diffsynth/models/sdxl_text_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d30c7d4056bf37abbb341b2807aa47a67785023
--- /dev/null
+++ b/PusaV1/diffsynth/models/sdxl_text_encoder.py
@@ -0,0 +1,759 @@
+import torch
+from .sd_text_encoder import CLIPEncoderLayer
+
+
+class SDXLTextEncoder(torch.nn.Module):
+ def __init__(self, embed_dim=768, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=11, encoder_intermediate_size=3072):
+ super().__init__()
+
+ # token_embedding
+ self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim)
+
+ # position_embeds (This is a fixed tensor)
+ self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim))
+
+ # encoders
+ self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size) for _ in range(num_encoder_layers)])
+
+ # attn_mask
+ self.attn_mask = self.attention_mask(max_position_embeddings)
+
+ # The text encoder is different to that in Stable Diffusion 1.x.
+ # It does not include final_layer_norm.
+
+ def attention_mask(self, length):
+ mask = torch.empty(length, length)
+ mask.fill_(float("-inf"))
+ mask.triu_(1)
+ return mask
+
+ def forward(self, input_ids, clip_skip=1):
+ embeds = self.token_embedding(input_ids) + self.position_embeds
+ attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
+ for encoder_id, encoder in enumerate(self.encoders):
+ embeds = encoder(embeds, attn_mask=attn_mask)
+ if encoder_id + clip_skip == len(self.encoders):
+ break
+ return embeds
+
+ @staticmethod
+ def state_dict_converter():
+ return SDXLTextEncoderStateDictConverter()
+
+
+class SDXLTextEncoder2(torch.nn.Module):
+ def __init__(self, embed_dim=1280, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=32, encoder_intermediate_size=5120):
+ super().__init__()
+
+ # token_embedding
+ self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim)
+
+ # position_embeds (This is a fixed tensor)
+ self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim))
+
+ # encoders
+ self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size, num_heads=20, head_dim=64, use_quick_gelu=False) for _ in range(num_encoder_layers)])
+
+ # attn_mask
+ self.attn_mask = self.attention_mask(max_position_embeddings)
+
+ # final_layer_norm
+ self.final_layer_norm = torch.nn.LayerNorm(embed_dim)
+
+ # text_projection
+ self.text_projection = torch.nn.Linear(embed_dim, embed_dim, bias=False)
+
+ def attention_mask(self, length):
+ mask = torch.empty(length, length)
+ mask.fill_(float("-inf"))
+ mask.triu_(1)
+ return mask
+
+ def forward(self, input_ids, clip_skip=2):
+ embeds = self.token_embedding(input_ids) + self.position_embeds
+ attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
+ for encoder_id, encoder in enumerate(self.encoders):
+ embeds = encoder(embeds, attn_mask=attn_mask)
+ if encoder_id + clip_skip == len(self.encoders):
+ hidden_states = embeds
+ embeds = self.final_layer_norm(embeds)
+ pooled_embeds = embeds[torch.arange(embeds.shape[0]), input_ids.to(dtype=torch.int).argmax(dim=-1)]
+ pooled_embeds = self.text_projection(pooled_embeds)
+ return pooled_embeds, hidden_states
+
+ @staticmethod
+ def state_dict_converter():
+ return SDXLTextEncoder2StateDictConverter()
+
+
+class SDXLTextEncoderStateDictConverter:
+ def __init__(self):
+ pass
+
+ def from_diffusers(self, state_dict):
+ rename_dict = {
+ "text_model.embeddings.token_embedding.weight": "token_embedding.weight",
+ "text_model.embeddings.position_embedding.weight": "position_embeds",
+ "text_model.final_layer_norm.weight": "final_layer_norm.weight",
+ "text_model.final_layer_norm.bias": "final_layer_norm.bias"
+ }
+ attn_rename_dict = {
+ "self_attn.q_proj": "attn.to_q",
+ "self_attn.k_proj": "attn.to_k",
+ "self_attn.v_proj": "attn.to_v",
+ "self_attn.out_proj": "attn.to_out",
+ "layer_norm1": "layer_norm1",
+ "layer_norm2": "layer_norm2",
+ "mlp.fc1": "fc1",
+ "mlp.fc2": "fc2",
+ }
+ state_dict_ = {}
+ for name in state_dict:
+ if name in rename_dict:
+ param = state_dict[name]
+ if name == "text_model.embeddings.position_embedding.weight":
+ param = param.reshape((1, param.shape[0], param.shape[1]))
+ state_dict_[rename_dict[name]] = param
+ elif name.startswith("text_model.encoder.layers."):
+ param = state_dict[name]
+ names = name.split(".")
+ layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
+ name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail])
+ state_dict_[name_] = param
+ return state_dict_
+
+ def from_civitai(self, state_dict):
+ rename_dict = {
+ "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight": "position_embeds",
+ "conditioner.embedders.0.transformer.text_model.embeddings.token_embedding.weight": "token_embedding.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm1.bias": "encoders.0.layer_norm1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm1.weight": "encoders.0.layer_norm1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm2.bias": "encoders.0.layer_norm2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm2.weight": "encoders.0.layer_norm2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc1.bias": "encoders.0.fc1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc1.weight": "encoders.0.fc1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc2.bias": "encoders.0.fc2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc2.weight": "encoders.0.fc2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.k_proj.bias": "encoders.0.attn.to_k.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.k_proj.weight": "encoders.0.attn.to_k.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.out_proj.bias": "encoders.0.attn.to_out.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.out_proj.weight": "encoders.0.attn.to_out.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.q_proj.bias": "encoders.0.attn.to_q.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight": "encoders.0.attn.to_q.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.v_proj.bias": "encoders.0.attn.to_v.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.v_proj.weight": "encoders.0.attn.to_v.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm1.bias": "encoders.1.layer_norm1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm1.weight": "encoders.1.layer_norm1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm2.bias": "encoders.1.layer_norm2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm2.weight": "encoders.1.layer_norm2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc1.bias": "encoders.1.fc1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc1.weight": "encoders.1.fc1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc2.bias": "encoders.1.fc2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc2.weight": "encoders.1.fc2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.k_proj.bias": "encoders.1.attn.to_k.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.k_proj.weight": "encoders.1.attn.to_k.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.out_proj.bias": "encoders.1.attn.to_out.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.out_proj.weight": "encoders.1.attn.to_out.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.q_proj.bias": "encoders.1.attn.to_q.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.q_proj.weight": "encoders.1.attn.to_q.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.v_proj.bias": "encoders.1.attn.to_v.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.v_proj.weight": "encoders.1.attn.to_v.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm1.bias": "encoders.10.layer_norm1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm1.weight": "encoders.10.layer_norm1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm2.bias": "encoders.10.layer_norm2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm2.weight": "encoders.10.layer_norm2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc1.bias": "encoders.10.fc1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc1.weight": "encoders.10.fc1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc2.bias": "encoders.10.fc2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc2.weight": "encoders.10.fc2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.k_proj.bias": "encoders.10.attn.to_k.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.k_proj.weight": "encoders.10.attn.to_k.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.out_proj.bias": "encoders.10.attn.to_out.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.out_proj.weight": "encoders.10.attn.to_out.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.q_proj.bias": "encoders.10.attn.to_q.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.q_proj.weight": "encoders.10.attn.to_q.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.v_proj.bias": "encoders.10.attn.to_v.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.v_proj.weight": "encoders.10.attn.to_v.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm1.bias": "encoders.2.layer_norm1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm1.weight": "encoders.2.layer_norm1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm2.bias": "encoders.2.layer_norm2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm2.weight": "encoders.2.layer_norm2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc1.bias": "encoders.2.fc1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc1.weight": "encoders.2.fc1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc2.bias": "encoders.2.fc2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc2.weight": "encoders.2.fc2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.k_proj.bias": "encoders.2.attn.to_k.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.k_proj.weight": "encoders.2.attn.to_k.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.out_proj.bias": "encoders.2.attn.to_out.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.out_proj.weight": "encoders.2.attn.to_out.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.q_proj.bias": "encoders.2.attn.to_q.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.q_proj.weight": "encoders.2.attn.to_q.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.v_proj.bias": "encoders.2.attn.to_v.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.v_proj.weight": "encoders.2.attn.to_v.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm1.bias": "encoders.3.layer_norm1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm1.weight": "encoders.3.layer_norm1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm2.bias": "encoders.3.layer_norm2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm2.weight": "encoders.3.layer_norm2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc1.bias": "encoders.3.fc1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc1.weight": "encoders.3.fc1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc2.bias": "encoders.3.fc2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc2.weight": "encoders.3.fc2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.k_proj.bias": "encoders.3.attn.to_k.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.k_proj.weight": "encoders.3.attn.to_k.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.out_proj.bias": "encoders.3.attn.to_out.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.out_proj.weight": "encoders.3.attn.to_out.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.q_proj.bias": "encoders.3.attn.to_q.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.q_proj.weight": "encoders.3.attn.to_q.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.v_proj.bias": "encoders.3.attn.to_v.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.v_proj.weight": "encoders.3.attn.to_v.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm1.bias": "encoders.4.layer_norm1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm1.weight": "encoders.4.layer_norm1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm2.bias": "encoders.4.layer_norm2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm2.weight": "encoders.4.layer_norm2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc1.bias": "encoders.4.fc1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc1.weight": "encoders.4.fc1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc2.bias": "encoders.4.fc2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc2.weight": "encoders.4.fc2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.k_proj.bias": "encoders.4.attn.to_k.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.k_proj.weight": "encoders.4.attn.to_k.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.out_proj.bias": "encoders.4.attn.to_out.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.out_proj.weight": "encoders.4.attn.to_out.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.q_proj.bias": "encoders.4.attn.to_q.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.q_proj.weight": "encoders.4.attn.to_q.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.v_proj.bias": "encoders.4.attn.to_v.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.v_proj.weight": "encoders.4.attn.to_v.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm1.bias": "encoders.5.layer_norm1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm1.weight": "encoders.5.layer_norm1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm2.bias": "encoders.5.layer_norm2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm2.weight": "encoders.5.layer_norm2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc1.bias": "encoders.5.fc1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc1.weight": "encoders.5.fc1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc2.bias": "encoders.5.fc2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc2.weight": "encoders.5.fc2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.k_proj.bias": "encoders.5.attn.to_k.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.k_proj.weight": "encoders.5.attn.to_k.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.out_proj.bias": "encoders.5.attn.to_out.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.out_proj.weight": "encoders.5.attn.to_out.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.q_proj.bias": "encoders.5.attn.to_q.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.q_proj.weight": "encoders.5.attn.to_q.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.v_proj.bias": "encoders.5.attn.to_v.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.v_proj.weight": "encoders.5.attn.to_v.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm1.bias": "encoders.6.layer_norm1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm1.weight": "encoders.6.layer_norm1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm2.bias": "encoders.6.layer_norm2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm2.weight": "encoders.6.layer_norm2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc1.bias": "encoders.6.fc1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc1.weight": "encoders.6.fc1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc2.bias": "encoders.6.fc2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc2.weight": "encoders.6.fc2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.k_proj.bias": "encoders.6.attn.to_k.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.k_proj.weight": "encoders.6.attn.to_k.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.out_proj.bias": "encoders.6.attn.to_out.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.out_proj.weight": "encoders.6.attn.to_out.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.q_proj.bias": "encoders.6.attn.to_q.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.q_proj.weight": "encoders.6.attn.to_q.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.v_proj.bias": "encoders.6.attn.to_v.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.v_proj.weight": "encoders.6.attn.to_v.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm1.bias": "encoders.7.layer_norm1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm1.weight": "encoders.7.layer_norm1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm2.bias": "encoders.7.layer_norm2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm2.weight": "encoders.7.layer_norm2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc1.bias": "encoders.7.fc1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc1.weight": "encoders.7.fc1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc2.bias": "encoders.7.fc2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc2.weight": "encoders.7.fc2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.k_proj.bias": "encoders.7.attn.to_k.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.k_proj.weight": "encoders.7.attn.to_k.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.out_proj.bias": "encoders.7.attn.to_out.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.out_proj.weight": "encoders.7.attn.to_out.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.q_proj.bias": "encoders.7.attn.to_q.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.q_proj.weight": "encoders.7.attn.to_q.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.v_proj.bias": "encoders.7.attn.to_v.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.v_proj.weight": "encoders.7.attn.to_v.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm1.bias": "encoders.8.layer_norm1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm1.weight": "encoders.8.layer_norm1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm2.bias": "encoders.8.layer_norm2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm2.weight": "encoders.8.layer_norm2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc1.bias": "encoders.8.fc1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc1.weight": "encoders.8.fc1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc2.bias": "encoders.8.fc2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc2.weight": "encoders.8.fc2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.k_proj.bias": "encoders.8.attn.to_k.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.k_proj.weight": "encoders.8.attn.to_k.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.out_proj.bias": "encoders.8.attn.to_out.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.out_proj.weight": "encoders.8.attn.to_out.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.q_proj.bias": "encoders.8.attn.to_q.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.q_proj.weight": "encoders.8.attn.to_q.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.v_proj.bias": "encoders.8.attn.to_v.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.v_proj.weight": "encoders.8.attn.to_v.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm1.bias": "encoders.9.layer_norm1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm1.weight": "encoders.9.layer_norm1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm2.bias": "encoders.9.layer_norm2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm2.weight": "encoders.9.layer_norm2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc1.bias": "encoders.9.fc1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc1.weight": "encoders.9.fc1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc2.bias": "encoders.9.fc2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc2.weight": "encoders.9.fc2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.k_proj.bias": "encoders.9.attn.to_k.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.k_proj.weight": "encoders.9.attn.to_k.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.out_proj.bias": "encoders.9.attn.to_out.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.out_proj.weight": "encoders.9.attn.to_out.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.q_proj.bias": "encoders.9.attn.to_q.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.q_proj.weight": "encoders.9.attn.to_q.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.v_proj.bias": "encoders.9.attn.to_v.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight": "encoders.9.attn.to_v.weight",
+ }
+ state_dict_ = {}
+ for name in state_dict:
+ if name in rename_dict:
+ param = state_dict[name]
+ if name == "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight":
+ param = param.reshape((1, param.shape[0], param.shape[1]))
+ state_dict_[rename_dict[name]] = param
+ return state_dict_
+
+
+class SDXLTextEncoder2StateDictConverter:
+ def __init__(self):
+ pass
+
+ def from_diffusers(self, state_dict):
+ rename_dict = {
+ "text_model.embeddings.token_embedding.weight": "token_embedding.weight",
+ "text_model.embeddings.position_embedding.weight": "position_embeds",
+ "text_model.final_layer_norm.weight": "final_layer_norm.weight",
+ "text_model.final_layer_norm.bias": "final_layer_norm.bias",
+ "text_projection.weight": "text_projection.weight"
+ }
+ attn_rename_dict = {
+ "self_attn.q_proj": "attn.to_q",
+ "self_attn.k_proj": "attn.to_k",
+ "self_attn.v_proj": "attn.to_v",
+ "self_attn.out_proj": "attn.to_out",
+ "layer_norm1": "layer_norm1",
+ "layer_norm2": "layer_norm2",
+ "mlp.fc1": "fc1",
+ "mlp.fc2": "fc2",
+ }
+ state_dict_ = {}
+ for name in state_dict:
+ if name in rename_dict:
+ param = state_dict[name]
+ if name == "text_model.embeddings.position_embedding.weight":
+ param = param.reshape((1, param.shape[0], param.shape[1]))
+ state_dict_[rename_dict[name]] = param
+ elif name.startswith("text_model.encoder.layers."):
+ param = state_dict[name]
+ names = name.split(".")
+ layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
+ name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail])
+ state_dict_[name_] = param
+ return state_dict_
+
+ def from_civitai(self, state_dict):
+ rename_dict = {
+ "conditioner.embedders.1.model.ln_final.bias": "final_layer_norm.bias",
+ "conditioner.embedders.1.model.ln_final.weight": "final_layer_norm.weight",
+ "conditioner.embedders.1.model.positional_embedding": "position_embeds",
+ "conditioner.embedders.1.model.token_embedding.weight": "token_embedding.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_bias": ['encoders.0.attn.to_q.bias', 'encoders.0.attn.to_k.bias', 'encoders.0.attn.to_v.bias'],
+ "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_weight": ['encoders.0.attn.to_q.weight', 'encoders.0.attn.to_k.weight', 'encoders.0.attn.to_v.weight'],
+ "conditioner.embedders.1.model.transformer.resblocks.0.attn.out_proj.bias": "encoders.0.attn.to_out.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.0.attn.out_proj.weight": "encoders.0.attn.to_out.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.0.ln_1.bias": "encoders.0.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.0.ln_1.weight": "encoders.0.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.0.ln_2.bias": "encoders.0.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.0.ln_2.weight": "encoders.0.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_fc.bias": "encoders.0.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_fc.weight": "encoders.0.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_proj.bias": "encoders.0.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_proj.weight": "encoders.0.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_bias": ['encoders.1.attn.to_q.bias', 'encoders.1.attn.to_k.bias', 'encoders.1.attn.to_v.bias'],
+ "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_weight": ['encoders.1.attn.to_q.weight', 'encoders.1.attn.to_k.weight', 'encoders.1.attn.to_v.weight'],
+ "conditioner.embedders.1.model.transformer.resblocks.1.attn.out_proj.bias": "encoders.1.attn.to_out.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.1.attn.out_proj.weight": "encoders.1.attn.to_out.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.1.ln_1.bias": "encoders.1.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.1.ln_1.weight": "encoders.1.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.1.ln_2.bias": "encoders.1.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.1.ln_2.weight": "encoders.1.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_fc.bias": "encoders.1.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_fc.weight": "encoders.1.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_proj.bias": "encoders.1.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_proj.weight": "encoders.1.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_bias": ['encoders.10.attn.to_q.bias', 'encoders.10.attn.to_k.bias', 'encoders.10.attn.to_v.bias'],
+ "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_weight": ['encoders.10.attn.to_q.weight', 'encoders.10.attn.to_k.weight', 'encoders.10.attn.to_v.weight'],
+ "conditioner.embedders.1.model.transformer.resblocks.10.attn.out_proj.bias": "encoders.10.attn.to_out.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.10.attn.out_proj.weight": "encoders.10.attn.to_out.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.10.ln_1.bias": "encoders.10.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.10.ln_1.weight": "encoders.10.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.10.ln_2.bias": "encoders.10.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.10.ln_2.weight": "encoders.10.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_fc.bias": "encoders.10.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_fc.weight": "encoders.10.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_proj.bias": "encoders.10.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_proj.weight": "encoders.10.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_bias": ['encoders.11.attn.to_q.bias', 'encoders.11.attn.to_k.bias', 'encoders.11.attn.to_v.bias'],
+ "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_weight": ['encoders.11.attn.to_q.weight', 'encoders.11.attn.to_k.weight', 'encoders.11.attn.to_v.weight'],
+ "conditioner.embedders.1.model.transformer.resblocks.11.attn.out_proj.bias": "encoders.11.attn.to_out.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.11.attn.out_proj.weight": "encoders.11.attn.to_out.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.11.ln_1.bias": "encoders.11.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.11.ln_1.weight": "encoders.11.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.11.ln_2.bias": "encoders.11.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.11.ln_2.weight": "encoders.11.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_fc.bias": "encoders.11.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_fc.weight": "encoders.11.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_proj.bias": "encoders.11.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_proj.weight": "encoders.11.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_bias": ['encoders.12.attn.to_q.bias', 'encoders.12.attn.to_k.bias', 'encoders.12.attn.to_v.bias'],
+ "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_weight": ['encoders.12.attn.to_q.weight', 'encoders.12.attn.to_k.weight', 'encoders.12.attn.to_v.weight'],
+ "conditioner.embedders.1.model.transformer.resblocks.12.attn.out_proj.bias": "encoders.12.attn.to_out.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.12.attn.out_proj.weight": "encoders.12.attn.to_out.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.12.ln_1.bias": "encoders.12.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.12.ln_1.weight": "encoders.12.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.12.ln_2.bias": "encoders.12.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.12.ln_2.weight": "encoders.12.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_fc.bias": "encoders.12.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_fc.weight": "encoders.12.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_proj.bias": "encoders.12.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_proj.weight": "encoders.12.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_bias": ['encoders.13.attn.to_q.bias', 'encoders.13.attn.to_k.bias', 'encoders.13.attn.to_v.bias'],
+ "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_weight": ['encoders.13.attn.to_q.weight', 'encoders.13.attn.to_k.weight', 'encoders.13.attn.to_v.weight'],
+ "conditioner.embedders.1.model.transformer.resblocks.13.attn.out_proj.bias": "encoders.13.attn.to_out.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.13.attn.out_proj.weight": "encoders.13.attn.to_out.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.13.ln_1.bias": "encoders.13.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.13.ln_1.weight": "encoders.13.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.13.ln_2.bias": "encoders.13.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.13.ln_2.weight": "encoders.13.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_fc.bias": "encoders.13.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_fc.weight": "encoders.13.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_proj.bias": "encoders.13.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_proj.weight": "encoders.13.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_bias": ['encoders.14.attn.to_q.bias', 'encoders.14.attn.to_k.bias', 'encoders.14.attn.to_v.bias'],
+ "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_weight": ['encoders.14.attn.to_q.weight', 'encoders.14.attn.to_k.weight', 'encoders.14.attn.to_v.weight'],
+ "conditioner.embedders.1.model.transformer.resblocks.14.attn.out_proj.bias": "encoders.14.attn.to_out.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.14.attn.out_proj.weight": "encoders.14.attn.to_out.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.14.ln_1.bias": "encoders.14.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.14.ln_1.weight": "encoders.14.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.14.ln_2.bias": "encoders.14.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.14.ln_2.weight": "encoders.14.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_fc.bias": "encoders.14.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_fc.weight": "encoders.14.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_proj.bias": "encoders.14.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_proj.weight": "encoders.14.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_bias": ['encoders.15.attn.to_q.bias', 'encoders.15.attn.to_k.bias', 'encoders.15.attn.to_v.bias'],
+ "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_weight": ['encoders.15.attn.to_q.weight', 'encoders.15.attn.to_k.weight', 'encoders.15.attn.to_v.weight'],
+ "conditioner.embedders.1.model.transformer.resblocks.15.attn.out_proj.bias": "encoders.15.attn.to_out.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.15.attn.out_proj.weight": "encoders.15.attn.to_out.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.15.ln_1.bias": "encoders.15.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.15.ln_1.weight": "encoders.15.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.15.ln_2.bias": "encoders.15.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.15.ln_2.weight": "encoders.15.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_fc.bias": "encoders.15.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_fc.weight": "encoders.15.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_proj.bias": "encoders.15.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_proj.weight": "encoders.15.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_bias": ['encoders.16.attn.to_q.bias', 'encoders.16.attn.to_k.bias', 'encoders.16.attn.to_v.bias'],
+ "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_weight": ['encoders.16.attn.to_q.weight', 'encoders.16.attn.to_k.weight', 'encoders.16.attn.to_v.weight'],
+ "conditioner.embedders.1.model.transformer.resblocks.16.attn.out_proj.bias": "encoders.16.attn.to_out.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.16.attn.out_proj.weight": "encoders.16.attn.to_out.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.16.ln_1.bias": "encoders.16.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.16.ln_1.weight": "encoders.16.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.16.ln_2.bias": "encoders.16.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.16.ln_2.weight": "encoders.16.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_fc.bias": "encoders.16.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_fc.weight": "encoders.16.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_proj.bias": "encoders.16.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_proj.weight": "encoders.16.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_bias": ['encoders.17.attn.to_q.bias', 'encoders.17.attn.to_k.bias', 'encoders.17.attn.to_v.bias'],
+ "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_weight": ['encoders.17.attn.to_q.weight', 'encoders.17.attn.to_k.weight', 'encoders.17.attn.to_v.weight'],
+ "conditioner.embedders.1.model.transformer.resblocks.17.attn.out_proj.bias": "encoders.17.attn.to_out.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.17.attn.out_proj.weight": "encoders.17.attn.to_out.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.17.ln_1.bias": "encoders.17.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.17.ln_1.weight": "encoders.17.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.17.ln_2.bias": "encoders.17.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.17.ln_2.weight": "encoders.17.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_fc.bias": "encoders.17.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_fc.weight": "encoders.17.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_proj.bias": "encoders.17.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_proj.weight": "encoders.17.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_bias": ['encoders.18.attn.to_q.bias', 'encoders.18.attn.to_k.bias', 'encoders.18.attn.to_v.bias'],
+ "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_weight": ['encoders.18.attn.to_q.weight', 'encoders.18.attn.to_k.weight', 'encoders.18.attn.to_v.weight'],
+ "conditioner.embedders.1.model.transformer.resblocks.18.attn.out_proj.bias": "encoders.18.attn.to_out.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.18.attn.out_proj.weight": "encoders.18.attn.to_out.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.18.ln_1.bias": "encoders.18.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.18.ln_1.weight": "encoders.18.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.18.ln_2.bias": "encoders.18.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.18.ln_2.weight": "encoders.18.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_fc.bias": "encoders.18.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_fc.weight": "encoders.18.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_proj.bias": "encoders.18.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_proj.weight": "encoders.18.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_bias": ['encoders.19.attn.to_q.bias', 'encoders.19.attn.to_k.bias', 'encoders.19.attn.to_v.bias'],
+ "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_weight": ['encoders.19.attn.to_q.weight', 'encoders.19.attn.to_k.weight', 'encoders.19.attn.to_v.weight'],
+ "conditioner.embedders.1.model.transformer.resblocks.19.attn.out_proj.bias": "encoders.19.attn.to_out.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.19.attn.out_proj.weight": "encoders.19.attn.to_out.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.19.ln_1.bias": "encoders.19.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.19.ln_1.weight": "encoders.19.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.19.ln_2.bias": "encoders.19.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.19.ln_2.weight": "encoders.19.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_fc.bias": "encoders.19.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_fc.weight": "encoders.19.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_proj.bias": "encoders.19.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_proj.weight": "encoders.19.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_bias": ['encoders.2.attn.to_q.bias', 'encoders.2.attn.to_k.bias', 'encoders.2.attn.to_v.bias'],
+ "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_weight": ['encoders.2.attn.to_q.weight', 'encoders.2.attn.to_k.weight', 'encoders.2.attn.to_v.weight'],
+ "conditioner.embedders.1.model.transformer.resblocks.2.attn.out_proj.bias": "encoders.2.attn.to_out.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.2.attn.out_proj.weight": "encoders.2.attn.to_out.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.2.ln_1.bias": "encoders.2.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.2.ln_1.weight": "encoders.2.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.2.ln_2.bias": "encoders.2.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.2.ln_2.weight": "encoders.2.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_fc.bias": "encoders.2.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_fc.weight": "encoders.2.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_proj.bias": "encoders.2.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_proj.weight": "encoders.2.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_bias": ['encoders.20.attn.to_q.bias', 'encoders.20.attn.to_k.bias', 'encoders.20.attn.to_v.bias'],
+ "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_weight": ['encoders.20.attn.to_q.weight', 'encoders.20.attn.to_k.weight', 'encoders.20.attn.to_v.weight'],
+ "conditioner.embedders.1.model.transformer.resblocks.20.attn.out_proj.bias": "encoders.20.attn.to_out.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.20.attn.out_proj.weight": "encoders.20.attn.to_out.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.20.ln_1.bias": "encoders.20.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.20.ln_1.weight": "encoders.20.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.20.ln_2.bias": "encoders.20.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.20.ln_2.weight": "encoders.20.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_fc.bias": "encoders.20.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_fc.weight": "encoders.20.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_proj.bias": "encoders.20.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_proj.weight": "encoders.20.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_bias": ['encoders.21.attn.to_q.bias', 'encoders.21.attn.to_k.bias', 'encoders.21.attn.to_v.bias'],
+ "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_weight": ['encoders.21.attn.to_q.weight', 'encoders.21.attn.to_k.weight', 'encoders.21.attn.to_v.weight'],
+ "conditioner.embedders.1.model.transformer.resblocks.21.attn.out_proj.bias": "encoders.21.attn.to_out.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.21.attn.out_proj.weight": "encoders.21.attn.to_out.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.21.ln_1.bias": "encoders.21.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.21.ln_1.weight": "encoders.21.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.21.ln_2.bias": "encoders.21.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.21.ln_2.weight": "encoders.21.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_fc.bias": "encoders.21.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_fc.weight": "encoders.21.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_proj.bias": "encoders.21.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_proj.weight": "encoders.21.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_bias": ['encoders.22.attn.to_q.bias', 'encoders.22.attn.to_k.bias', 'encoders.22.attn.to_v.bias'],
+ "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_weight": ['encoders.22.attn.to_q.weight', 'encoders.22.attn.to_k.weight', 'encoders.22.attn.to_v.weight'],
+ "conditioner.embedders.1.model.transformer.resblocks.22.attn.out_proj.bias": "encoders.22.attn.to_out.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.22.attn.out_proj.weight": "encoders.22.attn.to_out.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.22.ln_1.bias": "encoders.22.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.22.ln_1.weight": "encoders.22.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.22.ln_2.bias": "encoders.22.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.22.ln_2.weight": "encoders.22.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_fc.bias": "encoders.22.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_fc.weight": "encoders.22.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_proj.bias": "encoders.22.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_proj.weight": "encoders.22.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_bias": ['encoders.23.attn.to_q.bias', 'encoders.23.attn.to_k.bias', 'encoders.23.attn.to_v.bias'],
+ "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_weight": ['encoders.23.attn.to_q.weight', 'encoders.23.attn.to_k.weight', 'encoders.23.attn.to_v.weight'],
+ "conditioner.embedders.1.model.transformer.resblocks.23.attn.out_proj.bias": "encoders.23.attn.to_out.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.23.attn.out_proj.weight": "encoders.23.attn.to_out.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.23.ln_1.bias": "encoders.23.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.23.ln_1.weight": "encoders.23.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.23.ln_2.bias": "encoders.23.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.23.ln_2.weight": "encoders.23.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_fc.bias": "encoders.23.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_fc.weight": "encoders.23.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_proj.bias": "encoders.23.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_proj.weight": "encoders.23.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_bias": ['encoders.24.attn.to_q.bias', 'encoders.24.attn.to_k.bias', 'encoders.24.attn.to_v.bias'],
+ "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_weight": ['encoders.24.attn.to_q.weight', 'encoders.24.attn.to_k.weight', 'encoders.24.attn.to_v.weight'],
+ "conditioner.embedders.1.model.transformer.resblocks.24.attn.out_proj.bias": "encoders.24.attn.to_out.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.24.attn.out_proj.weight": "encoders.24.attn.to_out.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.24.ln_1.bias": "encoders.24.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.24.ln_1.weight": "encoders.24.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.24.ln_2.bias": "encoders.24.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.24.ln_2.weight": "encoders.24.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_fc.bias": "encoders.24.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_fc.weight": "encoders.24.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_proj.bias": "encoders.24.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_proj.weight": "encoders.24.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_bias": ['encoders.25.attn.to_q.bias', 'encoders.25.attn.to_k.bias', 'encoders.25.attn.to_v.bias'],
+ "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_weight": ['encoders.25.attn.to_q.weight', 'encoders.25.attn.to_k.weight', 'encoders.25.attn.to_v.weight'],
+ "conditioner.embedders.1.model.transformer.resblocks.25.attn.out_proj.bias": "encoders.25.attn.to_out.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.25.attn.out_proj.weight": "encoders.25.attn.to_out.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.25.ln_1.bias": "encoders.25.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.25.ln_1.weight": "encoders.25.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.25.ln_2.bias": "encoders.25.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.25.ln_2.weight": "encoders.25.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_fc.bias": "encoders.25.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_fc.weight": "encoders.25.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_proj.bias": "encoders.25.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_proj.weight": "encoders.25.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_bias": ['encoders.26.attn.to_q.bias', 'encoders.26.attn.to_k.bias', 'encoders.26.attn.to_v.bias'],
+ "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_weight": ['encoders.26.attn.to_q.weight', 'encoders.26.attn.to_k.weight', 'encoders.26.attn.to_v.weight'],
+ "conditioner.embedders.1.model.transformer.resblocks.26.attn.out_proj.bias": "encoders.26.attn.to_out.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.26.attn.out_proj.weight": "encoders.26.attn.to_out.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.26.ln_1.bias": "encoders.26.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.26.ln_1.weight": "encoders.26.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.26.ln_2.bias": "encoders.26.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.26.ln_2.weight": "encoders.26.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_fc.bias": "encoders.26.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_fc.weight": "encoders.26.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_proj.bias": "encoders.26.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_proj.weight": "encoders.26.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_bias": ['encoders.27.attn.to_q.bias', 'encoders.27.attn.to_k.bias', 'encoders.27.attn.to_v.bias'],
+ "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_weight": ['encoders.27.attn.to_q.weight', 'encoders.27.attn.to_k.weight', 'encoders.27.attn.to_v.weight'],
+ "conditioner.embedders.1.model.transformer.resblocks.27.attn.out_proj.bias": "encoders.27.attn.to_out.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.27.attn.out_proj.weight": "encoders.27.attn.to_out.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.27.ln_1.bias": "encoders.27.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.27.ln_1.weight": "encoders.27.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.27.ln_2.bias": "encoders.27.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.27.ln_2.weight": "encoders.27.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_fc.bias": "encoders.27.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_fc.weight": "encoders.27.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_proj.bias": "encoders.27.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_proj.weight": "encoders.27.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_bias": ['encoders.28.attn.to_q.bias', 'encoders.28.attn.to_k.bias', 'encoders.28.attn.to_v.bias'],
+ "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_weight": ['encoders.28.attn.to_q.weight', 'encoders.28.attn.to_k.weight', 'encoders.28.attn.to_v.weight'],
+ "conditioner.embedders.1.model.transformer.resblocks.28.attn.out_proj.bias": "encoders.28.attn.to_out.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.28.attn.out_proj.weight": "encoders.28.attn.to_out.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.28.ln_1.bias": "encoders.28.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.28.ln_1.weight": "encoders.28.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.28.ln_2.bias": "encoders.28.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.28.ln_2.weight": "encoders.28.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_fc.bias": "encoders.28.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_fc.weight": "encoders.28.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_proj.bias": "encoders.28.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_proj.weight": "encoders.28.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_bias": ['encoders.29.attn.to_q.bias', 'encoders.29.attn.to_k.bias', 'encoders.29.attn.to_v.bias'],
+ "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_weight": ['encoders.29.attn.to_q.weight', 'encoders.29.attn.to_k.weight', 'encoders.29.attn.to_v.weight'],
+ "conditioner.embedders.1.model.transformer.resblocks.29.attn.out_proj.bias": "encoders.29.attn.to_out.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.29.attn.out_proj.weight": "encoders.29.attn.to_out.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.29.ln_1.bias": "encoders.29.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.29.ln_1.weight": "encoders.29.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.29.ln_2.bias": "encoders.29.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.29.ln_2.weight": "encoders.29.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_fc.bias": "encoders.29.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_fc.weight": "encoders.29.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_proj.bias": "encoders.29.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_proj.weight": "encoders.29.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_bias": ['encoders.3.attn.to_q.bias', 'encoders.3.attn.to_k.bias', 'encoders.3.attn.to_v.bias'],
+ "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_weight": ['encoders.3.attn.to_q.weight', 'encoders.3.attn.to_k.weight', 'encoders.3.attn.to_v.weight'],
+ "conditioner.embedders.1.model.transformer.resblocks.3.attn.out_proj.bias": "encoders.3.attn.to_out.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.3.attn.out_proj.weight": "encoders.3.attn.to_out.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.3.ln_1.bias": "encoders.3.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.3.ln_1.weight": "encoders.3.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.3.ln_2.bias": "encoders.3.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.3.ln_2.weight": "encoders.3.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_fc.bias": "encoders.3.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_fc.weight": "encoders.3.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_proj.bias": "encoders.3.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_proj.weight": "encoders.3.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_bias": ['encoders.30.attn.to_q.bias', 'encoders.30.attn.to_k.bias', 'encoders.30.attn.to_v.bias'],
+ "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_weight": ['encoders.30.attn.to_q.weight', 'encoders.30.attn.to_k.weight', 'encoders.30.attn.to_v.weight'],
+ "conditioner.embedders.1.model.transformer.resblocks.30.attn.out_proj.bias": "encoders.30.attn.to_out.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.30.attn.out_proj.weight": "encoders.30.attn.to_out.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.30.ln_1.bias": "encoders.30.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.30.ln_1.weight": "encoders.30.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.30.ln_2.bias": "encoders.30.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.30.ln_2.weight": "encoders.30.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_fc.bias": "encoders.30.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_fc.weight": "encoders.30.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_proj.bias": "encoders.30.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_proj.weight": "encoders.30.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_bias": ['encoders.31.attn.to_q.bias', 'encoders.31.attn.to_k.bias', 'encoders.31.attn.to_v.bias'],
+ "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_weight": ['encoders.31.attn.to_q.weight', 'encoders.31.attn.to_k.weight', 'encoders.31.attn.to_v.weight'],
+ "conditioner.embedders.1.model.transformer.resblocks.31.attn.out_proj.bias": "encoders.31.attn.to_out.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.31.attn.out_proj.weight": "encoders.31.attn.to_out.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.31.ln_1.bias": "encoders.31.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.31.ln_1.weight": "encoders.31.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.31.ln_2.bias": "encoders.31.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.31.ln_2.weight": "encoders.31.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_fc.bias": "encoders.31.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_fc.weight": "encoders.31.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_proj.bias": "encoders.31.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_proj.weight": "encoders.31.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_bias": ['encoders.4.attn.to_q.bias', 'encoders.4.attn.to_k.bias', 'encoders.4.attn.to_v.bias'],
+ "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_weight": ['encoders.4.attn.to_q.weight', 'encoders.4.attn.to_k.weight', 'encoders.4.attn.to_v.weight'],
+ "conditioner.embedders.1.model.transformer.resblocks.4.attn.out_proj.bias": "encoders.4.attn.to_out.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.4.attn.out_proj.weight": "encoders.4.attn.to_out.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.4.ln_1.bias": "encoders.4.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.4.ln_1.weight": "encoders.4.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.4.ln_2.bias": "encoders.4.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.4.ln_2.weight": "encoders.4.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_fc.bias": "encoders.4.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_fc.weight": "encoders.4.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_proj.bias": "encoders.4.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_proj.weight": "encoders.4.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_bias": ['encoders.5.attn.to_q.bias', 'encoders.5.attn.to_k.bias', 'encoders.5.attn.to_v.bias'],
+ "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_weight": ['encoders.5.attn.to_q.weight', 'encoders.5.attn.to_k.weight', 'encoders.5.attn.to_v.weight'],
+ "conditioner.embedders.1.model.transformer.resblocks.5.attn.out_proj.bias": "encoders.5.attn.to_out.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.5.attn.out_proj.weight": "encoders.5.attn.to_out.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.5.ln_1.bias": "encoders.5.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.5.ln_1.weight": "encoders.5.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.5.ln_2.bias": "encoders.5.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.5.ln_2.weight": "encoders.5.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_fc.bias": "encoders.5.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_fc.weight": "encoders.5.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_proj.bias": "encoders.5.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_proj.weight": "encoders.5.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_bias": ['encoders.6.attn.to_q.bias', 'encoders.6.attn.to_k.bias', 'encoders.6.attn.to_v.bias'],
+ "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_weight": ['encoders.6.attn.to_q.weight', 'encoders.6.attn.to_k.weight', 'encoders.6.attn.to_v.weight'],
+ "conditioner.embedders.1.model.transformer.resblocks.6.attn.out_proj.bias": "encoders.6.attn.to_out.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.6.attn.out_proj.weight": "encoders.6.attn.to_out.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.6.ln_1.bias": "encoders.6.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.6.ln_1.weight": "encoders.6.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.6.ln_2.bias": "encoders.6.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.6.ln_2.weight": "encoders.6.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_fc.bias": "encoders.6.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_fc.weight": "encoders.6.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_proj.bias": "encoders.6.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_proj.weight": "encoders.6.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_bias": ['encoders.7.attn.to_q.bias', 'encoders.7.attn.to_k.bias', 'encoders.7.attn.to_v.bias'],
+ "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_weight": ['encoders.7.attn.to_q.weight', 'encoders.7.attn.to_k.weight', 'encoders.7.attn.to_v.weight'],
+ "conditioner.embedders.1.model.transformer.resblocks.7.attn.out_proj.bias": "encoders.7.attn.to_out.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.7.attn.out_proj.weight": "encoders.7.attn.to_out.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.7.ln_1.bias": "encoders.7.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.7.ln_1.weight": "encoders.7.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.7.ln_2.bias": "encoders.7.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.7.ln_2.weight": "encoders.7.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_fc.bias": "encoders.7.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_fc.weight": "encoders.7.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_proj.bias": "encoders.7.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_proj.weight": "encoders.7.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_bias": ['encoders.8.attn.to_q.bias', 'encoders.8.attn.to_k.bias', 'encoders.8.attn.to_v.bias'],
+ "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_weight": ['encoders.8.attn.to_q.weight', 'encoders.8.attn.to_k.weight', 'encoders.8.attn.to_v.weight'],
+ "conditioner.embedders.1.model.transformer.resblocks.8.attn.out_proj.bias": "encoders.8.attn.to_out.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.8.attn.out_proj.weight": "encoders.8.attn.to_out.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.8.ln_1.bias": "encoders.8.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.8.ln_1.weight": "encoders.8.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.8.ln_2.bias": "encoders.8.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.8.ln_2.weight": "encoders.8.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_fc.bias": "encoders.8.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_fc.weight": "encoders.8.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_proj.bias": "encoders.8.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_proj.weight": "encoders.8.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_bias": ['encoders.9.attn.to_q.bias', 'encoders.9.attn.to_k.bias', 'encoders.9.attn.to_v.bias'],
+ "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_weight": ['encoders.9.attn.to_q.weight', 'encoders.9.attn.to_k.weight', 'encoders.9.attn.to_v.weight'],
+ "conditioner.embedders.1.model.transformer.resblocks.9.attn.out_proj.bias": "encoders.9.attn.to_out.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.9.attn.out_proj.weight": "encoders.9.attn.to_out.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.9.ln_1.bias": "encoders.9.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.9.ln_1.weight": "encoders.9.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.9.ln_2.bias": "encoders.9.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.9.ln_2.weight": "encoders.9.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_fc.bias": "encoders.9.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_fc.weight": "encoders.9.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias": "encoders.9.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.weight": "encoders.9.fc2.weight",
+ "conditioner.embedders.1.model.text_projection": "text_projection.weight",
+ }
+ state_dict_ = {}
+ for name in state_dict:
+ if name in rename_dict:
+ param = state_dict[name]
+ if name == "conditioner.embedders.1.model.positional_embedding":
+ param = param.reshape((1, param.shape[0], param.shape[1]))
+ elif name == "conditioner.embedders.1.model.text_projection":
+ param = param.T
+ if isinstance(rename_dict[name], str):
+ state_dict_[rename_dict[name]] = param
+ else:
+ length = param.shape[0] // 3
+ for i, rename in enumerate(rename_dict[name]):
+ state_dict_[rename] = param[i*length: i*length+length]
+ return state_dict_
\ No newline at end of file
diff --git a/PusaV1/diffsynth/models/sdxl_unet.py b/PusaV1/diffsynth/models/sdxl_unet.py
new file mode 100644
index 0000000000000000000000000000000000000000..9bc63e63181c999f94421b72843c7b9e03b31d4a
--- /dev/null
+++ b/PusaV1/diffsynth/models/sdxl_unet.py
@@ -0,0 +1,1901 @@
+import torch
+from .sd_unet import Timesteps, ResnetBlock, AttentionBlock, PushBlock, PopBlock, DownSampler, UpSampler
+
+
+class SDXLUNet(torch.nn.Module):
+ def __init__(self, is_kolors=False):
+ super().__init__()
+ self.time_proj = Timesteps(320)
+ self.time_embedding = torch.nn.Sequential(
+ torch.nn.Linear(320, 1280),
+ torch.nn.SiLU(),
+ torch.nn.Linear(1280, 1280)
+ )
+ self.add_time_proj = Timesteps(256)
+ self.add_time_embedding = torch.nn.Sequential(
+ torch.nn.Linear(5632 if is_kolors else 2816, 1280),
+ torch.nn.SiLU(),
+ torch.nn.Linear(1280, 1280)
+ )
+ self.conv_in = torch.nn.Conv2d(4, 320, kernel_size=3, padding=1)
+ self.text_intermediate_proj = torch.nn.Linear(4096, 2048) if is_kolors else None
+
+ self.blocks = torch.nn.ModuleList([
+ # DownBlock2D
+ ResnetBlock(320, 320, 1280),
+ PushBlock(),
+ ResnetBlock(320, 320, 1280),
+ PushBlock(),
+ DownSampler(320),
+ PushBlock(),
+ # CrossAttnDownBlock2D
+ ResnetBlock(320, 640, 1280),
+ AttentionBlock(10, 64, 640, 2, 2048),
+ PushBlock(),
+ ResnetBlock(640, 640, 1280),
+ AttentionBlock(10, 64, 640, 2, 2048),
+ PushBlock(),
+ DownSampler(640),
+ PushBlock(),
+ # CrossAttnDownBlock2D
+ ResnetBlock(640, 1280, 1280),
+ AttentionBlock(20, 64, 1280, 10, 2048),
+ PushBlock(),
+ ResnetBlock(1280, 1280, 1280),
+ AttentionBlock(20, 64, 1280, 10, 2048),
+ PushBlock(),
+ # UNetMidBlock2DCrossAttn
+ ResnetBlock(1280, 1280, 1280),
+ AttentionBlock(20, 64, 1280, 10, 2048),
+ ResnetBlock(1280, 1280, 1280),
+ # CrossAttnUpBlock2D
+ PopBlock(),
+ ResnetBlock(2560, 1280, 1280),
+ AttentionBlock(20, 64, 1280, 10, 2048),
+ PopBlock(),
+ ResnetBlock(2560, 1280, 1280),
+ AttentionBlock(20, 64, 1280, 10, 2048),
+ PopBlock(),
+ ResnetBlock(1920, 1280, 1280),
+ AttentionBlock(20, 64, 1280, 10, 2048),
+ UpSampler(1280),
+ # CrossAttnUpBlock2D
+ PopBlock(),
+ ResnetBlock(1920, 640, 1280),
+ AttentionBlock(10, 64, 640, 2, 2048),
+ PopBlock(),
+ ResnetBlock(1280, 640, 1280),
+ AttentionBlock(10, 64, 640, 2, 2048),
+ PopBlock(),
+ ResnetBlock(960, 640, 1280),
+ AttentionBlock(10, 64, 640, 2, 2048),
+ UpSampler(640),
+ # UpBlock2D
+ PopBlock(),
+ ResnetBlock(960, 320, 1280),
+ PopBlock(),
+ ResnetBlock(640, 320, 1280),
+ PopBlock(),
+ ResnetBlock(640, 320, 1280)
+ ])
+
+ self.conv_norm_out = torch.nn.GroupNorm(num_channels=320, num_groups=32, eps=1e-5)
+ self.conv_act = torch.nn.SiLU()
+ self.conv_out = torch.nn.Conv2d(320, 4, kernel_size=3, padding=1)
+
+ self.is_kolors = is_kolors
+
+ def forward(
+ self,
+ sample, timestep, encoder_hidden_states, add_time_id, add_text_embeds,
+ tiled=False, tile_size=64, tile_stride=8,
+ use_gradient_checkpointing=False,
+ **kwargs
+ ):
+ # 1. time
+ t_emb = self.time_proj(timestep).to(sample.dtype)
+ t_emb = self.time_embedding(t_emb)
+
+ time_embeds = self.add_time_proj(add_time_id)
+ time_embeds = time_embeds.reshape((add_text_embeds.shape[0], -1))
+ add_embeds = torch.concat([add_text_embeds, time_embeds], dim=-1)
+ add_embeds = add_embeds.to(sample.dtype)
+ add_embeds = self.add_time_embedding(add_embeds)
+
+ time_emb = t_emb + add_embeds
+
+ # 2. pre-process
+ height, width = sample.shape[2], sample.shape[3]
+ hidden_states = self.conv_in(sample)
+ text_emb = encoder_hidden_states if self.text_intermediate_proj is None else self.text_intermediate_proj(encoder_hidden_states)
+ res_stack = [hidden_states]
+
+ # 3. blocks
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+ return custom_forward
+ for i, block in enumerate(self.blocks):
+ if self.training and use_gradient_checkpointing and not (isinstance(block, PushBlock) or isinstance(block, PopBlock)):
+ hidden_states, time_emb, text_emb, res_stack = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ hidden_states, time_emb, text_emb, res_stack,
+ use_reentrant=False,
+ )
+ else:
+ hidden_states, time_emb, text_emb, res_stack = block(
+ hidden_states, time_emb, text_emb, res_stack,
+ tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
+ )
+
+ # 4. output
+ hidden_states = self.conv_norm_out(hidden_states)
+ hidden_states = self.conv_act(hidden_states)
+ hidden_states = self.conv_out(hidden_states)
+
+ return hidden_states
+
+ @staticmethod
+ def state_dict_converter():
+ return SDXLUNetStateDictConverter()
+
+
+class SDXLUNetStateDictConverter:
+ def __init__(self):
+ pass
+
+ def from_diffusers(self, state_dict):
+ # architecture
+ block_types = [
+ 'ResnetBlock', 'PushBlock', 'ResnetBlock', 'PushBlock', 'DownSampler', 'PushBlock',
+ 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'DownSampler', 'PushBlock',
+ 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'ResnetBlock', 'AttentionBlock', 'PushBlock',
+ 'ResnetBlock', 'AttentionBlock', 'ResnetBlock',
+ 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'UpSampler',
+ 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'UpSampler',
+ 'PopBlock', 'ResnetBlock', 'PopBlock', 'ResnetBlock', 'PopBlock', 'ResnetBlock'
+ ]
+
+ # Rename each parameter
+ name_list = sorted([name for name in state_dict])
+ rename_dict = {}
+ block_id = {"ResnetBlock": -1, "AttentionBlock": -1, "DownSampler": -1, "UpSampler": -1}
+ last_block_type_with_id = {"ResnetBlock": "", "AttentionBlock": "", "DownSampler": "", "UpSampler": ""}
+ for name in name_list:
+ names = name.split(".")
+ if names[0] in ["conv_in", "conv_norm_out", "conv_out"]:
+ pass
+ elif names[0] in ["encoder_hid_proj"]:
+ names[0] = "text_intermediate_proj"
+ elif names[0] in ["time_embedding", "add_embedding"]:
+ if names[0] == "add_embedding":
+ names[0] = "add_time_embedding"
+ names[1] = {"linear_1": "0", "linear_2": "2"}[names[1]]
+ elif names[0] in ["down_blocks", "mid_block", "up_blocks"]:
+ if names[0] == "mid_block":
+ names.insert(1, "0")
+ block_type = {"resnets": "ResnetBlock", "attentions": "AttentionBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[2]]
+ block_type_with_id = ".".join(names[:4])
+ if block_type_with_id != last_block_type_with_id[block_type]:
+ block_id[block_type] += 1
+ last_block_type_with_id[block_type] = block_type_with_id
+ while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type:
+ block_id[block_type] += 1
+ block_type_with_id = ".".join(names[:4])
+ names = ["blocks", str(block_id[block_type])] + names[4:]
+ if "ff" in names:
+ ff_index = names.index("ff")
+ component = ".".join(names[ff_index:ff_index+3])
+ component = {"ff.net.0": "act_fn", "ff.net.2": "ff"}[component]
+ names = names[:ff_index] + [component] + names[ff_index+3:]
+ if "to_out" in names:
+ names.pop(names.index("to_out") + 1)
+ else:
+ raise ValueError(f"Unknown parameters: {name}")
+ rename_dict[name] = ".".join(names)
+
+ # Convert state_dict
+ state_dict_ = {}
+ for name, param in state_dict.items():
+ if ".proj_in." in name or ".proj_out." in name:
+ param = param.squeeze()
+ state_dict_[rename_dict[name]] = param
+ if "text_intermediate_proj.weight" in state_dict_:
+ return state_dict_, {"is_kolors": True}
+ else:
+ return state_dict_
+
+ def from_civitai(self, state_dict):
+ rename_dict = {
+ "model.diffusion_model.input_blocks.0.0.bias": "conv_in.bias",
+ "model.diffusion_model.input_blocks.0.0.weight": "conv_in.weight",
+ "model.diffusion_model.input_blocks.1.0.emb_layers.1.bias": "blocks.0.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.1.0.emb_layers.1.weight": "blocks.0.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.1.0.in_layers.0.bias": "blocks.0.norm1.bias",
+ "model.diffusion_model.input_blocks.1.0.in_layers.0.weight": "blocks.0.norm1.weight",
+ "model.diffusion_model.input_blocks.1.0.in_layers.2.bias": "blocks.0.conv1.bias",
+ "model.diffusion_model.input_blocks.1.0.in_layers.2.weight": "blocks.0.conv1.weight",
+ "model.diffusion_model.input_blocks.1.0.out_layers.0.bias": "blocks.0.norm2.bias",
+ "model.diffusion_model.input_blocks.1.0.out_layers.0.weight": "blocks.0.norm2.weight",
+ "model.diffusion_model.input_blocks.1.0.out_layers.3.bias": "blocks.0.conv2.bias",
+ "model.diffusion_model.input_blocks.1.0.out_layers.3.weight": "blocks.0.conv2.weight",
+ "model.diffusion_model.input_blocks.2.0.emb_layers.1.bias": "blocks.2.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.2.0.emb_layers.1.weight": "blocks.2.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.2.0.in_layers.0.bias": "blocks.2.norm1.bias",
+ "model.diffusion_model.input_blocks.2.0.in_layers.0.weight": "blocks.2.norm1.weight",
+ "model.diffusion_model.input_blocks.2.0.in_layers.2.bias": "blocks.2.conv1.bias",
+ "model.diffusion_model.input_blocks.2.0.in_layers.2.weight": "blocks.2.conv1.weight",
+ "model.diffusion_model.input_blocks.2.0.out_layers.0.bias": "blocks.2.norm2.bias",
+ "model.diffusion_model.input_blocks.2.0.out_layers.0.weight": "blocks.2.norm2.weight",
+ "model.diffusion_model.input_blocks.2.0.out_layers.3.bias": "blocks.2.conv2.bias",
+ "model.diffusion_model.input_blocks.2.0.out_layers.3.weight": "blocks.2.conv2.weight",
+ "model.diffusion_model.input_blocks.3.0.op.bias": "blocks.4.conv.bias",
+ "model.diffusion_model.input_blocks.3.0.op.weight": "blocks.4.conv.weight",
+ "model.diffusion_model.input_blocks.4.0.emb_layers.1.bias": "blocks.6.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.4.0.emb_layers.1.weight": "blocks.6.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.4.0.in_layers.0.bias": "blocks.6.norm1.bias",
+ "model.diffusion_model.input_blocks.4.0.in_layers.0.weight": "blocks.6.norm1.weight",
+ "model.diffusion_model.input_blocks.4.0.in_layers.2.bias": "blocks.6.conv1.bias",
+ "model.diffusion_model.input_blocks.4.0.in_layers.2.weight": "blocks.6.conv1.weight",
+ "model.diffusion_model.input_blocks.4.0.out_layers.0.bias": "blocks.6.norm2.bias",
+ "model.diffusion_model.input_blocks.4.0.out_layers.0.weight": "blocks.6.norm2.weight",
+ "model.diffusion_model.input_blocks.4.0.out_layers.3.bias": "blocks.6.conv2.bias",
+ "model.diffusion_model.input_blocks.4.0.out_layers.3.weight": "blocks.6.conv2.weight",
+ "model.diffusion_model.input_blocks.4.0.skip_connection.bias": "blocks.6.conv_shortcut.bias",
+ "model.diffusion_model.input_blocks.4.0.skip_connection.weight": "blocks.6.conv_shortcut.weight",
+ "model.diffusion_model.input_blocks.4.1.norm.bias": "blocks.7.norm.bias",
+ "model.diffusion_model.input_blocks.4.1.norm.weight": "blocks.7.norm.weight",
+ "model.diffusion_model.input_blocks.4.1.proj_in.bias": "blocks.7.proj_in.bias",
+ "model.diffusion_model.input_blocks.4.1.proj_in.weight": "blocks.7.proj_in.weight",
+ "model.diffusion_model.input_blocks.4.1.proj_out.bias": "blocks.7.proj_out.bias",
+ "model.diffusion_model.input_blocks.4.1.proj_out.weight": "blocks.7.proj_out.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_k.weight": "blocks.7.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.7.transformer_blocks.0.attn1.to_out.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.7.transformer_blocks.0.attn1.to_out.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_q.weight": "blocks.7.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_v.weight": "blocks.7.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight": "blocks.7.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.7.transformer_blocks.0.attn2.to_out.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.7.transformer_blocks.0.attn2.to_out.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_q.weight": "blocks.7.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_v.weight": "blocks.7.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.7.transformer_blocks.0.act_fn.proj.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.7.transformer_blocks.0.act_fn.proj.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.bias": "blocks.7.transformer_blocks.0.ff.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.weight": "blocks.7.transformer_blocks.0.ff.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm1.bias": "blocks.7.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm1.weight": "blocks.7.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm2.bias": "blocks.7.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm2.weight": "blocks.7.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm3.bias": "blocks.7.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm3.weight": "blocks.7.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn1.to_k.weight": "blocks.7.transformer_blocks.1.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn1.to_out.0.bias": "blocks.7.transformer_blocks.1.attn1.to_out.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn1.to_out.0.weight": "blocks.7.transformer_blocks.1.attn1.to_out.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn1.to_q.weight": "blocks.7.transformer_blocks.1.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn1.to_v.weight": "blocks.7.transformer_blocks.1.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn2.to_k.weight": "blocks.7.transformer_blocks.1.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn2.to_out.0.bias": "blocks.7.transformer_blocks.1.attn2.to_out.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn2.to_out.0.weight": "blocks.7.transformer_blocks.1.attn2.to_out.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn2.to_q.weight": "blocks.7.transformer_blocks.1.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn2.to_v.weight": "blocks.7.transformer_blocks.1.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.ff.net.0.proj.bias": "blocks.7.transformer_blocks.1.act_fn.proj.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.ff.net.0.proj.weight": "blocks.7.transformer_blocks.1.act_fn.proj.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.ff.net.2.bias": "blocks.7.transformer_blocks.1.ff.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.ff.net.2.weight": "blocks.7.transformer_blocks.1.ff.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.norm1.bias": "blocks.7.transformer_blocks.1.norm1.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.norm1.weight": "blocks.7.transformer_blocks.1.norm1.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.norm2.bias": "blocks.7.transformer_blocks.1.norm2.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.norm2.weight": "blocks.7.transformer_blocks.1.norm2.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.norm3.bias": "blocks.7.transformer_blocks.1.norm3.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.norm3.weight": "blocks.7.transformer_blocks.1.norm3.weight",
+ "model.diffusion_model.input_blocks.5.0.emb_layers.1.bias": "blocks.9.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.5.0.emb_layers.1.weight": "blocks.9.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.5.0.in_layers.0.bias": "blocks.9.norm1.bias",
+ "model.diffusion_model.input_blocks.5.0.in_layers.0.weight": "blocks.9.norm1.weight",
+ "model.diffusion_model.input_blocks.5.0.in_layers.2.bias": "blocks.9.conv1.bias",
+ "model.diffusion_model.input_blocks.5.0.in_layers.2.weight": "blocks.9.conv1.weight",
+ "model.diffusion_model.input_blocks.5.0.out_layers.0.bias": "blocks.9.norm2.bias",
+ "model.diffusion_model.input_blocks.5.0.out_layers.0.weight": "blocks.9.norm2.weight",
+ "model.diffusion_model.input_blocks.5.0.out_layers.3.bias": "blocks.9.conv2.bias",
+ "model.diffusion_model.input_blocks.5.0.out_layers.3.weight": "blocks.9.conv2.weight",
+ "model.diffusion_model.input_blocks.5.1.norm.bias": "blocks.10.norm.bias",
+ "model.diffusion_model.input_blocks.5.1.norm.weight": "blocks.10.norm.weight",
+ "model.diffusion_model.input_blocks.5.1.proj_in.bias": "blocks.10.proj_in.bias",
+ "model.diffusion_model.input_blocks.5.1.proj_in.weight": "blocks.10.proj_in.weight",
+ "model.diffusion_model.input_blocks.5.1.proj_out.bias": "blocks.10.proj_out.bias",
+ "model.diffusion_model.input_blocks.5.1.proj_out.weight": "blocks.10.proj_out.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_k.weight": "blocks.10.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.10.transformer_blocks.0.attn1.to_out.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.10.transformer_blocks.0.attn1.to_out.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_q.weight": "blocks.10.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_v.weight": "blocks.10.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_k.weight": "blocks.10.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.10.transformer_blocks.0.attn2.to_out.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.10.transformer_blocks.0.attn2.to_out.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_q.weight": "blocks.10.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_v.weight": "blocks.10.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.10.transformer_blocks.0.act_fn.proj.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.10.transformer_blocks.0.act_fn.proj.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.bias": "blocks.10.transformer_blocks.0.ff.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.weight": "blocks.10.transformer_blocks.0.ff.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm1.bias": "blocks.10.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm1.weight": "blocks.10.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm2.bias": "blocks.10.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm2.weight": "blocks.10.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm3.bias": "blocks.10.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm3.weight": "blocks.10.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn1.to_k.weight": "blocks.10.transformer_blocks.1.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn1.to_out.0.bias": "blocks.10.transformer_blocks.1.attn1.to_out.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn1.to_out.0.weight": "blocks.10.transformer_blocks.1.attn1.to_out.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn1.to_q.weight": "blocks.10.transformer_blocks.1.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn1.to_v.weight": "blocks.10.transformer_blocks.1.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn2.to_k.weight": "blocks.10.transformer_blocks.1.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn2.to_out.0.bias": "blocks.10.transformer_blocks.1.attn2.to_out.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn2.to_out.0.weight": "blocks.10.transformer_blocks.1.attn2.to_out.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn2.to_q.weight": "blocks.10.transformer_blocks.1.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn2.to_v.weight": "blocks.10.transformer_blocks.1.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.ff.net.0.proj.bias": "blocks.10.transformer_blocks.1.act_fn.proj.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.ff.net.0.proj.weight": "blocks.10.transformer_blocks.1.act_fn.proj.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.ff.net.2.bias": "blocks.10.transformer_blocks.1.ff.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.ff.net.2.weight": "blocks.10.transformer_blocks.1.ff.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.norm1.bias": "blocks.10.transformer_blocks.1.norm1.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.norm1.weight": "blocks.10.transformer_blocks.1.norm1.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.norm2.bias": "blocks.10.transformer_blocks.1.norm2.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.norm2.weight": "blocks.10.transformer_blocks.1.norm2.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.norm3.bias": "blocks.10.transformer_blocks.1.norm3.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.norm3.weight": "blocks.10.transformer_blocks.1.norm3.weight",
+ "model.diffusion_model.input_blocks.6.0.op.bias": "blocks.12.conv.bias",
+ "model.diffusion_model.input_blocks.6.0.op.weight": "blocks.12.conv.weight",
+ "model.diffusion_model.input_blocks.7.0.emb_layers.1.bias": "blocks.14.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.7.0.emb_layers.1.weight": "blocks.14.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.7.0.in_layers.0.bias": "blocks.14.norm1.bias",
+ "model.diffusion_model.input_blocks.7.0.in_layers.0.weight": "blocks.14.norm1.weight",
+ "model.diffusion_model.input_blocks.7.0.in_layers.2.bias": "blocks.14.conv1.bias",
+ "model.diffusion_model.input_blocks.7.0.in_layers.2.weight": "blocks.14.conv1.weight",
+ "model.diffusion_model.input_blocks.7.0.out_layers.0.bias": "blocks.14.norm2.bias",
+ "model.diffusion_model.input_blocks.7.0.out_layers.0.weight": "blocks.14.norm2.weight",
+ "model.diffusion_model.input_blocks.7.0.out_layers.3.bias": "blocks.14.conv2.bias",
+ "model.diffusion_model.input_blocks.7.0.out_layers.3.weight": "blocks.14.conv2.weight",
+ "model.diffusion_model.input_blocks.7.0.skip_connection.bias": "blocks.14.conv_shortcut.bias",
+ "model.diffusion_model.input_blocks.7.0.skip_connection.weight": "blocks.14.conv_shortcut.weight",
+ "model.diffusion_model.input_blocks.7.1.norm.bias": "blocks.15.norm.bias",
+ "model.diffusion_model.input_blocks.7.1.norm.weight": "blocks.15.norm.weight",
+ "model.diffusion_model.input_blocks.7.1.proj_in.bias": "blocks.15.proj_in.bias",
+ "model.diffusion_model.input_blocks.7.1.proj_in.weight": "blocks.15.proj_in.weight",
+ "model.diffusion_model.input_blocks.7.1.proj_out.bias": "blocks.15.proj_out.bias",
+ "model.diffusion_model.input_blocks.7.1.proj_out.weight": "blocks.15.proj_out.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_k.weight": "blocks.15.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.15.transformer_blocks.0.attn1.to_out.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.15.transformer_blocks.0.attn1.to_out.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_q.weight": "blocks.15.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_v.weight": "blocks.15.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_k.weight": "blocks.15.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.15.transformer_blocks.0.attn2.to_out.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.15.transformer_blocks.0.attn2.to_out.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_q.weight": "blocks.15.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_v.weight": "blocks.15.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.15.transformer_blocks.0.act_fn.proj.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.15.transformer_blocks.0.act_fn.proj.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.bias": "blocks.15.transformer_blocks.0.ff.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.weight": "blocks.15.transformer_blocks.0.ff.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm1.bias": "blocks.15.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm1.weight": "blocks.15.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm2.bias": "blocks.15.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm2.weight": "blocks.15.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm3.bias": "blocks.15.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm3.weight": "blocks.15.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn1.to_k.weight": "blocks.15.transformer_blocks.1.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn1.to_out.0.bias": "blocks.15.transformer_blocks.1.attn1.to_out.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn1.to_out.0.weight": "blocks.15.transformer_blocks.1.attn1.to_out.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn1.to_q.weight": "blocks.15.transformer_blocks.1.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn1.to_v.weight": "blocks.15.transformer_blocks.1.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn2.to_k.weight": "blocks.15.transformer_blocks.1.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn2.to_out.0.bias": "blocks.15.transformer_blocks.1.attn2.to_out.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn2.to_out.0.weight": "blocks.15.transformer_blocks.1.attn2.to_out.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn2.to_q.weight": "blocks.15.transformer_blocks.1.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn2.to_v.weight": "blocks.15.transformer_blocks.1.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.ff.net.0.proj.bias": "blocks.15.transformer_blocks.1.act_fn.proj.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.ff.net.0.proj.weight": "blocks.15.transformer_blocks.1.act_fn.proj.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.ff.net.2.bias": "blocks.15.transformer_blocks.1.ff.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.ff.net.2.weight": "blocks.15.transformer_blocks.1.ff.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm1.bias": "blocks.15.transformer_blocks.1.norm1.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm1.weight": "blocks.15.transformer_blocks.1.norm1.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm2.bias": "blocks.15.transformer_blocks.1.norm2.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm2.weight": "blocks.15.transformer_blocks.1.norm2.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm3.bias": "blocks.15.transformer_blocks.1.norm3.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm3.weight": "blocks.15.transformer_blocks.1.norm3.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn1.to_k.weight": "blocks.15.transformer_blocks.2.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn1.to_out.0.bias": "blocks.15.transformer_blocks.2.attn1.to_out.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn1.to_out.0.weight": "blocks.15.transformer_blocks.2.attn1.to_out.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn1.to_q.weight": "blocks.15.transformer_blocks.2.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn1.to_v.weight": "blocks.15.transformer_blocks.2.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn2.to_k.weight": "blocks.15.transformer_blocks.2.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn2.to_out.0.bias": "blocks.15.transformer_blocks.2.attn2.to_out.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn2.to_out.0.weight": "blocks.15.transformer_blocks.2.attn2.to_out.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn2.to_q.weight": "blocks.15.transformer_blocks.2.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn2.to_v.weight": "blocks.15.transformer_blocks.2.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.ff.net.0.proj.bias": "blocks.15.transformer_blocks.2.act_fn.proj.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.ff.net.0.proj.weight": "blocks.15.transformer_blocks.2.act_fn.proj.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.ff.net.2.bias": "blocks.15.transformer_blocks.2.ff.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.ff.net.2.weight": "blocks.15.transformer_blocks.2.ff.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.norm1.bias": "blocks.15.transformer_blocks.2.norm1.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.norm1.weight": "blocks.15.transformer_blocks.2.norm1.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.norm2.bias": "blocks.15.transformer_blocks.2.norm2.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.norm2.weight": "blocks.15.transformer_blocks.2.norm2.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.norm3.bias": "blocks.15.transformer_blocks.2.norm3.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.norm3.weight": "blocks.15.transformer_blocks.2.norm3.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn1.to_k.weight": "blocks.15.transformer_blocks.3.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn1.to_out.0.bias": "blocks.15.transformer_blocks.3.attn1.to_out.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn1.to_out.0.weight": "blocks.15.transformer_blocks.3.attn1.to_out.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn1.to_q.weight": "blocks.15.transformer_blocks.3.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn1.to_v.weight": "blocks.15.transformer_blocks.3.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn2.to_k.weight": "blocks.15.transformer_blocks.3.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn2.to_out.0.bias": "blocks.15.transformer_blocks.3.attn2.to_out.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn2.to_out.0.weight": "blocks.15.transformer_blocks.3.attn2.to_out.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn2.to_q.weight": "blocks.15.transformer_blocks.3.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn2.to_v.weight": "blocks.15.transformer_blocks.3.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.ff.net.0.proj.bias": "blocks.15.transformer_blocks.3.act_fn.proj.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.ff.net.0.proj.weight": "blocks.15.transformer_blocks.3.act_fn.proj.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.ff.net.2.bias": "blocks.15.transformer_blocks.3.ff.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.ff.net.2.weight": "blocks.15.transformer_blocks.3.ff.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.norm1.bias": "blocks.15.transformer_blocks.3.norm1.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.norm1.weight": "blocks.15.transformer_blocks.3.norm1.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.norm2.bias": "blocks.15.transformer_blocks.3.norm2.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.norm2.weight": "blocks.15.transformer_blocks.3.norm2.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.norm3.bias": "blocks.15.transformer_blocks.3.norm3.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.norm3.weight": "blocks.15.transformer_blocks.3.norm3.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.attn1.to_k.weight": "blocks.15.transformer_blocks.4.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.attn1.to_out.0.bias": "blocks.15.transformer_blocks.4.attn1.to_out.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.attn1.to_out.0.weight": "blocks.15.transformer_blocks.4.attn1.to_out.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.attn1.to_q.weight": "blocks.15.transformer_blocks.4.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.attn1.to_v.weight": "blocks.15.transformer_blocks.4.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.attn2.to_k.weight": "blocks.15.transformer_blocks.4.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.attn2.to_out.0.bias": "blocks.15.transformer_blocks.4.attn2.to_out.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.attn2.to_out.0.weight": "blocks.15.transformer_blocks.4.attn2.to_out.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.attn2.to_q.weight": "blocks.15.transformer_blocks.4.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.attn2.to_v.weight": "blocks.15.transformer_blocks.4.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.ff.net.0.proj.bias": "blocks.15.transformer_blocks.4.act_fn.proj.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.ff.net.0.proj.weight": "blocks.15.transformer_blocks.4.act_fn.proj.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.ff.net.2.bias": "blocks.15.transformer_blocks.4.ff.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.ff.net.2.weight": "blocks.15.transformer_blocks.4.ff.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.norm1.bias": "blocks.15.transformer_blocks.4.norm1.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.norm1.weight": "blocks.15.transformer_blocks.4.norm1.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.norm2.bias": "blocks.15.transformer_blocks.4.norm2.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.norm2.weight": "blocks.15.transformer_blocks.4.norm2.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.norm3.bias": "blocks.15.transformer_blocks.4.norm3.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.norm3.weight": "blocks.15.transformer_blocks.4.norm3.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.attn1.to_k.weight": "blocks.15.transformer_blocks.5.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.attn1.to_out.0.bias": "blocks.15.transformer_blocks.5.attn1.to_out.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.attn1.to_out.0.weight": "blocks.15.transformer_blocks.5.attn1.to_out.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.attn1.to_q.weight": "blocks.15.transformer_blocks.5.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.attn1.to_v.weight": "blocks.15.transformer_blocks.5.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.attn2.to_k.weight": "blocks.15.transformer_blocks.5.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.attn2.to_out.0.bias": "blocks.15.transformer_blocks.5.attn2.to_out.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.attn2.to_out.0.weight": "blocks.15.transformer_blocks.5.attn2.to_out.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.attn2.to_q.weight": "blocks.15.transformer_blocks.5.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.attn2.to_v.weight": "blocks.15.transformer_blocks.5.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.ff.net.0.proj.bias": "blocks.15.transformer_blocks.5.act_fn.proj.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.ff.net.0.proj.weight": "blocks.15.transformer_blocks.5.act_fn.proj.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.ff.net.2.bias": "blocks.15.transformer_blocks.5.ff.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.ff.net.2.weight": "blocks.15.transformer_blocks.5.ff.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.norm1.bias": "blocks.15.transformer_blocks.5.norm1.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.norm1.weight": "blocks.15.transformer_blocks.5.norm1.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.norm2.bias": "blocks.15.transformer_blocks.5.norm2.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.norm2.weight": "blocks.15.transformer_blocks.5.norm2.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.norm3.bias": "blocks.15.transformer_blocks.5.norm3.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.norm3.weight": "blocks.15.transformer_blocks.5.norm3.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.attn1.to_k.weight": "blocks.15.transformer_blocks.6.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.attn1.to_out.0.bias": "blocks.15.transformer_blocks.6.attn1.to_out.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.attn1.to_out.0.weight": "blocks.15.transformer_blocks.6.attn1.to_out.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.attn1.to_q.weight": "blocks.15.transformer_blocks.6.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.attn1.to_v.weight": "blocks.15.transformer_blocks.6.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.attn2.to_k.weight": "blocks.15.transformer_blocks.6.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.attn2.to_out.0.bias": "blocks.15.transformer_blocks.6.attn2.to_out.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.attn2.to_out.0.weight": "blocks.15.transformer_blocks.6.attn2.to_out.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.attn2.to_q.weight": "blocks.15.transformer_blocks.6.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.attn2.to_v.weight": "blocks.15.transformer_blocks.6.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.ff.net.0.proj.bias": "blocks.15.transformer_blocks.6.act_fn.proj.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.ff.net.0.proj.weight": "blocks.15.transformer_blocks.6.act_fn.proj.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.ff.net.2.bias": "blocks.15.transformer_blocks.6.ff.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.ff.net.2.weight": "blocks.15.transformer_blocks.6.ff.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.norm1.bias": "blocks.15.transformer_blocks.6.norm1.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.norm1.weight": "blocks.15.transformer_blocks.6.norm1.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.norm2.bias": "blocks.15.transformer_blocks.6.norm2.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.norm2.weight": "blocks.15.transformer_blocks.6.norm2.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.norm3.bias": "blocks.15.transformer_blocks.6.norm3.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.norm3.weight": "blocks.15.transformer_blocks.6.norm3.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.attn1.to_k.weight": "blocks.15.transformer_blocks.7.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.attn1.to_out.0.bias": "blocks.15.transformer_blocks.7.attn1.to_out.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.attn1.to_out.0.weight": "blocks.15.transformer_blocks.7.attn1.to_out.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.attn1.to_q.weight": "blocks.15.transformer_blocks.7.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.attn1.to_v.weight": "blocks.15.transformer_blocks.7.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.attn2.to_k.weight": "blocks.15.transformer_blocks.7.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.attn2.to_out.0.bias": "blocks.15.transformer_blocks.7.attn2.to_out.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.attn2.to_out.0.weight": "blocks.15.transformer_blocks.7.attn2.to_out.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.attn2.to_q.weight": "blocks.15.transformer_blocks.7.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.attn2.to_v.weight": "blocks.15.transformer_blocks.7.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.ff.net.0.proj.bias": "blocks.15.transformer_blocks.7.act_fn.proj.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.ff.net.0.proj.weight": "blocks.15.transformer_blocks.7.act_fn.proj.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.ff.net.2.bias": "blocks.15.transformer_blocks.7.ff.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.ff.net.2.weight": "blocks.15.transformer_blocks.7.ff.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.norm1.bias": "blocks.15.transformer_blocks.7.norm1.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.norm1.weight": "blocks.15.transformer_blocks.7.norm1.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.norm2.bias": "blocks.15.transformer_blocks.7.norm2.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.norm2.weight": "blocks.15.transformer_blocks.7.norm2.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.norm3.bias": "blocks.15.transformer_blocks.7.norm3.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.norm3.weight": "blocks.15.transformer_blocks.7.norm3.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.attn1.to_k.weight": "blocks.15.transformer_blocks.8.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.attn1.to_out.0.bias": "blocks.15.transformer_blocks.8.attn1.to_out.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.attn1.to_out.0.weight": "blocks.15.transformer_blocks.8.attn1.to_out.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.attn1.to_q.weight": "blocks.15.transformer_blocks.8.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.attn1.to_v.weight": "blocks.15.transformer_blocks.8.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.attn2.to_k.weight": "blocks.15.transformer_blocks.8.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.attn2.to_out.0.bias": "blocks.15.transformer_blocks.8.attn2.to_out.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.attn2.to_out.0.weight": "blocks.15.transformer_blocks.8.attn2.to_out.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.attn2.to_q.weight": "blocks.15.transformer_blocks.8.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.attn2.to_v.weight": "blocks.15.transformer_blocks.8.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.ff.net.0.proj.bias": "blocks.15.transformer_blocks.8.act_fn.proj.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.ff.net.0.proj.weight": "blocks.15.transformer_blocks.8.act_fn.proj.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.ff.net.2.bias": "blocks.15.transformer_blocks.8.ff.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.ff.net.2.weight": "blocks.15.transformer_blocks.8.ff.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.norm1.bias": "blocks.15.transformer_blocks.8.norm1.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.norm1.weight": "blocks.15.transformer_blocks.8.norm1.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.norm2.bias": "blocks.15.transformer_blocks.8.norm2.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.norm2.weight": "blocks.15.transformer_blocks.8.norm2.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.norm3.bias": "blocks.15.transformer_blocks.8.norm3.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.norm3.weight": "blocks.15.transformer_blocks.8.norm3.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.attn1.to_k.weight": "blocks.15.transformer_blocks.9.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.attn1.to_out.0.bias": "blocks.15.transformer_blocks.9.attn1.to_out.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.attn1.to_out.0.weight": "blocks.15.transformer_blocks.9.attn1.to_out.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.attn1.to_q.weight": "blocks.15.transformer_blocks.9.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.attn1.to_v.weight": "blocks.15.transformer_blocks.9.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.attn2.to_k.weight": "blocks.15.transformer_blocks.9.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.attn2.to_out.0.bias": "blocks.15.transformer_blocks.9.attn2.to_out.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.attn2.to_out.0.weight": "blocks.15.transformer_blocks.9.attn2.to_out.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.attn2.to_q.weight": "blocks.15.transformer_blocks.9.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.attn2.to_v.weight": "blocks.15.transformer_blocks.9.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.ff.net.0.proj.bias": "blocks.15.transformer_blocks.9.act_fn.proj.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.ff.net.0.proj.weight": "blocks.15.transformer_blocks.9.act_fn.proj.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.ff.net.2.bias": "blocks.15.transformer_blocks.9.ff.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.ff.net.2.weight": "blocks.15.transformer_blocks.9.ff.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.norm1.bias": "blocks.15.transformer_blocks.9.norm1.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.norm1.weight": "blocks.15.transformer_blocks.9.norm1.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.norm2.bias": "blocks.15.transformer_blocks.9.norm2.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.norm2.weight": "blocks.15.transformer_blocks.9.norm2.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.norm3.bias": "blocks.15.transformer_blocks.9.norm3.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.norm3.weight": "blocks.15.transformer_blocks.9.norm3.weight",
+ "model.diffusion_model.input_blocks.8.0.emb_layers.1.bias": "blocks.17.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.8.0.emb_layers.1.weight": "blocks.17.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.8.0.in_layers.0.bias": "blocks.17.norm1.bias",
+ "model.diffusion_model.input_blocks.8.0.in_layers.0.weight": "blocks.17.norm1.weight",
+ "model.diffusion_model.input_blocks.8.0.in_layers.2.bias": "blocks.17.conv1.bias",
+ "model.diffusion_model.input_blocks.8.0.in_layers.2.weight": "blocks.17.conv1.weight",
+ "model.diffusion_model.input_blocks.8.0.out_layers.0.bias": "blocks.17.norm2.bias",
+ "model.diffusion_model.input_blocks.8.0.out_layers.0.weight": "blocks.17.norm2.weight",
+ "model.diffusion_model.input_blocks.8.0.out_layers.3.bias": "blocks.17.conv2.bias",
+ "model.diffusion_model.input_blocks.8.0.out_layers.3.weight": "blocks.17.conv2.weight",
+ "model.diffusion_model.input_blocks.8.1.norm.bias": "blocks.18.norm.bias",
+ "model.diffusion_model.input_blocks.8.1.norm.weight": "blocks.18.norm.weight",
+ "model.diffusion_model.input_blocks.8.1.proj_in.bias": "blocks.18.proj_in.bias",
+ "model.diffusion_model.input_blocks.8.1.proj_in.weight": "blocks.18.proj_in.weight",
+ "model.diffusion_model.input_blocks.8.1.proj_out.bias": "blocks.18.proj_out.bias",
+ "model.diffusion_model.input_blocks.8.1.proj_out.weight": "blocks.18.proj_out.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_k.weight": "blocks.18.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.18.transformer_blocks.0.attn1.to_out.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.18.transformer_blocks.0.attn1.to_out.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_q.weight": "blocks.18.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_v.weight": "blocks.18.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_k.weight": "blocks.18.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.18.transformer_blocks.0.attn2.to_out.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.18.transformer_blocks.0.attn2.to_out.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_q.weight": "blocks.18.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_v.weight": "blocks.18.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.18.transformer_blocks.0.act_fn.proj.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.18.transformer_blocks.0.act_fn.proj.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.bias": "blocks.18.transformer_blocks.0.ff.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.weight": "blocks.18.transformer_blocks.0.ff.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm1.bias": "blocks.18.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm1.weight": "blocks.18.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm2.bias": "blocks.18.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm2.weight": "blocks.18.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm3.bias": "blocks.18.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm3.weight": "blocks.18.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn1.to_k.weight": "blocks.18.transformer_blocks.1.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn1.to_out.0.bias": "blocks.18.transformer_blocks.1.attn1.to_out.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn1.to_out.0.weight": "blocks.18.transformer_blocks.1.attn1.to_out.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn1.to_q.weight": "blocks.18.transformer_blocks.1.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn1.to_v.weight": "blocks.18.transformer_blocks.1.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn2.to_k.weight": "blocks.18.transformer_blocks.1.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn2.to_out.0.bias": "blocks.18.transformer_blocks.1.attn2.to_out.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn2.to_out.0.weight": "blocks.18.transformer_blocks.1.attn2.to_out.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn2.to_q.weight": "blocks.18.transformer_blocks.1.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn2.to_v.weight": "blocks.18.transformer_blocks.1.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.ff.net.0.proj.bias": "blocks.18.transformer_blocks.1.act_fn.proj.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.ff.net.0.proj.weight": "blocks.18.transformer_blocks.1.act_fn.proj.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.ff.net.2.bias": "blocks.18.transformer_blocks.1.ff.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.ff.net.2.weight": "blocks.18.transformer_blocks.1.ff.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm1.bias": "blocks.18.transformer_blocks.1.norm1.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm1.weight": "blocks.18.transformer_blocks.1.norm1.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm2.bias": "blocks.18.transformer_blocks.1.norm2.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm2.weight": "blocks.18.transformer_blocks.1.norm2.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm3.bias": "blocks.18.transformer_blocks.1.norm3.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm3.weight": "blocks.18.transformer_blocks.1.norm3.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn1.to_k.weight": "blocks.18.transformer_blocks.2.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn1.to_out.0.bias": "blocks.18.transformer_blocks.2.attn1.to_out.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn1.to_out.0.weight": "blocks.18.transformer_blocks.2.attn1.to_out.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn1.to_q.weight": "blocks.18.transformer_blocks.2.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn1.to_v.weight": "blocks.18.transformer_blocks.2.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn2.to_k.weight": "blocks.18.transformer_blocks.2.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn2.to_out.0.bias": "blocks.18.transformer_blocks.2.attn2.to_out.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn2.to_out.0.weight": "blocks.18.transformer_blocks.2.attn2.to_out.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn2.to_q.weight": "blocks.18.transformer_blocks.2.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn2.to_v.weight": "blocks.18.transformer_blocks.2.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.ff.net.0.proj.bias": "blocks.18.transformer_blocks.2.act_fn.proj.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.ff.net.0.proj.weight": "blocks.18.transformer_blocks.2.act_fn.proj.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.ff.net.2.bias": "blocks.18.transformer_blocks.2.ff.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.ff.net.2.weight": "blocks.18.transformer_blocks.2.ff.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.norm1.bias": "blocks.18.transformer_blocks.2.norm1.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.norm1.weight": "blocks.18.transformer_blocks.2.norm1.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.norm2.bias": "blocks.18.transformer_blocks.2.norm2.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.norm2.weight": "blocks.18.transformer_blocks.2.norm2.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.norm3.bias": "blocks.18.transformer_blocks.2.norm3.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.norm3.weight": "blocks.18.transformer_blocks.2.norm3.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn1.to_k.weight": "blocks.18.transformer_blocks.3.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn1.to_out.0.bias": "blocks.18.transformer_blocks.3.attn1.to_out.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn1.to_out.0.weight": "blocks.18.transformer_blocks.3.attn1.to_out.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn1.to_q.weight": "blocks.18.transformer_blocks.3.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn1.to_v.weight": "blocks.18.transformer_blocks.3.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn2.to_k.weight": "blocks.18.transformer_blocks.3.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn2.to_out.0.bias": "blocks.18.transformer_blocks.3.attn2.to_out.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn2.to_out.0.weight": "blocks.18.transformer_blocks.3.attn2.to_out.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn2.to_q.weight": "blocks.18.transformer_blocks.3.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn2.to_v.weight": "blocks.18.transformer_blocks.3.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.ff.net.0.proj.bias": "blocks.18.transformer_blocks.3.act_fn.proj.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.ff.net.0.proj.weight": "blocks.18.transformer_blocks.3.act_fn.proj.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.ff.net.2.bias": "blocks.18.transformer_blocks.3.ff.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.ff.net.2.weight": "blocks.18.transformer_blocks.3.ff.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.norm1.bias": "blocks.18.transformer_blocks.3.norm1.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.norm1.weight": "blocks.18.transformer_blocks.3.norm1.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.norm2.bias": "blocks.18.transformer_blocks.3.norm2.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.norm2.weight": "blocks.18.transformer_blocks.3.norm2.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.norm3.bias": "blocks.18.transformer_blocks.3.norm3.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.norm3.weight": "blocks.18.transformer_blocks.3.norm3.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.attn1.to_k.weight": "blocks.18.transformer_blocks.4.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.attn1.to_out.0.bias": "blocks.18.transformer_blocks.4.attn1.to_out.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.attn1.to_out.0.weight": "blocks.18.transformer_blocks.4.attn1.to_out.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.attn1.to_q.weight": "blocks.18.transformer_blocks.4.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.attn1.to_v.weight": "blocks.18.transformer_blocks.4.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.attn2.to_k.weight": "blocks.18.transformer_blocks.4.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.attn2.to_out.0.bias": "blocks.18.transformer_blocks.4.attn2.to_out.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.attn2.to_out.0.weight": "blocks.18.transformer_blocks.4.attn2.to_out.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.attn2.to_q.weight": "blocks.18.transformer_blocks.4.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.attn2.to_v.weight": "blocks.18.transformer_blocks.4.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.ff.net.0.proj.bias": "blocks.18.transformer_blocks.4.act_fn.proj.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.ff.net.0.proj.weight": "blocks.18.transformer_blocks.4.act_fn.proj.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.ff.net.2.bias": "blocks.18.transformer_blocks.4.ff.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.ff.net.2.weight": "blocks.18.transformer_blocks.4.ff.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.norm1.bias": "blocks.18.transformer_blocks.4.norm1.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.norm1.weight": "blocks.18.transformer_blocks.4.norm1.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.norm2.bias": "blocks.18.transformer_blocks.4.norm2.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.norm2.weight": "blocks.18.transformer_blocks.4.norm2.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.norm3.bias": "blocks.18.transformer_blocks.4.norm3.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.norm3.weight": "blocks.18.transformer_blocks.4.norm3.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.attn1.to_k.weight": "blocks.18.transformer_blocks.5.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.attn1.to_out.0.bias": "blocks.18.transformer_blocks.5.attn1.to_out.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.attn1.to_out.0.weight": "blocks.18.transformer_blocks.5.attn1.to_out.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.attn1.to_q.weight": "blocks.18.transformer_blocks.5.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.attn1.to_v.weight": "blocks.18.transformer_blocks.5.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.attn2.to_k.weight": "blocks.18.transformer_blocks.5.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.attn2.to_out.0.bias": "blocks.18.transformer_blocks.5.attn2.to_out.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.attn2.to_out.0.weight": "blocks.18.transformer_blocks.5.attn2.to_out.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.attn2.to_q.weight": "blocks.18.transformer_blocks.5.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.attn2.to_v.weight": "blocks.18.transformer_blocks.5.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.ff.net.0.proj.bias": "blocks.18.transformer_blocks.5.act_fn.proj.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.ff.net.0.proj.weight": "blocks.18.transformer_blocks.5.act_fn.proj.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.ff.net.2.bias": "blocks.18.transformer_blocks.5.ff.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.ff.net.2.weight": "blocks.18.transformer_blocks.5.ff.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.norm1.bias": "blocks.18.transformer_blocks.5.norm1.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.norm1.weight": "blocks.18.transformer_blocks.5.norm1.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.norm2.bias": "blocks.18.transformer_blocks.5.norm2.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.norm2.weight": "blocks.18.transformer_blocks.5.norm2.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.norm3.bias": "blocks.18.transformer_blocks.5.norm3.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.norm3.weight": "blocks.18.transformer_blocks.5.norm3.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.attn1.to_k.weight": "blocks.18.transformer_blocks.6.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.attn1.to_out.0.bias": "blocks.18.transformer_blocks.6.attn1.to_out.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.attn1.to_out.0.weight": "blocks.18.transformer_blocks.6.attn1.to_out.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.attn1.to_q.weight": "blocks.18.transformer_blocks.6.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.attn1.to_v.weight": "blocks.18.transformer_blocks.6.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.attn2.to_k.weight": "blocks.18.transformer_blocks.6.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.attn2.to_out.0.bias": "blocks.18.transformer_blocks.6.attn2.to_out.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.attn2.to_out.0.weight": "blocks.18.transformer_blocks.6.attn2.to_out.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.attn2.to_q.weight": "blocks.18.transformer_blocks.6.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.attn2.to_v.weight": "blocks.18.transformer_blocks.6.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.ff.net.0.proj.bias": "blocks.18.transformer_blocks.6.act_fn.proj.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.ff.net.0.proj.weight": "blocks.18.transformer_blocks.6.act_fn.proj.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.ff.net.2.bias": "blocks.18.transformer_blocks.6.ff.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.ff.net.2.weight": "blocks.18.transformer_blocks.6.ff.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.norm1.bias": "blocks.18.transformer_blocks.6.norm1.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.norm1.weight": "blocks.18.transformer_blocks.6.norm1.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.norm2.bias": "blocks.18.transformer_blocks.6.norm2.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.norm2.weight": "blocks.18.transformer_blocks.6.norm2.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.norm3.bias": "blocks.18.transformer_blocks.6.norm3.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.norm3.weight": "blocks.18.transformer_blocks.6.norm3.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.attn1.to_k.weight": "blocks.18.transformer_blocks.7.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.attn1.to_out.0.bias": "blocks.18.transformer_blocks.7.attn1.to_out.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.attn1.to_out.0.weight": "blocks.18.transformer_blocks.7.attn1.to_out.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.attn1.to_q.weight": "blocks.18.transformer_blocks.7.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.attn1.to_v.weight": "blocks.18.transformer_blocks.7.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.attn2.to_k.weight": "blocks.18.transformer_blocks.7.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.attn2.to_out.0.bias": "blocks.18.transformer_blocks.7.attn2.to_out.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.attn2.to_out.0.weight": "blocks.18.transformer_blocks.7.attn2.to_out.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.attn2.to_q.weight": "blocks.18.transformer_blocks.7.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.attn2.to_v.weight": "blocks.18.transformer_blocks.7.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.ff.net.0.proj.bias": "blocks.18.transformer_blocks.7.act_fn.proj.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.ff.net.0.proj.weight": "blocks.18.transformer_blocks.7.act_fn.proj.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.ff.net.2.bias": "blocks.18.transformer_blocks.7.ff.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.ff.net.2.weight": "blocks.18.transformer_blocks.7.ff.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.norm1.bias": "blocks.18.transformer_blocks.7.norm1.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.norm1.weight": "blocks.18.transformer_blocks.7.norm1.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.norm2.bias": "blocks.18.transformer_blocks.7.norm2.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.norm2.weight": "blocks.18.transformer_blocks.7.norm2.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.norm3.bias": "blocks.18.transformer_blocks.7.norm3.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.norm3.weight": "blocks.18.transformer_blocks.7.norm3.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.attn1.to_k.weight": "blocks.18.transformer_blocks.8.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.attn1.to_out.0.bias": "blocks.18.transformer_blocks.8.attn1.to_out.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.attn1.to_out.0.weight": "blocks.18.transformer_blocks.8.attn1.to_out.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.attn1.to_q.weight": "blocks.18.transformer_blocks.8.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.attn1.to_v.weight": "blocks.18.transformer_blocks.8.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.attn2.to_k.weight": "blocks.18.transformer_blocks.8.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.attn2.to_out.0.bias": "blocks.18.transformer_blocks.8.attn2.to_out.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.attn2.to_out.0.weight": "blocks.18.transformer_blocks.8.attn2.to_out.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.attn2.to_q.weight": "blocks.18.transformer_blocks.8.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.attn2.to_v.weight": "blocks.18.transformer_blocks.8.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.ff.net.0.proj.bias": "blocks.18.transformer_blocks.8.act_fn.proj.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.ff.net.0.proj.weight": "blocks.18.transformer_blocks.8.act_fn.proj.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.ff.net.2.bias": "blocks.18.transformer_blocks.8.ff.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.ff.net.2.weight": "blocks.18.transformer_blocks.8.ff.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.norm1.bias": "blocks.18.transformer_blocks.8.norm1.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.norm1.weight": "blocks.18.transformer_blocks.8.norm1.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.norm2.bias": "blocks.18.transformer_blocks.8.norm2.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.norm2.weight": "blocks.18.transformer_blocks.8.norm2.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.norm3.bias": "blocks.18.transformer_blocks.8.norm3.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.norm3.weight": "blocks.18.transformer_blocks.8.norm3.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.attn1.to_k.weight": "blocks.18.transformer_blocks.9.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.attn1.to_out.0.bias": "blocks.18.transformer_blocks.9.attn1.to_out.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.attn1.to_out.0.weight": "blocks.18.transformer_blocks.9.attn1.to_out.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.attn1.to_q.weight": "blocks.18.transformer_blocks.9.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.attn1.to_v.weight": "blocks.18.transformer_blocks.9.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.attn2.to_k.weight": "blocks.18.transformer_blocks.9.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.attn2.to_out.0.bias": "blocks.18.transformer_blocks.9.attn2.to_out.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.attn2.to_out.0.weight": "blocks.18.transformer_blocks.9.attn2.to_out.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.attn2.to_q.weight": "blocks.18.transformer_blocks.9.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.attn2.to_v.weight": "blocks.18.transformer_blocks.9.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.ff.net.0.proj.bias": "blocks.18.transformer_blocks.9.act_fn.proj.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.ff.net.0.proj.weight": "blocks.18.transformer_blocks.9.act_fn.proj.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.ff.net.2.bias": "blocks.18.transformer_blocks.9.ff.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.ff.net.2.weight": "blocks.18.transformer_blocks.9.ff.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.norm1.bias": "blocks.18.transformer_blocks.9.norm1.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.norm1.weight": "blocks.18.transformer_blocks.9.norm1.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.norm2.bias": "blocks.18.transformer_blocks.9.norm2.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.norm2.weight": "blocks.18.transformer_blocks.9.norm2.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.norm3.bias": "blocks.18.transformer_blocks.9.norm3.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.norm3.weight": "blocks.18.transformer_blocks.9.norm3.weight",
+ "model.diffusion_model.label_emb.0.0.bias": "add_time_embedding.0.bias",
+ "model.diffusion_model.label_emb.0.0.weight": "add_time_embedding.0.weight",
+ "model.diffusion_model.label_emb.0.2.bias": "add_time_embedding.2.bias",
+ "model.diffusion_model.label_emb.0.2.weight": "add_time_embedding.2.weight",
+ "model.diffusion_model.middle_block.0.emb_layers.1.bias": "blocks.20.time_emb_proj.bias",
+ "model.diffusion_model.middle_block.0.emb_layers.1.weight": "blocks.20.time_emb_proj.weight",
+ "model.diffusion_model.middle_block.0.in_layers.0.bias": "blocks.20.norm1.bias",
+ "model.diffusion_model.middle_block.0.in_layers.0.weight": "blocks.20.norm1.weight",
+ "model.diffusion_model.middle_block.0.in_layers.2.bias": "blocks.20.conv1.bias",
+ "model.diffusion_model.middle_block.0.in_layers.2.weight": "blocks.20.conv1.weight",
+ "model.diffusion_model.middle_block.0.out_layers.0.bias": "blocks.20.norm2.bias",
+ "model.diffusion_model.middle_block.0.out_layers.0.weight": "blocks.20.norm2.weight",
+ "model.diffusion_model.middle_block.0.out_layers.3.bias": "blocks.20.conv2.bias",
+ "model.diffusion_model.middle_block.0.out_layers.3.weight": "blocks.20.conv2.weight",
+ "model.diffusion_model.middle_block.1.norm.bias": "blocks.21.norm.bias",
+ "model.diffusion_model.middle_block.1.norm.weight": "blocks.21.norm.weight",
+ "model.diffusion_model.middle_block.1.proj_in.bias": "blocks.21.proj_in.bias",
+ "model.diffusion_model.middle_block.1.proj_in.weight": "blocks.21.proj_in.weight",
+ "model.diffusion_model.middle_block.1.proj_out.bias": "blocks.21.proj_out.bias",
+ "model.diffusion_model.middle_block.1.proj_out.weight": "blocks.21.proj_out.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_k.weight": "blocks.21.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.21.transformer_blocks.0.attn1.to_out.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.21.transformer_blocks.0.attn1.to_out.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight": "blocks.21.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_v.weight": "blocks.21.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_k.weight": "blocks.21.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.21.transformer_blocks.0.attn2.to_out.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.21.transformer_blocks.0.attn2.to_out.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_q.weight": "blocks.21.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_v.weight": "blocks.21.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.21.transformer_blocks.0.act_fn.proj.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.21.transformer_blocks.0.act_fn.proj.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.2.bias": "blocks.21.transformer_blocks.0.ff.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.2.weight": "blocks.21.transformer_blocks.0.ff.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.norm1.bias": "blocks.21.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.norm1.weight": "blocks.21.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.norm2.bias": "blocks.21.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.norm2.weight": "blocks.21.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.norm3.bias": "blocks.21.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.norm3.weight": "blocks.21.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.1.attn1.to_k.weight": "blocks.21.transformer_blocks.1.attn1.to_k.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.1.attn1.to_out.0.bias": "blocks.21.transformer_blocks.1.attn1.to_out.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.1.attn1.to_out.0.weight": "blocks.21.transformer_blocks.1.attn1.to_out.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.1.attn1.to_q.weight": "blocks.21.transformer_blocks.1.attn1.to_q.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.1.attn1.to_v.weight": "blocks.21.transformer_blocks.1.attn1.to_v.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.1.attn2.to_k.weight": "blocks.21.transformer_blocks.1.attn2.to_k.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.1.attn2.to_out.0.bias": "blocks.21.transformer_blocks.1.attn2.to_out.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.1.attn2.to_out.0.weight": "blocks.21.transformer_blocks.1.attn2.to_out.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.1.attn2.to_q.weight": "blocks.21.transformer_blocks.1.attn2.to_q.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.1.attn2.to_v.weight": "blocks.21.transformer_blocks.1.attn2.to_v.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.1.ff.net.0.proj.bias": "blocks.21.transformer_blocks.1.act_fn.proj.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.1.ff.net.0.proj.weight": "blocks.21.transformer_blocks.1.act_fn.proj.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.1.ff.net.2.bias": "blocks.21.transformer_blocks.1.ff.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.1.ff.net.2.weight": "blocks.21.transformer_blocks.1.ff.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.1.norm1.bias": "blocks.21.transformer_blocks.1.norm1.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.1.norm1.weight": "blocks.21.transformer_blocks.1.norm1.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.1.norm2.bias": "blocks.21.transformer_blocks.1.norm2.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.1.norm2.weight": "blocks.21.transformer_blocks.1.norm2.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.1.norm3.bias": "blocks.21.transformer_blocks.1.norm3.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.1.norm3.weight": "blocks.21.transformer_blocks.1.norm3.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.2.attn1.to_k.weight": "blocks.21.transformer_blocks.2.attn1.to_k.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.2.attn1.to_out.0.bias": "blocks.21.transformer_blocks.2.attn1.to_out.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.2.attn1.to_out.0.weight": "blocks.21.transformer_blocks.2.attn1.to_out.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.2.attn1.to_q.weight": "blocks.21.transformer_blocks.2.attn1.to_q.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.2.attn1.to_v.weight": "blocks.21.transformer_blocks.2.attn1.to_v.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.2.attn2.to_k.weight": "blocks.21.transformer_blocks.2.attn2.to_k.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.2.attn2.to_out.0.bias": "blocks.21.transformer_blocks.2.attn2.to_out.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.2.attn2.to_out.0.weight": "blocks.21.transformer_blocks.2.attn2.to_out.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.2.attn2.to_q.weight": "blocks.21.transformer_blocks.2.attn2.to_q.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.2.attn2.to_v.weight": "blocks.21.transformer_blocks.2.attn2.to_v.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.2.ff.net.0.proj.bias": "blocks.21.transformer_blocks.2.act_fn.proj.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.2.ff.net.0.proj.weight": "blocks.21.transformer_blocks.2.act_fn.proj.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.2.ff.net.2.bias": "blocks.21.transformer_blocks.2.ff.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.2.ff.net.2.weight": "blocks.21.transformer_blocks.2.ff.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.2.norm1.bias": "blocks.21.transformer_blocks.2.norm1.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.2.norm1.weight": "blocks.21.transformer_blocks.2.norm1.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.2.norm2.bias": "blocks.21.transformer_blocks.2.norm2.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.2.norm2.weight": "blocks.21.transformer_blocks.2.norm2.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.2.norm3.bias": "blocks.21.transformer_blocks.2.norm3.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.2.norm3.weight": "blocks.21.transformer_blocks.2.norm3.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.3.attn1.to_k.weight": "blocks.21.transformer_blocks.3.attn1.to_k.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.3.attn1.to_out.0.bias": "blocks.21.transformer_blocks.3.attn1.to_out.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.3.attn1.to_out.0.weight": "blocks.21.transformer_blocks.3.attn1.to_out.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.3.attn1.to_q.weight": "blocks.21.transformer_blocks.3.attn1.to_q.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.3.attn1.to_v.weight": "blocks.21.transformer_blocks.3.attn1.to_v.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.3.attn2.to_k.weight": "blocks.21.transformer_blocks.3.attn2.to_k.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.3.attn2.to_out.0.bias": "blocks.21.transformer_blocks.3.attn2.to_out.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.3.attn2.to_out.0.weight": "blocks.21.transformer_blocks.3.attn2.to_out.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.3.attn2.to_q.weight": "blocks.21.transformer_blocks.3.attn2.to_q.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.3.attn2.to_v.weight": "blocks.21.transformer_blocks.3.attn2.to_v.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.3.ff.net.0.proj.bias": "blocks.21.transformer_blocks.3.act_fn.proj.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.3.ff.net.0.proj.weight": "blocks.21.transformer_blocks.3.act_fn.proj.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.3.ff.net.2.bias": "blocks.21.transformer_blocks.3.ff.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.3.ff.net.2.weight": "blocks.21.transformer_blocks.3.ff.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.3.norm1.bias": "blocks.21.transformer_blocks.3.norm1.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.3.norm1.weight": "blocks.21.transformer_blocks.3.norm1.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.3.norm2.bias": "blocks.21.transformer_blocks.3.norm2.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.3.norm2.weight": "blocks.21.transformer_blocks.3.norm2.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.3.norm3.bias": "blocks.21.transformer_blocks.3.norm3.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.3.norm3.weight": "blocks.21.transformer_blocks.3.norm3.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.4.attn1.to_k.weight": "blocks.21.transformer_blocks.4.attn1.to_k.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.4.attn1.to_out.0.bias": "blocks.21.transformer_blocks.4.attn1.to_out.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.4.attn1.to_out.0.weight": "blocks.21.transformer_blocks.4.attn1.to_out.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.4.attn1.to_q.weight": "blocks.21.transformer_blocks.4.attn1.to_q.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.4.attn1.to_v.weight": "blocks.21.transformer_blocks.4.attn1.to_v.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.4.attn2.to_k.weight": "blocks.21.transformer_blocks.4.attn2.to_k.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.4.attn2.to_out.0.bias": "blocks.21.transformer_blocks.4.attn2.to_out.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.4.attn2.to_out.0.weight": "blocks.21.transformer_blocks.4.attn2.to_out.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.4.attn2.to_q.weight": "blocks.21.transformer_blocks.4.attn2.to_q.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.4.attn2.to_v.weight": "blocks.21.transformer_blocks.4.attn2.to_v.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.4.ff.net.0.proj.bias": "blocks.21.transformer_blocks.4.act_fn.proj.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.4.ff.net.0.proj.weight": "blocks.21.transformer_blocks.4.act_fn.proj.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.4.ff.net.2.bias": "blocks.21.transformer_blocks.4.ff.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.4.ff.net.2.weight": "blocks.21.transformer_blocks.4.ff.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.4.norm1.bias": "blocks.21.transformer_blocks.4.norm1.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.4.norm1.weight": "blocks.21.transformer_blocks.4.norm1.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.4.norm2.bias": "blocks.21.transformer_blocks.4.norm2.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.4.norm2.weight": "blocks.21.transformer_blocks.4.norm2.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.4.norm3.bias": "blocks.21.transformer_blocks.4.norm3.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.4.norm3.weight": "blocks.21.transformer_blocks.4.norm3.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.5.attn1.to_k.weight": "blocks.21.transformer_blocks.5.attn1.to_k.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.5.attn1.to_out.0.bias": "blocks.21.transformer_blocks.5.attn1.to_out.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.5.attn1.to_out.0.weight": "blocks.21.transformer_blocks.5.attn1.to_out.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.5.attn1.to_q.weight": "blocks.21.transformer_blocks.5.attn1.to_q.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.5.attn1.to_v.weight": "blocks.21.transformer_blocks.5.attn1.to_v.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.5.attn2.to_k.weight": "blocks.21.transformer_blocks.5.attn2.to_k.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.5.attn2.to_out.0.bias": "blocks.21.transformer_blocks.5.attn2.to_out.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.5.attn2.to_out.0.weight": "blocks.21.transformer_blocks.5.attn2.to_out.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.5.attn2.to_q.weight": "blocks.21.transformer_blocks.5.attn2.to_q.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.5.attn2.to_v.weight": "blocks.21.transformer_blocks.5.attn2.to_v.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.5.ff.net.0.proj.bias": "blocks.21.transformer_blocks.5.act_fn.proj.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.5.ff.net.0.proj.weight": "blocks.21.transformer_blocks.5.act_fn.proj.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.5.ff.net.2.bias": "blocks.21.transformer_blocks.5.ff.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.5.ff.net.2.weight": "blocks.21.transformer_blocks.5.ff.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.5.norm1.bias": "blocks.21.transformer_blocks.5.norm1.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.5.norm1.weight": "blocks.21.transformer_blocks.5.norm1.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.5.norm2.bias": "blocks.21.transformer_blocks.5.norm2.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.5.norm2.weight": "blocks.21.transformer_blocks.5.norm2.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.5.norm3.bias": "blocks.21.transformer_blocks.5.norm3.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.5.norm3.weight": "blocks.21.transformer_blocks.5.norm3.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.6.attn1.to_k.weight": "blocks.21.transformer_blocks.6.attn1.to_k.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.6.attn1.to_out.0.bias": "blocks.21.transformer_blocks.6.attn1.to_out.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.6.attn1.to_out.0.weight": "blocks.21.transformer_blocks.6.attn1.to_out.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.6.attn1.to_q.weight": "blocks.21.transformer_blocks.6.attn1.to_q.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.6.attn1.to_v.weight": "blocks.21.transformer_blocks.6.attn1.to_v.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.6.attn2.to_k.weight": "blocks.21.transformer_blocks.6.attn2.to_k.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.6.attn2.to_out.0.bias": "blocks.21.transformer_blocks.6.attn2.to_out.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.6.attn2.to_out.0.weight": "blocks.21.transformer_blocks.6.attn2.to_out.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.6.attn2.to_q.weight": "blocks.21.transformer_blocks.6.attn2.to_q.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.6.attn2.to_v.weight": "blocks.21.transformer_blocks.6.attn2.to_v.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.6.ff.net.0.proj.bias": "blocks.21.transformer_blocks.6.act_fn.proj.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.6.ff.net.0.proj.weight": "blocks.21.transformer_blocks.6.act_fn.proj.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.6.ff.net.2.bias": "blocks.21.transformer_blocks.6.ff.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.6.ff.net.2.weight": "blocks.21.transformer_blocks.6.ff.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.6.norm1.bias": "blocks.21.transformer_blocks.6.norm1.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.6.norm1.weight": "blocks.21.transformer_blocks.6.norm1.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.6.norm2.bias": "blocks.21.transformer_blocks.6.norm2.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.6.norm2.weight": "blocks.21.transformer_blocks.6.norm2.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.6.norm3.bias": "blocks.21.transformer_blocks.6.norm3.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.6.norm3.weight": "blocks.21.transformer_blocks.6.norm3.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.7.attn1.to_k.weight": "blocks.21.transformer_blocks.7.attn1.to_k.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.7.attn1.to_out.0.bias": "blocks.21.transformer_blocks.7.attn1.to_out.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.7.attn1.to_out.0.weight": "blocks.21.transformer_blocks.7.attn1.to_out.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.7.attn1.to_q.weight": "blocks.21.transformer_blocks.7.attn1.to_q.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.7.attn1.to_v.weight": "blocks.21.transformer_blocks.7.attn1.to_v.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.7.attn2.to_k.weight": "blocks.21.transformer_blocks.7.attn2.to_k.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.7.attn2.to_out.0.bias": "blocks.21.transformer_blocks.7.attn2.to_out.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.7.attn2.to_out.0.weight": "blocks.21.transformer_blocks.7.attn2.to_out.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.7.attn2.to_q.weight": "blocks.21.transformer_blocks.7.attn2.to_q.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.7.attn2.to_v.weight": "blocks.21.transformer_blocks.7.attn2.to_v.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.7.ff.net.0.proj.bias": "blocks.21.transformer_blocks.7.act_fn.proj.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.7.ff.net.0.proj.weight": "blocks.21.transformer_blocks.7.act_fn.proj.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.7.ff.net.2.bias": "blocks.21.transformer_blocks.7.ff.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.7.ff.net.2.weight": "blocks.21.transformer_blocks.7.ff.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.7.norm1.bias": "blocks.21.transformer_blocks.7.norm1.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.7.norm1.weight": "blocks.21.transformer_blocks.7.norm1.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.7.norm2.bias": "blocks.21.transformer_blocks.7.norm2.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.7.norm2.weight": "blocks.21.transformer_blocks.7.norm2.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.7.norm3.bias": "blocks.21.transformer_blocks.7.norm3.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.7.norm3.weight": "blocks.21.transformer_blocks.7.norm3.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.8.attn1.to_k.weight": "blocks.21.transformer_blocks.8.attn1.to_k.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.8.attn1.to_out.0.bias": "blocks.21.transformer_blocks.8.attn1.to_out.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.8.attn1.to_out.0.weight": "blocks.21.transformer_blocks.8.attn1.to_out.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.8.attn1.to_q.weight": "blocks.21.transformer_blocks.8.attn1.to_q.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.8.attn1.to_v.weight": "blocks.21.transformer_blocks.8.attn1.to_v.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.8.attn2.to_k.weight": "blocks.21.transformer_blocks.8.attn2.to_k.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.8.attn2.to_out.0.bias": "blocks.21.transformer_blocks.8.attn2.to_out.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.8.attn2.to_out.0.weight": "blocks.21.transformer_blocks.8.attn2.to_out.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.8.attn2.to_q.weight": "blocks.21.transformer_blocks.8.attn2.to_q.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.8.attn2.to_v.weight": "blocks.21.transformer_blocks.8.attn2.to_v.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.8.ff.net.0.proj.bias": "blocks.21.transformer_blocks.8.act_fn.proj.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.8.ff.net.0.proj.weight": "blocks.21.transformer_blocks.8.act_fn.proj.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.8.ff.net.2.bias": "blocks.21.transformer_blocks.8.ff.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.8.ff.net.2.weight": "blocks.21.transformer_blocks.8.ff.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.8.norm1.bias": "blocks.21.transformer_blocks.8.norm1.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.8.norm1.weight": "blocks.21.transformer_blocks.8.norm1.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.8.norm2.bias": "blocks.21.transformer_blocks.8.norm2.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.8.norm2.weight": "blocks.21.transformer_blocks.8.norm2.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.8.norm3.bias": "blocks.21.transformer_blocks.8.norm3.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.8.norm3.weight": "blocks.21.transformer_blocks.8.norm3.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.9.attn1.to_k.weight": "blocks.21.transformer_blocks.9.attn1.to_k.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.9.attn1.to_out.0.bias": "blocks.21.transformer_blocks.9.attn1.to_out.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.9.attn1.to_out.0.weight": "blocks.21.transformer_blocks.9.attn1.to_out.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.9.attn1.to_q.weight": "blocks.21.transformer_blocks.9.attn1.to_q.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.9.attn1.to_v.weight": "blocks.21.transformer_blocks.9.attn1.to_v.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.9.attn2.to_k.weight": "blocks.21.transformer_blocks.9.attn2.to_k.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.9.attn2.to_out.0.bias": "blocks.21.transformer_blocks.9.attn2.to_out.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.9.attn2.to_out.0.weight": "blocks.21.transformer_blocks.9.attn2.to_out.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.9.attn2.to_q.weight": "blocks.21.transformer_blocks.9.attn2.to_q.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.9.attn2.to_v.weight": "blocks.21.transformer_blocks.9.attn2.to_v.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.9.ff.net.0.proj.bias": "blocks.21.transformer_blocks.9.act_fn.proj.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.9.ff.net.0.proj.weight": "blocks.21.transformer_blocks.9.act_fn.proj.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.9.ff.net.2.bias": "blocks.21.transformer_blocks.9.ff.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.9.ff.net.2.weight": "blocks.21.transformer_blocks.9.ff.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.9.norm1.bias": "blocks.21.transformer_blocks.9.norm1.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.9.norm1.weight": "blocks.21.transformer_blocks.9.norm1.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.9.norm2.bias": "blocks.21.transformer_blocks.9.norm2.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.9.norm2.weight": "blocks.21.transformer_blocks.9.norm2.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.9.norm3.bias": "blocks.21.transformer_blocks.9.norm3.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.9.norm3.weight": "blocks.21.transformer_blocks.9.norm3.weight",
+ "model.diffusion_model.middle_block.2.emb_layers.1.bias": "blocks.22.time_emb_proj.bias",
+ "model.diffusion_model.middle_block.2.emb_layers.1.weight": "blocks.22.time_emb_proj.weight",
+ "model.diffusion_model.middle_block.2.in_layers.0.bias": "blocks.22.norm1.bias",
+ "model.diffusion_model.middle_block.2.in_layers.0.weight": "blocks.22.norm1.weight",
+ "model.diffusion_model.middle_block.2.in_layers.2.bias": "blocks.22.conv1.bias",
+ "model.diffusion_model.middle_block.2.in_layers.2.weight": "blocks.22.conv1.weight",
+ "model.diffusion_model.middle_block.2.out_layers.0.bias": "blocks.22.norm2.bias",
+ "model.diffusion_model.middle_block.2.out_layers.0.weight": "blocks.22.norm2.weight",
+ "model.diffusion_model.middle_block.2.out_layers.3.bias": "blocks.22.conv2.bias",
+ "model.diffusion_model.middle_block.2.out_layers.3.weight": "blocks.22.conv2.weight",
+ "model.diffusion_model.out.0.bias": "conv_norm_out.bias",
+ "model.diffusion_model.out.0.weight": "conv_norm_out.weight",
+ "model.diffusion_model.out.2.bias": "conv_out.bias",
+ "model.diffusion_model.out.2.weight": "conv_out.weight",
+ "model.diffusion_model.output_blocks.0.0.emb_layers.1.bias": "blocks.24.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.0.0.emb_layers.1.weight": "blocks.24.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.0.0.in_layers.0.bias": "blocks.24.norm1.bias",
+ "model.diffusion_model.output_blocks.0.0.in_layers.0.weight": "blocks.24.norm1.weight",
+ "model.diffusion_model.output_blocks.0.0.in_layers.2.bias": "blocks.24.conv1.bias",
+ "model.diffusion_model.output_blocks.0.0.in_layers.2.weight": "blocks.24.conv1.weight",
+ "model.diffusion_model.output_blocks.0.0.out_layers.0.bias": "blocks.24.norm2.bias",
+ "model.diffusion_model.output_blocks.0.0.out_layers.0.weight": "blocks.24.norm2.weight",
+ "model.diffusion_model.output_blocks.0.0.out_layers.3.bias": "blocks.24.conv2.bias",
+ "model.diffusion_model.output_blocks.0.0.out_layers.3.weight": "blocks.24.conv2.weight",
+ "model.diffusion_model.output_blocks.0.0.skip_connection.bias": "blocks.24.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.0.0.skip_connection.weight": "blocks.24.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.0.1.norm.bias": "blocks.25.norm.bias",
+ "model.diffusion_model.output_blocks.0.1.norm.weight": "blocks.25.norm.weight",
+ "model.diffusion_model.output_blocks.0.1.proj_in.bias": "blocks.25.proj_in.bias",
+ "model.diffusion_model.output_blocks.0.1.proj_in.weight": "blocks.25.proj_in.weight",
+ "model.diffusion_model.output_blocks.0.1.proj_out.bias": "blocks.25.proj_out.bias",
+ "model.diffusion_model.output_blocks.0.1.proj_out.weight": "blocks.25.proj_out.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn1.to_k.weight": "blocks.25.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.25.transformer_blocks.0.attn1.to_out.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.25.transformer_blocks.0.attn1.to_out.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn1.to_q.weight": "blocks.25.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn1.to_v.weight": "blocks.25.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn2.to_k.weight": "blocks.25.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.25.transformer_blocks.0.attn2.to_out.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.25.transformer_blocks.0.attn2.to_out.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn2.to_q.weight": "blocks.25.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn2.to_v.weight": "blocks.25.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.25.transformer_blocks.0.act_fn.proj.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.25.transformer_blocks.0.act_fn.proj.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.ff.net.2.bias": "blocks.25.transformer_blocks.0.ff.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.ff.net.2.weight": "blocks.25.transformer_blocks.0.ff.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.norm1.bias": "blocks.25.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.norm1.weight": "blocks.25.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.norm2.bias": "blocks.25.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.norm2.weight": "blocks.25.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.norm3.bias": "blocks.25.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.norm3.weight": "blocks.25.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn1.to_k.weight": "blocks.25.transformer_blocks.1.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn1.to_out.0.bias": "blocks.25.transformer_blocks.1.attn1.to_out.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn1.to_out.0.weight": "blocks.25.transformer_blocks.1.attn1.to_out.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn1.to_q.weight": "blocks.25.transformer_blocks.1.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn1.to_v.weight": "blocks.25.transformer_blocks.1.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn2.to_k.weight": "blocks.25.transformer_blocks.1.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn2.to_out.0.bias": "blocks.25.transformer_blocks.1.attn2.to_out.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn2.to_out.0.weight": "blocks.25.transformer_blocks.1.attn2.to_out.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn2.to_q.weight": "blocks.25.transformer_blocks.1.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn2.to_v.weight": "blocks.25.transformer_blocks.1.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.ff.net.0.proj.bias": "blocks.25.transformer_blocks.1.act_fn.proj.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.ff.net.0.proj.weight": "blocks.25.transformer_blocks.1.act_fn.proj.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.ff.net.2.bias": "blocks.25.transformer_blocks.1.ff.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.ff.net.2.weight": "blocks.25.transformer_blocks.1.ff.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.norm1.bias": "blocks.25.transformer_blocks.1.norm1.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.norm1.weight": "blocks.25.transformer_blocks.1.norm1.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.norm2.bias": "blocks.25.transformer_blocks.1.norm2.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.norm2.weight": "blocks.25.transformer_blocks.1.norm2.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.norm3.bias": "blocks.25.transformer_blocks.1.norm3.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.norm3.weight": "blocks.25.transformer_blocks.1.norm3.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn1.to_k.weight": "blocks.25.transformer_blocks.2.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn1.to_out.0.bias": "blocks.25.transformer_blocks.2.attn1.to_out.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn1.to_out.0.weight": "blocks.25.transformer_blocks.2.attn1.to_out.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn1.to_q.weight": "blocks.25.transformer_blocks.2.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn1.to_v.weight": "blocks.25.transformer_blocks.2.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn2.to_k.weight": "blocks.25.transformer_blocks.2.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn2.to_out.0.bias": "blocks.25.transformer_blocks.2.attn2.to_out.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn2.to_out.0.weight": "blocks.25.transformer_blocks.2.attn2.to_out.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn2.to_q.weight": "blocks.25.transformer_blocks.2.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn2.to_v.weight": "blocks.25.transformer_blocks.2.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.ff.net.0.proj.bias": "blocks.25.transformer_blocks.2.act_fn.proj.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.ff.net.0.proj.weight": "blocks.25.transformer_blocks.2.act_fn.proj.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.ff.net.2.bias": "blocks.25.transformer_blocks.2.ff.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.ff.net.2.weight": "blocks.25.transformer_blocks.2.ff.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.norm1.bias": "blocks.25.transformer_blocks.2.norm1.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.norm1.weight": "blocks.25.transformer_blocks.2.norm1.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.norm2.bias": "blocks.25.transformer_blocks.2.norm2.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.norm2.weight": "blocks.25.transformer_blocks.2.norm2.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.norm3.bias": "blocks.25.transformer_blocks.2.norm3.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.norm3.weight": "blocks.25.transformer_blocks.2.norm3.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn1.to_k.weight": "blocks.25.transformer_blocks.3.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn1.to_out.0.bias": "blocks.25.transformer_blocks.3.attn1.to_out.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn1.to_out.0.weight": "blocks.25.transformer_blocks.3.attn1.to_out.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn1.to_q.weight": "blocks.25.transformer_blocks.3.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn1.to_v.weight": "blocks.25.transformer_blocks.3.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn2.to_k.weight": "blocks.25.transformer_blocks.3.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn2.to_out.0.bias": "blocks.25.transformer_blocks.3.attn2.to_out.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn2.to_out.0.weight": "blocks.25.transformer_blocks.3.attn2.to_out.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn2.to_q.weight": "blocks.25.transformer_blocks.3.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn2.to_v.weight": "blocks.25.transformer_blocks.3.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.ff.net.0.proj.bias": "blocks.25.transformer_blocks.3.act_fn.proj.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.ff.net.0.proj.weight": "blocks.25.transformer_blocks.3.act_fn.proj.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.ff.net.2.bias": "blocks.25.transformer_blocks.3.ff.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.ff.net.2.weight": "blocks.25.transformer_blocks.3.ff.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.norm1.bias": "blocks.25.transformer_blocks.3.norm1.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.norm1.weight": "blocks.25.transformer_blocks.3.norm1.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.norm2.bias": "blocks.25.transformer_blocks.3.norm2.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.norm2.weight": "blocks.25.transformer_blocks.3.norm2.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.norm3.bias": "blocks.25.transformer_blocks.3.norm3.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.norm3.weight": "blocks.25.transformer_blocks.3.norm3.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.attn1.to_k.weight": "blocks.25.transformer_blocks.4.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.attn1.to_out.0.bias": "blocks.25.transformer_blocks.4.attn1.to_out.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.attn1.to_out.0.weight": "blocks.25.transformer_blocks.4.attn1.to_out.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.attn1.to_q.weight": "blocks.25.transformer_blocks.4.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.attn1.to_v.weight": "blocks.25.transformer_blocks.4.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.attn2.to_k.weight": "blocks.25.transformer_blocks.4.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.attn2.to_out.0.bias": "blocks.25.transformer_blocks.4.attn2.to_out.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.attn2.to_out.0.weight": "blocks.25.transformer_blocks.4.attn2.to_out.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.attn2.to_q.weight": "blocks.25.transformer_blocks.4.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.attn2.to_v.weight": "blocks.25.transformer_blocks.4.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.ff.net.0.proj.bias": "blocks.25.transformer_blocks.4.act_fn.proj.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.ff.net.0.proj.weight": "blocks.25.transformer_blocks.4.act_fn.proj.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.ff.net.2.bias": "blocks.25.transformer_blocks.4.ff.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.ff.net.2.weight": "blocks.25.transformer_blocks.4.ff.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.norm1.bias": "blocks.25.transformer_blocks.4.norm1.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.norm1.weight": "blocks.25.transformer_blocks.4.norm1.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.norm2.bias": "blocks.25.transformer_blocks.4.norm2.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.norm2.weight": "blocks.25.transformer_blocks.4.norm2.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.norm3.bias": "blocks.25.transformer_blocks.4.norm3.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.norm3.weight": "blocks.25.transformer_blocks.4.norm3.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.attn1.to_k.weight": "blocks.25.transformer_blocks.5.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.attn1.to_out.0.bias": "blocks.25.transformer_blocks.5.attn1.to_out.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.attn1.to_out.0.weight": "blocks.25.transformer_blocks.5.attn1.to_out.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.attn1.to_q.weight": "blocks.25.transformer_blocks.5.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.attn1.to_v.weight": "blocks.25.transformer_blocks.5.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.attn2.to_k.weight": "blocks.25.transformer_blocks.5.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.attn2.to_out.0.bias": "blocks.25.transformer_blocks.5.attn2.to_out.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.attn2.to_out.0.weight": "blocks.25.transformer_blocks.5.attn2.to_out.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.attn2.to_q.weight": "blocks.25.transformer_blocks.5.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.attn2.to_v.weight": "blocks.25.transformer_blocks.5.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.ff.net.0.proj.bias": "blocks.25.transformer_blocks.5.act_fn.proj.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.ff.net.0.proj.weight": "blocks.25.transformer_blocks.5.act_fn.proj.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.ff.net.2.bias": "blocks.25.transformer_blocks.5.ff.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.ff.net.2.weight": "blocks.25.transformer_blocks.5.ff.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.norm1.bias": "blocks.25.transformer_blocks.5.norm1.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.norm1.weight": "blocks.25.transformer_blocks.5.norm1.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.norm2.bias": "blocks.25.transformer_blocks.5.norm2.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.norm2.weight": "blocks.25.transformer_blocks.5.norm2.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.norm3.bias": "blocks.25.transformer_blocks.5.norm3.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.norm3.weight": "blocks.25.transformer_blocks.5.norm3.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.attn1.to_k.weight": "blocks.25.transformer_blocks.6.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.attn1.to_out.0.bias": "blocks.25.transformer_blocks.6.attn1.to_out.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.attn1.to_out.0.weight": "blocks.25.transformer_blocks.6.attn1.to_out.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.attn1.to_q.weight": "blocks.25.transformer_blocks.6.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.attn1.to_v.weight": "blocks.25.transformer_blocks.6.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.attn2.to_k.weight": "blocks.25.transformer_blocks.6.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.attn2.to_out.0.bias": "blocks.25.transformer_blocks.6.attn2.to_out.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.attn2.to_out.0.weight": "blocks.25.transformer_blocks.6.attn2.to_out.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.attn2.to_q.weight": "blocks.25.transformer_blocks.6.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.attn2.to_v.weight": "blocks.25.transformer_blocks.6.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.ff.net.0.proj.bias": "blocks.25.transformer_blocks.6.act_fn.proj.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.ff.net.0.proj.weight": "blocks.25.transformer_blocks.6.act_fn.proj.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.ff.net.2.bias": "blocks.25.transformer_blocks.6.ff.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.ff.net.2.weight": "blocks.25.transformer_blocks.6.ff.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.norm1.bias": "blocks.25.transformer_blocks.6.norm1.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.norm1.weight": "blocks.25.transformer_blocks.6.norm1.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.norm2.bias": "blocks.25.transformer_blocks.6.norm2.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.norm2.weight": "blocks.25.transformer_blocks.6.norm2.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.norm3.bias": "blocks.25.transformer_blocks.6.norm3.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.norm3.weight": "blocks.25.transformer_blocks.6.norm3.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.attn1.to_k.weight": "blocks.25.transformer_blocks.7.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.attn1.to_out.0.bias": "blocks.25.transformer_blocks.7.attn1.to_out.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.attn1.to_out.0.weight": "blocks.25.transformer_blocks.7.attn1.to_out.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.attn1.to_q.weight": "blocks.25.transformer_blocks.7.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.attn1.to_v.weight": "blocks.25.transformer_blocks.7.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.attn2.to_k.weight": "blocks.25.transformer_blocks.7.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.attn2.to_out.0.bias": "blocks.25.transformer_blocks.7.attn2.to_out.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.attn2.to_out.0.weight": "blocks.25.transformer_blocks.7.attn2.to_out.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.attn2.to_q.weight": "blocks.25.transformer_blocks.7.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.attn2.to_v.weight": "blocks.25.transformer_blocks.7.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.ff.net.0.proj.bias": "blocks.25.transformer_blocks.7.act_fn.proj.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.ff.net.0.proj.weight": "blocks.25.transformer_blocks.7.act_fn.proj.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.ff.net.2.bias": "blocks.25.transformer_blocks.7.ff.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.ff.net.2.weight": "blocks.25.transformer_blocks.7.ff.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.norm1.bias": "blocks.25.transformer_blocks.7.norm1.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.norm1.weight": "blocks.25.transformer_blocks.7.norm1.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.norm2.bias": "blocks.25.transformer_blocks.7.norm2.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.norm2.weight": "blocks.25.transformer_blocks.7.norm2.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.norm3.bias": "blocks.25.transformer_blocks.7.norm3.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.norm3.weight": "blocks.25.transformer_blocks.7.norm3.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.attn1.to_k.weight": "blocks.25.transformer_blocks.8.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.attn1.to_out.0.bias": "blocks.25.transformer_blocks.8.attn1.to_out.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.attn1.to_out.0.weight": "blocks.25.transformer_blocks.8.attn1.to_out.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.attn1.to_q.weight": "blocks.25.transformer_blocks.8.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.attn1.to_v.weight": "blocks.25.transformer_blocks.8.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.attn2.to_k.weight": "blocks.25.transformer_blocks.8.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.attn2.to_out.0.bias": "blocks.25.transformer_blocks.8.attn2.to_out.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.attn2.to_out.0.weight": "blocks.25.transformer_blocks.8.attn2.to_out.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.attn2.to_q.weight": "blocks.25.transformer_blocks.8.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.attn2.to_v.weight": "blocks.25.transformer_blocks.8.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.ff.net.0.proj.bias": "blocks.25.transformer_blocks.8.act_fn.proj.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.ff.net.0.proj.weight": "blocks.25.transformer_blocks.8.act_fn.proj.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.ff.net.2.bias": "blocks.25.transformer_blocks.8.ff.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.ff.net.2.weight": "blocks.25.transformer_blocks.8.ff.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.norm1.bias": "blocks.25.transformer_blocks.8.norm1.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.norm1.weight": "blocks.25.transformer_blocks.8.norm1.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.norm2.bias": "blocks.25.transformer_blocks.8.norm2.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.norm2.weight": "blocks.25.transformer_blocks.8.norm2.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.norm3.bias": "blocks.25.transformer_blocks.8.norm3.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.norm3.weight": "blocks.25.transformer_blocks.8.norm3.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.attn1.to_k.weight": "blocks.25.transformer_blocks.9.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.attn1.to_out.0.bias": "blocks.25.transformer_blocks.9.attn1.to_out.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.attn1.to_out.0.weight": "blocks.25.transformer_blocks.9.attn1.to_out.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.attn1.to_q.weight": "blocks.25.transformer_blocks.9.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.attn1.to_v.weight": "blocks.25.transformer_blocks.9.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.attn2.to_k.weight": "blocks.25.transformer_blocks.9.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.attn2.to_out.0.bias": "blocks.25.transformer_blocks.9.attn2.to_out.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.attn2.to_out.0.weight": "blocks.25.transformer_blocks.9.attn2.to_out.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.attn2.to_q.weight": "blocks.25.transformer_blocks.9.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.attn2.to_v.weight": "blocks.25.transformer_blocks.9.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.ff.net.0.proj.bias": "blocks.25.transformer_blocks.9.act_fn.proj.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.ff.net.0.proj.weight": "blocks.25.transformer_blocks.9.act_fn.proj.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.ff.net.2.bias": "blocks.25.transformer_blocks.9.ff.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.ff.net.2.weight": "blocks.25.transformer_blocks.9.ff.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.norm1.bias": "blocks.25.transformer_blocks.9.norm1.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.norm1.weight": "blocks.25.transformer_blocks.9.norm1.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.norm2.bias": "blocks.25.transformer_blocks.9.norm2.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.norm2.weight": "blocks.25.transformer_blocks.9.norm2.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.norm3.bias": "blocks.25.transformer_blocks.9.norm3.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.norm3.weight": "blocks.25.transformer_blocks.9.norm3.weight",
+ "model.diffusion_model.output_blocks.1.0.emb_layers.1.bias": "blocks.27.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.1.0.emb_layers.1.weight": "blocks.27.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.1.0.in_layers.0.bias": "blocks.27.norm1.bias",
+ "model.diffusion_model.output_blocks.1.0.in_layers.0.weight": "blocks.27.norm1.weight",
+ "model.diffusion_model.output_blocks.1.0.in_layers.2.bias": "blocks.27.conv1.bias",
+ "model.diffusion_model.output_blocks.1.0.in_layers.2.weight": "blocks.27.conv1.weight",
+ "model.diffusion_model.output_blocks.1.0.out_layers.0.bias": "blocks.27.norm2.bias",
+ "model.diffusion_model.output_blocks.1.0.out_layers.0.weight": "blocks.27.norm2.weight",
+ "model.diffusion_model.output_blocks.1.0.out_layers.3.bias": "blocks.27.conv2.bias",
+ "model.diffusion_model.output_blocks.1.0.out_layers.3.weight": "blocks.27.conv2.weight",
+ "model.diffusion_model.output_blocks.1.0.skip_connection.bias": "blocks.27.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.1.0.skip_connection.weight": "blocks.27.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.1.1.norm.bias": "blocks.28.norm.bias",
+ "model.diffusion_model.output_blocks.1.1.norm.weight": "blocks.28.norm.weight",
+ "model.diffusion_model.output_blocks.1.1.proj_in.bias": "blocks.28.proj_in.bias",
+ "model.diffusion_model.output_blocks.1.1.proj_in.weight": "blocks.28.proj_in.weight",
+ "model.diffusion_model.output_blocks.1.1.proj_out.bias": "blocks.28.proj_out.bias",
+ "model.diffusion_model.output_blocks.1.1.proj_out.weight": "blocks.28.proj_out.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn1.to_k.weight": "blocks.28.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.28.transformer_blocks.0.attn1.to_out.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.28.transformer_blocks.0.attn1.to_out.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn1.to_q.weight": "blocks.28.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn1.to_v.weight": "blocks.28.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn2.to_k.weight": "blocks.28.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.28.transformer_blocks.0.attn2.to_out.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.28.transformer_blocks.0.attn2.to_out.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn2.to_q.weight": "blocks.28.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn2.to_v.weight": "blocks.28.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.28.transformer_blocks.0.act_fn.proj.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.28.transformer_blocks.0.act_fn.proj.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.ff.net.2.bias": "blocks.28.transformer_blocks.0.ff.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.ff.net.2.weight": "blocks.28.transformer_blocks.0.ff.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.norm1.bias": "blocks.28.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.norm1.weight": "blocks.28.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.norm2.bias": "blocks.28.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.norm2.weight": "blocks.28.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.norm3.bias": "blocks.28.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.norm3.weight": "blocks.28.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn1.to_k.weight": "blocks.28.transformer_blocks.1.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn1.to_out.0.bias": "blocks.28.transformer_blocks.1.attn1.to_out.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn1.to_out.0.weight": "blocks.28.transformer_blocks.1.attn1.to_out.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn1.to_q.weight": "blocks.28.transformer_blocks.1.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn1.to_v.weight": "blocks.28.transformer_blocks.1.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn2.to_k.weight": "blocks.28.transformer_blocks.1.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn2.to_out.0.bias": "blocks.28.transformer_blocks.1.attn2.to_out.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn2.to_out.0.weight": "blocks.28.transformer_blocks.1.attn2.to_out.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn2.to_q.weight": "blocks.28.transformer_blocks.1.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn2.to_v.weight": "blocks.28.transformer_blocks.1.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.ff.net.0.proj.bias": "blocks.28.transformer_blocks.1.act_fn.proj.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.ff.net.0.proj.weight": "blocks.28.transformer_blocks.1.act_fn.proj.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.ff.net.2.bias": "blocks.28.transformer_blocks.1.ff.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.ff.net.2.weight": "blocks.28.transformer_blocks.1.ff.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.norm1.bias": "blocks.28.transformer_blocks.1.norm1.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.norm1.weight": "blocks.28.transformer_blocks.1.norm1.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.norm2.bias": "blocks.28.transformer_blocks.1.norm2.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.norm2.weight": "blocks.28.transformer_blocks.1.norm2.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.norm3.bias": "blocks.28.transformer_blocks.1.norm3.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.norm3.weight": "blocks.28.transformer_blocks.1.norm3.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn1.to_k.weight": "blocks.28.transformer_blocks.2.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn1.to_out.0.bias": "blocks.28.transformer_blocks.2.attn1.to_out.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn1.to_out.0.weight": "blocks.28.transformer_blocks.2.attn1.to_out.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn1.to_q.weight": "blocks.28.transformer_blocks.2.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn1.to_v.weight": "blocks.28.transformer_blocks.2.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn2.to_k.weight": "blocks.28.transformer_blocks.2.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn2.to_out.0.bias": "blocks.28.transformer_blocks.2.attn2.to_out.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn2.to_out.0.weight": "blocks.28.transformer_blocks.2.attn2.to_out.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn2.to_q.weight": "blocks.28.transformer_blocks.2.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn2.to_v.weight": "blocks.28.transformer_blocks.2.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.ff.net.0.proj.bias": "blocks.28.transformer_blocks.2.act_fn.proj.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.ff.net.0.proj.weight": "blocks.28.transformer_blocks.2.act_fn.proj.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.ff.net.2.bias": "blocks.28.transformer_blocks.2.ff.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.ff.net.2.weight": "blocks.28.transformer_blocks.2.ff.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.norm1.bias": "blocks.28.transformer_blocks.2.norm1.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.norm1.weight": "blocks.28.transformer_blocks.2.norm1.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.norm2.bias": "blocks.28.transformer_blocks.2.norm2.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.norm2.weight": "blocks.28.transformer_blocks.2.norm2.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.norm3.bias": "blocks.28.transformer_blocks.2.norm3.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.norm3.weight": "blocks.28.transformer_blocks.2.norm3.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn1.to_k.weight": "blocks.28.transformer_blocks.3.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn1.to_out.0.bias": "blocks.28.transformer_blocks.3.attn1.to_out.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn1.to_out.0.weight": "blocks.28.transformer_blocks.3.attn1.to_out.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn1.to_q.weight": "blocks.28.transformer_blocks.3.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn1.to_v.weight": "blocks.28.transformer_blocks.3.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn2.to_k.weight": "blocks.28.transformer_blocks.3.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn2.to_out.0.bias": "blocks.28.transformer_blocks.3.attn2.to_out.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn2.to_out.0.weight": "blocks.28.transformer_blocks.3.attn2.to_out.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn2.to_q.weight": "blocks.28.transformer_blocks.3.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn2.to_v.weight": "blocks.28.transformer_blocks.3.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.ff.net.0.proj.bias": "blocks.28.transformer_blocks.3.act_fn.proj.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.ff.net.0.proj.weight": "blocks.28.transformer_blocks.3.act_fn.proj.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.ff.net.2.bias": "blocks.28.transformer_blocks.3.ff.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.ff.net.2.weight": "blocks.28.transformer_blocks.3.ff.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.norm1.bias": "blocks.28.transformer_blocks.3.norm1.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.norm1.weight": "blocks.28.transformer_blocks.3.norm1.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.norm2.bias": "blocks.28.transformer_blocks.3.norm2.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.norm2.weight": "blocks.28.transformer_blocks.3.norm2.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.norm3.bias": "blocks.28.transformer_blocks.3.norm3.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.norm3.weight": "blocks.28.transformer_blocks.3.norm3.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.attn1.to_k.weight": "blocks.28.transformer_blocks.4.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.attn1.to_out.0.bias": "blocks.28.transformer_blocks.4.attn1.to_out.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.attn1.to_out.0.weight": "blocks.28.transformer_blocks.4.attn1.to_out.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.attn1.to_q.weight": "blocks.28.transformer_blocks.4.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.attn1.to_v.weight": "blocks.28.transformer_blocks.4.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.attn2.to_k.weight": "blocks.28.transformer_blocks.4.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.attn2.to_out.0.bias": "blocks.28.transformer_blocks.4.attn2.to_out.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.attn2.to_out.0.weight": "blocks.28.transformer_blocks.4.attn2.to_out.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.attn2.to_q.weight": "blocks.28.transformer_blocks.4.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.attn2.to_v.weight": "blocks.28.transformer_blocks.4.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.ff.net.0.proj.bias": "blocks.28.transformer_blocks.4.act_fn.proj.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.ff.net.0.proj.weight": "blocks.28.transformer_blocks.4.act_fn.proj.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.ff.net.2.bias": "blocks.28.transformer_blocks.4.ff.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.ff.net.2.weight": "blocks.28.transformer_blocks.4.ff.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.norm1.bias": "blocks.28.transformer_blocks.4.norm1.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.norm1.weight": "blocks.28.transformer_blocks.4.norm1.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.norm2.bias": "blocks.28.transformer_blocks.4.norm2.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.norm2.weight": "blocks.28.transformer_blocks.4.norm2.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.norm3.bias": "blocks.28.transformer_blocks.4.norm3.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.norm3.weight": "blocks.28.transformer_blocks.4.norm3.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.attn1.to_k.weight": "blocks.28.transformer_blocks.5.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.attn1.to_out.0.bias": "blocks.28.transformer_blocks.5.attn1.to_out.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.attn1.to_out.0.weight": "blocks.28.transformer_blocks.5.attn1.to_out.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.attn1.to_q.weight": "blocks.28.transformer_blocks.5.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.attn1.to_v.weight": "blocks.28.transformer_blocks.5.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.attn2.to_k.weight": "blocks.28.transformer_blocks.5.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.attn2.to_out.0.bias": "blocks.28.transformer_blocks.5.attn2.to_out.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.attn2.to_out.0.weight": "blocks.28.transformer_blocks.5.attn2.to_out.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.attn2.to_q.weight": "blocks.28.transformer_blocks.5.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.attn2.to_v.weight": "blocks.28.transformer_blocks.5.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.ff.net.0.proj.bias": "blocks.28.transformer_blocks.5.act_fn.proj.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.ff.net.0.proj.weight": "blocks.28.transformer_blocks.5.act_fn.proj.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.ff.net.2.bias": "blocks.28.transformer_blocks.5.ff.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.ff.net.2.weight": "blocks.28.transformer_blocks.5.ff.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.norm1.bias": "blocks.28.transformer_blocks.5.norm1.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.norm1.weight": "blocks.28.transformer_blocks.5.norm1.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.norm2.bias": "blocks.28.transformer_blocks.5.norm2.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.norm2.weight": "blocks.28.transformer_blocks.5.norm2.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.norm3.bias": "blocks.28.transformer_blocks.5.norm3.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.norm3.weight": "blocks.28.transformer_blocks.5.norm3.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.attn1.to_k.weight": "blocks.28.transformer_blocks.6.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.attn1.to_out.0.bias": "blocks.28.transformer_blocks.6.attn1.to_out.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.attn1.to_out.0.weight": "blocks.28.transformer_blocks.6.attn1.to_out.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.attn1.to_q.weight": "blocks.28.transformer_blocks.6.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.attn1.to_v.weight": "blocks.28.transformer_blocks.6.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.attn2.to_k.weight": "blocks.28.transformer_blocks.6.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.attn2.to_out.0.bias": "blocks.28.transformer_blocks.6.attn2.to_out.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.attn2.to_out.0.weight": "blocks.28.transformer_blocks.6.attn2.to_out.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.attn2.to_q.weight": "blocks.28.transformer_blocks.6.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.attn2.to_v.weight": "blocks.28.transformer_blocks.6.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.ff.net.0.proj.bias": "blocks.28.transformer_blocks.6.act_fn.proj.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.ff.net.0.proj.weight": "blocks.28.transformer_blocks.6.act_fn.proj.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.ff.net.2.bias": "blocks.28.transformer_blocks.6.ff.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.ff.net.2.weight": "blocks.28.transformer_blocks.6.ff.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.norm1.bias": "blocks.28.transformer_blocks.6.norm1.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.norm1.weight": "blocks.28.transformer_blocks.6.norm1.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.norm2.bias": "blocks.28.transformer_blocks.6.norm2.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.norm2.weight": "blocks.28.transformer_blocks.6.norm2.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.norm3.bias": "blocks.28.transformer_blocks.6.norm3.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.norm3.weight": "blocks.28.transformer_blocks.6.norm3.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.attn1.to_k.weight": "blocks.28.transformer_blocks.7.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.attn1.to_out.0.bias": "blocks.28.transformer_blocks.7.attn1.to_out.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.attn1.to_out.0.weight": "blocks.28.transformer_blocks.7.attn1.to_out.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.attn1.to_q.weight": "blocks.28.transformer_blocks.7.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.attn1.to_v.weight": "blocks.28.transformer_blocks.7.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.attn2.to_k.weight": "blocks.28.transformer_blocks.7.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.attn2.to_out.0.bias": "blocks.28.transformer_blocks.7.attn2.to_out.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.attn2.to_out.0.weight": "blocks.28.transformer_blocks.7.attn2.to_out.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.attn2.to_q.weight": "blocks.28.transformer_blocks.7.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.attn2.to_v.weight": "blocks.28.transformer_blocks.7.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.ff.net.0.proj.bias": "blocks.28.transformer_blocks.7.act_fn.proj.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.ff.net.0.proj.weight": "blocks.28.transformer_blocks.7.act_fn.proj.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.ff.net.2.bias": "blocks.28.transformer_blocks.7.ff.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.ff.net.2.weight": "blocks.28.transformer_blocks.7.ff.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.norm1.bias": "blocks.28.transformer_blocks.7.norm1.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.norm1.weight": "blocks.28.transformer_blocks.7.norm1.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.norm2.bias": "blocks.28.transformer_blocks.7.norm2.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.norm2.weight": "blocks.28.transformer_blocks.7.norm2.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.norm3.bias": "blocks.28.transformer_blocks.7.norm3.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.norm3.weight": "blocks.28.transformer_blocks.7.norm3.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.attn1.to_k.weight": "blocks.28.transformer_blocks.8.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.attn1.to_out.0.bias": "blocks.28.transformer_blocks.8.attn1.to_out.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.attn1.to_out.0.weight": "blocks.28.transformer_blocks.8.attn1.to_out.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.attn1.to_q.weight": "blocks.28.transformer_blocks.8.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.attn1.to_v.weight": "blocks.28.transformer_blocks.8.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.attn2.to_k.weight": "blocks.28.transformer_blocks.8.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.attn2.to_out.0.bias": "blocks.28.transformer_blocks.8.attn2.to_out.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.attn2.to_out.0.weight": "blocks.28.transformer_blocks.8.attn2.to_out.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.attn2.to_q.weight": "blocks.28.transformer_blocks.8.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.attn2.to_v.weight": "blocks.28.transformer_blocks.8.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.ff.net.0.proj.bias": "blocks.28.transformer_blocks.8.act_fn.proj.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.ff.net.0.proj.weight": "blocks.28.transformer_blocks.8.act_fn.proj.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.ff.net.2.bias": "blocks.28.transformer_blocks.8.ff.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.ff.net.2.weight": "blocks.28.transformer_blocks.8.ff.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.norm1.bias": "blocks.28.transformer_blocks.8.norm1.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.norm1.weight": "blocks.28.transformer_blocks.8.norm1.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.norm2.bias": "blocks.28.transformer_blocks.8.norm2.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.norm2.weight": "blocks.28.transformer_blocks.8.norm2.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.norm3.bias": "blocks.28.transformer_blocks.8.norm3.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.norm3.weight": "blocks.28.transformer_blocks.8.norm3.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.attn1.to_k.weight": "blocks.28.transformer_blocks.9.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.attn1.to_out.0.bias": "blocks.28.transformer_blocks.9.attn1.to_out.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.attn1.to_out.0.weight": "blocks.28.transformer_blocks.9.attn1.to_out.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.attn1.to_q.weight": "blocks.28.transformer_blocks.9.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.attn1.to_v.weight": "blocks.28.transformer_blocks.9.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.attn2.to_k.weight": "blocks.28.transformer_blocks.9.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.attn2.to_out.0.bias": "blocks.28.transformer_blocks.9.attn2.to_out.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.attn2.to_out.0.weight": "blocks.28.transformer_blocks.9.attn2.to_out.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.attn2.to_q.weight": "blocks.28.transformer_blocks.9.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.attn2.to_v.weight": "blocks.28.transformer_blocks.9.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.ff.net.0.proj.bias": "blocks.28.transformer_blocks.9.act_fn.proj.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.ff.net.0.proj.weight": "blocks.28.transformer_blocks.9.act_fn.proj.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.ff.net.2.bias": "blocks.28.transformer_blocks.9.ff.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.ff.net.2.weight": "blocks.28.transformer_blocks.9.ff.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.norm1.bias": "blocks.28.transformer_blocks.9.norm1.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.norm1.weight": "blocks.28.transformer_blocks.9.norm1.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.norm2.bias": "blocks.28.transformer_blocks.9.norm2.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.norm2.weight": "blocks.28.transformer_blocks.9.norm2.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.norm3.bias": "blocks.28.transformer_blocks.9.norm3.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.norm3.weight": "blocks.28.transformer_blocks.9.norm3.weight",
+ "model.diffusion_model.output_blocks.2.0.emb_layers.1.bias": "blocks.30.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.2.0.emb_layers.1.weight": "blocks.30.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.2.0.in_layers.0.bias": "blocks.30.norm1.bias",
+ "model.diffusion_model.output_blocks.2.0.in_layers.0.weight": "blocks.30.norm1.weight",
+ "model.diffusion_model.output_blocks.2.0.in_layers.2.bias": "blocks.30.conv1.bias",
+ "model.diffusion_model.output_blocks.2.0.in_layers.2.weight": "blocks.30.conv1.weight",
+ "model.diffusion_model.output_blocks.2.0.out_layers.0.bias": "blocks.30.norm2.bias",
+ "model.diffusion_model.output_blocks.2.0.out_layers.0.weight": "blocks.30.norm2.weight",
+ "model.diffusion_model.output_blocks.2.0.out_layers.3.bias": "blocks.30.conv2.bias",
+ "model.diffusion_model.output_blocks.2.0.out_layers.3.weight": "blocks.30.conv2.weight",
+ "model.diffusion_model.output_blocks.2.0.skip_connection.bias": "blocks.30.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.2.0.skip_connection.weight": "blocks.30.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.2.1.norm.bias": "blocks.31.norm.bias",
+ "model.diffusion_model.output_blocks.2.1.norm.weight": "blocks.31.norm.weight",
+ "model.diffusion_model.output_blocks.2.1.proj_in.bias": "blocks.31.proj_in.bias",
+ "model.diffusion_model.output_blocks.2.1.proj_in.weight": "blocks.31.proj_in.weight",
+ "model.diffusion_model.output_blocks.2.1.proj_out.bias": "blocks.31.proj_out.bias",
+ "model.diffusion_model.output_blocks.2.1.proj_out.weight": "blocks.31.proj_out.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn1.to_k.weight": "blocks.31.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.31.transformer_blocks.0.attn1.to_out.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.31.transformer_blocks.0.attn1.to_out.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn1.to_q.weight": "blocks.31.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn1.to_v.weight": "blocks.31.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn2.to_k.weight": "blocks.31.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.31.transformer_blocks.0.attn2.to_out.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.31.transformer_blocks.0.attn2.to_out.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn2.to_q.weight": "blocks.31.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn2.to_v.weight": "blocks.31.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.31.transformer_blocks.0.act_fn.proj.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.31.transformer_blocks.0.act_fn.proj.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.ff.net.2.bias": "blocks.31.transformer_blocks.0.ff.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.ff.net.2.weight": "blocks.31.transformer_blocks.0.ff.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.norm1.bias": "blocks.31.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.norm1.weight": "blocks.31.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.norm2.bias": "blocks.31.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.norm2.weight": "blocks.31.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.norm3.bias": "blocks.31.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.norm3.weight": "blocks.31.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn1.to_k.weight": "blocks.31.transformer_blocks.1.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn1.to_out.0.bias": "blocks.31.transformer_blocks.1.attn1.to_out.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn1.to_out.0.weight": "blocks.31.transformer_blocks.1.attn1.to_out.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn1.to_q.weight": "blocks.31.transformer_blocks.1.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn1.to_v.weight": "blocks.31.transformer_blocks.1.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn2.to_k.weight": "blocks.31.transformer_blocks.1.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn2.to_out.0.bias": "blocks.31.transformer_blocks.1.attn2.to_out.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn2.to_out.0.weight": "blocks.31.transformer_blocks.1.attn2.to_out.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn2.to_q.weight": "blocks.31.transformer_blocks.1.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn2.to_v.weight": "blocks.31.transformer_blocks.1.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.ff.net.0.proj.bias": "blocks.31.transformer_blocks.1.act_fn.proj.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.ff.net.0.proj.weight": "blocks.31.transformer_blocks.1.act_fn.proj.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.ff.net.2.bias": "blocks.31.transformer_blocks.1.ff.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.ff.net.2.weight": "blocks.31.transformer_blocks.1.ff.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.norm1.bias": "blocks.31.transformer_blocks.1.norm1.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.norm1.weight": "blocks.31.transformer_blocks.1.norm1.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.norm2.bias": "blocks.31.transformer_blocks.1.norm2.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.norm2.weight": "blocks.31.transformer_blocks.1.norm2.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.norm3.bias": "blocks.31.transformer_blocks.1.norm3.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.norm3.weight": "blocks.31.transformer_blocks.1.norm3.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn1.to_k.weight": "blocks.31.transformer_blocks.2.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn1.to_out.0.bias": "blocks.31.transformer_blocks.2.attn1.to_out.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn1.to_out.0.weight": "blocks.31.transformer_blocks.2.attn1.to_out.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn1.to_q.weight": "blocks.31.transformer_blocks.2.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn1.to_v.weight": "blocks.31.transformer_blocks.2.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn2.to_k.weight": "blocks.31.transformer_blocks.2.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn2.to_out.0.bias": "blocks.31.transformer_blocks.2.attn2.to_out.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn2.to_out.0.weight": "blocks.31.transformer_blocks.2.attn2.to_out.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn2.to_q.weight": "blocks.31.transformer_blocks.2.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn2.to_v.weight": "blocks.31.transformer_blocks.2.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.ff.net.0.proj.bias": "blocks.31.transformer_blocks.2.act_fn.proj.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.ff.net.0.proj.weight": "blocks.31.transformer_blocks.2.act_fn.proj.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.ff.net.2.bias": "blocks.31.transformer_blocks.2.ff.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.ff.net.2.weight": "blocks.31.transformer_blocks.2.ff.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.norm1.bias": "blocks.31.transformer_blocks.2.norm1.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.norm1.weight": "blocks.31.transformer_blocks.2.norm1.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.norm2.bias": "blocks.31.transformer_blocks.2.norm2.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.norm2.weight": "blocks.31.transformer_blocks.2.norm2.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.norm3.bias": "blocks.31.transformer_blocks.2.norm3.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.norm3.weight": "blocks.31.transformer_blocks.2.norm3.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn1.to_k.weight": "blocks.31.transformer_blocks.3.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn1.to_out.0.bias": "blocks.31.transformer_blocks.3.attn1.to_out.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn1.to_out.0.weight": "blocks.31.transformer_blocks.3.attn1.to_out.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn1.to_q.weight": "blocks.31.transformer_blocks.3.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn1.to_v.weight": "blocks.31.transformer_blocks.3.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn2.to_k.weight": "blocks.31.transformer_blocks.3.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn2.to_out.0.bias": "blocks.31.transformer_blocks.3.attn2.to_out.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn2.to_out.0.weight": "blocks.31.transformer_blocks.3.attn2.to_out.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn2.to_q.weight": "blocks.31.transformer_blocks.3.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn2.to_v.weight": "blocks.31.transformer_blocks.3.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.ff.net.0.proj.bias": "blocks.31.transformer_blocks.3.act_fn.proj.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.ff.net.0.proj.weight": "blocks.31.transformer_blocks.3.act_fn.proj.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.ff.net.2.bias": "blocks.31.transformer_blocks.3.ff.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.ff.net.2.weight": "blocks.31.transformer_blocks.3.ff.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.norm1.bias": "blocks.31.transformer_blocks.3.norm1.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.norm1.weight": "blocks.31.transformer_blocks.3.norm1.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.norm2.bias": "blocks.31.transformer_blocks.3.norm2.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.norm2.weight": "blocks.31.transformer_blocks.3.norm2.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.norm3.bias": "blocks.31.transformer_blocks.3.norm3.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.norm3.weight": "blocks.31.transformer_blocks.3.norm3.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn1.to_k.weight": "blocks.31.transformer_blocks.4.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn1.to_out.0.bias": "blocks.31.transformer_blocks.4.attn1.to_out.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn1.to_out.0.weight": "blocks.31.transformer_blocks.4.attn1.to_out.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn1.to_q.weight": "blocks.31.transformer_blocks.4.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn1.to_v.weight": "blocks.31.transformer_blocks.4.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn2.to_k.weight": "blocks.31.transformer_blocks.4.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn2.to_out.0.bias": "blocks.31.transformer_blocks.4.attn2.to_out.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn2.to_out.0.weight": "blocks.31.transformer_blocks.4.attn2.to_out.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn2.to_q.weight": "blocks.31.transformer_blocks.4.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn2.to_v.weight": "blocks.31.transformer_blocks.4.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.ff.net.0.proj.bias": "blocks.31.transformer_blocks.4.act_fn.proj.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.ff.net.0.proj.weight": "blocks.31.transformer_blocks.4.act_fn.proj.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.ff.net.2.bias": "blocks.31.transformer_blocks.4.ff.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.ff.net.2.weight": "blocks.31.transformer_blocks.4.ff.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.norm1.bias": "blocks.31.transformer_blocks.4.norm1.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.norm1.weight": "blocks.31.transformer_blocks.4.norm1.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.norm2.bias": "blocks.31.transformer_blocks.4.norm2.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.norm2.weight": "blocks.31.transformer_blocks.4.norm2.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.norm3.bias": "blocks.31.transformer_blocks.4.norm3.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.norm3.weight": "blocks.31.transformer_blocks.4.norm3.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn1.to_k.weight": "blocks.31.transformer_blocks.5.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn1.to_out.0.bias": "blocks.31.transformer_blocks.5.attn1.to_out.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn1.to_out.0.weight": "blocks.31.transformer_blocks.5.attn1.to_out.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn1.to_q.weight": "blocks.31.transformer_blocks.5.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn1.to_v.weight": "blocks.31.transformer_blocks.5.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn2.to_k.weight": "blocks.31.transformer_blocks.5.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn2.to_out.0.bias": "blocks.31.transformer_blocks.5.attn2.to_out.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn2.to_out.0.weight": "blocks.31.transformer_blocks.5.attn2.to_out.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn2.to_q.weight": "blocks.31.transformer_blocks.5.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn2.to_v.weight": "blocks.31.transformer_blocks.5.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.ff.net.0.proj.bias": "blocks.31.transformer_blocks.5.act_fn.proj.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.ff.net.0.proj.weight": "blocks.31.transformer_blocks.5.act_fn.proj.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.ff.net.2.bias": "blocks.31.transformer_blocks.5.ff.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.ff.net.2.weight": "blocks.31.transformer_blocks.5.ff.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.norm1.bias": "blocks.31.transformer_blocks.5.norm1.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.norm1.weight": "blocks.31.transformer_blocks.5.norm1.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.norm2.bias": "blocks.31.transformer_blocks.5.norm2.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.norm2.weight": "blocks.31.transformer_blocks.5.norm2.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.norm3.bias": "blocks.31.transformer_blocks.5.norm3.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.norm3.weight": "blocks.31.transformer_blocks.5.norm3.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn1.to_k.weight": "blocks.31.transformer_blocks.6.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn1.to_out.0.bias": "blocks.31.transformer_blocks.6.attn1.to_out.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn1.to_out.0.weight": "blocks.31.transformer_blocks.6.attn1.to_out.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn1.to_q.weight": "blocks.31.transformer_blocks.6.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn1.to_v.weight": "blocks.31.transformer_blocks.6.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn2.to_k.weight": "blocks.31.transformer_blocks.6.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn2.to_out.0.bias": "blocks.31.transformer_blocks.6.attn2.to_out.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn2.to_out.0.weight": "blocks.31.transformer_blocks.6.attn2.to_out.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn2.to_q.weight": "blocks.31.transformer_blocks.6.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn2.to_v.weight": "blocks.31.transformer_blocks.6.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.ff.net.0.proj.bias": "blocks.31.transformer_blocks.6.act_fn.proj.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.ff.net.0.proj.weight": "blocks.31.transformer_blocks.6.act_fn.proj.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.ff.net.2.bias": "blocks.31.transformer_blocks.6.ff.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.ff.net.2.weight": "blocks.31.transformer_blocks.6.ff.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.norm1.bias": "blocks.31.transformer_blocks.6.norm1.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.norm1.weight": "blocks.31.transformer_blocks.6.norm1.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.norm2.bias": "blocks.31.transformer_blocks.6.norm2.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.norm2.weight": "blocks.31.transformer_blocks.6.norm2.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.norm3.bias": "blocks.31.transformer_blocks.6.norm3.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.norm3.weight": "blocks.31.transformer_blocks.6.norm3.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn1.to_k.weight": "blocks.31.transformer_blocks.7.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn1.to_out.0.bias": "blocks.31.transformer_blocks.7.attn1.to_out.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn1.to_out.0.weight": "blocks.31.transformer_blocks.7.attn1.to_out.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn1.to_q.weight": "blocks.31.transformer_blocks.7.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn1.to_v.weight": "blocks.31.transformer_blocks.7.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn2.to_k.weight": "blocks.31.transformer_blocks.7.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn2.to_out.0.bias": "blocks.31.transformer_blocks.7.attn2.to_out.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn2.to_out.0.weight": "blocks.31.transformer_blocks.7.attn2.to_out.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn2.to_q.weight": "blocks.31.transformer_blocks.7.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn2.to_v.weight": "blocks.31.transformer_blocks.7.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.ff.net.0.proj.bias": "blocks.31.transformer_blocks.7.act_fn.proj.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.ff.net.0.proj.weight": "blocks.31.transformer_blocks.7.act_fn.proj.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.ff.net.2.bias": "blocks.31.transformer_blocks.7.ff.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.ff.net.2.weight": "blocks.31.transformer_blocks.7.ff.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.norm1.bias": "blocks.31.transformer_blocks.7.norm1.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.norm1.weight": "blocks.31.transformer_blocks.7.norm1.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.norm2.bias": "blocks.31.transformer_blocks.7.norm2.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.norm2.weight": "blocks.31.transformer_blocks.7.norm2.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.norm3.bias": "blocks.31.transformer_blocks.7.norm3.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.norm3.weight": "blocks.31.transformer_blocks.7.norm3.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn1.to_k.weight": "blocks.31.transformer_blocks.8.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn1.to_out.0.bias": "blocks.31.transformer_blocks.8.attn1.to_out.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn1.to_out.0.weight": "blocks.31.transformer_blocks.8.attn1.to_out.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn1.to_q.weight": "blocks.31.transformer_blocks.8.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn1.to_v.weight": "blocks.31.transformer_blocks.8.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn2.to_k.weight": "blocks.31.transformer_blocks.8.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn2.to_out.0.bias": "blocks.31.transformer_blocks.8.attn2.to_out.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn2.to_out.0.weight": "blocks.31.transformer_blocks.8.attn2.to_out.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn2.to_q.weight": "blocks.31.transformer_blocks.8.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn2.to_v.weight": "blocks.31.transformer_blocks.8.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.ff.net.0.proj.bias": "blocks.31.transformer_blocks.8.act_fn.proj.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.ff.net.0.proj.weight": "blocks.31.transformer_blocks.8.act_fn.proj.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.ff.net.2.bias": "blocks.31.transformer_blocks.8.ff.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.ff.net.2.weight": "blocks.31.transformer_blocks.8.ff.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.norm1.bias": "blocks.31.transformer_blocks.8.norm1.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.norm1.weight": "blocks.31.transformer_blocks.8.norm1.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.norm2.bias": "blocks.31.transformer_blocks.8.norm2.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.norm2.weight": "blocks.31.transformer_blocks.8.norm2.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.norm3.bias": "blocks.31.transformer_blocks.8.norm3.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.norm3.weight": "blocks.31.transformer_blocks.8.norm3.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn1.to_k.weight": "blocks.31.transformer_blocks.9.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn1.to_out.0.bias": "blocks.31.transformer_blocks.9.attn1.to_out.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn1.to_out.0.weight": "blocks.31.transformer_blocks.9.attn1.to_out.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn1.to_q.weight": "blocks.31.transformer_blocks.9.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn1.to_v.weight": "blocks.31.transformer_blocks.9.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn2.to_k.weight": "blocks.31.transformer_blocks.9.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn2.to_out.0.bias": "blocks.31.transformer_blocks.9.attn2.to_out.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn2.to_out.0.weight": "blocks.31.transformer_blocks.9.attn2.to_out.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn2.to_q.weight": "blocks.31.transformer_blocks.9.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn2.to_v.weight": "blocks.31.transformer_blocks.9.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.ff.net.0.proj.bias": "blocks.31.transformer_blocks.9.act_fn.proj.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.ff.net.0.proj.weight": "blocks.31.transformer_blocks.9.act_fn.proj.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.ff.net.2.bias": "blocks.31.transformer_blocks.9.ff.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.ff.net.2.weight": "blocks.31.transformer_blocks.9.ff.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.norm1.bias": "blocks.31.transformer_blocks.9.norm1.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.norm1.weight": "blocks.31.transformer_blocks.9.norm1.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.norm2.bias": "blocks.31.transformer_blocks.9.norm2.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.norm2.weight": "blocks.31.transformer_blocks.9.norm2.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.norm3.bias": "blocks.31.transformer_blocks.9.norm3.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.norm3.weight": "blocks.31.transformer_blocks.9.norm3.weight",
+ "model.diffusion_model.output_blocks.2.2.conv.bias": "blocks.32.conv.bias",
+ "model.diffusion_model.output_blocks.2.2.conv.weight": "blocks.32.conv.weight",
+ "model.diffusion_model.output_blocks.3.0.emb_layers.1.bias": "blocks.34.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.3.0.emb_layers.1.weight": "blocks.34.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.3.0.in_layers.0.bias": "blocks.34.norm1.bias",
+ "model.diffusion_model.output_blocks.3.0.in_layers.0.weight": "blocks.34.norm1.weight",
+ "model.diffusion_model.output_blocks.3.0.in_layers.2.bias": "blocks.34.conv1.bias",
+ "model.diffusion_model.output_blocks.3.0.in_layers.2.weight": "blocks.34.conv1.weight",
+ "model.diffusion_model.output_blocks.3.0.out_layers.0.bias": "blocks.34.norm2.bias",
+ "model.diffusion_model.output_blocks.3.0.out_layers.0.weight": "blocks.34.norm2.weight",
+ "model.diffusion_model.output_blocks.3.0.out_layers.3.bias": "blocks.34.conv2.bias",
+ "model.diffusion_model.output_blocks.3.0.out_layers.3.weight": "blocks.34.conv2.weight",
+ "model.diffusion_model.output_blocks.3.0.skip_connection.bias": "blocks.34.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.3.0.skip_connection.weight": "blocks.34.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.3.1.norm.bias": "blocks.35.norm.bias",
+ "model.diffusion_model.output_blocks.3.1.norm.weight": "blocks.35.norm.weight",
+ "model.diffusion_model.output_blocks.3.1.proj_in.bias": "blocks.35.proj_in.bias",
+ "model.diffusion_model.output_blocks.3.1.proj_in.weight": "blocks.35.proj_in.weight",
+ "model.diffusion_model.output_blocks.3.1.proj_out.bias": "blocks.35.proj_out.bias",
+ "model.diffusion_model.output_blocks.3.1.proj_out.weight": "blocks.35.proj_out.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_k.weight": "blocks.35.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.35.transformer_blocks.0.attn1.to_out.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.35.transformer_blocks.0.attn1.to_out.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_q.weight": "blocks.35.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_v.weight": "blocks.35.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_k.weight": "blocks.35.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.35.transformer_blocks.0.attn2.to_out.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.35.transformer_blocks.0.attn2.to_out.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_q.weight": "blocks.35.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_v.weight": "blocks.35.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.35.transformer_blocks.0.act_fn.proj.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.35.transformer_blocks.0.act_fn.proj.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.2.bias": "blocks.35.transformer_blocks.0.ff.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.2.weight": "blocks.35.transformer_blocks.0.ff.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm1.bias": "blocks.35.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm1.weight": "blocks.35.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm2.bias": "blocks.35.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm2.weight": "blocks.35.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm3.bias": "blocks.35.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm3.weight": "blocks.35.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn1.to_k.weight": "blocks.35.transformer_blocks.1.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn1.to_out.0.bias": "blocks.35.transformer_blocks.1.attn1.to_out.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn1.to_out.0.weight": "blocks.35.transformer_blocks.1.attn1.to_out.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn1.to_q.weight": "blocks.35.transformer_blocks.1.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn1.to_v.weight": "blocks.35.transformer_blocks.1.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn2.to_k.weight": "blocks.35.transformer_blocks.1.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn2.to_out.0.bias": "blocks.35.transformer_blocks.1.attn2.to_out.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn2.to_out.0.weight": "blocks.35.transformer_blocks.1.attn2.to_out.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn2.to_q.weight": "blocks.35.transformer_blocks.1.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn2.to_v.weight": "blocks.35.transformer_blocks.1.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.ff.net.0.proj.bias": "blocks.35.transformer_blocks.1.act_fn.proj.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.ff.net.0.proj.weight": "blocks.35.transformer_blocks.1.act_fn.proj.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.ff.net.2.bias": "blocks.35.transformer_blocks.1.ff.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.ff.net.2.weight": "blocks.35.transformer_blocks.1.ff.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.norm1.bias": "blocks.35.transformer_blocks.1.norm1.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.norm1.weight": "blocks.35.transformer_blocks.1.norm1.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.norm2.bias": "blocks.35.transformer_blocks.1.norm2.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.norm2.weight": "blocks.35.transformer_blocks.1.norm2.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.norm3.bias": "blocks.35.transformer_blocks.1.norm3.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.norm3.weight": "blocks.35.transformer_blocks.1.norm3.weight",
+ "model.diffusion_model.output_blocks.4.0.emb_layers.1.bias": "blocks.37.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.4.0.emb_layers.1.weight": "blocks.37.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.4.0.in_layers.0.bias": "blocks.37.norm1.bias",
+ "model.diffusion_model.output_blocks.4.0.in_layers.0.weight": "blocks.37.norm1.weight",
+ "model.diffusion_model.output_blocks.4.0.in_layers.2.bias": "blocks.37.conv1.bias",
+ "model.diffusion_model.output_blocks.4.0.in_layers.2.weight": "blocks.37.conv1.weight",
+ "model.diffusion_model.output_blocks.4.0.out_layers.0.bias": "blocks.37.norm2.bias",
+ "model.diffusion_model.output_blocks.4.0.out_layers.0.weight": "blocks.37.norm2.weight",
+ "model.diffusion_model.output_blocks.4.0.out_layers.3.bias": "blocks.37.conv2.bias",
+ "model.diffusion_model.output_blocks.4.0.out_layers.3.weight": "blocks.37.conv2.weight",
+ "model.diffusion_model.output_blocks.4.0.skip_connection.bias": "blocks.37.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.4.0.skip_connection.weight": "blocks.37.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.4.1.norm.bias": "blocks.38.norm.bias",
+ "model.diffusion_model.output_blocks.4.1.norm.weight": "blocks.38.norm.weight",
+ "model.diffusion_model.output_blocks.4.1.proj_in.bias": "blocks.38.proj_in.bias",
+ "model.diffusion_model.output_blocks.4.1.proj_in.weight": "blocks.38.proj_in.weight",
+ "model.diffusion_model.output_blocks.4.1.proj_out.bias": "blocks.38.proj_out.bias",
+ "model.diffusion_model.output_blocks.4.1.proj_out.weight": "blocks.38.proj_out.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_k.weight": "blocks.38.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.38.transformer_blocks.0.attn1.to_out.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.38.transformer_blocks.0.attn1.to_out.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_q.weight": "blocks.38.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_v.weight": "blocks.38.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_k.weight": "blocks.38.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.38.transformer_blocks.0.attn2.to_out.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.38.transformer_blocks.0.attn2.to_out.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_q.weight": "blocks.38.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_v.weight": "blocks.38.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.38.transformer_blocks.0.act_fn.proj.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.38.transformer_blocks.0.act_fn.proj.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.2.bias": "blocks.38.transformer_blocks.0.ff.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.2.weight": "blocks.38.transformer_blocks.0.ff.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm1.bias": "blocks.38.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm1.weight": "blocks.38.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm2.bias": "blocks.38.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm2.weight": "blocks.38.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm3.bias": "blocks.38.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm3.weight": "blocks.38.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn1.to_k.weight": "blocks.38.transformer_blocks.1.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn1.to_out.0.bias": "blocks.38.transformer_blocks.1.attn1.to_out.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn1.to_out.0.weight": "blocks.38.transformer_blocks.1.attn1.to_out.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn1.to_q.weight": "blocks.38.transformer_blocks.1.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn1.to_v.weight": "blocks.38.transformer_blocks.1.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn2.to_k.weight": "blocks.38.transformer_blocks.1.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn2.to_out.0.bias": "blocks.38.transformer_blocks.1.attn2.to_out.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn2.to_out.0.weight": "blocks.38.transformer_blocks.1.attn2.to_out.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn2.to_q.weight": "blocks.38.transformer_blocks.1.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn2.to_v.weight": "blocks.38.transformer_blocks.1.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.ff.net.0.proj.bias": "blocks.38.transformer_blocks.1.act_fn.proj.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.ff.net.0.proj.weight": "blocks.38.transformer_blocks.1.act_fn.proj.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.ff.net.2.bias": "blocks.38.transformer_blocks.1.ff.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.ff.net.2.weight": "blocks.38.transformer_blocks.1.ff.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.norm1.bias": "blocks.38.transformer_blocks.1.norm1.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.norm1.weight": "blocks.38.transformer_blocks.1.norm1.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.norm2.bias": "blocks.38.transformer_blocks.1.norm2.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.norm2.weight": "blocks.38.transformer_blocks.1.norm2.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.norm3.bias": "blocks.38.transformer_blocks.1.norm3.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.norm3.weight": "blocks.38.transformer_blocks.1.norm3.weight",
+ "model.diffusion_model.output_blocks.5.0.emb_layers.1.bias": "blocks.40.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.5.0.emb_layers.1.weight": "blocks.40.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.5.0.in_layers.0.bias": "blocks.40.norm1.bias",
+ "model.diffusion_model.output_blocks.5.0.in_layers.0.weight": "blocks.40.norm1.weight",
+ "model.diffusion_model.output_blocks.5.0.in_layers.2.bias": "blocks.40.conv1.bias",
+ "model.diffusion_model.output_blocks.5.0.in_layers.2.weight": "blocks.40.conv1.weight",
+ "model.diffusion_model.output_blocks.5.0.out_layers.0.bias": "blocks.40.norm2.bias",
+ "model.diffusion_model.output_blocks.5.0.out_layers.0.weight": "blocks.40.norm2.weight",
+ "model.diffusion_model.output_blocks.5.0.out_layers.3.bias": "blocks.40.conv2.bias",
+ "model.diffusion_model.output_blocks.5.0.out_layers.3.weight": "blocks.40.conv2.weight",
+ "model.diffusion_model.output_blocks.5.0.skip_connection.bias": "blocks.40.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.5.0.skip_connection.weight": "blocks.40.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.5.1.norm.bias": "blocks.41.norm.bias",
+ "model.diffusion_model.output_blocks.5.1.norm.weight": "blocks.41.norm.weight",
+ "model.diffusion_model.output_blocks.5.1.proj_in.bias": "blocks.41.proj_in.bias",
+ "model.diffusion_model.output_blocks.5.1.proj_in.weight": "blocks.41.proj_in.weight",
+ "model.diffusion_model.output_blocks.5.1.proj_out.bias": "blocks.41.proj_out.bias",
+ "model.diffusion_model.output_blocks.5.1.proj_out.weight": "blocks.41.proj_out.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_k.weight": "blocks.41.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.41.transformer_blocks.0.attn1.to_out.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.41.transformer_blocks.0.attn1.to_out.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_q.weight": "blocks.41.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_v.weight": "blocks.41.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_k.weight": "blocks.41.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.41.transformer_blocks.0.attn2.to_out.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.41.transformer_blocks.0.attn2.to_out.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_q.weight": "blocks.41.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_v.weight": "blocks.41.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.41.transformer_blocks.0.act_fn.proj.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.41.transformer_blocks.0.act_fn.proj.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.2.bias": "blocks.41.transformer_blocks.0.ff.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.2.weight": "blocks.41.transformer_blocks.0.ff.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm1.bias": "blocks.41.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm1.weight": "blocks.41.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm2.bias": "blocks.41.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm2.weight": "blocks.41.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm3.bias": "blocks.41.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm3.weight": "blocks.41.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn1.to_k.weight": "blocks.41.transformer_blocks.1.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn1.to_out.0.bias": "blocks.41.transformer_blocks.1.attn1.to_out.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn1.to_out.0.weight": "blocks.41.transformer_blocks.1.attn1.to_out.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn1.to_q.weight": "blocks.41.transformer_blocks.1.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn1.to_v.weight": "blocks.41.transformer_blocks.1.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn2.to_k.weight": "blocks.41.transformer_blocks.1.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn2.to_out.0.bias": "blocks.41.transformer_blocks.1.attn2.to_out.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn2.to_out.0.weight": "blocks.41.transformer_blocks.1.attn2.to_out.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn2.to_q.weight": "blocks.41.transformer_blocks.1.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn2.to_v.weight": "blocks.41.transformer_blocks.1.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.ff.net.0.proj.bias": "blocks.41.transformer_blocks.1.act_fn.proj.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.ff.net.0.proj.weight": "blocks.41.transformer_blocks.1.act_fn.proj.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.ff.net.2.bias": "blocks.41.transformer_blocks.1.ff.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.ff.net.2.weight": "blocks.41.transformer_blocks.1.ff.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.norm1.bias": "blocks.41.transformer_blocks.1.norm1.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.norm1.weight": "blocks.41.transformer_blocks.1.norm1.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.norm2.bias": "blocks.41.transformer_blocks.1.norm2.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.norm2.weight": "blocks.41.transformer_blocks.1.norm2.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.norm3.bias": "blocks.41.transformer_blocks.1.norm3.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.norm3.weight": "blocks.41.transformer_blocks.1.norm3.weight",
+ "model.diffusion_model.output_blocks.5.2.conv.bias": "blocks.42.conv.bias",
+ "model.diffusion_model.output_blocks.5.2.conv.weight": "blocks.42.conv.weight",
+ "model.diffusion_model.output_blocks.6.0.emb_layers.1.bias": "blocks.44.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.6.0.emb_layers.1.weight": "blocks.44.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.6.0.in_layers.0.bias": "blocks.44.norm1.bias",
+ "model.diffusion_model.output_blocks.6.0.in_layers.0.weight": "blocks.44.norm1.weight",
+ "model.diffusion_model.output_blocks.6.0.in_layers.2.bias": "blocks.44.conv1.bias",
+ "model.diffusion_model.output_blocks.6.0.in_layers.2.weight": "blocks.44.conv1.weight",
+ "model.diffusion_model.output_blocks.6.0.out_layers.0.bias": "blocks.44.norm2.bias",
+ "model.diffusion_model.output_blocks.6.0.out_layers.0.weight": "blocks.44.norm2.weight",
+ "model.diffusion_model.output_blocks.6.0.out_layers.3.bias": "blocks.44.conv2.bias",
+ "model.diffusion_model.output_blocks.6.0.out_layers.3.weight": "blocks.44.conv2.weight",
+ "model.diffusion_model.output_blocks.6.0.skip_connection.bias": "blocks.44.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.6.0.skip_connection.weight": "blocks.44.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.7.0.emb_layers.1.bias": "blocks.46.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.7.0.emb_layers.1.weight": "blocks.46.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.7.0.in_layers.0.bias": "blocks.46.norm1.bias",
+ "model.diffusion_model.output_blocks.7.0.in_layers.0.weight": "blocks.46.norm1.weight",
+ "model.diffusion_model.output_blocks.7.0.in_layers.2.bias": "blocks.46.conv1.bias",
+ "model.diffusion_model.output_blocks.7.0.in_layers.2.weight": "blocks.46.conv1.weight",
+ "model.diffusion_model.output_blocks.7.0.out_layers.0.bias": "blocks.46.norm2.bias",
+ "model.diffusion_model.output_blocks.7.0.out_layers.0.weight": "blocks.46.norm2.weight",
+ "model.diffusion_model.output_blocks.7.0.out_layers.3.bias": "blocks.46.conv2.bias",
+ "model.diffusion_model.output_blocks.7.0.out_layers.3.weight": "blocks.46.conv2.weight",
+ "model.diffusion_model.output_blocks.7.0.skip_connection.bias": "blocks.46.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.7.0.skip_connection.weight": "blocks.46.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.8.0.emb_layers.1.bias": "blocks.48.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.8.0.emb_layers.1.weight": "blocks.48.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.8.0.in_layers.0.bias": "blocks.48.norm1.bias",
+ "model.diffusion_model.output_blocks.8.0.in_layers.0.weight": "blocks.48.norm1.weight",
+ "model.diffusion_model.output_blocks.8.0.in_layers.2.bias": "blocks.48.conv1.bias",
+ "model.diffusion_model.output_blocks.8.0.in_layers.2.weight": "blocks.48.conv1.weight",
+ "model.diffusion_model.output_blocks.8.0.out_layers.0.bias": "blocks.48.norm2.bias",
+ "model.diffusion_model.output_blocks.8.0.out_layers.0.weight": "blocks.48.norm2.weight",
+ "model.diffusion_model.output_blocks.8.0.out_layers.3.bias": "blocks.48.conv2.bias",
+ "model.diffusion_model.output_blocks.8.0.out_layers.3.weight": "blocks.48.conv2.weight",
+ "model.diffusion_model.output_blocks.8.0.skip_connection.bias": "blocks.48.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.8.0.skip_connection.weight": "blocks.48.conv_shortcut.weight",
+ "model.diffusion_model.time_embed.0.bias": "time_embedding.0.bias",
+ "model.diffusion_model.time_embed.0.weight": "time_embedding.0.weight",
+ "model.diffusion_model.time_embed.2.bias": "time_embedding.2.bias",
+ "model.diffusion_model.time_embed.2.weight": "time_embedding.2.weight",
+ }
+ state_dict_ = {}
+ for name in state_dict:
+ if name in rename_dict:
+ param = state_dict[name]
+ if ".proj_in." in name or ".proj_out." in name:
+ param = param.squeeze()
+ state_dict_[rename_dict[name]] = param
+ if "text_intermediate_proj.weight" in state_dict_:
+ return state_dict_, {"is_kolors": True}
+ else:
+ return state_dict_
diff --git a/PusaV1/diffsynth/models/sdxl_vae_decoder.py b/PusaV1/diffsynth/models/sdxl_vae_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..290c7851e3619f61a1baa2ddb3ad77180809297d
--- /dev/null
+++ b/PusaV1/diffsynth/models/sdxl_vae_decoder.py
@@ -0,0 +1,24 @@
+from .sd_vae_decoder import SDVAEDecoder, SDVAEDecoderStateDictConverter
+
+
+class SDXLVAEDecoder(SDVAEDecoder):
+ def __init__(self, upcast_to_float32=True):
+ super().__init__()
+ self.scaling_factor = 0.13025
+
+ @staticmethod
+ def state_dict_converter():
+ return SDXLVAEDecoderStateDictConverter()
+
+
+class SDXLVAEDecoderStateDictConverter(SDVAEDecoderStateDictConverter):
+ def __init__(self):
+ super().__init__()
+
+ def from_diffusers(self, state_dict):
+ state_dict = super().from_diffusers(state_dict)
+ return state_dict, {"upcast_to_float32": True}
+
+ def from_civitai(self, state_dict):
+ state_dict = super().from_civitai(state_dict)
+ return state_dict, {"upcast_to_float32": True}
diff --git a/PusaV1/diffsynth/models/sdxl_vae_encoder.py b/PusaV1/diffsynth/models/sdxl_vae_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..14af09cd33c41452b9777daa6819115eb900b788
--- /dev/null
+++ b/PusaV1/diffsynth/models/sdxl_vae_encoder.py
@@ -0,0 +1,24 @@
+from .sd_vae_encoder import SDVAEEncoderStateDictConverter, SDVAEEncoder
+
+
+class SDXLVAEEncoder(SDVAEEncoder):
+ def __init__(self, upcast_to_float32=True):
+ super().__init__()
+ self.scaling_factor = 0.13025
+
+ @staticmethod
+ def state_dict_converter():
+ return SDXLVAEEncoderStateDictConverter()
+
+
+class SDXLVAEEncoderStateDictConverter(SDVAEEncoderStateDictConverter):
+ def __init__(self):
+ super().__init__()
+
+ def from_diffusers(self, state_dict):
+ state_dict = super().from_diffusers(state_dict)
+ return state_dict, {"upcast_to_float32": True}
+
+ def from_civitai(self, state_dict):
+ state_dict = super().from_civitai(state_dict)
+ return state_dict, {"upcast_to_float32": True}
diff --git a/PusaV1/diffsynth/models/stepvideo_dit.py b/PusaV1/diffsynth/models/stepvideo_dit.py
new file mode 100644
index 0000000000000000000000000000000000000000..10576e77dbbcd9ac10e1e5b4ed8d52362e6ad82a
--- /dev/null
+++ b/PusaV1/diffsynth/models/stepvideo_dit.py
@@ -0,0 +1,940 @@
+# Copyright 2025 StepFun Inc. All Rights Reserved.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+# ==============================================================================
+from typing import Dict, Optional, Tuple, Union, List
+import torch, math
+from torch import nn
+from einops import rearrange, repeat
+from tqdm import tqdm
+
+
+class RMSNorm(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ elementwise_affine=True,
+ eps: float = 1e-6,
+ device=None,
+ dtype=None,
+ ):
+ """
+ Initialize the RMSNorm normalization layer.
+
+ Args:
+ dim (int): The dimension of the input tensor.
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
+
+ Attributes:
+ eps (float): A small value added to the denominator for numerical stability.
+ weight (nn.Parameter): Learnable scaling parameter.
+
+ """
+ factory_kwargs = {"device": device, "dtype": dtype}
+ super().__init__()
+ self.eps = eps
+ if elementwise_affine:
+ self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
+
+ def _norm(self, x):
+ """
+ Apply the RMSNorm normalization to the input tensor.
+
+ Args:
+ x (torch.Tensor): The input tensor.
+
+ Returns:
+ torch.Tensor: The normalized tensor.
+
+ """
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
+
+ def forward(self, x):
+ """
+ Forward pass through the RMSNorm layer.
+
+ Args:
+ x (torch.Tensor): The input tensor.
+
+ Returns:
+ torch.Tensor: The output tensor after applying RMSNorm.
+
+ """
+ output = self._norm(x.float()).type_as(x)
+ if hasattr(self, "weight"):
+ output = output * self.weight
+ return output
+
+
+ACTIVATION_FUNCTIONS = {
+ "swish": nn.SiLU(),
+ "silu": nn.SiLU(),
+ "mish": nn.Mish(),
+ "gelu": nn.GELU(),
+ "relu": nn.ReLU(),
+}
+
+
+def get_activation(act_fn: str) -> nn.Module:
+ """Helper function to get activation function from string.
+
+ Args:
+ act_fn (str): Name of activation function.
+
+ Returns:
+ nn.Module: Activation function.
+ """
+
+ act_fn = act_fn.lower()
+ if act_fn in ACTIVATION_FUNCTIONS:
+ return ACTIVATION_FUNCTIONS[act_fn]
+ else:
+ raise ValueError(f"Unsupported activation function: {act_fn}")
+
+
+def get_timestep_embedding(
+ timesteps: torch.Tensor,
+ embedding_dim: int,
+ flip_sin_to_cos: bool = False,
+ downscale_freq_shift: float = 1,
+ scale: float = 1,
+ max_period: int = 10000,
+):
+ """
+ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
+
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
+ embeddings. :return: an [N x dim] Tensor of positional embeddings.
+ """
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
+
+ half_dim = embedding_dim // 2
+ exponent = -math.log(max_period) * torch.arange(
+ start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
+ )
+ exponent = exponent / (half_dim - downscale_freq_shift)
+
+ emb = torch.exp(exponent)
+ emb = timesteps[:, None].float() * emb[None, :]
+
+ # scale embeddings
+ emb = scale * emb
+
+ # concat sine and cosine embeddings
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
+
+ # flip sine and cosine embeddings
+ if flip_sin_to_cos:
+ emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
+
+ # zero pad
+ if embedding_dim % 2 == 1:
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
+ return emb
+
+
+class Timesteps(nn.Module):
+ def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
+ super().__init__()
+ self.num_channels = num_channels
+ self.flip_sin_to_cos = flip_sin_to_cos
+ self.downscale_freq_shift = downscale_freq_shift
+
+ def forward(self, timesteps):
+ t_emb = get_timestep_embedding(
+ timesteps,
+ self.num_channels,
+ flip_sin_to_cos=self.flip_sin_to_cos,
+ downscale_freq_shift=self.downscale_freq_shift,
+ )
+ return t_emb
+
+
+class TimestepEmbedding(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ time_embed_dim: int,
+ act_fn: str = "silu",
+ out_dim: int = None,
+ post_act_fn: Optional[str] = None,
+ cond_proj_dim=None,
+ sample_proj_bias=True
+ ):
+ super().__init__()
+ linear_cls = nn.Linear
+
+ self.linear_1 = linear_cls(
+ in_channels,
+ time_embed_dim,
+ bias=sample_proj_bias,
+ )
+
+ if cond_proj_dim is not None:
+ self.cond_proj = linear_cls(
+ cond_proj_dim,
+ in_channels,
+ bias=False,
+ )
+ else:
+ self.cond_proj = None
+
+ self.act = get_activation(act_fn)
+
+ if out_dim is not None:
+ time_embed_dim_out = out_dim
+ else:
+ time_embed_dim_out = time_embed_dim
+
+ self.linear_2 = linear_cls(
+ time_embed_dim,
+ time_embed_dim_out,
+ bias=sample_proj_bias,
+ )
+
+ if post_act_fn is None:
+ self.post_act = None
+ else:
+ self.post_act = get_activation(post_act_fn)
+
+ def forward(self, sample, condition=None):
+ if condition is not None:
+ sample = sample + self.cond_proj(condition)
+ sample = self.linear_1(sample)
+
+ if self.act is not None:
+ sample = self.act(sample)
+
+ sample = self.linear_2(sample)
+
+ if self.post_act is not None:
+ sample = self.post_act(sample)
+ return sample
+
+
+class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
+ def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False):
+ super().__init__()
+
+ self.outdim = size_emb_dim
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
+
+ self.use_additional_conditions = use_additional_conditions
+ if self.use_additional_conditions:
+ self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
+ self.nframe_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
+ self.fps_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
+
+ def forward(self, timestep, resolution=None, nframe=None, fps=None):
+ hidden_dtype = timestep.dtype
+
+ timesteps_proj = self.time_proj(timestep)
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
+
+ if self.use_additional_conditions:
+ batch_size = timestep.shape[0]
+ resolution_emb = self.additional_condition_proj(resolution.flatten()).to(hidden_dtype)
+ resolution_emb = self.resolution_embedder(resolution_emb).reshape(batch_size, -1)
+ nframe_emb = self.additional_condition_proj(nframe.flatten()).to(hidden_dtype)
+ nframe_emb = self.nframe_embedder(nframe_emb).reshape(batch_size, -1)
+ conditioning = timesteps_emb + resolution_emb + nframe_emb
+
+ if fps is not None:
+ fps_emb = self.additional_condition_proj(fps.flatten()).to(hidden_dtype)
+ fps_emb = self.fps_embedder(fps_emb).reshape(batch_size, -1)
+ conditioning = conditioning + fps_emb
+ else:
+ conditioning = timesteps_emb
+
+ return conditioning
+
+
+class AdaLayerNormSingle(nn.Module):
+ r"""
+ Norm layer adaptive layer norm single (adaLN-single).
+
+ As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
+
+ Parameters:
+ embedding_dim (`int`): The size of each embedding vector.
+ use_additional_conditions (`bool`): To use additional conditions for normalization or not.
+ """
+ def __init__(self, embedding_dim: int, use_additional_conditions: bool = False, time_step_rescale=1000):
+ super().__init__()
+
+ self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
+ embedding_dim, size_emb_dim=embedding_dim // 2, use_additional_conditions=use_additional_conditions
+ )
+
+ self.silu = nn.SiLU()
+ self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
+
+ self.time_step_rescale = time_step_rescale ## timestep usually in [0, 1], we rescale it to [0,1000] for stability
+
+ def forward(
+ self,
+ timestep: torch.Tensor,
+ added_cond_kwargs: Dict[str, torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ embedded_timestep = self.emb(timestep*self.time_step_rescale, **added_cond_kwargs)
+
+ out = self.linear(self.silu(embedded_timestep))
+
+ return out, embedded_timestep
+
+
+class PixArtAlphaTextProjection(nn.Module):
+ """
+ Projects caption embeddings. Also handles dropout for classifier-free guidance.
+
+ Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
+ """
+
+ def __init__(self, in_features, hidden_size):
+ super().__init__()
+ self.linear_1 = nn.Linear(
+ in_features,
+ hidden_size,
+ bias=True,
+ )
+ self.act_1 = nn.GELU(approximate="tanh")
+ self.linear_2 = nn.Linear(
+ hidden_size,
+ hidden_size,
+ bias=True,
+ )
+
+ def forward(self, caption):
+ hidden_states = self.linear_1(caption)
+ hidden_states = self.act_1(hidden_states)
+ hidden_states = self.linear_2(hidden_states)
+ return hidden_states
+
+
+class Attention(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def attn_processor(self, attn_type):
+ if attn_type == 'torch':
+ return self.torch_attn_func
+ elif attn_type == 'parallel':
+ return self.parallel_attn_func
+ else:
+ raise Exception('Not supported attention type...')
+
+ def torch_attn_func(
+ self,
+ q,
+ k,
+ v,
+ attn_mask=None,
+ causal=False,
+ drop_rate=0.0,
+ **kwargs
+ ):
+
+ if attn_mask is not None and attn_mask.dtype != torch.bool:
+ attn_mask = attn_mask.to(q.dtype)
+
+ if attn_mask is not None and attn_mask.ndim == 3: ## no head
+ n_heads = q.shape[2]
+ attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1)
+
+ q, k, v = map(lambda x: rearrange(x, 'b s h d -> b h s d'), (q, k, v))
+ if attn_mask is not None:
+ attn_mask = attn_mask.to(q.device)
+ x = torch.nn.functional.scaled_dot_product_attention(
+ q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal
+ )
+ x = rearrange(x, 'b h s d -> b s h d')
+ return x
+
+
+class RoPE1D:
+ def __init__(self, freq=1e4, F0=1.0, scaling_factor=1.0):
+ self.base = freq
+ self.F0 = F0
+ self.scaling_factor = scaling_factor
+ self.cache = {}
+
+ def get_cos_sin(self, D, seq_len, device, dtype):
+ if (D, seq_len, device, dtype) not in self.cache:
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2).float().to(device) / D))
+ t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
+ freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype)
+ freqs = torch.cat((freqs, freqs), dim=-1)
+ cos = freqs.cos() # (Seq, Dim)
+ sin = freqs.sin()
+ self.cache[D, seq_len, device, dtype] = (cos, sin)
+ return self.cache[D, seq_len, device, dtype]
+
+ @staticmethod
+ def rotate_half(x):
+ x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
+ return torch.cat((-x2, x1), dim=-1)
+
+ def apply_rope1d(self, tokens, pos1d, cos, sin):
+ assert pos1d.ndim == 2
+ cos = torch.nn.functional.embedding(pos1d, cos)[:, :, None, :]
+ sin = torch.nn.functional.embedding(pos1d, sin)[:, :, None, :]
+ return (tokens * cos) + (self.rotate_half(tokens) * sin)
+
+ def __call__(self, tokens, positions):
+ """
+ input:
+ * tokens: batch_size x ntokens x nheads x dim
+ * positions: batch_size x ntokens (t position of each token)
+ output:
+ * tokens after applying RoPE2D (batch_size x ntokens x nheads x dim)
+ """
+ D = tokens.size(3)
+ assert positions.ndim == 2 # Batch, Seq
+ cos, sin = self.get_cos_sin(D, int(positions.max()) + 1, tokens.device, tokens.dtype)
+ tokens = self.apply_rope1d(tokens, positions, cos, sin)
+ return tokens
+
+
+class RoPE3D(RoPE1D):
+ def __init__(self, freq=1e4, F0=1.0, scaling_factor=1.0):
+ super(RoPE3D, self).__init__(freq, F0, scaling_factor)
+ self.position_cache = {}
+
+ def get_mesh_3d(self, rope_positions, bsz):
+ f, h, w = rope_positions
+
+ if f"{f}-{h}-{w}" not in self.position_cache:
+ x = torch.arange(f, device='cpu')
+ y = torch.arange(h, device='cpu')
+ z = torch.arange(w, device='cpu')
+ self.position_cache[f"{f}-{h}-{w}"] = torch.cartesian_prod(x, y, z).view(1, f*h*w, 3).expand(bsz, -1, 3)
+ return self.position_cache[f"{f}-{h}-{w}"]
+
+ def __call__(self, tokens, rope_positions, ch_split, parallel=False):
+ """
+ input:
+ * tokens: batch_size x ntokens x nheads x dim
+ * rope_positions: list of (f, h, w)
+ output:
+ * tokens after applying RoPE2D (batch_size x ntokens x nheads x dim)
+ """
+ assert sum(ch_split) == tokens.size(-1);
+
+ mesh_grid = self.get_mesh_3d(rope_positions, bsz=tokens.shape[0])
+ out = []
+ for i, (D, x) in enumerate(zip(ch_split, torch.split(tokens, ch_split, dim=-1))):
+ cos, sin = self.get_cos_sin(D, int(mesh_grid.max()) + 1, tokens.device, tokens.dtype)
+
+ if parallel:
+ pass
+ else:
+ mesh = mesh_grid[:, :, i].clone()
+ x = self.apply_rope1d(x, mesh.to(tokens.device), cos, sin)
+ out.append(x)
+
+ tokens = torch.cat(out, dim=-1)
+ return tokens
+
+
+class SelfAttention(Attention):
+ def __init__(self, hidden_dim, head_dim, bias=False, with_rope=True, with_qk_norm=True, attn_type='torch'):
+ super().__init__()
+ self.head_dim = head_dim
+ self.n_heads = hidden_dim // head_dim
+
+ self.wqkv = nn.Linear(hidden_dim, hidden_dim*3, bias=bias)
+ self.wo = nn.Linear(hidden_dim, hidden_dim, bias=bias)
+
+ self.with_rope = with_rope
+ self.with_qk_norm = with_qk_norm
+ if self.with_qk_norm:
+ self.q_norm = RMSNorm(head_dim, elementwise_affine=True)
+ self.k_norm = RMSNorm(head_dim, elementwise_affine=True)
+
+ if self.with_rope:
+ self.rope_3d = RoPE3D(freq=1e4, F0=1.0, scaling_factor=1.0)
+ self.rope_ch_split = [64, 32, 32]
+
+ self.core_attention = self.attn_processor(attn_type=attn_type)
+ self.parallel = attn_type=='parallel'
+
+ def apply_rope3d(self, x, fhw_positions, rope_ch_split, parallel=True):
+ x = self.rope_3d(x, fhw_positions, rope_ch_split, parallel)
+ return x
+
+ def forward(
+ self,
+ x,
+ cu_seqlens=None,
+ max_seqlen=None,
+ rope_positions=None,
+ attn_mask=None
+ ):
+ xqkv = self.wqkv(x)
+ xqkv = xqkv.view(*x.shape[:-1], self.n_heads, 3*self.head_dim)
+
+ xq, xk, xv = torch.split(xqkv, [self.head_dim]*3, dim=-1) ## seq_len, n, dim
+
+ if self.with_qk_norm:
+ xq = self.q_norm(xq)
+ xk = self.k_norm(xk)
+
+ if self.with_rope:
+ xq = self.apply_rope3d(xq, rope_positions, self.rope_ch_split, parallel=self.parallel)
+ xk = self.apply_rope3d(xk, rope_positions, self.rope_ch_split, parallel=self.parallel)
+
+ output = self.core_attention(
+ xq,
+ xk,
+ xv,
+ cu_seqlens=cu_seqlens,
+ max_seqlen=max_seqlen,
+ attn_mask=attn_mask
+ )
+ output = rearrange(output, 'b s h d -> b s (h d)')
+ output = self.wo(output)
+
+ return output
+
+
+class CrossAttention(Attention):
+ def __init__(self, hidden_dim, head_dim, bias=False, with_qk_norm=True, attn_type='torch'):
+ super().__init__()
+ self.head_dim = head_dim
+ self.n_heads = hidden_dim // head_dim
+
+ self.wq = nn.Linear(hidden_dim, hidden_dim, bias=bias)
+ self.wkv = nn.Linear(hidden_dim, hidden_dim*2, bias=bias)
+ self.wo = nn.Linear(hidden_dim, hidden_dim, bias=bias)
+
+ self.with_qk_norm = with_qk_norm
+ if self.with_qk_norm:
+ self.q_norm = RMSNorm(head_dim, elementwise_affine=True)
+ self.k_norm = RMSNorm(head_dim, elementwise_affine=True)
+
+ self.core_attention = self.attn_processor(attn_type=attn_type)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ attn_mask=None
+ ):
+ xq = self.wq(x)
+ xq = xq.view(*xq.shape[:-1], self.n_heads, self.head_dim)
+
+ xkv = self.wkv(encoder_hidden_states)
+ xkv = xkv.view(*xkv.shape[:-1], self.n_heads, 2*self.head_dim)
+
+ xk, xv = torch.split(xkv, [self.head_dim]*2, dim=-1) ## seq_len, n, dim
+
+ if self.with_qk_norm:
+ xq = self.q_norm(xq)
+ xk = self.k_norm(xk)
+
+ output = self.core_attention(
+ xq,
+ xk,
+ xv,
+ attn_mask=attn_mask
+ )
+
+ output = rearrange(output, 'b s h d -> b s (h d)')
+ output = self.wo(output)
+
+ return output
+
+
+class GELU(nn.Module):
+ r"""
+ GELU activation function with tanh approximation support with `approximate="tanh"`.
+
+ Parameters:
+ dim_in (`int`): The number of channels in the input.
+ dim_out (`int`): The number of channels in the output.
+ approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
+ """
+
+ def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out, bias=bias)
+ self.approximate = approximate
+
+ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
+ return torch.nn.functional.gelu(gate, approximate=self.approximate)
+
+ def forward(self, hidden_states):
+ hidden_states = self.proj(hidden_states)
+ hidden_states = self.gelu(hidden_states)
+ return hidden_states
+
+
+class FeedForward(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ inner_dim: Optional[int] = None,
+ dim_out: Optional[int] = None,
+ mult: int = 4,
+ bias: bool = False,
+ ):
+ super().__init__()
+ inner_dim = dim*mult if inner_dim is None else inner_dim
+ dim_out = dim if dim_out is None else dim_out
+ self.net = nn.ModuleList([
+ GELU(dim, inner_dim, approximate="tanh", bias=bias),
+ nn.Identity(),
+ nn.Linear(inner_dim, dim_out, bias=bias)
+ ])
+
+
+ def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
+ for module in self.net:
+ hidden_states = module(hidden_states)
+ return hidden_states
+
+
+def modulate(x, scale, shift):
+ x = x * (1 + scale) + shift
+ return x
+
+
+def gate(x, gate):
+ x = gate * x
+ return x
+
+
+class StepVideoTransformerBlock(nn.Module):
+ r"""
+ A basic Transformer block.
+
+ Parameters:
+ dim (`int`): The number of channels in the input and output.
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`): The number of channels in each head.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
+ num_embeds_ada_norm (:
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
+ attention_bias (:
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
+ only_cross_attention (`bool`, *optional*):
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
+ double_self_attention (`bool`, *optional*):
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
+ upcast_attention (`bool`, *optional*):
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
+ Whether to use learnable elementwise affine parameters for normalization.
+ norm_type (`str`, *optional*, defaults to `"layer_norm"`):
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
+ final_dropout (`bool` *optional*, defaults to False):
+ Whether to apply a final dropout after the last feed-forward layer.
+ attention_type (`str`, *optional*, defaults to `"default"`):
+ The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
+ positional_embeddings (`str`, *optional*, defaults to `None`):
+ The type of positional embeddings to apply to.
+ num_positional_embeddings (`int`, *optional*, defaults to `None`):
+ The maximum number of positional embeddings to apply.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ attention_head_dim: int,
+ norm_eps: float = 1e-5,
+ ff_inner_dim: Optional[int] = None,
+ ff_bias: bool = False,
+ attention_type: str = 'parallel'
+ ):
+ super().__init__()
+ self.dim = dim
+ self.norm1 = nn.LayerNorm(dim, eps=norm_eps)
+ self.attn1 = SelfAttention(dim, attention_head_dim, bias=False, with_rope=True, with_qk_norm=True, attn_type=attention_type)
+
+ self.norm2 = nn.LayerNorm(dim, eps=norm_eps)
+ self.attn2 = CrossAttention(dim, attention_head_dim, bias=False, with_qk_norm=True, attn_type='torch')
+
+ self.ff = FeedForward(dim=dim, inner_dim=ff_inner_dim, dim_out=dim, bias=ff_bias)
+
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) /dim**0.5)
+
+ @torch.no_grad()
+ def forward(
+ self,
+ q: torch.Tensor,
+ kv: Optional[torch.Tensor] = None,
+ timestep: Optional[torch.LongTensor] = None,
+ attn_mask = None,
+ rope_positions: list = None,
+ ) -> torch.Tensor:
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
+ torch.clone(chunk) for chunk in (self.scale_shift_table[None].to(dtype=q.dtype, device=q.device) + timestep.reshape(-1, 6, self.dim)).chunk(6, dim=1)
+ )
+
+ scale_shift_q = modulate(self.norm1(q), scale_msa, shift_msa)
+
+ attn_q = self.attn1(
+ scale_shift_q,
+ rope_positions=rope_positions
+ )
+
+ q = gate(attn_q, gate_msa) + q
+
+ attn_q = self.attn2(
+ q,
+ kv,
+ attn_mask
+ )
+
+ q = attn_q + q
+
+ scale_shift_q = modulate(self.norm2(q), scale_mlp, shift_mlp)
+
+ ff_output = self.ff(scale_shift_q)
+
+ q = gate(ff_output, gate_mlp) + q
+
+ return q
+
+
+class PatchEmbed(nn.Module):
+ """2D Image to Patch Embedding"""
+
+ def __init__(
+ self,
+ patch_size=64,
+ in_channels=3,
+ embed_dim=768,
+ layer_norm=False,
+ flatten=True,
+ bias=True,
+ ):
+ super().__init__()
+
+ self.flatten = flatten
+ self.layer_norm = layer_norm
+
+ self.proj = nn.Conv2d(
+ in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
+ )
+
+ def forward(self, latent):
+ latent = self.proj(latent).to(latent.dtype)
+ if self.flatten:
+ latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
+ if self.layer_norm:
+ latent = self.norm(latent)
+
+ return latent
+
+
+class StepVideoModel(torch.nn.Module):
+ def __init__(
+ self,
+ num_attention_heads: int = 48,
+ attention_head_dim: int = 128,
+ in_channels: int = 64,
+ out_channels: Optional[int] = 64,
+ num_layers: int = 48,
+ dropout: float = 0.0,
+ patch_size: int = 1,
+ norm_type: str = "ada_norm_single",
+ norm_elementwise_affine: bool = False,
+ norm_eps: float = 1e-6,
+ use_additional_conditions: Optional[bool] = False,
+ caption_channels: Optional[Union[int, List, Tuple]] = [6144, 1024],
+ attention_type: Optional[str] = "torch",
+ ):
+ super().__init__()
+
+ # Set some common variables used across the board.
+ self.inner_dim = num_attention_heads * attention_head_dim
+ self.out_channels = in_channels if out_channels is None else out_channels
+
+ self.use_additional_conditions = use_additional_conditions
+
+ self.pos_embed = PatchEmbed(
+ patch_size=patch_size,
+ in_channels=in_channels,
+ embed_dim=self.inner_dim,
+ )
+
+ self.transformer_blocks = nn.ModuleList(
+ [
+ StepVideoTransformerBlock(
+ dim=self.inner_dim,
+ attention_head_dim=attention_head_dim,
+ attention_type=attention_type
+ )
+ for _ in range(num_layers)
+ ]
+ )
+
+ # 3. Output blocks.
+ self.norm_out = nn.LayerNorm(self.inner_dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
+ self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5)
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels)
+ self.patch_size = patch_size
+
+ self.adaln_single = AdaLayerNormSingle(
+ self.inner_dim, use_additional_conditions=self.use_additional_conditions
+ )
+
+ if isinstance(caption_channels, int):
+ caption_channel = caption_channels
+ else:
+ caption_channel, clip_channel = caption_channels
+ self.clip_projection = nn.Linear(clip_channel, self.inner_dim)
+
+ self.caption_norm = nn.LayerNorm(caption_channel, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
+
+ self.caption_projection = PixArtAlphaTextProjection(
+ in_features=caption_channel, hidden_size=self.inner_dim
+ )
+
+ self.parallel = attention_type=='parallel'
+
+ def patchfy(self, hidden_states):
+ hidden_states = rearrange(hidden_states, 'b f c h w -> (b f) c h w')
+ hidden_states = self.pos_embed(hidden_states)
+ return hidden_states
+
+ def prepare_attn_mask(self, encoder_attention_mask, encoder_hidden_states, q_seqlen):
+ kv_seqlens = encoder_attention_mask.sum(dim=1).int()
+ mask = torch.zeros([len(kv_seqlens), q_seqlen, max(kv_seqlens)], dtype=torch.bool, device=encoder_attention_mask.device)
+ encoder_hidden_states = encoder_hidden_states[:,: max(kv_seqlens)]
+ for i, kv_len in enumerate(kv_seqlens):
+ mask[i, :, :kv_len] = 1
+ return encoder_hidden_states, mask
+
+
+ def block_forward(
+ self,
+ hidden_states,
+ encoder_hidden_states=None,
+ timestep=None,
+ rope_positions=None,
+ attn_mask=None,
+ parallel=True
+ ):
+ for block in tqdm(self.transformer_blocks, desc="Transformer blocks"):
+ hidden_states = block(
+ hidden_states,
+ encoder_hidden_states,
+ timestep=timestep,
+ attn_mask=attn_mask,
+ rope_positions=rope_positions
+ )
+
+ return hidden_states
+
+
+ @torch.inference_mode()
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_hidden_states_2: Optional[torch.Tensor] = None,
+ timestep: Optional[torch.LongTensor] = None,
+ added_cond_kwargs: Dict[str, torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ fps: torch.Tensor=None,
+ return_dict: bool = False,
+ ):
+ assert hidden_states.ndim==5; "hidden_states's shape should be (bsz, f, ch, h ,w)"
+
+ bsz, frame, _, height, width = hidden_states.shape
+ height, width = height // self.patch_size, width // self.patch_size
+
+ hidden_states = self.patchfy(hidden_states)
+ len_frame = hidden_states.shape[1]
+
+ if self.use_additional_conditions:
+ added_cond_kwargs = {
+ "resolution": torch.tensor([(height, width)]*bsz, device=hidden_states.device, dtype=hidden_states.dtype),
+ "nframe": torch.tensor([frame]*bsz, device=hidden_states.device, dtype=hidden_states.dtype),
+ "fps": fps
+ }
+ else:
+ added_cond_kwargs = {}
+
+ timestep, embedded_timestep = self.adaln_single(
+ timestep, added_cond_kwargs=added_cond_kwargs
+ )
+
+ encoder_hidden_states = self.caption_projection(self.caption_norm(encoder_hidden_states))
+
+ if encoder_hidden_states_2 is not None and hasattr(self, 'clip_projection'):
+ clip_embedding = self.clip_projection(encoder_hidden_states_2)
+ encoder_hidden_states = torch.cat([clip_embedding, encoder_hidden_states], dim=1)
+
+ hidden_states = rearrange(hidden_states, '(b f) l d-> b (f l) d', b=bsz, f=frame, l=len_frame).contiguous()
+ encoder_hidden_states, attn_mask = self.prepare_attn_mask(encoder_attention_mask, encoder_hidden_states, q_seqlen=frame*len_frame)
+
+ hidden_states = self.block_forward(
+ hidden_states,
+ encoder_hidden_states,
+ timestep=timestep,
+ rope_positions=[frame, height, width],
+ attn_mask=attn_mask,
+ parallel=self.parallel
+ )
+
+ hidden_states = rearrange(hidden_states, 'b (f l) d -> (b f) l d', b=bsz, f=frame, l=len_frame)
+
+ embedded_timestep = repeat(embedded_timestep, 'b d -> (b f) d', f=frame).contiguous()
+
+ shift, scale = (self.scale_shift_table[None].to(dtype=embedded_timestep.dtype, device=embedded_timestep.device) + embedded_timestep[:, None]).chunk(2, dim=1)
+ hidden_states = self.norm_out(hidden_states)
+ # Modulation
+ hidden_states = hidden_states * (1 + scale) + shift
+ hidden_states = self.proj_out(hidden_states)
+
+ # unpatchify
+ hidden_states = hidden_states.reshape(
+ shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
+ )
+
+ hidden_states = rearrange(hidden_states, 'n h w p q c -> n c h p w q')
+ output = hidden_states.reshape(
+ shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
+ )
+
+ output = rearrange(output, '(b f) c h w -> b f c h w', f=frame)
+
+ if return_dict:
+ return {'x': output}
+ return output
+
+ @staticmethod
+ def state_dict_converter():
+ return StepVideoDiTStateDictConverter()
+
+
+class StepVideoDiTStateDictConverter:
+ def __init__(self):
+ super().__init__()
+
+ def from_diffusers(self, state_dict):
+ return state_dict
+
+ def from_civitai(self, state_dict):
+ return state_dict
+
+
+
\ No newline at end of file
diff --git a/PusaV1/diffsynth/models/stepvideo_text_encoder.py b/PusaV1/diffsynth/models/stepvideo_text_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..598825a9402ea15183c9ff1488943f3bb4e5a548
--- /dev/null
+++ b/PusaV1/diffsynth/models/stepvideo_text_encoder.py
@@ -0,0 +1,553 @@
+# Copyright 2025 StepFun Inc. All Rights Reserved.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+# ==============================================================================
+import os
+from typing import Optional
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from .stepvideo_dit import RMSNorm
+from safetensors.torch import load_file
+from transformers.modeling_utils import PretrainedConfig, PreTrainedModel
+from einops import rearrange
+import json
+from typing import List
+from functools import wraps
+import warnings
+
+
+
+class EmptyInitOnDevice(torch.overrides.TorchFunctionMode):
+ def __init__(self, device=None):
+ self.device = device
+
+ def __torch_function__(self, func, types, args=(), kwargs=None):
+ kwargs = kwargs or {}
+ if getattr(func, '__module__', None) == 'torch.nn.init':
+ if 'tensor' in kwargs:
+ return kwargs['tensor']
+ else:
+ return args[0]
+ if self.device is not None and func in torch.utils._device._device_constructors() and kwargs.get('device') is None:
+ kwargs['device'] = self.device
+ return func(*args, **kwargs)
+
+
+def with_empty_init(func):
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ with EmptyInitOnDevice('cpu'):
+ return func(*args, **kwargs)
+ return wrapper
+
+
+
+class LLaMaEmbedding(nn.Module):
+ """Language model embeddings.
+
+ Arguments:
+ hidden_size: hidden size
+ vocab_size: vocabulary size
+ max_sequence_length: maximum size of sequence. This
+ is used for positional embedding
+ embedding_dropout_prob: dropout probability for embeddings
+ init_method: weight initialization method
+ num_tokentypes: size of the token-type embeddings. 0 value
+ will ignore this embedding
+ """
+
+ def __init__(self,
+ cfg,
+ ):
+ super().__init__()
+ self.hidden_size = cfg.hidden_size
+ self.params_dtype = cfg.params_dtype
+ self.fp32_residual_connection = cfg.fp32_residual_connection
+ self.embedding_weights_in_fp32 = cfg.embedding_weights_in_fp32
+ self.word_embeddings = torch.nn.Embedding(
+ cfg.padded_vocab_size, self.hidden_size,
+ )
+ self.embedding_dropout = torch.nn.Dropout(cfg.hidden_dropout)
+
+ def forward(self, input_ids):
+ # Embeddings.
+ if self.embedding_weights_in_fp32:
+ self.word_embeddings = self.word_embeddings.to(torch.float32)
+ embeddings = self.word_embeddings(input_ids)
+ if self.embedding_weights_in_fp32:
+ embeddings = embeddings.to(self.params_dtype)
+ self.word_embeddings = self.word_embeddings.to(self.params_dtype)
+
+ # Data format change to avoid explicit transposes : [b s h] --> [s b h].
+ embeddings = embeddings.transpose(0, 1).contiguous()
+
+ # If the input flag for fp32 residual connection is set, convert for float.
+ if self.fp32_residual_connection:
+ embeddings = embeddings.float()
+
+ # Dropout.
+ embeddings = self.embedding_dropout(embeddings)
+
+ return embeddings
+
+
+
+class StepChatTokenizer:
+ """Step Chat Tokenizer"""
+
+ def __init__(
+ self, model_file, name="StepChatTokenizer",
+ bot_token="<|BOT|>", # Begin of Turn
+ eot_token="<|EOT|>", # End of Turn
+ call_start_token="<|CALL_START|>", # Call Start
+ call_end_token="<|CALL_END|>", # Call End
+ think_start_token="<|THINK_START|>", # Think Start
+ think_end_token="<|THINK_END|>", # Think End
+ mask_start_token="<|MASK_1e69f|>", # Mask start
+ mask_end_token="<|UNMASK_1e69f|>", # Mask end
+ ):
+ import sentencepiece
+
+ self._tokenizer = sentencepiece.SentencePieceProcessor(model_file=model_file)
+
+ self._vocab = {}
+ self._inv_vocab = {}
+
+ self._special_tokens = {}
+ self._inv_special_tokens = {}
+
+ self._t5_tokens = []
+
+ for idx in range(self._tokenizer.get_piece_size()):
+ text = self._tokenizer.id_to_piece(idx)
+ self._inv_vocab[idx] = text
+ self._vocab[text] = idx
+
+ if self._tokenizer.is_control(idx) or self._tokenizer.is_unknown(idx):
+ self._special_tokens[text] = idx
+ self._inv_special_tokens[idx] = text
+
+ self._unk_id = self._tokenizer.unk_id()
+ self._bos_id = self._tokenizer.bos_id()
+ self._eos_id = self._tokenizer.eos_id()
+
+ for token in [
+ bot_token, eot_token, call_start_token, call_end_token,
+ think_start_token, think_end_token
+ ]:
+ assert token in self._vocab, f"Token '{token}' not found in tokenizer"
+ assert token in self._special_tokens, f"Token '{token}' is not a special token"
+
+ for token in [mask_start_token, mask_end_token]:
+ assert token in self._vocab, f"Token '{token}' not found in tokenizer"
+
+ self._bot_id = self._tokenizer.piece_to_id(bot_token)
+ self._eot_id = self._tokenizer.piece_to_id(eot_token)
+ self._call_start_id = self._tokenizer.piece_to_id(call_start_token)
+ self._call_end_id = self._tokenizer.piece_to_id(call_end_token)
+ self._think_start_id = self._tokenizer.piece_to_id(think_start_token)
+ self._think_end_id = self._tokenizer.piece_to_id(think_end_token)
+ self._mask_start_id = self._tokenizer.piece_to_id(mask_start_token)
+ self._mask_end_id = self._tokenizer.piece_to_id(mask_end_token)
+
+ self._underline_id = self._tokenizer.piece_to_id("\u2581")
+
+ @property
+ def vocab(self):
+ return self._vocab
+
+ @property
+ def inv_vocab(self):
+ return self._inv_vocab
+
+ @property
+ def vocab_size(self):
+ return self._tokenizer.vocab_size()
+
+ def tokenize(self, text: str) -> List[int]:
+ return self._tokenizer.encode_as_ids(text)
+
+ def detokenize(self, token_ids: List[int]) -> str:
+ return self._tokenizer.decode_ids(token_ids)
+
+
+class Tokens:
+ def __init__(self, input_ids, cu_input_ids, attention_mask, cu_seqlens, max_seq_len) -> None:
+ self.input_ids = input_ids
+ self.attention_mask = attention_mask
+ self.cu_input_ids = cu_input_ids
+ self.cu_seqlens = cu_seqlens
+ self.max_seq_len = max_seq_len
+ def to(self, device):
+ self.input_ids = self.input_ids.to(device)
+ self.attention_mask = self.attention_mask.to(device)
+ self.cu_input_ids = self.cu_input_ids.to(device)
+ self.cu_seqlens = self.cu_seqlens.to(device)
+ return self
+
+class Wrapped_StepChatTokenizer(StepChatTokenizer):
+ def __call__(self, text, max_length=320, padding="max_length", truncation=True, return_tensors="pt"):
+ # [bos, ..., eos, pad, pad, ..., pad]
+ self.BOS = 1
+ self.EOS = 2
+ self.PAD = 2
+ out_tokens = []
+ attn_mask = []
+ if len(text) == 0:
+ part_tokens = [self.BOS] + [self.EOS]
+ valid_size = len(part_tokens)
+ if len(part_tokens) < max_length:
+ part_tokens += [self.PAD] * (max_length - valid_size)
+ out_tokens.append(part_tokens)
+ attn_mask.append([1]*valid_size+[0]*(max_length-valid_size))
+ else:
+ for part in text:
+ part_tokens = self.tokenize(part)
+ part_tokens = part_tokens[:(max_length - 2)] # leave 2 space for bos and eos
+ part_tokens = [self.BOS] + part_tokens + [self.EOS]
+ valid_size = len(part_tokens)
+ if len(part_tokens) < max_length:
+ part_tokens += [self.PAD] * (max_length - valid_size)
+ out_tokens.append(part_tokens)
+ attn_mask.append([1]*valid_size+[0]*(max_length-valid_size))
+
+ out_tokens = torch.tensor(out_tokens, dtype=torch.long)
+ attn_mask = torch.tensor(attn_mask, dtype=torch.long)
+
+ # padding y based on tp size
+ padded_len = 0
+ padded_flag = True if padded_len > 0 else False
+ if padded_flag:
+ pad_tokens = torch.tensor([[self.PAD] * max_length], device=out_tokens.device)
+ pad_attn_mask = torch.tensor([[1]*padded_len+[0]*(max_length-padded_len)], device=attn_mask.device)
+ out_tokens = torch.cat([out_tokens, pad_tokens], dim=0)
+ attn_mask = torch.cat([attn_mask, pad_attn_mask], dim=0)
+
+ # cu_seqlens
+ cu_out_tokens = out_tokens.masked_select(attn_mask != 0).unsqueeze(0)
+ seqlen = attn_mask.sum(dim=1).tolist()
+ cu_seqlens = torch.cumsum(torch.tensor([0]+seqlen), 0).to(device=out_tokens.device,dtype=torch.int32)
+ max_seq_len = max(seqlen)
+ return Tokens(out_tokens, cu_out_tokens, attn_mask, cu_seqlens, max_seq_len)
+
+
+
+def flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=True,
+ return_attn_probs=False, tp_group_rank=0, tp_group_size=1):
+ softmax_scale = q.size(-1) ** (-0.5) if softmax_scale is None else softmax_scale
+ if hasattr(torch.ops.Optimus, "fwd"):
+ results = torch.ops.Optimus.fwd(q, k, v, None, dropout_p, softmax_scale, causal, return_attn_probs, None, tp_group_rank, tp_group_size)[0]
+ else:
+ warnings.warn("Cannot load `torch.ops.Optimus.fwd`. Using `torch.nn.functional.scaled_dot_product_attention` instead.")
+ results = torch.nn.functional.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=True, scale=softmax_scale).transpose(1, 2)
+ return results
+
+
+class FlashSelfAttention(torch.nn.Module):
+ def __init__(
+ self,
+ attention_dropout=0.0,
+ ):
+ super().__init__()
+ self.dropout_p = attention_dropout
+
+
+ def forward(self, q, k, v, cu_seqlens=None, max_seq_len=None):
+ if cu_seqlens is None:
+ output = flash_attn_func(q, k, v, dropout_p=self.dropout_p)
+ else:
+ raise ValueError('cu_seqlens is not supported!')
+
+ return output
+
+
+
+def safediv(n, d):
+ q, r = divmod(n, d)
+ assert r == 0
+ return q
+
+
+class MultiQueryAttention(nn.Module):
+ def __init__(self, cfg, layer_id=None):
+ super().__init__()
+
+ self.head_dim = cfg.hidden_size // cfg.num_attention_heads
+ self.max_seq_len = cfg.seq_length
+ self.use_flash_attention = cfg.use_flash_attn
+ assert self.use_flash_attention, 'FlashAttention is required!'
+
+ self.n_groups = cfg.num_attention_groups
+ self.tp_size = 1
+ self.n_local_heads = cfg.num_attention_heads
+ self.n_local_groups = self.n_groups
+
+ self.wqkv = nn.Linear(
+ cfg.hidden_size,
+ cfg.hidden_size + self.head_dim * 2 * self.n_groups,
+ bias=False,
+ )
+ self.wo = nn.Linear(
+ cfg.hidden_size,
+ cfg.hidden_size,
+ bias=False,
+ )
+
+ assert self.use_flash_attention, 'non-Flash attention not supported yet.'
+ self.core_attention = FlashSelfAttention(attention_dropout=cfg.attention_dropout)
+
+ self.layer_id = layer_id
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ mask: Optional[torch.Tensor],
+ cu_seqlens: Optional[torch.Tensor],
+ max_seq_len: Optional[torch.Tensor],
+ ):
+ seqlen, bsz, dim = x.shape
+ xqkv = self.wqkv(x)
+
+ xq, xkv = torch.split(
+ xqkv,
+ (dim // self.tp_size,
+ self.head_dim*2*self.n_groups // self.tp_size
+ ),
+ dim=-1,
+ )
+
+ # gather on 1st dimension
+ xq = xq.view(seqlen, bsz, self.n_local_heads, self.head_dim)
+ xkv = xkv.view(seqlen, bsz, self.n_local_groups, 2 * self.head_dim)
+ xk, xv = xkv.chunk(2, -1)
+
+ # rotary embedding + flash attn
+ xq = rearrange(xq, "s b h d -> b s h d")
+ xk = rearrange(xk, "s b h d -> b s h d")
+ xv = rearrange(xv, "s b h d -> b s h d")
+
+ q_per_kv = self.n_local_heads // self.n_local_groups
+ if q_per_kv > 1:
+ b, s, h, d = xk.size()
+ if h == 1:
+ xk = xk.expand(b, s, q_per_kv, d)
+ xv = xv.expand(b, s, q_per_kv, d)
+ else:
+ ''' To cover the cases where h > 1, we have
+ the following implementation, which is equivalent to:
+ xk = xk.repeat_interleave(q_per_kv, dim=-2)
+ xv = xv.repeat_interleave(q_per_kv, dim=-2)
+ but can avoid calling aten::item() that involves cpu.
+ '''
+ idx = torch.arange(q_per_kv * h, device=xk.device).reshape(q_per_kv, -1).permute(1, 0).flatten()
+ xk = torch.index_select(xk.repeat(1, 1, q_per_kv, 1), 2, idx).contiguous()
+ xv = torch.index_select(xv.repeat(1, 1, q_per_kv, 1), 2, idx).contiguous()
+
+ if self.use_flash_attention:
+ output = self.core_attention(xq, xk, xv,
+ cu_seqlens=cu_seqlens,
+ max_seq_len=max_seq_len)
+ # reduce-scatter only support first dimension now
+ output = rearrange(output, "b s h d -> s b (h d)").contiguous()
+ else:
+ xq, xk, xv = [
+ rearrange(x, "b s ... -> s b ...").contiguous()
+ for x in (xq, xk, xv)
+ ]
+ output = self.core_attention(xq, xk, xv, mask)
+ output = self.wo(output)
+ return output
+
+
+
+class FeedForward(nn.Module):
+ def __init__(
+ self,
+ cfg,
+ dim: int,
+ hidden_dim: int,
+ layer_id: int,
+ multiple_of: int=256,
+ ):
+ super().__init__()
+
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
+ def swiglu(x):
+ x = torch.chunk(x, 2, dim=-1)
+ return F.silu(x[0]) * x[1]
+ self.swiglu = swiglu
+
+ self.w1 = nn.Linear(
+ dim,
+ 2 * hidden_dim,
+ bias=False,
+ )
+ self.w2 = nn.Linear(
+ hidden_dim,
+ dim,
+ bias=False,
+ )
+
+ def forward(self, x):
+ x = self.swiglu(self.w1(x))
+ output = self.w2(x)
+ return output
+
+
+
+class TransformerBlock(nn.Module):
+ def __init__(
+ self, cfg, layer_id: int
+ ):
+ super().__init__()
+
+ self.n_heads = cfg.num_attention_heads
+ self.dim = cfg.hidden_size
+ self.head_dim = cfg.hidden_size // cfg.num_attention_heads
+ self.attention = MultiQueryAttention(
+ cfg,
+ layer_id=layer_id,
+ )
+
+ self.feed_forward = FeedForward(
+ cfg,
+ dim=cfg.hidden_size,
+ hidden_dim=cfg.ffn_hidden_size,
+ layer_id=layer_id,
+ )
+ self.layer_id = layer_id
+ self.attention_norm = RMSNorm(
+ cfg.hidden_size,
+ eps=cfg.layernorm_epsilon,
+ )
+ self.ffn_norm = RMSNorm(
+ cfg.hidden_size,
+ eps=cfg.layernorm_epsilon,
+ )
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ mask: Optional[torch.Tensor],
+ cu_seqlens: Optional[torch.Tensor],
+ max_seq_len: Optional[torch.Tensor],
+ ):
+ residual = self.attention.forward(
+ self.attention_norm(x), mask,
+ cu_seqlens, max_seq_len
+ )
+ h = x + residual
+ ffn_res = self.feed_forward.forward(self.ffn_norm(h))
+ out = h + ffn_res
+ return out
+
+
+class Transformer(nn.Module):
+ def __init__(
+ self,
+ config,
+ max_seq_size=8192,
+ ):
+ super().__init__()
+ self.num_layers = config.num_layers
+ self.layers = self._build_layers(config)
+
+ def _build_layers(self, config):
+ layers = torch.nn.ModuleList()
+ for layer_id in range(self.num_layers):
+ layers.append(
+ TransformerBlock(
+ config,
+ layer_id=layer_id + 1 ,
+ )
+ )
+ return layers
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask,
+ cu_seqlens=None,
+ max_seq_len=None,
+ ):
+
+ if max_seq_len is not None and not isinstance(max_seq_len, torch.Tensor):
+ max_seq_len = torch.tensor(max_seq_len, dtype=torch.int32, device="cpu")
+
+ for lid, layer in enumerate(self.layers):
+ hidden_states = layer(
+ hidden_states,
+ attention_mask,
+ cu_seqlens,
+ max_seq_len,
+ )
+ return hidden_states
+
+
+class Step1Model(PreTrainedModel):
+ config_class=PretrainedConfig
+ @with_empty_init
+ def __init__(
+ self,
+ config,
+ ):
+ super().__init__(config)
+ self.tok_embeddings = LLaMaEmbedding(config)
+ self.transformer = Transformer(config)
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ ):
+
+ hidden_states = self.tok_embeddings(input_ids)
+
+ hidden_states = self.transformer(
+ hidden_states,
+ attention_mask,
+ )
+ return hidden_states
+
+
+
+class STEP1TextEncoder(torch.nn.Module):
+ def __init__(self, model_dir, max_length=320):
+ super(STEP1TextEncoder, self).__init__()
+ self.max_length = max_length
+ self.text_tokenizer = Wrapped_StepChatTokenizer(os.path.join(model_dir, 'step1_chat_tokenizer.model'))
+ text_encoder = Step1Model.from_pretrained(model_dir)
+ self.text_encoder = text_encoder.eval().to(torch.bfloat16)
+
+ @staticmethod
+ def from_pretrained(path, torch_dtype=torch.bfloat16):
+ model = STEP1TextEncoder(path).to(torch_dtype)
+ return model
+
+ @torch.no_grad
+ def forward(self, prompts, with_mask=True, max_length=None, device="cuda"):
+ self.device = device
+ with torch.no_grad(), torch.amp.autocast(dtype=torch.bfloat16, device_type=device):
+ if type(prompts) is str:
+ prompts = [prompts]
+
+ txt_tokens = self.text_tokenizer(
+ prompts, max_length=max_length or self.max_length, padding="max_length", truncation=True, return_tensors="pt"
+ )
+ y = self.text_encoder(
+ txt_tokens.input_ids.to(self.device),
+ attention_mask=txt_tokens.attention_mask.to(self.device) if with_mask else None
+ )
+ y_mask = txt_tokens.attention_mask
+ return y.transpose(0,1), y_mask
+
diff --git a/PusaV1/diffsynth/models/stepvideo_vae.py b/PusaV1/diffsynth/models/stepvideo_vae.py
new file mode 100644
index 0000000000000000000000000000000000000000..db244c00de53c29f87de67956e472fabad17256b
--- /dev/null
+++ b/PusaV1/diffsynth/models/stepvideo_vae.py
@@ -0,0 +1,1132 @@
+# Copyright 2025 StepFun Inc. All Rights Reserved.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+# ==============================================================================
+import torch
+from einops import rearrange
+from torch import nn
+from torch.nn import functional as F
+from tqdm import tqdm
+from einops import repeat
+
+
+class BaseGroupNorm(nn.GroupNorm):
+ def __init__(self, num_groups, num_channels):
+ super().__init__(num_groups=num_groups, num_channels=num_channels)
+
+ def forward(self, x, zero_pad=False, **kwargs):
+ if zero_pad:
+ return base_group_norm_with_zero_pad(x, self, **kwargs)
+ else:
+ return base_group_norm(x, self, **kwargs)
+
+
+def base_group_norm(x, norm_layer, act_silu=False, channel_last=False):
+ if hasattr(base_group_norm, 'spatial') and base_group_norm.spatial:
+ assert channel_last == True
+ x_shape = x.shape
+ x = x.flatten(0, 1)
+ if channel_last:
+ # Permute to NCHW format
+ x = x.permute(0, 3, 1, 2)
+
+ out = F.group_norm(x.contiguous(), norm_layer.num_groups, norm_layer.weight, norm_layer.bias, norm_layer.eps)
+ if act_silu:
+ out = F.silu(out)
+
+ if channel_last:
+ # Permute back to NHWC format
+ out = out.permute(0, 2, 3, 1)
+
+ out = out.view(x_shape)
+ else:
+ if channel_last:
+ # Permute to NCHW format
+ x = x.permute(0, 3, 1, 2)
+ out = F.group_norm(x.contiguous(), norm_layer.num_groups, norm_layer.weight, norm_layer.bias, norm_layer.eps)
+ if act_silu:
+ out = F.silu(out)
+ if channel_last:
+ # Permute back to NHWC format
+ out = out.permute(0, 2, 3, 1)
+ return out
+
+def base_conv2d(x, conv_layer, channel_last=False, residual=None):
+ if channel_last:
+ x = x.permute(0, 3, 1, 2) # NHWC to NCHW
+ out = F.conv2d(x, conv_layer.weight, conv_layer.bias, stride=conv_layer.stride, padding=conv_layer.padding)
+ if residual is not None:
+ if channel_last:
+ residual = residual.permute(0, 3, 1, 2) # NHWC to NCHW
+ out += residual
+ if channel_last:
+ out = out.permute(0, 2, 3, 1) # NCHW to NHWC
+ return out
+
+def base_conv3d(x, conv_layer, channel_last=False, residual=None, only_return_output=False):
+ if only_return_output:
+ size = cal_outsize(x.shape, conv_layer.weight.shape, conv_layer.stride, conv_layer.padding)
+ return torch.empty(size, device=x.device, dtype=x.dtype)
+ if channel_last:
+ x = x.permute(0, 4, 1, 2, 3) # NDHWC to NCDHW
+ out = F.conv3d(x, conv_layer.weight, conv_layer.bias, stride=conv_layer.stride, padding=conv_layer.padding)
+ if residual is not None:
+ if channel_last:
+ residual = residual.permute(0, 4, 1, 2, 3) # NDHWC to NCDHW
+ out += residual
+ if channel_last:
+ out = out.permute(0, 2, 3, 4, 1) # NCDHW to NDHWC
+ return out
+
+
+def cal_outsize(input_sizes, kernel_sizes, stride, padding):
+ stride_d, stride_h, stride_w = stride
+ padding_d, padding_h, padding_w = padding
+ dilation_d, dilation_h, dilation_w = 1, 1, 1
+
+ in_d = input_sizes[1]
+ in_h = input_sizes[2]
+ in_w = input_sizes[3]
+ in_channel = input_sizes[4]
+
+
+ kernel_d = kernel_sizes[2]
+ kernel_h = kernel_sizes[3]
+ kernel_w = kernel_sizes[4]
+ out_channels = kernel_sizes[0]
+
+ out_d = calc_out_(in_d, padding_d, dilation_d, kernel_d, stride_d)
+ out_h = calc_out_(in_h, padding_h, dilation_h, kernel_h, stride_h)
+ out_w = calc_out_(in_w, padding_w, dilation_w, kernel_w, stride_w)
+ size = [input_sizes[0], out_d, out_h, out_w, out_channels]
+ return size
+
+
+
+
+def calc_out_(in_size, padding, dilation, kernel, stride):
+ return (in_size + 2 * padding - dilation * (kernel - 1) - 1) // stride + 1
+
+
+
+def base_conv3d_channel_last(x, conv_layer, residual=None):
+ in_numel = x.numel()
+ out_numel = int(x.numel() * conv_layer.out_channels / conv_layer.in_channels)
+ if (in_numel >= 2**30) or (out_numel >= 2**30):
+ assert conv_layer.stride[0] == 1, "time split asks time stride = 1"
+
+ B,T,H,W,C = x.shape
+ K = conv_layer.kernel_size[0]
+
+ chunks = 4
+ chunk_size = T // chunks
+
+ if residual is None:
+ out_nhwc = base_conv3d(x, conv_layer, channel_last=True, residual=residual, only_return_output=True)
+ else:
+ out_nhwc = residual
+
+ assert B == 1
+ outs = []
+ for i in range(chunks):
+ if i == chunks-1:
+ xi = x[:1,chunk_size*i:]
+ out_nhwci = out_nhwc[:1,chunk_size*i:]
+ else:
+ xi = x[:1,chunk_size*i:chunk_size*(i+1)+K-1]
+ out_nhwci = out_nhwc[:1,chunk_size*i:chunk_size*(i+1)]
+ if residual is not None:
+ if i == chunks-1:
+ ri = residual[:1,chunk_size*i:]
+ else:
+ ri = residual[:1,chunk_size*i:chunk_size*(i+1)]
+ else:
+ ri = None
+ out_nhwci.copy_(base_conv3d(xi, conv_layer, channel_last=True, residual=ri))
+ else:
+ out_nhwc = base_conv3d(x, conv_layer, channel_last=True, residual=residual)
+ return out_nhwc
+
+
+
+class Upsample2D(nn.Module):
+ def __init__(self,
+ channels,
+ use_conv=False,
+ use_conv_transpose=False,
+ out_channels=None):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_conv_transpose = use_conv_transpose
+
+ if use_conv:
+ self.conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
+ else:
+ assert "Not Supported"
+ self.conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
+
+ def forward(self, x, output_size=None):
+ assert x.shape[-1] == self.channels
+
+ if self.use_conv_transpose:
+ return self.conv(x)
+
+ if output_size is None:
+ x = F.interpolate(
+ x.permute(0,3,1,2).to(memory_format=torch.channels_last),
+ scale_factor=2.0, mode='nearest').permute(0,2,3,1).contiguous()
+ else:
+ x = F.interpolate(
+ x.permute(0,3,1,2).to(memory_format=torch.channels_last),
+ size=output_size, mode='nearest').permute(0,2,3,1).contiguous()
+
+ # x = self.conv(x)
+ x = base_conv2d(x, self.conv, channel_last=True)
+ return x
+
+
+class Downsample2D(nn.Module):
+ def __init__(self, channels, use_conv=False, out_channels=None, padding=1):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.padding = padding
+ stride = 2
+
+ if use_conv:
+ self.conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
+ else:
+ assert self.channels == self.out_channels
+ self.conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
+
+ def forward(self, x):
+ assert x.shape[-1] == self.channels
+ if self.use_conv and self.padding == 0:
+ pad = (0, 0, 0, 1, 0, 1)
+ x = F.pad(x, pad, mode="constant", value=0)
+
+ assert x.shape[-1] == self.channels
+ # x = self.conv(x)
+ x = base_conv2d(x, self.conv, channel_last=True)
+ return x
+
+
+
+class CausalConv(nn.Module):
+ def __init__(self,
+ chan_in,
+ chan_out,
+ kernel_size,
+ **kwargs
+ ):
+ super().__init__()
+
+ if isinstance(kernel_size, int):
+ kernel_size = kernel_size if isinstance(kernel_size, tuple) else ((kernel_size,) * 3)
+ time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
+
+ self.dilation = kwargs.pop('dilation', 1)
+ self.stride = kwargs.pop('stride', 1)
+ if isinstance(self.stride, int):
+ self.stride = (self.stride, 1, 1)
+ time_pad = self.dilation * (time_kernel_size - 1) + max((1 - self.stride[0]), 0)
+ height_pad = height_kernel_size // 2
+ width_pad = width_kernel_size // 2
+ self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)
+ self.time_uncausal_padding = (width_pad, width_pad, height_pad, height_pad, 0, 0)
+
+ self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride=self.stride, dilation=self.dilation, **kwargs)
+ self.is_first_run = True
+
+ def forward(self, x, is_init=True, residual=None):
+ x = nn.functional.pad(x,
+ self.time_causal_padding if is_init else self.time_uncausal_padding)
+
+ x = self.conv(x)
+ if residual is not None:
+ x.add_(residual)
+ return x
+
+
+class ChannelDuplicatingPixelUnshuffleUpSampleLayer3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ factor: int,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.factor = factor
+ assert out_channels * factor**3 % in_channels == 0
+ self.repeats = out_channels * factor**3 // in_channels
+
+ def forward(self, x: torch.Tensor, is_init=True) -> torch.Tensor:
+ x = x.repeat_interleave(self.repeats, dim=1)
+ x = x.view(x.size(0), self.out_channels, self.factor, self.factor, self.factor, x.size(2), x.size(3), x.size(4))
+ x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
+ x = x.view(x.size(0), self.out_channels, x.size(2)*self.factor, x.size(4)*self.factor, x.size(6)*self.factor)
+ x = x[:, :, self.factor - 1:, :, :]
+ return x
+
+class ConvPixelShuffleUpSampleLayer3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int,
+ factor: int,
+ ):
+ super().__init__()
+ self.factor = factor
+ out_ratio = factor**3
+ self.conv = CausalConv(
+ in_channels,
+ out_channels * out_ratio,
+ kernel_size=kernel_size
+ )
+
+ def forward(self, x: torch.Tensor, is_init=True) -> torch.Tensor:
+ x = self.conv(x, is_init)
+ x = self.pixel_shuffle_3d(x, self.factor)
+ return x
+
+ @staticmethod
+ def pixel_shuffle_3d(x: torch.Tensor, factor: int) -> torch.Tensor:
+ batch_size, channels, depth, height, width = x.size()
+ new_channels = channels // (factor ** 3)
+ new_depth = depth * factor
+ new_height = height * factor
+ new_width = width * factor
+
+ x = x.view(batch_size, new_channels, factor, factor, factor, depth, height, width)
+ x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
+ x = x.view(batch_size, new_channels, new_depth, new_height, new_width)
+ x = x[:, :, factor - 1:, :, :]
+ return x
+
+class ConvPixelUnshuffleDownSampleLayer3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int,
+ factor: int,
+ ):
+ super().__init__()
+ self.factor = factor
+ out_ratio = factor**3
+ assert out_channels % out_ratio == 0
+ self.conv = CausalConv(
+ in_channels,
+ out_channels // out_ratio,
+ kernel_size=kernel_size
+ )
+
+ def forward(self, x: torch.Tensor, is_init=True) -> torch.Tensor:
+ x = self.conv(x, is_init)
+ x = self.pixel_unshuffle_3d(x, self.factor)
+ return x
+
+ @staticmethod
+ def pixel_unshuffle_3d(x: torch.Tensor, factor: int) -> torch.Tensor:
+ pad = (0, 0, 0, 0, factor-1, 0) # (left, right, top, bottom, front, back)
+ x = F.pad(x, pad)
+ B, C, D, H, W = x.shape
+ x = x.view(B, C, D // factor, factor, H // factor, factor, W // factor, factor)
+ x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous()
+ x = x.view(B, C * factor**3, D // factor, H // factor, W // factor)
+ return x
+
+class PixelUnshuffleChannelAveragingDownSampleLayer3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ factor: int,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.factor = factor
+ assert in_channels * factor**3 % out_channels == 0
+ self.group_size = in_channels * factor**3 // out_channels
+
+ def forward(self, x: torch.Tensor, is_init=True) -> torch.Tensor:
+ pad = (0, 0, 0, 0, self.factor-1, 0) # (left, right, top, bottom, front, back)
+ x = F.pad(x, pad)
+ B, C, D, H, W = x.shape
+ x = x.view(B, C, D // self.factor, self.factor, H // self.factor, self.factor, W // self.factor, self.factor)
+ x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous()
+ x = x.view(B, C * self.factor**3, D // self.factor, H // self.factor, W // self.factor)
+ x = x.view(B, self.out_channels, self.group_size, D // self.factor, H // self.factor, W // self.factor)
+ x = x.mean(dim=2)
+ return x
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ factor: int,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.factor = factor
+ assert in_channels * factor**3 % out_channels == 0
+ self.group_size = in_channels * factor**3 // out_channels
+
+ def forward(self, x: torch.Tensor, is_init=True) -> torch.Tensor:
+ pad = (0, 0, 0, 0, self.factor-1, 0) # (left, right, top, bottom, front, back)
+ x = F.pad(x, pad)
+ B, C, D, H, W = x.shape
+ x = x.view(B, C, D // self.factor, self.factor, H // self.factor, self.factor, W // self.factor, self.factor)
+ x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous()
+ x = x.view(B, C * self.factor**3, D // self.factor, H // self.factor, W // self.factor)
+ x = x.view(B, self.out_channels, self.group_size, D // self.factor, H // self.factor, W // self.factor)
+ x = x.mean(dim=2)
+ return x
+
+
+
+
+def base_group_norm_with_zero_pad(x, norm_layer, act_silu=True, pad_size=2):
+ out_shape = list(x.shape)
+ out_shape[1] += pad_size
+ out = torch.empty(out_shape, dtype=x.dtype, device=x.device)
+ out[:, pad_size:] = base_group_norm(x, norm_layer, act_silu=act_silu, channel_last=True)
+ out[:, :pad_size] = 0
+ return out
+
+
+class CausalConvChannelLast(CausalConv):
+ def __init__(self,
+ chan_in,
+ chan_out,
+ kernel_size,
+ **kwargs
+ ):
+ super().__init__(
+ chan_in, chan_out, kernel_size, **kwargs)
+
+ self.time_causal_padding = (0, 0) + self.time_causal_padding
+ self.time_uncausal_padding = (0, 0) + self.time_uncausal_padding
+
+ def forward(self, x, is_init=True, residual=None):
+ if self.is_first_run:
+ self.is_first_run = False
+ # self.conv.weight = nn.Parameter(self.conv.weight.permute(0,2,3,4,1).contiguous())
+
+ x = nn.functional.pad(x,
+ self.time_causal_padding if is_init else self.time_uncausal_padding)
+
+ x = base_conv3d_channel_last(x, self.conv, residual=residual)
+ return x
+
+class CausalConvAfterNorm(CausalConv):
+ def __init__(self,
+ chan_in,
+ chan_out,
+ kernel_size,
+ **kwargs
+ ):
+ super().__init__(
+ chan_in, chan_out, kernel_size, **kwargs)
+
+ if self.time_causal_padding == (1, 1, 1, 1, 2, 0):
+ self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride=self.stride, dilation=self.dilation, padding=(0, 1, 1), **kwargs)
+ else:
+ self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride=self.stride, dilation=self.dilation, **kwargs)
+ self.is_first_run = True
+
+ def forward(self, x, is_init=True, residual=None):
+ if self.is_first_run:
+ self.is_first_run = False
+
+ if self.time_causal_padding == (1, 1, 1, 1, 2, 0):
+ pass
+ else:
+ x = nn.functional.pad(x, self.time_causal_padding).contiguous()
+
+ x = base_conv3d_channel_last(x, self.conv, residual=residual)
+ return x
+
+class AttnBlock(nn.Module):
+ def __init__(self,
+ in_channels
+ ):
+ super().__init__()
+
+ self.norm = BaseGroupNorm(num_groups=32, num_channels=in_channels)
+ self.q = CausalConvChannelLast(in_channels, in_channels, kernel_size=1)
+ self.k = CausalConvChannelLast(in_channels, in_channels, kernel_size=1)
+ self.v = CausalConvChannelLast(in_channels, in_channels, kernel_size=1)
+ self.proj_out = CausalConvChannelLast(in_channels, in_channels, kernel_size=1)
+
+ def attention(self, x, is_init=True):
+ x = self.norm(x, act_silu=False, channel_last=True)
+ q = self.q(x, is_init)
+ k = self.k(x, is_init)
+ v = self.v(x, is_init)
+
+ b, t, h, w, c = q.shape
+ q, k, v = map(lambda x: rearrange(x, "b t h w c -> b 1 (t h w) c"), (q, k, v))
+ x = nn.functional.scaled_dot_product_attention(q, k, v, is_causal=True)
+ x = rearrange(x, "b 1 (t h w) c -> b t h w c", t=t, h=h, w=w)
+
+ return x
+
+ def forward(self, x):
+ x = x.permute(0,2,3,4,1).contiguous()
+ h = self.attention(x)
+ x = self.proj_out(h, residual=x)
+ x = x.permute(0,4,1,2,3)
+ return x
+
+class Resnet3DBlock(nn.Module):
+ def __init__(self,
+ in_channels,
+ out_channels=None,
+ temb_channels=512,
+ conv_shortcut=False,
+ ):
+ super().__init__()
+
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+
+ self.norm1 = BaseGroupNorm(num_groups=32, num_channels=in_channels)
+ self.conv1 = CausalConvAfterNorm(in_channels, out_channels, kernel_size=3)
+ if temb_channels > 0:
+ self.temb_proj = nn.Linear(temb_channels, out_channels)
+
+ self.norm2 = BaseGroupNorm(num_groups=32, num_channels=out_channels)
+ self.conv2 = CausalConvAfterNorm(out_channels, out_channels, kernel_size=3)
+
+ assert conv_shortcut is False
+ self.use_conv_shortcut = conv_shortcut
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ self.conv_shortcut = CausalConvAfterNorm(in_channels, out_channels, kernel_size=3)
+ else:
+ self.nin_shortcut = CausalConvAfterNorm(in_channels, out_channels, kernel_size=1)
+
+ def forward(self, x, temb=None, is_init=True):
+ x = x.permute(0,2,3,4,1).contiguous()
+
+ h = self.norm1(x, zero_pad=True, act_silu=True, pad_size=2)
+ h = self.conv1(h)
+ if temb is not None:
+ h = h + self.temb_proj(nn.functional.silu(temb))[:, :, None, None]
+
+ x = self.nin_shortcut(x) if self.in_channels != self.out_channels else x
+
+ h = self.norm2(h, zero_pad=True, act_silu=True, pad_size=2)
+ x = self.conv2(h, residual=x)
+
+ x = x.permute(0,4,1,2,3)
+ return x
+
+
+class Downsample3D(nn.Module):
+ def __init__(self,
+ in_channels,
+ with_conv,
+ stride
+ ):
+ super().__init__()
+
+ self.with_conv = with_conv
+ if with_conv:
+ self.conv = CausalConv(in_channels, in_channels, kernel_size=3, stride=stride)
+
+ def forward(self, x, is_init=True):
+ if self.with_conv:
+ x = self.conv(x, is_init)
+ else:
+ x = nn.functional.avg_pool3d(x, kernel_size=2, stride=2)
+ return x
+
+class VideoEncoder(nn.Module):
+ def __init__(self,
+ ch=32,
+ ch_mult=(4, 8, 16, 16),
+ num_res_blocks=2,
+ in_channels=3,
+ z_channels=16,
+ double_z=True,
+ down_sampling_layer=[1, 2],
+ resamp_with_conv=True,
+ version=1,
+ ):
+ super().__init__()
+
+ temb_ch = 0
+
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+
+ # downsampling
+ self.conv_in = CausalConv(in_channels, ch, kernel_size=3)
+ self.down_sampling_layer = down_sampling_layer
+
+ in_ch_mult = (1,) + tuple(ch_mult)
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch * in_ch_mult[i_level]
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(
+ Resnet3DBlock(in_channels=block_in, out_channels=block_out, temb_channels=temb_ch))
+ block_in = block_out
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions - 1:
+ if i_level in self.down_sampling_layer:
+ down.downsample = Downsample3D(block_in, resamp_with_conv, stride=(2, 2, 2))
+ else:
+ down.downsample = Downsample2D(block_in, resamp_with_conv, padding=0) #DIFF
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = Resnet3DBlock(in_channels=block_in, out_channels=block_in, temb_channels=temb_ch)
+ self.mid.attn_1 = AttnBlock(block_in)
+ self.mid.block_2 = Resnet3DBlock(in_channels=block_in, out_channels=block_in, temb_channels=temb_ch)
+
+ # end
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in)
+ self.version = version
+ if version == 2:
+ channels = 4 * z_channels * 2 ** 3
+ self.conv_patchify = ConvPixelUnshuffleDownSampleLayer3D(block_in, channels, kernel_size=3, factor=2)
+ self.shortcut_pathify = PixelUnshuffleChannelAveragingDownSampleLayer3D(block_in, channels, 2)
+ self.shortcut_out = PixelUnshuffleChannelAveragingDownSampleLayer3D(channels, 2 * z_channels if double_z else z_channels, 1)
+ self.conv_out = CausalConvChannelLast(channels, 2 * z_channels if double_z else z_channels, kernel_size=3)
+ else:
+ self.conv_out = CausalConvAfterNorm(block_in, 2 * z_channels if double_z else z_channels, kernel_size=3)
+
+ @torch.inference_mode()
+ def forward(self, x, video_frame_num, is_init=True):
+ # timestep embedding
+ temb = None
+
+ t = video_frame_num
+
+ # downsampling
+ h = self.conv_in(x, is_init)
+
+ # make it real channel last, but behave like normal layout
+ h = h.permute(0,2,3,4,1).contiguous().permute(0,4,1,2,3)
+
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](h, temb, is_init)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+
+ if i_level != self.num_resolutions - 1:
+ if isinstance(self.down[i_level].downsample, Downsample2D):
+ _, _, t, _, _ = h.shape
+ h = rearrange(h, "b c t h w -> (b t) h w c", t=t)
+ h = self.down[i_level].downsample(h)
+ h = rearrange(h, "(b t) h w c -> b c t h w", t=t)
+ else:
+ h = self.down[i_level].downsample(h, is_init)
+
+ h = self.mid.block_1(h, temb, is_init)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb, is_init)
+
+ h = h.permute(0,2,3,4,1).contiguous() # b c l h w -> b l h w c
+ if self.version == 2:
+ h = base_group_norm(h, self.norm_out, act_silu=True, channel_last=True)
+ h = h.permute(0,4,1,2,3).contiguous()
+ shortcut = self.shortcut_pathify(h, is_init)
+ h = self.conv_patchify(h, is_init)
+ h = h.add_(shortcut)
+ shortcut = self.shortcut_out(h, is_init).permute(0,2,3,4,1)
+ h = self.conv_out(h.permute(0,2,3,4,1).contiguous(), is_init)
+ h = h.add_(shortcut)
+ else:
+ h = base_group_norm_with_zero_pad(h, self.norm_out, act_silu=True, pad_size=2)
+ h = self.conv_out(h, is_init)
+ h = h.permute(0,4,1,2,3) # b l h w c -> b c l h w
+
+ h = rearrange(h, "b c t h w -> b t c h w")
+ return h
+
+
+class Res3DBlockUpsample(nn.Module):
+ def __init__(self,
+ input_filters,
+ num_filters,
+ down_sampling_stride,
+ down_sampling=False
+ ):
+ super().__init__()
+
+ self.input_filters = input_filters
+ self.num_filters = num_filters
+
+ self.act_ = nn.SiLU(inplace=True)
+
+ self.conv1 = CausalConvChannelLast(num_filters, num_filters, kernel_size=[3, 3, 3])
+ self.norm1 = BaseGroupNorm(32, num_filters)
+
+ self.conv2 = CausalConvChannelLast(num_filters, num_filters, kernel_size=[3, 3, 3])
+ self.norm2 = BaseGroupNorm(32, num_filters)
+
+ self.down_sampling = down_sampling
+ if down_sampling:
+ self.down_sampling_stride = down_sampling_stride
+ else:
+ self.down_sampling_stride = [1, 1, 1]
+
+ if num_filters != input_filters or down_sampling:
+ self.conv3 = CausalConvChannelLast(input_filters, num_filters, kernel_size=[1, 1, 1], stride=self.down_sampling_stride)
+ self.norm3 = BaseGroupNorm(32, num_filters)
+
+ def forward(self, x, is_init=False):
+ x = x.permute(0,2,3,4,1).contiguous()
+
+ residual = x
+
+ h = self.conv1(x, is_init)
+ h = self.norm1(h, act_silu=True, channel_last=True)
+
+ h = self.conv2(h, is_init)
+ h = self.norm2(h, act_silu=False, channel_last=True)
+
+ if self.down_sampling or self.num_filters != self.input_filters:
+ x = self.conv3(x, is_init)
+ x = self.norm3(x, act_silu=False, channel_last=True)
+
+ h.add_(x)
+ h = self.act_(h)
+ if residual is not None:
+ h.add_(residual)
+
+ h = h.permute(0,4,1,2,3)
+ return h
+
+class Upsample3D(nn.Module):
+ def __init__(self,
+ in_channels,
+ scale_factor=2
+ ):
+ super().__init__()
+
+ self.scale_factor = scale_factor
+ self.conv3d = Res3DBlockUpsample(input_filters=in_channels,
+ num_filters=in_channels,
+ down_sampling_stride=(1, 1, 1),
+ down_sampling=False)
+
+ def forward(self, x, is_init=True, is_split=True):
+ b, c, t, h, w = x.shape
+
+ # x = x.permute(0,2,3,4,1).contiguous().permute(0,4,1,2,3).to(memory_format=torch.channels_last_3d)
+ if is_split:
+ split_size = c // 8
+ x_slices = torch.split(x, split_size, dim=1)
+ x = [nn.functional.interpolate(x, scale_factor=self.scale_factor) for x in x_slices]
+ x = torch.cat(x, dim=1)
+ else:
+ x = nn.functional.interpolate(x, scale_factor=self.scale_factor)
+
+ x = self.conv3d(x, is_init)
+ return x
+
+class VideoDecoder(nn.Module):
+ def __init__(self,
+ ch=128,
+ z_channels=16,
+ out_channels=3,
+ ch_mult=(1, 2, 4, 4),
+ num_res_blocks=2,
+ temporal_up_layers=[2, 3],
+ temporal_downsample=4,
+ resamp_with_conv=True,
+ version=1,
+ ):
+ super().__init__()
+
+ temb_ch = 0
+
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.temporal_downsample = temporal_downsample
+
+ block_in = ch * ch_mult[self.num_resolutions - 1]
+ self.version = version
+ if version == 2:
+ channels = 4 * z_channels * 2 ** 3
+ self.conv_in = CausalConv(z_channels, channels, kernel_size=3)
+ self.shortcut_in = ChannelDuplicatingPixelUnshuffleUpSampleLayer3D(z_channels, channels, 1)
+ self.conv_unpatchify = ConvPixelShuffleUpSampleLayer3D(channels, block_in, kernel_size=3, factor=2)
+ self.shortcut_unpathify = ChannelDuplicatingPixelUnshuffleUpSampleLayer3D(channels, block_in, 2)
+ else:
+ self.conv_in = CausalConv(z_channels, block_in, kernel_size=3)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = Resnet3DBlock(in_channels=block_in, out_channels=block_in, temb_channels=temb_ch)
+ self.mid.attn_1 = AttnBlock(block_in)
+ self.mid.block_2 = Resnet3DBlock(in_channels=block_in, out_channels=block_in, temb_channels=temb_ch)
+
+ # upsampling
+ self.up_id = len(temporal_up_layers)
+ self.video_frame_num = 1
+ self.cur_video_frame_num = self.video_frame_num // 2 ** self.up_id + 1
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks + 1):
+ block.append(
+ Resnet3DBlock(in_channels=block_in, out_channels=block_out, temb_channels=temb_ch))
+ block_in = block_out
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ if i_level in temporal_up_layers:
+ up.upsample = Upsample3D(block_in)
+ self.cur_video_frame_num = self.cur_video_frame_num * 2
+ else:
+ up.upsample = Upsample2D(block_in, resamp_with_conv)
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in)
+ self.conv_out = CausalConvAfterNorm(block_in, out_channels, kernel_size=3)
+
+ @torch.inference_mode()
+ def forward(self, z, is_init=True):
+ z = rearrange(z, "b t c h w -> b c t h w")
+
+ h = self.conv_in(z, is_init=is_init)
+ if self.version == 2:
+ shortcut = self.shortcut_in(z, is_init=is_init)
+ h = h.add_(shortcut)
+ shortcut = self.shortcut_unpathify(h, is_init=is_init)
+ h = self.conv_unpatchify(h, is_init=is_init)
+ h = h.add_(shortcut)
+
+ temb = None
+
+ h = h.permute(0,2,3,4,1).contiguous().permute(0,4,1,2,3)
+ h = self.mid.block_1(h, temb, is_init=is_init)
+ h = self.mid.attn_1(h)
+ h = h.permute(0,2,3,4,1).contiguous().permute(0,4,1,2,3)
+ h = self.mid.block_2(h, temb, is_init=is_init)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = h.permute(0,2,3,4,1).contiguous().permute(0,4,1,2,3)
+ h = self.up[i_level].block[i_block](h, temb, is_init=is_init)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ if isinstance(self.up[i_level].upsample, Upsample2D) or (hasattr(self.up[i_level].upsample, "module") and isinstance(self.up[i_level].upsample.module, Upsample2D)):
+ B = h.size(0)
+ h = h.permute(0,2,3,4,1).flatten(0,1)
+ h = self.up[i_level].upsample(h)
+ h = h.unflatten(0, (B, -1)).permute(0,4,1,2,3)
+ else:
+ h = self.up[i_level].upsample(h, is_init=is_init)
+
+ # end
+ h = h.permute(0,2,3,4,1) # b c l h w -> b l h w c
+ self.norm_out.to(dtype=h.dtype, device=h.device) # To be updated
+ h = base_group_norm_with_zero_pad(h, self.norm_out, act_silu=True, pad_size=2)
+ h = self.conv_out(h)
+ h = h.permute(0,4,1,2,3)
+
+ if is_init:
+ h = h[:, :, (self.temporal_downsample - 1):]
+ return h
+
+
+
+def rms_norm(input, normalized_shape, eps=1e-6):
+ dtype = input.dtype
+ input = input.to(torch.float32)
+ variance = input.pow(2).flatten(-len(normalized_shape)).mean(-1)[(...,) + (None,) * len(normalized_shape)]
+ input = input * torch.rsqrt(variance + eps)
+ return input.to(dtype)
+
+class DiagonalGaussianDistribution(object):
+ def __init__(self, parameters, deterministic=False, rms_norm_mean=False, only_return_mean=False):
+ self.parameters = parameters
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=-3) #N,[X],C,H,W
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
+ self.std = torch.exp(0.5 * self.logvar)
+ self.var = torch.exp(self.logvar)
+ self.deterministic = deterministic
+ if self.deterministic:
+ self.var = self.std = torch.zeros_like(
+ self.mean,
+ device=self.parameters.device,
+ dtype=self.parameters.dtype)
+ if rms_norm_mean:
+ self.mean = rms_norm(self.mean, self.mean.size()[1:])
+ self.only_return_mean = only_return_mean
+
+ def sample(self, generator=None):
+ # make sure sample is on the same device
+ # as the parameters and has same dtype
+ sample = torch.randn(
+ self.mean.shape, generator=generator, device=self.parameters.device)
+ sample = sample.to(dtype=self.parameters.dtype)
+ x = self.mean + self.std * sample
+ if self.only_return_mean:
+ return self.mean
+ else:
+ return x
+
+
+class StepVideoVAE(nn.Module):
+ def __init__(self,
+ in_channels=3,
+ out_channels=3,
+ z_channels=64,
+ num_res_blocks=2,
+ model_path=None,
+ weight_dict={},
+ world_size=1,
+ version=2,
+ ):
+ super().__init__()
+
+ self.frame_len = 17
+ self.latent_len = 3 if version == 2 else 5
+
+ base_group_norm.spatial = True if version == 2 else False
+
+ self.encoder = VideoEncoder(
+ in_channels=in_channels,
+ z_channels=z_channels,
+ num_res_blocks=num_res_blocks,
+ version=version,
+ )
+
+ self.decoder = VideoDecoder(
+ z_channels=z_channels,
+ out_channels=out_channels,
+ num_res_blocks=num_res_blocks,
+ version=version,
+ )
+
+ if model_path is not None:
+ weight_dict = self.init_from_ckpt(model_path)
+ if len(weight_dict) != 0:
+ self.load_from_dict(weight_dict)
+ self.convert_channel_last()
+
+ self.world_size = world_size
+
+ def init_from_ckpt(self, model_path):
+ from safetensors import safe_open
+ p = {}
+ with safe_open(model_path, framework="pt", device="cpu") as f:
+ for k in f.keys():
+ tensor = f.get_tensor(k)
+ if k.startswith("decoder.conv_out."):
+ k = k.replace("decoder.conv_out.", "decoder.conv_out.conv.")
+ p[k] = tensor
+ return p
+
+ def load_from_dict(self, p):
+ self.load_state_dict(p)
+
+ def convert_channel_last(self):
+ #Conv2d NCHW->NHWC
+ pass
+
+ def naive_encode(self, x, is_init_image=True):
+ b, l, c, h, w = x.size()
+ x = rearrange(x, 'b l c h w -> b c l h w').contiguous()
+ z = self.encoder(x, l, True) # 下采样[1, 4, 8, 16, 16]
+ return z
+
+ @torch.inference_mode()
+ def encode(self, x):
+ # b (nc cf) c h w -> (b nc) cf c h w -> encode -> (b nc) cf c h w -> b (nc cf) c h w
+ chunks = list(x.split(self.frame_len, dim=1))
+ for i in range(len(chunks)):
+ chunks[i] = self.naive_encode(chunks[i], True)
+ z = torch.cat(chunks, dim=1)
+
+ posterior = DiagonalGaussianDistribution(z)
+ return posterior.sample()
+
+ def decode_naive(self, z, is_init=True):
+ z = z.to(next(self.decoder.parameters()).dtype)
+ dec = self.decoder(z, is_init)
+ return dec
+
+ @torch.inference_mode()
+ def decode_original(self, z):
+ # b (nc cf) c h w -> (b nc) cf c h w -> decode -> (b nc) c cf h w -> b (nc cf) c h w
+ chunks = list(z.split(self.latent_len, dim=1))
+
+ if self.world_size > 1:
+ chunks_total_num = len(chunks)
+ max_num_per_rank = (chunks_total_num + self.world_size - 1) // self.world_size
+ rank = torch.distributed.get_rank()
+ chunks_ = chunks[max_num_per_rank * rank : max_num_per_rank * (rank + 1)]
+ if len(chunks_) < max_num_per_rank:
+ chunks_.extend(chunks[:max_num_per_rank-len(chunks_)])
+ chunks = chunks_
+
+ for i in range(len(chunks)):
+ chunks[i] = self.decode_naive(chunks[i], True).permute(0,2,1,3,4)
+ x = torch.cat(chunks, dim=1)
+
+ if self.world_size > 1:
+ x_ = torch.empty([x.size(0), (self.world_size * max_num_per_rank) * self.frame_len, *x.shape[2:]], dtype=x.dtype, device=x.device)
+ torch.distributed.all_gather_into_tensor(x_, x)
+ x = x_[:, : chunks_total_num * self.frame_len]
+
+ x = self.mix(x)
+ return x
+
+ def mix(self, x, smooth_scale = 0.6):
+ remain_scale = smooth_scale
+ mix_scale = 1. - remain_scale
+ front = slice(self.frame_len - 1, x.size(1) - 1, self.frame_len)
+ back = slice(self.frame_len, x.size(1), self.frame_len)
+ x[:, front], x[:, back] = (
+ x[:, front] * remain_scale + x[:, back] * mix_scale,
+ x[:, back] * remain_scale + x[:, front] * mix_scale
+ )
+ return x
+
+ def single_decode(self, hidden_states, device):
+ chunks = list(hidden_states.split(self.latent_len, dim=1))
+ for i in range(len(chunks)):
+ chunks[i] = self.decode_naive(chunks[i].to(device), True).permute(0,2,1,3,4).cpu()
+ x = torch.cat(chunks, dim=1)
+ return x
+
+ def build_1d_mask(self, length, left_bound, right_bound, border_width):
+ x = torch.ones((length,))
+ if not left_bound:
+ x[:border_width] = (torch.arange(border_width) + 1) / border_width
+ if not right_bound:
+ x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,))
+ return x
+
+ def build_mask(self, data, is_bound, border_width):
+ _, _, _, H, W = data.shape
+ h = self.build_1d_mask(H, is_bound[0], is_bound[1], border_width[0])
+ w = self.build_1d_mask(W, is_bound[2], is_bound[3], border_width[1])
+
+ h = repeat(h, "H -> H W", H=H, W=W)
+ w = repeat(w, "W -> H W", H=H, W=W)
+
+ mask = torch.stack([h, w]).min(dim=0).values
+ mask = rearrange(mask, "H W -> 1 1 1 H W")
+ return mask
+
+ def tiled_decode(self, hidden_states, device, tile_size=(34, 34), tile_stride=(16, 16)):
+ B, T, C, H, W = hidden_states.shape
+ size_h, size_w = tile_size
+ stride_h, stride_w = tile_stride
+
+ # Split tasks
+ tasks = []
+ for t in range(0, T, 3):
+ for h in range(0, H, stride_h):
+ if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
+ for w in range(0, W, stride_w):
+ if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
+ t_, h_, w_ = t + 3, h + size_h, w + size_w
+ tasks.append((t, t_, h, h_, w, w_))
+
+ # Run
+ data_device = "cpu"
+ computation_device = device
+
+ weight = torch.zeros((1, 1, T//3*17, H * 16, W * 16), dtype=hidden_states.dtype, device=data_device)
+ values = torch.zeros((B, 3, T//3*17, H * 16, W * 16), dtype=hidden_states.dtype, device=data_device)
+
+ for t, t_, h, h_, w, w_ in tqdm(tasks, desc="VAE decoding"):
+ hidden_states_batch = hidden_states[:, t:t_, :, h:h_, w:w_].to(computation_device)
+ hidden_states_batch = self.decode_naive(hidden_states_batch, True).to(data_device)
+
+ mask = self.build_mask(
+ hidden_states_batch,
+ is_bound=(h==0, h_>=H, w==0, w_>=W),
+ border_width=((size_h - stride_h) * 16, (size_w - stride_w) * 16)
+ ).to(dtype=hidden_states.dtype, device=data_device)
+
+ target_t = t // 3 * 17
+ target_h = h * 16
+ target_w = w * 16
+ values[
+ :,
+ :,
+ target_t: target_t + hidden_states_batch.shape[2],
+ target_h: target_h + hidden_states_batch.shape[3],
+ target_w: target_w + hidden_states_batch.shape[4],
+ ] += hidden_states_batch * mask
+ weight[
+ :,
+ :,
+ target_t: target_t + hidden_states_batch.shape[2],
+ target_h: target_h + hidden_states_batch.shape[3],
+ target_w: target_w + hidden_states_batch.shape[4],
+ ] += mask
+ return values / weight
+
+ def decode(self, hidden_states, device, tiled=False, tile_size=(34, 34), tile_stride=(16, 16), smooth_scale=0.6):
+ hidden_states = hidden_states.to("cpu")
+ if tiled:
+ video = self.tiled_decode(hidden_states, device, tile_size, tile_stride)
+ else:
+ video = self.single_decode(hidden_states, device)
+ video = self.mix(video, smooth_scale=smooth_scale)
+ return video
+
+ @staticmethod
+ def state_dict_converter():
+ return StepVideoVAEStateDictConverter()
+
+
+class StepVideoVAEStateDictConverter:
+ def __init__(self):
+ super().__init__()
+
+ def from_diffusers(self, state_dict):
+ return self.from_civitai(state_dict)
+
+ def from_civitai(self, state_dict):
+ state_dict_ = {}
+ for name, param in state_dict.items():
+ if name.startswith("decoder.conv_out."):
+ name_ = name.replace("decoder.conv_out.", "decoder.conv_out.conv.")
+ else:
+ name_ = name
+ state_dict_[name_] = param
+ return state_dict_
diff --git a/PusaV1/diffsynth/models/svd_image_encoder.py b/PusaV1/diffsynth/models/svd_image_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ee79c863ec8da0185bf0e57ac0a286495b049ed
--- /dev/null
+++ b/PusaV1/diffsynth/models/svd_image_encoder.py
@@ -0,0 +1,505 @@
+import torch
+from .sd_text_encoder import CLIPEncoderLayer
+
+
+class CLIPVisionEmbeddings(torch.nn.Module):
+ def __init__(self, embed_dim=1280, image_size=224, patch_size=14, num_channels=3):
+ super().__init__()
+
+ # class_embeds (This is a fixed tensor)
+ self.class_embedding = torch.nn.Parameter(torch.randn(1, 1, embed_dim))
+
+ # position_embeds
+ self.patch_embedding = torch.nn.Conv2d(in_channels=num_channels, out_channels=embed_dim, kernel_size=patch_size, stride=patch_size, bias=False)
+
+ # position_embeds (This is a fixed tensor)
+ self.position_embeds = torch.nn.Parameter(torch.zeros(1, (image_size // patch_size) ** 2 + 1, embed_dim))
+
+ def forward(self, pixel_values):
+ batch_size = pixel_values.shape[0]
+ patch_embeds = self.patch_embedding(pixel_values)
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
+ class_embeds = self.class_embedding.repeat(batch_size, 1, 1)
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + self.position_embeds
+ return embeddings
+
+
+class SVDImageEncoder(torch.nn.Module):
+ def __init__(self, embed_dim=1280, layer_norm_eps=1e-5, num_encoder_layers=32, encoder_intermediate_size=5120, projection_dim=1024, num_heads=16, head_dim=80):
+ super().__init__()
+ self.embeddings = CLIPVisionEmbeddings(embed_dim=embed_dim)
+ self.pre_layernorm = torch.nn.LayerNorm(embed_dim, eps=layer_norm_eps)
+ self.encoders = torch.nn.ModuleList([
+ CLIPEncoderLayer(embed_dim, encoder_intermediate_size, num_heads=num_heads, head_dim=head_dim, use_quick_gelu=False)
+ for _ in range(num_encoder_layers)])
+ self.post_layernorm = torch.nn.LayerNorm(embed_dim, eps=layer_norm_eps)
+ self.visual_projection = torch.nn.Linear(embed_dim, projection_dim, bias=False)
+
+ def forward(self, pixel_values):
+ embeds = self.embeddings(pixel_values)
+ embeds = self.pre_layernorm(embeds)
+ for encoder_id, encoder in enumerate(self.encoders):
+ embeds = encoder(embeds)
+ embeds = self.post_layernorm(embeds[:, 0, :])
+ embeds = self.visual_projection(embeds)
+ return embeds
+
+ @staticmethod
+ def state_dict_converter():
+ return SVDImageEncoderStateDictConverter()
+
+
+class SVDImageEncoderStateDictConverter:
+ def __init__(self):
+ pass
+
+ def from_diffusers(self, state_dict):
+ rename_dict = {
+ "vision_model.embeddings.patch_embedding.weight": "embeddings.patch_embedding.weight",
+ "vision_model.embeddings.class_embedding": "embeddings.class_embedding",
+ "vision_model.embeddings.position_embedding.weight": "embeddings.position_embeds",
+ "vision_model.pre_layrnorm.weight": "pre_layernorm.weight",
+ "vision_model.pre_layrnorm.bias": "pre_layernorm.bias",
+ "vision_model.post_layernorm.weight": "post_layernorm.weight",
+ "vision_model.post_layernorm.bias": "post_layernorm.bias",
+ "visual_projection.weight": "visual_projection.weight"
+ }
+ attn_rename_dict = {
+ "self_attn.q_proj": "attn.to_q",
+ "self_attn.k_proj": "attn.to_k",
+ "self_attn.v_proj": "attn.to_v",
+ "self_attn.out_proj": "attn.to_out",
+ "layer_norm1": "layer_norm1",
+ "layer_norm2": "layer_norm2",
+ "mlp.fc1": "fc1",
+ "mlp.fc2": "fc2",
+ }
+ state_dict_ = {}
+ for name in state_dict:
+ if name in rename_dict:
+ param = state_dict[name]
+ if name == "vision_model.embeddings.class_embedding":
+ param = state_dict[name].view(1, 1, -1)
+ elif name == "vision_model.embeddings.position_embedding.weight":
+ param = state_dict[name].unsqueeze(0)
+ state_dict_[rename_dict[name]] = param
+ elif name.startswith("vision_model.encoder.layers."):
+ param = state_dict[name]
+ names = name.split(".")
+ layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
+ name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail])
+ state_dict_[name_] = param
+ return state_dict_
+
+ def from_civitai(self, state_dict):
+ rename_dict = {
+ "conditioner.embedders.0.open_clip.model.visual.class_embedding": "embeddings.class_embedding",
+ "conditioner.embedders.0.open_clip.model.visual.conv1.weight": "embeddings.patch_embedding.weight",
+ "conditioner.embedders.0.open_clip.model.visual.ln_post.bias": "post_layernorm.bias",
+ "conditioner.embedders.0.open_clip.model.visual.ln_post.weight": "post_layernorm.weight",
+ "conditioner.embedders.0.open_clip.model.visual.ln_pre.bias": "pre_layernorm.bias",
+ "conditioner.embedders.0.open_clip.model.visual.ln_pre.weight": "pre_layernorm.weight",
+ "conditioner.embedders.0.open_clip.model.visual.positional_embedding": "embeddings.position_embeds",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.attn.in_proj_bias": ['encoders.0.attn.to_q.bias', 'encoders.0.attn.to_k.bias', 'encoders.0.attn.to_v.bias'],
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.attn.in_proj_weight": ['encoders.0.attn.to_q.weight', 'encoders.0.attn.to_k.weight', 'encoders.0.attn.to_v.weight'],
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.attn.out_proj.bias": "encoders.0.attn.to_out.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.attn.out_proj.weight": "encoders.0.attn.to_out.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.ln_1.bias": "encoders.0.layer_norm1.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.ln_1.weight": "encoders.0.layer_norm1.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.ln_2.bias": "encoders.0.layer_norm2.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.ln_2.weight": "encoders.0.layer_norm2.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.mlp.c_fc.bias": "encoders.0.fc1.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.mlp.c_fc.weight": "encoders.0.fc1.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.mlp.c_proj.bias": "encoders.0.fc2.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.mlp.c_proj.weight": "encoders.0.fc2.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.attn.in_proj_bias": ['encoders.1.attn.to_q.bias', 'encoders.1.attn.to_k.bias', 'encoders.1.attn.to_v.bias'],
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.attn.in_proj_weight": ['encoders.1.attn.to_q.weight', 'encoders.1.attn.to_k.weight', 'encoders.1.attn.to_v.weight'],
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.attn.out_proj.bias": "encoders.1.attn.to_out.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.attn.out_proj.weight": "encoders.1.attn.to_out.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.ln_1.bias": "encoders.1.layer_norm1.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.ln_1.weight": "encoders.1.layer_norm1.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.ln_2.bias": "encoders.1.layer_norm2.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.ln_2.weight": "encoders.1.layer_norm2.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.mlp.c_fc.bias": "encoders.1.fc1.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.mlp.c_fc.weight": "encoders.1.fc1.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.mlp.c_proj.bias": "encoders.1.fc2.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.mlp.c_proj.weight": "encoders.1.fc2.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.attn.in_proj_bias": ['encoders.10.attn.to_q.bias', 'encoders.10.attn.to_k.bias', 'encoders.10.attn.to_v.bias'],
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.attn.in_proj_weight": ['encoders.10.attn.to_q.weight', 'encoders.10.attn.to_k.weight', 'encoders.10.attn.to_v.weight'],
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.attn.out_proj.bias": "encoders.10.attn.to_out.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.attn.out_proj.weight": "encoders.10.attn.to_out.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.ln_1.bias": "encoders.10.layer_norm1.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.ln_1.weight": "encoders.10.layer_norm1.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.ln_2.bias": "encoders.10.layer_norm2.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.ln_2.weight": "encoders.10.layer_norm2.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.mlp.c_fc.bias": "encoders.10.fc1.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.mlp.c_fc.weight": "encoders.10.fc1.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.mlp.c_proj.bias": "encoders.10.fc2.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.mlp.c_proj.weight": "encoders.10.fc2.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.attn.in_proj_bias": ['encoders.11.attn.to_q.bias', 'encoders.11.attn.to_k.bias', 'encoders.11.attn.to_v.bias'],
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.attn.in_proj_weight": ['encoders.11.attn.to_q.weight', 'encoders.11.attn.to_k.weight', 'encoders.11.attn.to_v.weight'],
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.attn.out_proj.bias": "encoders.11.attn.to_out.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.attn.out_proj.weight": "encoders.11.attn.to_out.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.ln_1.bias": "encoders.11.layer_norm1.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.ln_1.weight": "encoders.11.layer_norm1.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.ln_2.bias": "encoders.11.layer_norm2.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.ln_2.weight": "encoders.11.layer_norm2.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.mlp.c_fc.bias": "encoders.11.fc1.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.mlp.c_fc.weight": "encoders.11.fc1.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.mlp.c_proj.bias": "encoders.11.fc2.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.mlp.c_proj.weight": "encoders.11.fc2.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.attn.in_proj_bias": ['encoders.12.attn.to_q.bias', 'encoders.12.attn.to_k.bias', 'encoders.12.attn.to_v.bias'],
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.attn.in_proj_weight": ['encoders.12.attn.to_q.weight', 'encoders.12.attn.to_k.weight', 'encoders.12.attn.to_v.weight'],
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.attn.out_proj.bias": "encoders.12.attn.to_out.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.attn.out_proj.weight": "encoders.12.attn.to_out.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.ln_1.bias": "encoders.12.layer_norm1.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.ln_1.weight": "encoders.12.layer_norm1.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.ln_2.bias": "encoders.12.layer_norm2.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.ln_2.weight": "encoders.12.layer_norm2.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.mlp.c_fc.bias": "encoders.12.fc1.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.mlp.c_fc.weight": "encoders.12.fc1.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.mlp.c_proj.bias": "encoders.12.fc2.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.mlp.c_proj.weight": "encoders.12.fc2.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.attn.in_proj_bias": ['encoders.13.attn.to_q.bias', 'encoders.13.attn.to_k.bias', 'encoders.13.attn.to_v.bias'],
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.attn.in_proj_weight": ['encoders.13.attn.to_q.weight', 'encoders.13.attn.to_k.weight', 'encoders.13.attn.to_v.weight'],
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.attn.out_proj.bias": "encoders.13.attn.to_out.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.attn.out_proj.weight": "encoders.13.attn.to_out.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.ln_1.bias": "encoders.13.layer_norm1.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.ln_1.weight": "encoders.13.layer_norm1.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.ln_2.bias": "encoders.13.layer_norm2.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.ln_2.weight": "encoders.13.layer_norm2.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.mlp.c_fc.bias": "encoders.13.fc1.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.mlp.c_fc.weight": "encoders.13.fc1.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.mlp.c_proj.bias": "encoders.13.fc2.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.mlp.c_proj.weight": "encoders.13.fc2.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.attn.in_proj_bias": ['encoders.14.attn.to_q.bias', 'encoders.14.attn.to_k.bias', 'encoders.14.attn.to_v.bias'],
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.attn.in_proj_weight": ['encoders.14.attn.to_q.weight', 'encoders.14.attn.to_k.weight', 'encoders.14.attn.to_v.weight'],
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.attn.out_proj.bias": "encoders.14.attn.to_out.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.attn.out_proj.weight": "encoders.14.attn.to_out.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.ln_1.bias": "encoders.14.layer_norm1.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.ln_1.weight": "encoders.14.layer_norm1.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.ln_2.bias": "encoders.14.layer_norm2.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.ln_2.weight": "encoders.14.layer_norm2.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.mlp.c_fc.bias": "encoders.14.fc1.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.mlp.c_fc.weight": "encoders.14.fc1.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.mlp.c_proj.bias": "encoders.14.fc2.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.mlp.c_proj.weight": "encoders.14.fc2.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.attn.in_proj_bias": ['encoders.15.attn.to_q.bias', 'encoders.15.attn.to_k.bias', 'encoders.15.attn.to_v.bias'],
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.attn.in_proj_weight": ['encoders.15.attn.to_q.weight', 'encoders.15.attn.to_k.weight', 'encoders.15.attn.to_v.weight'],
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.attn.out_proj.bias": "encoders.15.attn.to_out.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.attn.out_proj.weight": "encoders.15.attn.to_out.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.ln_1.bias": "encoders.15.layer_norm1.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.ln_1.weight": "encoders.15.layer_norm1.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.ln_2.bias": "encoders.15.layer_norm2.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.ln_2.weight": "encoders.15.layer_norm2.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.mlp.c_fc.bias": "encoders.15.fc1.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.mlp.c_fc.weight": "encoders.15.fc1.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.mlp.c_proj.bias": "encoders.15.fc2.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.mlp.c_proj.weight": "encoders.15.fc2.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.attn.in_proj_bias": ['encoders.16.attn.to_q.bias', 'encoders.16.attn.to_k.bias', 'encoders.16.attn.to_v.bias'],
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.attn.in_proj_weight": ['encoders.16.attn.to_q.weight', 'encoders.16.attn.to_k.weight', 'encoders.16.attn.to_v.weight'],
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.attn.out_proj.bias": "encoders.16.attn.to_out.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.attn.out_proj.weight": "encoders.16.attn.to_out.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.ln_1.bias": "encoders.16.layer_norm1.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.ln_1.weight": "encoders.16.layer_norm1.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.ln_2.bias": "encoders.16.layer_norm2.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.ln_2.weight": "encoders.16.layer_norm2.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.mlp.c_fc.bias": "encoders.16.fc1.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.mlp.c_fc.weight": "encoders.16.fc1.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.mlp.c_proj.bias": "encoders.16.fc2.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.mlp.c_proj.weight": "encoders.16.fc2.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.attn.in_proj_bias": ['encoders.17.attn.to_q.bias', 'encoders.17.attn.to_k.bias', 'encoders.17.attn.to_v.bias'],
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.attn.in_proj_weight": ['encoders.17.attn.to_q.weight', 'encoders.17.attn.to_k.weight', 'encoders.17.attn.to_v.weight'],
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.attn.out_proj.bias": "encoders.17.attn.to_out.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.attn.out_proj.weight": "encoders.17.attn.to_out.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.ln_1.bias": "encoders.17.layer_norm1.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.ln_1.weight": "encoders.17.layer_norm1.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.ln_2.bias": "encoders.17.layer_norm2.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.ln_2.weight": "encoders.17.layer_norm2.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.mlp.c_fc.bias": "encoders.17.fc1.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.mlp.c_fc.weight": "encoders.17.fc1.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.mlp.c_proj.bias": "encoders.17.fc2.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.mlp.c_proj.weight": "encoders.17.fc2.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.attn.in_proj_bias": ['encoders.18.attn.to_q.bias', 'encoders.18.attn.to_k.bias', 'encoders.18.attn.to_v.bias'],
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.attn.in_proj_weight": ['encoders.18.attn.to_q.weight', 'encoders.18.attn.to_k.weight', 'encoders.18.attn.to_v.weight'],
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.attn.out_proj.bias": "encoders.18.attn.to_out.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.attn.out_proj.weight": "encoders.18.attn.to_out.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.ln_1.bias": "encoders.18.layer_norm1.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.ln_1.weight": "encoders.18.layer_norm1.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.ln_2.bias": "encoders.18.layer_norm2.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.ln_2.weight": "encoders.18.layer_norm2.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.mlp.c_fc.bias": "encoders.18.fc1.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.mlp.c_fc.weight": "encoders.18.fc1.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.mlp.c_proj.bias": "encoders.18.fc2.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.mlp.c_proj.weight": "encoders.18.fc2.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.attn.in_proj_bias": ['encoders.19.attn.to_q.bias', 'encoders.19.attn.to_k.bias', 'encoders.19.attn.to_v.bias'],
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.attn.in_proj_weight": ['encoders.19.attn.to_q.weight', 'encoders.19.attn.to_k.weight', 'encoders.19.attn.to_v.weight'],
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.attn.out_proj.bias": "encoders.19.attn.to_out.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.attn.out_proj.weight": "encoders.19.attn.to_out.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.ln_1.bias": "encoders.19.layer_norm1.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.ln_1.weight": "encoders.19.layer_norm1.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.ln_2.bias": "encoders.19.layer_norm2.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.ln_2.weight": "encoders.19.layer_norm2.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.mlp.c_fc.bias": "encoders.19.fc1.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.mlp.c_fc.weight": "encoders.19.fc1.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.mlp.c_proj.bias": "encoders.19.fc2.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.mlp.c_proj.weight": "encoders.19.fc2.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.attn.in_proj_bias": ['encoders.2.attn.to_q.bias', 'encoders.2.attn.to_k.bias', 'encoders.2.attn.to_v.bias'],
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.attn.in_proj_weight": ['encoders.2.attn.to_q.weight', 'encoders.2.attn.to_k.weight', 'encoders.2.attn.to_v.weight'],
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.attn.out_proj.bias": "encoders.2.attn.to_out.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.attn.out_proj.weight": "encoders.2.attn.to_out.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.ln_1.bias": "encoders.2.layer_norm1.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.ln_1.weight": "encoders.2.layer_norm1.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.ln_2.bias": "encoders.2.layer_norm2.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.ln_2.weight": "encoders.2.layer_norm2.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.mlp.c_fc.bias": "encoders.2.fc1.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.mlp.c_fc.weight": "encoders.2.fc1.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.mlp.c_proj.bias": "encoders.2.fc2.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.mlp.c_proj.weight": "encoders.2.fc2.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.attn.in_proj_bias": ['encoders.20.attn.to_q.bias', 'encoders.20.attn.to_k.bias', 'encoders.20.attn.to_v.bias'],
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.attn.in_proj_weight": ['encoders.20.attn.to_q.weight', 'encoders.20.attn.to_k.weight', 'encoders.20.attn.to_v.weight'],
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.attn.out_proj.bias": "encoders.20.attn.to_out.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.attn.out_proj.weight": "encoders.20.attn.to_out.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.ln_1.bias": "encoders.20.layer_norm1.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.ln_1.weight": "encoders.20.layer_norm1.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.ln_2.bias": "encoders.20.layer_norm2.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.ln_2.weight": "encoders.20.layer_norm2.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.mlp.c_fc.bias": "encoders.20.fc1.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.mlp.c_fc.weight": "encoders.20.fc1.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.mlp.c_proj.bias": "encoders.20.fc2.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.mlp.c_proj.weight": "encoders.20.fc2.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.attn.in_proj_bias": ['encoders.21.attn.to_q.bias', 'encoders.21.attn.to_k.bias', 'encoders.21.attn.to_v.bias'],
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.attn.in_proj_weight": ['encoders.21.attn.to_q.weight', 'encoders.21.attn.to_k.weight', 'encoders.21.attn.to_v.weight'],
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.attn.out_proj.bias": "encoders.21.attn.to_out.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.attn.out_proj.weight": "encoders.21.attn.to_out.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.ln_1.bias": "encoders.21.layer_norm1.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.ln_1.weight": "encoders.21.layer_norm1.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.ln_2.bias": "encoders.21.layer_norm2.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.ln_2.weight": "encoders.21.layer_norm2.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.mlp.c_fc.bias": "encoders.21.fc1.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.mlp.c_fc.weight": "encoders.21.fc1.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.mlp.c_proj.bias": "encoders.21.fc2.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.mlp.c_proj.weight": "encoders.21.fc2.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.attn.in_proj_bias": ['encoders.22.attn.to_q.bias', 'encoders.22.attn.to_k.bias', 'encoders.22.attn.to_v.bias'],
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.attn.in_proj_weight": ['encoders.22.attn.to_q.weight', 'encoders.22.attn.to_k.weight', 'encoders.22.attn.to_v.weight'],
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.attn.out_proj.bias": "encoders.22.attn.to_out.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.attn.out_proj.weight": "encoders.22.attn.to_out.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.ln_1.bias": "encoders.22.layer_norm1.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.ln_1.weight": "encoders.22.layer_norm1.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.ln_2.bias": "encoders.22.layer_norm2.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.ln_2.weight": "encoders.22.layer_norm2.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.mlp.c_fc.bias": "encoders.22.fc1.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.mlp.c_fc.weight": "encoders.22.fc1.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.mlp.c_proj.bias": "encoders.22.fc2.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.mlp.c_proj.weight": "encoders.22.fc2.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.attn.in_proj_bias": ['encoders.23.attn.to_q.bias', 'encoders.23.attn.to_k.bias', 'encoders.23.attn.to_v.bias'],
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.attn.in_proj_weight": ['encoders.23.attn.to_q.weight', 'encoders.23.attn.to_k.weight', 'encoders.23.attn.to_v.weight'],
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.attn.out_proj.bias": "encoders.23.attn.to_out.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.attn.out_proj.weight": "encoders.23.attn.to_out.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.ln_1.bias": "encoders.23.layer_norm1.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.ln_1.weight": "encoders.23.layer_norm1.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.ln_2.bias": "encoders.23.layer_norm2.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.ln_2.weight": "encoders.23.layer_norm2.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.mlp.c_fc.bias": "encoders.23.fc1.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.mlp.c_fc.weight": "encoders.23.fc1.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.mlp.c_proj.bias": "encoders.23.fc2.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.mlp.c_proj.weight": "encoders.23.fc2.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.attn.in_proj_bias": ['encoders.24.attn.to_q.bias', 'encoders.24.attn.to_k.bias', 'encoders.24.attn.to_v.bias'],
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.attn.in_proj_weight": ['encoders.24.attn.to_q.weight', 'encoders.24.attn.to_k.weight', 'encoders.24.attn.to_v.weight'],
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.attn.out_proj.bias": "encoders.24.attn.to_out.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.attn.out_proj.weight": "encoders.24.attn.to_out.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.ln_1.bias": "encoders.24.layer_norm1.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.ln_1.weight": "encoders.24.layer_norm1.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.ln_2.bias": "encoders.24.layer_norm2.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.ln_2.weight": "encoders.24.layer_norm2.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.mlp.c_fc.bias": "encoders.24.fc1.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.mlp.c_fc.weight": "encoders.24.fc1.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.mlp.c_proj.bias": "encoders.24.fc2.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.mlp.c_proj.weight": "encoders.24.fc2.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.attn.in_proj_bias": ['encoders.25.attn.to_q.bias', 'encoders.25.attn.to_k.bias', 'encoders.25.attn.to_v.bias'],
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.attn.in_proj_weight": ['encoders.25.attn.to_q.weight', 'encoders.25.attn.to_k.weight', 'encoders.25.attn.to_v.weight'],
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.attn.out_proj.bias": "encoders.25.attn.to_out.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.attn.out_proj.weight": "encoders.25.attn.to_out.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.ln_1.bias": "encoders.25.layer_norm1.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.ln_1.weight": "encoders.25.layer_norm1.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.ln_2.bias": "encoders.25.layer_norm2.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.ln_2.weight": "encoders.25.layer_norm2.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.mlp.c_fc.bias": "encoders.25.fc1.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.mlp.c_fc.weight": "encoders.25.fc1.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.mlp.c_proj.bias": "encoders.25.fc2.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.mlp.c_proj.weight": "encoders.25.fc2.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.attn.in_proj_bias": ['encoders.26.attn.to_q.bias', 'encoders.26.attn.to_k.bias', 'encoders.26.attn.to_v.bias'],
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.attn.in_proj_weight": ['encoders.26.attn.to_q.weight', 'encoders.26.attn.to_k.weight', 'encoders.26.attn.to_v.weight'],
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.attn.out_proj.bias": "encoders.26.attn.to_out.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.attn.out_proj.weight": "encoders.26.attn.to_out.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.ln_1.bias": "encoders.26.layer_norm1.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.ln_1.weight": "encoders.26.layer_norm1.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.ln_2.bias": "encoders.26.layer_norm2.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.ln_2.weight": "encoders.26.layer_norm2.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.mlp.c_fc.bias": "encoders.26.fc1.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.mlp.c_fc.weight": "encoders.26.fc1.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.mlp.c_proj.bias": "encoders.26.fc2.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.mlp.c_proj.weight": "encoders.26.fc2.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.attn.in_proj_bias": ['encoders.27.attn.to_q.bias', 'encoders.27.attn.to_k.bias', 'encoders.27.attn.to_v.bias'],
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.attn.in_proj_weight": ['encoders.27.attn.to_q.weight', 'encoders.27.attn.to_k.weight', 'encoders.27.attn.to_v.weight'],
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.attn.out_proj.bias": "encoders.27.attn.to_out.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.attn.out_proj.weight": "encoders.27.attn.to_out.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.ln_1.bias": "encoders.27.layer_norm1.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.ln_1.weight": "encoders.27.layer_norm1.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.ln_2.bias": "encoders.27.layer_norm2.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.ln_2.weight": "encoders.27.layer_norm2.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.mlp.c_fc.bias": "encoders.27.fc1.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.mlp.c_fc.weight": "encoders.27.fc1.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.mlp.c_proj.bias": "encoders.27.fc2.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.mlp.c_proj.weight": "encoders.27.fc2.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.attn.in_proj_bias": ['encoders.28.attn.to_q.bias', 'encoders.28.attn.to_k.bias', 'encoders.28.attn.to_v.bias'],
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.attn.in_proj_weight": ['encoders.28.attn.to_q.weight', 'encoders.28.attn.to_k.weight', 'encoders.28.attn.to_v.weight'],
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.attn.out_proj.bias": "encoders.28.attn.to_out.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.attn.out_proj.weight": "encoders.28.attn.to_out.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.ln_1.bias": "encoders.28.layer_norm1.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.ln_1.weight": "encoders.28.layer_norm1.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.ln_2.bias": "encoders.28.layer_norm2.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.ln_2.weight": "encoders.28.layer_norm2.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.mlp.c_fc.bias": "encoders.28.fc1.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.mlp.c_fc.weight": "encoders.28.fc1.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.mlp.c_proj.bias": "encoders.28.fc2.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.mlp.c_proj.weight": "encoders.28.fc2.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.attn.in_proj_bias": ['encoders.29.attn.to_q.bias', 'encoders.29.attn.to_k.bias', 'encoders.29.attn.to_v.bias'],
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.attn.in_proj_weight": ['encoders.29.attn.to_q.weight', 'encoders.29.attn.to_k.weight', 'encoders.29.attn.to_v.weight'],
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.attn.out_proj.bias": "encoders.29.attn.to_out.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.attn.out_proj.weight": "encoders.29.attn.to_out.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.ln_1.bias": "encoders.29.layer_norm1.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.ln_1.weight": "encoders.29.layer_norm1.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.ln_2.bias": "encoders.29.layer_norm2.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.ln_2.weight": "encoders.29.layer_norm2.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.mlp.c_fc.bias": "encoders.29.fc1.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.mlp.c_fc.weight": "encoders.29.fc1.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.mlp.c_proj.bias": "encoders.29.fc2.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.mlp.c_proj.weight": "encoders.29.fc2.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.attn.in_proj_bias": ['encoders.3.attn.to_q.bias', 'encoders.3.attn.to_k.bias', 'encoders.3.attn.to_v.bias'],
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.attn.in_proj_weight": ['encoders.3.attn.to_q.weight', 'encoders.3.attn.to_k.weight', 'encoders.3.attn.to_v.weight'],
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.attn.out_proj.bias": "encoders.3.attn.to_out.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.attn.out_proj.weight": "encoders.3.attn.to_out.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.ln_1.bias": "encoders.3.layer_norm1.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.ln_1.weight": "encoders.3.layer_norm1.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.ln_2.bias": "encoders.3.layer_norm2.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.ln_2.weight": "encoders.3.layer_norm2.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.mlp.c_fc.bias": "encoders.3.fc1.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.mlp.c_fc.weight": "encoders.3.fc1.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.mlp.c_proj.bias": "encoders.3.fc2.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.mlp.c_proj.weight": "encoders.3.fc2.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.attn.in_proj_bias": ['encoders.30.attn.to_q.bias', 'encoders.30.attn.to_k.bias', 'encoders.30.attn.to_v.bias'],
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.attn.in_proj_weight": ['encoders.30.attn.to_q.weight', 'encoders.30.attn.to_k.weight', 'encoders.30.attn.to_v.weight'],
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.attn.out_proj.bias": "encoders.30.attn.to_out.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.attn.out_proj.weight": "encoders.30.attn.to_out.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.ln_1.bias": "encoders.30.layer_norm1.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.ln_1.weight": "encoders.30.layer_norm1.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.ln_2.bias": "encoders.30.layer_norm2.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.ln_2.weight": "encoders.30.layer_norm2.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.mlp.c_fc.bias": "encoders.30.fc1.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.mlp.c_fc.weight": "encoders.30.fc1.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.mlp.c_proj.bias": "encoders.30.fc2.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.mlp.c_proj.weight": "encoders.30.fc2.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.attn.in_proj_bias": ['encoders.31.attn.to_q.bias', 'encoders.31.attn.to_k.bias', 'encoders.31.attn.to_v.bias'],
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.attn.in_proj_weight": ['encoders.31.attn.to_q.weight', 'encoders.31.attn.to_k.weight', 'encoders.31.attn.to_v.weight'],
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.attn.out_proj.bias": "encoders.31.attn.to_out.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.attn.out_proj.weight": "encoders.31.attn.to_out.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.ln_1.bias": "encoders.31.layer_norm1.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.ln_1.weight": "encoders.31.layer_norm1.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.ln_2.bias": "encoders.31.layer_norm2.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.ln_2.weight": "encoders.31.layer_norm2.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.mlp.c_fc.bias": "encoders.31.fc1.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.mlp.c_fc.weight": "encoders.31.fc1.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.mlp.c_proj.bias": "encoders.31.fc2.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.mlp.c_proj.weight": "encoders.31.fc2.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.attn.in_proj_bias": ['encoders.4.attn.to_q.bias', 'encoders.4.attn.to_k.bias', 'encoders.4.attn.to_v.bias'],
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.attn.in_proj_weight": ['encoders.4.attn.to_q.weight', 'encoders.4.attn.to_k.weight', 'encoders.4.attn.to_v.weight'],
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.attn.out_proj.bias": "encoders.4.attn.to_out.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.attn.out_proj.weight": "encoders.4.attn.to_out.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.ln_1.bias": "encoders.4.layer_norm1.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.ln_1.weight": "encoders.4.layer_norm1.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.ln_2.bias": "encoders.4.layer_norm2.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.ln_2.weight": "encoders.4.layer_norm2.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.mlp.c_fc.bias": "encoders.4.fc1.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.mlp.c_fc.weight": "encoders.4.fc1.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.mlp.c_proj.bias": "encoders.4.fc2.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.mlp.c_proj.weight": "encoders.4.fc2.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.attn.in_proj_bias": ['encoders.5.attn.to_q.bias', 'encoders.5.attn.to_k.bias', 'encoders.5.attn.to_v.bias'],
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.attn.in_proj_weight": ['encoders.5.attn.to_q.weight', 'encoders.5.attn.to_k.weight', 'encoders.5.attn.to_v.weight'],
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.attn.out_proj.bias": "encoders.5.attn.to_out.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.attn.out_proj.weight": "encoders.5.attn.to_out.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.ln_1.bias": "encoders.5.layer_norm1.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.ln_1.weight": "encoders.5.layer_norm1.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.ln_2.bias": "encoders.5.layer_norm2.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.ln_2.weight": "encoders.5.layer_norm2.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.mlp.c_fc.bias": "encoders.5.fc1.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.mlp.c_fc.weight": "encoders.5.fc1.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.mlp.c_proj.bias": "encoders.5.fc2.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.mlp.c_proj.weight": "encoders.5.fc2.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.attn.in_proj_bias": ['encoders.6.attn.to_q.bias', 'encoders.6.attn.to_k.bias', 'encoders.6.attn.to_v.bias'],
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.attn.in_proj_weight": ['encoders.6.attn.to_q.weight', 'encoders.6.attn.to_k.weight', 'encoders.6.attn.to_v.weight'],
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.attn.out_proj.bias": "encoders.6.attn.to_out.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.attn.out_proj.weight": "encoders.6.attn.to_out.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.ln_1.bias": "encoders.6.layer_norm1.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.ln_1.weight": "encoders.6.layer_norm1.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.ln_2.bias": "encoders.6.layer_norm2.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.ln_2.weight": "encoders.6.layer_norm2.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.mlp.c_fc.bias": "encoders.6.fc1.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.mlp.c_fc.weight": "encoders.6.fc1.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.mlp.c_proj.bias": "encoders.6.fc2.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.mlp.c_proj.weight": "encoders.6.fc2.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.attn.in_proj_bias": ['encoders.7.attn.to_q.bias', 'encoders.7.attn.to_k.bias', 'encoders.7.attn.to_v.bias'],
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.attn.in_proj_weight": ['encoders.7.attn.to_q.weight', 'encoders.7.attn.to_k.weight', 'encoders.7.attn.to_v.weight'],
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.attn.out_proj.bias": "encoders.7.attn.to_out.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.attn.out_proj.weight": "encoders.7.attn.to_out.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.ln_1.bias": "encoders.7.layer_norm1.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.ln_1.weight": "encoders.7.layer_norm1.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.ln_2.bias": "encoders.7.layer_norm2.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.ln_2.weight": "encoders.7.layer_norm2.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.mlp.c_fc.bias": "encoders.7.fc1.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.mlp.c_fc.weight": "encoders.7.fc1.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.mlp.c_proj.bias": "encoders.7.fc2.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.mlp.c_proj.weight": "encoders.7.fc2.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.attn.in_proj_bias": ['encoders.8.attn.to_q.bias', 'encoders.8.attn.to_k.bias', 'encoders.8.attn.to_v.bias'],
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.attn.in_proj_weight": ['encoders.8.attn.to_q.weight', 'encoders.8.attn.to_k.weight', 'encoders.8.attn.to_v.weight'],
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.attn.out_proj.bias": "encoders.8.attn.to_out.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.attn.out_proj.weight": "encoders.8.attn.to_out.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.ln_1.bias": "encoders.8.layer_norm1.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.ln_1.weight": "encoders.8.layer_norm1.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.ln_2.bias": "encoders.8.layer_norm2.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.ln_2.weight": "encoders.8.layer_norm2.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.mlp.c_fc.bias": "encoders.8.fc1.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.mlp.c_fc.weight": "encoders.8.fc1.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.mlp.c_proj.bias": "encoders.8.fc2.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.mlp.c_proj.weight": "encoders.8.fc2.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.attn.in_proj_bias": ['encoders.9.attn.to_q.bias', 'encoders.9.attn.to_k.bias', 'encoders.9.attn.to_v.bias'],
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.attn.in_proj_weight": ['encoders.9.attn.to_q.weight', 'encoders.9.attn.to_k.weight', 'encoders.9.attn.to_v.weight'],
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.attn.out_proj.bias": "encoders.9.attn.to_out.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.attn.out_proj.weight": "encoders.9.attn.to_out.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.ln_1.bias": "encoders.9.layer_norm1.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.ln_1.weight": "encoders.9.layer_norm1.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.ln_2.bias": "encoders.9.layer_norm2.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.ln_2.weight": "encoders.9.layer_norm2.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.mlp.c_fc.bias": "encoders.9.fc1.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.mlp.c_fc.weight": "encoders.9.fc1.weight",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.mlp.c_proj.bias": "encoders.9.fc2.bias",
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.mlp.c_proj.weight": "encoders.9.fc2.weight",
+ "conditioner.embedders.0.open_clip.model.visual.proj": "visual_projection.weight",
+ }
+ state_dict_ = {}
+ for name in state_dict:
+ if name in rename_dict:
+ param = state_dict[name]
+ if name == "conditioner.embedders.0.open_clip.model.visual.class_embedding":
+ param = param.reshape((1, 1, param.shape[0]))
+ elif name == "conditioner.embedders.0.open_clip.model.visual.positional_embedding":
+ param = param.reshape((1, param.shape[0], param.shape[1]))
+ elif name == "conditioner.embedders.0.open_clip.model.visual.proj":
+ param = param.T
+ if isinstance(rename_dict[name], str):
+ state_dict_[rename_dict[name]] = param
+ else:
+ length = param.shape[0] // 3
+ for i, rename in enumerate(rename_dict[name]):
+ state_dict_[rename] = param[i*length: i*length+length]
+ return state_dict_
diff --git a/PusaV1/diffsynth/models/svd_unet.py b/PusaV1/diffsynth/models/svd_unet.py
new file mode 100644
index 0000000000000000000000000000000000000000..19c540a926914eea4f827cfb1aa460a098c61025
--- /dev/null
+++ b/PusaV1/diffsynth/models/svd_unet.py
@@ -0,0 +1,2007 @@
+import torch, math
+from einops import rearrange, repeat
+from .sd_unet import Timesteps, PushBlock, PopBlock, Attention, GEGLU, ResnetBlock, AttentionBlock, DownSampler, UpSampler
+
+
+class TemporalResnetBlock(torch.nn.Module):
+ def __init__(self, in_channels, out_channels, temb_channels=None, groups=32, eps=1e-5):
+ super().__init__()
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
+ self.conv1 = torch.nn.Conv3d(in_channels, out_channels, kernel_size=(3, 1, 1), stride=(1, 1, 1), padding=(1, 0, 0))
+ if temb_channels is not None:
+ self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True)
+ self.conv2 = torch.nn.Conv3d(out_channels, out_channels, kernel_size=(3, 1, 1), stride=(1, 1, 1), padding=(1, 0, 0))
+ self.nonlinearity = torch.nn.SiLU()
+ self.conv_shortcut = None
+ if in_channels != out_channels:
+ self.conv_shortcut = torch.nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=True)
+
+ def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs):
+ x = rearrange(hidden_states, "f c h w -> 1 c f h w")
+ x = self.norm1(x)
+ x = self.nonlinearity(x)
+ x = self.conv1(x)
+ if time_emb is not None:
+ emb = self.nonlinearity(time_emb)
+ emb = self.time_emb_proj(emb)
+ emb = repeat(emb, "b c -> b c f 1 1", f=hidden_states.shape[0])
+ x = x + emb
+ x = self.norm2(x)
+ x = self.nonlinearity(x)
+ x = self.conv2(x)
+ if self.conv_shortcut is not None:
+ hidden_states = self.conv_shortcut(hidden_states)
+ x = rearrange(x[0], "c f h w -> f c h w")
+ hidden_states = hidden_states + x
+ return hidden_states, time_emb, text_emb, res_stack
+
+
+def get_timestep_embedding(
+ timesteps: torch.Tensor,
+ embedding_dim: int,
+ flip_sin_to_cos: bool = False,
+ downscale_freq_shift: float = 1,
+ scale: float = 1,
+ max_period: int = 10000,
+ computation_device = None,
+):
+ """
+ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
+
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
+ embeddings. :return: an [N x dim] Tensor of positional embeddings.
+ """
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
+
+ half_dim = embedding_dim // 2
+ exponent = -math.log(max_period) * torch.arange(
+ start=0, end=half_dim, dtype=torch.float32, device=timesteps.device if computation_device is None else computation_device
+ )
+ exponent = exponent / (half_dim - downscale_freq_shift)
+
+ emb = torch.exp(exponent).to(timesteps.device)
+ emb = timesteps[:, None].float() * emb[None, :]
+
+ # scale embeddings
+ emb = scale * emb
+
+ # concat sine and cosine embeddings
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
+
+ # flip sine and cosine embeddings
+ if flip_sin_to_cos:
+ emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
+
+ # zero pad
+ if embedding_dim % 2 == 1:
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
+ return emb
+
+
+class TemporalTimesteps(torch.nn.Module):
+ def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, computation_device = None):
+ super().__init__()
+ self.num_channels = num_channels
+ self.flip_sin_to_cos = flip_sin_to_cos
+ self.downscale_freq_shift = downscale_freq_shift
+ self.computation_device = computation_device
+
+ def forward(self, timesteps):
+ t_emb = get_timestep_embedding(
+ timesteps,
+ self.num_channels,
+ flip_sin_to_cos=self.flip_sin_to_cos,
+ downscale_freq_shift=self.downscale_freq_shift,
+ computation_device=self.computation_device,
+ )
+ return t_emb
+
+
+class TrainableTemporalTimesteps(torch.nn.Module):
+ def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, num_frames: int):
+ super().__init__()
+ timesteps = PositionalID()(num_frames)
+ embeddings = get_timestep_embedding(timesteps, num_channels, flip_sin_to_cos, downscale_freq_shift)
+ self.embeddings = torch.nn.Parameter(embeddings)
+
+ def forward(self, timesteps):
+ t_emb = self.embeddings[timesteps]
+ return t_emb
+
+
+class PositionalID(torch.nn.Module):
+ def __init__(self, max_id=25, repeat_length=20):
+ super().__init__()
+ self.max_id = max_id
+ self.repeat_length = repeat_length
+
+ def frame_id_to_position_id(self, frame_id):
+ if frame_id < self.max_id:
+ position_id = frame_id
+ else:
+ position_id = (frame_id - self.max_id) % (self.repeat_length * 2)
+ if position_id < self.repeat_length:
+ position_id = self.max_id - 2 - position_id
+ else:
+ position_id = self.max_id - 2 * self.repeat_length + position_id
+ return position_id
+
+ def forward(self, num_frames, pivot_frame_id=0):
+ position_ids = [self.frame_id_to_position_id(abs(i-pivot_frame_id)) for i in range(num_frames)]
+ position_ids = torch.IntTensor(position_ids)
+ return position_ids
+
+
+class TemporalAttentionBlock(torch.nn.Module):
+
+ def __init__(self, num_attention_heads, attention_head_dim, in_channels, cross_attention_dim=None, add_positional_conv=None):
+ super().__init__()
+
+ self.positional_embedding_proj = torch.nn.Sequential(
+ torch.nn.Linear(in_channels, in_channels * 4),
+ torch.nn.SiLU(),
+ torch.nn.Linear(in_channels * 4, in_channels)
+ )
+ if add_positional_conv is not None:
+ self.positional_embedding = TrainableTemporalTimesteps(in_channels, True, 0, add_positional_conv)
+ self.positional_conv = torch.nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1, padding_mode="reflect")
+ else:
+ self.positional_embedding = TemporalTimesteps(in_channels, True, 0)
+ self.positional_conv = None
+
+ self.norm_in = torch.nn.LayerNorm(in_channels)
+ self.act_fn_in = GEGLU(in_channels, in_channels * 4)
+ self.ff_in = torch.nn.Linear(in_channels * 4, in_channels)
+
+ self.norm1 = torch.nn.LayerNorm(in_channels)
+ self.attn1 = Attention(
+ q_dim=in_channels,
+ num_heads=num_attention_heads,
+ head_dim=attention_head_dim,
+ bias_out=True
+ )
+
+ self.norm2 = torch.nn.LayerNorm(in_channels)
+ self.attn2 = Attention(
+ q_dim=in_channels,
+ kv_dim=cross_attention_dim,
+ num_heads=num_attention_heads,
+ head_dim=attention_head_dim,
+ bias_out=True
+ )
+
+ self.norm_out = torch.nn.LayerNorm(in_channels)
+ self.act_fn_out = GEGLU(in_channels, in_channels * 4)
+ self.ff_out = torch.nn.Linear(in_channels * 4, in_channels)
+
+ def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs):
+
+ batch, inner_dim, height, width = hidden_states.shape
+ pos_emb = torch.arange(batch)
+ pos_emb = self.positional_embedding(pos_emb).to(dtype=hidden_states.dtype, device=hidden_states.device)
+ pos_emb = self.positional_embedding_proj(pos_emb)
+
+ hidden_states = rearrange(hidden_states, "T C H W -> 1 C T H W") + rearrange(pos_emb, "T C -> 1 C T 1 1")
+ if self.positional_conv is not None:
+ hidden_states = self.positional_conv(hidden_states)
+ hidden_states = rearrange(hidden_states[0], "C T H W -> (H W) T C")
+
+ residual = hidden_states
+ hidden_states = self.norm_in(hidden_states)
+ hidden_states = self.act_fn_in(hidden_states)
+ hidden_states = self.ff_in(hidden_states)
+ hidden_states = hidden_states + residual
+
+ norm_hidden_states = self.norm1(hidden_states)
+ attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None)
+ hidden_states = attn_output + hidden_states
+
+ norm_hidden_states = self.norm2(hidden_states)
+ attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=text_emb.repeat(height * width, 1))
+ hidden_states = attn_output + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.norm_out(hidden_states)
+ hidden_states = self.act_fn_out(hidden_states)
+ hidden_states = self.ff_out(hidden_states)
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states.reshape(height, width, batch, inner_dim).permute(2, 3, 0, 1)
+
+ return hidden_states, time_emb, text_emb, res_stack
+
+
+class PopMixBlock(torch.nn.Module):
+ def __init__(self, in_channels=None):
+ super().__init__()
+ self.mix_factor = torch.nn.Parameter(torch.Tensor([0.5]))
+ self.need_proj = in_channels is not None
+ if self.need_proj:
+ self.proj = torch.nn.Linear(in_channels, in_channels)
+
+ def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs):
+ res_hidden_states = res_stack.pop()
+ alpha = torch.sigmoid(self.mix_factor)
+ hidden_states = alpha * res_hidden_states + (1 - alpha) * hidden_states
+ if self.need_proj:
+ hidden_states = hidden_states.permute(0, 2, 3, 1)
+ hidden_states = self.proj(hidden_states)
+ hidden_states = hidden_states.permute(0, 3, 1, 2)
+ res_hidden_states = res_stack.pop()
+ hidden_states = hidden_states + res_hidden_states
+ return hidden_states, time_emb, text_emb, res_stack
+
+
+class SVDUNet(torch.nn.Module):
+ def __init__(self, add_positional_conv=None):
+ super().__init__()
+ self.time_proj = Timesteps(320)
+ self.time_embedding = torch.nn.Sequential(
+ torch.nn.Linear(320, 1280),
+ torch.nn.SiLU(),
+ torch.nn.Linear(1280, 1280)
+ )
+ self.add_time_proj = Timesteps(256)
+ self.add_time_embedding = torch.nn.Sequential(
+ torch.nn.Linear(768, 1280),
+ torch.nn.SiLU(),
+ torch.nn.Linear(1280, 1280)
+ )
+ self.conv_in = torch.nn.Conv2d(8, 320, kernel_size=3, padding=1)
+
+ self.blocks = torch.nn.ModuleList([
+ # CrossAttnDownBlockSpatioTemporal
+ ResnetBlock(320, 320, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(320, 320, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
+ AttentionBlock(5, 64, 320, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(5, 64, 320, 1024, add_positional_conv), PopMixBlock(320), PushBlock(),
+ ResnetBlock(320, 320, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(320, 320, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
+ AttentionBlock(5, 64, 320, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(5, 64, 320, 1024, add_positional_conv), PopMixBlock(320), PushBlock(),
+ DownSampler(320), PushBlock(),
+ # CrossAttnDownBlockSpatioTemporal
+ ResnetBlock(320, 640, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(640, 640, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
+ AttentionBlock(10, 64, 640, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(10, 64, 640, 1024, add_positional_conv), PopMixBlock(640), PushBlock(),
+ ResnetBlock(640, 640, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(640, 640, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
+ AttentionBlock(10, 64, 640, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(10, 64, 640, 1024, add_positional_conv), PopMixBlock(640), PushBlock(),
+ DownSampler(640), PushBlock(),
+ # CrossAttnDownBlockSpatioTemporal
+ ResnetBlock(640, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
+ AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024, add_positional_conv), PopMixBlock(1280), PushBlock(),
+ ResnetBlock(1280, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
+ AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024, add_positional_conv), PopMixBlock(1280), PushBlock(),
+ DownSampler(1280), PushBlock(),
+ # DownBlockSpatioTemporal
+ ResnetBlock(1280, 1280, 1280, eps=1e-5), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(), PushBlock(),
+ ResnetBlock(1280, 1280, 1280, eps=1e-5), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(), PushBlock(),
+ # UNetMidBlockSpatioTemporal
+ ResnetBlock(1280, 1280, 1280, eps=1e-5), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(), PushBlock(),
+ AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024, add_positional_conv), PopMixBlock(1280),
+ ResnetBlock(1280, 1280, 1280, eps=1e-5), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(),
+ # UpBlockSpatioTemporal
+ PopBlock(), ResnetBlock(2560, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(),
+ PopBlock(), ResnetBlock(2560, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(),
+ PopBlock(), ResnetBlock(2560, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(),
+ UpSampler(1280),
+ # CrossAttnUpBlockSpatioTemporal
+ PopBlock(), ResnetBlock(2560, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
+ AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024, add_positional_conv), PopMixBlock(1280),
+ PopBlock(), ResnetBlock(2560, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
+ AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024, add_positional_conv), PopMixBlock(1280),
+ PopBlock(), ResnetBlock(1920, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
+ AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024, add_positional_conv), PopMixBlock(1280),
+ UpSampler(1280),
+ # CrossAttnUpBlockSpatioTemporal
+ PopBlock(), ResnetBlock(1920, 640, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(640, 640, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
+ AttentionBlock(10, 64, 640, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(10, 64, 640, 1024, add_positional_conv), PopMixBlock(640),
+ PopBlock(), ResnetBlock(1280, 640, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(640, 640, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
+ AttentionBlock(10, 64, 640, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(10, 64, 640, 1024, add_positional_conv), PopMixBlock(640),
+ PopBlock(), ResnetBlock(960, 640, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(640, 640, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
+ AttentionBlock(10, 64, 640, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(10, 64, 640, 1024, add_positional_conv), PopMixBlock(640),
+ UpSampler(640),
+ # CrossAttnUpBlockSpatioTemporal
+ PopBlock(), ResnetBlock(960, 320, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(320, 320, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
+ AttentionBlock(5, 64, 320, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(5, 64, 320, 1024, add_positional_conv), PopMixBlock(320),
+ PopBlock(), ResnetBlock(640, 320, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(320, 320, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
+ AttentionBlock(5, 64, 320, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(5, 64, 320, 1024, add_positional_conv), PopMixBlock(320),
+ PopBlock(), ResnetBlock(640, 320, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(320, 320, 1280, eps=1e-6), PopMixBlock(), PushBlock(),
+ AttentionBlock(5, 64, 320, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(5, 64, 320, 1024, add_positional_conv), PopMixBlock(320),
+ ])
+
+ self.conv_norm_out = torch.nn.GroupNorm(32, 320, eps=1e-05, affine=True)
+ self.conv_act = torch.nn.SiLU()
+ self.conv_out = torch.nn.Conv2d(320, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
+
+
+ def build_mask(self, data, is_bound):
+ T, C, H, W = data.shape
+ t = repeat(torch.arange(T), "T -> T H W", T=T, H=H, W=W)
+ h = repeat(torch.arange(H), "H -> T H W", T=T, H=H, W=W)
+ w = repeat(torch.arange(W), "W -> T H W", T=T, H=H, W=W)
+ border_width = (T + H + W) // 6
+ pad = torch.ones_like(t) * border_width
+ mask = torch.stack([
+ pad if is_bound[0] else t + 1,
+ pad if is_bound[1] else T - t,
+ pad if is_bound[2] else h + 1,
+ pad if is_bound[3] else H - h,
+ pad if is_bound[4] else w + 1,
+ pad if is_bound[5] else W - w
+ ]).min(dim=0).values
+ mask = mask.clip(1, border_width)
+ mask = (mask / border_width).to(dtype=data.dtype, device=data.device)
+ mask = rearrange(mask, "T H W -> T 1 H W")
+ return mask
+
+
+ def tiled_forward(
+ self, sample, timestep, encoder_hidden_states, add_time_id,
+ batch_time=25, batch_height=128, batch_width=128,
+ stride_time=5, stride_height=64, stride_width=64,
+ progress_bar=lambda x:x
+ ):
+ data_device = sample.device
+ computation_device = self.conv_in.weight.device
+ torch_dtype = sample.dtype
+ T, C, H, W = sample.shape
+
+ weight = torch.zeros((T, 1, H, W), dtype=torch_dtype, device=data_device)
+ values = torch.zeros((T, 4, H, W), dtype=torch_dtype, device=data_device)
+
+ # Split tasks
+ tasks = []
+ for t in range(0, T, stride_time):
+ for h in range(0, H, stride_height):
+ for w in range(0, W, stride_width):
+ if (t-stride_time >= 0 and t-stride_time+batch_time >= T)\
+ or (h-stride_height >= 0 and h-stride_height+batch_height >= H)\
+ or (w-stride_width >= 0 and w-stride_width+batch_width >= W):
+ continue
+ tasks.append((t, t+batch_time, h, h+batch_height, w, w+batch_width))
+
+ # Run
+ for tl, tr, hl, hr, wl, wr in progress_bar(tasks):
+ sample_batch = sample[tl:tr, :, hl:hr, wl:wr].to(computation_device)
+ sample_batch = self.forward(sample_batch, timestep, encoder_hidden_states, add_time_id).to(data_device)
+ mask = self.build_mask(sample_batch, is_bound=(tl==0, tr>=T, hl==0, hr>=H, wl==0, wr>=W))
+ values[tl:tr, :, hl:hr, wl:wr] += sample_batch * mask
+ weight[tl:tr, :, hl:hr, wl:wr] += mask
+ values /= weight
+ return values
+
+
+ def forward(self, sample, timestep, encoder_hidden_states, add_time_id, use_gradient_checkpointing=False, **kwargs):
+ # 1. time
+ timestep = torch.tensor((timestep,)).to(sample.device)
+ t_emb = self.time_proj(timestep).to(sample.dtype)
+ t_emb = self.time_embedding(t_emb)
+
+ add_embeds = self.add_time_proj(add_time_id.flatten()).to(sample.dtype)
+ add_embeds = add_embeds.reshape((-1, 768))
+ add_embeds = self.add_time_embedding(add_embeds)
+
+ time_emb = t_emb + add_embeds
+
+ # 2. pre-process
+ height, width = sample.shape[2], sample.shape[3]
+ hidden_states = self.conv_in(sample)
+ text_emb = encoder_hidden_states
+ res_stack = [hidden_states]
+
+ # 3. blocks
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+ return custom_forward
+ for i, block in enumerate(self.blocks):
+ if self.training and use_gradient_checkpointing and not (isinstance(block, PushBlock) or isinstance(block, PopBlock) or isinstance(block, PopMixBlock)):
+ hidden_states, time_emb, text_emb, res_stack = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ hidden_states, time_emb, text_emb, res_stack,
+ use_reentrant=False,
+ )
+ else:
+ hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
+
+ # 4. output
+ hidden_states = self.conv_norm_out(hidden_states)
+ hidden_states = self.conv_act(hidden_states)
+ hidden_states = self.conv_out(hidden_states)
+
+ return hidden_states
+
+ @staticmethod
+ def state_dict_converter():
+ return SVDUNetStateDictConverter()
+
+
+
+class SVDUNetStateDictConverter:
+ def __init__(self):
+ pass
+
+ def get_block_name(self, names):
+ if names[0] in ["down_blocks", "mid_block", "up_blocks"]:
+ if names[4] in ["norm", "proj_in"]:
+ return ".".join(names[:4] + ["transformer_blocks"])
+ elif names[4] in ["time_pos_embed"]:
+ return ".".join(names[:4] + ["temporal_transformer_blocks"])
+ elif names[4] in ["proj_out"]:
+ return ".".join(names[:4] + ["time_mixer"])
+ else:
+ return ".".join(names[:5])
+ return ""
+
+ def from_diffusers(self, state_dict):
+ rename_dict = {
+ "time_embedding.linear_1": "time_embedding.0",
+ "time_embedding.linear_2": "time_embedding.2",
+ "add_embedding.linear_1": "add_time_embedding.0",
+ "add_embedding.linear_2": "add_time_embedding.2",
+ "conv_in": "conv_in",
+ "conv_norm_out": "conv_norm_out",
+ "conv_out": "conv_out",
+ }
+ blocks_rename_dict = [
+ "down_blocks.0.resnets.0.spatial_res_block", None, "down_blocks.0.resnets.0.temporal_res_block", "down_blocks.0.resnets.0.time_mixer", None,
+ "down_blocks.0.attentions.0.transformer_blocks", None, "down_blocks.0.attentions.0.temporal_transformer_blocks", "down_blocks.0.attentions.0.time_mixer", None,
+ "down_blocks.0.resnets.1.spatial_res_block", None, "down_blocks.0.resnets.1.temporal_res_block", "down_blocks.0.resnets.1.time_mixer", None,
+ "down_blocks.0.attentions.1.transformer_blocks", None, "down_blocks.0.attentions.1.temporal_transformer_blocks", "down_blocks.0.attentions.1.time_mixer", None,
+ "down_blocks.0.downsamplers.0.conv", None,
+ "down_blocks.1.resnets.0.spatial_res_block", None, "down_blocks.1.resnets.0.temporal_res_block", "down_blocks.1.resnets.0.time_mixer", None,
+ "down_blocks.1.attentions.0.transformer_blocks", None, "down_blocks.1.attentions.0.temporal_transformer_blocks", "down_blocks.1.attentions.0.time_mixer", None,
+ "down_blocks.1.resnets.1.spatial_res_block", None, "down_blocks.1.resnets.1.temporal_res_block", "down_blocks.1.resnets.1.time_mixer", None,
+ "down_blocks.1.attentions.1.transformer_blocks", None, "down_blocks.1.attentions.1.temporal_transformer_blocks", "down_blocks.1.attentions.1.time_mixer", None,
+ "down_blocks.1.downsamplers.0.conv", None,
+ "down_blocks.2.resnets.0.spatial_res_block", None, "down_blocks.2.resnets.0.temporal_res_block", "down_blocks.2.resnets.0.time_mixer", None,
+ "down_blocks.2.attentions.0.transformer_blocks", None, "down_blocks.2.attentions.0.temporal_transformer_blocks", "down_blocks.2.attentions.0.time_mixer", None,
+ "down_blocks.2.resnets.1.spatial_res_block", None, "down_blocks.2.resnets.1.temporal_res_block", "down_blocks.2.resnets.1.time_mixer", None,
+ "down_blocks.2.attentions.1.transformer_blocks", None, "down_blocks.2.attentions.1.temporal_transformer_blocks", "down_blocks.2.attentions.1.time_mixer", None,
+ "down_blocks.2.downsamplers.0.conv", None,
+ "down_blocks.3.resnets.0.spatial_res_block", None, "down_blocks.3.resnets.0.temporal_res_block", "down_blocks.3.resnets.0.time_mixer", None,
+ "down_blocks.3.resnets.1.spatial_res_block", None, "down_blocks.3.resnets.1.temporal_res_block", "down_blocks.3.resnets.1.time_mixer", None,
+ "mid_block.mid_block.resnets.0.spatial_res_block", None, "mid_block.mid_block.resnets.0.temporal_res_block", "mid_block.mid_block.resnets.0.time_mixer", None,
+ "mid_block.mid_block.attentions.0.transformer_blocks", None, "mid_block.mid_block.attentions.0.temporal_transformer_blocks", "mid_block.mid_block.attentions.0.time_mixer",
+ "mid_block.mid_block.resnets.1.spatial_res_block", None, "mid_block.mid_block.resnets.1.temporal_res_block", "mid_block.mid_block.resnets.1.time_mixer",
+ None, "up_blocks.0.resnets.0.spatial_res_block", None, "up_blocks.0.resnets.0.temporal_res_block", "up_blocks.0.resnets.0.time_mixer",
+ None, "up_blocks.0.resnets.1.spatial_res_block", None, "up_blocks.0.resnets.1.temporal_res_block", "up_blocks.0.resnets.1.time_mixer",
+ None, "up_blocks.0.resnets.2.spatial_res_block", None, "up_blocks.0.resnets.2.temporal_res_block", "up_blocks.0.resnets.2.time_mixer",
+ "up_blocks.0.upsamplers.0.conv",
+ None, "up_blocks.1.resnets.0.spatial_res_block", None, "up_blocks.1.resnets.0.temporal_res_block", "up_blocks.1.resnets.0.time_mixer", None,
+ "up_blocks.1.attentions.0.transformer_blocks", None, "up_blocks.1.attentions.0.temporal_transformer_blocks", "up_blocks.1.attentions.0.time_mixer",
+ None, "up_blocks.1.resnets.1.spatial_res_block", None, "up_blocks.1.resnets.1.temporal_res_block", "up_blocks.1.resnets.1.time_mixer", None,
+ "up_blocks.1.attentions.1.transformer_blocks", None, "up_blocks.1.attentions.1.temporal_transformer_blocks", "up_blocks.1.attentions.1.time_mixer",
+ None, "up_blocks.1.resnets.2.spatial_res_block", None, "up_blocks.1.resnets.2.temporal_res_block", "up_blocks.1.resnets.2.time_mixer", None,
+ "up_blocks.1.attentions.2.transformer_blocks", None, "up_blocks.1.attentions.2.temporal_transformer_blocks", "up_blocks.1.attentions.2.time_mixer",
+ "up_blocks.1.upsamplers.0.conv",
+ None, "up_blocks.2.resnets.0.spatial_res_block", None, "up_blocks.2.resnets.0.temporal_res_block", "up_blocks.2.resnets.0.time_mixer", None,
+ "up_blocks.2.attentions.0.transformer_blocks", None, "up_blocks.2.attentions.0.temporal_transformer_blocks", "up_blocks.2.attentions.0.time_mixer",
+ None, "up_blocks.2.resnets.1.spatial_res_block", None, "up_blocks.2.resnets.1.temporal_res_block", "up_blocks.2.resnets.1.time_mixer", None,
+ "up_blocks.2.attentions.1.transformer_blocks", None, "up_blocks.2.attentions.1.temporal_transformer_blocks", "up_blocks.2.attentions.1.time_mixer",
+ None, "up_blocks.2.resnets.2.spatial_res_block", None, "up_blocks.2.resnets.2.temporal_res_block", "up_blocks.2.resnets.2.time_mixer", None,
+ "up_blocks.2.attentions.2.transformer_blocks", None, "up_blocks.2.attentions.2.temporal_transformer_blocks", "up_blocks.2.attentions.2.time_mixer",
+ "up_blocks.2.upsamplers.0.conv",
+ None, "up_blocks.3.resnets.0.spatial_res_block", None, "up_blocks.3.resnets.0.temporal_res_block", "up_blocks.3.resnets.0.time_mixer", None,
+ "up_blocks.3.attentions.0.transformer_blocks", None, "up_blocks.3.attentions.0.temporal_transformer_blocks", "up_blocks.3.attentions.0.time_mixer",
+ None, "up_blocks.3.resnets.1.spatial_res_block", None, "up_blocks.3.resnets.1.temporal_res_block", "up_blocks.3.resnets.1.time_mixer", None,
+ "up_blocks.3.attentions.1.transformer_blocks", None, "up_blocks.3.attentions.1.temporal_transformer_blocks", "up_blocks.3.attentions.1.time_mixer",
+ None, "up_blocks.3.resnets.2.spatial_res_block", None, "up_blocks.3.resnets.2.temporal_res_block", "up_blocks.3.resnets.2.time_mixer", None,
+ "up_blocks.3.attentions.2.transformer_blocks", None, "up_blocks.3.attentions.2.temporal_transformer_blocks", "up_blocks.3.attentions.2.time_mixer",
+ ]
+ blocks_rename_dict = {i:j for j,i in enumerate(blocks_rename_dict) if i is not None}
+ state_dict_ = {}
+ for name, param in sorted(state_dict.items()):
+ names = name.split(".")
+ if names[0] == "mid_block":
+ names = ["mid_block"] + names
+ if names[-1] in ["weight", "bias"]:
+ name_prefix = ".".join(names[:-1])
+ if name_prefix in rename_dict:
+ state_dict_[rename_dict[name_prefix] + "." + names[-1]] = param
+ else:
+ block_name = self.get_block_name(names)
+ if "resnets" in block_name and block_name in blocks_rename_dict:
+ rename = ".".join(["blocks", str(blocks_rename_dict[block_name])] + names[5:])
+ state_dict_[rename] = param
+ elif ("downsamplers" in block_name or "upsamplers" in block_name) and block_name in blocks_rename_dict:
+ rename = ".".join(["blocks", str(blocks_rename_dict[block_name])] + names[-2:])
+ state_dict_[rename] = param
+ elif "attentions" in block_name and block_name in blocks_rename_dict:
+ attention_id = names[5]
+ if "transformer_blocks" in names:
+ suffix_dict = {
+ "attn1.to_out.0": "attn1.to_out",
+ "attn2.to_out.0": "attn2.to_out",
+ "ff.net.0.proj": "act_fn.proj",
+ "ff.net.2": "ff",
+ }
+ suffix = ".".join(names[6:-1])
+ suffix = suffix_dict.get(suffix, suffix)
+ rename = ".".join(["blocks", str(blocks_rename_dict[block_name]), "transformer_blocks", attention_id, suffix, names[-1]])
+ elif "temporal_transformer_blocks" in names:
+ suffix_dict = {
+ "attn1.to_out.0": "attn1.to_out",
+ "attn2.to_out.0": "attn2.to_out",
+ "ff_in.net.0.proj": "act_fn_in.proj",
+ "ff_in.net.2": "ff_in",
+ "ff.net.0.proj": "act_fn_out.proj",
+ "ff.net.2": "ff_out",
+ "norm3": "norm_out",
+ }
+ suffix = ".".join(names[6:-1])
+ suffix = suffix_dict.get(suffix, suffix)
+ rename = ".".join(["blocks", str(blocks_rename_dict[block_name]), suffix, names[-1]])
+ elif "time_mixer" in block_name:
+ rename = ".".join(["blocks", str(blocks_rename_dict[block_name]), "proj", names[-1]])
+ else:
+ suffix_dict = {
+ "linear_1": "positional_embedding_proj.0",
+ "linear_2": "positional_embedding_proj.2",
+ }
+ suffix = names[-2]
+ suffix = suffix_dict.get(suffix, suffix)
+ rename = ".".join(["blocks", str(blocks_rename_dict[block_name]), suffix, names[-1]])
+ state_dict_[rename] = param
+ else:
+ print(name)
+ else:
+ block_name = self.get_block_name(names)
+ if len(block_name)>0 and block_name in blocks_rename_dict:
+ rename = ".".join(["blocks", str(blocks_rename_dict[block_name]), names[-1]])
+ state_dict_[rename] = param
+ return state_dict_
+
+
+ def from_civitai(self, state_dict, add_positional_conv=None):
+ rename_dict = {
+ "model.diffusion_model.input_blocks.0.0.bias": "conv_in.bias",
+ "model.diffusion_model.input_blocks.0.0.weight": "conv_in.weight",
+ "model.diffusion_model.input_blocks.1.0.emb_layers.1.bias": "blocks.0.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.1.0.emb_layers.1.weight": "blocks.0.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.1.0.in_layers.0.bias": "blocks.0.norm1.bias",
+ "model.diffusion_model.input_blocks.1.0.in_layers.0.weight": "blocks.0.norm1.weight",
+ "model.diffusion_model.input_blocks.1.0.in_layers.2.bias": "blocks.0.conv1.bias",
+ "model.diffusion_model.input_blocks.1.0.in_layers.2.weight": "blocks.0.conv1.weight",
+ "model.diffusion_model.input_blocks.1.0.out_layers.0.bias": "blocks.0.norm2.bias",
+ "model.diffusion_model.input_blocks.1.0.out_layers.0.weight": "blocks.0.norm2.weight",
+ "model.diffusion_model.input_blocks.1.0.out_layers.3.bias": "blocks.0.conv2.bias",
+ "model.diffusion_model.input_blocks.1.0.out_layers.3.weight": "blocks.0.conv2.weight",
+ "model.diffusion_model.input_blocks.1.0.time_mixer.mix_factor": "blocks.3.mix_factor",
+ "model.diffusion_model.input_blocks.1.0.time_stack.emb_layers.1.bias": "blocks.2.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.1.0.time_stack.emb_layers.1.weight": "blocks.2.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.1.0.time_stack.in_layers.0.bias": "blocks.2.norm1.bias",
+ "model.diffusion_model.input_blocks.1.0.time_stack.in_layers.0.weight": "blocks.2.norm1.weight",
+ "model.diffusion_model.input_blocks.1.0.time_stack.in_layers.2.bias": "blocks.2.conv1.bias",
+ "model.diffusion_model.input_blocks.1.0.time_stack.in_layers.2.weight": "blocks.2.conv1.weight",
+ "model.diffusion_model.input_blocks.1.0.time_stack.out_layers.0.bias": "blocks.2.norm2.bias",
+ "model.diffusion_model.input_blocks.1.0.time_stack.out_layers.0.weight": "blocks.2.norm2.weight",
+ "model.diffusion_model.input_blocks.1.0.time_stack.out_layers.3.bias": "blocks.2.conv2.bias",
+ "model.diffusion_model.input_blocks.1.0.time_stack.out_layers.3.weight": "blocks.2.conv2.weight",
+ "model.diffusion_model.input_blocks.1.1.norm.bias": "blocks.5.norm.bias",
+ "model.diffusion_model.input_blocks.1.1.norm.weight": "blocks.5.norm.weight",
+ "model.diffusion_model.input_blocks.1.1.proj_in.bias": "blocks.5.proj_in.bias",
+ "model.diffusion_model.input_blocks.1.1.proj_in.weight": "blocks.5.proj_in.weight",
+ "model.diffusion_model.input_blocks.1.1.proj_out.bias": "blocks.8.proj.bias",
+ "model.diffusion_model.input_blocks.1.1.proj_out.weight": "blocks.8.proj.weight",
+ "model.diffusion_model.input_blocks.1.1.time_mixer.mix_factor": "blocks.8.mix_factor",
+ "model.diffusion_model.input_blocks.1.1.time_pos_embed.0.bias": "blocks.7.positional_embedding_proj.0.bias",
+ "model.diffusion_model.input_blocks.1.1.time_pos_embed.0.weight": "blocks.7.positional_embedding_proj.0.weight",
+ "model.diffusion_model.input_blocks.1.1.time_pos_embed.2.bias": "blocks.7.positional_embedding_proj.2.bias",
+ "model.diffusion_model.input_blocks.1.1.time_pos_embed.2.weight": "blocks.7.positional_embedding_proj.2.weight",
+ "model.diffusion_model.input_blocks.1.1.time_stack.0.attn1.to_k.weight": "blocks.7.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.1.1.time_stack.0.attn1.to_out.0.bias": "blocks.7.attn1.to_out.bias",
+ "model.diffusion_model.input_blocks.1.1.time_stack.0.attn1.to_out.0.weight": "blocks.7.attn1.to_out.weight",
+ "model.diffusion_model.input_blocks.1.1.time_stack.0.attn1.to_q.weight": "blocks.7.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.1.1.time_stack.0.attn1.to_v.weight": "blocks.7.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.1.1.time_stack.0.attn2.to_k.weight": "blocks.7.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.1.1.time_stack.0.attn2.to_out.0.bias": "blocks.7.attn2.to_out.bias",
+ "model.diffusion_model.input_blocks.1.1.time_stack.0.attn2.to_out.0.weight": "blocks.7.attn2.to_out.weight",
+ "model.diffusion_model.input_blocks.1.1.time_stack.0.attn2.to_q.weight": "blocks.7.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.1.1.time_stack.0.attn2.to_v.weight": "blocks.7.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.1.1.time_stack.0.ff.net.0.proj.bias": "blocks.7.act_fn_out.proj.bias",
+ "model.diffusion_model.input_blocks.1.1.time_stack.0.ff.net.0.proj.weight": "blocks.7.act_fn_out.proj.weight",
+ "model.diffusion_model.input_blocks.1.1.time_stack.0.ff.net.2.bias": "blocks.7.ff_out.bias",
+ "model.diffusion_model.input_blocks.1.1.time_stack.0.ff.net.2.weight": "blocks.7.ff_out.weight",
+ "model.diffusion_model.input_blocks.1.1.time_stack.0.ff_in.net.0.proj.bias": "blocks.7.act_fn_in.proj.bias",
+ "model.diffusion_model.input_blocks.1.1.time_stack.0.ff_in.net.0.proj.weight": "blocks.7.act_fn_in.proj.weight",
+ "model.diffusion_model.input_blocks.1.1.time_stack.0.ff_in.net.2.bias": "blocks.7.ff_in.bias",
+ "model.diffusion_model.input_blocks.1.1.time_stack.0.ff_in.net.2.weight": "blocks.7.ff_in.weight",
+ "model.diffusion_model.input_blocks.1.1.time_stack.0.norm1.bias": "blocks.7.norm1.bias",
+ "model.diffusion_model.input_blocks.1.1.time_stack.0.norm1.weight": "blocks.7.norm1.weight",
+ "model.diffusion_model.input_blocks.1.1.time_stack.0.norm2.bias": "blocks.7.norm2.bias",
+ "model.diffusion_model.input_blocks.1.1.time_stack.0.norm2.weight": "blocks.7.norm2.weight",
+ "model.diffusion_model.input_blocks.1.1.time_stack.0.norm3.bias": "blocks.7.norm_out.bias",
+ "model.diffusion_model.input_blocks.1.1.time_stack.0.norm3.weight": "blocks.7.norm_out.weight",
+ "model.diffusion_model.input_blocks.1.1.time_stack.0.norm_in.bias": "blocks.7.norm_in.bias",
+ "model.diffusion_model.input_blocks.1.1.time_stack.0.norm_in.weight": "blocks.7.norm_in.weight",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_k.weight": "blocks.5.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.5.transformer_blocks.0.attn1.to_out.bias",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.5.transformer_blocks.0.attn1.to_out.weight",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_q.weight": "blocks.5.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_v.weight": "blocks.5.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight": "blocks.5.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.5.transformer_blocks.0.attn2.to_out.bias",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.5.transformer_blocks.0.attn2.to_out.weight",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_q.weight": "blocks.5.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_v.weight": "blocks.5.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.5.transformer_blocks.0.act_fn.proj.bias",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.5.transformer_blocks.0.act_fn.proj.weight",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.ff.net.2.bias": "blocks.5.transformer_blocks.0.ff.bias",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.ff.net.2.weight": "blocks.5.transformer_blocks.0.ff.weight",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm1.bias": "blocks.5.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm1.weight": "blocks.5.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm2.bias": "blocks.5.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm2.weight": "blocks.5.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm3.bias": "blocks.5.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm3.weight": "blocks.5.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.input_blocks.10.0.emb_layers.1.bias": "blocks.66.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.10.0.emb_layers.1.weight": "blocks.66.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.10.0.in_layers.0.bias": "blocks.66.norm1.bias",
+ "model.diffusion_model.input_blocks.10.0.in_layers.0.weight": "blocks.66.norm1.weight",
+ "model.diffusion_model.input_blocks.10.0.in_layers.2.bias": "blocks.66.conv1.bias",
+ "model.diffusion_model.input_blocks.10.0.in_layers.2.weight": "blocks.66.conv1.weight",
+ "model.diffusion_model.input_blocks.10.0.out_layers.0.bias": "blocks.66.norm2.bias",
+ "model.diffusion_model.input_blocks.10.0.out_layers.0.weight": "blocks.66.norm2.weight",
+ "model.diffusion_model.input_blocks.10.0.out_layers.3.bias": "blocks.66.conv2.bias",
+ "model.diffusion_model.input_blocks.10.0.out_layers.3.weight": "blocks.66.conv2.weight",
+ "model.diffusion_model.input_blocks.10.0.time_mixer.mix_factor": "blocks.69.mix_factor",
+ "model.diffusion_model.input_blocks.10.0.time_stack.emb_layers.1.bias": "blocks.68.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.10.0.time_stack.emb_layers.1.weight": "blocks.68.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.10.0.time_stack.in_layers.0.bias": "blocks.68.norm1.bias",
+ "model.diffusion_model.input_blocks.10.0.time_stack.in_layers.0.weight": "blocks.68.norm1.weight",
+ "model.diffusion_model.input_blocks.10.0.time_stack.in_layers.2.bias": "blocks.68.conv1.bias",
+ "model.diffusion_model.input_blocks.10.0.time_stack.in_layers.2.weight": "blocks.68.conv1.weight",
+ "model.diffusion_model.input_blocks.10.0.time_stack.out_layers.0.bias": "blocks.68.norm2.bias",
+ "model.diffusion_model.input_blocks.10.0.time_stack.out_layers.0.weight": "blocks.68.norm2.weight",
+ "model.diffusion_model.input_blocks.10.0.time_stack.out_layers.3.bias": "blocks.68.conv2.bias",
+ "model.diffusion_model.input_blocks.10.0.time_stack.out_layers.3.weight": "blocks.68.conv2.weight",
+ "model.diffusion_model.input_blocks.11.0.emb_layers.1.bias": "blocks.71.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.11.0.emb_layers.1.weight": "blocks.71.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.11.0.in_layers.0.bias": "blocks.71.norm1.bias",
+ "model.diffusion_model.input_blocks.11.0.in_layers.0.weight": "blocks.71.norm1.weight",
+ "model.diffusion_model.input_blocks.11.0.in_layers.2.bias": "blocks.71.conv1.bias",
+ "model.diffusion_model.input_blocks.11.0.in_layers.2.weight": "blocks.71.conv1.weight",
+ "model.diffusion_model.input_blocks.11.0.out_layers.0.bias": "blocks.71.norm2.bias",
+ "model.diffusion_model.input_blocks.11.0.out_layers.0.weight": "blocks.71.norm2.weight",
+ "model.diffusion_model.input_blocks.11.0.out_layers.3.bias": "blocks.71.conv2.bias",
+ "model.diffusion_model.input_blocks.11.0.out_layers.3.weight": "blocks.71.conv2.weight",
+ "model.diffusion_model.input_blocks.11.0.time_mixer.mix_factor": "blocks.74.mix_factor",
+ "model.diffusion_model.input_blocks.11.0.time_stack.emb_layers.1.bias": "blocks.73.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.11.0.time_stack.emb_layers.1.weight": "blocks.73.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.11.0.time_stack.in_layers.0.bias": "blocks.73.norm1.bias",
+ "model.diffusion_model.input_blocks.11.0.time_stack.in_layers.0.weight": "blocks.73.norm1.weight",
+ "model.diffusion_model.input_blocks.11.0.time_stack.in_layers.2.bias": "blocks.73.conv1.bias",
+ "model.diffusion_model.input_blocks.11.0.time_stack.in_layers.2.weight": "blocks.73.conv1.weight",
+ "model.diffusion_model.input_blocks.11.0.time_stack.out_layers.0.bias": "blocks.73.norm2.bias",
+ "model.diffusion_model.input_blocks.11.0.time_stack.out_layers.0.weight": "blocks.73.norm2.weight",
+ "model.diffusion_model.input_blocks.11.0.time_stack.out_layers.3.bias": "blocks.73.conv2.bias",
+ "model.diffusion_model.input_blocks.11.0.time_stack.out_layers.3.weight": "blocks.73.conv2.weight",
+ "model.diffusion_model.input_blocks.2.0.emb_layers.1.bias": "blocks.10.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.2.0.emb_layers.1.weight": "blocks.10.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.2.0.in_layers.0.bias": "blocks.10.norm1.bias",
+ "model.diffusion_model.input_blocks.2.0.in_layers.0.weight": "blocks.10.norm1.weight",
+ "model.diffusion_model.input_blocks.2.0.in_layers.2.bias": "blocks.10.conv1.bias",
+ "model.diffusion_model.input_blocks.2.0.in_layers.2.weight": "blocks.10.conv1.weight",
+ "model.diffusion_model.input_blocks.2.0.out_layers.0.bias": "blocks.10.norm2.bias",
+ "model.diffusion_model.input_blocks.2.0.out_layers.0.weight": "blocks.10.norm2.weight",
+ "model.diffusion_model.input_blocks.2.0.out_layers.3.bias": "blocks.10.conv2.bias",
+ "model.diffusion_model.input_blocks.2.0.out_layers.3.weight": "blocks.10.conv2.weight",
+ "model.diffusion_model.input_blocks.2.0.time_mixer.mix_factor": "blocks.13.mix_factor",
+ "model.diffusion_model.input_blocks.2.0.time_stack.emb_layers.1.bias": "blocks.12.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.2.0.time_stack.emb_layers.1.weight": "blocks.12.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.2.0.time_stack.in_layers.0.bias": "blocks.12.norm1.bias",
+ "model.diffusion_model.input_blocks.2.0.time_stack.in_layers.0.weight": "blocks.12.norm1.weight",
+ "model.diffusion_model.input_blocks.2.0.time_stack.in_layers.2.bias": "blocks.12.conv1.bias",
+ "model.diffusion_model.input_blocks.2.0.time_stack.in_layers.2.weight": "blocks.12.conv1.weight",
+ "model.diffusion_model.input_blocks.2.0.time_stack.out_layers.0.bias": "blocks.12.norm2.bias",
+ "model.diffusion_model.input_blocks.2.0.time_stack.out_layers.0.weight": "blocks.12.norm2.weight",
+ "model.diffusion_model.input_blocks.2.0.time_stack.out_layers.3.bias": "blocks.12.conv2.bias",
+ "model.diffusion_model.input_blocks.2.0.time_stack.out_layers.3.weight": "blocks.12.conv2.weight",
+ "model.diffusion_model.input_blocks.2.1.norm.bias": "blocks.15.norm.bias",
+ "model.diffusion_model.input_blocks.2.1.norm.weight": "blocks.15.norm.weight",
+ "model.diffusion_model.input_blocks.2.1.proj_in.bias": "blocks.15.proj_in.bias",
+ "model.diffusion_model.input_blocks.2.1.proj_in.weight": "blocks.15.proj_in.weight",
+ "model.diffusion_model.input_blocks.2.1.proj_out.bias": "blocks.18.proj.bias",
+ "model.diffusion_model.input_blocks.2.1.proj_out.weight": "blocks.18.proj.weight",
+ "model.diffusion_model.input_blocks.2.1.time_mixer.mix_factor": "blocks.18.mix_factor",
+ "model.diffusion_model.input_blocks.2.1.time_pos_embed.0.bias": "blocks.17.positional_embedding_proj.0.bias",
+ "model.diffusion_model.input_blocks.2.1.time_pos_embed.0.weight": "blocks.17.positional_embedding_proj.0.weight",
+ "model.diffusion_model.input_blocks.2.1.time_pos_embed.2.bias": "blocks.17.positional_embedding_proj.2.bias",
+ "model.diffusion_model.input_blocks.2.1.time_pos_embed.2.weight": "blocks.17.positional_embedding_proj.2.weight",
+ "model.diffusion_model.input_blocks.2.1.time_stack.0.attn1.to_k.weight": "blocks.17.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.2.1.time_stack.0.attn1.to_out.0.bias": "blocks.17.attn1.to_out.bias",
+ "model.diffusion_model.input_blocks.2.1.time_stack.0.attn1.to_out.0.weight": "blocks.17.attn1.to_out.weight",
+ "model.diffusion_model.input_blocks.2.1.time_stack.0.attn1.to_q.weight": "blocks.17.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.2.1.time_stack.0.attn1.to_v.weight": "blocks.17.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.2.1.time_stack.0.attn2.to_k.weight": "blocks.17.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.2.1.time_stack.0.attn2.to_out.0.bias": "blocks.17.attn2.to_out.bias",
+ "model.diffusion_model.input_blocks.2.1.time_stack.0.attn2.to_out.0.weight": "blocks.17.attn2.to_out.weight",
+ "model.diffusion_model.input_blocks.2.1.time_stack.0.attn2.to_q.weight": "blocks.17.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.2.1.time_stack.0.attn2.to_v.weight": "blocks.17.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.2.1.time_stack.0.ff.net.0.proj.bias": "blocks.17.act_fn_out.proj.bias",
+ "model.diffusion_model.input_blocks.2.1.time_stack.0.ff.net.0.proj.weight": "blocks.17.act_fn_out.proj.weight",
+ "model.diffusion_model.input_blocks.2.1.time_stack.0.ff.net.2.bias": "blocks.17.ff_out.bias",
+ "model.diffusion_model.input_blocks.2.1.time_stack.0.ff.net.2.weight": "blocks.17.ff_out.weight",
+ "model.diffusion_model.input_blocks.2.1.time_stack.0.ff_in.net.0.proj.bias": "blocks.17.act_fn_in.proj.bias",
+ "model.diffusion_model.input_blocks.2.1.time_stack.0.ff_in.net.0.proj.weight": "blocks.17.act_fn_in.proj.weight",
+ "model.diffusion_model.input_blocks.2.1.time_stack.0.ff_in.net.2.bias": "blocks.17.ff_in.bias",
+ "model.diffusion_model.input_blocks.2.1.time_stack.0.ff_in.net.2.weight": "blocks.17.ff_in.weight",
+ "model.diffusion_model.input_blocks.2.1.time_stack.0.norm1.bias": "blocks.17.norm1.bias",
+ "model.diffusion_model.input_blocks.2.1.time_stack.0.norm1.weight": "blocks.17.norm1.weight",
+ "model.diffusion_model.input_blocks.2.1.time_stack.0.norm2.bias": "blocks.17.norm2.bias",
+ "model.diffusion_model.input_blocks.2.1.time_stack.0.norm2.weight": "blocks.17.norm2.weight",
+ "model.diffusion_model.input_blocks.2.1.time_stack.0.norm3.bias": "blocks.17.norm_out.bias",
+ "model.diffusion_model.input_blocks.2.1.time_stack.0.norm3.weight": "blocks.17.norm_out.weight",
+ "model.diffusion_model.input_blocks.2.1.time_stack.0.norm_in.bias": "blocks.17.norm_in.bias",
+ "model.diffusion_model.input_blocks.2.1.time_stack.0.norm_in.weight": "blocks.17.norm_in.weight",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_k.weight": "blocks.15.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.15.transformer_blocks.0.attn1.to_out.bias",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.15.transformer_blocks.0.attn1.to_out.weight",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_q.weight": "blocks.15.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_v.weight": "blocks.15.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight": "blocks.15.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.15.transformer_blocks.0.attn2.to_out.bias",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.15.transformer_blocks.0.attn2.to_out.weight",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_q.weight": "blocks.15.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_v.weight": "blocks.15.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.15.transformer_blocks.0.act_fn.proj.bias",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.15.transformer_blocks.0.act_fn.proj.weight",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.ff.net.2.bias": "blocks.15.transformer_blocks.0.ff.bias",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.ff.net.2.weight": "blocks.15.transformer_blocks.0.ff.weight",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm1.bias": "blocks.15.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm1.weight": "blocks.15.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm2.bias": "blocks.15.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm2.weight": "blocks.15.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm3.bias": "blocks.15.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm3.weight": "blocks.15.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.input_blocks.3.0.op.bias": "blocks.20.conv.bias",
+ "model.diffusion_model.input_blocks.3.0.op.weight": "blocks.20.conv.weight",
+ "model.diffusion_model.input_blocks.4.0.emb_layers.1.bias": "blocks.22.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.4.0.emb_layers.1.weight": "blocks.22.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.4.0.in_layers.0.bias": "blocks.22.norm1.bias",
+ "model.diffusion_model.input_blocks.4.0.in_layers.0.weight": "blocks.22.norm1.weight",
+ "model.diffusion_model.input_blocks.4.0.in_layers.2.bias": "blocks.22.conv1.bias",
+ "model.diffusion_model.input_blocks.4.0.in_layers.2.weight": "blocks.22.conv1.weight",
+ "model.diffusion_model.input_blocks.4.0.out_layers.0.bias": "blocks.22.norm2.bias",
+ "model.diffusion_model.input_blocks.4.0.out_layers.0.weight": "blocks.22.norm2.weight",
+ "model.diffusion_model.input_blocks.4.0.out_layers.3.bias": "blocks.22.conv2.bias",
+ "model.diffusion_model.input_blocks.4.0.out_layers.3.weight": "blocks.22.conv2.weight",
+ "model.diffusion_model.input_blocks.4.0.skip_connection.bias": "blocks.22.conv_shortcut.bias",
+ "model.diffusion_model.input_blocks.4.0.skip_connection.weight": "blocks.22.conv_shortcut.weight",
+ "model.diffusion_model.input_blocks.4.0.time_mixer.mix_factor": "blocks.25.mix_factor",
+ "model.diffusion_model.input_blocks.4.0.time_stack.emb_layers.1.bias": "blocks.24.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.4.0.time_stack.emb_layers.1.weight": "blocks.24.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.4.0.time_stack.in_layers.0.bias": "blocks.24.norm1.bias",
+ "model.diffusion_model.input_blocks.4.0.time_stack.in_layers.0.weight": "blocks.24.norm1.weight",
+ "model.diffusion_model.input_blocks.4.0.time_stack.in_layers.2.bias": "blocks.24.conv1.bias",
+ "model.diffusion_model.input_blocks.4.0.time_stack.in_layers.2.weight": "blocks.24.conv1.weight",
+ "model.diffusion_model.input_blocks.4.0.time_stack.out_layers.0.bias": "blocks.24.norm2.bias",
+ "model.diffusion_model.input_blocks.4.0.time_stack.out_layers.0.weight": "blocks.24.norm2.weight",
+ "model.diffusion_model.input_blocks.4.0.time_stack.out_layers.3.bias": "blocks.24.conv2.bias",
+ "model.diffusion_model.input_blocks.4.0.time_stack.out_layers.3.weight": "blocks.24.conv2.weight",
+ "model.diffusion_model.input_blocks.4.1.norm.bias": "blocks.27.norm.bias",
+ "model.diffusion_model.input_blocks.4.1.norm.weight": "blocks.27.norm.weight",
+ "model.diffusion_model.input_blocks.4.1.proj_in.bias": "blocks.27.proj_in.bias",
+ "model.diffusion_model.input_blocks.4.1.proj_in.weight": "blocks.27.proj_in.weight",
+ "model.diffusion_model.input_blocks.4.1.proj_out.bias": "blocks.30.proj.bias",
+ "model.diffusion_model.input_blocks.4.1.proj_out.weight": "blocks.30.proj.weight",
+ "model.diffusion_model.input_blocks.4.1.time_mixer.mix_factor": "blocks.30.mix_factor",
+ "model.diffusion_model.input_blocks.4.1.time_pos_embed.0.bias": "blocks.29.positional_embedding_proj.0.bias",
+ "model.diffusion_model.input_blocks.4.1.time_pos_embed.0.weight": "blocks.29.positional_embedding_proj.0.weight",
+ "model.diffusion_model.input_blocks.4.1.time_pos_embed.2.bias": "blocks.29.positional_embedding_proj.2.bias",
+ "model.diffusion_model.input_blocks.4.1.time_pos_embed.2.weight": "blocks.29.positional_embedding_proj.2.weight",
+ "model.diffusion_model.input_blocks.4.1.time_stack.0.attn1.to_k.weight": "blocks.29.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.4.1.time_stack.0.attn1.to_out.0.bias": "blocks.29.attn1.to_out.bias",
+ "model.diffusion_model.input_blocks.4.1.time_stack.0.attn1.to_out.0.weight": "blocks.29.attn1.to_out.weight",
+ "model.diffusion_model.input_blocks.4.1.time_stack.0.attn1.to_q.weight": "blocks.29.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.4.1.time_stack.0.attn1.to_v.weight": "blocks.29.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.4.1.time_stack.0.attn2.to_k.weight": "blocks.29.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.4.1.time_stack.0.attn2.to_out.0.bias": "blocks.29.attn2.to_out.bias",
+ "model.diffusion_model.input_blocks.4.1.time_stack.0.attn2.to_out.0.weight": "blocks.29.attn2.to_out.weight",
+ "model.diffusion_model.input_blocks.4.1.time_stack.0.attn2.to_q.weight": "blocks.29.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.4.1.time_stack.0.attn2.to_v.weight": "blocks.29.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.4.1.time_stack.0.ff.net.0.proj.bias": "blocks.29.act_fn_out.proj.bias",
+ "model.diffusion_model.input_blocks.4.1.time_stack.0.ff.net.0.proj.weight": "blocks.29.act_fn_out.proj.weight",
+ "model.diffusion_model.input_blocks.4.1.time_stack.0.ff.net.2.bias": "blocks.29.ff_out.bias",
+ "model.diffusion_model.input_blocks.4.1.time_stack.0.ff.net.2.weight": "blocks.29.ff_out.weight",
+ "model.diffusion_model.input_blocks.4.1.time_stack.0.ff_in.net.0.proj.bias": "blocks.29.act_fn_in.proj.bias",
+ "model.diffusion_model.input_blocks.4.1.time_stack.0.ff_in.net.0.proj.weight": "blocks.29.act_fn_in.proj.weight",
+ "model.diffusion_model.input_blocks.4.1.time_stack.0.ff_in.net.2.bias": "blocks.29.ff_in.bias",
+ "model.diffusion_model.input_blocks.4.1.time_stack.0.ff_in.net.2.weight": "blocks.29.ff_in.weight",
+ "model.diffusion_model.input_blocks.4.1.time_stack.0.norm1.bias": "blocks.29.norm1.bias",
+ "model.diffusion_model.input_blocks.4.1.time_stack.0.norm1.weight": "blocks.29.norm1.weight",
+ "model.diffusion_model.input_blocks.4.1.time_stack.0.norm2.bias": "blocks.29.norm2.bias",
+ "model.diffusion_model.input_blocks.4.1.time_stack.0.norm2.weight": "blocks.29.norm2.weight",
+ "model.diffusion_model.input_blocks.4.1.time_stack.0.norm3.bias": "blocks.29.norm_out.bias",
+ "model.diffusion_model.input_blocks.4.1.time_stack.0.norm3.weight": "blocks.29.norm_out.weight",
+ "model.diffusion_model.input_blocks.4.1.time_stack.0.norm_in.bias": "blocks.29.norm_in.bias",
+ "model.diffusion_model.input_blocks.4.1.time_stack.0.norm_in.weight": "blocks.29.norm_in.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_k.weight": "blocks.27.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.27.transformer_blocks.0.attn1.to_out.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.27.transformer_blocks.0.attn1.to_out.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_q.weight": "blocks.27.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_v.weight": "blocks.27.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight": "blocks.27.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.27.transformer_blocks.0.attn2.to_out.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.27.transformer_blocks.0.attn2.to_out.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_q.weight": "blocks.27.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_v.weight": "blocks.27.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.27.transformer_blocks.0.act_fn.proj.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.27.transformer_blocks.0.act_fn.proj.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.bias": "blocks.27.transformer_blocks.0.ff.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.weight": "blocks.27.transformer_blocks.0.ff.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm1.bias": "blocks.27.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm1.weight": "blocks.27.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm2.bias": "blocks.27.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm2.weight": "blocks.27.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm3.bias": "blocks.27.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm3.weight": "blocks.27.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.input_blocks.5.0.emb_layers.1.bias": "blocks.32.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.5.0.emb_layers.1.weight": "blocks.32.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.5.0.in_layers.0.bias": "blocks.32.norm1.bias",
+ "model.diffusion_model.input_blocks.5.0.in_layers.0.weight": "blocks.32.norm1.weight",
+ "model.diffusion_model.input_blocks.5.0.in_layers.2.bias": "blocks.32.conv1.bias",
+ "model.diffusion_model.input_blocks.5.0.in_layers.2.weight": "blocks.32.conv1.weight",
+ "model.diffusion_model.input_blocks.5.0.out_layers.0.bias": "blocks.32.norm2.bias",
+ "model.diffusion_model.input_blocks.5.0.out_layers.0.weight": "blocks.32.norm2.weight",
+ "model.diffusion_model.input_blocks.5.0.out_layers.3.bias": "blocks.32.conv2.bias",
+ "model.diffusion_model.input_blocks.5.0.out_layers.3.weight": "blocks.32.conv2.weight",
+ "model.diffusion_model.input_blocks.5.0.time_mixer.mix_factor": "blocks.35.mix_factor",
+ "model.diffusion_model.input_blocks.5.0.time_stack.emb_layers.1.bias": "blocks.34.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.5.0.time_stack.emb_layers.1.weight": "blocks.34.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.5.0.time_stack.in_layers.0.bias": "blocks.34.norm1.bias",
+ "model.diffusion_model.input_blocks.5.0.time_stack.in_layers.0.weight": "blocks.34.norm1.weight",
+ "model.diffusion_model.input_blocks.5.0.time_stack.in_layers.2.bias": "blocks.34.conv1.bias",
+ "model.diffusion_model.input_blocks.5.0.time_stack.in_layers.2.weight": "blocks.34.conv1.weight",
+ "model.diffusion_model.input_blocks.5.0.time_stack.out_layers.0.bias": "blocks.34.norm2.bias",
+ "model.diffusion_model.input_blocks.5.0.time_stack.out_layers.0.weight": "blocks.34.norm2.weight",
+ "model.diffusion_model.input_blocks.5.0.time_stack.out_layers.3.bias": "blocks.34.conv2.bias",
+ "model.diffusion_model.input_blocks.5.0.time_stack.out_layers.3.weight": "blocks.34.conv2.weight",
+ "model.diffusion_model.input_blocks.5.1.norm.bias": "blocks.37.norm.bias",
+ "model.diffusion_model.input_blocks.5.1.norm.weight": "blocks.37.norm.weight",
+ "model.diffusion_model.input_blocks.5.1.proj_in.bias": "blocks.37.proj_in.bias",
+ "model.diffusion_model.input_blocks.5.1.proj_in.weight": "blocks.37.proj_in.weight",
+ "model.diffusion_model.input_blocks.5.1.proj_out.bias": "blocks.40.proj.bias",
+ "model.diffusion_model.input_blocks.5.1.proj_out.weight": "blocks.40.proj.weight",
+ "model.diffusion_model.input_blocks.5.1.time_mixer.mix_factor": "blocks.40.mix_factor",
+ "model.diffusion_model.input_blocks.5.1.time_pos_embed.0.bias": "blocks.39.positional_embedding_proj.0.bias",
+ "model.diffusion_model.input_blocks.5.1.time_pos_embed.0.weight": "blocks.39.positional_embedding_proj.0.weight",
+ "model.diffusion_model.input_blocks.5.1.time_pos_embed.2.bias": "blocks.39.positional_embedding_proj.2.bias",
+ "model.diffusion_model.input_blocks.5.1.time_pos_embed.2.weight": "blocks.39.positional_embedding_proj.2.weight",
+ "model.diffusion_model.input_blocks.5.1.time_stack.0.attn1.to_k.weight": "blocks.39.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.5.1.time_stack.0.attn1.to_out.0.bias": "blocks.39.attn1.to_out.bias",
+ "model.diffusion_model.input_blocks.5.1.time_stack.0.attn1.to_out.0.weight": "blocks.39.attn1.to_out.weight",
+ "model.diffusion_model.input_blocks.5.1.time_stack.0.attn1.to_q.weight": "blocks.39.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.5.1.time_stack.0.attn1.to_v.weight": "blocks.39.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.5.1.time_stack.0.attn2.to_k.weight": "blocks.39.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.5.1.time_stack.0.attn2.to_out.0.bias": "blocks.39.attn2.to_out.bias",
+ "model.diffusion_model.input_blocks.5.1.time_stack.0.attn2.to_out.0.weight": "blocks.39.attn2.to_out.weight",
+ "model.diffusion_model.input_blocks.5.1.time_stack.0.attn2.to_q.weight": "blocks.39.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.5.1.time_stack.0.attn2.to_v.weight": "blocks.39.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.5.1.time_stack.0.ff.net.0.proj.bias": "blocks.39.act_fn_out.proj.bias",
+ "model.diffusion_model.input_blocks.5.1.time_stack.0.ff.net.0.proj.weight": "blocks.39.act_fn_out.proj.weight",
+ "model.diffusion_model.input_blocks.5.1.time_stack.0.ff.net.2.bias": "blocks.39.ff_out.bias",
+ "model.diffusion_model.input_blocks.5.1.time_stack.0.ff.net.2.weight": "blocks.39.ff_out.weight",
+ "model.diffusion_model.input_blocks.5.1.time_stack.0.ff_in.net.0.proj.bias": "blocks.39.act_fn_in.proj.bias",
+ "model.diffusion_model.input_blocks.5.1.time_stack.0.ff_in.net.0.proj.weight": "blocks.39.act_fn_in.proj.weight",
+ "model.diffusion_model.input_blocks.5.1.time_stack.0.ff_in.net.2.bias": "blocks.39.ff_in.bias",
+ "model.diffusion_model.input_blocks.5.1.time_stack.0.ff_in.net.2.weight": "blocks.39.ff_in.weight",
+ "model.diffusion_model.input_blocks.5.1.time_stack.0.norm1.bias": "blocks.39.norm1.bias",
+ "model.diffusion_model.input_blocks.5.1.time_stack.0.norm1.weight": "blocks.39.norm1.weight",
+ "model.diffusion_model.input_blocks.5.1.time_stack.0.norm2.bias": "blocks.39.norm2.bias",
+ "model.diffusion_model.input_blocks.5.1.time_stack.0.norm2.weight": "blocks.39.norm2.weight",
+ "model.diffusion_model.input_blocks.5.1.time_stack.0.norm3.bias": "blocks.39.norm_out.bias",
+ "model.diffusion_model.input_blocks.5.1.time_stack.0.norm3.weight": "blocks.39.norm_out.weight",
+ "model.diffusion_model.input_blocks.5.1.time_stack.0.norm_in.bias": "blocks.39.norm_in.bias",
+ "model.diffusion_model.input_blocks.5.1.time_stack.0.norm_in.weight": "blocks.39.norm_in.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_k.weight": "blocks.37.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.37.transformer_blocks.0.attn1.to_out.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.37.transformer_blocks.0.attn1.to_out.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_q.weight": "blocks.37.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_v.weight": "blocks.37.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_k.weight": "blocks.37.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.37.transformer_blocks.0.attn2.to_out.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.37.transformer_blocks.0.attn2.to_out.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_q.weight": "blocks.37.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_v.weight": "blocks.37.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.37.transformer_blocks.0.act_fn.proj.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.37.transformer_blocks.0.act_fn.proj.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.bias": "blocks.37.transformer_blocks.0.ff.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.weight": "blocks.37.transformer_blocks.0.ff.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm1.bias": "blocks.37.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm1.weight": "blocks.37.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm2.bias": "blocks.37.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm2.weight": "blocks.37.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm3.bias": "blocks.37.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm3.weight": "blocks.37.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.input_blocks.6.0.op.bias": "blocks.42.conv.bias",
+ "model.diffusion_model.input_blocks.6.0.op.weight": "blocks.42.conv.weight",
+ "model.diffusion_model.input_blocks.7.0.emb_layers.1.bias": "blocks.44.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.7.0.emb_layers.1.weight": "blocks.44.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.7.0.in_layers.0.bias": "blocks.44.norm1.bias",
+ "model.diffusion_model.input_blocks.7.0.in_layers.0.weight": "blocks.44.norm1.weight",
+ "model.diffusion_model.input_blocks.7.0.in_layers.2.bias": "blocks.44.conv1.bias",
+ "model.diffusion_model.input_blocks.7.0.in_layers.2.weight": "blocks.44.conv1.weight",
+ "model.diffusion_model.input_blocks.7.0.out_layers.0.bias": "blocks.44.norm2.bias",
+ "model.diffusion_model.input_blocks.7.0.out_layers.0.weight": "blocks.44.norm2.weight",
+ "model.diffusion_model.input_blocks.7.0.out_layers.3.bias": "blocks.44.conv2.bias",
+ "model.diffusion_model.input_blocks.7.0.out_layers.3.weight": "blocks.44.conv2.weight",
+ "model.diffusion_model.input_blocks.7.0.skip_connection.bias": "blocks.44.conv_shortcut.bias",
+ "model.diffusion_model.input_blocks.7.0.skip_connection.weight": "blocks.44.conv_shortcut.weight",
+ "model.diffusion_model.input_blocks.7.0.time_mixer.mix_factor": "blocks.47.mix_factor",
+ "model.diffusion_model.input_blocks.7.0.time_stack.emb_layers.1.bias": "blocks.46.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.7.0.time_stack.emb_layers.1.weight": "blocks.46.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.7.0.time_stack.in_layers.0.bias": "blocks.46.norm1.bias",
+ "model.diffusion_model.input_blocks.7.0.time_stack.in_layers.0.weight": "blocks.46.norm1.weight",
+ "model.diffusion_model.input_blocks.7.0.time_stack.in_layers.2.bias": "blocks.46.conv1.bias",
+ "model.diffusion_model.input_blocks.7.0.time_stack.in_layers.2.weight": "blocks.46.conv1.weight",
+ "model.diffusion_model.input_blocks.7.0.time_stack.out_layers.0.bias": "blocks.46.norm2.bias",
+ "model.diffusion_model.input_blocks.7.0.time_stack.out_layers.0.weight": "blocks.46.norm2.weight",
+ "model.diffusion_model.input_blocks.7.0.time_stack.out_layers.3.bias": "blocks.46.conv2.bias",
+ "model.diffusion_model.input_blocks.7.0.time_stack.out_layers.3.weight": "blocks.46.conv2.weight",
+ "model.diffusion_model.input_blocks.7.1.norm.bias": "blocks.49.norm.bias",
+ "model.diffusion_model.input_blocks.7.1.norm.weight": "blocks.49.norm.weight",
+ "model.diffusion_model.input_blocks.7.1.proj_in.bias": "blocks.49.proj_in.bias",
+ "model.diffusion_model.input_blocks.7.1.proj_in.weight": "blocks.49.proj_in.weight",
+ "model.diffusion_model.input_blocks.7.1.proj_out.bias": "blocks.52.proj.bias",
+ "model.diffusion_model.input_blocks.7.1.proj_out.weight": "blocks.52.proj.weight",
+ "model.diffusion_model.input_blocks.7.1.time_mixer.mix_factor": "blocks.52.mix_factor",
+ "model.diffusion_model.input_blocks.7.1.time_pos_embed.0.bias": "blocks.51.positional_embedding_proj.0.bias",
+ "model.diffusion_model.input_blocks.7.1.time_pos_embed.0.weight": "blocks.51.positional_embedding_proj.0.weight",
+ "model.diffusion_model.input_blocks.7.1.time_pos_embed.2.bias": "blocks.51.positional_embedding_proj.2.bias",
+ "model.diffusion_model.input_blocks.7.1.time_pos_embed.2.weight": "blocks.51.positional_embedding_proj.2.weight",
+ "model.diffusion_model.input_blocks.7.1.time_stack.0.attn1.to_k.weight": "blocks.51.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.7.1.time_stack.0.attn1.to_out.0.bias": "blocks.51.attn1.to_out.bias",
+ "model.diffusion_model.input_blocks.7.1.time_stack.0.attn1.to_out.0.weight": "blocks.51.attn1.to_out.weight",
+ "model.diffusion_model.input_blocks.7.1.time_stack.0.attn1.to_q.weight": "blocks.51.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.7.1.time_stack.0.attn1.to_v.weight": "blocks.51.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.7.1.time_stack.0.attn2.to_k.weight": "blocks.51.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.7.1.time_stack.0.attn2.to_out.0.bias": "blocks.51.attn2.to_out.bias",
+ "model.diffusion_model.input_blocks.7.1.time_stack.0.attn2.to_out.0.weight": "blocks.51.attn2.to_out.weight",
+ "model.diffusion_model.input_blocks.7.1.time_stack.0.attn2.to_q.weight": "blocks.51.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.7.1.time_stack.0.attn2.to_v.weight": "blocks.51.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.7.1.time_stack.0.ff.net.0.proj.bias": "blocks.51.act_fn_out.proj.bias",
+ "model.diffusion_model.input_blocks.7.1.time_stack.0.ff.net.0.proj.weight": "blocks.51.act_fn_out.proj.weight",
+ "model.diffusion_model.input_blocks.7.1.time_stack.0.ff.net.2.bias": "blocks.51.ff_out.bias",
+ "model.diffusion_model.input_blocks.7.1.time_stack.0.ff.net.2.weight": "blocks.51.ff_out.weight",
+ "model.diffusion_model.input_blocks.7.1.time_stack.0.ff_in.net.0.proj.bias": "blocks.51.act_fn_in.proj.bias",
+ "model.diffusion_model.input_blocks.7.1.time_stack.0.ff_in.net.0.proj.weight": "blocks.51.act_fn_in.proj.weight",
+ "model.diffusion_model.input_blocks.7.1.time_stack.0.ff_in.net.2.bias": "blocks.51.ff_in.bias",
+ "model.diffusion_model.input_blocks.7.1.time_stack.0.ff_in.net.2.weight": "blocks.51.ff_in.weight",
+ "model.diffusion_model.input_blocks.7.1.time_stack.0.norm1.bias": "blocks.51.norm1.bias",
+ "model.diffusion_model.input_blocks.7.1.time_stack.0.norm1.weight": "blocks.51.norm1.weight",
+ "model.diffusion_model.input_blocks.7.1.time_stack.0.norm2.bias": "blocks.51.norm2.bias",
+ "model.diffusion_model.input_blocks.7.1.time_stack.0.norm2.weight": "blocks.51.norm2.weight",
+ "model.diffusion_model.input_blocks.7.1.time_stack.0.norm3.bias": "blocks.51.norm_out.bias",
+ "model.diffusion_model.input_blocks.7.1.time_stack.0.norm3.weight": "blocks.51.norm_out.weight",
+ "model.diffusion_model.input_blocks.7.1.time_stack.0.norm_in.bias": "blocks.51.norm_in.bias",
+ "model.diffusion_model.input_blocks.7.1.time_stack.0.norm_in.weight": "blocks.51.norm_in.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_k.weight": "blocks.49.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.49.transformer_blocks.0.attn1.to_out.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.49.transformer_blocks.0.attn1.to_out.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_q.weight": "blocks.49.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_v.weight": "blocks.49.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_k.weight": "blocks.49.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.49.transformer_blocks.0.attn2.to_out.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.49.transformer_blocks.0.attn2.to_out.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_q.weight": "blocks.49.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_v.weight": "blocks.49.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.49.transformer_blocks.0.act_fn.proj.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.49.transformer_blocks.0.act_fn.proj.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.bias": "blocks.49.transformer_blocks.0.ff.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.weight": "blocks.49.transformer_blocks.0.ff.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm1.bias": "blocks.49.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm1.weight": "blocks.49.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm2.bias": "blocks.49.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm2.weight": "blocks.49.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm3.bias": "blocks.49.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm3.weight": "blocks.49.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.input_blocks.8.0.emb_layers.1.bias": "blocks.54.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.8.0.emb_layers.1.weight": "blocks.54.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.8.0.in_layers.0.bias": "blocks.54.norm1.bias",
+ "model.diffusion_model.input_blocks.8.0.in_layers.0.weight": "blocks.54.norm1.weight",
+ "model.diffusion_model.input_blocks.8.0.in_layers.2.bias": "blocks.54.conv1.bias",
+ "model.diffusion_model.input_blocks.8.0.in_layers.2.weight": "blocks.54.conv1.weight",
+ "model.diffusion_model.input_blocks.8.0.out_layers.0.bias": "blocks.54.norm2.bias",
+ "model.diffusion_model.input_blocks.8.0.out_layers.0.weight": "blocks.54.norm2.weight",
+ "model.diffusion_model.input_blocks.8.0.out_layers.3.bias": "blocks.54.conv2.bias",
+ "model.diffusion_model.input_blocks.8.0.out_layers.3.weight": "blocks.54.conv2.weight",
+ "model.diffusion_model.input_blocks.8.0.time_mixer.mix_factor": "blocks.57.mix_factor",
+ "model.diffusion_model.input_blocks.8.0.time_stack.emb_layers.1.bias": "blocks.56.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.8.0.time_stack.emb_layers.1.weight": "blocks.56.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.8.0.time_stack.in_layers.0.bias": "blocks.56.norm1.bias",
+ "model.diffusion_model.input_blocks.8.0.time_stack.in_layers.0.weight": "blocks.56.norm1.weight",
+ "model.diffusion_model.input_blocks.8.0.time_stack.in_layers.2.bias": "blocks.56.conv1.bias",
+ "model.diffusion_model.input_blocks.8.0.time_stack.in_layers.2.weight": "blocks.56.conv1.weight",
+ "model.diffusion_model.input_blocks.8.0.time_stack.out_layers.0.bias": "blocks.56.norm2.bias",
+ "model.diffusion_model.input_blocks.8.0.time_stack.out_layers.0.weight": "blocks.56.norm2.weight",
+ "model.diffusion_model.input_blocks.8.0.time_stack.out_layers.3.bias": "blocks.56.conv2.bias",
+ "model.diffusion_model.input_blocks.8.0.time_stack.out_layers.3.weight": "blocks.56.conv2.weight",
+ "model.diffusion_model.input_blocks.8.1.norm.bias": "blocks.59.norm.bias",
+ "model.diffusion_model.input_blocks.8.1.norm.weight": "blocks.59.norm.weight",
+ "model.diffusion_model.input_blocks.8.1.proj_in.bias": "blocks.59.proj_in.bias",
+ "model.diffusion_model.input_blocks.8.1.proj_in.weight": "blocks.59.proj_in.weight",
+ "model.diffusion_model.input_blocks.8.1.proj_out.bias": "blocks.62.proj.bias",
+ "model.diffusion_model.input_blocks.8.1.proj_out.weight": "blocks.62.proj.weight",
+ "model.diffusion_model.input_blocks.8.1.time_mixer.mix_factor": "blocks.62.mix_factor",
+ "model.diffusion_model.input_blocks.8.1.time_pos_embed.0.bias": "blocks.61.positional_embedding_proj.0.bias",
+ "model.diffusion_model.input_blocks.8.1.time_pos_embed.0.weight": "blocks.61.positional_embedding_proj.0.weight",
+ "model.diffusion_model.input_blocks.8.1.time_pos_embed.2.bias": "blocks.61.positional_embedding_proj.2.bias",
+ "model.diffusion_model.input_blocks.8.1.time_pos_embed.2.weight": "blocks.61.positional_embedding_proj.2.weight",
+ "model.diffusion_model.input_blocks.8.1.time_stack.0.attn1.to_k.weight": "blocks.61.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.8.1.time_stack.0.attn1.to_out.0.bias": "blocks.61.attn1.to_out.bias",
+ "model.diffusion_model.input_blocks.8.1.time_stack.0.attn1.to_out.0.weight": "blocks.61.attn1.to_out.weight",
+ "model.diffusion_model.input_blocks.8.1.time_stack.0.attn1.to_q.weight": "blocks.61.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.8.1.time_stack.0.attn1.to_v.weight": "blocks.61.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.8.1.time_stack.0.attn2.to_k.weight": "blocks.61.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.8.1.time_stack.0.attn2.to_out.0.bias": "blocks.61.attn2.to_out.bias",
+ "model.diffusion_model.input_blocks.8.1.time_stack.0.attn2.to_out.0.weight": "blocks.61.attn2.to_out.weight",
+ "model.diffusion_model.input_blocks.8.1.time_stack.0.attn2.to_q.weight": "blocks.61.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.8.1.time_stack.0.attn2.to_v.weight": "blocks.61.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.8.1.time_stack.0.ff.net.0.proj.bias": "blocks.61.act_fn_out.proj.bias",
+ "model.diffusion_model.input_blocks.8.1.time_stack.0.ff.net.0.proj.weight": "blocks.61.act_fn_out.proj.weight",
+ "model.diffusion_model.input_blocks.8.1.time_stack.0.ff.net.2.bias": "blocks.61.ff_out.bias",
+ "model.diffusion_model.input_blocks.8.1.time_stack.0.ff.net.2.weight": "blocks.61.ff_out.weight",
+ "model.diffusion_model.input_blocks.8.1.time_stack.0.ff_in.net.0.proj.bias": "blocks.61.act_fn_in.proj.bias",
+ "model.diffusion_model.input_blocks.8.1.time_stack.0.ff_in.net.0.proj.weight": "blocks.61.act_fn_in.proj.weight",
+ "model.diffusion_model.input_blocks.8.1.time_stack.0.ff_in.net.2.bias": "blocks.61.ff_in.bias",
+ "model.diffusion_model.input_blocks.8.1.time_stack.0.ff_in.net.2.weight": "blocks.61.ff_in.weight",
+ "model.diffusion_model.input_blocks.8.1.time_stack.0.norm1.bias": "blocks.61.norm1.bias",
+ "model.diffusion_model.input_blocks.8.1.time_stack.0.norm1.weight": "blocks.61.norm1.weight",
+ "model.diffusion_model.input_blocks.8.1.time_stack.0.norm2.bias": "blocks.61.norm2.bias",
+ "model.diffusion_model.input_blocks.8.1.time_stack.0.norm2.weight": "blocks.61.norm2.weight",
+ "model.diffusion_model.input_blocks.8.1.time_stack.0.norm3.bias": "blocks.61.norm_out.bias",
+ "model.diffusion_model.input_blocks.8.1.time_stack.0.norm3.weight": "blocks.61.norm_out.weight",
+ "model.diffusion_model.input_blocks.8.1.time_stack.0.norm_in.bias": "blocks.61.norm_in.bias",
+ "model.diffusion_model.input_blocks.8.1.time_stack.0.norm_in.weight": "blocks.61.norm_in.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_k.weight": "blocks.59.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.59.transformer_blocks.0.attn1.to_out.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.59.transformer_blocks.0.attn1.to_out.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_q.weight": "blocks.59.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_v.weight": "blocks.59.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_k.weight": "blocks.59.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.59.transformer_blocks.0.attn2.to_out.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.59.transformer_blocks.0.attn2.to_out.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_q.weight": "blocks.59.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_v.weight": "blocks.59.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.59.transformer_blocks.0.act_fn.proj.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.59.transformer_blocks.0.act_fn.proj.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.bias": "blocks.59.transformer_blocks.0.ff.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.weight": "blocks.59.transformer_blocks.0.ff.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm1.bias": "blocks.59.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm1.weight": "blocks.59.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm2.bias": "blocks.59.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm2.weight": "blocks.59.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm3.bias": "blocks.59.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm3.weight": "blocks.59.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.input_blocks.9.0.op.bias": "blocks.64.conv.bias",
+ "model.diffusion_model.input_blocks.9.0.op.weight": "blocks.64.conv.weight",
+ "model.diffusion_model.label_emb.0.0.bias": "add_time_embedding.0.bias",
+ "model.diffusion_model.label_emb.0.0.weight": "add_time_embedding.0.weight",
+ "model.diffusion_model.label_emb.0.2.bias": "add_time_embedding.2.bias",
+ "model.diffusion_model.label_emb.0.2.weight": "add_time_embedding.2.weight",
+ "model.diffusion_model.middle_block.0.emb_layers.1.bias": "blocks.76.time_emb_proj.bias",
+ "model.diffusion_model.middle_block.0.emb_layers.1.weight": "blocks.76.time_emb_proj.weight",
+ "model.diffusion_model.middle_block.0.in_layers.0.bias": "blocks.76.norm1.bias",
+ "model.diffusion_model.middle_block.0.in_layers.0.weight": "blocks.76.norm1.weight",
+ "model.diffusion_model.middle_block.0.in_layers.2.bias": "blocks.76.conv1.bias",
+ "model.diffusion_model.middle_block.0.in_layers.2.weight": "blocks.76.conv1.weight",
+ "model.diffusion_model.middle_block.0.out_layers.0.bias": "blocks.76.norm2.bias",
+ "model.diffusion_model.middle_block.0.out_layers.0.weight": "blocks.76.norm2.weight",
+ "model.diffusion_model.middle_block.0.out_layers.3.bias": "blocks.76.conv2.bias",
+ "model.diffusion_model.middle_block.0.out_layers.3.weight": "blocks.76.conv2.weight",
+ "model.diffusion_model.middle_block.0.time_mixer.mix_factor": "blocks.79.mix_factor",
+ "model.diffusion_model.middle_block.0.time_stack.emb_layers.1.bias": "blocks.78.time_emb_proj.bias",
+ "model.diffusion_model.middle_block.0.time_stack.emb_layers.1.weight": "blocks.78.time_emb_proj.weight",
+ "model.diffusion_model.middle_block.0.time_stack.in_layers.0.bias": "blocks.78.norm1.bias",
+ "model.diffusion_model.middle_block.0.time_stack.in_layers.0.weight": "blocks.78.norm1.weight",
+ "model.diffusion_model.middle_block.0.time_stack.in_layers.2.bias": "blocks.78.conv1.bias",
+ "model.diffusion_model.middle_block.0.time_stack.in_layers.2.weight": "blocks.78.conv1.weight",
+ "model.diffusion_model.middle_block.0.time_stack.out_layers.0.bias": "blocks.78.norm2.bias",
+ "model.diffusion_model.middle_block.0.time_stack.out_layers.0.weight": "blocks.78.norm2.weight",
+ "model.diffusion_model.middle_block.0.time_stack.out_layers.3.bias": "blocks.78.conv2.bias",
+ "model.diffusion_model.middle_block.0.time_stack.out_layers.3.weight": "blocks.78.conv2.weight",
+ "model.diffusion_model.middle_block.1.norm.bias": "blocks.81.norm.bias",
+ "model.diffusion_model.middle_block.1.norm.weight": "blocks.81.norm.weight",
+ "model.diffusion_model.middle_block.1.proj_in.bias": "blocks.81.proj_in.bias",
+ "model.diffusion_model.middle_block.1.proj_in.weight": "blocks.81.proj_in.weight",
+ "model.diffusion_model.middle_block.1.proj_out.bias": "blocks.84.proj.bias",
+ "model.diffusion_model.middle_block.1.proj_out.weight": "blocks.84.proj.weight",
+ "model.diffusion_model.middle_block.1.time_mixer.mix_factor": "blocks.84.mix_factor",
+ "model.diffusion_model.middle_block.1.time_pos_embed.0.bias": "blocks.83.positional_embedding_proj.0.bias",
+ "model.diffusion_model.middle_block.1.time_pos_embed.0.weight": "blocks.83.positional_embedding_proj.0.weight",
+ "model.diffusion_model.middle_block.1.time_pos_embed.2.bias": "blocks.83.positional_embedding_proj.2.bias",
+ "model.diffusion_model.middle_block.1.time_pos_embed.2.weight": "blocks.83.positional_embedding_proj.2.weight",
+ "model.diffusion_model.middle_block.1.time_stack.0.attn1.to_k.weight": "blocks.83.attn1.to_k.weight",
+ "model.diffusion_model.middle_block.1.time_stack.0.attn1.to_out.0.bias": "blocks.83.attn1.to_out.bias",
+ "model.diffusion_model.middle_block.1.time_stack.0.attn1.to_out.0.weight": "blocks.83.attn1.to_out.weight",
+ "model.diffusion_model.middle_block.1.time_stack.0.attn1.to_q.weight": "blocks.83.attn1.to_q.weight",
+ "model.diffusion_model.middle_block.1.time_stack.0.attn1.to_v.weight": "blocks.83.attn1.to_v.weight",
+ "model.diffusion_model.middle_block.1.time_stack.0.attn2.to_k.weight": "blocks.83.attn2.to_k.weight",
+ "model.diffusion_model.middle_block.1.time_stack.0.attn2.to_out.0.bias": "blocks.83.attn2.to_out.bias",
+ "model.diffusion_model.middle_block.1.time_stack.0.attn2.to_out.0.weight": "blocks.83.attn2.to_out.weight",
+ "model.diffusion_model.middle_block.1.time_stack.0.attn2.to_q.weight": "blocks.83.attn2.to_q.weight",
+ "model.diffusion_model.middle_block.1.time_stack.0.attn2.to_v.weight": "blocks.83.attn2.to_v.weight",
+ "model.diffusion_model.middle_block.1.time_stack.0.ff.net.0.proj.bias": "blocks.83.act_fn_out.proj.bias",
+ "model.diffusion_model.middle_block.1.time_stack.0.ff.net.0.proj.weight": "blocks.83.act_fn_out.proj.weight",
+ "model.diffusion_model.middle_block.1.time_stack.0.ff.net.2.bias": "blocks.83.ff_out.bias",
+ "model.diffusion_model.middle_block.1.time_stack.0.ff.net.2.weight": "blocks.83.ff_out.weight",
+ "model.diffusion_model.middle_block.1.time_stack.0.ff_in.net.0.proj.bias": "blocks.83.act_fn_in.proj.bias",
+ "model.diffusion_model.middle_block.1.time_stack.0.ff_in.net.0.proj.weight": "blocks.83.act_fn_in.proj.weight",
+ "model.diffusion_model.middle_block.1.time_stack.0.ff_in.net.2.bias": "blocks.83.ff_in.bias",
+ "model.diffusion_model.middle_block.1.time_stack.0.ff_in.net.2.weight": "blocks.83.ff_in.weight",
+ "model.diffusion_model.middle_block.1.time_stack.0.norm1.bias": "blocks.83.norm1.bias",
+ "model.diffusion_model.middle_block.1.time_stack.0.norm1.weight": "blocks.83.norm1.weight",
+ "model.diffusion_model.middle_block.1.time_stack.0.norm2.bias": "blocks.83.norm2.bias",
+ "model.diffusion_model.middle_block.1.time_stack.0.norm2.weight": "blocks.83.norm2.weight",
+ "model.diffusion_model.middle_block.1.time_stack.0.norm3.bias": "blocks.83.norm_out.bias",
+ "model.diffusion_model.middle_block.1.time_stack.0.norm3.weight": "blocks.83.norm_out.weight",
+ "model.diffusion_model.middle_block.1.time_stack.0.norm_in.bias": "blocks.83.norm_in.bias",
+ "model.diffusion_model.middle_block.1.time_stack.0.norm_in.weight": "blocks.83.norm_in.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_k.weight": "blocks.81.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.81.transformer_blocks.0.attn1.to_out.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.81.transformer_blocks.0.attn1.to_out.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight": "blocks.81.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_v.weight": "blocks.81.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_k.weight": "blocks.81.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.81.transformer_blocks.0.attn2.to_out.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.81.transformer_blocks.0.attn2.to_out.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_q.weight": "blocks.81.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_v.weight": "blocks.81.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.81.transformer_blocks.0.act_fn.proj.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.81.transformer_blocks.0.act_fn.proj.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.2.bias": "blocks.81.transformer_blocks.0.ff.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.2.weight": "blocks.81.transformer_blocks.0.ff.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.norm1.bias": "blocks.81.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.norm1.weight": "blocks.81.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.norm2.bias": "blocks.81.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.norm2.weight": "blocks.81.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.norm3.bias": "blocks.81.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.norm3.weight": "blocks.81.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.middle_block.2.emb_layers.1.bias": "blocks.85.time_emb_proj.bias",
+ "model.diffusion_model.middle_block.2.emb_layers.1.weight": "blocks.85.time_emb_proj.weight",
+ "model.diffusion_model.middle_block.2.in_layers.0.bias": "blocks.85.norm1.bias",
+ "model.diffusion_model.middle_block.2.in_layers.0.weight": "blocks.85.norm1.weight",
+ "model.diffusion_model.middle_block.2.in_layers.2.bias": "blocks.85.conv1.bias",
+ "model.diffusion_model.middle_block.2.in_layers.2.weight": "blocks.85.conv1.weight",
+ "model.diffusion_model.middle_block.2.out_layers.0.bias": "blocks.85.norm2.bias",
+ "model.diffusion_model.middle_block.2.out_layers.0.weight": "blocks.85.norm2.weight",
+ "model.diffusion_model.middle_block.2.out_layers.3.bias": "blocks.85.conv2.bias",
+ "model.diffusion_model.middle_block.2.out_layers.3.weight": "blocks.85.conv2.weight",
+ "model.diffusion_model.middle_block.2.time_mixer.mix_factor": "blocks.88.mix_factor",
+ "model.diffusion_model.middle_block.2.time_stack.emb_layers.1.bias": "blocks.87.time_emb_proj.bias",
+ "model.diffusion_model.middle_block.2.time_stack.emb_layers.1.weight": "blocks.87.time_emb_proj.weight",
+ "model.diffusion_model.middle_block.2.time_stack.in_layers.0.bias": "blocks.87.norm1.bias",
+ "model.diffusion_model.middle_block.2.time_stack.in_layers.0.weight": "blocks.87.norm1.weight",
+ "model.diffusion_model.middle_block.2.time_stack.in_layers.2.bias": "blocks.87.conv1.bias",
+ "model.diffusion_model.middle_block.2.time_stack.in_layers.2.weight": "blocks.87.conv1.weight",
+ "model.diffusion_model.middle_block.2.time_stack.out_layers.0.bias": "blocks.87.norm2.bias",
+ "model.diffusion_model.middle_block.2.time_stack.out_layers.0.weight": "blocks.87.norm2.weight",
+ "model.diffusion_model.middle_block.2.time_stack.out_layers.3.bias": "blocks.87.conv2.bias",
+ "model.diffusion_model.middle_block.2.time_stack.out_layers.3.weight": "blocks.87.conv2.weight",
+ "model.diffusion_model.out.0.bias": "conv_norm_out.bias",
+ "model.diffusion_model.out.0.weight": "conv_norm_out.weight",
+ "model.diffusion_model.out.2.bias": "conv_out.bias",
+ "model.diffusion_model.out.2.weight": "conv_out.weight",
+ "model.diffusion_model.output_blocks.0.0.emb_layers.1.bias": "blocks.90.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.0.0.emb_layers.1.weight": "blocks.90.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.0.0.in_layers.0.bias": "blocks.90.norm1.bias",
+ "model.diffusion_model.output_blocks.0.0.in_layers.0.weight": "blocks.90.norm1.weight",
+ "model.diffusion_model.output_blocks.0.0.in_layers.2.bias": "blocks.90.conv1.bias",
+ "model.diffusion_model.output_blocks.0.0.in_layers.2.weight": "blocks.90.conv1.weight",
+ "model.diffusion_model.output_blocks.0.0.out_layers.0.bias": "blocks.90.norm2.bias",
+ "model.diffusion_model.output_blocks.0.0.out_layers.0.weight": "blocks.90.norm2.weight",
+ "model.diffusion_model.output_blocks.0.0.out_layers.3.bias": "blocks.90.conv2.bias",
+ "model.diffusion_model.output_blocks.0.0.out_layers.3.weight": "blocks.90.conv2.weight",
+ "model.diffusion_model.output_blocks.0.0.skip_connection.bias": "blocks.90.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.0.0.skip_connection.weight": "blocks.90.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.0.0.time_mixer.mix_factor": "blocks.93.mix_factor",
+ "model.diffusion_model.output_blocks.0.0.time_stack.emb_layers.1.bias": "blocks.92.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.0.0.time_stack.emb_layers.1.weight": "blocks.92.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.0.0.time_stack.in_layers.0.bias": "blocks.92.norm1.bias",
+ "model.diffusion_model.output_blocks.0.0.time_stack.in_layers.0.weight": "blocks.92.norm1.weight",
+ "model.diffusion_model.output_blocks.0.0.time_stack.in_layers.2.bias": "blocks.92.conv1.bias",
+ "model.diffusion_model.output_blocks.0.0.time_stack.in_layers.2.weight": "blocks.92.conv1.weight",
+ "model.diffusion_model.output_blocks.0.0.time_stack.out_layers.0.bias": "blocks.92.norm2.bias",
+ "model.diffusion_model.output_blocks.0.0.time_stack.out_layers.0.weight": "blocks.92.norm2.weight",
+ "model.diffusion_model.output_blocks.0.0.time_stack.out_layers.3.bias": "blocks.92.conv2.bias",
+ "model.diffusion_model.output_blocks.0.0.time_stack.out_layers.3.weight": "blocks.92.conv2.weight",
+ "model.diffusion_model.output_blocks.1.0.emb_layers.1.bias": "blocks.95.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.1.0.emb_layers.1.weight": "blocks.95.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.1.0.in_layers.0.bias": "blocks.95.norm1.bias",
+ "model.diffusion_model.output_blocks.1.0.in_layers.0.weight": "blocks.95.norm1.weight",
+ "model.diffusion_model.output_blocks.1.0.in_layers.2.bias": "blocks.95.conv1.bias",
+ "model.diffusion_model.output_blocks.1.0.in_layers.2.weight": "blocks.95.conv1.weight",
+ "model.diffusion_model.output_blocks.1.0.out_layers.0.bias": "blocks.95.norm2.bias",
+ "model.diffusion_model.output_blocks.1.0.out_layers.0.weight": "blocks.95.norm2.weight",
+ "model.diffusion_model.output_blocks.1.0.out_layers.3.bias": "blocks.95.conv2.bias",
+ "model.diffusion_model.output_blocks.1.0.out_layers.3.weight": "blocks.95.conv2.weight",
+ "model.diffusion_model.output_blocks.1.0.skip_connection.bias": "blocks.95.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.1.0.skip_connection.weight": "blocks.95.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.1.0.time_mixer.mix_factor": "blocks.98.mix_factor",
+ "model.diffusion_model.output_blocks.1.0.time_stack.emb_layers.1.bias": "blocks.97.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.1.0.time_stack.emb_layers.1.weight": "blocks.97.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.1.0.time_stack.in_layers.0.bias": "blocks.97.norm1.bias",
+ "model.diffusion_model.output_blocks.1.0.time_stack.in_layers.0.weight": "blocks.97.norm1.weight",
+ "model.diffusion_model.output_blocks.1.0.time_stack.in_layers.2.bias": "blocks.97.conv1.bias",
+ "model.diffusion_model.output_blocks.1.0.time_stack.in_layers.2.weight": "blocks.97.conv1.weight",
+ "model.diffusion_model.output_blocks.1.0.time_stack.out_layers.0.bias": "blocks.97.norm2.bias",
+ "model.diffusion_model.output_blocks.1.0.time_stack.out_layers.0.weight": "blocks.97.norm2.weight",
+ "model.diffusion_model.output_blocks.1.0.time_stack.out_layers.3.bias": "blocks.97.conv2.bias",
+ "model.diffusion_model.output_blocks.1.0.time_stack.out_layers.3.weight": "blocks.97.conv2.weight",
+ "model.diffusion_model.output_blocks.10.0.emb_layers.1.bias": "blocks.178.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.10.0.emb_layers.1.weight": "blocks.178.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.10.0.in_layers.0.bias": "blocks.178.norm1.bias",
+ "model.diffusion_model.output_blocks.10.0.in_layers.0.weight": "blocks.178.norm1.weight",
+ "model.diffusion_model.output_blocks.10.0.in_layers.2.bias": "blocks.178.conv1.bias",
+ "model.diffusion_model.output_blocks.10.0.in_layers.2.weight": "blocks.178.conv1.weight",
+ "model.diffusion_model.output_blocks.10.0.out_layers.0.bias": "blocks.178.norm2.bias",
+ "model.diffusion_model.output_blocks.10.0.out_layers.0.weight": "blocks.178.norm2.weight",
+ "model.diffusion_model.output_blocks.10.0.out_layers.3.bias": "blocks.178.conv2.bias",
+ "model.diffusion_model.output_blocks.10.0.out_layers.3.weight": "blocks.178.conv2.weight",
+ "model.diffusion_model.output_blocks.10.0.skip_connection.bias": "blocks.178.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.10.0.skip_connection.weight": "blocks.178.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.10.0.time_mixer.mix_factor": "blocks.181.mix_factor",
+ "model.diffusion_model.output_blocks.10.0.time_stack.emb_layers.1.bias": "blocks.180.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.10.0.time_stack.emb_layers.1.weight": "blocks.180.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.10.0.time_stack.in_layers.0.bias": "blocks.180.norm1.bias",
+ "model.diffusion_model.output_blocks.10.0.time_stack.in_layers.0.weight": "blocks.180.norm1.weight",
+ "model.diffusion_model.output_blocks.10.0.time_stack.in_layers.2.bias": "blocks.180.conv1.bias",
+ "model.diffusion_model.output_blocks.10.0.time_stack.in_layers.2.weight": "blocks.180.conv1.weight",
+ "model.diffusion_model.output_blocks.10.0.time_stack.out_layers.0.bias": "blocks.180.norm2.bias",
+ "model.diffusion_model.output_blocks.10.0.time_stack.out_layers.0.weight": "blocks.180.norm2.weight",
+ "model.diffusion_model.output_blocks.10.0.time_stack.out_layers.3.bias": "blocks.180.conv2.bias",
+ "model.diffusion_model.output_blocks.10.0.time_stack.out_layers.3.weight": "blocks.180.conv2.weight",
+ "model.diffusion_model.output_blocks.10.1.norm.bias": "blocks.183.norm.bias",
+ "model.diffusion_model.output_blocks.10.1.norm.weight": "blocks.183.norm.weight",
+ "model.diffusion_model.output_blocks.10.1.proj_in.bias": "blocks.183.proj_in.bias",
+ "model.diffusion_model.output_blocks.10.1.proj_in.weight": "blocks.183.proj_in.weight",
+ "model.diffusion_model.output_blocks.10.1.proj_out.bias": "blocks.186.proj.bias",
+ "model.diffusion_model.output_blocks.10.1.proj_out.weight": "blocks.186.proj.weight",
+ "model.diffusion_model.output_blocks.10.1.time_mixer.mix_factor": "blocks.186.mix_factor",
+ "model.diffusion_model.output_blocks.10.1.time_pos_embed.0.bias": "blocks.185.positional_embedding_proj.0.bias",
+ "model.diffusion_model.output_blocks.10.1.time_pos_embed.0.weight": "blocks.185.positional_embedding_proj.0.weight",
+ "model.diffusion_model.output_blocks.10.1.time_pos_embed.2.bias": "blocks.185.positional_embedding_proj.2.bias",
+ "model.diffusion_model.output_blocks.10.1.time_pos_embed.2.weight": "blocks.185.positional_embedding_proj.2.weight",
+ "model.diffusion_model.output_blocks.10.1.time_stack.0.attn1.to_k.weight": "blocks.185.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.10.1.time_stack.0.attn1.to_out.0.bias": "blocks.185.attn1.to_out.bias",
+ "model.diffusion_model.output_blocks.10.1.time_stack.0.attn1.to_out.0.weight": "blocks.185.attn1.to_out.weight",
+ "model.diffusion_model.output_blocks.10.1.time_stack.0.attn1.to_q.weight": "blocks.185.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.10.1.time_stack.0.attn1.to_v.weight": "blocks.185.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.10.1.time_stack.0.attn2.to_k.weight": "blocks.185.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.10.1.time_stack.0.attn2.to_out.0.bias": "blocks.185.attn2.to_out.bias",
+ "model.diffusion_model.output_blocks.10.1.time_stack.0.attn2.to_out.0.weight": "blocks.185.attn2.to_out.weight",
+ "model.diffusion_model.output_blocks.10.1.time_stack.0.attn2.to_q.weight": "blocks.185.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.10.1.time_stack.0.attn2.to_v.weight": "blocks.185.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.10.1.time_stack.0.ff.net.0.proj.bias": "blocks.185.act_fn_out.proj.bias",
+ "model.diffusion_model.output_blocks.10.1.time_stack.0.ff.net.0.proj.weight": "blocks.185.act_fn_out.proj.weight",
+ "model.diffusion_model.output_blocks.10.1.time_stack.0.ff.net.2.bias": "blocks.185.ff_out.bias",
+ "model.diffusion_model.output_blocks.10.1.time_stack.0.ff.net.2.weight": "blocks.185.ff_out.weight",
+ "model.diffusion_model.output_blocks.10.1.time_stack.0.ff_in.net.0.proj.bias": "blocks.185.act_fn_in.proj.bias",
+ "model.diffusion_model.output_blocks.10.1.time_stack.0.ff_in.net.0.proj.weight": "blocks.185.act_fn_in.proj.weight",
+ "model.diffusion_model.output_blocks.10.1.time_stack.0.ff_in.net.2.bias": "blocks.185.ff_in.bias",
+ "model.diffusion_model.output_blocks.10.1.time_stack.0.ff_in.net.2.weight": "blocks.185.ff_in.weight",
+ "model.diffusion_model.output_blocks.10.1.time_stack.0.norm1.bias": "blocks.185.norm1.bias",
+ "model.diffusion_model.output_blocks.10.1.time_stack.0.norm1.weight": "blocks.185.norm1.weight",
+ "model.diffusion_model.output_blocks.10.1.time_stack.0.norm2.bias": "blocks.185.norm2.bias",
+ "model.diffusion_model.output_blocks.10.1.time_stack.0.norm2.weight": "blocks.185.norm2.weight",
+ "model.diffusion_model.output_blocks.10.1.time_stack.0.norm3.bias": "blocks.185.norm_out.bias",
+ "model.diffusion_model.output_blocks.10.1.time_stack.0.norm3.weight": "blocks.185.norm_out.weight",
+ "model.diffusion_model.output_blocks.10.1.time_stack.0.norm_in.bias": "blocks.185.norm_in.bias",
+ "model.diffusion_model.output_blocks.10.1.time_stack.0.norm_in.weight": "blocks.185.norm_in.weight",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_k.weight": "blocks.183.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.183.transformer_blocks.0.attn1.to_out.bias",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.183.transformer_blocks.0.attn1.to_out.weight",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_q.weight": "blocks.183.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_v.weight": "blocks.183.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_k.weight": "blocks.183.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.183.transformer_blocks.0.attn2.to_out.bias",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.183.transformer_blocks.0.attn2.to_out.weight",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_q.weight": "blocks.183.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_v.weight": "blocks.183.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.183.transformer_blocks.0.act_fn.proj.bias",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.183.transformer_blocks.0.act_fn.proj.weight",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.ff.net.2.bias": "blocks.183.transformer_blocks.0.ff.bias",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.ff.net.2.weight": "blocks.183.transformer_blocks.0.ff.weight",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm1.bias": "blocks.183.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm1.weight": "blocks.183.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm2.bias": "blocks.183.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm2.weight": "blocks.183.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm3.bias": "blocks.183.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm3.weight": "blocks.183.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.output_blocks.11.0.emb_layers.1.bias": "blocks.188.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.11.0.emb_layers.1.weight": "blocks.188.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.11.0.in_layers.0.bias": "blocks.188.norm1.bias",
+ "model.diffusion_model.output_blocks.11.0.in_layers.0.weight": "blocks.188.norm1.weight",
+ "model.diffusion_model.output_blocks.11.0.in_layers.2.bias": "blocks.188.conv1.bias",
+ "model.diffusion_model.output_blocks.11.0.in_layers.2.weight": "blocks.188.conv1.weight",
+ "model.diffusion_model.output_blocks.11.0.out_layers.0.bias": "blocks.188.norm2.bias",
+ "model.diffusion_model.output_blocks.11.0.out_layers.0.weight": "blocks.188.norm2.weight",
+ "model.diffusion_model.output_blocks.11.0.out_layers.3.bias": "blocks.188.conv2.bias",
+ "model.diffusion_model.output_blocks.11.0.out_layers.3.weight": "blocks.188.conv2.weight",
+ "model.diffusion_model.output_blocks.11.0.skip_connection.bias": "blocks.188.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.11.0.skip_connection.weight": "blocks.188.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.11.0.time_mixer.mix_factor": "blocks.191.mix_factor",
+ "model.diffusion_model.output_blocks.11.0.time_stack.emb_layers.1.bias": "blocks.190.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.11.0.time_stack.emb_layers.1.weight": "blocks.190.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.11.0.time_stack.in_layers.0.bias": "blocks.190.norm1.bias",
+ "model.diffusion_model.output_blocks.11.0.time_stack.in_layers.0.weight": "blocks.190.norm1.weight",
+ "model.diffusion_model.output_blocks.11.0.time_stack.in_layers.2.bias": "blocks.190.conv1.bias",
+ "model.diffusion_model.output_blocks.11.0.time_stack.in_layers.2.weight": "blocks.190.conv1.weight",
+ "model.diffusion_model.output_blocks.11.0.time_stack.out_layers.0.bias": "blocks.190.norm2.bias",
+ "model.diffusion_model.output_blocks.11.0.time_stack.out_layers.0.weight": "blocks.190.norm2.weight",
+ "model.diffusion_model.output_blocks.11.0.time_stack.out_layers.3.bias": "blocks.190.conv2.bias",
+ "model.diffusion_model.output_blocks.11.0.time_stack.out_layers.3.weight": "blocks.190.conv2.weight",
+ "model.diffusion_model.output_blocks.11.1.norm.bias": "blocks.193.norm.bias",
+ "model.diffusion_model.output_blocks.11.1.norm.weight": "blocks.193.norm.weight",
+ "model.diffusion_model.output_blocks.11.1.proj_in.bias": "blocks.193.proj_in.bias",
+ "model.diffusion_model.output_blocks.11.1.proj_in.weight": "blocks.193.proj_in.weight",
+ "model.diffusion_model.output_blocks.11.1.proj_out.bias": "blocks.196.proj.bias",
+ "model.diffusion_model.output_blocks.11.1.proj_out.weight": "blocks.196.proj.weight",
+ "model.diffusion_model.output_blocks.11.1.time_mixer.mix_factor": "blocks.196.mix_factor",
+ "model.diffusion_model.output_blocks.11.1.time_pos_embed.0.bias": "blocks.195.positional_embedding_proj.0.bias",
+ "model.diffusion_model.output_blocks.11.1.time_pos_embed.0.weight": "blocks.195.positional_embedding_proj.0.weight",
+ "model.diffusion_model.output_blocks.11.1.time_pos_embed.2.bias": "blocks.195.positional_embedding_proj.2.bias",
+ "model.diffusion_model.output_blocks.11.1.time_pos_embed.2.weight": "blocks.195.positional_embedding_proj.2.weight",
+ "model.diffusion_model.output_blocks.11.1.time_stack.0.attn1.to_k.weight": "blocks.195.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.11.1.time_stack.0.attn1.to_out.0.bias": "blocks.195.attn1.to_out.bias",
+ "model.diffusion_model.output_blocks.11.1.time_stack.0.attn1.to_out.0.weight": "blocks.195.attn1.to_out.weight",
+ "model.diffusion_model.output_blocks.11.1.time_stack.0.attn1.to_q.weight": "blocks.195.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.11.1.time_stack.0.attn1.to_v.weight": "blocks.195.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.11.1.time_stack.0.attn2.to_k.weight": "blocks.195.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.11.1.time_stack.0.attn2.to_out.0.bias": "blocks.195.attn2.to_out.bias",
+ "model.diffusion_model.output_blocks.11.1.time_stack.0.attn2.to_out.0.weight": "blocks.195.attn2.to_out.weight",
+ "model.diffusion_model.output_blocks.11.1.time_stack.0.attn2.to_q.weight": "blocks.195.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.11.1.time_stack.0.attn2.to_v.weight": "blocks.195.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.11.1.time_stack.0.ff.net.0.proj.bias": "blocks.195.act_fn_out.proj.bias",
+ "model.diffusion_model.output_blocks.11.1.time_stack.0.ff.net.0.proj.weight": "blocks.195.act_fn_out.proj.weight",
+ "model.diffusion_model.output_blocks.11.1.time_stack.0.ff.net.2.bias": "blocks.195.ff_out.bias",
+ "model.diffusion_model.output_blocks.11.1.time_stack.0.ff.net.2.weight": "blocks.195.ff_out.weight",
+ "model.diffusion_model.output_blocks.11.1.time_stack.0.ff_in.net.0.proj.bias": "blocks.195.act_fn_in.proj.bias",
+ "model.diffusion_model.output_blocks.11.1.time_stack.0.ff_in.net.0.proj.weight": "blocks.195.act_fn_in.proj.weight",
+ "model.diffusion_model.output_blocks.11.1.time_stack.0.ff_in.net.2.bias": "blocks.195.ff_in.bias",
+ "model.diffusion_model.output_blocks.11.1.time_stack.0.ff_in.net.2.weight": "blocks.195.ff_in.weight",
+ "model.diffusion_model.output_blocks.11.1.time_stack.0.norm1.bias": "blocks.195.norm1.bias",
+ "model.diffusion_model.output_blocks.11.1.time_stack.0.norm1.weight": "blocks.195.norm1.weight",
+ "model.diffusion_model.output_blocks.11.1.time_stack.0.norm2.bias": "blocks.195.norm2.bias",
+ "model.diffusion_model.output_blocks.11.1.time_stack.0.norm2.weight": "blocks.195.norm2.weight",
+ "model.diffusion_model.output_blocks.11.1.time_stack.0.norm3.bias": "blocks.195.norm_out.bias",
+ "model.diffusion_model.output_blocks.11.1.time_stack.0.norm3.weight": "blocks.195.norm_out.weight",
+ "model.diffusion_model.output_blocks.11.1.time_stack.0.norm_in.bias": "blocks.195.norm_in.bias",
+ "model.diffusion_model.output_blocks.11.1.time_stack.0.norm_in.weight": "blocks.195.norm_in.weight",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_k.weight": "blocks.193.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.193.transformer_blocks.0.attn1.to_out.bias",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.193.transformer_blocks.0.attn1.to_out.weight",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_q.weight": "blocks.193.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_v.weight": "blocks.193.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_k.weight": "blocks.193.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.193.transformer_blocks.0.attn2.to_out.bias",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.193.transformer_blocks.0.attn2.to_out.weight",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_q.weight": "blocks.193.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_v.weight": "blocks.193.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.193.transformer_blocks.0.act_fn.proj.bias",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.193.transformer_blocks.0.act_fn.proj.weight",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.ff.net.2.bias": "blocks.193.transformer_blocks.0.ff.bias",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.ff.net.2.weight": "blocks.193.transformer_blocks.0.ff.weight",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.bias": "blocks.193.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.weight": "blocks.193.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm2.bias": "blocks.193.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm2.weight": "blocks.193.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm3.bias": "blocks.193.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm3.weight": "blocks.193.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.output_blocks.2.0.emb_layers.1.bias": "blocks.100.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.2.0.emb_layers.1.weight": "blocks.100.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.2.0.in_layers.0.bias": "blocks.100.norm1.bias",
+ "model.diffusion_model.output_blocks.2.0.in_layers.0.weight": "blocks.100.norm1.weight",
+ "model.diffusion_model.output_blocks.2.0.in_layers.2.bias": "blocks.100.conv1.bias",
+ "model.diffusion_model.output_blocks.2.0.in_layers.2.weight": "blocks.100.conv1.weight",
+ "model.diffusion_model.output_blocks.2.0.out_layers.0.bias": "blocks.100.norm2.bias",
+ "model.diffusion_model.output_blocks.2.0.out_layers.0.weight": "blocks.100.norm2.weight",
+ "model.diffusion_model.output_blocks.2.0.out_layers.3.bias": "blocks.100.conv2.bias",
+ "model.diffusion_model.output_blocks.2.0.out_layers.3.weight": "blocks.100.conv2.weight",
+ "model.diffusion_model.output_blocks.2.0.skip_connection.bias": "blocks.100.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.2.0.skip_connection.weight": "blocks.100.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.2.0.time_mixer.mix_factor": "blocks.103.mix_factor",
+ "model.diffusion_model.output_blocks.2.0.time_stack.emb_layers.1.bias": "blocks.102.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.2.0.time_stack.emb_layers.1.weight": "blocks.102.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.2.0.time_stack.in_layers.0.bias": "blocks.102.norm1.bias",
+ "model.diffusion_model.output_blocks.2.0.time_stack.in_layers.0.weight": "blocks.102.norm1.weight",
+ "model.diffusion_model.output_blocks.2.0.time_stack.in_layers.2.bias": "blocks.102.conv1.bias",
+ "model.diffusion_model.output_blocks.2.0.time_stack.in_layers.2.weight": "blocks.102.conv1.weight",
+ "model.diffusion_model.output_blocks.2.0.time_stack.out_layers.0.bias": "blocks.102.norm2.bias",
+ "model.diffusion_model.output_blocks.2.0.time_stack.out_layers.0.weight": "blocks.102.norm2.weight",
+ "model.diffusion_model.output_blocks.2.0.time_stack.out_layers.3.bias": "blocks.102.conv2.bias",
+ "model.diffusion_model.output_blocks.2.0.time_stack.out_layers.3.weight": "blocks.102.conv2.weight",
+ "model.diffusion_model.output_blocks.2.1.conv.bias": "blocks.104.conv.bias",
+ "model.diffusion_model.output_blocks.2.1.conv.weight": "blocks.104.conv.weight",
+ "model.diffusion_model.output_blocks.3.0.emb_layers.1.bias": "blocks.106.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.3.0.emb_layers.1.weight": "blocks.106.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.3.0.in_layers.0.bias": "blocks.106.norm1.bias",
+ "model.diffusion_model.output_blocks.3.0.in_layers.0.weight": "blocks.106.norm1.weight",
+ "model.diffusion_model.output_blocks.3.0.in_layers.2.bias": "blocks.106.conv1.bias",
+ "model.diffusion_model.output_blocks.3.0.in_layers.2.weight": "blocks.106.conv1.weight",
+ "model.diffusion_model.output_blocks.3.0.out_layers.0.bias": "blocks.106.norm2.bias",
+ "model.diffusion_model.output_blocks.3.0.out_layers.0.weight": "blocks.106.norm2.weight",
+ "model.diffusion_model.output_blocks.3.0.out_layers.3.bias": "blocks.106.conv2.bias",
+ "model.diffusion_model.output_blocks.3.0.out_layers.3.weight": "blocks.106.conv2.weight",
+ "model.diffusion_model.output_blocks.3.0.skip_connection.bias": "blocks.106.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.3.0.skip_connection.weight": "blocks.106.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.3.0.time_mixer.mix_factor": "blocks.109.mix_factor",
+ "model.diffusion_model.output_blocks.3.0.time_stack.emb_layers.1.bias": "blocks.108.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.3.0.time_stack.emb_layers.1.weight": "blocks.108.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.3.0.time_stack.in_layers.0.bias": "blocks.108.norm1.bias",
+ "model.diffusion_model.output_blocks.3.0.time_stack.in_layers.0.weight": "blocks.108.norm1.weight",
+ "model.diffusion_model.output_blocks.3.0.time_stack.in_layers.2.bias": "blocks.108.conv1.bias",
+ "model.diffusion_model.output_blocks.3.0.time_stack.in_layers.2.weight": "blocks.108.conv1.weight",
+ "model.diffusion_model.output_blocks.3.0.time_stack.out_layers.0.bias": "blocks.108.norm2.bias",
+ "model.diffusion_model.output_blocks.3.0.time_stack.out_layers.0.weight": "blocks.108.norm2.weight",
+ "model.diffusion_model.output_blocks.3.0.time_stack.out_layers.3.bias": "blocks.108.conv2.bias",
+ "model.diffusion_model.output_blocks.3.0.time_stack.out_layers.3.weight": "blocks.108.conv2.weight",
+ "model.diffusion_model.output_blocks.3.1.norm.bias": "blocks.111.norm.bias",
+ "model.diffusion_model.output_blocks.3.1.norm.weight": "blocks.111.norm.weight",
+ "model.diffusion_model.output_blocks.3.1.proj_in.bias": "blocks.111.proj_in.bias",
+ "model.diffusion_model.output_blocks.3.1.proj_in.weight": "blocks.111.proj_in.weight",
+ "model.diffusion_model.output_blocks.3.1.proj_out.bias": "blocks.114.proj.bias",
+ "model.diffusion_model.output_blocks.3.1.proj_out.weight": "blocks.114.proj.weight",
+ "model.diffusion_model.output_blocks.3.1.time_mixer.mix_factor": "blocks.114.mix_factor",
+ "model.diffusion_model.output_blocks.3.1.time_pos_embed.0.bias": "blocks.113.positional_embedding_proj.0.bias",
+ "model.diffusion_model.output_blocks.3.1.time_pos_embed.0.weight": "blocks.113.positional_embedding_proj.0.weight",
+ "model.diffusion_model.output_blocks.3.1.time_pos_embed.2.bias": "blocks.113.positional_embedding_proj.2.bias",
+ "model.diffusion_model.output_blocks.3.1.time_pos_embed.2.weight": "blocks.113.positional_embedding_proj.2.weight",
+ "model.diffusion_model.output_blocks.3.1.time_stack.0.attn1.to_k.weight": "blocks.113.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.3.1.time_stack.0.attn1.to_out.0.bias": "blocks.113.attn1.to_out.bias",
+ "model.diffusion_model.output_blocks.3.1.time_stack.0.attn1.to_out.0.weight": "blocks.113.attn1.to_out.weight",
+ "model.diffusion_model.output_blocks.3.1.time_stack.0.attn1.to_q.weight": "blocks.113.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.3.1.time_stack.0.attn1.to_v.weight": "blocks.113.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.3.1.time_stack.0.attn2.to_k.weight": "blocks.113.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.3.1.time_stack.0.attn2.to_out.0.bias": "blocks.113.attn2.to_out.bias",
+ "model.diffusion_model.output_blocks.3.1.time_stack.0.attn2.to_out.0.weight": "blocks.113.attn2.to_out.weight",
+ "model.diffusion_model.output_blocks.3.1.time_stack.0.attn2.to_q.weight": "blocks.113.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.3.1.time_stack.0.attn2.to_v.weight": "blocks.113.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.3.1.time_stack.0.ff.net.0.proj.bias": "blocks.113.act_fn_out.proj.bias",
+ "model.diffusion_model.output_blocks.3.1.time_stack.0.ff.net.0.proj.weight": "blocks.113.act_fn_out.proj.weight",
+ "model.diffusion_model.output_blocks.3.1.time_stack.0.ff.net.2.bias": "blocks.113.ff_out.bias",
+ "model.diffusion_model.output_blocks.3.1.time_stack.0.ff.net.2.weight": "blocks.113.ff_out.weight",
+ "model.diffusion_model.output_blocks.3.1.time_stack.0.ff_in.net.0.proj.bias": "blocks.113.act_fn_in.proj.bias",
+ "model.diffusion_model.output_blocks.3.1.time_stack.0.ff_in.net.0.proj.weight": "blocks.113.act_fn_in.proj.weight",
+ "model.diffusion_model.output_blocks.3.1.time_stack.0.ff_in.net.2.bias": "blocks.113.ff_in.bias",
+ "model.diffusion_model.output_blocks.3.1.time_stack.0.ff_in.net.2.weight": "blocks.113.ff_in.weight",
+ "model.diffusion_model.output_blocks.3.1.time_stack.0.norm1.bias": "blocks.113.norm1.bias",
+ "model.diffusion_model.output_blocks.3.1.time_stack.0.norm1.weight": "blocks.113.norm1.weight",
+ "model.diffusion_model.output_blocks.3.1.time_stack.0.norm2.bias": "blocks.113.norm2.bias",
+ "model.diffusion_model.output_blocks.3.1.time_stack.0.norm2.weight": "blocks.113.norm2.weight",
+ "model.diffusion_model.output_blocks.3.1.time_stack.0.norm3.bias": "blocks.113.norm_out.bias",
+ "model.diffusion_model.output_blocks.3.1.time_stack.0.norm3.weight": "blocks.113.norm_out.weight",
+ "model.diffusion_model.output_blocks.3.1.time_stack.0.norm_in.bias": "blocks.113.norm_in.bias",
+ "model.diffusion_model.output_blocks.3.1.time_stack.0.norm_in.weight": "blocks.113.norm_in.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_k.weight": "blocks.111.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.111.transformer_blocks.0.attn1.to_out.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.111.transformer_blocks.0.attn1.to_out.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_q.weight": "blocks.111.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_v.weight": "blocks.111.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_k.weight": "blocks.111.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.111.transformer_blocks.0.attn2.to_out.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.111.transformer_blocks.0.attn2.to_out.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_q.weight": "blocks.111.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_v.weight": "blocks.111.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.111.transformer_blocks.0.act_fn.proj.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.111.transformer_blocks.0.act_fn.proj.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.2.bias": "blocks.111.transformer_blocks.0.ff.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.2.weight": "blocks.111.transformer_blocks.0.ff.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm1.bias": "blocks.111.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm1.weight": "blocks.111.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm2.bias": "blocks.111.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm2.weight": "blocks.111.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm3.bias": "blocks.111.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm3.weight": "blocks.111.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.output_blocks.4.0.emb_layers.1.bias": "blocks.116.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.4.0.emb_layers.1.weight": "blocks.116.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.4.0.in_layers.0.bias": "blocks.116.norm1.bias",
+ "model.diffusion_model.output_blocks.4.0.in_layers.0.weight": "blocks.116.norm1.weight",
+ "model.diffusion_model.output_blocks.4.0.in_layers.2.bias": "blocks.116.conv1.bias",
+ "model.diffusion_model.output_blocks.4.0.in_layers.2.weight": "blocks.116.conv1.weight",
+ "model.diffusion_model.output_blocks.4.0.out_layers.0.bias": "blocks.116.norm2.bias",
+ "model.diffusion_model.output_blocks.4.0.out_layers.0.weight": "blocks.116.norm2.weight",
+ "model.diffusion_model.output_blocks.4.0.out_layers.3.bias": "blocks.116.conv2.bias",
+ "model.diffusion_model.output_blocks.4.0.out_layers.3.weight": "blocks.116.conv2.weight",
+ "model.diffusion_model.output_blocks.4.0.skip_connection.bias": "blocks.116.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.4.0.skip_connection.weight": "blocks.116.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.4.0.time_mixer.mix_factor": "blocks.119.mix_factor",
+ "model.diffusion_model.output_blocks.4.0.time_stack.emb_layers.1.bias": "blocks.118.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.4.0.time_stack.emb_layers.1.weight": "blocks.118.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.4.0.time_stack.in_layers.0.bias": "blocks.118.norm1.bias",
+ "model.diffusion_model.output_blocks.4.0.time_stack.in_layers.0.weight": "blocks.118.norm1.weight",
+ "model.diffusion_model.output_blocks.4.0.time_stack.in_layers.2.bias": "blocks.118.conv1.bias",
+ "model.diffusion_model.output_blocks.4.0.time_stack.in_layers.2.weight": "blocks.118.conv1.weight",
+ "model.diffusion_model.output_blocks.4.0.time_stack.out_layers.0.bias": "blocks.118.norm2.bias",
+ "model.diffusion_model.output_blocks.4.0.time_stack.out_layers.0.weight": "blocks.118.norm2.weight",
+ "model.diffusion_model.output_blocks.4.0.time_stack.out_layers.3.bias": "blocks.118.conv2.bias",
+ "model.diffusion_model.output_blocks.4.0.time_stack.out_layers.3.weight": "blocks.118.conv2.weight",
+ "model.diffusion_model.output_blocks.4.1.norm.bias": "blocks.121.norm.bias",
+ "model.diffusion_model.output_blocks.4.1.norm.weight": "blocks.121.norm.weight",
+ "model.diffusion_model.output_blocks.4.1.proj_in.bias": "blocks.121.proj_in.bias",
+ "model.diffusion_model.output_blocks.4.1.proj_in.weight": "blocks.121.proj_in.weight",
+ "model.diffusion_model.output_blocks.4.1.proj_out.bias": "blocks.124.proj.bias",
+ "model.diffusion_model.output_blocks.4.1.proj_out.weight": "blocks.124.proj.weight",
+ "model.diffusion_model.output_blocks.4.1.time_mixer.mix_factor": "blocks.124.mix_factor",
+ "model.diffusion_model.output_blocks.4.1.time_pos_embed.0.bias": "blocks.123.positional_embedding_proj.0.bias",
+ "model.diffusion_model.output_blocks.4.1.time_pos_embed.0.weight": "blocks.123.positional_embedding_proj.0.weight",
+ "model.diffusion_model.output_blocks.4.1.time_pos_embed.2.bias": "blocks.123.positional_embedding_proj.2.bias",
+ "model.diffusion_model.output_blocks.4.1.time_pos_embed.2.weight": "blocks.123.positional_embedding_proj.2.weight",
+ "model.diffusion_model.output_blocks.4.1.time_stack.0.attn1.to_k.weight": "blocks.123.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.4.1.time_stack.0.attn1.to_out.0.bias": "blocks.123.attn1.to_out.bias",
+ "model.diffusion_model.output_blocks.4.1.time_stack.0.attn1.to_out.0.weight": "blocks.123.attn1.to_out.weight",
+ "model.diffusion_model.output_blocks.4.1.time_stack.0.attn1.to_q.weight": "blocks.123.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.4.1.time_stack.0.attn1.to_v.weight": "blocks.123.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.4.1.time_stack.0.attn2.to_k.weight": "blocks.123.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.4.1.time_stack.0.attn2.to_out.0.bias": "blocks.123.attn2.to_out.bias",
+ "model.diffusion_model.output_blocks.4.1.time_stack.0.attn2.to_out.0.weight": "blocks.123.attn2.to_out.weight",
+ "model.diffusion_model.output_blocks.4.1.time_stack.0.attn2.to_q.weight": "blocks.123.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.4.1.time_stack.0.attn2.to_v.weight": "blocks.123.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.4.1.time_stack.0.ff.net.0.proj.bias": "blocks.123.act_fn_out.proj.bias",
+ "model.diffusion_model.output_blocks.4.1.time_stack.0.ff.net.0.proj.weight": "blocks.123.act_fn_out.proj.weight",
+ "model.diffusion_model.output_blocks.4.1.time_stack.0.ff.net.2.bias": "blocks.123.ff_out.bias",
+ "model.diffusion_model.output_blocks.4.1.time_stack.0.ff.net.2.weight": "blocks.123.ff_out.weight",
+ "model.diffusion_model.output_blocks.4.1.time_stack.0.ff_in.net.0.proj.bias": "blocks.123.act_fn_in.proj.bias",
+ "model.diffusion_model.output_blocks.4.1.time_stack.0.ff_in.net.0.proj.weight": "blocks.123.act_fn_in.proj.weight",
+ "model.diffusion_model.output_blocks.4.1.time_stack.0.ff_in.net.2.bias": "blocks.123.ff_in.bias",
+ "model.diffusion_model.output_blocks.4.1.time_stack.0.ff_in.net.2.weight": "blocks.123.ff_in.weight",
+ "model.diffusion_model.output_blocks.4.1.time_stack.0.norm1.bias": "blocks.123.norm1.bias",
+ "model.diffusion_model.output_blocks.4.1.time_stack.0.norm1.weight": "blocks.123.norm1.weight",
+ "model.diffusion_model.output_blocks.4.1.time_stack.0.norm2.bias": "blocks.123.norm2.bias",
+ "model.diffusion_model.output_blocks.4.1.time_stack.0.norm2.weight": "blocks.123.norm2.weight",
+ "model.diffusion_model.output_blocks.4.1.time_stack.0.norm3.bias": "blocks.123.norm_out.bias",
+ "model.diffusion_model.output_blocks.4.1.time_stack.0.norm3.weight": "blocks.123.norm_out.weight",
+ "model.diffusion_model.output_blocks.4.1.time_stack.0.norm_in.bias": "blocks.123.norm_in.bias",
+ "model.diffusion_model.output_blocks.4.1.time_stack.0.norm_in.weight": "blocks.123.norm_in.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_k.weight": "blocks.121.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.121.transformer_blocks.0.attn1.to_out.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.121.transformer_blocks.0.attn1.to_out.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_q.weight": "blocks.121.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_v.weight": "blocks.121.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_k.weight": "blocks.121.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.121.transformer_blocks.0.attn2.to_out.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.121.transformer_blocks.0.attn2.to_out.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_q.weight": "blocks.121.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_v.weight": "blocks.121.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.121.transformer_blocks.0.act_fn.proj.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.121.transformer_blocks.0.act_fn.proj.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.2.bias": "blocks.121.transformer_blocks.0.ff.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.2.weight": "blocks.121.transformer_blocks.0.ff.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm1.bias": "blocks.121.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm1.weight": "blocks.121.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm2.bias": "blocks.121.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm2.weight": "blocks.121.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm3.bias": "blocks.121.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm3.weight": "blocks.121.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.output_blocks.5.0.emb_layers.1.bias": "blocks.126.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.5.0.emb_layers.1.weight": "blocks.126.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.5.0.in_layers.0.bias": "blocks.126.norm1.bias",
+ "model.diffusion_model.output_blocks.5.0.in_layers.0.weight": "blocks.126.norm1.weight",
+ "model.diffusion_model.output_blocks.5.0.in_layers.2.bias": "blocks.126.conv1.bias",
+ "model.diffusion_model.output_blocks.5.0.in_layers.2.weight": "blocks.126.conv1.weight",
+ "model.diffusion_model.output_blocks.5.0.out_layers.0.bias": "blocks.126.norm2.bias",
+ "model.diffusion_model.output_blocks.5.0.out_layers.0.weight": "blocks.126.norm2.weight",
+ "model.diffusion_model.output_blocks.5.0.out_layers.3.bias": "blocks.126.conv2.bias",
+ "model.diffusion_model.output_blocks.5.0.out_layers.3.weight": "blocks.126.conv2.weight",
+ "model.diffusion_model.output_blocks.5.0.skip_connection.bias": "blocks.126.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.5.0.skip_connection.weight": "blocks.126.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.5.0.time_mixer.mix_factor": "blocks.129.mix_factor",
+ "model.diffusion_model.output_blocks.5.0.time_stack.emb_layers.1.bias": "blocks.128.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.5.0.time_stack.emb_layers.1.weight": "blocks.128.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.5.0.time_stack.in_layers.0.bias": "blocks.128.norm1.bias",
+ "model.diffusion_model.output_blocks.5.0.time_stack.in_layers.0.weight": "blocks.128.norm1.weight",
+ "model.diffusion_model.output_blocks.5.0.time_stack.in_layers.2.bias": "blocks.128.conv1.bias",
+ "model.diffusion_model.output_blocks.5.0.time_stack.in_layers.2.weight": "blocks.128.conv1.weight",
+ "model.diffusion_model.output_blocks.5.0.time_stack.out_layers.0.bias": "blocks.128.norm2.bias",
+ "model.diffusion_model.output_blocks.5.0.time_stack.out_layers.0.weight": "blocks.128.norm2.weight",
+ "model.diffusion_model.output_blocks.5.0.time_stack.out_layers.3.bias": "blocks.128.conv2.bias",
+ "model.diffusion_model.output_blocks.5.0.time_stack.out_layers.3.weight": "blocks.128.conv2.weight",
+ "model.diffusion_model.output_blocks.5.1.norm.bias": "blocks.131.norm.bias",
+ "model.diffusion_model.output_blocks.5.1.norm.weight": "blocks.131.norm.weight",
+ "model.diffusion_model.output_blocks.5.1.proj_in.bias": "blocks.131.proj_in.bias",
+ "model.diffusion_model.output_blocks.5.1.proj_in.weight": "blocks.131.proj_in.weight",
+ "model.diffusion_model.output_blocks.5.1.proj_out.bias": "blocks.134.proj.bias",
+ "model.diffusion_model.output_blocks.5.1.proj_out.weight": "blocks.134.proj.weight",
+ "model.diffusion_model.output_blocks.5.1.time_mixer.mix_factor": "blocks.134.mix_factor",
+ "model.diffusion_model.output_blocks.5.1.time_pos_embed.0.bias": "blocks.133.positional_embedding_proj.0.bias",
+ "model.diffusion_model.output_blocks.5.1.time_pos_embed.0.weight": "blocks.133.positional_embedding_proj.0.weight",
+ "model.diffusion_model.output_blocks.5.1.time_pos_embed.2.bias": "blocks.133.positional_embedding_proj.2.bias",
+ "model.diffusion_model.output_blocks.5.1.time_pos_embed.2.weight": "blocks.133.positional_embedding_proj.2.weight",
+ "model.diffusion_model.output_blocks.5.1.time_stack.0.attn1.to_k.weight": "blocks.133.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.5.1.time_stack.0.attn1.to_out.0.bias": "blocks.133.attn1.to_out.bias",
+ "model.diffusion_model.output_blocks.5.1.time_stack.0.attn1.to_out.0.weight": "blocks.133.attn1.to_out.weight",
+ "model.diffusion_model.output_blocks.5.1.time_stack.0.attn1.to_q.weight": "blocks.133.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.5.1.time_stack.0.attn1.to_v.weight": "blocks.133.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.5.1.time_stack.0.attn2.to_k.weight": "blocks.133.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.5.1.time_stack.0.attn2.to_out.0.bias": "blocks.133.attn2.to_out.bias",
+ "model.diffusion_model.output_blocks.5.1.time_stack.0.attn2.to_out.0.weight": "blocks.133.attn2.to_out.weight",
+ "model.diffusion_model.output_blocks.5.1.time_stack.0.attn2.to_q.weight": "blocks.133.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.5.1.time_stack.0.attn2.to_v.weight": "blocks.133.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.5.1.time_stack.0.ff.net.0.proj.bias": "blocks.133.act_fn_out.proj.bias",
+ "model.diffusion_model.output_blocks.5.1.time_stack.0.ff.net.0.proj.weight": "blocks.133.act_fn_out.proj.weight",
+ "model.diffusion_model.output_blocks.5.1.time_stack.0.ff.net.2.bias": "blocks.133.ff_out.bias",
+ "model.diffusion_model.output_blocks.5.1.time_stack.0.ff.net.2.weight": "blocks.133.ff_out.weight",
+ "model.diffusion_model.output_blocks.5.1.time_stack.0.ff_in.net.0.proj.bias": "blocks.133.act_fn_in.proj.bias",
+ "model.diffusion_model.output_blocks.5.1.time_stack.0.ff_in.net.0.proj.weight": "blocks.133.act_fn_in.proj.weight",
+ "model.diffusion_model.output_blocks.5.1.time_stack.0.ff_in.net.2.bias": "blocks.133.ff_in.bias",
+ "model.diffusion_model.output_blocks.5.1.time_stack.0.ff_in.net.2.weight": "blocks.133.ff_in.weight",
+ "model.diffusion_model.output_blocks.5.1.time_stack.0.norm1.bias": "blocks.133.norm1.bias",
+ "model.diffusion_model.output_blocks.5.1.time_stack.0.norm1.weight": "blocks.133.norm1.weight",
+ "model.diffusion_model.output_blocks.5.1.time_stack.0.norm2.bias": "blocks.133.norm2.bias",
+ "model.diffusion_model.output_blocks.5.1.time_stack.0.norm2.weight": "blocks.133.norm2.weight",
+ "model.diffusion_model.output_blocks.5.1.time_stack.0.norm3.bias": "blocks.133.norm_out.bias",
+ "model.diffusion_model.output_blocks.5.1.time_stack.0.norm3.weight": "blocks.133.norm_out.weight",
+ "model.diffusion_model.output_blocks.5.1.time_stack.0.norm_in.bias": "blocks.133.norm_in.bias",
+ "model.diffusion_model.output_blocks.5.1.time_stack.0.norm_in.weight": "blocks.133.norm_in.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_k.weight": "blocks.131.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.131.transformer_blocks.0.attn1.to_out.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.131.transformer_blocks.0.attn1.to_out.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_q.weight": "blocks.131.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_v.weight": "blocks.131.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_k.weight": "blocks.131.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.131.transformer_blocks.0.attn2.to_out.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.131.transformer_blocks.0.attn2.to_out.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_q.weight": "blocks.131.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_v.weight": "blocks.131.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.131.transformer_blocks.0.act_fn.proj.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.131.transformer_blocks.0.act_fn.proj.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.2.bias": "blocks.131.transformer_blocks.0.ff.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.2.weight": "blocks.131.transformer_blocks.0.ff.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm1.bias": "blocks.131.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm1.weight": "blocks.131.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm2.bias": "blocks.131.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm2.weight": "blocks.131.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm3.bias": "blocks.131.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm3.weight": "blocks.131.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.output_blocks.5.2.conv.bias": "blocks.135.conv.bias",
+ "model.diffusion_model.output_blocks.5.2.conv.weight": "blocks.135.conv.weight",
+ "model.diffusion_model.output_blocks.6.0.emb_layers.1.bias": "blocks.137.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.6.0.emb_layers.1.weight": "blocks.137.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.6.0.in_layers.0.bias": "blocks.137.norm1.bias",
+ "model.diffusion_model.output_blocks.6.0.in_layers.0.weight": "blocks.137.norm1.weight",
+ "model.diffusion_model.output_blocks.6.0.in_layers.2.bias": "blocks.137.conv1.bias",
+ "model.diffusion_model.output_blocks.6.0.in_layers.2.weight": "blocks.137.conv1.weight",
+ "model.diffusion_model.output_blocks.6.0.out_layers.0.bias": "blocks.137.norm2.bias",
+ "model.diffusion_model.output_blocks.6.0.out_layers.0.weight": "blocks.137.norm2.weight",
+ "model.diffusion_model.output_blocks.6.0.out_layers.3.bias": "blocks.137.conv2.bias",
+ "model.diffusion_model.output_blocks.6.0.out_layers.3.weight": "blocks.137.conv2.weight",
+ "model.diffusion_model.output_blocks.6.0.skip_connection.bias": "blocks.137.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.6.0.skip_connection.weight": "blocks.137.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.6.0.time_mixer.mix_factor": "blocks.140.mix_factor",
+ "model.diffusion_model.output_blocks.6.0.time_stack.emb_layers.1.bias": "blocks.139.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.6.0.time_stack.emb_layers.1.weight": "blocks.139.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.6.0.time_stack.in_layers.0.bias": "blocks.139.norm1.bias",
+ "model.diffusion_model.output_blocks.6.0.time_stack.in_layers.0.weight": "blocks.139.norm1.weight",
+ "model.diffusion_model.output_blocks.6.0.time_stack.in_layers.2.bias": "blocks.139.conv1.bias",
+ "model.diffusion_model.output_blocks.6.0.time_stack.in_layers.2.weight": "blocks.139.conv1.weight",
+ "model.diffusion_model.output_blocks.6.0.time_stack.out_layers.0.bias": "blocks.139.norm2.bias",
+ "model.diffusion_model.output_blocks.6.0.time_stack.out_layers.0.weight": "blocks.139.norm2.weight",
+ "model.diffusion_model.output_blocks.6.0.time_stack.out_layers.3.bias": "blocks.139.conv2.bias",
+ "model.diffusion_model.output_blocks.6.0.time_stack.out_layers.3.weight": "blocks.139.conv2.weight",
+ "model.diffusion_model.output_blocks.6.1.norm.bias": "blocks.142.norm.bias",
+ "model.diffusion_model.output_blocks.6.1.norm.weight": "blocks.142.norm.weight",
+ "model.diffusion_model.output_blocks.6.1.proj_in.bias": "blocks.142.proj_in.bias",
+ "model.diffusion_model.output_blocks.6.1.proj_in.weight": "blocks.142.proj_in.weight",
+ "model.diffusion_model.output_blocks.6.1.proj_out.bias": "blocks.145.proj.bias",
+ "model.diffusion_model.output_blocks.6.1.proj_out.weight": "blocks.145.proj.weight",
+ "model.diffusion_model.output_blocks.6.1.time_mixer.mix_factor": "blocks.145.mix_factor",
+ "model.diffusion_model.output_blocks.6.1.time_pos_embed.0.bias": "blocks.144.positional_embedding_proj.0.bias",
+ "model.diffusion_model.output_blocks.6.1.time_pos_embed.0.weight": "blocks.144.positional_embedding_proj.0.weight",
+ "model.diffusion_model.output_blocks.6.1.time_pos_embed.2.bias": "blocks.144.positional_embedding_proj.2.bias",
+ "model.diffusion_model.output_blocks.6.1.time_pos_embed.2.weight": "blocks.144.positional_embedding_proj.2.weight",
+ "model.diffusion_model.output_blocks.6.1.time_stack.0.attn1.to_k.weight": "blocks.144.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.6.1.time_stack.0.attn1.to_out.0.bias": "blocks.144.attn1.to_out.bias",
+ "model.diffusion_model.output_blocks.6.1.time_stack.0.attn1.to_out.0.weight": "blocks.144.attn1.to_out.weight",
+ "model.diffusion_model.output_blocks.6.1.time_stack.0.attn1.to_q.weight": "blocks.144.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.6.1.time_stack.0.attn1.to_v.weight": "blocks.144.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.6.1.time_stack.0.attn2.to_k.weight": "blocks.144.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.6.1.time_stack.0.attn2.to_out.0.bias": "blocks.144.attn2.to_out.bias",
+ "model.diffusion_model.output_blocks.6.1.time_stack.0.attn2.to_out.0.weight": "blocks.144.attn2.to_out.weight",
+ "model.diffusion_model.output_blocks.6.1.time_stack.0.attn2.to_q.weight": "blocks.144.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.6.1.time_stack.0.attn2.to_v.weight": "blocks.144.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.6.1.time_stack.0.ff.net.0.proj.bias": "blocks.144.act_fn_out.proj.bias",
+ "model.diffusion_model.output_blocks.6.1.time_stack.0.ff.net.0.proj.weight": "blocks.144.act_fn_out.proj.weight",
+ "model.diffusion_model.output_blocks.6.1.time_stack.0.ff.net.2.bias": "blocks.144.ff_out.bias",
+ "model.diffusion_model.output_blocks.6.1.time_stack.0.ff.net.2.weight": "blocks.144.ff_out.weight",
+ "model.diffusion_model.output_blocks.6.1.time_stack.0.ff_in.net.0.proj.bias": "blocks.144.act_fn_in.proj.bias",
+ "model.diffusion_model.output_blocks.6.1.time_stack.0.ff_in.net.0.proj.weight": "blocks.144.act_fn_in.proj.weight",
+ "model.diffusion_model.output_blocks.6.1.time_stack.0.ff_in.net.2.bias": "blocks.144.ff_in.bias",
+ "model.diffusion_model.output_blocks.6.1.time_stack.0.ff_in.net.2.weight": "blocks.144.ff_in.weight",
+ "model.diffusion_model.output_blocks.6.1.time_stack.0.norm1.bias": "blocks.144.norm1.bias",
+ "model.diffusion_model.output_blocks.6.1.time_stack.0.norm1.weight": "blocks.144.norm1.weight",
+ "model.diffusion_model.output_blocks.6.1.time_stack.0.norm2.bias": "blocks.144.norm2.bias",
+ "model.diffusion_model.output_blocks.6.1.time_stack.0.norm2.weight": "blocks.144.norm2.weight",
+ "model.diffusion_model.output_blocks.6.1.time_stack.0.norm3.bias": "blocks.144.norm_out.bias",
+ "model.diffusion_model.output_blocks.6.1.time_stack.0.norm3.weight": "blocks.144.norm_out.weight",
+ "model.diffusion_model.output_blocks.6.1.time_stack.0.norm_in.bias": "blocks.144.norm_in.bias",
+ "model.diffusion_model.output_blocks.6.1.time_stack.0.norm_in.weight": "blocks.144.norm_in.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_k.weight": "blocks.142.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.142.transformer_blocks.0.attn1.to_out.bias",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.142.transformer_blocks.0.attn1.to_out.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_q.weight": "blocks.142.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_v.weight": "blocks.142.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_k.weight": "blocks.142.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.142.transformer_blocks.0.attn2.to_out.bias",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.142.transformer_blocks.0.attn2.to_out.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_q.weight": "blocks.142.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_v.weight": "blocks.142.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.142.transformer_blocks.0.act_fn.proj.bias",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.142.transformer_blocks.0.act_fn.proj.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.2.bias": "blocks.142.transformer_blocks.0.ff.bias",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.2.weight": "blocks.142.transformer_blocks.0.ff.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm1.bias": "blocks.142.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm1.weight": "blocks.142.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm2.bias": "blocks.142.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm2.weight": "blocks.142.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm3.bias": "blocks.142.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm3.weight": "blocks.142.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.output_blocks.7.0.emb_layers.1.bias": "blocks.147.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.7.0.emb_layers.1.weight": "blocks.147.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.7.0.in_layers.0.bias": "blocks.147.norm1.bias",
+ "model.diffusion_model.output_blocks.7.0.in_layers.0.weight": "blocks.147.norm1.weight",
+ "model.diffusion_model.output_blocks.7.0.in_layers.2.bias": "blocks.147.conv1.bias",
+ "model.diffusion_model.output_blocks.7.0.in_layers.2.weight": "blocks.147.conv1.weight",
+ "model.diffusion_model.output_blocks.7.0.out_layers.0.bias": "blocks.147.norm2.bias",
+ "model.diffusion_model.output_blocks.7.0.out_layers.0.weight": "blocks.147.norm2.weight",
+ "model.diffusion_model.output_blocks.7.0.out_layers.3.bias": "blocks.147.conv2.bias",
+ "model.diffusion_model.output_blocks.7.0.out_layers.3.weight": "blocks.147.conv2.weight",
+ "model.diffusion_model.output_blocks.7.0.skip_connection.bias": "blocks.147.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.7.0.skip_connection.weight": "blocks.147.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.7.0.time_mixer.mix_factor": "blocks.150.mix_factor",
+ "model.diffusion_model.output_blocks.7.0.time_stack.emb_layers.1.bias": "blocks.149.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.7.0.time_stack.emb_layers.1.weight": "blocks.149.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.7.0.time_stack.in_layers.0.bias": "blocks.149.norm1.bias",
+ "model.diffusion_model.output_blocks.7.0.time_stack.in_layers.0.weight": "blocks.149.norm1.weight",
+ "model.diffusion_model.output_blocks.7.0.time_stack.in_layers.2.bias": "blocks.149.conv1.bias",
+ "model.diffusion_model.output_blocks.7.0.time_stack.in_layers.2.weight": "blocks.149.conv1.weight",
+ "model.diffusion_model.output_blocks.7.0.time_stack.out_layers.0.bias": "blocks.149.norm2.bias",
+ "model.diffusion_model.output_blocks.7.0.time_stack.out_layers.0.weight": "blocks.149.norm2.weight",
+ "model.diffusion_model.output_blocks.7.0.time_stack.out_layers.3.bias": "blocks.149.conv2.bias",
+ "model.diffusion_model.output_blocks.7.0.time_stack.out_layers.3.weight": "blocks.149.conv2.weight",
+ "model.diffusion_model.output_blocks.7.1.norm.bias": "blocks.152.norm.bias",
+ "model.diffusion_model.output_blocks.7.1.norm.weight": "blocks.152.norm.weight",
+ "model.diffusion_model.output_blocks.7.1.proj_in.bias": "blocks.152.proj_in.bias",
+ "model.diffusion_model.output_blocks.7.1.proj_in.weight": "blocks.152.proj_in.weight",
+ "model.diffusion_model.output_blocks.7.1.proj_out.bias": "blocks.155.proj.bias",
+ "model.diffusion_model.output_blocks.7.1.proj_out.weight": "blocks.155.proj.weight",
+ "model.diffusion_model.output_blocks.7.1.time_mixer.mix_factor": "blocks.155.mix_factor",
+ "model.diffusion_model.output_blocks.7.1.time_pos_embed.0.bias": "blocks.154.positional_embedding_proj.0.bias",
+ "model.diffusion_model.output_blocks.7.1.time_pos_embed.0.weight": "blocks.154.positional_embedding_proj.0.weight",
+ "model.diffusion_model.output_blocks.7.1.time_pos_embed.2.bias": "blocks.154.positional_embedding_proj.2.bias",
+ "model.diffusion_model.output_blocks.7.1.time_pos_embed.2.weight": "blocks.154.positional_embedding_proj.2.weight",
+ "model.diffusion_model.output_blocks.7.1.time_stack.0.attn1.to_k.weight": "blocks.154.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.7.1.time_stack.0.attn1.to_out.0.bias": "blocks.154.attn1.to_out.bias",
+ "model.diffusion_model.output_blocks.7.1.time_stack.0.attn1.to_out.0.weight": "blocks.154.attn1.to_out.weight",
+ "model.diffusion_model.output_blocks.7.1.time_stack.0.attn1.to_q.weight": "blocks.154.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.7.1.time_stack.0.attn1.to_v.weight": "blocks.154.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.7.1.time_stack.0.attn2.to_k.weight": "blocks.154.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.7.1.time_stack.0.attn2.to_out.0.bias": "blocks.154.attn2.to_out.bias",
+ "model.diffusion_model.output_blocks.7.1.time_stack.0.attn2.to_out.0.weight": "blocks.154.attn2.to_out.weight",
+ "model.diffusion_model.output_blocks.7.1.time_stack.0.attn2.to_q.weight": "blocks.154.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.7.1.time_stack.0.attn2.to_v.weight": "blocks.154.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.7.1.time_stack.0.ff.net.0.proj.bias": "blocks.154.act_fn_out.proj.bias",
+ "model.diffusion_model.output_blocks.7.1.time_stack.0.ff.net.0.proj.weight": "blocks.154.act_fn_out.proj.weight",
+ "model.diffusion_model.output_blocks.7.1.time_stack.0.ff.net.2.bias": "blocks.154.ff_out.bias",
+ "model.diffusion_model.output_blocks.7.1.time_stack.0.ff.net.2.weight": "blocks.154.ff_out.weight",
+ "model.diffusion_model.output_blocks.7.1.time_stack.0.ff_in.net.0.proj.bias": "blocks.154.act_fn_in.proj.bias",
+ "model.diffusion_model.output_blocks.7.1.time_stack.0.ff_in.net.0.proj.weight": "blocks.154.act_fn_in.proj.weight",
+ "model.diffusion_model.output_blocks.7.1.time_stack.0.ff_in.net.2.bias": "blocks.154.ff_in.bias",
+ "model.diffusion_model.output_blocks.7.1.time_stack.0.ff_in.net.2.weight": "blocks.154.ff_in.weight",
+ "model.diffusion_model.output_blocks.7.1.time_stack.0.norm1.bias": "blocks.154.norm1.bias",
+ "model.diffusion_model.output_blocks.7.1.time_stack.0.norm1.weight": "blocks.154.norm1.weight",
+ "model.diffusion_model.output_blocks.7.1.time_stack.0.norm2.bias": "blocks.154.norm2.bias",
+ "model.diffusion_model.output_blocks.7.1.time_stack.0.norm2.weight": "blocks.154.norm2.weight",
+ "model.diffusion_model.output_blocks.7.1.time_stack.0.norm3.bias": "blocks.154.norm_out.bias",
+ "model.diffusion_model.output_blocks.7.1.time_stack.0.norm3.weight": "blocks.154.norm_out.weight",
+ "model.diffusion_model.output_blocks.7.1.time_stack.0.norm_in.bias": "blocks.154.norm_in.bias",
+ "model.diffusion_model.output_blocks.7.1.time_stack.0.norm_in.weight": "blocks.154.norm_in.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_k.weight": "blocks.152.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.152.transformer_blocks.0.attn1.to_out.bias",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.152.transformer_blocks.0.attn1.to_out.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_q.weight": "blocks.152.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_v.weight": "blocks.152.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_k.weight": "blocks.152.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.152.transformer_blocks.0.attn2.to_out.bias",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.152.transformer_blocks.0.attn2.to_out.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_q.weight": "blocks.152.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_v.weight": "blocks.152.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.152.transformer_blocks.0.act_fn.proj.bias",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.152.transformer_blocks.0.act_fn.proj.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.2.bias": "blocks.152.transformer_blocks.0.ff.bias",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.2.weight": "blocks.152.transformer_blocks.0.ff.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm1.bias": "blocks.152.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm1.weight": "blocks.152.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm2.bias": "blocks.152.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm2.weight": "blocks.152.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm3.bias": "blocks.152.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm3.weight": "blocks.152.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.output_blocks.8.0.emb_layers.1.bias": "blocks.157.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.8.0.emb_layers.1.weight": "blocks.157.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.8.0.in_layers.0.bias": "blocks.157.norm1.bias",
+ "model.diffusion_model.output_blocks.8.0.in_layers.0.weight": "blocks.157.norm1.weight",
+ "model.diffusion_model.output_blocks.8.0.in_layers.2.bias": "blocks.157.conv1.bias",
+ "model.diffusion_model.output_blocks.8.0.in_layers.2.weight": "blocks.157.conv1.weight",
+ "model.diffusion_model.output_blocks.8.0.out_layers.0.bias": "blocks.157.norm2.bias",
+ "model.diffusion_model.output_blocks.8.0.out_layers.0.weight": "blocks.157.norm2.weight",
+ "model.diffusion_model.output_blocks.8.0.out_layers.3.bias": "blocks.157.conv2.bias",
+ "model.diffusion_model.output_blocks.8.0.out_layers.3.weight": "blocks.157.conv2.weight",
+ "model.diffusion_model.output_blocks.8.0.skip_connection.bias": "blocks.157.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.8.0.skip_connection.weight": "blocks.157.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.8.0.time_mixer.mix_factor": "blocks.160.mix_factor",
+ "model.diffusion_model.output_blocks.8.0.time_stack.emb_layers.1.bias": "blocks.159.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.8.0.time_stack.emb_layers.1.weight": "blocks.159.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.8.0.time_stack.in_layers.0.bias": "blocks.159.norm1.bias",
+ "model.diffusion_model.output_blocks.8.0.time_stack.in_layers.0.weight": "blocks.159.norm1.weight",
+ "model.diffusion_model.output_blocks.8.0.time_stack.in_layers.2.bias": "blocks.159.conv1.bias",
+ "model.diffusion_model.output_blocks.8.0.time_stack.in_layers.2.weight": "blocks.159.conv1.weight",
+ "model.diffusion_model.output_blocks.8.0.time_stack.out_layers.0.bias": "blocks.159.norm2.bias",
+ "model.diffusion_model.output_blocks.8.0.time_stack.out_layers.0.weight": "blocks.159.norm2.weight",
+ "model.diffusion_model.output_blocks.8.0.time_stack.out_layers.3.bias": "blocks.159.conv2.bias",
+ "model.diffusion_model.output_blocks.8.0.time_stack.out_layers.3.weight": "blocks.159.conv2.weight",
+ "model.diffusion_model.output_blocks.8.1.norm.bias": "blocks.162.norm.bias",
+ "model.diffusion_model.output_blocks.8.1.norm.weight": "blocks.162.norm.weight",
+ "model.diffusion_model.output_blocks.8.1.proj_in.bias": "blocks.162.proj_in.bias",
+ "model.diffusion_model.output_blocks.8.1.proj_in.weight": "blocks.162.proj_in.weight",
+ "model.diffusion_model.output_blocks.8.1.proj_out.bias": "blocks.165.proj.bias",
+ "model.diffusion_model.output_blocks.8.1.proj_out.weight": "blocks.165.proj.weight",
+ "model.diffusion_model.output_blocks.8.1.time_mixer.mix_factor": "blocks.165.mix_factor",
+ "model.diffusion_model.output_blocks.8.1.time_pos_embed.0.bias": "blocks.164.positional_embedding_proj.0.bias",
+ "model.diffusion_model.output_blocks.8.1.time_pos_embed.0.weight": "blocks.164.positional_embedding_proj.0.weight",
+ "model.diffusion_model.output_blocks.8.1.time_pos_embed.2.bias": "blocks.164.positional_embedding_proj.2.bias",
+ "model.diffusion_model.output_blocks.8.1.time_pos_embed.2.weight": "blocks.164.positional_embedding_proj.2.weight",
+ "model.diffusion_model.output_blocks.8.1.time_stack.0.attn1.to_k.weight": "blocks.164.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.8.1.time_stack.0.attn1.to_out.0.bias": "blocks.164.attn1.to_out.bias",
+ "model.diffusion_model.output_blocks.8.1.time_stack.0.attn1.to_out.0.weight": "blocks.164.attn1.to_out.weight",
+ "model.diffusion_model.output_blocks.8.1.time_stack.0.attn1.to_q.weight": "blocks.164.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.8.1.time_stack.0.attn1.to_v.weight": "blocks.164.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.8.1.time_stack.0.attn2.to_k.weight": "blocks.164.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.8.1.time_stack.0.attn2.to_out.0.bias": "blocks.164.attn2.to_out.bias",
+ "model.diffusion_model.output_blocks.8.1.time_stack.0.attn2.to_out.0.weight": "blocks.164.attn2.to_out.weight",
+ "model.diffusion_model.output_blocks.8.1.time_stack.0.attn2.to_q.weight": "blocks.164.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.8.1.time_stack.0.attn2.to_v.weight": "blocks.164.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.8.1.time_stack.0.ff.net.0.proj.bias": "blocks.164.act_fn_out.proj.bias",
+ "model.diffusion_model.output_blocks.8.1.time_stack.0.ff.net.0.proj.weight": "blocks.164.act_fn_out.proj.weight",
+ "model.diffusion_model.output_blocks.8.1.time_stack.0.ff.net.2.bias": "blocks.164.ff_out.bias",
+ "model.diffusion_model.output_blocks.8.1.time_stack.0.ff.net.2.weight": "blocks.164.ff_out.weight",
+ "model.diffusion_model.output_blocks.8.1.time_stack.0.ff_in.net.0.proj.bias": "blocks.164.act_fn_in.proj.bias",
+ "model.diffusion_model.output_blocks.8.1.time_stack.0.ff_in.net.0.proj.weight": "blocks.164.act_fn_in.proj.weight",
+ "model.diffusion_model.output_blocks.8.1.time_stack.0.ff_in.net.2.bias": "blocks.164.ff_in.bias",
+ "model.diffusion_model.output_blocks.8.1.time_stack.0.ff_in.net.2.weight": "blocks.164.ff_in.weight",
+ "model.diffusion_model.output_blocks.8.1.time_stack.0.norm1.bias": "blocks.164.norm1.bias",
+ "model.diffusion_model.output_blocks.8.1.time_stack.0.norm1.weight": "blocks.164.norm1.weight",
+ "model.diffusion_model.output_blocks.8.1.time_stack.0.norm2.bias": "blocks.164.norm2.bias",
+ "model.diffusion_model.output_blocks.8.1.time_stack.0.norm2.weight": "blocks.164.norm2.weight",
+ "model.diffusion_model.output_blocks.8.1.time_stack.0.norm3.bias": "blocks.164.norm_out.bias",
+ "model.diffusion_model.output_blocks.8.1.time_stack.0.norm3.weight": "blocks.164.norm_out.weight",
+ "model.diffusion_model.output_blocks.8.1.time_stack.0.norm_in.bias": "blocks.164.norm_in.bias",
+ "model.diffusion_model.output_blocks.8.1.time_stack.0.norm_in.weight": "blocks.164.norm_in.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_k.weight": "blocks.162.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.162.transformer_blocks.0.attn1.to_out.bias",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.162.transformer_blocks.0.attn1.to_out.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_q.weight": "blocks.162.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_v.weight": "blocks.162.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_k.weight": "blocks.162.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.162.transformer_blocks.0.attn2.to_out.bias",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.162.transformer_blocks.0.attn2.to_out.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_q.weight": "blocks.162.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_v.weight": "blocks.162.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.162.transformer_blocks.0.act_fn.proj.bias",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.162.transformer_blocks.0.act_fn.proj.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.2.bias": "blocks.162.transformer_blocks.0.ff.bias",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.2.weight": "blocks.162.transformer_blocks.0.ff.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm1.bias": "blocks.162.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm1.weight": "blocks.162.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm2.bias": "blocks.162.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm2.weight": "blocks.162.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm3.bias": "blocks.162.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm3.weight": "blocks.162.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.output_blocks.8.2.conv.bias": "blocks.166.conv.bias",
+ "model.diffusion_model.output_blocks.8.2.conv.weight": "blocks.166.conv.weight",
+ "model.diffusion_model.output_blocks.9.0.emb_layers.1.bias": "blocks.168.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.9.0.emb_layers.1.weight": "blocks.168.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.9.0.in_layers.0.bias": "blocks.168.norm1.bias",
+ "model.diffusion_model.output_blocks.9.0.in_layers.0.weight": "blocks.168.norm1.weight",
+ "model.diffusion_model.output_blocks.9.0.in_layers.2.bias": "blocks.168.conv1.bias",
+ "model.diffusion_model.output_blocks.9.0.in_layers.2.weight": "blocks.168.conv1.weight",
+ "model.diffusion_model.output_blocks.9.0.out_layers.0.bias": "blocks.168.norm2.bias",
+ "model.diffusion_model.output_blocks.9.0.out_layers.0.weight": "blocks.168.norm2.weight",
+ "model.diffusion_model.output_blocks.9.0.out_layers.3.bias": "blocks.168.conv2.bias",
+ "model.diffusion_model.output_blocks.9.0.out_layers.3.weight": "blocks.168.conv2.weight",
+ "model.diffusion_model.output_blocks.9.0.skip_connection.bias": "blocks.168.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.9.0.skip_connection.weight": "blocks.168.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.9.0.time_mixer.mix_factor": "blocks.171.mix_factor",
+ "model.diffusion_model.output_blocks.9.0.time_stack.emb_layers.1.bias": "blocks.170.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.9.0.time_stack.emb_layers.1.weight": "blocks.170.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.9.0.time_stack.in_layers.0.bias": "blocks.170.norm1.bias",
+ "model.diffusion_model.output_blocks.9.0.time_stack.in_layers.0.weight": "blocks.170.norm1.weight",
+ "model.diffusion_model.output_blocks.9.0.time_stack.in_layers.2.bias": "blocks.170.conv1.bias",
+ "model.diffusion_model.output_blocks.9.0.time_stack.in_layers.2.weight": "blocks.170.conv1.weight",
+ "model.diffusion_model.output_blocks.9.0.time_stack.out_layers.0.bias": "blocks.170.norm2.bias",
+ "model.diffusion_model.output_blocks.9.0.time_stack.out_layers.0.weight": "blocks.170.norm2.weight",
+ "model.diffusion_model.output_blocks.9.0.time_stack.out_layers.3.bias": "blocks.170.conv2.bias",
+ "model.diffusion_model.output_blocks.9.0.time_stack.out_layers.3.weight": "blocks.170.conv2.weight",
+ "model.diffusion_model.output_blocks.9.1.norm.bias": "blocks.173.norm.bias",
+ "model.diffusion_model.output_blocks.9.1.norm.weight": "blocks.173.norm.weight",
+ "model.diffusion_model.output_blocks.9.1.proj_in.bias": "blocks.173.proj_in.bias",
+ "model.diffusion_model.output_blocks.9.1.proj_in.weight": "blocks.173.proj_in.weight",
+ "model.diffusion_model.output_blocks.9.1.proj_out.bias": "blocks.176.proj.bias",
+ "model.diffusion_model.output_blocks.9.1.proj_out.weight": "blocks.176.proj.weight",
+ "model.diffusion_model.output_blocks.9.1.time_mixer.mix_factor": "blocks.176.mix_factor",
+ "model.diffusion_model.output_blocks.9.1.time_pos_embed.0.bias": "blocks.175.positional_embedding_proj.0.bias",
+ "model.diffusion_model.output_blocks.9.1.time_pos_embed.0.weight": "blocks.175.positional_embedding_proj.0.weight",
+ "model.diffusion_model.output_blocks.9.1.time_pos_embed.2.bias": "blocks.175.positional_embedding_proj.2.bias",
+ "model.diffusion_model.output_blocks.9.1.time_pos_embed.2.weight": "blocks.175.positional_embedding_proj.2.weight",
+ "model.diffusion_model.output_blocks.9.1.time_stack.0.attn1.to_k.weight": "blocks.175.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.9.1.time_stack.0.attn1.to_out.0.bias": "blocks.175.attn1.to_out.bias",
+ "model.diffusion_model.output_blocks.9.1.time_stack.0.attn1.to_out.0.weight": "blocks.175.attn1.to_out.weight",
+ "model.diffusion_model.output_blocks.9.1.time_stack.0.attn1.to_q.weight": "blocks.175.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.9.1.time_stack.0.attn1.to_v.weight": "blocks.175.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.9.1.time_stack.0.attn2.to_k.weight": "blocks.175.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.9.1.time_stack.0.attn2.to_out.0.bias": "blocks.175.attn2.to_out.bias",
+ "model.diffusion_model.output_blocks.9.1.time_stack.0.attn2.to_out.0.weight": "blocks.175.attn2.to_out.weight",
+ "model.diffusion_model.output_blocks.9.1.time_stack.0.attn2.to_q.weight": "blocks.175.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.9.1.time_stack.0.attn2.to_v.weight": "blocks.175.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.9.1.time_stack.0.ff.net.0.proj.bias": "blocks.175.act_fn_out.proj.bias",
+ "model.diffusion_model.output_blocks.9.1.time_stack.0.ff.net.0.proj.weight": "blocks.175.act_fn_out.proj.weight",
+ "model.diffusion_model.output_blocks.9.1.time_stack.0.ff.net.2.bias": "blocks.175.ff_out.bias",
+ "model.diffusion_model.output_blocks.9.1.time_stack.0.ff.net.2.weight": "blocks.175.ff_out.weight",
+ "model.diffusion_model.output_blocks.9.1.time_stack.0.ff_in.net.0.proj.bias": "blocks.175.act_fn_in.proj.bias",
+ "model.diffusion_model.output_blocks.9.1.time_stack.0.ff_in.net.0.proj.weight": "blocks.175.act_fn_in.proj.weight",
+ "model.diffusion_model.output_blocks.9.1.time_stack.0.ff_in.net.2.bias": "blocks.175.ff_in.bias",
+ "model.diffusion_model.output_blocks.9.1.time_stack.0.ff_in.net.2.weight": "blocks.175.ff_in.weight",
+ "model.diffusion_model.output_blocks.9.1.time_stack.0.norm1.bias": "blocks.175.norm1.bias",
+ "model.diffusion_model.output_blocks.9.1.time_stack.0.norm1.weight": "blocks.175.norm1.weight",
+ "model.diffusion_model.output_blocks.9.1.time_stack.0.norm2.bias": "blocks.175.norm2.bias",
+ "model.diffusion_model.output_blocks.9.1.time_stack.0.norm2.weight": "blocks.175.norm2.weight",
+ "model.diffusion_model.output_blocks.9.1.time_stack.0.norm3.bias": "blocks.175.norm_out.bias",
+ "model.diffusion_model.output_blocks.9.1.time_stack.0.norm3.weight": "blocks.175.norm_out.weight",
+ "model.diffusion_model.output_blocks.9.1.time_stack.0.norm_in.bias": "blocks.175.norm_in.bias",
+ "model.diffusion_model.output_blocks.9.1.time_stack.0.norm_in.weight": "blocks.175.norm_in.weight",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_k.weight": "blocks.173.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.173.transformer_blocks.0.attn1.to_out.bias",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.173.transformer_blocks.0.attn1.to_out.weight",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_q.weight": "blocks.173.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_v.weight": "blocks.173.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_k.weight": "blocks.173.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.173.transformer_blocks.0.attn2.to_out.bias",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.173.transformer_blocks.0.attn2.to_out.weight",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_q.weight": "blocks.173.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_v.weight": "blocks.173.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.173.transformer_blocks.0.act_fn.proj.bias",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.173.transformer_blocks.0.act_fn.proj.weight",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.ff.net.2.bias": "blocks.173.transformer_blocks.0.ff.bias",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.ff.net.2.weight": "blocks.173.transformer_blocks.0.ff.weight",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm1.bias": "blocks.173.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm1.weight": "blocks.173.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm2.bias": "blocks.173.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm2.weight": "blocks.173.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm3.bias": "blocks.173.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm3.weight": "blocks.173.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.time_embed.0.bias": "time_embedding.0.bias",
+ "model.diffusion_model.time_embed.0.weight": "time_embedding.0.weight",
+ "model.diffusion_model.time_embed.2.bias": "time_embedding.2.bias",
+ "model.diffusion_model.time_embed.2.weight": "time_embedding.2.weight",
+ }
+ state_dict_ = {}
+ for name in state_dict:
+ if name in rename_dict:
+ param = state_dict[name]
+ if ".proj_in." in name or ".proj_out." in name:
+ param = param.squeeze()
+ state_dict_[rename_dict[name]] = param
+ if add_positional_conv is not None:
+ extra_names = [
+ "blocks.7.positional_conv", "blocks.17.positional_conv", "blocks.29.positional_conv", "blocks.39.positional_conv",
+ "blocks.51.positional_conv", "blocks.61.positional_conv", "blocks.83.positional_conv", "blocks.113.positional_conv",
+ "blocks.123.positional_conv", "blocks.133.positional_conv", "blocks.144.positional_conv", "blocks.154.positional_conv",
+ "blocks.164.positional_conv", "blocks.175.positional_conv", "blocks.185.positional_conv", "blocks.195.positional_conv",
+ ]
+ extra_channels = [320, 320, 640, 640, 1280, 1280, 1280, 1280, 1280, 1280, 640, 640, 640, 320, 320, 320]
+ for name, channels in zip(extra_names, extra_channels):
+ weight = torch.zeros((channels, channels, 3, 3, 3))
+ weight[:,:,1,1,1] = torch.eye(channels, channels)
+ bias = torch.zeros((channels,))
+ state_dict_[name + ".weight"] = weight
+ state_dict_[name + ".bias"] = bias
+ return state_dict_
diff --git a/PusaV1/diffsynth/models/svd_vae_decoder.py b/PusaV1/diffsynth/models/svd_vae_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4815961136820dc1b863573a559076b68fd785a
--- /dev/null
+++ b/PusaV1/diffsynth/models/svd_vae_decoder.py
@@ -0,0 +1,578 @@
+import torch
+from .attention import Attention
+from .sd_unet import ResnetBlock, UpSampler
+from .tiler import TileWorker
+from einops import rearrange, repeat
+
+
+class VAEAttentionBlock(torch.nn.Module):
+
+ def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5):
+ super().__init__()
+ inner_dim = num_attention_heads * attention_head_dim
+
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True)
+
+ self.transformer_blocks = torch.nn.ModuleList([
+ Attention(
+ inner_dim,
+ num_attention_heads,
+ attention_head_dim,
+ bias_q=True,
+ bias_kv=True,
+ bias_out=True
+ )
+ for d in range(num_layers)
+ ])
+
+ def forward(self, hidden_states, time_emb, text_emb, res_stack):
+ batch, _, height, width = hidden_states.shape
+ residual = hidden_states
+
+ hidden_states = self.norm(hidden_states)
+ inner_dim = hidden_states.shape[1]
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
+
+ for block in self.transformer_blocks:
+ hidden_states = block(hidden_states)
+
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
+ hidden_states = hidden_states + residual
+
+ return hidden_states, time_emb, text_emb, res_stack
+
+
+class TemporalResnetBlock(torch.nn.Module):
+
+ def __init__(self, in_channels, out_channels, groups=32, eps=1e-5):
+ super().__init__()
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
+ self.conv1 = torch.nn.Conv3d(in_channels, out_channels, kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0))
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True)
+ self.conv2 = torch.nn.Conv3d(out_channels, out_channels, kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0))
+ self.nonlinearity = torch.nn.SiLU()
+ self.mix_factor = torch.nn.Parameter(torch.Tensor([0.5]))
+
+ def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs):
+ x_spatial = hidden_states
+ x = rearrange(hidden_states, "T C H W -> 1 C T H W")
+ x = self.norm1(x)
+ x = self.nonlinearity(x)
+ x = self.conv1(x)
+ x = self.norm2(x)
+ x = self.nonlinearity(x)
+ x = self.conv2(x)
+ x_temporal = hidden_states + x[0].permute(1, 0, 2, 3)
+ alpha = torch.sigmoid(self.mix_factor)
+ hidden_states = alpha * x_temporal + (1 - alpha) * x_spatial
+ return hidden_states, time_emb, text_emb, res_stack
+
+
+class SVDVAEDecoder(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.scaling_factor = 0.18215
+ self.conv_in = torch.nn.Conv2d(4, 512, kernel_size=3, padding=1)
+
+ self.blocks = torch.nn.ModuleList([
+ # UNetMidBlock
+ ResnetBlock(512, 512, eps=1e-6),
+ TemporalResnetBlock(512, 512, eps=1e-6),
+ VAEAttentionBlock(1, 512, 512, 1, eps=1e-6),
+ ResnetBlock(512, 512, eps=1e-6),
+ TemporalResnetBlock(512, 512, eps=1e-6),
+ # UpDecoderBlock
+ ResnetBlock(512, 512, eps=1e-6),
+ TemporalResnetBlock(512, 512, eps=1e-6),
+ ResnetBlock(512, 512, eps=1e-6),
+ TemporalResnetBlock(512, 512, eps=1e-6),
+ ResnetBlock(512, 512, eps=1e-6),
+ TemporalResnetBlock(512, 512, eps=1e-6),
+ UpSampler(512),
+ # UpDecoderBlock
+ ResnetBlock(512, 512, eps=1e-6),
+ TemporalResnetBlock(512, 512, eps=1e-6),
+ ResnetBlock(512, 512, eps=1e-6),
+ TemporalResnetBlock(512, 512, eps=1e-6),
+ ResnetBlock(512, 512, eps=1e-6),
+ TemporalResnetBlock(512, 512, eps=1e-6),
+ UpSampler(512),
+ # UpDecoderBlock
+ ResnetBlock(512, 256, eps=1e-6),
+ TemporalResnetBlock(256, 256, eps=1e-6),
+ ResnetBlock(256, 256, eps=1e-6),
+ TemporalResnetBlock(256, 256, eps=1e-6),
+ ResnetBlock(256, 256, eps=1e-6),
+ TemporalResnetBlock(256, 256, eps=1e-6),
+ UpSampler(256),
+ # UpDecoderBlock
+ ResnetBlock(256, 128, eps=1e-6),
+ TemporalResnetBlock(128, 128, eps=1e-6),
+ ResnetBlock(128, 128, eps=1e-6),
+ TemporalResnetBlock(128, 128, eps=1e-6),
+ ResnetBlock(128, 128, eps=1e-6),
+ TemporalResnetBlock(128, 128, eps=1e-6),
+ ])
+
+ self.conv_norm_out = torch.nn.GroupNorm(num_channels=128, num_groups=32, eps=1e-5)
+ self.conv_act = torch.nn.SiLU()
+ self.conv_out = torch.nn.Conv2d(128, 3, kernel_size=3, padding=1)
+ self.time_conv_out = torch.nn.Conv3d(3, 3, kernel_size=(3, 1, 1), padding=(1, 0, 0))
+
+
+ def forward(self, sample):
+ # 1. pre-process
+ hidden_states = rearrange(sample, "C T H W -> T C H W")
+ hidden_states = hidden_states / self.scaling_factor
+ hidden_states = self.conv_in(hidden_states)
+ time_emb, text_emb, res_stack = None, None, None
+
+ # 2. blocks
+ for i, block in enumerate(self.blocks):
+ hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
+
+ # 3. output
+ hidden_states = self.conv_norm_out(hidden_states)
+ hidden_states = self.conv_act(hidden_states)
+ hidden_states = self.conv_out(hidden_states)
+ hidden_states = rearrange(hidden_states, "T C H W -> C T H W")
+ hidden_states = self.time_conv_out(hidden_states)
+
+ return hidden_states
+
+
+ def build_mask(self, data, is_bound):
+ _, T, H, W = data.shape
+ t = repeat(torch.arange(T), "T -> T H W", T=T, H=H, W=W)
+ h = repeat(torch.arange(H), "H -> T H W", T=T, H=H, W=W)
+ w = repeat(torch.arange(W), "W -> T H W", T=T, H=H, W=W)
+ border_width = (T + H + W) // 6
+ pad = torch.ones_like(t) * border_width
+ mask = torch.stack([
+ pad if is_bound[0] else t + 1,
+ pad if is_bound[1] else T - t,
+ pad if is_bound[2] else h + 1,
+ pad if is_bound[3] else H - h,
+ pad if is_bound[4] else w + 1,
+ pad if is_bound[5] else W - w
+ ]).min(dim=0).values
+ mask = mask.clip(1, border_width)
+ mask = (mask / border_width).to(dtype=data.dtype, device=data.device)
+ mask = rearrange(mask, "T H W -> 1 T H W")
+ return mask
+
+
+ def decode_video(
+ self, sample,
+ batch_time=8, batch_height=128, batch_width=128,
+ stride_time=4, stride_height=32, stride_width=32,
+ progress_bar=lambda x:x
+ ):
+ sample = sample.permute(1, 0, 2, 3)
+ data_device = sample.device
+ computation_device = self.conv_in.weight.device
+ torch_dtype = sample.dtype
+ _, T, H, W = sample.shape
+
+ weight = torch.zeros((1, T, H*8, W*8), dtype=torch_dtype, device=data_device)
+ values = torch.zeros((3, T, H*8, W*8), dtype=torch_dtype, device=data_device)
+
+ # Split tasks
+ tasks = []
+ for t in range(0, T, stride_time):
+ for h in range(0, H, stride_height):
+ for w in range(0, W, stride_width):
+ if (t-stride_time >= 0 and t-stride_time+batch_time >= T)\
+ or (h-stride_height >= 0 and h-stride_height+batch_height >= H)\
+ or (w-stride_width >= 0 and w-stride_width+batch_width >= W):
+ continue
+ tasks.append((t, t+batch_time, h, h+batch_height, w, w+batch_width))
+
+ # Run
+ for tl, tr, hl, hr, wl, wr in progress_bar(tasks):
+ sample_batch = sample[:, tl:tr, hl:hr, wl:wr].to(computation_device)
+ sample_batch = self.forward(sample_batch).to(data_device)
+ mask = self.build_mask(sample_batch, is_bound=(tl==0, tr>=T, hl==0, hr>=H, wl==0, wr>=W))
+ values[:, tl:tr, hl*8:hr*8, wl*8:wr*8] += sample_batch * mask
+ weight[:, tl:tr, hl*8:hr*8, wl*8:wr*8] += mask
+ values /= weight
+ return values
+
+
+ @staticmethod
+ def state_dict_converter():
+ return SVDVAEDecoderStateDictConverter()
+
+
+class SVDVAEDecoderStateDictConverter:
+ def __init__(self):
+ pass
+
+ def from_diffusers(self, state_dict):
+ static_rename_dict = {
+ "decoder.conv_in": "conv_in",
+ "decoder.mid_block.attentions.0.group_norm": "blocks.2.norm",
+ "decoder.mid_block.attentions.0.to_q": "blocks.2.transformer_blocks.0.to_q",
+ "decoder.mid_block.attentions.0.to_k": "blocks.2.transformer_blocks.0.to_k",
+ "decoder.mid_block.attentions.0.to_v": "blocks.2.transformer_blocks.0.to_v",
+ "decoder.mid_block.attentions.0.to_out.0": "blocks.2.transformer_blocks.0.to_out",
+ "decoder.up_blocks.0.upsamplers.0.conv": "blocks.11.conv",
+ "decoder.up_blocks.1.upsamplers.0.conv": "blocks.18.conv",
+ "decoder.up_blocks.2.upsamplers.0.conv": "blocks.25.conv",
+ "decoder.conv_norm_out": "conv_norm_out",
+ "decoder.conv_out": "conv_out",
+ "decoder.time_conv_out": "time_conv_out"
+ }
+ prefix_rename_dict = {
+ "decoder.mid_block.resnets.0.spatial_res_block": "blocks.0",
+ "decoder.mid_block.resnets.0.temporal_res_block": "blocks.1",
+ "decoder.mid_block.resnets.0.time_mixer": "blocks.1",
+ "decoder.mid_block.resnets.1.spatial_res_block": "blocks.3",
+ "decoder.mid_block.resnets.1.temporal_res_block": "blocks.4",
+ "decoder.mid_block.resnets.1.time_mixer": "blocks.4",
+
+ "decoder.up_blocks.0.resnets.0.spatial_res_block": "blocks.5",
+ "decoder.up_blocks.0.resnets.0.temporal_res_block": "blocks.6",
+ "decoder.up_blocks.0.resnets.0.time_mixer": "blocks.6",
+ "decoder.up_blocks.0.resnets.1.spatial_res_block": "blocks.7",
+ "decoder.up_blocks.0.resnets.1.temporal_res_block": "blocks.8",
+ "decoder.up_blocks.0.resnets.1.time_mixer": "blocks.8",
+ "decoder.up_blocks.0.resnets.2.spatial_res_block": "blocks.9",
+ "decoder.up_blocks.0.resnets.2.temporal_res_block": "blocks.10",
+ "decoder.up_blocks.0.resnets.2.time_mixer": "blocks.10",
+
+ "decoder.up_blocks.1.resnets.0.spatial_res_block": "blocks.12",
+ "decoder.up_blocks.1.resnets.0.temporal_res_block": "blocks.13",
+ "decoder.up_blocks.1.resnets.0.time_mixer": "blocks.13",
+ "decoder.up_blocks.1.resnets.1.spatial_res_block": "blocks.14",
+ "decoder.up_blocks.1.resnets.1.temporal_res_block": "blocks.15",
+ "decoder.up_blocks.1.resnets.1.time_mixer": "blocks.15",
+ "decoder.up_blocks.1.resnets.2.spatial_res_block": "blocks.16",
+ "decoder.up_blocks.1.resnets.2.temporal_res_block": "blocks.17",
+ "decoder.up_blocks.1.resnets.2.time_mixer": "blocks.17",
+
+ "decoder.up_blocks.2.resnets.0.spatial_res_block": "blocks.19",
+ "decoder.up_blocks.2.resnets.0.temporal_res_block": "blocks.20",
+ "decoder.up_blocks.2.resnets.0.time_mixer": "blocks.20",
+ "decoder.up_blocks.2.resnets.1.spatial_res_block": "blocks.21",
+ "decoder.up_blocks.2.resnets.1.temporal_res_block": "blocks.22",
+ "decoder.up_blocks.2.resnets.1.time_mixer": "blocks.22",
+ "decoder.up_blocks.2.resnets.2.spatial_res_block": "blocks.23",
+ "decoder.up_blocks.2.resnets.2.temporal_res_block": "blocks.24",
+ "decoder.up_blocks.2.resnets.2.time_mixer": "blocks.24",
+
+ "decoder.up_blocks.3.resnets.0.spatial_res_block": "blocks.26",
+ "decoder.up_blocks.3.resnets.0.temporal_res_block": "blocks.27",
+ "decoder.up_blocks.3.resnets.0.time_mixer": "blocks.27",
+ "decoder.up_blocks.3.resnets.1.spatial_res_block": "blocks.28",
+ "decoder.up_blocks.3.resnets.1.temporal_res_block": "blocks.29",
+ "decoder.up_blocks.3.resnets.1.time_mixer": "blocks.29",
+ "decoder.up_blocks.3.resnets.2.spatial_res_block": "blocks.30",
+ "decoder.up_blocks.3.resnets.2.temporal_res_block": "blocks.31",
+ "decoder.up_blocks.3.resnets.2.time_mixer": "blocks.31",
+ }
+ suffix_rename_dict = {
+ "norm1.weight": "norm1.weight",
+ "conv1.weight": "conv1.weight",
+ "norm2.weight": "norm2.weight",
+ "conv2.weight": "conv2.weight",
+ "conv_shortcut.weight": "conv_shortcut.weight",
+ "norm1.bias": "norm1.bias",
+ "conv1.bias": "conv1.bias",
+ "norm2.bias": "norm2.bias",
+ "conv2.bias": "conv2.bias",
+ "conv_shortcut.bias": "conv_shortcut.bias",
+ "mix_factor": "mix_factor",
+ }
+
+ state_dict_ = {}
+ for name in static_rename_dict:
+ state_dict_[static_rename_dict[name] + ".weight"] = state_dict[name + ".weight"]
+ state_dict_[static_rename_dict[name] + ".bias"] = state_dict[name + ".bias"]
+ for prefix_name in prefix_rename_dict:
+ for suffix_name in suffix_rename_dict:
+ name = prefix_name + "." + suffix_name
+ name_ = prefix_rename_dict[prefix_name] + "." + suffix_rename_dict[suffix_name]
+ if name in state_dict:
+ state_dict_[name_] = state_dict[name]
+
+ return state_dict_
+
+
+ def from_civitai(self, state_dict):
+ rename_dict = {
+ "first_stage_model.decoder.conv_in.bias": "conv_in.bias",
+ "first_stage_model.decoder.conv_in.weight": "conv_in.weight",
+ "first_stage_model.decoder.conv_out.bias": "conv_out.bias",
+ "first_stage_model.decoder.conv_out.time_mix_conv.bias": "time_conv_out.bias",
+ "first_stage_model.decoder.conv_out.time_mix_conv.weight": "time_conv_out.weight",
+ "first_stage_model.decoder.conv_out.weight": "conv_out.weight",
+ "first_stage_model.decoder.mid.attn_1.k.bias": "blocks.2.transformer_blocks.0.to_k.bias",
+ "first_stage_model.decoder.mid.attn_1.k.weight": "blocks.2.transformer_blocks.0.to_k.weight",
+ "first_stage_model.decoder.mid.attn_1.norm.bias": "blocks.2.norm.bias",
+ "first_stage_model.decoder.mid.attn_1.norm.weight": "blocks.2.norm.weight",
+ "first_stage_model.decoder.mid.attn_1.proj_out.bias": "blocks.2.transformer_blocks.0.to_out.bias",
+ "first_stage_model.decoder.mid.attn_1.proj_out.weight": "blocks.2.transformer_blocks.0.to_out.weight",
+ "first_stage_model.decoder.mid.attn_1.q.bias": "blocks.2.transformer_blocks.0.to_q.bias",
+ "first_stage_model.decoder.mid.attn_1.q.weight": "blocks.2.transformer_blocks.0.to_q.weight",
+ "first_stage_model.decoder.mid.attn_1.v.bias": "blocks.2.transformer_blocks.0.to_v.bias",
+ "first_stage_model.decoder.mid.attn_1.v.weight": "blocks.2.transformer_blocks.0.to_v.weight",
+ "first_stage_model.decoder.mid.block_1.conv1.bias": "blocks.0.conv1.bias",
+ "first_stage_model.decoder.mid.block_1.conv1.weight": "blocks.0.conv1.weight",
+ "first_stage_model.decoder.mid.block_1.conv2.bias": "blocks.0.conv2.bias",
+ "first_stage_model.decoder.mid.block_1.conv2.weight": "blocks.0.conv2.weight",
+ "first_stage_model.decoder.mid.block_1.mix_factor": "blocks.1.mix_factor",
+ "first_stage_model.decoder.mid.block_1.norm1.bias": "blocks.0.norm1.bias",
+ "first_stage_model.decoder.mid.block_1.norm1.weight": "blocks.0.norm1.weight",
+ "first_stage_model.decoder.mid.block_1.norm2.bias": "blocks.0.norm2.bias",
+ "first_stage_model.decoder.mid.block_1.norm2.weight": "blocks.0.norm2.weight",
+ "first_stage_model.decoder.mid.block_1.time_stack.in_layers.0.bias": "blocks.1.norm1.bias",
+ "first_stage_model.decoder.mid.block_1.time_stack.in_layers.0.weight": "blocks.1.norm1.weight",
+ "first_stage_model.decoder.mid.block_1.time_stack.in_layers.2.bias": "blocks.1.conv1.bias",
+ "first_stage_model.decoder.mid.block_1.time_stack.in_layers.2.weight": "blocks.1.conv1.weight",
+ "first_stage_model.decoder.mid.block_1.time_stack.out_layers.0.bias": "blocks.1.norm2.bias",
+ "first_stage_model.decoder.mid.block_1.time_stack.out_layers.0.weight": "blocks.1.norm2.weight",
+ "first_stage_model.decoder.mid.block_1.time_stack.out_layers.3.bias": "blocks.1.conv2.bias",
+ "first_stage_model.decoder.mid.block_1.time_stack.out_layers.3.weight": "blocks.1.conv2.weight",
+ "first_stage_model.decoder.mid.block_2.conv1.bias": "blocks.3.conv1.bias",
+ "first_stage_model.decoder.mid.block_2.conv1.weight": "blocks.3.conv1.weight",
+ "first_stage_model.decoder.mid.block_2.conv2.bias": "blocks.3.conv2.bias",
+ "first_stage_model.decoder.mid.block_2.conv2.weight": "blocks.3.conv2.weight",
+ "first_stage_model.decoder.mid.block_2.mix_factor": "blocks.4.mix_factor",
+ "first_stage_model.decoder.mid.block_2.norm1.bias": "blocks.3.norm1.bias",
+ "first_stage_model.decoder.mid.block_2.norm1.weight": "blocks.3.norm1.weight",
+ "first_stage_model.decoder.mid.block_2.norm2.bias": "blocks.3.norm2.bias",
+ "first_stage_model.decoder.mid.block_2.norm2.weight": "blocks.3.norm2.weight",
+ "first_stage_model.decoder.mid.block_2.time_stack.in_layers.0.bias": "blocks.4.norm1.bias",
+ "first_stage_model.decoder.mid.block_2.time_stack.in_layers.0.weight": "blocks.4.norm1.weight",
+ "first_stage_model.decoder.mid.block_2.time_stack.in_layers.2.bias": "blocks.4.conv1.bias",
+ "first_stage_model.decoder.mid.block_2.time_stack.in_layers.2.weight": "blocks.4.conv1.weight",
+ "first_stage_model.decoder.mid.block_2.time_stack.out_layers.0.bias": "blocks.4.norm2.bias",
+ "first_stage_model.decoder.mid.block_2.time_stack.out_layers.0.weight": "blocks.4.norm2.weight",
+ "first_stage_model.decoder.mid.block_2.time_stack.out_layers.3.bias": "blocks.4.conv2.bias",
+ "first_stage_model.decoder.mid.block_2.time_stack.out_layers.3.weight": "blocks.4.conv2.weight",
+ "first_stage_model.decoder.norm_out.bias": "conv_norm_out.bias",
+ "first_stage_model.decoder.norm_out.weight": "conv_norm_out.weight",
+ "first_stage_model.decoder.up.0.block.0.conv1.bias": "blocks.26.conv1.bias",
+ "first_stage_model.decoder.up.0.block.0.conv1.weight": "blocks.26.conv1.weight",
+ "first_stage_model.decoder.up.0.block.0.conv2.bias": "blocks.26.conv2.bias",
+ "first_stage_model.decoder.up.0.block.0.conv2.weight": "blocks.26.conv2.weight",
+ "first_stage_model.decoder.up.0.block.0.mix_factor": "blocks.27.mix_factor",
+ "first_stage_model.decoder.up.0.block.0.nin_shortcut.bias": "blocks.26.conv_shortcut.bias",
+ "first_stage_model.decoder.up.0.block.0.nin_shortcut.weight": "blocks.26.conv_shortcut.weight",
+ "first_stage_model.decoder.up.0.block.0.norm1.bias": "blocks.26.norm1.bias",
+ "first_stage_model.decoder.up.0.block.0.norm1.weight": "blocks.26.norm1.weight",
+ "first_stage_model.decoder.up.0.block.0.norm2.bias": "blocks.26.norm2.bias",
+ "first_stage_model.decoder.up.0.block.0.norm2.weight": "blocks.26.norm2.weight",
+ "first_stage_model.decoder.up.0.block.0.time_stack.in_layers.0.bias": "blocks.27.norm1.bias",
+ "first_stage_model.decoder.up.0.block.0.time_stack.in_layers.0.weight": "blocks.27.norm1.weight",
+ "first_stage_model.decoder.up.0.block.0.time_stack.in_layers.2.bias": "blocks.27.conv1.bias",
+ "first_stage_model.decoder.up.0.block.0.time_stack.in_layers.2.weight": "blocks.27.conv1.weight",
+ "first_stage_model.decoder.up.0.block.0.time_stack.out_layers.0.bias": "blocks.27.norm2.bias",
+ "first_stage_model.decoder.up.0.block.0.time_stack.out_layers.0.weight": "blocks.27.norm2.weight",
+ "first_stage_model.decoder.up.0.block.0.time_stack.out_layers.3.bias": "blocks.27.conv2.bias",
+ "first_stage_model.decoder.up.0.block.0.time_stack.out_layers.3.weight": "blocks.27.conv2.weight",
+ "first_stage_model.decoder.up.0.block.1.conv1.bias": "blocks.28.conv1.bias",
+ "first_stage_model.decoder.up.0.block.1.conv1.weight": "blocks.28.conv1.weight",
+ "first_stage_model.decoder.up.0.block.1.conv2.bias": "blocks.28.conv2.bias",
+ "first_stage_model.decoder.up.0.block.1.conv2.weight": "blocks.28.conv2.weight",
+ "first_stage_model.decoder.up.0.block.1.mix_factor": "blocks.29.mix_factor",
+ "first_stage_model.decoder.up.0.block.1.norm1.bias": "blocks.28.norm1.bias",
+ "first_stage_model.decoder.up.0.block.1.norm1.weight": "blocks.28.norm1.weight",
+ "first_stage_model.decoder.up.0.block.1.norm2.bias": "blocks.28.norm2.bias",
+ "first_stage_model.decoder.up.0.block.1.norm2.weight": "blocks.28.norm2.weight",
+ "first_stage_model.decoder.up.0.block.1.time_stack.in_layers.0.bias": "blocks.29.norm1.bias",
+ "first_stage_model.decoder.up.0.block.1.time_stack.in_layers.0.weight": "blocks.29.norm1.weight",
+ "first_stage_model.decoder.up.0.block.1.time_stack.in_layers.2.bias": "blocks.29.conv1.bias",
+ "first_stage_model.decoder.up.0.block.1.time_stack.in_layers.2.weight": "blocks.29.conv1.weight",
+ "first_stage_model.decoder.up.0.block.1.time_stack.out_layers.0.bias": "blocks.29.norm2.bias",
+ "first_stage_model.decoder.up.0.block.1.time_stack.out_layers.0.weight": "blocks.29.norm2.weight",
+ "first_stage_model.decoder.up.0.block.1.time_stack.out_layers.3.bias": "blocks.29.conv2.bias",
+ "first_stage_model.decoder.up.0.block.1.time_stack.out_layers.3.weight": "blocks.29.conv2.weight",
+ "first_stage_model.decoder.up.0.block.2.conv1.bias": "blocks.30.conv1.bias",
+ "first_stage_model.decoder.up.0.block.2.conv1.weight": "blocks.30.conv1.weight",
+ "first_stage_model.decoder.up.0.block.2.conv2.bias": "blocks.30.conv2.bias",
+ "first_stage_model.decoder.up.0.block.2.conv2.weight": "blocks.30.conv2.weight",
+ "first_stage_model.decoder.up.0.block.2.mix_factor": "blocks.31.mix_factor",
+ "first_stage_model.decoder.up.0.block.2.norm1.bias": "blocks.30.norm1.bias",
+ "first_stage_model.decoder.up.0.block.2.norm1.weight": "blocks.30.norm1.weight",
+ "first_stage_model.decoder.up.0.block.2.norm2.bias": "blocks.30.norm2.bias",
+ "first_stage_model.decoder.up.0.block.2.norm2.weight": "blocks.30.norm2.weight",
+ "first_stage_model.decoder.up.0.block.2.time_stack.in_layers.0.bias": "blocks.31.norm1.bias",
+ "first_stage_model.decoder.up.0.block.2.time_stack.in_layers.0.weight": "blocks.31.norm1.weight",
+ "first_stage_model.decoder.up.0.block.2.time_stack.in_layers.2.bias": "blocks.31.conv1.bias",
+ "first_stage_model.decoder.up.0.block.2.time_stack.in_layers.2.weight": "blocks.31.conv1.weight",
+ "first_stage_model.decoder.up.0.block.2.time_stack.out_layers.0.bias": "blocks.31.norm2.bias",
+ "first_stage_model.decoder.up.0.block.2.time_stack.out_layers.0.weight": "blocks.31.norm2.weight",
+ "first_stage_model.decoder.up.0.block.2.time_stack.out_layers.3.bias": "blocks.31.conv2.bias",
+ "first_stage_model.decoder.up.0.block.2.time_stack.out_layers.3.weight": "blocks.31.conv2.weight",
+ "first_stage_model.decoder.up.1.block.0.conv1.bias": "blocks.19.conv1.bias",
+ "first_stage_model.decoder.up.1.block.0.conv1.weight": "blocks.19.conv1.weight",
+ "first_stage_model.decoder.up.1.block.0.conv2.bias": "blocks.19.conv2.bias",
+ "first_stage_model.decoder.up.1.block.0.conv2.weight": "blocks.19.conv2.weight",
+ "first_stage_model.decoder.up.1.block.0.mix_factor": "blocks.20.mix_factor",
+ "first_stage_model.decoder.up.1.block.0.nin_shortcut.bias": "blocks.19.conv_shortcut.bias",
+ "first_stage_model.decoder.up.1.block.0.nin_shortcut.weight": "blocks.19.conv_shortcut.weight",
+ "first_stage_model.decoder.up.1.block.0.norm1.bias": "blocks.19.norm1.bias",
+ "first_stage_model.decoder.up.1.block.0.norm1.weight": "blocks.19.norm1.weight",
+ "first_stage_model.decoder.up.1.block.0.norm2.bias": "blocks.19.norm2.bias",
+ "first_stage_model.decoder.up.1.block.0.norm2.weight": "blocks.19.norm2.weight",
+ "first_stage_model.decoder.up.1.block.0.time_stack.in_layers.0.bias": "blocks.20.norm1.bias",
+ "first_stage_model.decoder.up.1.block.0.time_stack.in_layers.0.weight": "blocks.20.norm1.weight",
+ "first_stage_model.decoder.up.1.block.0.time_stack.in_layers.2.bias": "blocks.20.conv1.bias",
+ "first_stage_model.decoder.up.1.block.0.time_stack.in_layers.2.weight": "blocks.20.conv1.weight",
+ "first_stage_model.decoder.up.1.block.0.time_stack.out_layers.0.bias": "blocks.20.norm2.bias",
+ "first_stage_model.decoder.up.1.block.0.time_stack.out_layers.0.weight": "blocks.20.norm2.weight",
+ "first_stage_model.decoder.up.1.block.0.time_stack.out_layers.3.bias": "blocks.20.conv2.bias",
+ "first_stage_model.decoder.up.1.block.0.time_stack.out_layers.3.weight": "blocks.20.conv2.weight",
+ "first_stage_model.decoder.up.1.block.1.conv1.bias": "blocks.21.conv1.bias",
+ "first_stage_model.decoder.up.1.block.1.conv1.weight": "blocks.21.conv1.weight",
+ "first_stage_model.decoder.up.1.block.1.conv2.bias": "blocks.21.conv2.bias",
+ "first_stage_model.decoder.up.1.block.1.conv2.weight": "blocks.21.conv2.weight",
+ "first_stage_model.decoder.up.1.block.1.mix_factor": "blocks.22.mix_factor",
+ "first_stage_model.decoder.up.1.block.1.norm1.bias": "blocks.21.norm1.bias",
+ "first_stage_model.decoder.up.1.block.1.norm1.weight": "blocks.21.norm1.weight",
+ "first_stage_model.decoder.up.1.block.1.norm2.bias": "blocks.21.norm2.bias",
+ "first_stage_model.decoder.up.1.block.1.norm2.weight": "blocks.21.norm2.weight",
+ "first_stage_model.decoder.up.1.block.1.time_stack.in_layers.0.bias": "blocks.22.norm1.bias",
+ "first_stage_model.decoder.up.1.block.1.time_stack.in_layers.0.weight": "blocks.22.norm1.weight",
+ "first_stage_model.decoder.up.1.block.1.time_stack.in_layers.2.bias": "blocks.22.conv1.bias",
+ "first_stage_model.decoder.up.1.block.1.time_stack.in_layers.2.weight": "blocks.22.conv1.weight",
+ "first_stage_model.decoder.up.1.block.1.time_stack.out_layers.0.bias": "blocks.22.norm2.bias",
+ "first_stage_model.decoder.up.1.block.1.time_stack.out_layers.0.weight": "blocks.22.norm2.weight",
+ "first_stage_model.decoder.up.1.block.1.time_stack.out_layers.3.bias": "blocks.22.conv2.bias",
+ "first_stage_model.decoder.up.1.block.1.time_stack.out_layers.3.weight": "blocks.22.conv2.weight",
+ "first_stage_model.decoder.up.1.block.2.conv1.bias": "blocks.23.conv1.bias",
+ "first_stage_model.decoder.up.1.block.2.conv1.weight": "blocks.23.conv1.weight",
+ "first_stage_model.decoder.up.1.block.2.conv2.bias": "blocks.23.conv2.bias",
+ "first_stage_model.decoder.up.1.block.2.conv2.weight": "blocks.23.conv2.weight",
+ "first_stage_model.decoder.up.1.block.2.mix_factor": "blocks.24.mix_factor",
+ "first_stage_model.decoder.up.1.block.2.norm1.bias": "blocks.23.norm1.bias",
+ "first_stage_model.decoder.up.1.block.2.norm1.weight": "blocks.23.norm1.weight",
+ "first_stage_model.decoder.up.1.block.2.norm2.bias": "blocks.23.norm2.bias",
+ "first_stage_model.decoder.up.1.block.2.norm2.weight": "blocks.23.norm2.weight",
+ "first_stage_model.decoder.up.1.block.2.time_stack.in_layers.0.bias": "blocks.24.norm1.bias",
+ "first_stage_model.decoder.up.1.block.2.time_stack.in_layers.0.weight": "blocks.24.norm1.weight",
+ "first_stage_model.decoder.up.1.block.2.time_stack.in_layers.2.bias": "blocks.24.conv1.bias",
+ "first_stage_model.decoder.up.1.block.2.time_stack.in_layers.2.weight": "blocks.24.conv1.weight",
+ "first_stage_model.decoder.up.1.block.2.time_stack.out_layers.0.bias": "blocks.24.norm2.bias",
+ "first_stage_model.decoder.up.1.block.2.time_stack.out_layers.0.weight": "blocks.24.norm2.weight",
+ "first_stage_model.decoder.up.1.block.2.time_stack.out_layers.3.bias": "blocks.24.conv2.bias",
+ "first_stage_model.decoder.up.1.block.2.time_stack.out_layers.3.weight": "blocks.24.conv2.weight",
+ "first_stage_model.decoder.up.1.upsample.conv.bias": "blocks.25.conv.bias",
+ "first_stage_model.decoder.up.1.upsample.conv.weight": "blocks.25.conv.weight",
+ "first_stage_model.decoder.up.2.block.0.conv1.bias": "blocks.12.conv1.bias",
+ "first_stage_model.decoder.up.2.block.0.conv1.weight": "blocks.12.conv1.weight",
+ "first_stage_model.decoder.up.2.block.0.conv2.bias": "blocks.12.conv2.bias",
+ "first_stage_model.decoder.up.2.block.0.conv2.weight": "blocks.12.conv2.weight",
+ "first_stage_model.decoder.up.2.block.0.mix_factor": "blocks.13.mix_factor",
+ "first_stage_model.decoder.up.2.block.0.norm1.bias": "blocks.12.norm1.bias",
+ "first_stage_model.decoder.up.2.block.0.norm1.weight": "blocks.12.norm1.weight",
+ "first_stage_model.decoder.up.2.block.0.norm2.bias": "blocks.12.norm2.bias",
+ "first_stage_model.decoder.up.2.block.0.norm2.weight": "blocks.12.norm2.weight",
+ "first_stage_model.decoder.up.2.block.0.time_stack.in_layers.0.bias": "blocks.13.norm1.bias",
+ "first_stage_model.decoder.up.2.block.0.time_stack.in_layers.0.weight": "blocks.13.norm1.weight",
+ "first_stage_model.decoder.up.2.block.0.time_stack.in_layers.2.bias": "blocks.13.conv1.bias",
+ "first_stage_model.decoder.up.2.block.0.time_stack.in_layers.2.weight": "blocks.13.conv1.weight",
+ "first_stage_model.decoder.up.2.block.0.time_stack.out_layers.0.bias": "blocks.13.norm2.bias",
+ "first_stage_model.decoder.up.2.block.0.time_stack.out_layers.0.weight": "blocks.13.norm2.weight",
+ "first_stage_model.decoder.up.2.block.0.time_stack.out_layers.3.bias": "blocks.13.conv2.bias",
+ "first_stage_model.decoder.up.2.block.0.time_stack.out_layers.3.weight": "blocks.13.conv2.weight",
+ "first_stage_model.decoder.up.2.block.1.conv1.bias": "blocks.14.conv1.bias",
+ "first_stage_model.decoder.up.2.block.1.conv1.weight": "blocks.14.conv1.weight",
+ "first_stage_model.decoder.up.2.block.1.conv2.bias": "blocks.14.conv2.bias",
+ "first_stage_model.decoder.up.2.block.1.conv2.weight": "blocks.14.conv2.weight",
+ "first_stage_model.decoder.up.2.block.1.mix_factor": "blocks.15.mix_factor",
+ "first_stage_model.decoder.up.2.block.1.norm1.bias": "blocks.14.norm1.bias",
+ "first_stage_model.decoder.up.2.block.1.norm1.weight": "blocks.14.norm1.weight",
+ "first_stage_model.decoder.up.2.block.1.norm2.bias": "blocks.14.norm2.bias",
+ "first_stage_model.decoder.up.2.block.1.norm2.weight": "blocks.14.norm2.weight",
+ "first_stage_model.decoder.up.2.block.1.time_stack.in_layers.0.bias": "blocks.15.norm1.bias",
+ "first_stage_model.decoder.up.2.block.1.time_stack.in_layers.0.weight": "blocks.15.norm1.weight",
+ "first_stage_model.decoder.up.2.block.1.time_stack.in_layers.2.bias": "blocks.15.conv1.bias",
+ "first_stage_model.decoder.up.2.block.1.time_stack.in_layers.2.weight": "blocks.15.conv1.weight",
+ "first_stage_model.decoder.up.2.block.1.time_stack.out_layers.0.bias": "blocks.15.norm2.bias",
+ "first_stage_model.decoder.up.2.block.1.time_stack.out_layers.0.weight": "blocks.15.norm2.weight",
+ "first_stage_model.decoder.up.2.block.1.time_stack.out_layers.3.bias": "blocks.15.conv2.bias",
+ "first_stage_model.decoder.up.2.block.1.time_stack.out_layers.3.weight": "blocks.15.conv2.weight",
+ "first_stage_model.decoder.up.2.block.2.conv1.bias": "blocks.16.conv1.bias",
+ "first_stage_model.decoder.up.2.block.2.conv1.weight": "blocks.16.conv1.weight",
+ "first_stage_model.decoder.up.2.block.2.conv2.bias": "blocks.16.conv2.bias",
+ "first_stage_model.decoder.up.2.block.2.conv2.weight": "blocks.16.conv2.weight",
+ "first_stage_model.decoder.up.2.block.2.mix_factor": "blocks.17.mix_factor",
+ "first_stage_model.decoder.up.2.block.2.norm1.bias": "blocks.16.norm1.bias",
+ "first_stage_model.decoder.up.2.block.2.norm1.weight": "blocks.16.norm1.weight",
+ "first_stage_model.decoder.up.2.block.2.norm2.bias": "blocks.16.norm2.bias",
+ "first_stage_model.decoder.up.2.block.2.norm2.weight": "blocks.16.norm2.weight",
+ "first_stage_model.decoder.up.2.block.2.time_stack.in_layers.0.bias": "blocks.17.norm1.bias",
+ "first_stage_model.decoder.up.2.block.2.time_stack.in_layers.0.weight": "blocks.17.norm1.weight",
+ "first_stage_model.decoder.up.2.block.2.time_stack.in_layers.2.bias": "blocks.17.conv1.bias",
+ "first_stage_model.decoder.up.2.block.2.time_stack.in_layers.2.weight": "blocks.17.conv1.weight",
+ "first_stage_model.decoder.up.2.block.2.time_stack.out_layers.0.bias": "blocks.17.norm2.bias",
+ "first_stage_model.decoder.up.2.block.2.time_stack.out_layers.0.weight": "blocks.17.norm2.weight",
+ "first_stage_model.decoder.up.2.block.2.time_stack.out_layers.3.bias": "blocks.17.conv2.bias",
+ "first_stage_model.decoder.up.2.block.2.time_stack.out_layers.3.weight": "blocks.17.conv2.weight",
+ "first_stage_model.decoder.up.2.upsample.conv.bias": "blocks.18.conv.bias",
+ "first_stage_model.decoder.up.2.upsample.conv.weight": "blocks.18.conv.weight",
+ "first_stage_model.decoder.up.3.block.0.conv1.bias": "blocks.5.conv1.bias",
+ "first_stage_model.decoder.up.3.block.0.conv1.weight": "blocks.5.conv1.weight",
+ "first_stage_model.decoder.up.3.block.0.conv2.bias": "blocks.5.conv2.bias",
+ "first_stage_model.decoder.up.3.block.0.conv2.weight": "blocks.5.conv2.weight",
+ "first_stage_model.decoder.up.3.block.0.mix_factor": "blocks.6.mix_factor",
+ "first_stage_model.decoder.up.3.block.0.norm1.bias": "blocks.5.norm1.bias",
+ "first_stage_model.decoder.up.3.block.0.norm1.weight": "blocks.5.norm1.weight",
+ "first_stage_model.decoder.up.3.block.0.norm2.bias": "blocks.5.norm2.bias",
+ "first_stage_model.decoder.up.3.block.0.norm2.weight": "blocks.5.norm2.weight",
+ "first_stage_model.decoder.up.3.block.0.time_stack.in_layers.0.bias": "blocks.6.norm1.bias",
+ "first_stage_model.decoder.up.3.block.0.time_stack.in_layers.0.weight": "blocks.6.norm1.weight",
+ "first_stage_model.decoder.up.3.block.0.time_stack.in_layers.2.bias": "blocks.6.conv1.bias",
+ "first_stage_model.decoder.up.3.block.0.time_stack.in_layers.2.weight": "blocks.6.conv1.weight",
+ "first_stage_model.decoder.up.3.block.0.time_stack.out_layers.0.bias": "blocks.6.norm2.bias",
+ "first_stage_model.decoder.up.3.block.0.time_stack.out_layers.0.weight": "blocks.6.norm2.weight",
+ "first_stage_model.decoder.up.3.block.0.time_stack.out_layers.3.bias": "blocks.6.conv2.bias",
+ "first_stage_model.decoder.up.3.block.0.time_stack.out_layers.3.weight": "blocks.6.conv2.weight",
+ "first_stage_model.decoder.up.3.block.1.conv1.bias": "blocks.7.conv1.bias",
+ "first_stage_model.decoder.up.3.block.1.conv1.weight": "blocks.7.conv1.weight",
+ "first_stage_model.decoder.up.3.block.1.conv2.bias": "blocks.7.conv2.bias",
+ "first_stage_model.decoder.up.3.block.1.conv2.weight": "blocks.7.conv2.weight",
+ "first_stage_model.decoder.up.3.block.1.mix_factor": "blocks.8.mix_factor",
+ "first_stage_model.decoder.up.3.block.1.norm1.bias": "blocks.7.norm1.bias",
+ "first_stage_model.decoder.up.3.block.1.norm1.weight": "blocks.7.norm1.weight",
+ "first_stage_model.decoder.up.3.block.1.norm2.bias": "blocks.7.norm2.bias",
+ "first_stage_model.decoder.up.3.block.1.norm2.weight": "blocks.7.norm2.weight",
+ "first_stage_model.decoder.up.3.block.1.time_stack.in_layers.0.bias": "blocks.8.norm1.bias",
+ "first_stage_model.decoder.up.3.block.1.time_stack.in_layers.0.weight": "blocks.8.norm1.weight",
+ "first_stage_model.decoder.up.3.block.1.time_stack.in_layers.2.bias": "blocks.8.conv1.bias",
+ "first_stage_model.decoder.up.3.block.1.time_stack.in_layers.2.weight": "blocks.8.conv1.weight",
+ "first_stage_model.decoder.up.3.block.1.time_stack.out_layers.0.bias": "blocks.8.norm2.bias",
+ "first_stage_model.decoder.up.3.block.1.time_stack.out_layers.0.weight": "blocks.8.norm2.weight",
+ "first_stage_model.decoder.up.3.block.1.time_stack.out_layers.3.bias": "blocks.8.conv2.bias",
+ "first_stage_model.decoder.up.3.block.1.time_stack.out_layers.3.weight": "blocks.8.conv2.weight",
+ "first_stage_model.decoder.up.3.block.2.conv1.bias": "blocks.9.conv1.bias",
+ "first_stage_model.decoder.up.3.block.2.conv1.weight": "blocks.9.conv1.weight",
+ "first_stage_model.decoder.up.3.block.2.conv2.bias": "blocks.9.conv2.bias",
+ "first_stage_model.decoder.up.3.block.2.conv2.weight": "blocks.9.conv2.weight",
+ "first_stage_model.decoder.up.3.block.2.mix_factor": "blocks.10.mix_factor",
+ "first_stage_model.decoder.up.3.block.2.norm1.bias": "blocks.9.norm1.bias",
+ "first_stage_model.decoder.up.3.block.2.norm1.weight": "blocks.9.norm1.weight",
+ "first_stage_model.decoder.up.3.block.2.norm2.bias": "blocks.9.norm2.bias",
+ "first_stage_model.decoder.up.3.block.2.norm2.weight": "blocks.9.norm2.weight",
+ "first_stage_model.decoder.up.3.block.2.time_stack.in_layers.0.bias": "blocks.10.norm1.bias",
+ "first_stage_model.decoder.up.3.block.2.time_stack.in_layers.0.weight": "blocks.10.norm1.weight",
+ "first_stage_model.decoder.up.3.block.2.time_stack.in_layers.2.bias": "blocks.10.conv1.bias",
+ "first_stage_model.decoder.up.3.block.2.time_stack.in_layers.2.weight": "blocks.10.conv1.weight",
+ "first_stage_model.decoder.up.3.block.2.time_stack.out_layers.0.bias": "blocks.10.norm2.bias",
+ "first_stage_model.decoder.up.3.block.2.time_stack.out_layers.0.weight": "blocks.10.norm2.weight",
+ "first_stage_model.decoder.up.3.block.2.time_stack.out_layers.3.bias": "blocks.10.conv2.bias",
+ "first_stage_model.decoder.up.3.block.2.time_stack.out_layers.3.weight": "blocks.10.conv2.weight",
+ "first_stage_model.decoder.up.3.upsample.conv.bias": "blocks.11.conv.bias",
+ "first_stage_model.decoder.up.3.upsample.conv.weight": "blocks.11.conv.weight",
+ }
+ state_dict_ = {}
+ for name in state_dict:
+ if name in rename_dict:
+ param = state_dict[name]
+ if "blocks.2.transformer_blocks.0" in rename_dict[name]:
+ param = param.squeeze()
+ state_dict_[rename_dict[name]] = param
+ return state_dict_
diff --git a/PusaV1/diffsynth/models/svd_vae_encoder.py b/PusaV1/diffsynth/models/svd_vae_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..04a543a22c2794095d5f56089b2ca60d445fbc4e
--- /dev/null
+++ b/PusaV1/diffsynth/models/svd_vae_encoder.py
@@ -0,0 +1,139 @@
+from .sd_vae_encoder import SDVAEEncoderStateDictConverter, SDVAEEncoder
+
+
+class SVDVAEEncoder(SDVAEEncoder):
+ def __init__(self):
+ super().__init__()
+ self.scaling_factor = 0.13025
+
+ @staticmethod
+ def state_dict_converter():
+ return SVDVAEEncoderStateDictConverter()
+
+
+class SVDVAEEncoderStateDictConverter(SDVAEEncoderStateDictConverter):
+ def __init__(self):
+ super().__init__()
+
+ def from_diffusers(self, state_dict):
+ return super().from_diffusers(state_dict)
+
+ def from_civitai(self, state_dict):
+ rename_dict = {
+ "conditioner.embedders.3.encoder.encoder.conv_in.bias": "conv_in.bias",
+ "conditioner.embedders.3.encoder.encoder.conv_in.weight": "conv_in.weight",
+ "conditioner.embedders.3.encoder.encoder.conv_out.bias": "conv_out.bias",
+ "conditioner.embedders.3.encoder.encoder.conv_out.weight": "conv_out.weight",
+ "conditioner.embedders.3.encoder.encoder.down.0.block.0.conv1.bias": "blocks.0.conv1.bias",
+ "conditioner.embedders.3.encoder.encoder.down.0.block.0.conv1.weight": "blocks.0.conv1.weight",
+ "conditioner.embedders.3.encoder.encoder.down.0.block.0.conv2.bias": "blocks.0.conv2.bias",
+ "conditioner.embedders.3.encoder.encoder.down.0.block.0.conv2.weight": "blocks.0.conv2.weight",
+ "conditioner.embedders.3.encoder.encoder.down.0.block.0.norm1.bias": "blocks.0.norm1.bias",
+ "conditioner.embedders.3.encoder.encoder.down.0.block.0.norm1.weight": "blocks.0.norm1.weight",
+ "conditioner.embedders.3.encoder.encoder.down.0.block.0.norm2.bias": "blocks.0.norm2.bias",
+ "conditioner.embedders.3.encoder.encoder.down.0.block.0.norm2.weight": "blocks.0.norm2.weight",
+ "conditioner.embedders.3.encoder.encoder.down.0.block.1.conv1.bias": "blocks.1.conv1.bias",
+ "conditioner.embedders.3.encoder.encoder.down.0.block.1.conv1.weight": "blocks.1.conv1.weight",
+ "conditioner.embedders.3.encoder.encoder.down.0.block.1.conv2.bias": "blocks.1.conv2.bias",
+ "conditioner.embedders.3.encoder.encoder.down.0.block.1.conv2.weight": "blocks.1.conv2.weight",
+ "conditioner.embedders.3.encoder.encoder.down.0.block.1.norm1.bias": "blocks.1.norm1.bias",
+ "conditioner.embedders.3.encoder.encoder.down.0.block.1.norm1.weight": "blocks.1.norm1.weight",
+ "conditioner.embedders.3.encoder.encoder.down.0.block.1.norm2.bias": "blocks.1.norm2.bias",
+ "conditioner.embedders.3.encoder.encoder.down.0.block.1.norm2.weight": "blocks.1.norm2.weight",
+ "conditioner.embedders.3.encoder.encoder.down.0.downsample.conv.bias": "blocks.2.conv.bias",
+ "conditioner.embedders.3.encoder.encoder.down.0.downsample.conv.weight": "blocks.2.conv.weight",
+ "conditioner.embedders.3.encoder.encoder.down.1.block.0.conv1.bias": "blocks.3.conv1.bias",
+ "conditioner.embedders.3.encoder.encoder.down.1.block.0.conv1.weight": "blocks.3.conv1.weight",
+ "conditioner.embedders.3.encoder.encoder.down.1.block.0.conv2.bias": "blocks.3.conv2.bias",
+ "conditioner.embedders.3.encoder.encoder.down.1.block.0.conv2.weight": "blocks.3.conv2.weight",
+ "conditioner.embedders.3.encoder.encoder.down.1.block.0.nin_shortcut.bias": "blocks.3.conv_shortcut.bias",
+ "conditioner.embedders.3.encoder.encoder.down.1.block.0.nin_shortcut.weight": "blocks.3.conv_shortcut.weight",
+ "conditioner.embedders.3.encoder.encoder.down.1.block.0.norm1.bias": "blocks.3.norm1.bias",
+ "conditioner.embedders.3.encoder.encoder.down.1.block.0.norm1.weight": "blocks.3.norm1.weight",
+ "conditioner.embedders.3.encoder.encoder.down.1.block.0.norm2.bias": "blocks.3.norm2.bias",
+ "conditioner.embedders.3.encoder.encoder.down.1.block.0.norm2.weight": "blocks.3.norm2.weight",
+ "conditioner.embedders.3.encoder.encoder.down.1.block.1.conv1.bias": "blocks.4.conv1.bias",
+ "conditioner.embedders.3.encoder.encoder.down.1.block.1.conv1.weight": "blocks.4.conv1.weight",
+ "conditioner.embedders.3.encoder.encoder.down.1.block.1.conv2.bias": "blocks.4.conv2.bias",
+ "conditioner.embedders.3.encoder.encoder.down.1.block.1.conv2.weight": "blocks.4.conv2.weight",
+ "conditioner.embedders.3.encoder.encoder.down.1.block.1.norm1.bias": "blocks.4.norm1.bias",
+ "conditioner.embedders.3.encoder.encoder.down.1.block.1.norm1.weight": "blocks.4.norm1.weight",
+ "conditioner.embedders.3.encoder.encoder.down.1.block.1.norm2.bias": "blocks.4.norm2.bias",
+ "conditioner.embedders.3.encoder.encoder.down.1.block.1.norm2.weight": "blocks.4.norm2.weight",
+ "conditioner.embedders.3.encoder.encoder.down.1.downsample.conv.bias": "blocks.5.conv.bias",
+ "conditioner.embedders.3.encoder.encoder.down.1.downsample.conv.weight": "blocks.5.conv.weight",
+ "conditioner.embedders.3.encoder.encoder.down.2.block.0.conv1.bias": "blocks.6.conv1.bias",
+ "conditioner.embedders.3.encoder.encoder.down.2.block.0.conv1.weight": "blocks.6.conv1.weight",
+ "conditioner.embedders.3.encoder.encoder.down.2.block.0.conv2.bias": "blocks.6.conv2.bias",
+ "conditioner.embedders.3.encoder.encoder.down.2.block.0.conv2.weight": "blocks.6.conv2.weight",
+ "conditioner.embedders.3.encoder.encoder.down.2.block.0.nin_shortcut.bias": "blocks.6.conv_shortcut.bias",
+ "conditioner.embedders.3.encoder.encoder.down.2.block.0.nin_shortcut.weight": "blocks.6.conv_shortcut.weight",
+ "conditioner.embedders.3.encoder.encoder.down.2.block.0.norm1.bias": "blocks.6.norm1.bias",
+ "conditioner.embedders.3.encoder.encoder.down.2.block.0.norm1.weight": "blocks.6.norm1.weight",
+ "conditioner.embedders.3.encoder.encoder.down.2.block.0.norm2.bias": "blocks.6.norm2.bias",
+ "conditioner.embedders.3.encoder.encoder.down.2.block.0.norm2.weight": "blocks.6.norm2.weight",
+ "conditioner.embedders.3.encoder.encoder.down.2.block.1.conv1.bias": "blocks.7.conv1.bias",
+ "conditioner.embedders.3.encoder.encoder.down.2.block.1.conv1.weight": "blocks.7.conv1.weight",
+ "conditioner.embedders.3.encoder.encoder.down.2.block.1.conv2.bias": "blocks.7.conv2.bias",
+ "conditioner.embedders.3.encoder.encoder.down.2.block.1.conv2.weight": "blocks.7.conv2.weight",
+ "conditioner.embedders.3.encoder.encoder.down.2.block.1.norm1.bias": "blocks.7.norm1.bias",
+ "conditioner.embedders.3.encoder.encoder.down.2.block.1.norm1.weight": "blocks.7.norm1.weight",
+ "conditioner.embedders.3.encoder.encoder.down.2.block.1.norm2.bias": "blocks.7.norm2.bias",
+ "conditioner.embedders.3.encoder.encoder.down.2.block.1.norm2.weight": "blocks.7.norm2.weight",
+ "conditioner.embedders.3.encoder.encoder.down.2.downsample.conv.bias": "blocks.8.conv.bias",
+ "conditioner.embedders.3.encoder.encoder.down.2.downsample.conv.weight": "blocks.8.conv.weight",
+ "conditioner.embedders.3.encoder.encoder.down.3.block.0.conv1.bias": "blocks.9.conv1.bias",
+ "conditioner.embedders.3.encoder.encoder.down.3.block.0.conv1.weight": "blocks.9.conv1.weight",
+ "conditioner.embedders.3.encoder.encoder.down.3.block.0.conv2.bias": "blocks.9.conv2.bias",
+ "conditioner.embedders.3.encoder.encoder.down.3.block.0.conv2.weight": "blocks.9.conv2.weight",
+ "conditioner.embedders.3.encoder.encoder.down.3.block.0.norm1.bias": "blocks.9.norm1.bias",
+ "conditioner.embedders.3.encoder.encoder.down.3.block.0.norm1.weight": "blocks.9.norm1.weight",
+ "conditioner.embedders.3.encoder.encoder.down.3.block.0.norm2.bias": "blocks.9.norm2.bias",
+ "conditioner.embedders.3.encoder.encoder.down.3.block.0.norm2.weight": "blocks.9.norm2.weight",
+ "conditioner.embedders.3.encoder.encoder.down.3.block.1.conv1.bias": "blocks.10.conv1.bias",
+ "conditioner.embedders.3.encoder.encoder.down.3.block.1.conv1.weight": "blocks.10.conv1.weight",
+ "conditioner.embedders.3.encoder.encoder.down.3.block.1.conv2.bias": "blocks.10.conv2.bias",
+ "conditioner.embedders.3.encoder.encoder.down.3.block.1.conv2.weight": "blocks.10.conv2.weight",
+ "conditioner.embedders.3.encoder.encoder.down.3.block.1.norm1.bias": "blocks.10.norm1.bias",
+ "conditioner.embedders.3.encoder.encoder.down.3.block.1.norm1.weight": "blocks.10.norm1.weight",
+ "conditioner.embedders.3.encoder.encoder.down.3.block.1.norm2.bias": "blocks.10.norm2.bias",
+ "conditioner.embedders.3.encoder.encoder.down.3.block.1.norm2.weight": "blocks.10.norm2.weight",
+ "conditioner.embedders.3.encoder.encoder.mid.attn_1.k.bias": "blocks.12.transformer_blocks.0.to_k.bias",
+ "conditioner.embedders.3.encoder.encoder.mid.attn_1.k.weight": "blocks.12.transformer_blocks.0.to_k.weight",
+ "conditioner.embedders.3.encoder.encoder.mid.attn_1.norm.bias": "blocks.12.norm.bias",
+ "conditioner.embedders.3.encoder.encoder.mid.attn_1.norm.weight": "blocks.12.norm.weight",
+ "conditioner.embedders.3.encoder.encoder.mid.attn_1.proj_out.bias": "blocks.12.transformer_blocks.0.to_out.bias",
+ "conditioner.embedders.3.encoder.encoder.mid.attn_1.proj_out.weight": "blocks.12.transformer_blocks.0.to_out.weight",
+ "conditioner.embedders.3.encoder.encoder.mid.attn_1.q.bias": "blocks.12.transformer_blocks.0.to_q.bias",
+ "conditioner.embedders.3.encoder.encoder.mid.attn_1.q.weight": "blocks.12.transformer_blocks.0.to_q.weight",
+ "conditioner.embedders.3.encoder.encoder.mid.attn_1.v.bias": "blocks.12.transformer_blocks.0.to_v.bias",
+ "conditioner.embedders.3.encoder.encoder.mid.attn_1.v.weight": "blocks.12.transformer_blocks.0.to_v.weight",
+ "conditioner.embedders.3.encoder.encoder.mid.block_1.conv1.bias": "blocks.11.conv1.bias",
+ "conditioner.embedders.3.encoder.encoder.mid.block_1.conv1.weight": "blocks.11.conv1.weight",
+ "conditioner.embedders.3.encoder.encoder.mid.block_1.conv2.bias": "blocks.11.conv2.bias",
+ "conditioner.embedders.3.encoder.encoder.mid.block_1.conv2.weight": "blocks.11.conv2.weight",
+ "conditioner.embedders.3.encoder.encoder.mid.block_1.norm1.bias": "blocks.11.norm1.bias",
+ "conditioner.embedders.3.encoder.encoder.mid.block_1.norm1.weight": "blocks.11.norm1.weight",
+ "conditioner.embedders.3.encoder.encoder.mid.block_1.norm2.bias": "blocks.11.norm2.bias",
+ "conditioner.embedders.3.encoder.encoder.mid.block_1.norm2.weight": "blocks.11.norm2.weight",
+ "conditioner.embedders.3.encoder.encoder.mid.block_2.conv1.bias": "blocks.13.conv1.bias",
+ "conditioner.embedders.3.encoder.encoder.mid.block_2.conv1.weight": "blocks.13.conv1.weight",
+ "conditioner.embedders.3.encoder.encoder.mid.block_2.conv2.bias": "blocks.13.conv2.bias",
+ "conditioner.embedders.3.encoder.encoder.mid.block_2.conv2.weight": "blocks.13.conv2.weight",
+ "conditioner.embedders.3.encoder.encoder.mid.block_2.norm1.bias": "blocks.13.norm1.bias",
+ "conditioner.embedders.3.encoder.encoder.mid.block_2.norm1.weight": "blocks.13.norm1.weight",
+ "conditioner.embedders.3.encoder.encoder.mid.block_2.norm2.bias": "blocks.13.norm2.bias",
+ "conditioner.embedders.3.encoder.encoder.mid.block_2.norm2.weight": "blocks.13.norm2.weight",
+ "conditioner.embedders.3.encoder.encoder.norm_out.bias": "conv_norm_out.bias",
+ "conditioner.embedders.3.encoder.encoder.norm_out.weight": "conv_norm_out.weight",
+ "conditioner.embedders.3.encoder.quant_conv.bias": "quant_conv.bias",
+ "conditioner.embedders.3.encoder.quant_conv.weight": "quant_conv.weight",
+ }
+ state_dict_ = {}
+ for name in state_dict:
+ if name in rename_dict:
+ param = state_dict[name]
+ if "transformer_blocks" in rename_dict[name]:
+ param = param.squeeze()
+ state_dict_[rename_dict[name]] = param
+ return state_dict_
diff --git a/PusaV1/diffsynth/models/tiler.py b/PusaV1/diffsynth/models/tiler.py
new file mode 100644
index 0000000000000000000000000000000000000000..dff5ebf2674b504f0b66a6ba7aba800e048f5099
--- /dev/null
+++ b/PusaV1/diffsynth/models/tiler.py
@@ -0,0 +1,234 @@
+import torch
+from einops import rearrange, repeat
+
+
+class TileWorker:
+ def __init__(self):
+ pass
+
+
+ def mask(self, height, width, border_width):
+ # Create a mask with shape (height, width).
+ # The centre area is filled with 1, and the border line is filled with values in range (0, 1].
+ x = torch.arange(height).repeat(width, 1).T
+ y = torch.arange(width).repeat(height, 1)
+ mask = torch.stack([x + 1, height - x, y + 1, width - y]).min(dim=0).values
+ mask = (mask / border_width).clip(0, 1)
+ return mask
+
+
+ def tile(self, model_input, tile_size, tile_stride, tile_device, tile_dtype):
+ # Convert a tensor (b, c, h, w) to (b, c, tile_size, tile_size, tile_num)
+ batch_size, channel, _, _ = model_input.shape
+ model_input = model_input.to(device=tile_device, dtype=tile_dtype)
+ unfold_operator = torch.nn.Unfold(
+ kernel_size=(tile_size, tile_size),
+ stride=(tile_stride, tile_stride)
+ )
+ model_input = unfold_operator(model_input)
+ model_input = model_input.view((batch_size, channel, tile_size, tile_size, -1))
+
+ return model_input
+
+
+ def tiled_inference(self, forward_fn, model_input, tile_batch_size, inference_device, inference_dtype, tile_device, tile_dtype):
+ # Call y=forward_fn(x) for each tile
+ tile_num = model_input.shape[-1]
+ model_output_stack = []
+
+ for tile_id in range(0, tile_num, tile_batch_size):
+
+ # process input
+ tile_id_ = min(tile_id + tile_batch_size, tile_num)
+ x = model_input[:, :, :, :, tile_id: tile_id_]
+ x = x.to(device=inference_device, dtype=inference_dtype)
+ x = rearrange(x, "b c h w n -> (n b) c h w")
+
+ # process output
+ y = forward_fn(x)
+ y = rearrange(y, "(n b) c h w -> b c h w n", n=tile_id_-tile_id)
+ y = y.to(device=tile_device, dtype=tile_dtype)
+ model_output_stack.append(y)
+
+ model_output = torch.concat(model_output_stack, dim=-1)
+ return model_output
+
+
+ def io_scale(self, model_output, tile_size):
+ # Determine the size modification happened in forward_fn
+ # We only consider the same scale on height and width.
+ io_scale = model_output.shape[2] / tile_size
+ return io_scale
+
+
+ def untile(self, model_output, height, width, tile_size, tile_stride, border_width, tile_device, tile_dtype):
+ # The reversed function of tile
+ mask = self.mask(tile_size, tile_size, border_width)
+ mask = mask.to(device=tile_device, dtype=tile_dtype)
+ mask = rearrange(mask, "h w -> 1 1 h w 1")
+ model_output = model_output * mask
+
+ fold_operator = torch.nn.Fold(
+ output_size=(height, width),
+ kernel_size=(tile_size, tile_size),
+ stride=(tile_stride, tile_stride)
+ )
+ mask = repeat(mask[0, 0, :, :, 0], "h w -> 1 (h w) n", n=model_output.shape[-1])
+ model_output = rearrange(model_output, "b c h w n -> b (c h w) n")
+ model_output = fold_operator(model_output) / fold_operator(mask)
+
+ return model_output
+
+
+ def tiled_forward(self, forward_fn, model_input, tile_size, tile_stride, tile_batch_size=1, tile_device="cpu", tile_dtype=torch.float32, border_width=None):
+ # Prepare
+ inference_device, inference_dtype = model_input.device, model_input.dtype
+ height, width = model_input.shape[2], model_input.shape[3]
+ border_width = int(tile_stride*0.5) if border_width is None else border_width
+
+ # tile
+ model_input = self.tile(model_input, tile_size, tile_stride, tile_device, tile_dtype)
+
+ # inference
+ model_output = self.tiled_inference(forward_fn, model_input, tile_batch_size, inference_device, inference_dtype, tile_device, tile_dtype)
+
+ # resize
+ io_scale = self.io_scale(model_output, tile_size)
+ height, width = int(height*io_scale), int(width*io_scale)
+ tile_size, tile_stride = int(tile_size*io_scale), int(tile_stride*io_scale)
+ border_width = int(border_width*io_scale)
+
+ # untile
+ model_output = self.untile(model_output, height, width, tile_size, tile_stride, border_width, tile_device, tile_dtype)
+
+ # Done!
+ model_output = model_output.to(device=inference_device, dtype=inference_dtype)
+ return model_output
+
+
+
+class FastTileWorker:
+ def __init__(self):
+ pass
+
+
+ def build_mask(self, data, is_bound):
+ _, _, H, W = data.shape
+ h = repeat(torch.arange(H), "H -> H W", H=H, W=W)
+ w = repeat(torch.arange(W), "W -> H W", H=H, W=W)
+ border_width = (H + W) // 4
+ pad = torch.ones_like(h) * border_width
+ mask = torch.stack([
+ pad if is_bound[0] else h + 1,
+ pad if is_bound[1] else H - h,
+ pad if is_bound[2] else w + 1,
+ pad if is_bound[3] else W - w
+ ]).min(dim=0).values
+ mask = mask.clip(1, border_width)
+ mask = (mask / border_width).to(dtype=data.dtype, device=data.device)
+ mask = rearrange(mask, "H W -> 1 H W")
+ return mask
+
+
+ def tiled_forward(self, forward_fn, model_input, tile_size, tile_stride, tile_device="cpu", tile_dtype=torch.float32, border_width=None):
+ # Prepare
+ B, C, H, W = model_input.shape
+ border_width = int(tile_stride*0.5) if border_width is None else border_width
+ weight = torch.zeros((1, 1, H, W), dtype=tile_dtype, device=tile_device)
+ values = torch.zeros((B, C, H, W), dtype=tile_dtype, device=tile_device)
+
+ # Split tasks
+ tasks = []
+ for h in range(0, H, tile_stride):
+ for w in range(0, W, tile_stride):
+ if (h-tile_stride >= 0 and h-tile_stride+tile_size >= H) or (w-tile_stride >= 0 and w-tile_stride+tile_size >= W):
+ continue
+ h_, w_ = h + tile_size, w + tile_size
+ if h_ > H: h, h_ = H - tile_size, H
+ if w_ > W: w, w_ = W - tile_size, W
+ tasks.append((h, h_, w, w_))
+
+ # Run
+ for hl, hr, wl, wr in tasks:
+ # Forward
+ hidden_states_batch = forward_fn(hl, hr, wl, wr).to(dtype=tile_dtype, device=tile_device)
+
+ mask = self.build_mask(hidden_states_batch, is_bound=(hl==0, hr>=H, wl==0, wr>=W))
+ values[:, :, hl:hr, wl:wr] += hidden_states_batch * mask
+ weight[:, :, hl:hr, wl:wr] += mask
+ values /= weight
+ return values
+
+
+
+class TileWorker2Dto3D:
+ """
+ Process 3D tensors, but only enable TileWorker on 2D.
+ """
+ def __init__(self):
+ pass
+
+
+ def build_mask(self, T, H, W, dtype, device, is_bound, border_width):
+ t = repeat(torch.arange(T), "T -> T H W", T=T, H=H, W=W)
+ h = repeat(torch.arange(H), "H -> T H W", T=T, H=H, W=W)
+ w = repeat(torch.arange(W), "W -> T H W", T=T, H=H, W=W)
+ border_width = (H + W) // 4 if border_width is None else border_width
+ pad = torch.ones_like(h) * border_width
+ mask = torch.stack([
+ pad if is_bound[0] else t + 1,
+ pad if is_bound[1] else T - t,
+ pad if is_bound[2] else h + 1,
+ pad if is_bound[3] else H - h,
+ pad if is_bound[4] else w + 1,
+ pad if is_bound[5] else W - w
+ ]).min(dim=0).values
+ mask = mask.clip(1, border_width)
+ mask = (mask / border_width).to(dtype=dtype, device=device)
+ mask = rearrange(mask, "T H W -> 1 1 T H W")
+ return mask
+
+
+ def tiled_forward(
+ self,
+ forward_fn,
+ model_input,
+ tile_size, tile_stride,
+ tile_device="cpu", tile_dtype=torch.float32,
+ computation_device="cuda", computation_dtype=torch.float32,
+ border_width=None, scales=[1, 1, 1, 1],
+ progress_bar=lambda x:x
+ ):
+ B, C, T, H, W = model_input.shape
+ scale_C, scale_T, scale_H, scale_W = scales
+ tile_size_H, tile_size_W = tile_size
+ tile_stride_H, tile_stride_W = tile_stride
+
+ value = torch.zeros((B, int(C*scale_C), int(T*scale_T), int(H*scale_H), int(W*scale_W)), dtype=tile_dtype, device=tile_device)
+ weight = torch.zeros((1, 1, int(T*scale_T), int(H*scale_H), int(W*scale_W)), dtype=tile_dtype, device=tile_device)
+
+ # Split tasks
+ tasks = []
+ for h in range(0, H, tile_stride_H):
+ for w in range(0, W, tile_stride_W):
+ if (h-tile_stride_H >= 0 and h-tile_stride_H+tile_size_H >= H) or (w-tile_stride_W >= 0 and w-tile_stride_W+tile_size_W >= W):
+ continue
+ h_, w_ = h + tile_size_H, w + tile_size_W
+ if h_ > H: h, h_ = max(H - tile_size_H, 0), H
+ if w_ > W: w, w_ = max(W - tile_size_W, 0), W
+ tasks.append((h, h_, w, w_))
+
+ # Run
+ for hl, hr, wl, wr in progress_bar(tasks):
+ mask = self.build_mask(
+ int(T*scale_T), int((hr-hl)*scale_H), int((wr-wl)*scale_W),
+ tile_dtype, tile_device,
+ is_bound=(True, True, hl==0, hr>=H, wl==0, wr>=W),
+ border_width=border_width
+ )
+ grid_input = model_input[:, :, :, hl:hr, wl:wr].to(dtype=computation_dtype, device=computation_device)
+ grid_output = forward_fn(grid_input).to(dtype=tile_dtype, device=tile_device)
+ value[:, :, :, int(hl*scale_H):int(hr*scale_H), int(wl*scale_W):int(wr*scale_W)] += grid_output * mask
+ weight[:, :, :, int(hl*scale_H):int(hr*scale_H), int(wl*scale_W):int(wr*scale_W)] += mask
+ value = value / weight
+ return value
\ No newline at end of file
diff --git a/PusaV1/diffsynth/models/utils.py b/PusaV1/diffsynth/models/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..99f5dee14b4f4b8b422a5d7f3c2ce7da7e3c20d6
--- /dev/null
+++ b/PusaV1/diffsynth/models/utils.py
@@ -0,0 +1,182 @@
+import torch, os
+from safetensors import safe_open
+from contextlib import contextmanager
+import hashlib
+
+@contextmanager
+def init_weights_on_device(device = torch.device("meta"), include_buffers :bool = False):
+
+ old_register_parameter = torch.nn.Module.register_parameter
+ if include_buffers:
+ old_register_buffer = torch.nn.Module.register_buffer
+
+ def register_empty_parameter(module, name, param):
+ old_register_parameter(module, name, param)
+ if param is not None:
+ param_cls = type(module._parameters[name])
+ kwargs = module._parameters[name].__dict__
+ kwargs["requires_grad"] = param.requires_grad
+ module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
+
+ def register_empty_buffer(module, name, buffer, persistent=True):
+ old_register_buffer(module, name, buffer, persistent=persistent)
+ if buffer is not None:
+ module._buffers[name] = module._buffers[name].to(device)
+
+ def patch_tensor_constructor(fn):
+ def wrapper(*args, **kwargs):
+ kwargs["device"] = device
+ return fn(*args, **kwargs)
+
+ return wrapper
+
+ if include_buffers:
+ tensor_constructors_to_patch = {
+ torch_function_name: getattr(torch, torch_function_name)
+ for torch_function_name in ["empty", "zeros", "ones", "full"]
+ }
+ else:
+ tensor_constructors_to_patch = {}
+
+ try:
+ torch.nn.Module.register_parameter = register_empty_parameter
+ if include_buffers:
+ torch.nn.Module.register_buffer = register_empty_buffer
+ for torch_function_name in tensor_constructors_to_patch.keys():
+ setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
+ yield
+ finally:
+ torch.nn.Module.register_parameter = old_register_parameter
+ if include_buffers:
+ torch.nn.Module.register_buffer = old_register_buffer
+ for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():
+ setattr(torch, torch_function_name, old_torch_function)
+
+def load_state_dict_from_folder(file_path, torch_dtype=None):
+ state_dict = {}
+ for file_name in os.listdir(file_path):
+ if "." in file_name and file_name.split(".")[-1] in [
+ "safetensors", "bin", "ckpt", "pth", "pt"
+ ]:
+ state_dict.update(load_state_dict(os.path.join(file_path, file_name), torch_dtype=torch_dtype))
+ return state_dict
+
+
+def load_state_dict(file_path, torch_dtype=None):
+ if file_path.endswith(".safetensors"):
+ return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype)
+ else:
+ return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype)
+
+
+def load_state_dict_from_safetensors(file_path, torch_dtype=None):
+ state_dict = {}
+ with safe_open(file_path, framework="pt", device="cpu") as f:
+ for k in f.keys():
+ state_dict[k] = f.get_tensor(k)
+ if torch_dtype is not None:
+ state_dict[k] = state_dict[k].to(torch_dtype)
+ return state_dict
+
+
+def load_state_dict_from_bin(file_path, torch_dtype=None):
+ state_dict = torch.load(file_path, map_location="cpu", weights_only=True)
+ if torch_dtype is not None:
+ for i in state_dict:
+ if isinstance(state_dict[i], torch.Tensor):
+ state_dict[i] = state_dict[i].to(torch_dtype)
+ return state_dict
+
+
+def search_for_embeddings(state_dict):
+ embeddings = []
+ for k in state_dict:
+ if isinstance(state_dict[k], torch.Tensor):
+ embeddings.append(state_dict[k])
+ elif isinstance(state_dict[k], dict):
+ embeddings += search_for_embeddings(state_dict[k])
+ return embeddings
+
+
+def search_parameter(param, state_dict):
+ for name, param_ in state_dict.items():
+ if param.numel() == param_.numel():
+ if param.shape == param_.shape:
+ if torch.dist(param, param_) < 1e-3:
+ return name
+ else:
+ if torch.dist(param.flatten(), param_.flatten()) < 1e-3:
+ return name
+ return None
+
+
+def build_rename_dict(source_state_dict, target_state_dict, split_qkv=False):
+ matched_keys = set()
+ with torch.no_grad():
+ for name in source_state_dict:
+ rename = search_parameter(source_state_dict[name], target_state_dict)
+ if rename is not None:
+ print(f'"{name}": "{rename}",')
+ matched_keys.add(rename)
+ elif split_qkv and len(source_state_dict[name].shape)>=1 and source_state_dict[name].shape[0]%3==0:
+ length = source_state_dict[name].shape[0] // 3
+ rename = []
+ for i in range(3):
+ rename.append(search_parameter(source_state_dict[name][i*length: i*length+length], target_state_dict))
+ if None not in rename:
+ print(f'"{name}": {rename},')
+ for rename_ in rename:
+ matched_keys.add(rename_)
+ for name in target_state_dict:
+ if name not in matched_keys:
+ print("Cannot find", name, target_state_dict[name].shape)
+
+
+def search_for_files(folder, extensions):
+ files = []
+ if os.path.isdir(folder):
+ for file in sorted(os.listdir(folder)):
+ files += search_for_files(os.path.join(folder, file), extensions)
+ elif os.path.isfile(folder):
+ for extension in extensions:
+ if folder.endswith(extension):
+ files.append(folder)
+ break
+ return files
+
+
+def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
+ keys = []
+ for key, value in state_dict.items():
+ if isinstance(key, str):
+ if isinstance(value, torch.Tensor):
+ if with_shape:
+ shape = "_".join(map(str, list(value.shape)))
+ keys.append(key + ":" + shape)
+ keys.append(key)
+ elif isinstance(value, dict):
+ keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape))
+ keys.sort()
+ keys_str = ",".join(keys)
+ return keys_str
+
+
+def split_state_dict_with_prefix(state_dict):
+ keys = sorted([key for key in state_dict if isinstance(key, str)])
+ prefix_dict = {}
+ for key in keys:
+ prefix = key if "." not in key else key.split(".")[0]
+ if prefix not in prefix_dict:
+ prefix_dict[prefix] = []
+ prefix_dict[prefix].append(key)
+ state_dicts = []
+ for prefix, keys in prefix_dict.items():
+ sub_state_dict = {key: state_dict[key] for key in keys}
+ state_dicts.append(sub_state_dict)
+ return state_dicts
+
+
+def hash_state_dict_keys(state_dict, with_shape=True):
+ keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
+ keys_str = keys_str.encode(encoding="UTF-8")
+ return hashlib.md5(keys_str).hexdigest()
\ No newline at end of file
diff --git a/PusaV1/diffsynth/models/wan_video_dit.py b/PusaV1/diffsynth/models/wan_video_dit.py
new file mode 100644
index 0000000000000000000000000000000000000000..a36b206ca9257d4de79a7746661157c994fca67c
--- /dev/null
+++ b/PusaV1/diffsynth/models/wan_video_dit.py
@@ -0,0 +1,666 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import math
+from typing import Tuple, Optional
+from einops import rearrange
+from .utils import hash_state_dict_keys
+try:
+ import flash_attn_interface
+ FLASH_ATTN_3_AVAILABLE = True
+except ModuleNotFoundError:
+ FLASH_ATTN_3_AVAILABLE = False
+
+try:
+ import flash_attn
+ FLASH_ATTN_2_AVAILABLE = True
+except ModuleNotFoundError:
+ FLASH_ATTN_2_AVAILABLE = False
+
+try:
+ from sageattention import sageattn
+ SAGE_ATTN_AVAILABLE = True
+except ModuleNotFoundError:
+ SAGE_ATTN_AVAILABLE = False
+
+
+_VISUALIZE_ATTENTION_CONFIG = {
+ "enabled": False, "path": None, "step": 0, "block_name": "", "attn_type": "", "grid_size": None,
+}
+
+
+def _visualize_cross_attention_from_center(q, k, config):
+ try:
+ import matplotlib.pyplot as plt
+ import seaborn as sns
+ import os
+ except ImportError:
+ print("Please install matplotlib and seaborn to visualize attention maps.")
+ _VISUALIZE_ATTENTION_CONFIG["enabled"] = False
+ return
+
+ f, h, w = config["grid_size"]
+ query_patch_idx_t = f // 2
+ query_patch_idx_h = h // 2
+ query_patch_idx_w = w // 2
+ query_patch_idx = query_patch_idx_t * (h * w) + query_patch_idx_h * w + query_patch_idx_w
+
+ b, n_heads, s_q, d_head = q.shape
+ if query_patch_idx >= s_q:
+ return
+
+ q_center = q[:, :, query_patch_idx:query_patch_idx+1, :]
+
+ attn_scores = torch.matmul(q_center, k.transpose(-2, -1)) / math.sqrt(d_head)
+ attn_weights = F.softmax(attn_scores, dim=-1)
+
+ token_attention = attn_weights.mean(dim=(0, 1)).squeeze(0).detach().float().cpu().numpy()
+
+ sub_type = config.get("sub_attn_type", "text")
+ path_prefix = os.path.join(config["path"], f'{config["block_name"]}_cross_attn_{sub_type}_step{config["step"]}')
+
+ plt.figure(figsize=(16, 2))
+ sns.heatmap(token_attention[None, :], cmap="viridis", cbar=True)
+ plt.title(f'Cross-Attention: {sub_type} (from center patch)\n{config["block_name"]}, step {config["step"]}')
+ plt.xlabel("Key token index")
+ plt.ylabel("Query patch")
+ plt.tight_layout()
+ plt.savefig(f"{path_prefix}_center_patch.png")
+ plt.close()
+
+def _visualize_frame_self_attention(q, k, config):
+ try:
+ import matplotlib.pyplot as plt
+ import seaborn as sns
+ import os
+ except ImportError:
+ print("Please install matplotlib and seaborn to visualize attention maps.")
+ _VISUALIZE_ATTENTION_CONFIG["enabled"] = False
+ return
+
+ b, n_heads, s, d_head = q.shape
+ f, h, w = config["grid_size"]
+ s_frame = h * w
+ if s != f * h * w:
+ return
+
+ q_frames = q.view(b, n_heads, f, s_frame, d_head)
+ k_frames = k.view(b, n_heads, f, s_frame, d_head)
+
+ # Directly average first is equivalent to first calculate all tokens attention then average each frame
+ q_frame_avg = q_frames.mean(dim=3)
+ k_frame_avg = k_frames.mean(dim=3)
+
+ frame_similarity_map = torch.einsum('bhid,bhjd->bhij', q_frame_avg, k_frame_avg) / math.sqrt(d_head)
+
+ frame_attention_map = F.softmax(frame_similarity_map, dim=-1)
+ frame_attention_map = frame_attention_map.mean(dim=(0,1)).detach().float().cpu().numpy()
+
+ path_prefix = os.path.join(config["path"], f'{config["block_name"]}_self_attn_step{config["step"]}')
+ plt.figure(figsize=(10, 8))
+ sns.heatmap(frame_attention_map, cmap="viridis", cbar=True, annot=True, fmt=".2f")
+ plt.title(f'Frame-to-Frame Self-Attention\n{config["block_name"]}, step {config["step"]}')
+ plt.xlabel("Key Frame Index")
+ plt.ylabel("Query Frame Index")
+ plt.tight_layout()
+ plt.savefig(f"{path_prefix}_frame_similarity.png")
+ plt.close()
+
+
+def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, compatibility_mode=False):
+ if _VISUALIZE_ATTENTION_CONFIG["enabled"]:
+ config = _VISUALIZE_ATTENTION_CONFIG
+ with torch.no_grad():
+ q_vis = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
+ k_vis = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
+
+ if config['attn_type'] == 'self':
+ _visualize_frame_self_attention(q_vis, k_vis, config)
+ elif config['attn_type'] == 'cross':
+ _visualize_cross_attention_from_center(q_vis, k_vis, config)
+
+ if compatibility_mode:
+ q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
+ k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
+ v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
+ x = F.scaled_dot_product_attention(q, k, v)
+ x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
+ elif FLASH_ATTN_3_AVAILABLE:
+ q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
+ k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
+ v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
+ x = flash_attn_interface.flash_attn_func(q, k, v)
+ if isinstance(x,tuple):
+ x = x[0]
+ x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
+ elif FLASH_ATTN_2_AVAILABLE:
+ q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
+ k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
+ v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
+ x = flash_attn.flash_attn_func(q, k, v)
+ x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
+ elif SAGE_ATTN_AVAILABLE:
+ q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
+ k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
+ v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
+ x = sageattn(q, k, v)
+ x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
+ else:
+ q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
+ k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
+ v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
+ x = F.scaled_dot_product_attention(q, k, v)
+ x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
+ return x
+
+
+def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
+ return (x * (1 + scale) + shift)
+
+
+def sinusoidal_embedding_1d(dim, position):
+ sinusoid = torch.outer(position.type(torch.float64), torch.pow(
+ 10000, -torch.arange(dim//2, dtype=torch.float64, device=position.device).div(dim//2)))
+ x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
+ return x.to(position.dtype)
+
+
+def precompute_freqs_cis_3d(dim: int, end: int = 1024, theta: float = 10000.0):
+ # 3d rope precompute
+ f_freqs_cis = precompute_freqs_cis(dim - 2 * (dim // 3), end, theta)
+ h_freqs_cis = precompute_freqs_cis(dim // 3, end, theta)
+ w_freqs_cis = precompute_freqs_cis(dim // 3, end, theta)
+ return f_freqs_cis, h_freqs_cis, w_freqs_cis
+
+
+def precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000.0):
+ # 1d rope precompute
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)
+ [: (dim // 2)].double() / dim))
+ freqs = torch.outer(torch.arange(end, device=freqs.device), freqs)
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
+ return freqs_cis
+
+
+def rope_apply(x, freqs, num_heads):
+ x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
+ x_out = torch.view_as_complex(x.to(torch.float64).reshape(
+ x.shape[0], x.shape[1], x.shape[2], -1, 2))
+ x_out = torch.view_as_real(x_out * freqs).flatten(2)
+ return x_out.to(x.dtype)
+
+
+class RMSNorm(nn.Module):
+ def __init__(self, dim, eps=1e-5):
+ super().__init__()
+ self.eps = eps
+ self.weight = nn.Parameter(torch.ones(dim))
+
+ def norm(self, x):
+ return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
+
+ def forward(self, x):
+ dtype = x.dtype
+ return self.norm(x.float()).to(dtype) * self.weight
+
+
+class AttentionModule(nn.Module):
+ def __init__(self, num_heads):
+ super().__init__()
+ self.num_heads = num_heads
+
+ def forward(self, q, k, v):
+ x = flash_attention(q=q, k=k, v=v, num_heads=self.num_heads)
+ return x
+
+
+class SelfAttention(nn.Module):
+ def __init__(self, dim: int, num_heads: int, eps: float = 1e-6):
+ super().__init__()
+ self.dim = dim
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+
+ self.q = nn.Linear(dim, dim)
+ self.k = nn.Linear(dim, dim)
+ self.v = nn.Linear(dim, dim)
+ self.o = nn.Linear(dim, dim)
+ self.norm_q = RMSNorm(dim, eps=eps)
+ self.norm_k = RMSNorm(dim, eps=eps)
+
+ self.attn = AttentionModule(self.num_heads)
+
+ def forward(self, x, freqs):
+ q = self.norm_q(self.q(x))
+ k = self.norm_k(self.k(x))
+ v = self.v(x)
+ q = rope_apply(q, freqs, self.num_heads)
+ k = rope_apply(k, freqs, self.num_heads)
+ x = self.attn(q, k, v)
+ return self.o(x)
+
+
+class CrossAttention(nn.Module):
+ def __init__(self, dim: int, num_heads: int, eps: float = 1e-6, has_image_input: bool = False):
+ super().__init__()
+ self.dim = dim
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+
+ self.q = nn.Linear(dim, dim)
+ self.k = nn.Linear(dim, dim)
+ self.v = nn.Linear(dim, dim)
+ self.o = nn.Linear(dim, dim)
+ self.norm_q = RMSNorm(dim, eps=eps)
+ self.norm_k = RMSNorm(dim, eps=eps)
+ self.has_image_input = has_image_input
+ if has_image_input:
+ self.k_img = nn.Linear(dim, dim)
+ self.v_img = nn.Linear(dim, dim)
+ self.norm_k_img = RMSNorm(dim, eps=eps)
+
+ self.attn = AttentionModule(self.num_heads)
+
+ def forward(self, x: torch.Tensor, y: torch.Tensor):
+ if self.has_image_input:
+ img = y[:, :257]
+ ctx = y[:, 257:]
+ else:
+ ctx = y
+ q = self.norm_q(self.q(x))
+ k = self.norm_k(self.k(ctx))
+ v = self.v(ctx)
+ x = self.attn(q, k, v)
+ if self.has_image_input:
+ k_img = self.norm_k_img(self.k_img(img))
+ v_img = self.v_img(img)
+ y = flash_attention(q, k_img, v_img, num_heads=self.num_heads)
+ x = x + y
+ return self.o(x)
+
+
+class GateModule(nn.Module):
+ def __init__(self,):
+ super().__init__()
+
+ def forward(self, x, gate, residual):
+ return x + gate * residual
+
+class DiTBlock(nn.Module):
+ def __init__(self, has_image_input: bool, dim: int, num_heads: int, ffn_dim: int, eps: float = 1e-6):
+ super().__init__()
+ self.dim = dim
+ self.num_heads = num_heads
+ self.ffn_dim = ffn_dim
+ self.block_name = ""
+
+ self.self_attn = SelfAttention(dim, num_heads, eps)
+ self.cross_attn = CrossAttention(
+ dim, num_heads, eps, has_image_input=has_image_input)
+ self.norm1 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
+ self.norm2 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
+ self.norm3 = nn.LayerNorm(dim, eps=eps)
+ self.ffn = nn.Sequential(nn.Linear(dim, ffn_dim), nn.GELU(
+ approximate='tanh'), nn.Linear(ffn_dim, dim))
+ self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
+ self.gate = GateModule()
+
+ def forward(self, x, context, t_mod, freqs):
+ # msa: multi-head self-attention mlp: multi-layer perceptron
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
+ self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=1)
+
+ if _VISUALIZE_ATTENTION_CONFIG["enabled"]:
+ _VISUALIZE_ATTENTION_CONFIG["block_name"] = self.block_name
+ _VISUALIZE_ATTENTION_CONFIG["attn_type"] = "self"
+ input_x = modulate(self.norm1(x), shift_msa, scale_msa)
+ x = self.gate(x, gate_msa, self.self_attn(input_x, freqs))
+
+ if _VISUALIZE_ATTENTION_CONFIG["enabled"]:
+ _VISUALIZE_ATTENTION_CONFIG["attn_type"] = "cross"
+ x = x + self.cross_attn(self.norm3(x), context)
+ input_x = modulate(self.norm2(x), shift_mlp, scale_mlp)
+ x = self.gate(x, gate_mlp, self.ffn(input_x))
+ return x
+
+
+class MLP(torch.nn.Module):
+ def __init__(self, in_dim, out_dim):
+ super().__init__()
+ self.proj = torch.nn.Sequential(
+ nn.LayerNorm(in_dim),
+ nn.Linear(in_dim, in_dim),
+ nn.GELU(),
+ nn.Linear(in_dim, out_dim),
+ nn.LayerNorm(out_dim)
+ )
+
+ def forward(self, x):
+ return self.proj(x)
+
+
+class Head(nn.Module):
+ def __init__(self, dim: int, out_dim: int, patch_size: Tuple[int, int, int], eps: float):
+ super().__init__()
+ self.dim = dim
+ self.patch_size = patch_size
+ self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
+ self.head = nn.Linear(dim, out_dim * math.prod(patch_size))
+ self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
+
+ def forward(self, x, t_mod):
+ shift, scale = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(2, dim=1)
+ x = (self.head(self.norm(x) * (1 + scale) + shift))
+ return x
+
+
+class WanModel(torch.nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ in_dim: int,
+ ffn_dim: int,
+ out_dim: int,
+ text_dim: int,
+ freq_dim: int,
+ eps: float,
+ patch_size: Tuple[int, int, int],
+ num_heads: int,
+ num_layers: int,
+ has_image_input: bool,
+ ):
+ super().__init__()
+ self.dim = dim
+ self.freq_dim = freq_dim
+ self.has_image_input = has_image_input
+ self.patch_size = patch_size
+
+ self.patch_embedding = nn.Conv3d(
+ in_dim, dim, kernel_size=patch_size, stride=patch_size)
+ self.text_embedding = nn.Sequential(
+ nn.Linear(text_dim, dim),
+ nn.GELU(approximate='tanh'),
+ nn.Linear(dim, dim)
+ )
+ self.time_embedding = nn.Sequential(
+ nn.Linear(freq_dim, dim),
+ nn.SiLU(),
+ nn.Linear(dim, dim)
+ )
+ self.time_projection = nn.Sequential(
+ nn.SiLU(), nn.Linear(dim, dim * 6))
+ self.blocks = nn.ModuleList([
+ DiTBlock(has_image_input, dim, num_heads, ffn_dim, eps)
+ for _ in range(num_layers)
+ ])
+ for i, block in enumerate(self.blocks):
+ block.block_name = f"block_{i}"
+ self.head = Head(dim, out_dim, patch_size, eps)
+ head_dim = dim // num_heads
+ self.freqs = precompute_freqs_cis_3d(head_dim)
+
+ if has_image_input:
+ self.img_emb = MLP(1280, dim) # clip_feature_dim = 1280
+
+ def patchify(self, x: torch.Tensor):
+ x = self.patch_embedding(x)
+ grid_size = x.shape[2:]
+ x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous()
+ return x, grid_size # x, grid_size: (f, h, w)
+
+ def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor):
+ return rearrange(
+ x, 'b (f h w) (x y z c) -> b c (f x) (h y) (w z)',
+ f=grid_size[0], h=grid_size[1], w=grid_size[2],
+ x=self.patch_size[0], y=self.patch_size[1], z=self.patch_size[2]
+ )
+
+ def forward(self,
+ x: torch.Tensor,
+ timestep: torch.Tensor,
+ context: torch.Tensor,
+ clip_feature: Optional[torch.Tensor] = None,
+ y: Optional[torch.Tensor] = None,
+ use_gradient_checkpointing: bool = False,
+ use_gradient_checkpointing_offload: bool = False,
+ **kwargs,
+ ):
+ t = self.time_embedding(
+ sinusoidal_embedding_1d(self.freq_dim, timestep))
+ t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
+ context = self.text_embedding(context)
+
+ if self.has_image_input:
+ x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
+ clip_embdding = self.img_emb(clip_feature)
+ context = torch.cat([clip_embdding, context], dim=1)
+
+ x, (f, h, w) = self.patchify(x)
+
+ if _VISUALIZE_ATTENTION_CONFIG["enabled"]:
+ _VISUALIZE_ATTENTION_CONFIG["grid_size"] = (f, h, w)
+ if timestep.numel() == 1:
+ _VISUALIZE_ATTENTION_CONFIG["step"] = int(timestep.item())
+
+ freqs = torch.cat([
+ self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
+ self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
+ self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
+ ], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+ return custom_forward
+
+ for block in self.blocks:
+ if self.training and use_gradient_checkpointing:
+ if use_gradient_checkpointing_offload:
+ with torch.autograd.graph.save_on_cpu():
+ x = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ x, context, t_mod, freqs,
+ use_reentrant=False,
+ )
+ else:
+ x = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ x, context, t_mod, freqs,
+ use_reentrant=False,
+ )
+ else:
+ x = block(x, context, t_mod, freqs)
+
+ x = self.head(x, t)
+ x = self.unpatchify(x, (f, h, w))
+ return x
+
+ @staticmethod
+ def state_dict_converter():
+ return WanModelStateDictConverter()
+
+
+class WanModelStateDictConverter:
+ def __init__(self):
+ pass
+
+ def from_diffusers(self, state_dict):
+ rename_dict = {
+ "blocks.0.attn1.norm_k.weight": "blocks.0.self_attn.norm_k.weight",
+ "blocks.0.attn1.norm_q.weight": "blocks.0.self_attn.norm_q.weight",
+ "blocks.0.attn1.to_k.bias": "blocks.0.self_attn.k.bias",
+ "blocks.0.attn1.to_k.weight": "blocks.0.self_attn.k.weight",
+ "blocks.0.attn1.to_out.0.bias": "blocks.0.self_attn.o.bias",
+ "blocks.0.attn1.to_out.0.weight": "blocks.0.self_attn.o.weight",
+ "blocks.0.attn1.to_q.bias": "blocks.0.self_attn.q.bias",
+ "blocks.0.attn1.to_q.weight": "blocks.0.self_attn.q.weight",
+ "blocks.0.attn1.to_v.bias": "blocks.0.self_attn.v.bias",
+ "blocks.0.attn1.to_v.weight": "blocks.0.self_attn.v.weight",
+ "blocks.0.attn2.norm_k.weight": "blocks.0.cross_attn.norm_k.weight",
+ "blocks.0.attn2.norm_q.weight": "blocks.0.cross_attn.norm_q.weight",
+ "blocks.0.attn2.to_k.bias": "blocks.0.cross_attn.k.bias",
+ "blocks.0.attn2.to_k.weight": "blocks.0.cross_attn.k.weight",
+ "blocks.0.attn2.to_out.0.bias": "blocks.0.cross_attn.o.bias",
+ "blocks.0.attn2.to_out.0.weight": "blocks.0.cross_attn.o.weight",
+ "blocks.0.attn2.to_q.bias": "blocks.0.cross_attn.q.bias",
+ "blocks.0.attn2.to_q.weight": "blocks.0.cross_attn.q.weight",
+ "blocks.0.attn2.to_v.bias": "blocks.0.cross_attn.v.bias",
+ "blocks.0.attn2.to_v.weight": "blocks.0.cross_attn.v.weight",
+ "blocks.0.ffn.net.0.proj.bias": "blocks.0.ffn.0.bias",
+ "blocks.0.ffn.net.0.proj.weight": "blocks.0.ffn.0.weight",
+ "blocks.0.ffn.net.2.bias": "blocks.0.ffn.2.bias",
+ "blocks.0.ffn.net.2.weight": "blocks.0.ffn.2.weight",
+ "blocks.0.norm2.bias": "blocks.0.norm3.bias",
+ "blocks.0.norm2.weight": "blocks.0.norm3.weight",
+ "blocks.0.scale_shift_table": "blocks.0.modulation",
+ "condition_embedder.text_embedder.linear_1.bias": "text_embedding.0.bias",
+ "condition_embedder.text_embedder.linear_1.weight": "text_embedding.0.weight",
+ "condition_embedder.text_embedder.linear_2.bias": "text_embedding.2.bias",
+ "condition_embedder.text_embedder.linear_2.weight": "text_embedding.2.weight",
+ "condition_embedder.time_embedder.linear_1.bias": "time_embedding.0.bias",
+ "condition_embedder.time_embedder.linear_1.weight": "time_embedding.0.weight",
+ "condition_embedder.time_embedder.linear_2.bias": "time_embedding.2.bias",
+ "condition_embedder.time_embedder.linear_2.weight": "time_embedding.2.weight",
+ "condition_embedder.time_proj.bias": "time_projection.1.bias",
+ "condition_embedder.time_proj.weight": "time_projection.1.weight",
+ "patch_embedding.bias": "patch_embedding.bias",
+ "patch_embedding.weight": "patch_embedding.weight",
+ "scale_shift_table": "head.modulation",
+ "proj_out.bias": "head.head.bias",
+ "proj_out.weight": "head.head.weight",
+ }
+ state_dict_ = {}
+ for name, param in state_dict.items():
+ if name in rename_dict:
+ state_dict_[rename_dict[name]] = param
+ else:
+ name_ = ".".join(name.split(".")[:1] + ["0"] + name.split(".")[2:])
+ if name_ in rename_dict:
+ name_ = rename_dict[name_]
+ name_ = ".".join(name_.split(".")[:1] + [name.split(".")[1]] + name_.split(".")[2:])
+ state_dict_[name_] = param
+ if hash_state_dict_keys(state_dict) == "cb104773c6c2cb6df4f9529ad5c60d0b":
+ config = {
+ "model_type": "t2v",
+ "patch_size": (1, 2, 2),
+ "text_len": 512,
+ "in_dim": 16,
+ "dim": 5120,
+ "ffn_dim": 13824,
+ "freq_dim": 256,
+ "text_dim": 4096,
+ "out_dim": 16,
+ "num_heads": 40,
+ "num_layers": 40,
+ "window_size": (-1, -1),
+ "qk_norm": True,
+ "cross_attn_norm": True,
+ "eps": 1e-6,
+ }
+ else:
+ config = {}
+ return state_dict_, config
+
+ def from_civitai(self, state_dict):
+ state_dict = {name: param for name, param in state_dict.items() if not name.startswith("vace")}
+ if hash_state_dict_keys(state_dict) == "9269f8db9040a9d860eaca435be61814":
+ config = {
+ "has_image_input": False,
+ "patch_size": [1, 2, 2],
+ "in_dim": 16,
+ "dim": 1536,
+ "ffn_dim": 8960,
+ "freq_dim": 256,
+ "text_dim": 4096,
+ "out_dim": 16,
+ "num_heads": 12,
+ "num_layers": 30,
+ "eps": 1e-6
+ }
+ elif hash_state_dict_keys(state_dict) == "aafcfd9672c3a2456dc46e1cb6e52c70":
+ config = {
+ "has_image_input": False,
+ "patch_size": [1, 2, 2],
+ "in_dim": 16,
+ "dim": 5120,
+ "ffn_dim": 13824,
+ "freq_dim": 256,
+ "text_dim": 4096,
+ "out_dim": 16,
+ "num_heads": 40,
+ "num_layers": 40,
+ "eps": 1e-6
+ }
+ elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e":
+ config = {
+ "has_image_input": True,
+ "patch_size": [1, 2, 2],
+ "in_dim": 36,
+ "dim": 5120,
+ "ffn_dim": 13824,
+ "freq_dim": 256,
+ "text_dim": 4096,
+ "out_dim": 16,
+ "num_heads": 40,
+ "num_layers": 40,
+ "eps": 1e-6
+ }
+ elif hash_state_dict_keys(state_dict) == "6d6ccde6845b95ad9114ab993d917893":
+ config = {
+ "has_image_input": True,
+ "patch_size": [1, 2, 2],
+ "in_dim": 36,
+ "dim": 1536,
+ "ffn_dim": 8960,
+ "freq_dim": 256,
+ "text_dim": 4096,
+ "out_dim": 16,
+ "num_heads": 12,
+ "num_layers": 30,
+ "eps": 1e-6
+ }
+ elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e":
+ config = {
+ "has_image_input": True,
+ "patch_size": [1, 2, 2],
+ "in_dim": 36,
+ "dim": 5120,
+ "ffn_dim": 13824,
+ "freq_dim": 256,
+ "text_dim": 4096,
+ "out_dim": 16,
+ "num_heads": 40,
+ "num_layers": 40,
+ "eps": 1e-6
+ }
+ elif hash_state_dict_keys(state_dict) == "349723183fc063b2bfc10bb2835cf677":
+ config = {
+ "has_image_input": True,
+ "patch_size": [1, 2, 2],
+ "in_dim": 48,
+ "dim": 1536,
+ "ffn_dim": 8960,
+ "freq_dim": 256,
+ "text_dim": 4096,
+ "out_dim": 16,
+ "num_heads": 12,
+ "num_layers": 30,
+ "eps": 1e-6
+ }
+ elif hash_state_dict_keys(state_dict) == "efa44cddf936c70abd0ea28b6cbe946c":
+ config = {
+ "has_image_input": True,
+ "patch_size": [1, 2, 2],
+ "in_dim": 48,
+ "dim": 5120,
+ "ffn_dim": 13824,
+ "freq_dim": 256,
+ "text_dim": 4096,
+ "out_dim": 16,
+ "num_heads": 40,
+ "num_layers": 40,
+ "eps": 1e-6
+ }
+ else:
+ config = {}
+ return state_dict, config
diff --git a/PusaV1/diffsynth/models/wan_video_image_encoder.py b/PusaV1/diffsynth/models/wan_video_image_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ca878b1fd6ed6dc00420f092f87479fb65ef63a
--- /dev/null
+++ b/PusaV1/diffsynth/models/wan_video_image_encoder.py
@@ -0,0 +1,902 @@
+"""
+Concise re-implementation of
+``https://github.com/openai/CLIP'' and
+``https://github.com/mlfoundations/open_clip''.
+"""
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torchvision.transforms as T
+from .wan_video_dit import flash_attention
+
+
+class SelfAttention(nn.Module):
+
+ def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5):
+ assert dim % num_heads == 0
+ super().__init__()
+ self.dim = dim
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.eps = eps
+
+ # layers
+ self.q = nn.Linear(dim, dim)
+ self.k = nn.Linear(dim, dim)
+ self.v = nn.Linear(dim, dim)
+ self.o = nn.Linear(dim, dim)
+ self.dropout = nn.Dropout(dropout)
+
+ def forward(self, x, mask):
+ """
+ x: [B, L, C].
+ """
+ b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
+
+ # compute query, key, value
+ q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
+ k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
+ v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
+
+ # compute attention
+ p = self.dropout.p if self.training else 0.0
+ x = F.scaled_dot_product_attention(q, k, v, mask, p)
+ x = x.permute(0, 2, 1, 3).reshape(b, s, c)
+
+ # output
+ x = self.o(x)
+ x = self.dropout(x)
+ return x
+
+
+class AttentionBlock(nn.Module):
+
+ def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5):
+ super().__init__()
+ self.dim = dim
+ self.num_heads = num_heads
+ self.post_norm = post_norm
+ self.eps = eps
+
+ # layers
+ self.attn = SelfAttention(dim, num_heads, dropout, eps)
+ self.norm1 = nn.LayerNorm(dim, eps=eps)
+ self.ffn = nn.Sequential(
+ nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim),
+ nn.Dropout(dropout))
+ self.norm2 = nn.LayerNorm(dim, eps=eps)
+
+ def forward(self, x, mask):
+ if self.post_norm:
+ x = self.norm1(x + self.attn(x, mask))
+ x = self.norm2(x + self.ffn(x))
+ else:
+ x = x + self.attn(self.norm1(x), mask)
+ x = x + self.ffn(self.norm2(x))
+ return x
+
+
+class XLMRoberta(nn.Module):
+ """
+ XLMRobertaModel with no pooler and no LM head.
+ """
+
+ def __init__(self,
+ vocab_size=250002,
+ max_seq_len=514,
+ type_size=1,
+ pad_id=1,
+ dim=1024,
+ num_heads=16,
+ num_layers=24,
+ post_norm=True,
+ dropout=0.1,
+ eps=1e-5):
+ super().__init__()
+ self.vocab_size = vocab_size
+ self.max_seq_len = max_seq_len
+ self.type_size = type_size
+ self.pad_id = pad_id
+ self.dim = dim
+ self.num_heads = num_heads
+ self.num_layers = num_layers
+ self.post_norm = post_norm
+ self.eps = eps
+
+ # embeddings
+ self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id)
+ self.type_embedding = nn.Embedding(type_size, dim)
+ self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id)
+ self.dropout = nn.Dropout(dropout)
+
+ # blocks
+ self.blocks = nn.ModuleList([
+ AttentionBlock(dim, num_heads, post_norm, dropout, eps)
+ for _ in range(num_layers)
+ ])
+
+ # norm layer
+ self.norm = nn.LayerNorm(dim, eps=eps)
+
+ def forward(self, ids):
+ """
+ ids: [B, L] of torch.LongTensor.
+ """
+ b, s = ids.shape
+ mask = ids.ne(self.pad_id).long()
+
+ # embeddings
+ x = self.token_embedding(ids) + \
+ self.type_embedding(torch.zeros_like(ids)) + \
+ self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask)
+ if self.post_norm:
+ x = self.norm(x)
+ x = self.dropout(x)
+
+ # blocks
+ mask = torch.where(
+ mask.view(b, 1, 1, s).gt(0), 0.0,
+ torch.finfo(x.dtype).min)
+ for block in self.blocks:
+ x = block(x, mask)
+
+ # output
+ if not self.post_norm:
+ x = self.norm(x)
+ return x
+
+
+def xlm_roberta_large(pretrained=False,
+ return_tokenizer=False,
+ device='cpu',
+ **kwargs):
+ """
+ XLMRobertaLarge adapted from Huggingface.
+ """
+ # params
+ cfg = dict(
+ vocab_size=250002,
+ max_seq_len=514,
+ type_size=1,
+ pad_id=1,
+ dim=1024,
+ num_heads=16,
+ num_layers=24,
+ post_norm=True,
+ dropout=0.1,
+ eps=1e-5)
+ cfg.update(**kwargs)
+
+ # init model
+ if pretrained:
+ from sora import DOWNLOAD_TO_CACHE
+
+ # init a meta model
+ with torch.device('meta'):
+ model = XLMRoberta(**cfg)
+
+ # load checkpoint
+ model.load_state_dict(
+ torch.load(
+ DOWNLOAD_TO_CACHE('models/xlm_roberta/xlm_roberta_large.pth'),
+ map_location=device),
+ assign=True)
+ else:
+ # init a model on device
+ with torch.device(device):
+ model = XLMRoberta(**cfg)
+
+ # init tokenizer
+ if return_tokenizer:
+ from sora.data import HuggingfaceTokenizer
+ tokenizer = HuggingfaceTokenizer(
+ name='xlm-roberta-large',
+ seq_len=model.text_len,
+ clean='whitespace')
+ return model, tokenizer
+ else:
+ return model
+
+
+
+def pos_interpolate(pos, seq_len):
+ if pos.size(1) == seq_len:
+ return pos
+ else:
+ src_grid = int(math.sqrt(pos.size(1)))
+ tar_grid = int(math.sqrt(seq_len))
+ n = pos.size(1) - src_grid * src_grid
+ return torch.cat([
+ pos[:, :n],
+ F.interpolate(
+ pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute(
+ 0, 3, 1, 2),
+ size=(tar_grid, tar_grid),
+ mode='bicubic',
+ align_corners=False).flatten(2).transpose(1, 2)
+ ],
+ dim=1)
+
+
+class QuickGELU(nn.Module):
+
+ def forward(self, x):
+ return x * torch.sigmoid(1.702 * x)
+
+
+class LayerNorm(nn.LayerNorm):
+
+ def forward(self, x):
+ return super().forward(x).type_as(x)
+
+
+class SelfAttention(nn.Module):
+
+ def __init__(self,
+ dim,
+ num_heads,
+ causal=False,
+ attn_dropout=0.0,
+ proj_dropout=0.0):
+ assert dim % num_heads == 0
+ super().__init__()
+ self.dim = dim
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.causal = causal
+ self.attn_dropout = attn_dropout
+ self.proj_dropout = proj_dropout
+
+ # layers
+ self.to_qkv = nn.Linear(dim, dim * 3)
+ self.proj = nn.Linear(dim, dim)
+
+ def forward(self, x):
+ """
+ x: [B, L, C].
+ """
+ # compute query, key, value
+ q, k, v = self.to_qkv(x).chunk(3, dim=-1)
+
+ # compute attention
+ x = flash_attention(q, k, v, num_heads=self.num_heads, compatibility_mode=True)
+
+ # output
+ x = self.proj(x)
+ x = F.dropout(x, self.proj_dropout, self.training)
+ return x
+
+
+class SwiGLU(nn.Module):
+
+ def __init__(self, dim, mid_dim):
+ super().__init__()
+ self.dim = dim
+ self.mid_dim = mid_dim
+
+ # layers
+ self.fc1 = nn.Linear(dim, mid_dim)
+ self.fc2 = nn.Linear(dim, mid_dim)
+ self.fc3 = nn.Linear(mid_dim, dim)
+
+ def forward(self, x):
+ x = F.silu(self.fc1(x)) * self.fc2(x)
+ x = self.fc3(x)
+ return x
+
+
+class AttentionBlock(nn.Module):
+
+ def __init__(self,
+ dim,
+ mlp_ratio,
+ num_heads,
+ post_norm=False,
+ causal=False,
+ activation='quick_gelu',
+ attn_dropout=0.0,
+ proj_dropout=0.0,
+ norm_eps=1e-5):
+ assert activation in ['quick_gelu', 'gelu', 'swi_glu']
+ super().__init__()
+ self.dim = dim
+ self.mlp_ratio = mlp_ratio
+ self.num_heads = num_heads
+ self.post_norm = post_norm
+ self.causal = causal
+ self.norm_eps = norm_eps
+
+ # layers
+ self.norm1 = LayerNorm(dim, eps=norm_eps)
+ self.attn = SelfAttention(dim, num_heads, causal, attn_dropout,
+ proj_dropout)
+ self.norm2 = LayerNorm(dim, eps=norm_eps)
+ if activation == 'swi_glu':
+ self.mlp = SwiGLU(dim, int(dim * mlp_ratio))
+ else:
+ self.mlp = nn.Sequential(
+ nn.Linear(dim, int(dim * mlp_ratio)),
+ QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
+ nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
+
+ def forward(self, x):
+ if self.post_norm:
+ x = x + self.norm1(self.attn(x))
+ x = x + self.norm2(self.mlp(x))
+ else:
+ x = x + self.attn(self.norm1(x))
+ x = x + self.mlp(self.norm2(x))
+ return x
+
+
+class AttentionPool(nn.Module):
+
+ def __init__(self,
+ dim,
+ mlp_ratio,
+ num_heads,
+ activation='gelu',
+ proj_dropout=0.0,
+ norm_eps=1e-5):
+ assert dim % num_heads == 0
+ super().__init__()
+ self.dim = dim
+ self.mlp_ratio = mlp_ratio
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.proj_dropout = proj_dropout
+ self.norm_eps = norm_eps
+
+ # layers
+ gain = 1.0 / math.sqrt(dim)
+ self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
+ self.to_q = nn.Linear(dim, dim)
+ self.to_kv = nn.Linear(dim, dim * 2)
+ self.proj = nn.Linear(dim, dim)
+ self.norm = LayerNorm(dim, eps=norm_eps)
+ self.mlp = nn.Sequential(
+ nn.Linear(dim, int(dim * mlp_ratio)),
+ QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
+ nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
+
+ def forward(self, x):
+ """
+ x: [B, L, C].
+ """
+ b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
+
+ # compute query, key, value
+ q = self.to_q(self.cls_embedding).view(1, 1, n*d).expand(b, -1, -1)
+ k, v = self.to_kv(x).chunk(2, dim=-1)
+
+ # compute attention
+ x = flash_attention(q, k, v, num_heads=self.num_heads, compatibility_mode=True)
+ x = x.reshape(b, 1, c)
+
+ # output
+ x = self.proj(x)
+ x = F.dropout(x, self.proj_dropout, self.training)
+
+ # mlp
+ x = x + self.mlp(self.norm(x))
+ return x[:, 0]
+
+
+class VisionTransformer(nn.Module):
+
+ def __init__(self,
+ image_size=224,
+ patch_size=16,
+ dim=768,
+ mlp_ratio=4,
+ out_dim=512,
+ num_heads=12,
+ num_layers=12,
+ pool_type='token',
+ pre_norm=True,
+ post_norm=False,
+ activation='quick_gelu',
+ attn_dropout=0.0,
+ proj_dropout=0.0,
+ embedding_dropout=0.0,
+ norm_eps=1e-5):
+ if image_size % patch_size != 0:
+ print(
+ '[WARNING] image_size is not divisible by patch_size',
+ flush=True)
+ assert pool_type in ('token', 'token_fc', 'attn_pool')
+ out_dim = out_dim or dim
+ super().__init__()
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_patches = (image_size // patch_size)**2
+ self.dim = dim
+ self.mlp_ratio = mlp_ratio
+ self.out_dim = out_dim
+ self.num_heads = num_heads
+ self.num_layers = num_layers
+ self.pool_type = pool_type
+ self.post_norm = post_norm
+ self.norm_eps = norm_eps
+
+ # embeddings
+ gain = 1.0 / math.sqrt(dim)
+ self.patch_embedding = nn.Conv2d(
+ 3,
+ dim,
+ kernel_size=patch_size,
+ stride=patch_size,
+ bias=not pre_norm)
+ if pool_type in ('token', 'token_fc'):
+ self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
+ self.pos_embedding = nn.Parameter(gain * torch.randn(
+ 1, self.num_patches +
+ (1 if pool_type in ('token', 'token_fc') else 0), dim))
+ self.dropout = nn.Dropout(embedding_dropout)
+
+ # transformer
+ self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None
+ self.transformer = nn.Sequential(*[
+ AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False,
+ activation, attn_dropout, proj_dropout, norm_eps)
+ for _ in range(num_layers)
+ ])
+ self.post_norm = LayerNorm(dim, eps=norm_eps)
+
+ # head
+ if pool_type == 'token':
+ self.head = nn.Parameter(gain * torch.randn(dim, out_dim))
+ elif pool_type == 'token_fc':
+ self.head = nn.Linear(dim, out_dim)
+ elif pool_type == 'attn_pool':
+ self.head = AttentionPool(dim, mlp_ratio, num_heads, activation,
+ proj_dropout, norm_eps)
+
+ def forward(self, x, interpolation=False, use_31_block=False):
+ b = x.size(0)
+
+ # embeddings
+ x = self.patch_embedding(x).flatten(2).permute(0, 2, 1)
+ if self.pool_type in ('token', 'token_fc'):
+ x = torch.cat([self.cls_embedding.expand(b, -1, -1).to(dtype=x.dtype, device=x.device), x], dim=1)
+ if interpolation:
+ e = pos_interpolate(self.pos_embedding, x.size(1))
+ else:
+ e = self.pos_embedding
+ e = e.to(dtype=x.dtype, device=x.device)
+ x = self.dropout(x + e)
+ if self.pre_norm is not None:
+ x = self.pre_norm(x)
+
+ # transformer
+ if use_31_block:
+ x = self.transformer[:-1](x)
+ return x
+ else:
+ x = self.transformer(x)
+ return x
+
+
+class CLIP(nn.Module):
+
+ def __init__(self,
+ embed_dim=512,
+ image_size=224,
+ patch_size=16,
+ vision_dim=768,
+ vision_mlp_ratio=4,
+ vision_heads=12,
+ vision_layers=12,
+ vision_pool='token',
+ vision_pre_norm=True,
+ vision_post_norm=False,
+ vocab_size=49408,
+ text_len=77,
+ text_dim=512,
+ text_mlp_ratio=4,
+ text_heads=8,
+ text_layers=12,
+ text_causal=True,
+ text_pool='argmax',
+ text_head_bias=False,
+ logit_bias=None,
+ activation='quick_gelu',
+ attn_dropout=0.0,
+ proj_dropout=0.0,
+ embedding_dropout=0.0,
+ norm_eps=1e-5):
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.vision_dim = vision_dim
+ self.vision_mlp_ratio = vision_mlp_ratio
+ self.vision_heads = vision_heads
+ self.vision_layers = vision_layers
+ self.vision_pool = vision_pool
+ self.vision_pre_norm = vision_pre_norm
+ self.vision_post_norm = vision_post_norm
+ self.vocab_size = vocab_size
+ self.text_len = text_len
+ self.text_dim = text_dim
+ self.text_mlp_ratio = text_mlp_ratio
+ self.text_heads = text_heads
+ self.text_layers = text_layers
+ self.text_causal = text_causal
+ self.text_pool = text_pool
+ self.text_head_bias = text_head_bias
+ self.norm_eps = norm_eps
+
+ # models
+ self.visual = VisionTransformer(
+ image_size=image_size,
+ patch_size=patch_size,
+ dim=vision_dim,
+ mlp_ratio=vision_mlp_ratio,
+ out_dim=embed_dim,
+ num_heads=vision_heads,
+ num_layers=vision_layers,
+ pool_type=vision_pool,
+ pre_norm=vision_pre_norm,
+ post_norm=vision_post_norm,
+ activation=activation,
+ attn_dropout=attn_dropout,
+ proj_dropout=proj_dropout,
+ embedding_dropout=embedding_dropout,
+ norm_eps=norm_eps)
+ self.textual = TextTransformer(
+ vocab_size=vocab_size,
+ text_len=text_len,
+ dim=text_dim,
+ mlp_ratio=text_mlp_ratio,
+ out_dim=embed_dim,
+ num_heads=text_heads,
+ num_layers=text_layers,
+ causal=text_causal,
+ pool_type=text_pool,
+ head_bias=text_head_bias,
+ activation=activation,
+ attn_dropout=attn_dropout,
+ proj_dropout=proj_dropout,
+ embedding_dropout=embedding_dropout,
+ norm_eps=norm_eps)
+ self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
+ if logit_bias is not None:
+ self.logit_bias = nn.Parameter(logit_bias * torch.ones([]))
+
+ # initialize weights
+ self.init_weights()
+
+ def forward(self, imgs, txt_ids):
+ """
+ imgs: [B, 3, H, W] of torch.float32.
+ - mean: [0.48145466, 0.4578275, 0.40821073]
+ - std: [0.26862954, 0.26130258, 0.27577711]
+ txt_ids: [B, L] of torch.long. Encoded by data.CLIPTokenizer.
+ """
+ xi = self.visual(imgs)
+ xt = self.textual(txt_ids)
+ return xi, xt
+
+ def init_weights(self):
+ # embeddings
+ nn.init.normal_(self.textual.token_embedding.weight, std=0.02)
+ nn.init.normal_(self.visual.patch_embedding.weight, std=0.1)
+
+ # attentions
+ for modality in ['visual', 'textual']:
+ dim = self.vision_dim if modality == 'visual' else self.text_dim
+ transformer = getattr(self, modality).transformer
+ proj_gain = (1.0 / math.sqrt(dim)) * (
+ 1.0 / math.sqrt(2 * len(transformer)))
+ attn_gain = 1.0 / math.sqrt(dim)
+ mlp_gain = 1.0 / math.sqrt(2.0 * dim)
+ for block in transformer:
+ nn.init.normal_(block.attn.to_qkv.weight, std=attn_gain)
+ nn.init.normal_(block.attn.proj.weight, std=proj_gain)
+ nn.init.normal_(block.mlp[0].weight, std=mlp_gain)
+ nn.init.normal_(block.mlp[2].weight, std=proj_gain)
+
+ def param_groups(self):
+ groups = [{
+ 'params': [
+ p for n, p in self.named_parameters()
+ if 'norm' in n or n.endswith('bias')
+ ],
+ 'weight_decay': 0.0
+ }, {
+ 'params': [
+ p for n, p in self.named_parameters()
+ if not ('norm' in n or n.endswith('bias'))
+ ]
+ }]
+ return groups
+
+
+class XLMRobertaWithHead(XLMRoberta):
+
+ def __init__(self, **kwargs):
+ self.out_dim = kwargs.pop('out_dim')
+ super().__init__(**kwargs)
+
+ # head
+ mid_dim = (self.dim + self.out_dim) // 2
+ self.head = nn.Sequential(
+ nn.Linear(self.dim, mid_dim, bias=False), nn.GELU(),
+ nn.Linear(mid_dim, self.out_dim, bias=False))
+
+ def forward(self, ids):
+ # xlm-roberta
+ x = super().forward(ids)
+
+ # average pooling
+ mask = ids.ne(self.pad_id).unsqueeze(-1).to(x)
+ x = (x * mask).sum(dim=1) / mask.sum(dim=1)
+
+ # head
+ x = self.head(x)
+ return x
+
+
+class XLMRobertaCLIP(nn.Module):
+
+ def __init__(self,
+ embed_dim=1024,
+ image_size=224,
+ patch_size=14,
+ vision_dim=1280,
+ vision_mlp_ratio=4,
+ vision_heads=16,
+ vision_layers=32,
+ vision_pool='token',
+ vision_pre_norm=True,
+ vision_post_norm=False,
+ activation='gelu',
+ vocab_size=250002,
+ max_text_len=514,
+ type_size=1,
+ pad_id=1,
+ text_dim=1024,
+ text_heads=16,
+ text_layers=24,
+ text_post_norm=True,
+ text_dropout=0.1,
+ attn_dropout=0.0,
+ proj_dropout=0.0,
+ embedding_dropout=0.0,
+ norm_eps=1e-5):
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.vision_dim = vision_dim
+ self.vision_mlp_ratio = vision_mlp_ratio
+ self.vision_heads = vision_heads
+ self.vision_layers = vision_layers
+ self.vision_pre_norm = vision_pre_norm
+ self.vision_post_norm = vision_post_norm
+ self.activation = activation
+ self.vocab_size = vocab_size
+ self.max_text_len = max_text_len
+ self.type_size = type_size
+ self.pad_id = pad_id
+ self.text_dim = text_dim
+ self.text_heads = text_heads
+ self.text_layers = text_layers
+ self.text_post_norm = text_post_norm
+ self.norm_eps = norm_eps
+
+ # models
+ self.visual = VisionTransformer(
+ image_size=image_size,
+ patch_size=patch_size,
+ dim=vision_dim,
+ mlp_ratio=vision_mlp_ratio,
+ out_dim=embed_dim,
+ num_heads=vision_heads,
+ num_layers=vision_layers,
+ pool_type=vision_pool,
+ pre_norm=vision_pre_norm,
+ post_norm=vision_post_norm,
+ activation=activation,
+ attn_dropout=attn_dropout,
+ proj_dropout=proj_dropout,
+ embedding_dropout=embedding_dropout,
+ norm_eps=norm_eps)
+ self.textual = None
+ self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
+
+ def forward(self, imgs, txt_ids):
+ """
+ imgs: [B, 3, H, W] of torch.float32.
+ - mean: [0.48145466, 0.4578275, 0.40821073]
+ - std: [0.26862954, 0.26130258, 0.27577711]
+ txt_ids: [B, L] of torch.long.
+ Encoded by data.CLIPTokenizer.
+ """
+ xi = self.visual(imgs)
+ xt = self.textual(txt_ids)
+ return xi, xt
+
+ def param_groups(self):
+ groups = [{
+ 'params': [
+ p for n, p in self.named_parameters()
+ if 'norm' in n or n.endswith('bias')
+ ],
+ 'weight_decay': 0.0
+ }, {
+ 'params': [
+ p for n, p in self.named_parameters()
+ if not ('norm' in n or n.endswith('bias'))
+ ]
+ }]
+ return groups
+
+
+def _clip(pretrained=False,
+ pretrained_name=None,
+ model_cls=CLIP,
+ return_transforms=False,
+ return_tokenizer=False,
+ tokenizer_padding='eos',
+ dtype=torch.float32,
+ device='cpu',
+ **kwargs):
+ # init model
+ if pretrained and pretrained_name:
+ from sora import BUCKET, DOWNLOAD_TO_CACHE
+
+ # init a meta model
+ with torch.device('meta'):
+ model = model_cls(**kwargs)
+
+ # checkpoint path
+ checkpoint = f'models/clip/{pretrained_name}'
+ if dtype in (torch.float16, torch.bfloat16):
+ suffix = '-' + {
+ torch.float16: 'fp16',
+ torch.bfloat16: 'bf16'
+ }[dtype]
+ if object_exists(BUCKET, f'{checkpoint}{suffix}.pth'):
+ checkpoint = f'{checkpoint}{suffix}'
+ checkpoint += '.pth'
+
+ # load
+ model.load_state_dict(
+ torch.load(DOWNLOAD_TO_CACHE(checkpoint), map_location=device),
+ assign=True,
+ strict=False)
+ else:
+ # init a model on device
+ with torch.device(device):
+ model = model_cls(**kwargs)
+
+ # set device
+ output = (model,)
+
+ # init transforms
+ if return_transforms:
+ # mean and std
+ if 'siglip' in pretrained_name.lower():
+ mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
+ else:
+ mean = [0.48145466, 0.4578275, 0.40821073]
+ std = [0.26862954, 0.26130258, 0.27577711]
+
+ # transforms
+ transforms = T.Compose([
+ T.Resize((model.image_size, model.image_size),
+ interpolation=T.InterpolationMode.BICUBIC),
+ T.ToTensor(),
+ T.Normalize(mean=mean, std=std)
+ ])
+ output += (transforms,)
+
+ # init tokenizer
+ if return_tokenizer:
+ from sora import data
+ if 'siglip' in pretrained_name.lower():
+ tokenizer = data.HuggingfaceTokenizer(
+ name=f'timm/{pretrained_name}',
+ seq_len=model.text_len,
+ clean='canonicalize')
+ elif 'xlm' in pretrained_name.lower():
+ tokenizer = data.HuggingfaceTokenizer(
+ name='xlm-roberta-large',
+ seq_len=model.max_text_len - 2,
+ clean='whitespace')
+ elif 'mba' in pretrained_name.lower():
+ tokenizer = data.HuggingfaceTokenizer(
+ name='facebook/xlm-roberta-xl',
+ seq_len=model.max_text_len - 2,
+ clean='whitespace')
+ else:
+ tokenizer = data.CLIPTokenizer(
+ seq_len=model.text_len, padding=tokenizer_padding)
+ output += (tokenizer,)
+ return output[0] if len(output) == 1 else output
+
+
+def clip_xlm_roberta_vit_h_14(
+ pretrained=False,
+ pretrained_name='open-clip-xlm-roberta-large-vit-huge-14',
+ **kwargs):
+ cfg = dict(
+ embed_dim=1024,
+ image_size=224,
+ patch_size=14,
+ vision_dim=1280,
+ vision_mlp_ratio=4,
+ vision_heads=16,
+ vision_layers=32,
+ vision_pool='token',
+ activation='gelu',
+ vocab_size=250002,
+ max_text_len=514,
+ type_size=1,
+ pad_id=1,
+ text_dim=1024,
+ text_heads=16,
+ text_layers=24,
+ text_post_norm=True,
+ text_dropout=0.1,
+ attn_dropout=0.0,
+ proj_dropout=0.0,
+ embedding_dropout=0.0)
+ cfg.update(**kwargs)
+ return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg)
+
+
+class WanImageEncoder(torch.nn.Module):
+
+ def __init__(self):
+ super().__init__()
+ # init model
+ self.model, self.transforms = clip_xlm_roberta_vit_h_14(
+ pretrained=False,
+ return_transforms=True,
+ return_tokenizer=False,
+ dtype=torch.float32,
+ device="cpu")
+
+ def encode_image(self, videos):
+ # preprocess
+ size = (self.model.image_size,) * 2
+ videos = torch.cat([
+ F.interpolate(
+ u,
+ size=size,
+ mode='bicubic',
+ align_corners=False) for u in videos
+ ])
+ videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))
+
+ # forward
+ dtype = next(iter(self.model.visual.parameters())).dtype
+ videos = videos.to(dtype)
+ out = self.model.visual(videos, use_31_block=True)
+ return out
+
+ @staticmethod
+ def state_dict_converter():
+ return WanImageEncoderStateDictConverter()
+
+
+class WanImageEncoderStateDictConverter:
+ def __init__(self):
+ pass
+
+ def from_diffusers(self, state_dict):
+ return state_dict
+
+ def from_civitai(self, state_dict):
+ state_dict_ = {}
+ for name, param in state_dict.items():
+ if name.startswith("textual."):
+ continue
+ name = "model." + name
+ state_dict_[name] = param
+ return state_dict_
+
diff --git a/PusaV1/diffsynth/models/wan_video_motion_controller.py b/PusaV1/diffsynth/models/wan_video_motion_controller.py
new file mode 100644
index 0000000000000000000000000000000000000000..518c1c66edca1cae11d5f3371af0455808b2a66a
--- /dev/null
+++ b/PusaV1/diffsynth/models/wan_video_motion_controller.py
@@ -0,0 +1,44 @@
+import torch
+import torch.nn as nn
+from .wan_video_dit import sinusoidal_embedding_1d
+
+
+
+class WanMotionControllerModel(torch.nn.Module):
+ def __init__(self, freq_dim=256, dim=1536):
+ super().__init__()
+ self.freq_dim = freq_dim
+ self.linear = nn.Sequential(
+ nn.Linear(freq_dim, dim),
+ nn.SiLU(),
+ nn.Linear(dim, dim),
+ nn.SiLU(),
+ nn.Linear(dim, dim * 6),
+ )
+
+ def forward(self, motion_bucket_id):
+ emb = sinusoidal_embedding_1d(self.freq_dim, motion_bucket_id * 10)
+ emb = self.linear(emb)
+ return emb
+
+ def init(self):
+ state_dict = self.linear[-1].state_dict()
+ state_dict = {i: state_dict[i] * 0 for i in state_dict}
+ self.linear[-1].load_state_dict(state_dict)
+
+ @staticmethod
+ def state_dict_converter():
+ return WanMotionControllerModelDictConverter()
+
+
+
+class WanMotionControllerModelDictConverter:
+ def __init__(self):
+ pass
+
+ def from_diffusers(self, state_dict):
+ return state_dict
+
+ def from_civitai(self, state_dict):
+ return state_dict
+
diff --git a/PusaV1/diffsynth/models/wan_video_pusa.py b/PusaV1/diffsynth/models/wan_video_pusa.py
new file mode 100644
index 0000000000000000000000000000000000000000..f8f197d06293dab0b083059f7f4724600bdf5921
--- /dev/null
+++ b/PusaV1/diffsynth/models/wan_video_pusa.py
@@ -0,0 +1,704 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import math
+from typing import Tuple, Optional
+from einops import rearrange
+from .utils import hash_state_dict_keys
+try:
+ import flash_attn_interface
+ FLASH_ATTN_3_AVAILABLE = True
+except ModuleNotFoundError:
+ FLASH_ATTN_3_AVAILABLE = False
+
+try:
+ import flash_attn
+ FLASH_ATTN_2_AVAILABLE = True
+except ModuleNotFoundError:
+ FLASH_ATTN_2_AVAILABLE = False
+
+try:
+ from sageattention import sageattn
+ SAGE_ATTN_AVAILABLE = True
+except ModuleNotFoundError:
+ SAGE_ATTN_AVAILABLE = False
+
+
+_VISUALIZE_ATTENTION_CONFIG = {
+ "enabled": False, "path": None, "step": 0, "block_name": "", "attn_type": "", "grid_size": None,
+}
+
+
+def _visualize_cross_attention_from_center(q, k, config):
+ try:
+ import matplotlib.pyplot as plt
+ import seaborn as sns
+ import os
+ except ImportError:
+ print("Please install matplotlib and seaborn to visualize attention maps.")
+ _VISUALIZE_ATTENTION_CONFIG["enabled"] = False
+ return
+
+ f, h, w = config["grid_size"]
+ query_patch_idx_t = f // 2
+ query_patch_idx_h = h // 2
+ query_patch_idx_w = w // 2
+ query_patch_idx = query_patch_idx_t * (h * w) + query_patch_idx_h * w + query_patch_idx_w
+
+ b, n_heads, s_q, d_head = q.shape
+ if query_patch_idx >= s_q:
+ return
+
+ q_center = q[:, :, query_patch_idx:query_patch_idx+1, :]
+
+ attn_scores = torch.matmul(q_center, k.transpose(-2, -1)) / math.sqrt(d_head)
+ attn_weights = F.softmax(attn_scores, dim=-1)
+
+ token_attention = attn_weights.mean(dim=(0, 1)).squeeze(0).detach().float().cpu().numpy()
+
+ sub_type = config.get("sub_attn_type", "text")
+ path_prefix = os.path.join(config["path"], f'{config["block_name"]}_cross_attn_{sub_type}_step{config["step"]}')
+
+ plt.figure(figsize=(16, 2))
+ sns.heatmap(token_attention[None, :], cmap="viridis", cbar=True)
+ plt.title(f'Cross-Attention: {sub_type} (from center patch)\n{config["block_name"]}, step {config["step"]}')
+ plt.xlabel("Key token index")
+ plt.ylabel("Query patch")
+ plt.tight_layout()
+ plt.savefig(f"{path_prefix}_center_patch.png")
+ plt.close()
+
+def _visualize_frame_self_attention(q, k, config):
+ try:
+ import matplotlib.pyplot as plt
+ import seaborn as sns
+ import os
+ except ImportError:
+ print("Please install matplotlib and seaborn to visualize attention maps.")
+ _VISUALIZE_ATTENTION_CONFIG["enabled"] = False
+ return
+
+ b, n_heads, s, d_head = q.shape
+ f, h, w = config["grid_size"]
+ s_frame = h * w
+ if s != f * h * w:
+ return
+
+ q_frames = q.view(b, n_heads, f, s_frame, d_head)
+ k_frames = k.view(b, n_heads, f, s_frame, d_head)
+
+ # Directly average first is equivalent to first calculate all tokens attention then average each frame
+ q_frame_avg = q_frames.mean(dim=3)
+ k_frame_avg = k_frames.mean(dim=3)
+
+ frame_similarity_map = torch.einsum('bhid,bhjd->bhij', q_frame_avg, k_frame_avg) / math.sqrt(d_head)
+
+ frame_attention_map = F.softmax(frame_similarity_map, dim=-1)
+ frame_attention_map = frame_attention_map.mean(dim=(0,1)).detach().float().cpu().numpy()
+
+ path_prefix = os.path.join(config["path"], f'{config["block_name"]}_self_attn_step{config["step"]}')
+ plt.figure(figsize=(10, 8))
+ sns.heatmap(frame_attention_map, cmap="viridis", cbar=True, annot=True, fmt=".2f")
+ plt.title(f'Frame-to-Frame Self-Attention\n{config["block_name"]}, step {config["step"]}')
+ plt.xlabel("Key Frame Index")
+ plt.ylabel("Query Frame Index")
+ plt.tight_layout()
+ plt.savefig(f"{path_prefix}_frame_similarity.png")
+ plt.close()
+
+
+def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, compatibility_mode=False):
+ if _VISUALIZE_ATTENTION_CONFIG["enabled"]:
+ config = _VISUALIZE_ATTENTION_CONFIG
+ with torch.no_grad():
+ q_vis = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
+ k_vis = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
+
+ if config['attn_type'] == 'self':
+ _visualize_frame_self_attention(q_vis, k_vis, config)
+ elif config['attn_type'] == 'cross':
+ _visualize_cross_attention_from_center(q_vis, k_vis, config)
+
+ if compatibility_mode:
+ q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
+ k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
+ v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
+ x = F.scaled_dot_product_attention(q, k, v)
+ x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
+ elif FLASH_ATTN_3_AVAILABLE:
+ q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
+ k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
+ v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
+ x = flash_attn_interface.flash_attn_func(q, k, v)
+ if isinstance(x,tuple):
+ x = x[0]
+ x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
+ elif FLASH_ATTN_2_AVAILABLE:
+ q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
+ k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
+ v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
+ x = flash_attn.flash_attn_func(q, k, v)
+ x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
+ elif SAGE_ATTN_AVAILABLE:
+ q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
+ k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
+ v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
+ x = sageattn(q, k, v)
+ x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
+ else:
+ q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
+ k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
+ v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
+ x = F.scaled_dot_product_attention(q, k, v)
+ x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
+ return x
+
+
+def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
+ return (x * (1 + scale) + shift)
+
+def sinusoidal_embedding_1d(dim, position):
+ # Handle both 1D and 2D position inputs
+ original_shape = position.shape
+
+ # Flatten to 1D if input is 2D
+ if len(original_shape) == 2:
+ position = position.reshape(-1) # Flatten to (B*T)
+
+ sinusoid = torch.outer(position.type(torch.float64), torch.pow(
+ 10000, -torch.arange(dim//2, dtype=torch.float64, device=position.device).div(dim//2)))
+ x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
+
+ # Reshape back to original batch shape if input was 2D
+ if len(original_shape) == 2:
+ x = x.reshape(original_shape[0], original_shape[1], dim)
+
+ return x.to(position.dtype)
+
+
+def precompute_freqs_cis_3d(dim: int, end: int = 1024, theta: float = 10000.0):
+ # 3d rope precompute
+ f_freqs_cis = precompute_freqs_cis(dim - 2 * (dim // 3), end, theta)
+ h_freqs_cis = precompute_freqs_cis(dim // 3, end, theta)
+ w_freqs_cis = precompute_freqs_cis(dim // 3, end, theta)
+ return f_freqs_cis, h_freqs_cis, w_freqs_cis
+
+
+def precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000.0):
+ # 1d rope precompute
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)
+ [: (dim // 2)].double() / dim))
+ freqs = torch.outer(torch.arange(end, device=freqs.device), freqs)
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
+ return freqs_cis
+
+
+def rope_apply(x, freqs, num_heads):
+ x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
+ x_out = torch.view_as_complex(x.to(torch.float64).reshape(
+ x.shape[0], x.shape[1], x.shape[2], -1, 2))
+ x_out = torch.view_as_real(x_out * freqs).flatten(2)
+ return x_out.to(x.dtype)
+
+
+class RMSNorm(nn.Module):
+ def __init__(self, dim, eps=1e-5):
+ super().__init__()
+ self.eps = eps
+ self.weight = nn.Parameter(torch.ones(dim))
+
+ def norm(self, x):
+ return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
+
+ def forward(self, x):
+ dtype = x.dtype
+ return self.norm(x.float()).to(dtype) * self.weight
+
+
+class AttentionModule(nn.Module):
+ def __init__(self, num_heads):
+ super().__init__()
+ self.num_heads = num_heads
+
+ def forward(self, q, k, v):
+ x = flash_attention(q=q, k=k, v=v, num_heads=self.num_heads)
+ return x
+
+
+class SelfAttention(nn.Module):
+ def __init__(self, dim: int, num_heads: int, eps: float = 1e-6):
+ super().__init__()
+ self.dim = dim
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+
+ self.q = nn.Linear(dim, dim)
+ self.k = nn.Linear(dim, dim)
+ self.v = nn.Linear(dim, dim)
+ self.o = nn.Linear(dim, dim)
+ self.norm_q = RMSNorm(dim, eps=eps)
+ self.norm_k = RMSNorm(dim, eps=eps)
+
+ self.attn = AttentionModule(self.num_heads)
+
+ def forward(self, x, freqs):
+ q = self.norm_q(self.q(x))
+ k = self.norm_k(self.k(x))
+ v = self.v(x)
+ q = rope_apply(q, freqs, self.num_heads)
+ k = rope_apply(k, freqs, self.num_heads)
+ x = self.attn(q, k, v)
+ return self.o(x)
+
+
+class CrossAttention(nn.Module):
+ def __init__(self, dim: int, num_heads: int, eps: float = 1e-6, has_image_input: bool = False):
+ super().__init__()
+ self.dim = dim
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+
+ self.q = nn.Linear(dim, dim)
+ self.k = nn.Linear(dim, dim)
+ self.v = nn.Linear(dim, dim)
+ self.o = nn.Linear(dim, dim)
+ self.norm_q = RMSNorm(dim, eps=eps)
+ self.norm_k = RMSNorm(dim, eps=eps)
+ self.has_image_input = has_image_input
+ if has_image_input:
+ self.k_img = nn.Linear(dim, dim)
+ self.v_img = nn.Linear(dim, dim)
+ self.norm_k_img = RMSNorm(dim, eps=eps)
+
+ self.attn = AttentionModule(self.num_heads)
+
+ def forward(self, x: torch.Tensor, y: torch.Tensor):
+ if self.has_image_input:
+ img = y[:, :257]
+ ctx = y[:, 257:]
+ else:
+ ctx = y
+ q = self.norm_q(self.q(x))
+ k = self.norm_k(self.k(ctx))
+ v = self.v(ctx)
+ if _VISUALIZE_ATTENTION_CONFIG["enabled"]:
+ _VISUALIZE_ATTENTION_CONFIG['sub_attn_type'] = 'text'
+ x = self.attn(q, k, v)
+ if self.has_image_input:
+ k_img = self.norm_k_img(self.k_img(img))
+ v_img = self.v_img(img)
+ if _VISUALIZE_ATTENTION_CONFIG["enabled"]:
+ _VISUALIZE_ATTENTION_CONFIG['sub_attn_type'] = 'image'
+ y = flash_attention(q, k_img, v_img, num_heads=self.num_heads)
+ x = x + y
+ return self.o(x)
+
+
+class GateModule(nn.Module):
+ def __init__(self,):
+ super().__init__()
+
+ def forward(self, x, gate, residual):
+ return x + gate * residual
+
+class DiTBlock(nn.Module):
+ def __init__(self, has_image_input: bool, dim: int, num_heads: int, ffn_dim: int, eps: float = 1e-6):
+ super().__init__()
+ self.dim = dim
+ self.num_heads = num_heads
+ self.ffn_dim = ffn_dim
+ self.block_name = ""
+
+ self.self_attn = SelfAttention(dim, num_heads, eps)
+ self.cross_attn = CrossAttention(
+ dim, num_heads, eps, has_image_input=has_image_input)
+ self.norm1 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
+ self.norm2 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
+ self.norm3 = nn.LayerNorm(dim, eps=eps)
+ self.ffn = nn.Sequential(nn.Linear(dim, ffn_dim), nn.GELU(
+ approximate='tanh'), nn.Linear(ffn_dim, dim))
+ self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
+ self.gate = GateModule()
+
+ def forward(self, x, context, t_mod, freqs):
+ # msa: multi-head self-attention mlp: multi-layer perceptron
+ # Handle the new sequence dimension in t_mod [B, 6, N, D]
+ # Reshape modulation to [1, 6, 1, D] for proper broadcasting
+ modulation = self.modulation.to(dtype=t_mod.dtype, device=t_mod.device).unsqueeze(2)
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
+ modulation + t_mod).chunk(6, dim=1)
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = shift_msa.squeeze(1), scale_msa.squeeze(1), gate_msa.squeeze(1), shift_mlp.squeeze(1), scale_mlp.squeeze(1), gate_mlp.squeeze(1)
+
+ # import ipdb; ipdb.set_trace()
+ if _VISUALIZE_ATTENTION_CONFIG["enabled"]:
+ _VISUALIZE_ATTENTION_CONFIG["block_name"] = self.block_name
+ _VISUALIZE_ATTENTION_CONFIG["attn_type"] = "self"
+
+ input_x = modulate(self.norm1(x), shift_msa, scale_msa)
+ x = self.gate(x, gate_msa, self.self_attn(input_x, freqs))
+ if _VISUALIZE_ATTENTION_CONFIG["enabled"]:
+ _VISUALIZE_ATTENTION_CONFIG["attn_type"] = "cross"
+ x = x + self.cross_attn(self.norm3(x), context)
+ input_x = modulate(self.norm2(x), shift_mlp, scale_mlp)
+ x = self.gate(x, gate_mlp, self.ffn(input_x))
+ return x
+
+class MLP(torch.nn.Module):
+ def __init__(self, in_dim, out_dim):
+ super().__init__()
+ self.proj = torch.nn.Sequential(
+ nn.LayerNorm(in_dim),
+ nn.Linear(in_dim, in_dim),
+ nn.GELU(),
+ nn.Linear(in_dim, out_dim),
+ nn.LayerNorm(out_dim)
+ )
+
+ def forward(self, x):
+ return self.proj(x)
+
+
+class Head(nn.Module):
+ def __init__(self, dim: int, out_dim: int, patch_size: Tuple[int, int, int], eps: float):
+ super().__init__()
+ self.dim = dim
+ self.patch_size = patch_size
+ self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
+ self.head = nn.Linear(dim, out_dim * math.prod(patch_size))
+ self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
+
+ def forward(self, x, t_mod):
+ t_mod = t_mod.unsqueeze(1).repeat(1,2,1,1).permute(0, 1, 3, 2)
+ modulation = self.modulation.to(dtype=t_mod.dtype, device=t_mod.device).unsqueeze(3)
+
+ shift, scale = (modulation + t_mod).chunk(2, dim=1)
+
+ shift, scale = shift.permute(0, 1, 3, 2).squeeze(1), scale.permute(0, 1, 3, 2).squeeze(1)
+
+ x = (self.head(self.norm(x) * (1 + scale) + shift))
+ return x
+
+
+class WanModelPusa(torch.nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ in_dim: int,
+ ffn_dim: int,
+ out_dim: int,
+ text_dim: int,
+ freq_dim: int,
+ eps: float,
+ patch_size: Tuple[int, int, int],
+ num_heads: int,
+ num_layers: int,
+ has_image_input: bool,
+ ):
+ super().__init__()
+ self.dim = dim
+ self.freq_dim = freq_dim
+ self.has_image_input = has_image_input
+ self.patch_size = patch_size
+
+ self.patch_embedding = nn.Conv3d(
+ in_dim, dim, kernel_size=patch_size, stride=patch_size)
+ self.text_embedding = nn.Sequential(
+ nn.Linear(text_dim, dim),
+ nn.GELU(approximate='tanh'),
+ nn.Linear(dim, dim)
+ )
+ self.time_embedding = nn.Sequential(
+ nn.Linear(freq_dim, dim),
+ nn.SiLU(),
+ nn.Linear(dim, dim)
+ )
+ self.time_projection = nn.Sequential(
+ nn.SiLU(), nn.Linear(dim, dim * 6))
+ self.blocks = nn.ModuleList([
+ DiTBlock(has_image_input, dim, num_heads, ffn_dim, eps)
+ for _ in range(num_layers)
+ ])
+ for i, block in enumerate(self.blocks):
+ block.block_name = f"block_{i}"
+ self.head = Head(dim, out_dim, patch_size, eps)
+ head_dim = dim // num_heads
+ self.freqs = precompute_freqs_cis_3d(head_dim)
+
+ if has_image_input:
+ self.img_emb = MLP(1280, dim) # clip_feature_dim = 1280
+
+ def patchify(self, x: torch.Tensor):
+ x = self.patch_embedding(x)
+ grid_size = x.shape[2:]
+ x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous()
+ return x, grid_size # x, grid_size: (f, h, w)
+
+ def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor):
+ return rearrange(
+ x, 'b (f h w) (x y z c) -> b c (f x) (h y) (w z)',
+ f=grid_size[0], h=grid_size[1], w=grid_size[2],
+ x=self.patch_size[0], y=self.patch_size[1], z=self.patch_size[2]
+ )
+
+ def forward(self,
+ x: torch.Tensor,
+ timestep: torch.Tensor,
+ context: torch.Tensor,
+ clip_feature: Optional[torch.Tensor] = None,
+ y: Optional[torch.Tensor] = None,
+ use_gradient_checkpointing: bool = False,
+ use_gradient_checkpointing_offload: bool = False,
+ **kwargs,
+ ):
+ # print(x)
+ t = self.time_embedding(
+ sinusoidal_embedding_1d(self.freq_dim, timestep))
+
+ B, C, T, H, W = x.shape
+ pH, pW = H // self.patch_size[1], W // self.patch_size[2]
+
+ x = x.to(self.patch_embedding.weight.dtype)
+ if y is not None:
+ y = y.to(self.patch_embedding.weight.dtype)
+
+ # import ipdb; ipdb.set_trace()
+ t_mod = self.time_projection(t).unflatten(2, (6, self.dim))
+ context = self.text_embedding(context)
+
+
+ t = t.unsqueeze(2).unsqueeze(3).repeat(1, 1, pH, pW, 1)
+ t = rearrange(t, 'b f h w d -> b (f h w) d').contiguous()
+ t_mod = t_mod.unsqueeze(3).unsqueeze(4).repeat(1, 1, 1, pH, pW, 1)
+ t_mod = rearrange(t_mod, 'b f m h w d -> b m (f h w) d').contiguous()
+
+
+ if self.has_image_input:
+ x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
+ clip_embdding = self.img_emb(clip_feature)
+ context = torch.cat([clip_embdding, context], dim=1)
+
+ x, (f, h, w) = self.patchify(x)
+
+ freqs = torch.cat([
+ self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
+ self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
+ self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
+ ], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+ return custom_forward
+
+ for block in self.blocks:
+ if self.training and use_gradient_checkpointing:
+ if use_gradient_checkpointing_offload:
+ with torch.autograd.graph.save_on_cpu():
+ x = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ x, context, t_mod, freqs,
+ use_reentrant=False,
+ )
+ else:
+ x = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ x, context, t_mod, freqs,
+ use_reentrant=False,
+ )
+ else:
+ x = block(x, context, t_mod, freqs)
+
+ x = self.head(x, t)
+ x = self.unpatchify(x, (f, h, w))
+ return x
+
+ @staticmethod
+ def state_dict_converter():
+ return WanModelPusaStateDictConverter()
+
+
+class WanModelPusaStateDictConverter:
+ def __init__(self):
+ pass
+
+ def from_diffusers(self, state_dict):
+ rename_dict = {
+ "blocks.0.attn1.norm_k.weight": "blocks.0.self_attn.norm_k.weight",
+ "blocks.0.attn1.norm_q.weight": "blocks.0.self_attn.norm_q.weight",
+ "blocks.0.attn1.to_k.bias": "blocks.0.self_attn.k.bias",
+ "blocks.0.attn1.to_k.weight": "blocks.0.self_attn.k.weight",
+ "blocks.0.attn1.to_out.0.bias": "blocks.0.self_attn.o.bias",
+ "blocks.0.attn1.to_out.0.weight": "blocks.0.self_attn.o.weight",
+ "blocks.0.attn1.to_q.bias": "blocks.0.self_attn.q.bias",
+ "blocks.0.attn1.to_q.weight": "blocks.0.self_attn.q.weight",
+ "blocks.0.attn1.to_v.bias": "blocks.0.self_attn.v.bias",
+ "blocks.0.attn1.to_v.weight": "blocks.0.self_attn.v.weight",
+ "blocks.0.attn2.norm_k.weight": "blocks.0.cross_attn.norm_k.weight",
+ "blocks.0.attn2.norm_q.weight": "blocks.0.cross_attn.norm_q.weight",
+ "blocks.0.attn2.to_k.bias": "blocks.0.cross_attn.k.bias",
+ "blocks.0.attn2.to_k.weight": "blocks.0.cross_attn.k.weight",
+ "blocks.0.attn2.to_out.0.bias": "blocks.0.cross_attn.o.bias",
+ "blocks.0.attn2.to_out.0.weight": "blocks.0.cross_attn.o.weight",
+ "blocks.0.attn2.to_q.bias": "blocks.0.cross_attn.q.bias",
+ "blocks.0.attn2.to_q.weight": "blocks.0.cross_attn.q.weight",
+ "blocks.0.attn2.to_v.bias": "blocks.0.cross_attn.v.bias",
+ "blocks.0.attn2.to_v.weight": "blocks.0.cross_attn.v.weight",
+ "blocks.0.ffn.net.0.proj.bias": "blocks.0.ffn.0.bias",
+ "blocks.0.ffn.net.0.proj.weight": "blocks.0.ffn.0.weight",
+ "blocks.0.ffn.net.2.bias": "blocks.0.ffn.2.bias",
+ "blocks.0.ffn.net.2.weight": "blocks.0.ffn.2.weight",
+ "blocks.0.norm2.bias": "blocks.0.norm3.bias",
+ "blocks.0.norm2.weight": "blocks.0.norm3.weight",
+ "blocks.0.scale_shift_table": "blocks.0.modulation",
+ "condition_embedder.text_embedder.linear_1.bias": "text_embedding.0.bias",
+ "condition_embedder.text_embedder.linear_1.weight": "text_embedding.0.weight",
+ "condition_embedder.text_embedder.linear_2.bias": "text_embedding.2.bias",
+ "condition_embedder.text_embedder.linear_2.weight": "text_embedding.2.weight",
+ "condition_embedder.time_embedder.linear_1.bias": "time_embedding.0.bias",
+ "condition_embedder.time_embedder.linear_1.weight": "time_embedding.0.weight",
+ "condition_embedder.time_embedder.linear_2.bias": "time_embedding.2.bias",
+ "condition_embedder.time_embedder.linear_2.weight": "time_embedding.2.weight",
+ "condition_embedder.time_proj.bias": "time_projection.1.bias",
+ "condition_embedder.time_proj.weight": "time_projection.1.weight",
+ "patch_embedding.bias": "patch_embedding.bias",
+ "patch_embedding.weight": "patch_embedding.weight",
+ "scale_shift_table": "head.modulation",
+ "proj_out.bias": "head.head.bias",
+ "proj_out.weight": "head.head.weight",
+ }
+ state_dict_ = {}
+ for name, param in state_dict.items():
+ if name in rename_dict:
+ state_dict_[rename_dict[name]] = param
+ else:
+ name_ = ".".join(name.split(".")[:1] + ["0"] + name.split(".")[2:])
+ if name_ in rename_dict:
+ name_ = rename_dict[name_]
+ name_ = ".".join(name_.split(".")[:1] + [name.split(".")[1]] + name_.split(".")[2:])
+ state_dict_[name_] = param
+ if hash_state_dict_keys(state_dict) == "cb104773c6c2cb6df4f9529ad5c60d0b":
+ config = {
+ "model_type": "t2v",
+ "patch_size": (1, 2, 2),
+ "text_len": 512,
+ "in_dim": 16,
+ "dim": 5120,
+ "ffn_dim": 13824,
+ "freq_dim": 256,
+ "text_dim": 4096,
+ "out_dim": 16,
+ "num_heads": 40,
+ "num_layers": 40,
+ "window_size": (-1, -1),
+ "qk_norm": True,
+ "cross_attn_norm": True,
+ "eps": 1e-6,
+ }
+ else:
+ config = {}
+ return state_dict_, config
+
+ def from_civitai(self, state_dict):
+ # print(state_dict)
+ state_dict = {name: param for name, param in state_dict.items() if not name.startswith("vace")}
+ if hash_state_dict_keys(state_dict) == "9269f8db9040a9d860eaca435be61814":
+ config = {
+ "has_image_input": False,
+ "patch_size": [1, 2, 2],
+ "in_dim": 16,
+ "dim": 1536,
+ "ffn_dim": 8960,
+ "freq_dim": 256,
+ "text_dim": 4096,
+ "out_dim": 16,
+ "num_heads": 12,
+ "num_layers": 30,
+ "eps": 1e-6
+ }
+ elif hash_state_dict_keys(state_dict) == "aafcfd9672c3a2456dc46e1cb6e52c70":
+ config = {
+ "has_image_input": False,
+ "patch_size": [1, 2, 2],
+ "in_dim": 16,
+ "dim": 5120,
+ "ffn_dim": 13824,
+ "freq_dim": 256,
+ "text_dim": 4096,
+ "out_dim": 16,
+ "num_heads": 40,
+ "num_layers": 40,
+ "eps": 1e-6
+ }
+ elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e":
+ config = {
+ "has_image_input": True,
+ "patch_size": [1, 2, 2],
+ "in_dim": 36,
+ "dim": 5120,
+ "ffn_dim": 13824,
+ "freq_dim": 256,
+ "text_dim": 4096,
+ "out_dim": 16,
+ "num_heads": 40,
+ "num_layers": 40,
+ "eps": 1e-6
+ }
+ elif hash_state_dict_keys(state_dict) == "6d6ccde6845b95ad9114ab993d917893":
+ config = {
+ "has_image_input": True,
+ "patch_size": [1, 2, 2],
+ "in_dim": 36,
+ "dim": 1536,
+ "ffn_dim": 8960,
+ "freq_dim": 256,
+ "text_dim": 4096,
+ "out_dim": 16,
+ "num_heads": 12,
+ "num_layers": 30,
+ "eps": 1e-6
+ }
+ elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e":
+ config = {
+ "has_image_input": True,
+ "patch_size": [1, 2, 2],
+ "in_dim": 36,
+ "dim": 5120,
+ "ffn_dim": 13824,
+ "freq_dim": 256,
+ "text_dim": 4096,
+ "out_dim": 16,
+ "num_heads": 40,
+ "num_layers": 40,
+ "eps": 1e-6
+ }
+ elif hash_state_dict_keys(state_dict) == "349723183fc063b2bfc10bb2835cf677":
+ config = {
+ "has_image_input": True,
+ "patch_size": [1, 2, 2],
+ "in_dim": 48,
+ "dim": 1536,
+ "ffn_dim": 8960,
+ "freq_dim": 256,
+ "text_dim": 4096,
+ "out_dim": 16,
+ "num_heads": 12,
+ "num_layers": 30,
+ "eps": 1e-6
+ }
+ elif hash_state_dict_keys(state_dict) == "efa44cddf936c70abd0ea28b6cbe946c":
+ config = {
+ "has_image_input": True,
+ "patch_size": [1, 2, 2],
+ "in_dim": 48,
+ "dim": 5120,
+ "ffn_dim": 13824,
+ "freq_dim": 256,
+ "text_dim": 4096,
+ "out_dim": 16,
+ "num_heads": 40,
+ "num_layers": 40,
+ "eps": 1e-6
+ }
+ else:
+ config = {}
+ return state_dict, config
diff --git a/PusaV1/diffsynth/models/wan_video_text_encoder.py b/PusaV1/diffsynth/models/wan_video_text_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..c28873722ee92f23914712c9d5b2c3a26fd2adb7
--- /dev/null
+++ b/PusaV1/diffsynth/models/wan_video_text_encoder.py
@@ -0,0 +1,269 @@
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def fp16_clamp(x):
+ if x.dtype == torch.float16 and torch.isinf(x).any():
+ clamp = torch.finfo(x.dtype).max - 1000
+ x = torch.clamp(x, min=-clamp, max=clamp)
+ return x
+
+
+class GELU(nn.Module):
+
+ def forward(self, x):
+ return 0.5 * x * (1.0 + torch.tanh(
+ math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
+
+
+class T5LayerNorm(nn.Module):
+
+ def __init__(self, dim, eps=1e-6):
+ super(T5LayerNorm, self).__init__()
+ self.dim = dim
+ self.eps = eps
+ self.weight = nn.Parameter(torch.ones(dim))
+
+ def forward(self, x):
+ x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) +
+ self.eps)
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
+ x = x.type_as(self.weight)
+ return self.weight * x
+
+
+class T5Attention(nn.Module):
+
+ def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
+ assert dim_attn % num_heads == 0
+ super(T5Attention, self).__init__()
+ self.dim = dim
+ self.dim_attn = dim_attn
+ self.num_heads = num_heads
+ self.head_dim = dim_attn // num_heads
+
+ # layers
+ self.q = nn.Linear(dim, dim_attn, bias=False)
+ self.k = nn.Linear(dim, dim_attn, bias=False)
+ self.v = nn.Linear(dim, dim_attn, bias=False)
+ self.o = nn.Linear(dim_attn, dim, bias=False)
+ self.dropout = nn.Dropout(dropout)
+
+ def forward(self, x, context=None, mask=None, pos_bias=None):
+ """
+ x: [B, L1, C].
+ context: [B, L2, C] or None.
+ mask: [B, L2] or [B, L1, L2] or None.
+ """
+ # check inputs
+ context = x if context is None else context
+ b, n, c = x.size(0), self.num_heads, self.head_dim
+
+ # compute query, key, value
+ q = self.q(x).view(b, -1, n, c)
+ k = self.k(context).view(b, -1, n, c)
+ v = self.v(context).view(b, -1, n, c)
+
+ # attention bias
+ attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))
+ if pos_bias is not None:
+ attn_bias += pos_bias
+ if mask is not None:
+ assert mask.ndim in [2, 3]
+ mask = mask.view(b, 1, 1,
+ -1) if mask.ndim == 2 else mask.unsqueeze(1)
+ attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min)
+
+ # compute attention (T5 does not use scaling)
+ attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias
+ attn = F.softmax(attn.float(), dim=-1).type_as(attn)
+ x = torch.einsum('bnij,bjnc->binc', attn, v)
+
+ # output
+ x = x.reshape(b, -1, n * c)
+ x = self.o(x)
+ x = self.dropout(x)
+ return x
+
+
+class T5FeedForward(nn.Module):
+
+ def __init__(self, dim, dim_ffn, dropout=0.1):
+ super(T5FeedForward, self).__init__()
+ self.dim = dim
+ self.dim_ffn = dim_ffn
+
+ # layers
+ self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU())
+ self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
+ self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
+ self.dropout = nn.Dropout(dropout)
+
+ def forward(self, x):
+ x = self.fc1(x) * self.gate(x)
+ x = self.dropout(x)
+ x = self.fc2(x)
+ x = self.dropout(x)
+ return x
+
+
+class T5SelfAttention(nn.Module):
+
+ def __init__(self,
+ dim,
+ dim_attn,
+ dim_ffn,
+ num_heads,
+ num_buckets,
+ shared_pos=True,
+ dropout=0.1):
+ super(T5SelfAttention, self).__init__()
+ self.dim = dim
+ self.dim_attn = dim_attn
+ self.dim_ffn = dim_ffn
+ self.num_heads = num_heads
+ self.num_buckets = num_buckets
+ self.shared_pos = shared_pos
+
+ # layers
+ self.norm1 = T5LayerNorm(dim)
+ self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
+ self.norm2 = T5LayerNorm(dim)
+ self.ffn = T5FeedForward(dim, dim_ffn, dropout)
+ self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
+ num_buckets, num_heads, bidirectional=True)
+
+ def forward(self, x, mask=None, pos_bias=None):
+ e = pos_bias if self.shared_pos else self.pos_embedding(
+ x.size(1), x.size(1))
+ x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
+ x = fp16_clamp(x + self.ffn(self.norm2(x)))
+ return x
+
+
+class T5RelativeEmbedding(nn.Module):
+
+ def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
+ super(T5RelativeEmbedding, self).__init__()
+ self.num_buckets = num_buckets
+ self.num_heads = num_heads
+ self.bidirectional = bidirectional
+ self.max_dist = max_dist
+
+ # layers
+ self.embedding = nn.Embedding(num_buckets, num_heads)
+
+ def forward(self, lq, lk):
+ device = self.embedding.weight.device
+ # rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \
+ # torch.arange(lq).unsqueeze(1).to(device)
+ rel_pos = torch.arange(lk, device=device).unsqueeze(0) - \
+ torch.arange(lq, device=device).unsqueeze(1)
+ rel_pos = self._relative_position_bucket(rel_pos)
+ rel_pos_embeds = self.embedding(rel_pos)
+ rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(
+ 0) # [1, N, Lq, Lk]
+ return rel_pos_embeds.contiguous()
+
+ def _relative_position_bucket(self, rel_pos):
+ # preprocess
+ if self.bidirectional:
+ num_buckets = self.num_buckets // 2
+ rel_buckets = (rel_pos > 0).long() * num_buckets
+ rel_pos = torch.abs(rel_pos)
+ else:
+ num_buckets = self.num_buckets
+ rel_buckets = 0
+ rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos))
+
+ # embeddings for small and large positions
+ max_exact = num_buckets // 2
+ rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) /
+ math.log(self.max_dist / max_exact) *
+ (num_buckets - max_exact)).long()
+ rel_pos_large = torch.min(
+ rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1))
+ rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
+ return rel_buckets
+
+def init_weights(m):
+ if isinstance(m, T5LayerNorm):
+ nn.init.ones_(m.weight)
+ elif isinstance(m, T5FeedForward):
+ nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5)
+ nn.init.normal_(m.fc1.weight, std=m.dim**-0.5)
+ nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5)
+ elif isinstance(m, T5Attention):
+ nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn)**-0.5)
+ nn.init.normal_(m.k.weight, std=m.dim**-0.5)
+ nn.init.normal_(m.v.weight, std=m.dim**-0.5)
+ nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn)**-0.5)
+ elif isinstance(m, T5RelativeEmbedding):
+ nn.init.normal_(
+ m.embedding.weight, std=(2 * m.num_buckets * m.num_heads)**-0.5)
+
+
+class WanTextEncoder(torch.nn.Module):
+
+ def __init__(self,
+ vocab=256384,
+ dim=4096,
+ dim_attn=4096,
+ dim_ffn=10240,
+ num_heads=64,
+ num_layers=24,
+ num_buckets=32,
+ shared_pos=False,
+ dropout=0.1):
+ super(WanTextEncoder, self).__init__()
+ self.dim = dim
+ self.dim_attn = dim_attn
+ self.dim_ffn = dim_ffn
+ self.num_heads = num_heads
+ self.num_layers = num_layers
+ self.num_buckets = num_buckets
+ self.shared_pos = shared_pos
+
+ # layers
+ self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \
+ else nn.Embedding(vocab, dim)
+ self.pos_embedding = T5RelativeEmbedding(
+ num_buckets, num_heads, bidirectional=True) if shared_pos else None
+ self.dropout = nn.Dropout(dropout)
+ self.blocks = nn.ModuleList([
+ T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
+ shared_pos, dropout) for _ in range(num_layers)
+ ])
+ self.norm = T5LayerNorm(dim)
+
+ # initialize weights
+ self.apply(init_weights)
+
+ def forward(self, ids, mask=None):
+ x = self.token_embedding(ids)
+ x = self.dropout(x)
+ e = self.pos_embedding(x.size(1),
+ x.size(1)) if self.shared_pos else None
+ for block in self.blocks:
+ x = block(x, mask, pos_bias=e)
+ x = self.norm(x)
+ x = self.dropout(x)
+ return x
+
+ @staticmethod
+ def state_dict_converter():
+ return WanTextEncoderStateDictConverter()
+
+
+class WanTextEncoderStateDictConverter:
+ def __init__(self):
+ pass
+
+ def from_diffusers(self, state_dict):
+ return state_dict
+
+ def from_civitai(self, state_dict):
+ return state_dict
diff --git a/PusaV1/diffsynth/models/wan_video_vace.py b/PusaV1/diffsynth/models/wan_video_vace.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c9c2d74bf271bf70b7a17f6611f84b6c9508193
--- /dev/null
+++ b/PusaV1/diffsynth/models/wan_video_vace.py
@@ -0,0 +1,77 @@
+import torch
+from .wan_video_dit import DiTBlock
+
+
+class VaceWanAttentionBlock(DiTBlock):
+ def __init__(self, has_image_input, dim, num_heads, ffn_dim, eps=1e-6, block_id=0):
+ super().__init__(has_image_input, dim, num_heads, ffn_dim, eps=eps)
+ self.block_id = block_id
+ if block_id == 0:
+ self.before_proj = torch.nn.Linear(self.dim, self.dim)
+ self.after_proj = torch.nn.Linear(self.dim, self.dim)
+
+ def forward(self, c, x, context, t_mod, freqs):
+ if self.block_id == 0:
+ c = self.before_proj(c) + x
+ all_c = []
+ else:
+ all_c = list(torch.unbind(c))
+ c = all_c.pop(-1)
+ c = super().forward(c, context, t_mod, freqs)
+ c_skip = self.after_proj(c)
+ all_c += [c_skip, c]
+ c = torch.stack(all_c)
+ return c
+
+
+class VaceWanModel(torch.nn.Module):
+ def __init__(
+ self,
+ vace_layers=(0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28),
+ vace_in_dim=96,
+ patch_size=(1, 2, 2),
+ has_image_input=False,
+ dim=1536,
+ num_heads=12,
+ ffn_dim=8960,
+ eps=1e-6,
+ ):
+ super().__init__()
+ self.vace_layers = vace_layers
+ self.vace_in_dim = vace_in_dim
+ self.vace_layers_mapping = {i: n for n, i in enumerate(self.vace_layers)}
+
+ # vace blocks
+ self.vace_blocks = torch.nn.ModuleList([
+ VaceWanAttentionBlock(has_image_input, dim, num_heads, ffn_dim, eps, block_id=i)
+ for i in self.vace_layers
+ ])
+
+ # vace patch embeddings
+ self.vace_patch_embedding = torch.nn.Conv3d(vace_in_dim, dim, kernel_size=patch_size, stride=patch_size)
+
+ def forward(self, x, vace_context, context, t_mod, freqs):
+ c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
+ c = [u.flatten(2).transpose(1, 2) for u in c]
+ c = torch.cat([
+ torch.cat([u, u.new_zeros(1, x.shape[1] - u.size(1), u.size(2))],
+ dim=1) for u in c
+ ])
+
+ for block in self.vace_blocks:
+ c = block(c, x, context, t_mod, freqs)
+ hints = torch.unbind(c)[:-1]
+ return hints
+
+ @staticmethod
+ def state_dict_converter():
+ return VaceWanModelDictConverter()
+
+
+class VaceWanModelDictConverter:
+ def __init__(self):
+ pass
+
+ def from_civitai(self, state_dict):
+ state_dict_ = {name: param for name, param in state_dict.items() if name.startswith("vace")}
+ return state_dict_
diff --git a/PusaV1/diffsynth/models/wan_video_vae.py b/PusaV1/diffsynth/models/wan_video_vae.py
new file mode 100644
index 0000000000000000000000000000000000000000..df23076166bf99231cac9e5822028a5eadf8ccea
--- /dev/null
+++ b/PusaV1/diffsynth/models/wan_video_vae.py
@@ -0,0 +1,807 @@
+from einops import rearrange, repeat
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from tqdm import tqdm
+
+CACHE_T = 2
+
+
+def check_is_instance(model, module_class):
+ if isinstance(model, module_class):
+ return True
+ if hasattr(model, "module") and isinstance(model.module, module_class):
+ return True
+ return False
+
+
+def block_causal_mask(x, block_size):
+ # params
+ b, n, s, _, device = *x.size(), x.device
+ assert s % block_size == 0
+ num_blocks = s // block_size
+
+ # build mask
+ mask = torch.zeros(b, n, s, s, dtype=torch.bool, device=device)
+ for i in range(num_blocks):
+ mask[:, :,
+ i * block_size:(i + 1) * block_size, :(i + 1) * block_size] = 1
+ return mask
+
+
+class CausalConv3d(nn.Conv3d):
+ """
+ Causal 3d convolusion.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._padding = (self.padding[2], self.padding[2], self.padding[1],
+ self.padding[1], 2 * self.padding[0], 0)
+ self.padding = (0, 0, 0)
+
+ def forward(self, x, cache_x=None):
+ padding = list(self._padding)
+ if cache_x is not None and self._padding[4] > 0:
+ cache_x = cache_x.to(x.device)
+ x = torch.cat([cache_x, x], dim=2)
+ padding[4] -= cache_x.shape[2]
+ x = F.pad(x, padding)
+
+ return super().forward(x)
+
+
+class RMS_norm(nn.Module):
+
+ def __init__(self, dim, channel_first=True, images=True, bias=False):
+ super().__init__()
+ broadcastable_dims = (1, 1, 1) if not images else (1, 1)
+ shape = (dim, *broadcastable_dims) if channel_first else (dim,)
+
+ self.channel_first = channel_first
+ self.scale = dim**0.5
+ self.gamma = nn.Parameter(torch.ones(shape))
+ self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
+
+ def forward(self, x):
+ return F.normalize(
+ x, dim=(1 if self.channel_first else
+ -1)) * self.scale * self.gamma + self.bias
+
+
+class Upsample(nn.Upsample):
+
+ def forward(self, x):
+ """
+ Fix bfloat16 support for nearest neighbor interpolation.
+ """
+ return super().forward(x.float()).type_as(x)
+
+
+class Resample(nn.Module):
+
+ def __init__(self, dim, mode):
+ assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
+ 'downsample3d')
+ super().__init__()
+ self.dim = dim
+ self.mode = mode
+
+ # layers
+ if mode == 'upsample2d':
+ self.resample = nn.Sequential(
+ Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
+ nn.Conv2d(dim, dim // 2, 3, padding=1))
+ elif mode == 'upsample3d':
+ self.resample = nn.Sequential(
+ Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
+ nn.Conv2d(dim, dim // 2, 3, padding=1))
+ self.time_conv = CausalConv3d(dim,
+ dim * 2, (3, 1, 1),
+ padding=(1, 0, 0))
+
+ elif mode == 'downsample2d':
+ self.resample = nn.Sequential(
+ nn.ZeroPad2d((0, 1, 0, 1)),
+ nn.Conv2d(dim, dim, 3, stride=(2, 2)))
+ elif mode == 'downsample3d':
+ self.resample = nn.Sequential(
+ nn.ZeroPad2d((0, 1, 0, 1)),
+ nn.Conv2d(dim, dim, 3, stride=(2, 2)))
+ self.time_conv = CausalConv3d(dim,
+ dim, (3, 1, 1),
+ stride=(2, 1, 1),
+ padding=(0, 0, 0))
+
+ else:
+ self.resample = nn.Identity()
+
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
+ b, c, t, h, w = x.size()
+ if self.mode == 'upsample3d':
+ if feat_cache is not None:
+ idx = feat_idx[0]
+ if feat_cache[idx] is None:
+ feat_cache[idx] = 'Rep'
+ feat_idx[0] += 1
+ else:
+
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
+ if cache_x.shape[2] < 2 and feat_cache[
+ idx] is not None and feat_cache[idx] != 'Rep':
+ # cache last frame of last two chunk
+ cache_x = torch.cat([
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
+ cache_x.device), cache_x
+ ],
+ dim=2)
+ if cache_x.shape[2] < 2 and feat_cache[
+ idx] is not None and feat_cache[idx] == 'Rep':
+ cache_x = torch.cat([
+ torch.zeros_like(cache_x).to(cache_x.device),
+ cache_x
+ ],
+ dim=2)
+ if feat_cache[idx] == 'Rep':
+ x = self.time_conv(x)
+ else:
+ x = self.time_conv(x, feat_cache[idx])
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+
+ x = x.reshape(b, 2, c, t, h, w)
+ x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
+ 3)
+ x = x.reshape(b, c, t * 2, h, w)
+ t = x.shape[2]
+ x = rearrange(x, 'b c t h w -> (b t) c h w')
+ x = self.resample(x)
+ x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
+
+ if self.mode == 'downsample3d':
+ if feat_cache is not None:
+ idx = feat_idx[0]
+ if feat_cache[idx] is None:
+ feat_cache[idx] = x.clone()
+ feat_idx[0] += 1
+ else:
+ cache_x = x[:, :, -1:, :, :].clone()
+ x = self.time_conv(
+ torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ return x
+
+ def init_weight(self, conv):
+ conv_weight = conv.weight
+ nn.init.zeros_(conv_weight)
+ c1, c2, t, h, w = conv_weight.size()
+ one_matrix = torch.eye(c1, c2)
+ init_matrix = one_matrix
+ nn.init.zeros_(conv_weight)
+ conv_weight.data[:, :, 1, 0, 0] = init_matrix
+ conv.weight.data.copy_(conv_weight)
+ nn.init.zeros_(conv.bias.data)
+
+ def init_weight2(self, conv):
+ conv_weight = conv.weight.data
+ nn.init.zeros_(conv_weight)
+ c1, c2, t, h, w = conv_weight.size()
+ init_matrix = torch.eye(c1 // 2, c2)
+ conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
+ conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
+ conv.weight.data.copy_(conv_weight)
+ nn.init.zeros_(conv.bias.data)
+
+
+class ResidualBlock(nn.Module):
+
+ def __init__(self, in_dim, out_dim, dropout=0.0):
+ super().__init__()
+ self.in_dim = in_dim
+ self.out_dim = out_dim
+
+ # layers
+ self.residual = nn.Sequential(
+ RMS_norm(in_dim, images=False), nn.SiLU(),
+ CausalConv3d(in_dim, out_dim, 3, padding=1),
+ RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
+ CausalConv3d(out_dim, out_dim, 3, padding=1))
+ self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
+ if in_dim != out_dim else nn.Identity()
+
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
+ h = self.shortcut(x)
+ for layer in self.residual:
+ if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
+ idx = feat_idx[0]
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
+ # cache last frame of last two chunk
+ cache_x = torch.cat([
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
+ cache_x.device), cache_x
+ ],
+ dim=2)
+ x = layer(x, feat_cache[idx])
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ else:
+ x = layer(x)
+ return x + h
+
+
+class AttentionBlock(nn.Module):
+ """
+ Causal self-attention with a single head.
+ """
+
+ def __init__(self, dim):
+ super().__init__()
+ self.dim = dim
+
+ # layers
+ self.norm = RMS_norm(dim)
+ self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
+ self.proj = nn.Conv2d(dim, dim, 1)
+
+ # zero out the last layer params
+ nn.init.zeros_(self.proj.weight)
+
+ def forward(self, x):
+ identity = x
+ b, c, t, h, w = x.size()
+ x = rearrange(x, 'b c t h w -> (b t) c h w')
+ x = self.norm(x)
+ # compute query, key, value
+ q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, -1).permute(
+ 0, 1, 3, 2).contiguous().chunk(3, dim=-1)
+
+ # apply attention
+ x = F.scaled_dot_product_attention(
+ q,
+ k,
+ v,
+ #attn_mask=block_causal_mask(q, block_size=h * w)
+ )
+ x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
+
+ # output
+ x = self.proj(x)
+ x = rearrange(x, '(b t) c h w-> b c t h w', t=t)
+ return x + identity
+
+
+class Encoder3d(nn.Module):
+
+ def __init__(self,
+ dim=128,
+ z_dim=4,
+ dim_mult=[1, 2, 4, 4],
+ num_res_blocks=2,
+ attn_scales=[],
+ temperal_downsample=[True, True, False],
+ dropout=0.0):
+ super().__init__()
+ self.dim = dim
+ self.z_dim = z_dim
+ self.dim_mult = dim_mult
+ self.num_res_blocks = num_res_blocks
+ self.attn_scales = attn_scales
+ self.temperal_downsample = temperal_downsample
+
+ # dimensions
+ dims = [dim * u for u in [1] + dim_mult]
+ scale = 1.0
+
+ # init block
+ self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
+
+ # downsample blocks
+ downsamples = []
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
+ # residual (+attention) blocks
+ for _ in range(num_res_blocks):
+ downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
+ if scale in attn_scales:
+ downsamples.append(AttentionBlock(out_dim))
+ in_dim = out_dim
+
+ # downsample block
+ if i != len(dim_mult) - 1:
+ mode = 'downsample3d' if temperal_downsample[
+ i] else 'downsample2d'
+ downsamples.append(Resample(out_dim, mode=mode))
+ scale /= 2.0
+ self.downsamples = nn.Sequential(*downsamples)
+
+ # middle blocks
+ self.middle = nn.Sequential(ResidualBlock(out_dim, out_dim, dropout),
+ AttentionBlock(out_dim),
+ ResidualBlock(out_dim, out_dim, dropout))
+
+ # output blocks
+ self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(),
+ CausalConv3d(out_dim, z_dim, 3, padding=1))
+
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
+ if feat_cache is not None:
+ idx = feat_idx[0]
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
+ # cache last frame of last two chunk
+ cache_x = torch.cat([
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
+ cache_x.device), cache_x
+ ],
+ dim=2)
+ x = self.conv1(x, feat_cache[idx])
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ else:
+ x = self.conv1(x)
+
+ ## downsamples
+ for layer in self.downsamples:
+ if feat_cache is not None:
+ x = layer(x, feat_cache, feat_idx)
+ else:
+ x = layer(x)
+
+ ## middle
+ for layer in self.middle:
+ if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
+ x = layer(x, feat_cache, feat_idx)
+ else:
+ x = layer(x)
+
+ ## head
+ for layer in self.head:
+ if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
+ idx = feat_idx[0]
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
+ # cache last frame of last two chunk
+ cache_x = torch.cat([
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
+ cache_x.device), cache_x
+ ],
+ dim=2)
+ x = layer(x, feat_cache[idx])
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ else:
+ x = layer(x)
+ return x
+
+
+class Decoder3d(nn.Module):
+
+ def __init__(self,
+ dim=128,
+ z_dim=4,
+ dim_mult=[1, 2, 4, 4],
+ num_res_blocks=2,
+ attn_scales=[],
+ temperal_upsample=[False, True, True],
+ dropout=0.0):
+ super().__init__()
+ self.dim = dim
+ self.z_dim = z_dim
+ self.dim_mult = dim_mult
+ self.num_res_blocks = num_res_blocks
+ self.attn_scales = attn_scales
+ self.temperal_upsample = temperal_upsample
+
+ # dimensions
+ dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
+ scale = 1.0 / 2**(len(dim_mult) - 2)
+
+ # init block
+ self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
+
+ # middle blocks
+ self.middle = nn.Sequential(ResidualBlock(dims[0], dims[0], dropout),
+ AttentionBlock(dims[0]),
+ ResidualBlock(dims[0], dims[0], dropout))
+
+ # upsample blocks
+ upsamples = []
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
+ # residual (+attention) blocks
+ if i == 1 or i == 2 or i == 3:
+ in_dim = in_dim // 2
+ for _ in range(num_res_blocks + 1):
+ upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
+ if scale in attn_scales:
+ upsamples.append(AttentionBlock(out_dim))
+ in_dim = out_dim
+
+ # upsample block
+ if i != len(dim_mult) - 1:
+ mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
+ upsamples.append(Resample(out_dim, mode=mode))
+ scale *= 2.0
+ self.upsamples = nn.Sequential(*upsamples)
+
+ # output blocks
+ self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(),
+ CausalConv3d(out_dim, 3, 3, padding=1))
+
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
+ ## conv1
+ if feat_cache is not None:
+ idx = feat_idx[0]
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
+ # cache last frame of last two chunk
+ cache_x = torch.cat([
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
+ cache_x.device), cache_x
+ ],
+ dim=2)
+ x = self.conv1(x, feat_cache[idx])
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ else:
+ x = self.conv1(x)
+
+ ## middle
+ for layer in self.middle:
+ if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
+ x = layer(x, feat_cache, feat_idx)
+ else:
+ x = layer(x)
+
+ ## upsamples
+ for layer in self.upsamples:
+ if feat_cache is not None:
+ x = layer(x, feat_cache, feat_idx)
+ else:
+ x = layer(x)
+
+ ## head
+ for layer in self.head:
+ if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
+ idx = feat_idx[0]
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
+ # cache last frame of last two chunk
+ cache_x = torch.cat([
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
+ cache_x.device), cache_x
+ ],
+ dim=2)
+ x = layer(x, feat_cache[idx])
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ else:
+ x = layer(x)
+ return x
+
+
+def count_conv3d(model):
+ count = 0
+ for m in model.modules():
+ if check_is_instance(m, CausalConv3d):
+ count += 1
+ return count
+
+
+class VideoVAE_(nn.Module):
+
+ def __init__(self,
+ dim=96,
+ z_dim=16,
+ dim_mult=[1, 2, 4, 4],
+ num_res_blocks=2,
+ attn_scales=[],
+ temperal_downsample=[False, True, True],
+ dropout=0.0):
+ super().__init__()
+ self.dim = dim
+ self.z_dim = z_dim
+ self.dim_mult = dim_mult
+ self.num_res_blocks = num_res_blocks
+ self.attn_scales = attn_scales
+ self.temperal_downsample = temperal_downsample
+ self.temperal_upsample = temperal_downsample[::-1]
+
+ # modules
+ self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
+ attn_scales, self.temperal_downsample, dropout)
+ self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
+ self.conv2 = CausalConv3d(z_dim, z_dim, 1)
+ self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
+ attn_scales, self.temperal_upsample, dropout)
+
+ def forward(self, x):
+ mu, log_var = self.encode(x)
+ z = self.reparameterize(mu, log_var)
+ x_recon = self.decode(z)
+ return x_recon, mu, log_var
+
+ def encode(self, x, scale):
+ self.clear_cache()
+ ## cache
+ t = x.shape[2]
+ iter_ = 1 + (t - 1) // 4
+
+ for i in range(iter_):
+ self._enc_conv_idx = [0]
+ if i == 0:
+ out = self.encoder(x[:, :, :1, :, :],
+ feat_cache=self._enc_feat_map,
+ feat_idx=self._enc_conv_idx)
+ else:
+ out_ = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
+ feat_cache=self._enc_feat_map,
+ feat_idx=self._enc_conv_idx)
+ out = torch.cat([out, out_], 2)
+ mu, log_var = self.conv1(out).chunk(2, dim=1)
+ if isinstance(scale[0], torch.Tensor):
+ scale = [s.to(dtype=mu.dtype, device=mu.device) for s in scale]
+ mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
+ 1, self.z_dim, 1, 1, 1)
+ else:
+ scale = scale.to(dtype=mu.dtype, device=mu.device)
+ mu = (mu - scale[0]) * scale[1]
+ return mu
+
+ def decode(self, z, scale):
+ self.clear_cache()
+ # z: [b,c,t,h,w]
+ if isinstance(scale[0], torch.Tensor):
+ scale = [s.to(dtype=z.dtype, device=z.device) for s in scale]
+ z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
+ 1, self.z_dim, 1, 1, 1)
+ else:
+ scale = scale.to(dtype=z.dtype, device=z.device)
+ z = z / scale[1] + scale[0]
+ iter_ = z.shape[2]
+ x = self.conv2(z)
+ for i in range(iter_):
+ self._conv_idx = [0]
+ if i == 0:
+ out = self.decoder(x[:, :, i:i + 1, :, :],
+ feat_cache=self._feat_map,
+ feat_idx=self._conv_idx)
+ else:
+ out_ = self.decoder(x[:, :, i:i + 1, :, :],
+ feat_cache=self._feat_map,
+ feat_idx=self._conv_idx)
+ out = torch.cat([out, out_], 2) # may add tensor offload
+ return out
+
+ def reparameterize(self, mu, log_var):
+ std = torch.exp(0.5 * log_var)
+ eps = torch.randn_like(std)
+ return eps * std + mu
+
+ def sample(self, imgs, deterministic=False):
+ mu, log_var = self.encode(imgs)
+ if deterministic:
+ return mu
+ std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
+ return mu + std * torch.randn_like(std)
+
+ def clear_cache(self):
+ self._conv_num = count_conv3d(self.decoder)
+ self._conv_idx = [0]
+ self._feat_map = [None] * self._conv_num
+ # cache encode
+ self._enc_conv_num = count_conv3d(self.encoder)
+ self._enc_conv_idx = [0]
+ self._enc_feat_map = [None] * self._enc_conv_num
+
+
+class WanVideoVAE(nn.Module):
+
+ def __init__(self, z_dim=16):
+ super().__init__()
+
+ mean = [
+ -0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
+ 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
+ ]
+ std = [
+ 2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
+ 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
+ ]
+ self.mean = torch.tensor(mean)
+ self.std = torch.tensor(std)
+ self.scale = [self.mean, 1.0 / self.std]
+
+ # init model
+ self.model = VideoVAE_(z_dim=z_dim).eval().requires_grad_(False)
+ self.upsampling_factor = 8
+
+
+ def build_1d_mask(self, length, left_bound, right_bound, border_width):
+ x = torch.ones((length,))
+ if not left_bound:
+ x[:border_width] = (torch.arange(border_width) + 1) / border_width
+ if not right_bound:
+ x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,))
+ return x
+
+
+ def build_mask(self, data, is_bound, border_width):
+ _, _, _, H, W = data.shape
+ h = self.build_1d_mask(H, is_bound[0], is_bound[1], border_width[0])
+ w = self.build_1d_mask(W, is_bound[2], is_bound[3], border_width[1])
+
+ h = repeat(h, "H -> H W", H=H, W=W)
+ w = repeat(w, "W -> H W", H=H, W=W)
+
+ mask = torch.stack([h, w]).min(dim=0).values
+ mask = rearrange(mask, "H W -> 1 1 1 H W")
+ return mask
+
+
+ def tiled_decode(self, hidden_states, device, tile_size, tile_stride):
+ _, _, T, H, W = hidden_states.shape
+ size_h, size_w = tile_size
+ stride_h, stride_w = tile_stride
+
+ # Split tasks
+ tasks = []
+ for h in range(0, H, stride_h):
+ if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
+ for w in range(0, W, stride_w):
+ if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
+ h_, w_ = h + size_h, w + size_w
+ tasks.append((h, h_, w, w_))
+
+ data_device = "cpu"
+ computation_device = device
+
+ out_T = T * 4 - 3
+ weight = torch.zeros((1, 1, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device)
+ values = torch.zeros((1, 3, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device)
+
+ for h, h_, w, w_ in tqdm(tasks, desc="VAE decoding"):
+ hidden_states_batch = hidden_states[:, :, :, h:h_, w:w_].to(computation_device)
+ hidden_states_batch = self.model.decode(hidden_states_batch, self.scale).to(data_device)
+
+ mask = self.build_mask(
+ hidden_states_batch,
+ is_bound=(h==0, h_>=H, w==0, w_>=W),
+ border_width=((size_h - stride_h) * self.upsampling_factor, (size_w - stride_w) * self.upsampling_factor)
+ ).to(dtype=hidden_states.dtype, device=data_device)
+
+ target_h = h * self.upsampling_factor
+ target_w = w * self.upsampling_factor
+ values[
+ :,
+ :,
+ :,
+ target_h:target_h + hidden_states_batch.shape[3],
+ target_w:target_w + hidden_states_batch.shape[4],
+ ] += hidden_states_batch * mask
+ weight[
+ :,
+ :,
+ :,
+ target_h: target_h + hidden_states_batch.shape[3],
+ target_w: target_w + hidden_states_batch.shape[4],
+ ] += mask
+ values = values / weight
+ values = values.clamp_(-1, 1)
+ return values
+
+
+ def tiled_encode(self, video, device, tile_size, tile_stride):
+ _, _, T, H, W = video.shape
+ size_h, size_w = tile_size
+ stride_h, stride_w = tile_stride
+
+ # Split tasks
+ tasks = []
+ for h in range(0, H, stride_h):
+ if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
+ for w in range(0, W, stride_w):
+ if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
+ h_, w_ = h + size_h, w + size_w
+ tasks.append((h, h_, w, w_))
+
+ data_device = "cpu"
+ computation_device = device
+
+ out_T = (T + 3) // 4
+ weight = torch.zeros((1, 1, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device)
+ values = torch.zeros((1, 16, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device)
+
+ for h, h_, w, w_ in tqdm(tasks, desc="VAE encoding"):
+ hidden_states_batch = video[:, :, :, h:h_, w:w_].to(computation_device)
+ hidden_states_batch = self.model.encode(hidden_states_batch, self.scale).to(data_device)
+
+ mask = self.build_mask(
+ hidden_states_batch,
+ is_bound=(h==0, h_>=H, w==0, w_>=W),
+ border_width=((size_h - stride_h) // self.upsampling_factor, (size_w - stride_w) // self.upsampling_factor)
+ ).to(dtype=video.dtype, device=data_device)
+
+ target_h = h // self.upsampling_factor
+ target_w = w // self.upsampling_factor
+ values[
+ :,
+ :,
+ :,
+ target_h:target_h + hidden_states_batch.shape[3],
+ target_w:target_w + hidden_states_batch.shape[4],
+ ] += hidden_states_batch * mask
+ weight[
+ :,
+ :,
+ :,
+ target_h: target_h + hidden_states_batch.shape[3],
+ target_w: target_w + hidden_states_batch.shape[4],
+ ] += mask
+ values = values / weight
+ return values
+
+
+ def single_encode(self, video, device):
+ video = video.to(device)
+ x = self.model.encode(video, self.scale)
+ return x
+
+
+ def single_decode(self, hidden_state, device):
+ hidden_state = hidden_state.to(device)
+ video = self.model.decode(hidden_state, self.scale)
+ return video.clamp_(-1, 1)
+
+
+ def encode(self, videos, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
+
+ videos = [video.to("cpu") for video in videos]
+ hidden_states = []
+ for video in videos:
+ video = video.unsqueeze(0)
+ if tiled:
+ tile_size = (tile_size[0] * 8, tile_size[1] * 8)
+ tile_stride = (tile_stride[0] * 8, tile_stride[1] * 8)
+ hidden_state = self.tiled_encode(video, device, tile_size, tile_stride)
+ else:
+ hidden_state = self.single_encode(video, device)
+ hidden_state = hidden_state.squeeze(0)
+ hidden_states.append(hidden_state)
+ hidden_states = torch.stack(hidden_states)
+ return hidden_states
+
+
+ def decode(self, hidden_states, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
+ hidden_states = [hidden_state.to("cpu") for hidden_state in hidden_states]
+ videos = []
+ for hidden_state in hidden_states:
+ hidden_state = hidden_state.unsqueeze(0)
+ if tiled:
+ video = self.tiled_decode(hidden_state, device, tile_size, tile_stride)
+ else:
+ video = self.single_decode(hidden_state, device)
+ video = video.squeeze(0)
+ videos.append(video)
+ videos = torch.stack(videos)
+ return videos
+
+
+ @staticmethod
+ def state_dict_converter():
+ return WanVideoVAEStateDictConverter()
+
+
+class WanVideoVAEStateDictConverter:
+
+ def __init__(self):
+ pass
+
+ def from_civitai(self, state_dict):
+ state_dict_ = {}
+ if 'model_state' in state_dict:
+ state_dict = state_dict['model_state']
+ for name in state_dict:
+ state_dict_['model.' + name] = state_dict[name]
+ return state_dict_
diff --git a/PusaV1/diffsynth/pipelines/__init__.py b/PusaV1/diffsynth/pipelines/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..29bfdc8520923c9d61b722bc3fe093e52d6d1038
--- /dev/null
+++ b/PusaV1/diffsynth/pipelines/__init__.py
@@ -0,0 +1,19 @@
+from .sd_image import SDImagePipeline
+from .sd_video import SDVideoPipeline
+from .sdxl_image import SDXLImagePipeline
+from .sdxl_video import SDXLVideoPipeline
+from .sd3_image import SD3ImagePipeline
+from .hunyuan_image import HunyuanDiTImagePipeline
+from .svd_video import SVDVideoPipeline
+from .flux_image import FluxImagePipeline
+from .cog_video import CogVideoPipeline
+from .omnigen_image import OmnigenImagePipeline
+from .pipeline_runner import SDVideoPipelineRunner
+from .hunyuan_video import HunyuanVideoPipeline
+from .step_video import StepVideoPipeline
+from .wan_video import WanVideoPipeline
+from .wan_video_pusa import WanVideoPusaPipeline
+from .wan_video_pusa_multi_frames import PusaMultiFramesPipeline
+from .wan_video_pusa_v2v import PusaV2VPipeline
+
+KolorsImagePipeline = SDXLImagePipeline
diff --git a/PusaV1/diffsynth/pipelines/__pycache__/__init__.cpython-310.pyc b/PusaV1/diffsynth/pipelines/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7fc00cd8fe5a8670fff5cecf70939fdab9ef92b3
Binary files /dev/null and b/PusaV1/diffsynth/pipelines/__pycache__/__init__.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/pipelines/__pycache__/__init__.cpython-312.pyc b/PusaV1/diffsynth/pipelines/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3f64ab2397361b8d73edfc4c3ea364aa47bfdb98
Binary files /dev/null and b/PusaV1/diffsynth/pipelines/__pycache__/__init__.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/pipelines/__pycache__/base.cpython-310.pyc b/PusaV1/diffsynth/pipelines/__pycache__/base.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c65924bb13b3f2eee1bc0e72aa5738e6a2d5c403
Binary files /dev/null and b/PusaV1/diffsynth/pipelines/__pycache__/base.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/pipelines/__pycache__/base.cpython-312.pyc b/PusaV1/diffsynth/pipelines/__pycache__/base.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e712a18bea7875651cd2f00f1ac51011d4a5292b
Binary files /dev/null and b/PusaV1/diffsynth/pipelines/__pycache__/base.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/pipelines/__pycache__/cog_video.cpython-310.pyc b/PusaV1/diffsynth/pipelines/__pycache__/cog_video.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f64abe6342d6dfbb339063805a54f6f8e027b7c0
Binary files /dev/null and b/PusaV1/diffsynth/pipelines/__pycache__/cog_video.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/pipelines/__pycache__/cog_video.cpython-312.pyc b/PusaV1/diffsynth/pipelines/__pycache__/cog_video.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5879fbe4d3f5f9bbf94bc4f1cf56267afa67cc2c
Binary files /dev/null and b/PusaV1/diffsynth/pipelines/__pycache__/cog_video.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/pipelines/__pycache__/dancer.cpython-310.pyc b/PusaV1/diffsynth/pipelines/__pycache__/dancer.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fe30b2a00b428717197093f62741afce1e2088ea
Binary files /dev/null and b/PusaV1/diffsynth/pipelines/__pycache__/dancer.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/pipelines/__pycache__/dancer.cpython-312.pyc b/PusaV1/diffsynth/pipelines/__pycache__/dancer.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8cbdde3636fc691346d3f88e21b4fcefecdbaf8f
Binary files /dev/null and b/PusaV1/diffsynth/pipelines/__pycache__/dancer.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/pipelines/__pycache__/flux_image.cpython-310.pyc b/PusaV1/diffsynth/pipelines/__pycache__/flux_image.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7ac7a3f865698051de80061202c11e6d78fa3247
Binary files /dev/null and b/PusaV1/diffsynth/pipelines/__pycache__/flux_image.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/pipelines/__pycache__/flux_image.cpython-312.pyc b/PusaV1/diffsynth/pipelines/__pycache__/flux_image.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..25331024494b67ac6534148f80f1a8ab07d52d91
Binary files /dev/null and b/PusaV1/diffsynth/pipelines/__pycache__/flux_image.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/pipelines/__pycache__/hunyuan_image.cpython-310.pyc b/PusaV1/diffsynth/pipelines/__pycache__/hunyuan_image.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..59e781e575308de7efa114ef1ac842c350690e55
Binary files /dev/null and b/PusaV1/diffsynth/pipelines/__pycache__/hunyuan_image.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/pipelines/__pycache__/hunyuan_image.cpython-312.pyc b/PusaV1/diffsynth/pipelines/__pycache__/hunyuan_image.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ab1407e1e35e4fdcc4514fd4301b3e6b9bdee109
Binary files /dev/null and b/PusaV1/diffsynth/pipelines/__pycache__/hunyuan_image.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/pipelines/__pycache__/hunyuan_video.cpython-310.pyc b/PusaV1/diffsynth/pipelines/__pycache__/hunyuan_video.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..527c2c0433e8c7007086c025c0a77eec73c4adfc
Binary files /dev/null and b/PusaV1/diffsynth/pipelines/__pycache__/hunyuan_video.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/pipelines/__pycache__/hunyuan_video.cpython-312.pyc b/PusaV1/diffsynth/pipelines/__pycache__/hunyuan_video.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..862df276b5e9205e3118d15ec3a5471828c01353
Binary files /dev/null and b/PusaV1/diffsynth/pipelines/__pycache__/hunyuan_video.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/pipelines/__pycache__/omnigen_image.cpython-310.pyc b/PusaV1/diffsynth/pipelines/__pycache__/omnigen_image.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f03435ae128b55e06ff3f024b7a1ee419a86ce6b
Binary files /dev/null and b/PusaV1/diffsynth/pipelines/__pycache__/omnigen_image.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/pipelines/__pycache__/omnigen_image.cpython-312.pyc b/PusaV1/diffsynth/pipelines/__pycache__/omnigen_image.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6880cb323b289e47df11f13b6966bd076e3c224b
Binary files /dev/null and b/PusaV1/diffsynth/pipelines/__pycache__/omnigen_image.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/pipelines/__pycache__/pipeline_runner.cpython-310.pyc b/PusaV1/diffsynth/pipelines/__pycache__/pipeline_runner.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e1d6f62e7e4583e4aced200dd8d817eb1e09a102
Binary files /dev/null and b/PusaV1/diffsynth/pipelines/__pycache__/pipeline_runner.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/pipelines/__pycache__/pipeline_runner.cpython-312.pyc b/PusaV1/diffsynth/pipelines/__pycache__/pipeline_runner.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0b612717abb64a6448ea89abfd1f3cb35ac3dd74
Binary files /dev/null and b/PusaV1/diffsynth/pipelines/__pycache__/pipeline_runner.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/pipelines/__pycache__/sd3_image.cpython-310.pyc b/PusaV1/diffsynth/pipelines/__pycache__/sd3_image.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d1288ad77345fb7949ad72b1d5fec97a1448abdc
Binary files /dev/null and b/PusaV1/diffsynth/pipelines/__pycache__/sd3_image.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/pipelines/__pycache__/sd3_image.cpython-312.pyc b/PusaV1/diffsynth/pipelines/__pycache__/sd3_image.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..294b146d0f3139bd8ba03db01aea58ad274976e8
Binary files /dev/null and b/PusaV1/diffsynth/pipelines/__pycache__/sd3_image.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/pipelines/__pycache__/sd_image.cpython-310.pyc b/PusaV1/diffsynth/pipelines/__pycache__/sd_image.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9ad6e1719a874c8fbc98d9f2d02389107e9339bd
Binary files /dev/null and b/PusaV1/diffsynth/pipelines/__pycache__/sd_image.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/pipelines/__pycache__/sd_image.cpython-312.pyc b/PusaV1/diffsynth/pipelines/__pycache__/sd_image.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..755be8498fd8f93bdfeaa2b9c93c0cf17295b1a5
Binary files /dev/null and b/PusaV1/diffsynth/pipelines/__pycache__/sd_image.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/pipelines/__pycache__/sd_video.cpython-310.pyc b/PusaV1/diffsynth/pipelines/__pycache__/sd_video.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..263ac7264f5db5ef89e80f5a4b67ee2a6ac1782f
Binary files /dev/null and b/PusaV1/diffsynth/pipelines/__pycache__/sd_video.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/pipelines/__pycache__/sd_video.cpython-312.pyc b/PusaV1/diffsynth/pipelines/__pycache__/sd_video.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c17bea6605cc2000f913f8eab1a35b65be6a8187
Binary files /dev/null and b/PusaV1/diffsynth/pipelines/__pycache__/sd_video.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/pipelines/__pycache__/sdxl_image.cpython-310.pyc b/PusaV1/diffsynth/pipelines/__pycache__/sdxl_image.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9055e9d6fe6c13ffb4b4ca48529c0cc91fa60f5b
Binary files /dev/null and b/PusaV1/diffsynth/pipelines/__pycache__/sdxl_image.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/pipelines/__pycache__/sdxl_image.cpython-312.pyc b/PusaV1/diffsynth/pipelines/__pycache__/sdxl_image.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..72cd748c68d0ad973a8ceca21100ece2422bd4a8
Binary files /dev/null and b/PusaV1/diffsynth/pipelines/__pycache__/sdxl_image.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/pipelines/__pycache__/sdxl_video.cpython-310.pyc b/PusaV1/diffsynth/pipelines/__pycache__/sdxl_video.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ef904354130e8101c649e73391e6b4f272e2ee0d
Binary files /dev/null and b/PusaV1/diffsynth/pipelines/__pycache__/sdxl_video.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/pipelines/__pycache__/sdxl_video.cpython-312.pyc b/PusaV1/diffsynth/pipelines/__pycache__/sdxl_video.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8599b0f54f085eb5494a2288ed5492c193b774bc
Binary files /dev/null and b/PusaV1/diffsynth/pipelines/__pycache__/sdxl_video.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/pipelines/__pycache__/step_video.cpython-310.pyc b/PusaV1/diffsynth/pipelines/__pycache__/step_video.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6782eb1714dbf10775e8d455c37a382a4312c28b
Binary files /dev/null and b/PusaV1/diffsynth/pipelines/__pycache__/step_video.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/pipelines/__pycache__/step_video.cpython-312.pyc b/PusaV1/diffsynth/pipelines/__pycache__/step_video.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f89ef0da8e5d1bbebcbb6ef1994f0a6d09de1bfe
Binary files /dev/null and b/PusaV1/diffsynth/pipelines/__pycache__/step_video.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/pipelines/__pycache__/svd_video.cpython-310.pyc b/PusaV1/diffsynth/pipelines/__pycache__/svd_video.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..73b926a8dbe26f56e7d6a23dba6c27094b757b82
Binary files /dev/null and b/PusaV1/diffsynth/pipelines/__pycache__/svd_video.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/pipelines/__pycache__/svd_video.cpython-312.pyc b/PusaV1/diffsynth/pipelines/__pycache__/svd_video.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3affafa0cd5eb3f29d4765bc6efb0a17d48f5c5f
Binary files /dev/null and b/PusaV1/diffsynth/pipelines/__pycache__/svd_video.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/pipelines/__pycache__/wan_video.cpython-310.pyc b/PusaV1/diffsynth/pipelines/__pycache__/wan_video.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e40c38e811de7eaebde1ec059ce50f4f6b4ec115
Binary files /dev/null and b/PusaV1/diffsynth/pipelines/__pycache__/wan_video.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/pipelines/__pycache__/wan_video.cpython-312.pyc b/PusaV1/diffsynth/pipelines/__pycache__/wan_video.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5ea71b2fff98dcdd0e6ee7aa301b41a4419127e5
Binary files /dev/null and b/PusaV1/diffsynth/pipelines/__pycache__/wan_video.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/pipelines/__pycache__/wan_video_pusa.cpython-310.pyc b/PusaV1/diffsynth/pipelines/__pycache__/wan_video_pusa.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4943b2af0227c2abe28ba9c2ed874044d1059a59
Binary files /dev/null and b/PusaV1/diffsynth/pipelines/__pycache__/wan_video_pusa.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/pipelines/__pycache__/wan_video_pusa.cpython-312.pyc b/PusaV1/diffsynth/pipelines/__pycache__/wan_video_pusa.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ecd179adceee3c9d1b618f2866857bcbfaad20b1
Binary files /dev/null and b/PusaV1/diffsynth/pipelines/__pycache__/wan_video_pusa.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/pipelines/__pycache__/wan_video_pusa_multi_frames.cpython-310.pyc b/PusaV1/diffsynth/pipelines/__pycache__/wan_video_pusa_multi_frames.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..44d4b58d40d2f8f24bfabeb8d33fa7d5ad1816cd
Binary files /dev/null and b/PusaV1/diffsynth/pipelines/__pycache__/wan_video_pusa_multi_frames.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/pipelines/__pycache__/wan_video_pusa_v2v.cpython-310.pyc b/PusaV1/diffsynth/pipelines/__pycache__/wan_video_pusa_v2v.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..93457a3a0fd9ef5bde3a569bb6ff94b31da1042e
Binary files /dev/null and b/PusaV1/diffsynth/pipelines/__pycache__/wan_video_pusa_v2v.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/pipelines/base.py b/PusaV1/diffsynth/pipelines/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a4f01cff55dc0fcca02dc5234227bd65efc7434
--- /dev/null
+++ b/PusaV1/diffsynth/pipelines/base.py
@@ -0,0 +1,127 @@
+import torch
+import numpy as np
+from PIL import Image
+from torchvision.transforms import GaussianBlur
+
+
+
+class BasePipeline(torch.nn.Module):
+
+ def __init__(self, device="cuda", torch_dtype=torch.float16, height_division_factor=64, width_division_factor=64):
+ super().__init__()
+ self.device = device
+ self.torch_dtype = torch_dtype
+ self.height_division_factor = height_division_factor
+ self.width_division_factor = width_division_factor
+ self.cpu_offload = False
+ self.model_names = []
+
+
+ def check_resize_height_width(self, height, width):
+ if height % self.height_division_factor != 0:
+ height = (height + self.height_division_factor - 1) // self.height_division_factor * self.height_division_factor
+ print(f"The height cannot be evenly divided by {self.height_division_factor}. We round it up to {height}.")
+ if width % self.width_division_factor != 0:
+ width = (width + self.width_division_factor - 1) // self.width_division_factor * self.width_division_factor
+ print(f"The width cannot be evenly divided by {self.width_division_factor}. We round it up to {width}.")
+ return height, width
+
+
+ def preprocess_image(self, image):
+ image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
+ return image
+
+
+ def preprocess_images(self, images):
+ return [self.preprocess_image(image) for image in images]
+
+
+ def vae_output_to_image(self, vae_output):
+ image = vae_output[0].cpu().float().permute(1, 2, 0).numpy()
+ image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
+ return image
+
+
+ def vae_output_to_video(self, vae_output):
+ video = vae_output.cpu().permute(1, 2, 0).numpy()
+ video = [Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8")) for image in video]
+ return video
+
+
+ def merge_latents(self, value, latents, masks, scales, blur_kernel_size=33, blur_sigma=10.0):
+ if len(latents) > 0:
+ blur = GaussianBlur(kernel_size=blur_kernel_size, sigma=blur_sigma)
+ height, width = value.shape[-2:]
+ weight = torch.ones_like(value)
+ for latent, mask, scale in zip(latents, masks, scales):
+ mask = self.preprocess_image(mask.resize((width, height))).mean(dim=1, keepdim=True) > 0
+ mask = mask.repeat(1, latent.shape[1], 1, 1).to(dtype=latent.dtype, device=latent.device)
+ mask = blur(mask)
+ value += latent * mask * scale
+ weight += mask * scale
+ value /= weight
+ return value
+
+
+ def control_noise_via_local_prompts(self, prompt_emb_global, prompt_emb_locals, masks, mask_scales, inference_callback, special_kwargs=None, special_local_kwargs_list=None):
+ if special_kwargs is None:
+ noise_pred_global = inference_callback(prompt_emb_global)
+ else:
+ noise_pred_global = inference_callback(prompt_emb_global, special_kwargs)
+ if special_local_kwargs_list is None:
+ noise_pred_locals = [inference_callback(prompt_emb_local) for prompt_emb_local in prompt_emb_locals]
+ else:
+ noise_pred_locals = [inference_callback(prompt_emb_local, special_kwargs) for prompt_emb_local, special_kwargs in zip(prompt_emb_locals, special_local_kwargs_list)]
+ noise_pred = self.merge_latents(noise_pred_global, noise_pred_locals, masks, mask_scales)
+ return noise_pred
+
+
+ def extend_prompt(self, prompt, local_prompts, masks, mask_scales):
+ local_prompts = local_prompts or []
+ masks = masks or []
+ mask_scales = mask_scales or []
+ extended_prompt_dict = self.prompter.extend_prompt(prompt)
+ prompt = extended_prompt_dict.get("prompt", prompt)
+ local_prompts += extended_prompt_dict.get("prompts", [])
+ masks += extended_prompt_dict.get("masks", [])
+ mask_scales += [100.0] * len(extended_prompt_dict.get("masks", []))
+ return prompt, local_prompts, masks, mask_scales
+
+
+ def enable_cpu_offload(self):
+ self.cpu_offload = True
+
+
+ def load_models_to_device(self, loadmodel_names=[]):
+ # only load models to device if cpu_offload is enabled
+ if not self.cpu_offload:
+ return
+ # offload the unneeded models to cpu
+ for model_name in self.model_names:
+ if model_name not in loadmodel_names:
+ model = getattr(self, model_name)
+ if model is not None:
+ if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
+ for module in model.modules():
+ if hasattr(module, "offload"):
+ module.offload()
+ else:
+ model.cpu()
+ # load the needed models to device
+ for model_name in loadmodel_names:
+ model = getattr(self, model_name)
+ if model is not None:
+ if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
+ for module in model.modules():
+ if hasattr(module, "onload"):
+ module.onload()
+ else:
+ model.to(self.device)
+ # fresh the cuda cache
+ torch.cuda.empty_cache()
+
+
+ def generate_noise(self, shape, seed=None, device="cpu", dtype=torch.float16):
+ generator = None if seed is None else torch.Generator(device).manual_seed(seed)
+ noise = torch.randn(shape, generator=generator, device=device, dtype=dtype)
+ return noise
diff --git a/PusaV1/diffsynth/pipelines/cog_video.py b/PusaV1/diffsynth/pipelines/cog_video.py
new file mode 100644
index 0000000000000000000000000000000000000000..f42d295187e718617cc7d4e327067700f2a689fd
--- /dev/null
+++ b/PusaV1/diffsynth/pipelines/cog_video.py
@@ -0,0 +1,135 @@
+from ..models import ModelManager, FluxTextEncoder2, CogDiT, CogVAEEncoder, CogVAEDecoder
+from ..prompters import CogPrompter
+from ..schedulers import EnhancedDDIMScheduler
+from .base import BasePipeline
+import torch
+from tqdm import tqdm
+from PIL import Image
+import numpy as np
+from einops import rearrange
+
+
+
+class CogVideoPipeline(BasePipeline):
+
+ def __init__(self, device="cuda", torch_dtype=torch.float16):
+ super().__init__(device=device, torch_dtype=torch_dtype, height_division_factor=16, width_division_factor=16)
+ self.scheduler = EnhancedDDIMScheduler(rescale_zero_terminal_snr=True, prediction_type="v_prediction")
+ self.prompter = CogPrompter()
+ # models
+ self.text_encoder: FluxTextEncoder2 = None
+ self.dit: CogDiT = None
+ self.vae_encoder: CogVAEEncoder = None
+ self.vae_decoder: CogVAEDecoder = None
+
+
+ def fetch_models(self, model_manager: ModelManager, prompt_refiner_classes=[]):
+ self.text_encoder = model_manager.fetch_model("flux_text_encoder_2")
+ self.dit = model_manager.fetch_model("cog_dit")
+ self.vae_encoder = model_manager.fetch_model("cog_vae_encoder")
+ self.vae_decoder = model_manager.fetch_model("cog_vae_decoder")
+ self.prompter.fetch_models(self.text_encoder)
+ self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
+
+
+ @staticmethod
+ def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[]):
+ pipe = CogVideoPipeline(
+ device=model_manager.device,
+ torch_dtype=model_manager.torch_dtype
+ )
+ pipe.fetch_models(model_manager, prompt_refiner_classes)
+ return pipe
+
+
+ def tensor2video(self, frames):
+ frames = rearrange(frames, "C T H W -> T H W C")
+ frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
+ frames = [Image.fromarray(frame) for frame in frames]
+ return frames
+
+
+ def encode_prompt(self, prompt, positive=True):
+ prompt_emb = self.prompter.encode_prompt(prompt, device=self.device, positive=positive)
+ return {"prompt_emb": prompt_emb}
+
+
+ def prepare_extra_input(self, latents):
+ return {"image_rotary_emb": self.dit.prepare_rotary_positional_embeddings(latents.shape[3], latents.shape[4], latents.shape[2], device=self.device)}
+
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt,
+ negative_prompt="",
+ input_video=None,
+ cfg_scale=7.0,
+ denoising_strength=1.0,
+ num_frames=49,
+ height=480,
+ width=720,
+ num_inference_steps=20,
+ tiled=False,
+ tile_size=(60, 90),
+ tile_stride=(30, 45),
+ seed=None,
+ progress_bar_cmd=tqdm,
+ progress_bar_st=None,
+ ):
+ height, width = self.check_resize_height_width(height, width)
+
+ # Tiler parameters
+ tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
+
+ # Prepare scheduler
+ self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength)
+
+ # Prepare latent tensors
+ noise = self.generate_noise((1, 16, num_frames // 4 + 1, height//8, width//8), seed=seed, device="cpu", dtype=self.torch_dtype)
+
+ if denoising_strength == 1.0:
+ latents = noise.clone()
+ else:
+ input_video = self.preprocess_images(input_video)
+ input_video = torch.stack(input_video, dim=2)
+ latents = self.vae_encoder.encode_video(input_video, **tiler_kwargs, progress_bar=progress_bar_cmd).to(dtype=self.torch_dtype)
+ latents = self.scheduler.add_noise(latents, noise, self.scheduler.timesteps[0])
+ if not tiled: latents = latents.to(self.device)
+
+ # Encode prompt
+ prompt_emb_posi = self.encode_prompt(prompt, positive=True)
+ if cfg_scale != 1.0:
+ prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
+
+ # Extra input
+ extra_input = self.prepare_extra_input(latents)
+
+ # Denoise
+ for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
+ timestep = timestep.unsqueeze(0).to(self.device)
+
+ # Classifier-free guidance
+ noise_pred_posi = self.dit(
+ latents, timestep=timestep, **prompt_emb_posi, **tiler_kwargs, **extra_input
+ )
+ if cfg_scale != 1.0:
+ noise_pred_nega = self.dit(
+ latents, timestep=timestep, **prompt_emb_nega, **tiler_kwargs, **extra_input
+ )
+ noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
+ else:
+ noise_pred = noise_pred_posi
+
+ # DDIM
+ latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
+
+ # Update progress bar
+ if progress_bar_st is not None:
+ progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
+
+ # Decode image
+ video = self.vae_decoder.decode_video(latents.to("cpu"), **tiler_kwargs, progress_bar=progress_bar_cmd)
+ video = self.tensor2video(video[0])
+
+ return video
diff --git a/PusaV1/diffsynth/pipelines/dancer.py b/PusaV1/diffsynth/pipelines/dancer.py
new file mode 100644
index 0000000000000000000000000000000000000000..593b57c8363f94e312debf7c7f69bf6decdb7dbd
--- /dev/null
+++ b/PusaV1/diffsynth/pipelines/dancer.py
@@ -0,0 +1,236 @@
+import torch
+from ..models import SDUNet, SDMotionModel, SDXLUNet, SDXLMotionModel
+from ..models.sd_unet import PushBlock, PopBlock
+from ..controlnets import MultiControlNetManager
+
+
+def lets_dance(
+ unet: SDUNet,
+ motion_modules: SDMotionModel = None,
+ controlnet: MultiControlNetManager = None,
+ sample = None,
+ timestep = None,
+ encoder_hidden_states = None,
+ ipadapter_kwargs_list = {},
+ controlnet_frames = None,
+ unet_batch_size = 1,
+ controlnet_batch_size = 1,
+ cross_frame_attention = False,
+ tiled=False,
+ tile_size=64,
+ tile_stride=32,
+ device = "cuda",
+ vram_limit_level = 0,
+):
+ # 0. Text embedding alignment (only for video processing)
+ if encoder_hidden_states.shape[0] != sample.shape[0]:
+ encoder_hidden_states = encoder_hidden_states.repeat(sample.shape[0], 1, 1, 1)
+
+ # 1. ControlNet
+ # This part will be repeated on overlapping frames if animatediff_batch_size > animatediff_stride.
+ # I leave it here because I intend to do something interesting on the ControlNets.
+ controlnet_insert_block_id = 30
+ if controlnet is not None and controlnet_frames is not None:
+ res_stacks = []
+ # process controlnet frames with batch
+ for batch_id in range(0, sample.shape[0], controlnet_batch_size):
+ batch_id_ = min(batch_id + controlnet_batch_size, sample.shape[0])
+ res_stack = controlnet(
+ sample[batch_id: batch_id_],
+ timestep,
+ encoder_hidden_states[batch_id: batch_id_],
+ controlnet_frames[:, batch_id: batch_id_],
+ tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
+ )
+ if vram_limit_level >= 1:
+ res_stack = [res.cpu() for res in res_stack]
+ res_stacks.append(res_stack)
+ # concat the residual
+ additional_res_stack = []
+ for i in range(len(res_stacks[0])):
+ res = torch.concat([res_stack[i] for res_stack in res_stacks], dim=0)
+ additional_res_stack.append(res)
+ else:
+ additional_res_stack = None
+
+ # 2. time
+ time_emb = unet.time_proj(timestep).to(sample.dtype)
+ time_emb = unet.time_embedding(time_emb)
+
+ # 3. pre-process
+ height, width = sample.shape[2], sample.shape[3]
+ hidden_states = unet.conv_in(sample)
+ text_emb = encoder_hidden_states
+ res_stack = [hidden_states.cpu() if vram_limit_level>=1 else hidden_states]
+
+ # 4. blocks
+ for block_id, block in enumerate(unet.blocks):
+ # 4.1 UNet
+ if isinstance(block, PushBlock):
+ hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
+ if vram_limit_level>=1:
+ res_stack[-1] = res_stack[-1].cpu()
+ elif isinstance(block, PopBlock):
+ if vram_limit_level>=1:
+ res_stack[-1] = res_stack[-1].to(device)
+ hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
+ else:
+ hidden_states_input = hidden_states
+ hidden_states_output = []
+ for batch_id in range(0, sample.shape[0], unet_batch_size):
+ batch_id_ = min(batch_id + unet_batch_size, sample.shape[0])
+ hidden_states, _, _, _ = block(
+ hidden_states_input[batch_id: batch_id_],
+ time_emb,
+ text_emb[batch_id: batch_id_],
+ res_stack,
+ cross_frame_attention=cross_frame_attention,
+ ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, {}),
+ tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
+ )
+ hidden_states_output.append(hidden_states)
+ hidden_states = torch.concat(hidden_states_output, dim=0)
+ # 4.2 AnimateDiff
+ if motion_modules is not None:
+ if block_id in motion_modules.call_block_id:
+ motion_module_id = motion_modules.call_block_id[block_id]
+ hidden_states, time_emb, text_emb, res_stack = motion_modules.motion_modules[motion_module_id](
+ hidden_states, time_emb, text_emb, res_stack,
+ batch_size=1
+ )
+ # 4.3 ControlNet
+ if block_id == controlnet_insert_block_id and additional_res_stack is not None:
+ hidden_states += additional_res_stack.pop().to(device)
+ if vram_limit_level>=1:
+ res_stack = [(res.to(device) + additional_res.to(device)).cpu() for res, additional_res in zip(res_stack, additional_res_stack)]
+ else:
+ res_stack = [res + additional_res for res, additional_res in zip(res_stack, additional_res_stack)]
+
+ # 5. output
+ hidden_states = unet.conv_norm_out(hidden_states)
+ hidden_states = unet.conv_act(hidden_states)
+ hidden_states = unet.conv_out(hidden_states)
+
+ return hidden_states
+
+
+
+
+def lets_dance_xl(
+ unet: SDXLUNet,
+ motion_modules: SDXLMotionModel = None,
+ controlnet: MultiControlNetManager = None,
+ sample = None,
+ add_time_id = None,
+ add_text_embeds = None,
+ timestep = None,
+ encoder_hidden_states = None,
+ ipadapter_kwargs_list = {},
+ controlnet_frames = None,
+ unet_batch_size = 1,
+ controlnet_batch_size = 1,
+ cross_frame_attention = False,
+ tiled=False,
+ tile_size=64,
+ tile_stride=32,
+ device = "cuda",
+ vram_limit_level = 0,
+):
+ # 0. Text embedding alignment (only for video processing)
+ if encoder_hidden_states.shape[0] != sample.shape[0]:
+ encoder_hidden_states = encoder_hidden_states.repeat(sample.shape[0], 1, 1, 1)
+ if add_text_embeds.shape[0] != sample.shape[0]:
+ add_text_embeds = add_text_embeds.repeat(sample.shape[0], 1)
+
+ # 1. ControlNet
+ controlnet_insert_block_id = 22
+ if controlnet is not None and controlnet_frames is not None:
+ res_stacks = []
+ # process controlnet frames with batch
+ for batch_id in range(0, sample.shape[0], controlnet_batch_size):
+ batch_id_ = min(batch_id + controlnet_batch_size, sample.shape[0])
+ res_stack = controlnet(
+ sample[batch_id: batch_id_],
+ timestep,
+ encoder_hidden_states[batch_id: batch_id_],
+ controlnet_frames[:, batch_id: batch_id_],
+ add_time_id=add_time_id,
+ add_text_embeds=add_text_embeds,
+ tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
+ unet=unet, # for Kolors, some modules in ControlNets will be replaced.
+ )
+ if vram_limit_level >= 1:
+ res_stack = [res.cpu() for res in res_stack]
+ res_stacks.append(res_stack)
+ # concat the residual
+ additional_res_stack = []
+ for i in range(len(res_stacks[0])):
+ res = torch.concat([res_stack[i] for res_stack in res_stacks], dim=0)
+ additional_res_stack.append(res)
+ else:
+ additional_res_stack = None
+
+ # 2. time
+ t_emb = unet.time_proj(timestep).to(sample.dtype)
+ t_emb = unet.time_embedding(t_emb)
+
+ time_embeds = unet.add_time_proj(add_time_id)
+ time_embeds = time_embeds.reshape((add_text_embeds.shape[0], -1))
+ add_embeds = torch.concat([add_text_embeds, time_embeds], dim=-1)
+ add_embeds = add_embeds.to(sample.dtype)
+ add_embeds = unet.add_time_embedding(add_embeds)
+
+ time_emb = t_emb + add_embeds
+
+ # 3. pre-process
+ height, width = sample.shape[2], sample.shape[3]
+ hidden_states = unet.conv_in(sample)
+ text_emb = encoder_hidden_states if unet.text_intermediate_proj is None else unet.text_intermediate_proj(encoder_hidden_states)
+ res_stack = [hidden_states]
+
+ # 4. blocks
+ for block_id, block in enumerate(unet.blocks):
+ # 4.1 UNet
+ if isinstance(block, PushBlock):
+ hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
+ if vram_limit_level>=1:
+ res_stack[-1] = res_stack[-1].cpu()
+ elif isinstance(block, PopBlock):
+ if vram_limit_level>=1:
+ res_stack[-1] = res_stack[-1].to(device)
+ hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
+ else:
+ hidden_states_input = hidden_states
+ hidden_states_output = []
+ for batch_id in range(0, sample.shape[0], unet_batch_size):
+ batch_id_ = min(batch_id + unet_batch_size, sample.shape[0])
+ hidden_states, _, _, _ = block(
+ hidden_states_input[batch_id: batch_id_],
+ time_emb[batch_id: batch_id_],
+ text_emb[batch_id: batch_id_],
+ res_stack,
+ cross_frame_attention=cross_frame_attention,
+ ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, {}),
+ tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
+ )
+ hidden_states_output.append(hidden_states)
+ hidden_states = torch.concat(hidden_states_output, dim=0)
+ # 4.2 AnimateDiff
+ if motion_modules is not None:
+ if block_id in motion_modules.call_block_id:
+ motion_module_id = motion_modules.call_block_id[block_id]
+ hidden_states, time_emb, text_emb, res_stack = motion_modules.motion_modules[motion_module_id](
+ hidden_states, time_emb, text_emb, res_stack,
+ batch_size=1
+ )
+ # 4.3 ControlNet
+ if block_id == controlnet_insert_block_id and additional_res_stack is not None:
+ hidden_states += additional_res_stack.pop().to(device)
+ res_stack = [res + additional_res for res, additional_res in zip(res_stack, additional_res_stack)]
+
+ # 5. output
+ hidden_states = unet.conv_norm_out(hidden_states)
+ hidden_states = unet.conv_act(hidden_states)
+ hidden_states = unet.conv_out(hidden_states)
+
+ return hidden_states
\ No newline at end of file
diff --git a/PusaV1/diffsynth/pipelines/flux_image.py b/PusaV1/diffsynth/pipelines/flux_image.py
new file mode 100644
index 0000000000000000000000000000000000000000..c0729fc5470d26c7099a498a28c33a550a94de12
--- /dev/null
+++ b/PusaV1/diffsynth/pipelines/flux_image.py
@@ -0,0 +1,722 @@
+from ..models import ModelManager, FluxDiT, SD3TextEncoder1, FluxTextEncoder2, FluxVAEDecoder, FluxVAEEncoder, FluxIpAdapter
+from ..controlnets import FluxMultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
+from ..prompters import FluxPrompter
+from ..schedulers import FlowMatchScheduler
+from .base import BasePipeline
+from typing import List
+import torch
+from tqdm import tqdm
+import numpy as np
+from PIL import Image
+from ..models.tiler import FastTileWorker
+from transformers import SiglipVisionModel
+from copy import deepcopy
+from transformers.models.t5.modeling_t5 import T5LayerNorm, T5DenseActDense, T5DenseGatedActDense
+from ..models.flux_dit import RMSNorm
+from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
+
+
+class FluxImagePipeline(BasePipeline):
+
+ def __init__(self, device="cuda", torch_dtype=torch.float16):
+ super().__init__(device=device, torch_dtype=torch_dtype, height_division_factor=16, width_division_factor=16)
+ self.scheduler = FlowMatchScheduler()
+ self.prompter = FluxPrompter()
+ # models
+ self.text_encoder_1: SD3TextEncoder1 = None
+ self.text_encoder_2: FluxTextEncoder2 = None
+ self.dit: FluxDiT = None
+ self.vae_decoder: FluxVAEDecoder = None
+ self.vae_encoder: FluxVAEEncoder = None
+ self.controlnet: FluxMultiControlNetManager = None
+ self.ipadapter: FluxIpAdapter = None
+ self.ipadapter_image_encoder: SiglipVisionModel = None
+ self.infinityou_processor: InfinitYou = None
+ self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder', 'controlnet', 'ipadapter', 'ipadapter_image_encoder']
+
+
+ def enable_vram_management(self, num_persistent_param_in_dit=None):
+ dtype = next(iter(self.text_encoder_1.parameters())).dtype
+ enable_vram_management(
+ self.text_encoder_1,
+ module_map = {
+ torch.nn.Linear: AutoWrappedLinear,
+ torch.nn.Embedding: AutoWrappedModule,
+ torch.nn.LayerNorm: AutoWrappedModule,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device="cpu",
+ computation_dtype=self.torch_dtype,
+ computation_device=self.device,
+ ),
+ )
+ dtype = next(iter(self.text_encoder_2.parameters())).dtype
+ enable_vram_management(
+ self.text_encoder_2,
+ module_map = {
+ torch.nn.Linear: AutoWrappedLinear,
+ torch.nn.Embedding: AutoWrappedModule,
+ T5LayerNorm: AutoWrappedModule,
+ T5DenseActDense: AutoWrappedModule,
+ T5DenseGatedActDense: AutoWrappedModule,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device="cpu",
+ computation_dtype=self.torch_dtype,
+ computation_device=self.device,
+ ),
+ )
+ dtype = next(iter(self.dit.parameters())).dtype
+ enable_vram_management(
+ self.dit,
+ module_map = {
+ RMSNorm: AutoWrappedModule,
+ torch.nn.Linear: AutoWrappedLinear,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device="cuda",
+ computation_dtype=self.torch_dtype,
+ computation_device=self.device,
+ ),
+ max_num_param=num_persistent_param_in_dit,
+ overflow_module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device="cpu",
+ computation_dtype=self.torch_dtype,
+ computation_device=self.device,
+ ),
+ )
+ dtype = next(iter(self.vae_decoder.parameters())).dtype
+ enable_vram_management(
+ self.vae_decoder,
+ module_map = {
+ torch.nn.Linear: AutoWrappedLinear,
+ torch.nn.Conv2d: AutoWrappedModule,
+ torch.nn.GroupNorm: AutoWrappedModule,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device="cpu",
+ computation_dtype=self.torch_dtype,
+ computation_device=self.device,
+ ),
+ )
+ dtype = next(iter(self.vae_encoder.parameters())).dtype
+ enable_vram_management(
+ self.vae_encoder,
+ module_map = {
+ torch.nn.Linear: AutoWrappedLinear,
+ torch.nn.Conv2d: AutoWrappedModule,
+ torch.nn.GroupNorm: AutoWrappedModule,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device="cpu",
+ computation_dtype=self.torch_dtype,
+ computation_device=self.device,
+ ),
+ )
+ self.enable_cpu_offload()
+
+
+ def denoising_model(self):
+ return self.dit
+
+
+ def fetch_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], prompt_extender_classes=[]):
+ self.text_encoder_1 = model_manager.fetch_model("sd3_text_encoder_1")
+ self.text_encoder_2 = model_manager.fetch_model("flux_text_encoder_2")
+ self.dit = model_manager.fetch_model("flux_dit")
+ self.vae_decoder = model_manager.fetch_model("flux_vae_decoder")
+ self.vae_encoder = model_manager.fetch_model("flux_vae_encoder")
+ self.prompter.fetch_models(self.text_encoder_1, self.text_encoder_2)
+ self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
+ self.prompter.load_prompt_extenders(model_manager, prompt_extender_classes)
+
+ # ControlNets
+ controlnet_units = []
+ for config in controlnet_config_units:
+ controlnet_unit = ControlNetUnit(
+ Annotator(config.processor_id, device=self.device, skip_processor=config.skip_processor),
+ model_manager.fetch_model("flux_controlnet", config.model_path),
+ config.scale
+ )
+ controlnet_units.append(controlnet_unit)
+ self.controlnet = FluxMultiControlNetManager(controlnet_units)
+
+ # IP-Adapters
+ self.ipadapter = model_manager.fetch_model("flux_ipadapter")
+ self.ipadapter_image_encoder = model_manager.fetch_model("siglip_vision_model")
+
+ # InfiniteYou
+ self.image_proj_model = model_manager.fetch_model("infiniteyou_image_projector")
+ if self.image_proj_model is not None:
+ self.infinityou_processor = InfinitYou(device=self.device)
+
+
+ @staticmethod
+ def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], prompt_extender_classes=[], device=None, torch_dtype=None):
+ pipe = FluxImagePipeline(
+ device=model_manager.device if device is None else device,
+ torch_dtype=model_manager.torch_dtype if torch_dtype is None else torch_dtype,
+ )
+ pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes, prompt_extender_classes)
+ return pipe
+
+
+ def encode_image(self, image, tiled=False, tile_size=64, tile_stride=32):
+ latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+ return latents
+
+
+ def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
+ image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+ image = self.vae_output_to_image(image)
+ return image
+
+
+ def encode_prompt(self, prompt, positive=True, t5_sequence_length=512):
+ prompt_emb, pooled_prompt_emb, text_ids = self.prompter.encode_prompt(
+ prompt, device=self.device, positive=positive, t5_sequence_length=t5_sequence_length
+ )
+ return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_ids": text_ids}
+
+
+ def prepare_extra_input(self, latents=None, guidance=1.0):
+ latent_image_ids = self.dit.prepare_image_ids(latents)
+ guidance = torch.Tensor([guidance] * latents.shape[0]).to(device=latents.device, dtype=latents.dtype)
+ return {"image_ids": latent_image_ids, "guidance": guidance}
+
+
+ def apply_controlnet_mask_on_latents(self, latents, mask):
+ mask = (self.preprocess_image(mask) + 1) / 2
+ mask = mask.mean(dim=1, keepdim=True)
+ mask = mask.to(dtype=self.torch_dtype, device=self.device)
+ mask = 1 - torch.nn.functional.interpolate(mask, size=latents.shape[-2:])
+ latents = torch.concat([latents, mask], dim=1)
+ return latents
+
+
+ def apply_controlnet_mask_on_image(self, image, mask):
+ mask = mask.resize(image.size)
+ mask = self.preprocess_image(mask).mean(dim=[0, 1])
+ image = np.array(image)
+ image[mask > 0] = 0
+ image = Image.fromarray(image)
+ return image
+
+
+ def prepare_controlnet_input(self, controlnet_image, controlnet_inpaint_mask, tiler_kwargs):
+ if isinstance(controlnet_image, Image.Image):
+ controlnet_image = [controlnet_image] * len(self.controlnet.processors)
+
+ controlnet_frames = []
+ for i in range(len(self.controlnet.processors)):
+ # image annotator
+ image = self.controlnet.process_image(controlnet_image[i], processor_id=i)[0]
+ if controlnet_inpaint_mask is not None and self.controlnet.processors[i].processor_id == "inpaint":
+ image = self.apply_controlnet_mask_on_image(image, controlnet_inpaint_mask)
+
+ # image to tensor
+ image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
+
+ # vae encoder
+ image = self.encode_image(image, **tiler_kwargs)
+ if controlnet_inpaint_mask is not None and self.controlnet.processors[i].processor_id == "inpaint":
+ image = self.apply_controlnet_mask_on_latents(image, controlnet_inpaint_mask)
+
+ # store it
+ controlnet_frames.append(image)
+ return controlnet_frames
+
+
+ def prepare_ipadapter_inputs(self, images, height=384, width=384):
+ images = [image.convert("RGB").resize((width, height), resample=3) for image in images]
+ images = [self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype) for image in images]
+ return torch.cat(images, dim=0)
+
+
+ def inpaint_fusion(self, latents, inpaint_latents, pred_noise, fg_mask, bg_mask, progress_id, background_weight=0.):
+ # inpaint noise
+ inpaint_noise = (latents - inpaint_latents) / self.scheduler.sigmas[progress_id]
+ # merge noise
+ weight = torch.ones_like(inpaint_noise)
+ inpaint_noise[fg_mask] = pred_noise[fg_mask]
+ inpaint_noise[bg_mask] += pred_noise[bg_mask] * background_weight
+ weight[bg_mask] += background_weight
+ inpaint_noise /= weight
+ return inpaint_noise
+
+
+ def preprocess_masks(self, masks, height, width, dim):
+ out_masks = []
+ for mask in masks:
+ mask = self.preprocess_image(mask.resize((width, height), resample=Image.NEAREST)).mean(dim=1, keepdim=True) > 0
+ mask = mask.repeat(1, dim, 1, 1).to(device=self.device, dtype=self.torch_dtype)
+ out_masks.append(mask)
+ return out_masks
+
+
+ def prepare_entity_inputs(self, entity_prompts, entity_masks, width, height, t5_sequence_length=512, enable_eligen_inpaint=False):
+ fg_mask, bg_mask = None, None
+ if enable_eligen_inpaint:
+ masks_ = deepcopy(entity_masks)
+ fg_masks = torch.cat([self.preprocess_image(mask.resize((width//8, height//8))).mean(dim=1, keepdim=True) for mask in masks_])
+ fg_masks = (fg_masks > 0).float()
+ fg_mask = fg_masks.sum(dim=0, keepdim=True).repeat(1, 16, 1, 1) > 0
+ bg_mask = ~fg_mask
+ entity_masks = self.preprocess_masks(entity_masks, height//8, width//8, 1)
+ entity_masks = torch.cat(entity_masks, dim=0).unsqueeze(0) # b, n_mask, c, h, w
+ entity_prompts = self.encode_prompt(entity_prompts, t5_sequence_length=t5_sequence_length)['prompt_emb'].unsqueeze(0)
+ return entity_prompts, entity_masks, fg_mask, bg_mask
+
+
+ def prepare_latents(self, input_image, height, width, seed, tiled, tile_size, tile_stride):
+ if input_image is not None:
+ self.load_models_to_device(['vae_encoder'])
+ image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
+ input_latents = self.encode_image(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+ noise = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
+ latents = self.scheduler.add_noise(input_latents, noise, timestep=self.scheduler.timesteps[0])
+ else:
+ latents = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
+ input_latents = None
+ return latents, input_latents
+
+
+ def prepare_ipadapter(self, ipadapter_images, ipadapter_scale):
+ if ipadapter_images is not None:
+ self.load_models_to_device(['ipadapter_image_encoder'])
+ ipadapter_images = self.prepare_ipadapter_inputs(ipadapter_images)
+ ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images).pooler_output
+ self.load_models_to_device(['ipadapter'])
+ ipadapter_kwargs_list_posi = {"ipadapter_kwargs_list": self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)}
+ ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": self.ipadapter(torch.zeros_like(ipadapter_image_encoding))}
+ else:
+ ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": {}}, {"ipadapter_kwargs_list": {}}
+ return ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega
+
+
+ def prepare_controlnet(self, controlnet_image, masks, controlnet_inpaint_mask, tiler_kwargs, enable_controlnet_on_negative):
+ if controlnet_image is not None:
+ self.load_models_to_device(['vae_encoder'])
+ controlnet_kwargs_posi = {"controlnet_frames": self.prepare_controlnet_input(controlnet_image, controlnet_inpaint_mask, tiler_kwargs)}
+ if len(masks) > 0 and controlnet_inpaint_mask is not None:
+ print("The controlnet_inpaint_mask will be overridden by masks.")
+ local_controlnet_kwargs = [{"controlnet_frames": self.prepare_controlnet_input(controlnet_image, mask, tiler_kwargs)} for mask in masks]
+ else:
+ local_controlnet_kwargs = None
+ else:
+ controlnet_kwargs_posi, local_controlnet_kwargs = {"controlnet_frames": None}, [{}] * len(masks)
+ controlnet_kwargs_nega = controlnet_kwargs_posi if enable_controlnet_on_negative else {}
+ return controlnet_kwargs_posi, controlnet_kwargs_nega, local_controlnet_kwargs
+
+
+ def prepare_eligen(self, prompt_emb_nega, eligen_entity_prompts, eligen_entity_masks, width, height, t5_sequence_length, enable_eligen_inpaint, enable_eligen_on_negative, cfg_scale):
+ if eligen_entity_masks is not None:
+ entity_prompt_emb_posi, entity_masks_posi, fg_mask, bg_mask = self.prepare_entity_inputs(eligen_entity_prompts, eligen_entity_masks, width, height, t5_sequence_length, enable_eligen_inpaint)
+ if enable_eligen_on_negative and cfg_scale != 1.0:
+ entity_prompt_emb_nega = prompt_emb_nega['prompt_emb'].unsqueeze(1).repeat(1, entity_masks_posi.shape[1], 1, 1)
+ entity_masks_nega = entity_masks_posi
+ else:
+ entity_prompt_emb_nega, entity_masks_nega = None, None
+ else:
+ entity_prompt_emb_posi, entity_masks_posi, entity_prompt_emb_nega, entity_masks_nega = None, None, None, None
+ fg_mask, bg_mask = None, None
+ eligen_kwargs_posi = {"entity_prompt_emb": entity_prompt_emb_posi, "entity_masks": entity_masks_posi}
+ eligen_kwargs_nega = {"entity_prompt_emb": entity_prompt_emb_nega, "entity_masks": entity_masks_nega}
+ return eligen_kwargs_posi, eligen_kwargs_nega, fg_mask, bg_mask
+
+
+ def prepare_prompts(self, prompt, local_prompts, masks, mask_scales, t5_sequence_length, negative_prompt, cfg_scale):
+ # Extend prompt
+ self.load_models_to_device(['text_encoder_1', 'text_encoder_2'])
+ prompt, local_prompts, masks, mask_scales = self.extend_prompt(prompt, local_prompts, masks, mask_scales)
+
+ # Encode prompts
+ prompt_emb_posi = self.encode_prompt(prompt, t5_sequence_length=t5_sequence_length)
+ prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False, t5_sequence_length=t5_sequence_length) if cfg_scale != 1.0 else None
+ prompt_emb_locals = [self.encode_prompt(prompt_local, t5_sequence_length=t5_sequence_length) for prompt_local in local_prompts]
+ return prompt_emb_posi, prompt_emb_nega, prompt_emb_locals
+
+
+ def prepare_infinite_you(self, id_image, controlnet_image, infinityou_guidance, height, width):
+ if self.infinityou_processor is not None and id_image is not None:
+ return self.infinityou_processor.prepare_infinite_you(self.image_proj_model, id_image, controlnet_image, infinityou_guidance, height, width)
+ else:
+ return {}, controlnet_image
+
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ # Prompt
+ prompt,
+ negative_prompt="",
+ cfg_scale=1.0,
+ embedded_guidance=3.5,
+ t5_sequence_length=512,
+ # Image
+ input_image=None,
+ denoising_strength=1.0,
+ height=1024,
+ width=1024,
+ seed=None,
+ # Steps
+ num_inference_steps=30,
+ # local prompts
+ local_prompts=(),
+ masks=(),
+ mask_scales=(),
+ # ControlNet
+ controlnet_image=None,
+ controlnet_inpaint_mask=None,
+ enable_controlnet_on_negative=False,
+ # IP-Adapter
+ ipadapter_images=None,
+ ipadapter_scale=1.0,
+ # EliGen
+ eligen_entity_prompts=None,
+ eligen_entity_masks=None,
+ enable_eligen_on_negative=False,
+ enable_eligen_inpaint=False,
+ # InfiniteYou
+ infinityou_id_image=None,
+ infinityou_guidance=1.0,
+ # TeaCache
+ tea_cache_l1_thresh=None,
+ # Tile
+ tiled=False,
+ tile_size=128,
+ tile_stride=64,
+ # Progress bar
+ progress_bar_cmd=tqdm,
+ progress_bar_st=None,
+ ):
+ height, width = self.check_resize_height_width(height, width)
+
+ # Tiler parameters
+ tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
+
+ # Prepare scheduler
+ self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
+
+ # Prepare latent tensors
+ latents, input_latents = self.prepare_latents(input_image, height, width, seed, tiled, tile_size, tile_stride)
+
+ # Prompt
+ prompt_emb_posi, prompt_emb_nega, prompt_emb_locals = self.prepare_prompts(prompt, local_prompts, masks, mask_scales, t5_sequence_length, negative_prompt, cfg_scale)
+
+ # Extra input
+ extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)
+
+ # InfiniteYou
+ infiniteyou_kwargs, controlnet_image = self.prepare_infinite_you(infinityou_id_image, controlnet_image, infinityou_guidance, height, width)
+
+ # Entity control
+ eligen_kwargs_posi, eligen_kwargs_nega, fg_mask, bg_mask = self.prepare_eligen(prompt_emb_nega, eligen_entity_prompts, eligen_entity_masks, width, height, t5_sequence_length, enable_eligen_inpaint, enable_eligen_on_negative, cfg_scale)
+
+ # IP-Adapter
+ ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = self.prepare_ipadapter(ipadapter_images, ipadapter_scale)
+
+ # ControlNets
+ controlnet_kwargs_posi, controlnet_kwargs_nega, local_controlnet_kwargs = self.prepare_controlnet(controlnet_image, masks, controlnet_inpaint_mask, tiler_kwargs, enable_controlnet_on_negative)
+
+ # TeaCache
+ tea_cache_kwargs = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh) if tea_cache_l1_thresh is not None else None}
+
+ # Denoise
+ self.load_models_to_device(['dit', 'controlnet'])
+ for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
+ timestep = timestep.unsqueeze(0).to(self.device)
+
+ # Positive side
+ inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux(
+ dit=self.dit, controlnet=self.controlnet,
+ hidden_states=latents, timestep=timestep,
+ **prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs, **infiniteyou_kwargs
+ )
+ noise_pred_posi = self.control_noise_via_local_prompts(
+ prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback,
+ special_kwargs=controlnet_kwargs_posi, special_local_kwargs_list=local_controlnet_kwargs
+ )
+
+ # Inpaint
+ if enable_eligen_inpaint:
+ noise_pred_posi = self.inpaint_fusion(latents, input_latents, noise_pred_posi, fg_mask, bg_mask, progress_id)
+
+ # Classifier-free guidance
+ if cfg_scale != 1.0:
+ # Negative side
+ noise_pred_nega = lets_dance_flux(
+ dit=self.dit, controlnet=self.controlnet,
+ hidden_states=latents, timestep=timestep,
+ **prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega, **infiniteyou_kwargs,
+ )
+ noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
+ else:
+ noise_pred = noise_pred_posi
+
+ # Iterate
+ latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
+
+ # UI
+ if progress_bar_st is not None:
+ progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
+
+ # Decode image
+ self.load_models_to_device(['vae_decoder'])
+ image = self.decode_image(latents, **tiler_kwargs)
+
+ # Offload all models
+ self.load_models_to_device([])
+ return image
+
+
+
+class InfinitYou:
+ def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
+ from facexlib.recognition import init_recognition_model
+ from insightface.app import FaceAnalysis
+ self.device = device
+ self.torch_dtype = torch_dtype
+ insightface_root_path = 'models/InfiniteYou/insightface'
+ self.app_640 = FaceAnalysis(name='antelopev2', root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
+ self.app_640.prepare(ctx_id=0, det_size=(640, 640))
+ self.app_320 = FaceAnalysis(name='antelopev2', root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
+ self.app_320.prepare(ctx_id=0, det_size=(320, 320))
+ self.app_160 = FaceAnalysis(name='antelopev2', root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
+ self.app_160.prepare(ctx_id=0, det_size=(160, 160))
+ self.arcface_model = init_recognition_model('arcface', device=self.device)
+
+ def _detect_face(self, id_image_cv2):
+ face_info = self.app_640.get(id_image_cv2)
+ if len(face_info) > 0:
+ return face_info
+ face_info = self.app_320.get(id_image_cv2)
+ if len(face_info) > 0:
+ return face_info
+ face_info = self.app_160.get(id_image_cv2)
+ return face_info
+
+ def extract_arcface_bgr_embedding(self, in_image, landmark):
+ from insightface.utils import face_align
+ arc_face_image = face_align.norm_crop(in_image, landmark=np.array(landmark), image_size=112)
+ arc_face_image = torch.from_numpy(arc_face_image).unsqueeze(0).permute(0, 3, 1, 2) / 255.
+ arc_face_image = 2 * arc_face_image - 1
+ arc_face_image = arc_face_image.contiguous().to(self.device)
+ face_emb = self.arcface_model(arc_face_image)[0] # [512], normalized
+ return face_emb
+
+ def prepare_infinite_you(self, model, id_image, controlnet_image, infinityou_guidance, height, width):
+ import cv2
+ if id_image is None:
+ return {'id_emb': None}, controlnet_image
+ id_image_cv2 = cv2.cvtColor(np.array(id_image), cv2.COLOR_RGB2BGR)
+ face_info = self._detect_face(id_image_cv2)
+ if len(face_info) == 0:
+ raise ValueError('No face detected in the input ID image')
+ landmark = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1]['kps'] # only use the maximum face
+ id_emb = self.extract_arcface_bgr_embedding(id_image_cv2, landmark)
+ id_emb = model(id_emb.unsqueeze(0).reshape([1, -1, 512]).to(dtype=self.torch_dtype))
+ if controlnet_image is None:
+ controlnet_image = Image.fromarray(np.zeros([height, width, 3]).astype(np.uint8))
+ infinityou_guidance = torch.Tensor([infinityou_guidance]).to(device=self.device, dtype=self.torch_dtype)
+ return {'id_emb': id_emb, 'infinityou_guidance': infinityou_guidance}, controlnet_image
+
+
+class TeaCache:
+ def __init__(self, num_inference_steps, rel_l1_thresh):
+ self.num_inference_steps = num_inference_steps
+ self.step = 0
+ self.accumulated_rel_l1_distance = 0
+ self.previous_modulated_input = None
+ self.rel_l1_thresh = rel_l1_thresh
+ self.previous_residual = None
+ self.previous_hidden_states = None
+
+ def check(self, dit: FluxDiT, hidden_states, conditioning):
+ inp = hidden_states.clone()
+ temb_ = conditioning.clone()
+ modulated_inp, _, _, _, _ = dit.blocks[0].norm1_a(inp, emb=temb_)
+ if self.step == 0 or self.step == self.num_inference_steps - 1:
+ should_calc = True
+ self.accumulated_rel_l1_distance = 0
+ else:
+ coefficients = [4.98651651e+02, -2.83781631e+02, 5.58554382e+01, -3.82021401e+00, 2.64230861e-01]
+ rescale_func = np.poly1d(coefficients)
+ self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
+ if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
+ should_calc = False
+ else:
+ should_calc = True
+ self.accumulated_rel_l1_distance = 0
+ self.previous_modulated_input = modulated_inp
+ self.step += 1
+ if self.step == self.num_inference_steps:
+ self.step = 0
+ if should_calc:
+ self.previous_hidden_states = hidden_states.clone()
+ return not should_calc
+
+ def store(self, hidden_states):
+ self.previous_residual = hidden_states - self.previous_hidden_states
+ self.previous_hidden_states = None
+
+ def update(self, hidden_states):
+ hidden_states = hidden_states + self.previous_residual
+ return hidden_states
+
+
+def lets_dance_flux(
+ dit: FluxDiT,
+ controlnet: FluxMultiControlNetManager = None,
+ hidden_states=None,
+ timestep=None,
+ prompt_emb=None,
+ pooled_prompt_emb=None,
+ guidance=None,
+ text_ids=None,
+ image_ids=None,
+ controlnet_frames=None,
+ tiled=False,
+ tile_size=128,
+ tile_stride=64,
+ entity_prompt_emb=None,
+ entity_masks=None,
+ ipadapter_kwargs_list={},
+ id_emb=None,
+ infinityou_guidance=None,
+ tea_cache: TeaCache = None,
+ **kwargs
+):
+ if tiled:
+ def flux_forward_fn(hl, hr, wl, wr):
+ tiled_controlnet_frames = [f[:, :, hl: hr, wl: wr] for f in controlnet_frames] if controlnet_frames is not None else None
+ return lets_dance_flux(
+ dit=dit,
+ controlnet=controlnet,
+ hidden_states=hidden_states[:, :, hl: hr, wl: wr],
+ timestep=timestep,
+ prompt_emb=prompt_emb,
+ pooled_prompt_emb=pooled_prompt_emb,
+ guidance=guidance,
+ text_ids=text_ids,
+ image_ids=None,
+ controlnet_frames=tiled_controlnet_frames,
+ tiled=False,
+ **kwargs
+ )
+ return FastTileWorker().tiled_forward(
+ flux_forward_fn,
+ hidden_states,
+ tile_size=tile_size,
+ tile_stride=tile_stride,
+ tile_device=hidden_states.device,
+ tile_dtype=hidden_states.dtype
+ )
+
+
+ # ControlNet
+ if controlnet is not None and controlnet_frames is not None:
+ controlnet_extra_kwargs = {
+ "hidden_states": hidden_states,
+ "timestep": timestep,
+ "prompt_emb": prompt_emb,
+ "pooled_prompt_emb": pooled_prompt_emb,
+ "guidance": guidance,
+ "text_ids": text_ids,
+ "image_ids": image_ids,
+ "tiled": tiled,
+ "tile_size": tile_size,
+ "tile_stride": tile_stride,
+ }
+ if id_emb is not None:
+ controlnet_text_ids = torch.zeros(id_emb.shape[0], id_emb.shape[1], 3).to(device=hidden_states.device, dtype=hidden_states.dtype)
+ controlnet_extra_kwargs.update({"prompt_emb": id_emb, 'text_ids': controlnet_text_ids, 'guidance': infinityou_guidance})
+ controlnet_res_stack, controlnet_single_res_stack = controlnet(
+ controlnet_frames, **controlnet_extra_kwargs
+ )
+
+ if image_ids is None:
+ image_ids = dit.prepare_image_ids(hidden_states)
+
+ conditioning = dit.time_embedder(timestep, hidden_states.dtype) + dit.pooled_text_embedder(pooled_prompt_emb)
+ if dit.guidance_embedder is not None:
+ guidance = guidance * 1000
+ conditioning = conditioning + dit.guidance_embedder(guidance, hidden_states.dtype)
+
+ height, width = hidden_states.shape[-2:]
+ hidden_states = dit.patchify(hidden_states)
+ hidden_states = dit.x_embedder(hidden_states)
+
+ if entity_prompt_emb is not None and entity_masks is not None:
+ prompt_emb, image_rotary_emb, attention_mask = dit.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids)
+ else:
+ prompt_emb = dit.context_embedder(prompt_emb)
+ image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
+ attention_mask = None
+
+ # TeaCache
+ if tea_cache is not None:
+ tea_cache_update = tea_cache.check(dit, hidden_states, conditioning)
+ else:
+ tea_cache_update = False
+
+ if tea_cache_update:
+ hidden_states = tea_cache.update(hidden_states)
+ else:
+ # Joint Blocks
+ for block_id, block in enumerate(dit.blocks):
+ hidden_states, prompt_emb = block(
+ hidden_states,
+ prompt_emb,
+ conditioning,
+ image_rotary_emb,
+ attention_mask,
+ ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None)
+ )
+ # ControlNet
+ if controlnet is not None and controlnet_frames is not None:
+ hidden_states = hidden_states + controlnet_res_stack[block_id]
+
+ # Single Blocks
+ hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
+ num_joint_blocks = len(dit.blocks)
+ for block_id, block in enumerate(dit.single_blocks):
+ hidden_states, prompt_emb = block(
+ hidden_states,
+ prompt_emb,
+ conditioning,
+ image_rotary_emb,
+ attention_mask,
+ ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None)
+ )
+ # ControlNet
+ if controlnet is not None and controlnet_frames is not None:
+ hidden_states[:, prompt_emb.shape[1]:] = hidden_states[:, prompt_emb.shape[1]:] + controlnet_single_res_stack[block_id]
+ hidden_states = hidden_states[:, prompt_emb.shape[1]:]
+
+ if tea_cache is not None:
+ tea_cache.store(hidden_states)
+
+ hidden_states = dit.final_norm_out(hidden_states, conditioning)
+ hidden_states = dit.final_proj_out(hidden_states)
+ hidden_states = dit.unpatchify(hidden_states, height, width)
+
+ return hidden_states
diff --git a/PusaV1/diffsynth/pipelines/hunyuan_image.py b/PusaV1/diffsynth/pipelines/hunyuan_image.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c6f6d5dedc6aac50b06a9f10701f7f8ab33117f
--- /dev/null
+++ b/PusaV1/diffsynth/pipelines/hunyuan_image.py
@@ -0,0 +1,288 @@
+from ..models.hunyuan_dit import HunyuanDiT
+from ..models.hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
+from ..models.sdxl_vae_encoder import SDXLVAEEncoder
+from ..models.sdxl_vae_decoder import SDXLVAEDecoder
+from ..models import ModelManager
+from ..prompters import HunyuanDiTPrompter
+from ..schedulers import EnhancedDDIMScheduler
+from .base import BasePipeline
+import torch
+from tqdm import tqdm
+import numpy as np
+
+
+
+class ImageSizeManager:
+ def __init__(self):
+ pass
+
+
+ def _to_tuple(self, x):
+ if isinstance(x, int):
+ return x, x
+ else:
+ return x
+
+
+ def get_fill_resize_and_crop(self, src, tgt):
+ th, tw = self._to_tuple(tgt)
+ h, w = self._to_tuple(src)
+
+ tr = th / tw # base 分辨率
+ r = h / w # 目标分辨率
+
+ # resize
+ if r > tr:
+ resize_height = th
+ resize_width = int(round(th / h * w))
+ else:
+ resize_width = tw
+ resize_height = int(round(tw / w * h)) # 根据base分辨率,将目标分辨率resize下来
+
+ crop_top = int(round((th - resize_height) / 2.0))
+ crop_left = int(round((tw - resize_width) / 2.0))
+
+ return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
+
+
+ def get_meshgrid(self, start, *args):
+ if len(args) == 0:
+ # start is grid_size
+ num = self._to_tuple(start)
+ start = (0, 0)
+ stop = num
+ elif len(args) == 1:
+ # start is start, args[0] is stop, step is 1
+ start = self._to_tuple(start)
+ stop = self._to_tuple(args[0])
+ num = (stop[0] - start[0], stop[1] - start[1])
+ elif len(args) == 2:
+ # start is start, args[0] is stop, args[1] is num
+ start = self._to_tuple(start) # 左上角 eg: 12,0
+ stop = self._to_tuple(args[0]) # 右下角 eg: 20,32
+ num = self._to_tuple(args[1]) # 目标大小 eg: 32,124
+ else:
+ raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
+
+ grid_h = np.linspace(start[0], stop[0], num[0], endpoint=False, dtype=np.float32) # 12-20 中间差值32份 0-32 中间差值124份
+ grid_w = np.linspace(start[1], stop[1], num[1], endpoint=False, dtype=np.float32)
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
+ grid = np.stack(grid, axis=0) # [2, W, H]
+ return grid
+
+
+ def get_2d_rotary_pos_embed(self, embed_dim, start, *args, use_real=True):
+ grid = self.get_meshgrid(start, *args) # [2, H, w]
+ grid = grid.reshape([2, 1, *grid.shape[1:]]) # 返回一个采样矩阵 分辨率与目标分辨率一致
+ pos_embed = self.get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real)
+ return pos_embed
+
+
+ def get_2d_rotary_pos_embed_from_grid(self, embed_dim, grid, use_real=False):
+ assert embed_dim % 4 == 0
+
+ # use half of dimensions to encode grid_h
+ emb_h = self.get_1d_rotary_pos_embed(embed_dim // 2, grid[0].reshape(-1), use_real=use_real) # (H*W, D/4)
+ emb_w = self.get_1d_rotary_pos_embed(embed_dim // 2, grid[1].reshape(-1), use_real=use_real) # (H*W, D/4)
+
+ if use_real:
+ cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D/2)
+ sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D/2)
+ return cos, sin
+ else:
+ emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2)
+ return emb
+
+
+ def get_1d_rotary_pos_embed(self, dim: int, pos, theta: float = 10000.0, use_real=False):
+ if isinstance(pos, int):
+ pos = np.arange(pos)
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2]
+ t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S]
+ freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2]
+ if use_real:
+ freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
+ freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
+ return freqs_cos, freqs_sin
+ else:
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
+ return freqs_cis
+
+
+ def calc_rope(self, height, width):
+ patch_size = 2
+ head_size = 88
+ th = height // 8 // patch_size
+ tw = width // 8 // patch_size
+ base_size = 512 // 8 // patch_size
+ start, stop = self.get_fill_resize_and_crop((th, tw), base_size)
+ sub_args = [start, stop, (th, tw)]
+ rope = self.get_2d_rotary_pos_embed(head_size, *sub_args)
+ return rope
+
+
+
+class HunyuanDiTImagePipeline(BasePipeline):
+
+ def __init__(self, device="cuda", torch_dtype=torch.float16):
+ super().__init__(device=device, torch_dtype=torch_dtype, height_division_factor=16, width_division_factor=16)
+ self.scheduler = EnhancedDDIMScheduler(prediction_type="v_prediction", beta_start=0.00085, beta_end=0.03)
+ self.prompter = HunyuanDiTPrompter()
+ self.image_size_manager = ImageSizeManager()
+ # models
+ self.text_encoder: HunyuanDiTCLIPTextEncoder = None
+ self.text_encoder_t5: HunyuanDiTT5TextEncoder = None
+ self.dit: HunyuanDiT = None
+ self.vae_decoder: SDXLVAEDecoder = None
+ self.vae_encoder: SDXLVAEEncoder = None
+ self.model_names = ['text_encoder', 'text_encoder_t5', 'dit', 'vae_decoder', 'vae_encoder']
+
+
+ def denoising_model(self):
+ return self.dit
+
+
+ def fetch_models(self, model_manager: ModelManager, prompt_refiner_classes=[]):
+ # Main models
+ self.text_encoder = model_manager.fetch_model("hunyuan_dit_clip_text_encoder")
+ self.text_encoder_t5 = model_manager.fetch_model("hunyuan_dit_t5_text_encoder")
+ self.dit = model_manager.fetch_model("hunyuan_dit")
+ self.vae_decoder = model_manager.fetch_model("sdxl_vae_decoder")
+ self.vae_encoder = model_manager.fetch_model("sdxl_vae_encoder")
+ self.prompter.fetch_models(self.text_encoder, self.text_encoder_t5)
+ self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
+
+
+ @staticmethod
+ def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[], device=None):
+ pipe = HunyuanDiTImagePipeline(
+ device=model_manager.device if device is None else device,
+ torch_dtype=model_manager.torch_dtype,
+ )
+ pipe.fetch_models(model_manager, prompt_refiner_classes)
+ return pipe
+
+
+ def encode_image(self, image, tiled=False, tile_size=64, tile_stride=32):
+ latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+ return latents
+
+
+ def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
+ image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+ image = self.vae_output_to_image(image)
+ return image
+
+
+ def encode_prompt(self, prompt, clip_skip=1, clip_skip_2=1, positive=True):
+ text_emb, text_emb_mask, text_emb_t5, text_emb_mask_t5 = self.prompter.encode_prompt(
+ prompt,
+ clip_skip=clip_skip,
+ clip_skip_2=clip_skip_2,
+ positive=positive,
+ device=self.device
+ )
+ return {
+ "text_emb": text_emb,
+ "text_emb_mask": text_emb_mask,
+ "text_emb_t5": text_emb_t5,
+ "text_emb_mask_t5": text_emb_mask_t5
+ }
+
+
+ def prepare_extra_input(self, latents=None, tiled=False, tile_size=64, tile_stride=32):
+ batch_size, height, width = latents.shape[0], latents.shape[2] * 8, latents.shape[3] * 8
+ if tiled:
+ height, width = tile_size * 16, tile_size * 16
+ image_meta_size = torch.as_tensor([width, height, width, height, 0, 0]).to(device=self.device)
+ freqs_cis_img = self.image_size_manager.calc_rope(height, width)
+ image_meta_size = torch.stack([image_meta_size] * batch_size)
+ return {
+ "size_emb": image_meta_size,
+ "freq_cis_img": (freqs_cis_img[0].to(dtype=self.torch_dtype, device=self.device), freqs_cis_img[1].to(dtype=self.torch_dtype, device=self.device)),
+ "tiled": tiled,
+ "tile_size": tile_size,
+ "tile_stride": tile_stride
+ }
+
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt,
+ local_prompts=[],
+ masks=[],
+ mask_scales=[],
+ negative_prompt="",
+ cfg_scale=7.5,
+ clip_skip=1,
+ clip_skip_2=1,
+ input_image=None,
+ reference_strengths=[0.4],
+ denoising_strength=1.0,
+ height=1024,
+ width=1024,
+ num_inference_steps=20,
+ tiled=False,
+ tile_size=64,
+ tile_stride=32,
+ seed=None,
+ progress_bar_cmd=tqdm,
+ progress_bar_st=None,
+ ):
+ height, width = self.check_resize_height_width(height, width)
+
+ # Prepare scheduler
+ self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
+
+ # Prepare latent tensors
+ noise = self.generate_noise((1, 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
+ if input_image is not None:
+ self.load_models_to_device(['vae_encoder'])
+ image = self.preprocess_image(input_image).to(device=self.device, dtype=torch.float32)
+ latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(self.torch_dtype)
+ latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
+ else:
+ latents = noise.clone()
+
+ # Encode prompts
+ self.load_models_to_device(['text_encoder', 'text_encoder_t5'])
+ prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True)
+ if cfg_scale != 1.0:
+ prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True)
+ prompt_emb_locals = [self.encode_prompt(prompt_local, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True) for prompt_local in local_prompts]
+
+ # Prepare positional id
+ extra_input = self.prepare_extra_input(latents, tiled, tile_size)
+
+ # Denoise
+ self.load_models_to_device(['dit'])
+ for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
+ timestep = torch.tensor([timestep]).to(dtype=self.torch_dtype, device=self.device)
+
+ # Positive side
+ inference_callback = lambda prompt_emb_posi: self.dit(latents, timestep=timestep, **prompt_emb_posi, **extra_input)
+ noise_pred_posi = self.control_noise_via_local_prompts(prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback)
+
+ if cfg_scale != 1.0:
+ # Negative side
+ noise_pred_nega = self.dit(
+ latents, timestep=timestep, **prompt_emb_nega, **extra_input,
+ )
+ # Classifier-free guidance
+ noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
+ else:
+ noise_pred = noise_pred_posi
+
+ latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
+
+ if progress_bar_st is not None:
+ progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
+
+ # Decode image
+ self.load_models_to_device(['vae_decoder'])
+ image = self.decode_image(latents.to(torch.float32), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+
+ # Offload all models
+ self.load_models_to_device([])
+ return image
diff --git a/PusaV1/diffsynth/pipelines/hunyuan_video.py b/PusaV1/diffsynth/pipelines/hunyuan_video.py
new file mode 100644
index 0000000000000000000000000000000000000000..d8a0411e155f293e86a2b64073fa8b25af3d83d5
--- /dev/null
+++ b/PusaV1/diffsynth/pipelines/hunyuan_video.py
@@ -0,0 +1,395 @@
+from ..models import ModelManager, SD3TextEncoder1, HunyuanVideoVAEDecoder, HunyuanVideoVAEEncoder
+from ..models.hunyuan_video_dit import HunyuanVideoDiT
+from ..models.hunyuan_video_text_encoder import HunyuanVideoLLMEncoder
+from ..schedulers.flow_match import FlowMatchScheduler
+from .base import BasePipeline
+from ..prompters import HunyuanVideoPrompter
+import torch
+import torchvision.transforms as transforms
+from einops import rearrange
+import numpy as np
+from PIL import Image
+from tqdm import tqdm
+
+
+class HunyuanVideoPipeline(BasePipeline):
+
+ def __init__(self, device="cuda", torch_dtype=torch.float16):
+ super().__init__(device=device, torch_dtype=torch_dtype)
+ self.scheduler = FlowMatchScheduler(shift=7.0, sigma_min=0.0, extra_one_step=True)
+ self.prompter = HunyuanVideoPrompter()
+ self.text_encoder_1: SD3TextEncoder1 = None
+ self.text_encoder_2: HunyuanVideoLLMEncoder = None
+ self.dit: HunyuanVideoDiT = None
+ self.vae_decoder: HunyuanVideoVAEDecoder = None
+ self.vae_encoder: HunyuanVideoVAEEncoder = None
+ self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder']
+ self.vram_management = False
+
+
+ def enable_vram_management(self):
+ self.vram_management = True
+ self.enable_cpu_offload()
+ self.text_encoder_2.enable_auto_offload(dtype=self.torch_dtype, device=self.device)
+ self.dit.enable_auto_offload(dtype=self.torch_dtype, device=self.device)
+
+
+ def fetch_models(self, model_manager: ModelManager):
+ self.text_encoder_1 = model_manager.fetch_model("sd3_text_encoder_1")
+ self.text_encoder_2 = model_manager.fetch_model("hunyuan_video_text_encoder_2")
+ self.dit = model_manager.fetch_model("hunyuan_video_dit")
+ self.vae_decoder = model_manager.fetch_model("hunyuan_video_vae_decoder")
+ self.vae_encoder = model_manager.fetch_model("hunyuan_video_vae_encoder")
+ self.prompter.fetch_models(self.text_encoder_1, self.text_encoder_2)
+
+
+ @staticmethod
+ def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None, enable_vram_management=True):
+ if device is None: device = model_manager.device
+ if torch_dtype is None: torch_dtype = model_manager.torch_dtype
+ pipe = HunyuanVideoPipeline(device=device, torch_dtype=torch_dtype)
+ pipe.fetch_models(model_manager)
+ if enable_vram_management:
+ pipe.enable_vram_management()
+ return pipe
+
+ def generate_crop_size_list(self, base_size=256, patch_size=32, max_ratio=4.0):
+ num_patches = round((base_size / patch_size)**2)
+ assert max_ratio >= 1.0
+ crop_size_list = []
+ wp, hp = num_patches, 1
+ while wp > 0:
+ if max(wp, hp) / min(wp, hp) <= max_ratio:
+ crop_size_list.append((wp * patch_size, hp * patch_size))
+ if (hp + 1) * wp <= num_patches:
+ hp += 1
+ else:
+ wp -= 1
+ return crop_size_list
+
+
+ def get_closest_ratio(self, height: float, width: float, ratios: list, buckets: list):
+ aspect_ratio = float(height) / float(width)
+ closest_ratio_id = np.abs(ratios - aspect_ratio).argmin()
+ closest_ratio = min(ratios, key=lambda ratio: abs(float(ratio) - aspect_ratio))
+ return buckets[closest_ratio_id], float(closest_ratio)
+
+
+ def prepare_vae_images_inputs(self, semantic_images, i2v_resolution="720p"):
+ if i2v_resolution == "720p":
+ bucket_hw_base_size = 960
+ elif i2v_resolution == "540p":
+ bucket_hw_base_size = 720
+ elif i2v_resolution == "360p":
+ bucket_hw_base_size = 480
+ else:
+ raise ValueError(f"i2v_resolution: {i2v_resolution} must be in [360p, 540p, 720p]")
+ origin_size = semantic_images[0].size
+
+ crop_size_list = self.generate_crop_size_list(bucket_hw_base_size, 32)
+ aspect_ratios = np.array([round(float(h) / float(w), 5) for h, w in crop_size_list])
+ closest_size, closest_ratio = self.get_closest_ratio(origin_size[1], origin_size[0], aspect_ratios, crop_size_list)
+ ref_image_transform = transforms.Compose([
+ transforms.Resize(closest_size),
+ transforms.CenterCrop(closest_size),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5])
+ ])
+
+ semantic_image_pixel_values = [ref_image_transform(semantic_image) for semantic_image in semantic_images]
+ semantic_image_pixel_values = torch.cat(semantic_image_pixel_values).unsqueeze(0).unsqueeze(2).to(self.device)
+ target_height, target_width = closest_size
+ return semantic_image_pixel_values, target_height, target_width
+
+
+ def encode_prompt(self, prompt, positive=True, clip_sequence_length=77, llm_sequence_length=256, input_images=None):
+ prompt_emb, pooled_prompt_emb, text_mask = self.prompter.encode_prompt(
+ prompt, device=self.device, positive=positive, clip_sequence_length=clip_sequence_length, llm_sequence_length=llm_sequence_length, images=input_images
+ )
+ return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_mask": text_mask}
+
+
+ def prepare_extra_input(self, latents=None, guidance=1.0):
+ freqs_cos, freqs_sin = self.dit.prepare_freqs(latents)
+ guidance = torch.Tensor([guidance] * latents.shape[0]).to(device=latents.device, dtype=latents.dtype)
+ return {"freqs_cos": freqs_cos, "freqs_sin": freqs_sin, "guidance": guidance}
+
+
+ def tensor2video(self, frames):
+ frames = rearrange(frames, "C T H W -> T H W C")
+ frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
+ frames = [Image.fromarray(frame) for frame in frames]
+ return frames
+
+
+ def encode_video(self, frames, tile_size=(17, 30, 30), tile_stride=(12, 20, 20)):
+ tile_size = ((tile_size[0] - 1) * 4 + 1, tile_size[1] * 8, tile_size[2] * 8)
+ tile_stride = (tile_stride[0] * 4, tile_stride[1] * 8, tile_stride[2] * 8)
+ latents = self.vae_encoder.encode_video(frames, tile_size=tile_size, tile_stride=tile_stride)
+ return latents
+
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt,
+ negative_prompt="",
+ input_video=None,
+ input_images=None,
+ i2v_resolution="720p",
+ i2v_stability=True,
+ denoising_strength=1.0,
+ seed=None,
+ rand_device=None,
+ height=720,
+ width=1280,
+ num_frames=129,
+ embedded_guidance=6.0,
+ cfg_scale=1.0,
+ num_inference_steps=30,
+ tea_cache_l1_thresh=None,
+ tile_size=(17, 30, 30),
+ tile_stride=(12, 20, 20),
+ step_processor=None,
+ progress_bar_cmd=lambda x: x,
+ progress_bar_st=None,
+ ):
+ # Tiler parameters
+ tiler_kwargs = {"tile_size": tile_size, "tile_stride": tile_stride}
+
+ # Scheduler
+ self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
+
+ # encoder input images
+ if input_images is not None:
+ self.load_models_to_device(['vae_encoder'])
+ image_pixel_values, height, width = self.prepare_vae_images_inputs(input_images, i2v_resolution=i2v_resolution)
+ with torch.autocast(device_type=self.device, dtype=torch.float16, enabled=True):
+ image_latents = self.vae_encoder(image_pixel_values)
+
+ # Initialize noise
+ rand_device = self.device if rand_device is None else rand_device
+ noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=self.torch_dtype).to(self.device)
+ if input_video is not None:
+ self.load_models_to_device(['vae_encoder'])
+ input_video = self.preprocess_images(input_video)
+ input_video = torch.stack(input_video, dim=2)
+ latents = self.encode_video(input_video, **tiler_kwargs).to(dtype=self.torch_dtype, device=self.device)
+ latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
+ elif input_images is not None and i2v_stability:
+ noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=image_latents.dtype).to(self.device)
+ t = torch.tensor([0.999]).to(device=self.device)
+ latents = noise * t + image_latents.repeat(1, 1, (num_frames - 1) // 4 + 1, 1, 1) * (1 - t)
+ latents = latents.to(dtype=image_latents.dtype)
+ else:
+ latents = noise
+
+ # Encode prompts
+ # current mllm does not support vram_management
+ self.load_models_to_device(["text_encoder_1"] if self.vram_management and input_images is None else ["text_encoder_1", "text_encoder_2"])
+ prompt_emb_posi = self.encode_prompt(prompt, positive=True, input_images=input_images)
+ if cfg_scale != 1.0:
+ prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
+
+ # Extra input
+ extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)
+
+ # TeaCache
+ tea_cache_kwargs = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh) if tea_cache_l1_thresh is not None else None}
+
+ # Denoise
+ self.load_models_to_device([] if self.vram_management else ["dit"])
+ for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
+ timestep = timestep.unsqueeze(0).to(self.device)
+ print(f"Step {progress_id + 1} / {len(self.scheduler.timesteps)}")
+
+ forward_func = lets_dance_hunyuan_video
+ if input_images is not None:
+ latents = torch.concat([image_latents, latents[:, :, 1:, :, :]], dim=2)
+ forward_func = lets_dance_hunyuan_video_i2v
+
+ # Inference
+ with torch.autocast(device_type=self.device, dtype=self.torch_dtype):
+ noise_pred_posi = forward_func(self.dit, latents, timestep, **prompt_emb_posi, **extra_input, **tea_cache_kwargs)
+ if cfg_scale != 1.0:
+ noise_pred_nega = forward_func(self.dit, latents, timestep, **prompt_emb_nega, **extra_input)
+ noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
+ else:
+ noise_pred = noise_pred_posi
+
+ # (Experimental feature, may be removed in the future)
+ if step_processor is not None:
+ self.load_models_to_device(['vae_decoder'])
+ rendered_frames = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents, to_final=True)
+ rendered_frames = self.vae_decoder.decode_video(rendered_frames, **tiler_kwargs)
+ rendered_frames = self.tensor2video(rendered_frames[0])
+ rendered_frames = step_processor(rendered_frames, original_frames=input_video)
+ self.load_models_to_device(['vae_encoder'])
+ rendered_frames = self.preprocess_images(rendered_frames)
+ rendered_frames = torch.stack(rendered_frames, dim=2)
+ target_latents = self.encode_video(rendered_frames).to(dtype=self.torch_dtype, device=self.device)
+ noise_pred = self.scheduler.return_to_timestep(self.scheduler.timesteps[progress_id], latents, target_latents)
+ self.load_models_to_device([] if self.vram_management else ["dit"])
+
+ # Scheduler
+ if input_images is not None:
+ latents = self.scheduler.step(noise_pred[:, :, 1:, :, :], self.scheduler.timesteps[progress_id], latents[:, :, 1:, :, :])
+ latents = torch.concat([image_latents, latents], dim=2)
+ else:
+ latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
+
+ # Decode
+ self.load_models_to_device(['vae_decoder'])
+ frames = self.vae_decoder.decode_video(latents, **tiler_kwargs)
+ self.load_models_to_device([])
+ frames = self.tensor2video(frames[0])
+
+ return frames
+
+
+
+class TeaCache:
+ def __init__(self, num_inference_steps, rel_l1_thresh):
+ self.num_inference_steps = num_inference_steps
+ self.step = 0
+ self.accumulated_rel_l1_distance = 0
+ self.previous_modulated_input = None
+ self.rel_l1_thresh = rel_l1_thresh
+ self.previous_residual = None
+ self.previous_hidden_states = None
+
+ def check(self, dit: HunyuanVideoDiT, img, vec):
+ img_ = img.clone()
+ vec_ = vec.clone()
+ img_mod1_shift, img_mod1_scale, _, _, _, _ = dit.double_blocks[0].component_a.mod(vec_).chunk(6, dim=-1)
+ normed_inp = dit.double_blocks[0].component_a.norm1(img_)
+ modulated_inp = normed_inp * (1 + img_mod1_scale.unsqueeze(1)) + img_mod1_shift.unsqueeze(1)
+ if self.step == 0 or self.step == self.num_inference_steps - 1:
+ should_calc = True
+ self.accumulated_rel_l1_distance = 0
+ else:
+ coefficients = [7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02]
+ rescale_func = np.poly1d(coefficients)
+ self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
+ if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
+ should_calc = False
+ else:
+ should_calc = True
+ self.accumulated_rel_l1_distance = 0
+ self.previous_modulated_input = modulated_inp
+ self.step += 1
+ if self.step == self.num_inference_steps:
+ self.step = 0
+ if should_calc:
+ self.previous_hidden_states = img.clone()
+ return not should_calc
+
+ def store(self, hidden_states):
+ self.previous_residual = hidden_states - self.previous_hidden_states
+ self.previous_hidden_states = None
+
+ def update(self, hidden_states):
+ hidden_states = hidden_states + self.previous_residual
+ return hidden_states
+
+
+
+def lets_dance_hunyuan_video(
+ dit: HunyuanVideoDiT,
+ x: torch.Tensor,
+ t: torch.Tensor,
+ prompt_emb: torch.Tensor = None,
+ text_mask: torch.Tensor = None,
+ pooled_prompt_emb: torch.Tensor = None,
+ freqs_cos: torch.Tensor = None,
+ freqs_sin: torch.Tensor = None,
+ guidance: torch.Tensor = None,
+ tea_cache: TeaCache = None,
+ **kwargs
+):
+ B, C, T, H, W = x.shape
+
+ vec = dit.time_in(t, dtype=torch.float32) + dit.vector_in(pooled_prompt_emb) + dit.guidance_in(guidance * 1000, dtype=torch.float32)
+ img = dit.img_in(x)
+ txt = dit.txt_in(prompt_emb, t, text_mask)
+
+ # TeaCache
+ if tea_cache is not None:
+ tea_cache_update = tea_cache.check(dit, img, vec)
+ else:
+ tea_cache_update = False
+
+ if tea_cache_update:
+ print("TeaCache skip forward.")
+ img = tea_cache.update(img)
+ else:
+ split_token = int(text_mask.sum(dim=1))
+ txt_len = int(txt.shape[1])
+ for block in tqdm(dit.double_blocks, desc="Double stream blocks"):
+ img, txt = block(img, txt, vec, (freqs_cos, freqs_sin), split_token=split_token)
+
+ x = torch.concat([img, txt], dim=1)
+ for block in tqdm(dit.single_blocks, desc="Single stream blocks"):
+ x = block(x, vec, (freqs_cos, freqs_sin), txt_len=txt_len, split_token=split_token)
+ img = x[:, :-txt_len]
+
+ if tea_cache is not None:
+ tea_cache.store(img)
+ img = dit.final_layer(img, vec)
+ img = dit.unpatchify(img, T=T//1, H=H//2, W=W//2)
+ return img
+
+
+def lets_dance_hunyuan_video_i2v(
+ dit: HunyuanVideoDiT,
+ x: torch.Tensor,
+ t: torch.Tensor,
+ prompt_emb: torch.Tensor = None,
+ text_mask: torch.Tensor = None,
+ pooled_prompt_emb: torch.Tensor = None,
+ freqs_cos: torch.Tensor = None,
+ freqs_sin: torch.Tensor = None,
+ guidance: torch.Tensor = None,
+ tea_cache: TeaCache = None,
+ **kwargs
+):
+ B, C, T, H, W = x.shape
+ # Uncomment below to keep same as official implementation
+ # guidance = guidance.to(dtype=torch.float32).to(torch.bfloat16)
+ vec = dit.time_in(t, dtype=torch.bfloat16)
+ vec_2 = dit.vector_in(pooled_prompt_emb)
+ vec = vec + vec_2
+ vec = vec + dit.guidance_in(guidance * 1000., dtype=torch.bfloat16)
+
+ token_replace_vec = dit.time_in(torch.zeros_like(t), dtype=torch.bfloat16)
+ tr_token = (H // 2) * (W // 2)
+ token_replace_vec = token_replace_vec + vec_2
+
+ img = dit.img_in(x)
+ txt = dit.txt_in(prompt_emb, t, text_mask)
+
+ # TeaCache
+ if tea_cache is not None:
+ tea_cache_update = tea_cache.check(dit, img, vec)
+ else:
+ tea_cache_update = False
+
+ if tea_cache_update:
+ print("TeaCache skip forward.")
+ img = tea_cache.update(img)
+ else:
+ split_token = int(text_mask.sum(dim=1))
+ txt_len = int(txt.shape[1])
+ for block in tqdm(dit.double_blocks, desc="Double stream blocks"):
+ img, txt = block(img, txt, vec, (freqs_cos, freqs_sin), token_replace_vec, tr_token, split_token)
+
+ x = torch.concat([img, txt], dim=1)
+ for block in tqdm(dit.single_blocks, desc="Single stream blocks"):
+ x = block(x, vec, (freqs_cos, freqs_sin), txt_len, token_replace_vec, tr_token, split_token)
+ img = x[:, :-txt_len]
+
+ if tea_cache is not None:
+ tea_cache.store(img)
+ img = dit.final_layer(img, vec)
+ img = dit.unpatchify(img, T=T//1, H=H//2, W=W//2)
+ return img
diff --git a/PusaV1/diffsynth/pipelines/omnigen_image.py b/PusaV1/diffsynth/pipelines/omnigen_image.py
new file mode 100644
index 0000000000000000000000000000000000000000..ddb2ae656639550084b7143fe690186602c0387d
--- /dev/null
+++ b/PusaV1/diffsynth/pipelines/omnigen_image.py
@@ -0,0 +1,289 @@
+from ..models.omnigen import OmniGenTransformer
+from ..models.sdxl_vae_encoder import SDXLVAEEncoder
+from ..models.sdxl_vae_decoder import SDXLVAEDecoder
+from ..models.model_manager import ModelManager
+from ..prompters.omnigen_prompter import OmniGenPrompter
+from ..schedulers import FlowMatchScheduler
+from .base import BasePipeline
+from typing import Optional, Dict, Any, Tuple, List
+from transformers.cache_utils import DynamicCache
+import torch, os
+from tqdm import tqdm
+
+
+
+class OmniGenCache(DynamicCache):
+ def __init__(self,
+ num_tokens_for_img: int, offload_kv_cache: bool=False) -> None:
+ if not torch.cuda.is_available():
+ print("No available GPU, offload_kv_cache will be set to False, which will result in large memory usage and time cost when input multiple images!!!")
+ offload_kv_cache = False
+ raise RuntimeError("OffloadedCache can only be used with a GPU")
+ super().__init__()
+ self.original_device = []
+ self.prefetch_stream = torch.cuda.Stream()
+ self.num_tokens_for_img = num_tokens_for_img
+ self.offload_kv_cache = offload_kv_cache
+
+ def prefetch_layer(self, layer_idx: int):
+ "Starts prefetching the next layer cache"
+ if layer_idx < len(self):
+ with torch.cuda.stream(self.prefetch_stream):
+ # Prefetch next layer tensors to GPU
+ device = self.original_device[layer_idx]
+ self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device, non_blocking=True)
+ self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device, non_blocking=True)
+
+
+ def evict_previous_layer(self, layer_idx: int):
+ "Moves the previous layer cache to the CPU"
+ if len(self) > 2:
+ # We do it on the default stream so it occurs after all earlier computations on these tensors are done
+ if layer_idx == 0:
+ prev_layer_idx = -1
+ else:
+ prev_layer_idx = (layer_idx - 1) % len(self)
+ self.key_cache[prev_layer_idx] = self.key_cache[prev_layer_idx].to("cpu", non_blocking=True)
+ self.value_cache[prev_layer_idx] = self.value_cache[prev_layer_idx].to("cpu", non_blocking=True)
+
+
+ def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
+ "Gets the cache for this layer to the device. Prefetches the next and evicts the previous layer."
+ if layer_idx < len(self):
+ if self.offload_kv_cache:
+ # Evict the previous layer if necessary
+ torch.cuda.current_stream().synchronize()
+ self.evict_previous_layer(layer_idx)
+ # Load current layer cache to its original device if not already there
+ original_device = self.original_device[layer_idx]
+ # self.prefetch_stream.synchronize(original_device)
+ torch.cuda.synchronize(self.prefetch_stream)
+ key_tensor = self.key_cache[layer_idx]
+ value_tensor = self.value_cache[layer_idx]
+
+ # Prefetch the next layer
+ self.prefetch_layer((layer_idx + 1) % len(self))
+ else:
+ key_tensor = self.key_cache[layer_idx]
+ value_tensor = self.value_cache[layer_idx]
+ return (key_tensor, value_tensor)
+ else:
+ raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
+
+
+ def update(
+ self,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ layer_idx: int,
+ cache_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
+ Parameters:
+ key_states (`torch.Tensor`):
+ The new key states to cache.
+ value_states (`torch.Tensor`):
+ The new value states to cache.
+ layer_idx (`int`):
+ The index of the layer to cache the states for.
+ cache_kwargs (`Dict[str, Any]`, `optional`):
+ Additional arguments for the cache subclass. No additional arguments are used in `OffloadedCache`.
+ Return:
+ A tuple containing the updated key and value states.
+ """
+ # Update the cache
+ if len(self.key_cache) < layer_idx:
+ raise ValueError("OffloadedCache does not support model usage where layers are skipped. Use DynamicCache.")
+ elif len(self.key_cache) == layer_idx:
+ # only cache the states for condition tokens
+ key_states = key_states[..., :-(self.num_tokens_for_img+1), :]
+ value_states = value_states[..., :-(self.num_tokens_for_img+1), :]
+
+ # Update the number of seen tokens
+ if layer_idx == 0:
+ self._seen_tokens += key_states.shape[-2]
+
+ self.key_cache.append(key_states)
+ self.value_cache.append(value_states)
+ self.original_device.append(key_states.device)
+ if self.offload_kv_cache:
+ self.evict_previous_layer(layer_idx)
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
+ else:
+ # only cache the states for condition tokens
+ key_tensor, value_tensor = self[layer_idx]
+ k = torch.cat([key_tensor, key_states], dim=-2)
+ v = torch.cat([value_tensor, value_states], dim=-2)
+ return k, v
+
+
+
+class OmnigenImagePipeline(BasePipeline):
+
+ def __init__(self, device="cuda", torch_dtype=torch.float16):
+ super().__init__(device=device, torch_dtype=torch_dtype)
+ self.scheduler = FlowMatchScheduler(num_train_timesteps=1, shift=1, inverse_timesteps=True, sigma_min=0, sigma_max=1)
+ # models
+ self.vae_decoder: SDXLVAEDecoder = None
+ self.vae_encoder: SDXLVAEEncoder = None
+ self.transformer: OmniGenTransformer = None
+ self.prompter: OmniGenPrompter = None
+ self.model_names = ['transformer', 'vae_decoder', 'vae_encoder']
+
+
+ def denoising_model(self):
+ return self.transformer
+
+
+ def fetch_models(self, model_manager: ModelManager, prompt_refiner_classes=[]):
+ # Main models
+ self.transformer, model_path = model_manager.fetch_model("omnigen_transformer", require_model_path=True)
+ self.vae_decoder = model_manager.fetch_model("sdxl_vae_decoder")
+ self.vae_encoder = model_manager.fetch_model("sdxl_vae_encoder")
+ self.prompter = OmniGenPrompter.from_pretrained(os.path.dirname(model_path))
+
+
+ @staticmethod
+ def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[], device=None):
+ pipe = OmnigenImagePipeline(
+ device=model_manager.device if device is None else device,
+ torch_dtype=model_manager.torch_dtype,
+ )
+ pipe.fetch_models(model_manager, prompt_refiner_classes=[])
+ return pipe
+
+
+ def encode_image(self, image, tiled=False, tile_size=64, tile_stride=32):
+ latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+ return latents
+
+
+ def encode_images(self, images, tiled=False, tile_size=64, tile_stride=32):
+ latents = [self.encode_image(image.to(device=self.device), tiled, tile_size, tile_stride).to(self.torch_dtype) for image in images]
+ return latents
+
+
+ def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
+ image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+ image = self.vae_output_to_image(image)
+ return image
+
+
+ def encode_prompt(self, prompt, clip_skip=1, positive=True):
+ prompt_emb = self.prompter.encode_prompt(prompt, clip_skip=clip_skip, device=self.device, positive=positive)
+ return {"encoder_hidden_states": prompt_emb}
+
+
+ def prepare_extra_input(self, latents=None):
+ return {}
+
+
+ def crop_position_ids_for_cache(self, position_ids, num_tokens_for_img):
+ if isinstance(position_ids, list):
+ for i in range(len(position_ids)):
+ position_ids[i] = position_ids[i][:, -(num_tokens_for_img+1):]
+ else:
+ position_ids = position_ids[:, -(num_tokens_for_img+1):]
+ return position_ids
+
+
+ def crop_attention_mask_for_cache(self, attention_mask, num_tokens_for_img):
+ if isinstance(attention_mask, list):
+ return [x[..., -(num_tokens_for_img+1):, :] for x in attention_mask]
+ return attention_mask[..., -(num_tokens_for_img+1):, :]
+
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt,
+ reference_images=[],
+ cfg_scale=2.0,
+ image_cfg_scale=2.0,
+ use_kv_cache=True,
+ offload_kv_cache=True,
+ input_image=None,
+ denoising_strength=1.0,
+ height=1024,
+ width=1024,
+ num_inference_steps=20,
+ tiled=False,
+ tile_size=64,
+ tile_stride=32,
+ seed=None,
+ progress_bar_cmd=tqdm,
+ progress_bar_st=None,
+ ):
+ height, width = self.check_resize_height_width(height, width)
+
+ # Tiler parameters
+ tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
+
+ # Prepare scheduler
+ self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
+
+ # Prepare latent tensors
+ if input_image is not None:
+ self.load_models_to_device(['vae_encoder'])
+ image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
+ latents = self.encode_image(image, **tiler_kwargs)
+ noise = self.generate_noise((1, 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
+ latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
+ else:
+ latents = self.generate_noise((1, 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
+ latents = latents.repeat(3, 1, 1, 1)
+
+ # Encode prompts
+ input_data = self.prompter(prompt, reference_images, height=height, width=width, use_img_cfg=True, separate_cfg_input=True, use_input_image_size_as_output=False)
+
+ # Encode images
+ reference_latents = [self.encode_images(images, **tiler_kwargs) for images in input_data['input_pixel_values']]
+
+ # Pack all parameters
+ model_kwargs = dict(input_ids=[input_ids.to(self.device) for input_ids in input_data['input_ids']],
+ input_img_latents=reference_latents,
+ input_image_sizes=input_data['input_image_sizes'],
+ attention_mask=[attention_mask.to(self.device) for attention_mask in input_data["attention_mask"]],
+ position_ids=[position_ids.to(self.device) for position_ids in input_data["position_ids"]],
+ cfg_scale=cfg_scale,
+ img_cfg_scale=image_cfg_scale,
+ use_img_cfg=True,
+ use_kv_cache=use_kv_cache,
+ offload_model=False,
+ )
+
+ # Denoise
+ self.load_models_to_device(['transformer'])
+ cache = [OmniGenCache(latents.size(-1)*latents.size(-2) // 4, offload_kv_cache) for _ in range(len(model_kwargs['input_ids']))] if use_kv_cache else None
+ for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
+ timestep = timestep.unsqueeze(0).repeat(latents.shape[0]).to(self.device)
+
+ # Forward
+ noise_pred, cache = self.transformer.forward_with_separate_cfg(latents, timestep, past_key_values=cache, **model_kwargs)
+
+ # Scheduler
+ latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
+
+ # Update KV cache
+ if progress_id == 0 and use_kv_cache:
+ num_tokens_for_img = latents.size(-1)*latents.size(-2) // 4
+ if isinstance(cache, list):
+ model_kwargs['input_ids'] = [None] * len(cache)
+ else:
+ model_kwargs['input_ids'] = None
+ model_kwargs['position_ids'] = self.crop_position_ids_for_cache(model_kwargs['position_ids'], num_tokens_for_img)
+ model_kwargs['attention_mask'] = self.crop_attention_mask_for_cache(model_kwargs['attention_mask'], num_tokens_for_img)
+
+ # UI
+ if progress_bar_st is not None:
+ progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
+
+ # Decode image
+ del cache
+ self.load_models_to_device(['vae_decoder'])
+ image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+
+ # offload all models
+ self.load_models_to_device([])
+ return image
diff --git a/PusaV1/diffsynth/pipelines/pipeline_runner.py b/PusaV1/diffsynth/pipelines/pipeline_runner.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b842f9bd7b25edca1c9951e67ebe5c364deca81
--- /dev/null
+++ b/PusaV1/diffsynth/pipelines/pipeline_runner.py
@@ -0,0 +1,105 @@
+import os, torch, json
+from .sd_video import ModelManager, SDVideoPipeline, ControlNetConfigUnit
+from ..processors.sequencial_processor import SequencialProcessor
+from ..data import VideoData, save_frames, save_video
+
+
+
+class SDVideoPipelineRunner:
+ def __init__(self, in_streamlit=False):
+ self.in_streamlit = in_streamlit
+
+
+ def load_pipeline(self, model_list, textual_inversion_folder, device, lora_alphas, controlnet_units):
+ # Load models
+ model_manager = ModelManager(torch_dtype=torch.float16, device=device)
+ model_manager.load_models(model_list)
+ pipe = SDVideoPipeline.from_model_manager(
+ model_manager,
+ [
+ ControlNetConfigUnit(
+ processor_id=unit["processor_id"],
+ model_path=unit["model_path"],
+ scale=unit["scale"]
+ ) for unit in controlnet_units
+ ]
+ )
+ textual_inversion_paths = []
+ for file_name in os.listdir(textual_inversion_folder):
+ if file_name.endswith(".pt") or file_name.endswith(".bin") or file_name.endswith(".pth") or file_name.endswith(".safetensors"):
+ textual_inversion_paths.append(os.path.join(textual_inversion_folder, file_name))
+ pipe.prompter.load_textual_inversions(textual_inversion_paths)
+ return model_manager, pipe
+
+
+ def load_smoother(self, model_manager, smoother_configs):
+ smoother = SequencialProcessor.from_model_manager(model_manager, smoother_configs)
+ return smoother
+
+
+ def synthesize_video(self, model_manager, pipe, seed, smoother, **pipeline_inputs):
+ torch.manual_seed(seed)
+ if self.in_streamlit:
+ import streamlit as st
+ progress_bar_st = st.progress(0.0)
+ output_video = pipe(**pipeline_inputs, smoother=smoother, progress_bar_st=progress_bar_st)
+ progress_bar_st.progress(1.0)
+ else:
+ output_video = pipe(**pipeline_inputs, smoother=smoother)
+ model_manager.to("cpu")
+ return output_video
+
+
+ def load_video(self, video_file, image_folder, height, width, start_frame_id, end_frame_id):
+ video = VideoData(video_file=video_file, image_folder=image_folder, height=height, width=width)
+ if start_frame_id is None:
+ start_frame_id = 0
+ if end_frame_id is None:
+ end_frame_id = len(video)
+ frames = [video[i] for i in range(start_frame_id, end_frame_id)]
+ return frames
+
+
+ def add_data_to_pipeline_inputs(self, data, pipeline_inputs):
+ pipeline_inputs["input_frames"] = self.load_video(**data["input_frames"])
+ pipeline_inputs["num_frames"] = len(pipeline_inputs["input_frames"])
+ pipeline_inputs["width"], pipeline_inputs["height"] = pipeline_inputs["input_frames"][0].size
+ if len(data["controlnet_frames"]) > 0:
+ pipeline_inputs["controlnet_frames"] = [self.load_video(**unit) for unit in data["controlnet_frames"]]
+ return pipeline_inputs
+
+
+ def save_output(self, video, output_folder, fps, config):
+ os.makedirs(output_folder, exist_ok=True)
+ save_frames(video, os.path.join(output_folder, "frames"))
+ save_video(video, os.path.join(output_folder, "video.mp4"), fps=fps)
+ config["pipeline"]["pipeline_inputs"]["input_frames"] = []
+ config["pipeline"]["pipeline_inputs"]["controlnet_frames"] = []
+ with open(os.path.join(output_folder, "config.json"), 'w') as file:
+ json.dump(config, file, indent=4)
+
+
+ def run(self, config):
+ if self.in_streamlit:
+ import streamlit as st
+ if self.in_streamlit: st.markdown("Loading videos ...")
+ config["pipeline"]["pipeline_inputs"] = self.add_data_to_pipeline_inputs(config["data"], config["pipeline"]["pipeline_inputs"])
+ if self.in_streamlit: st.markdown("Loading videos ... done!")
+ if self.in_streamlit: st.markdown("Loading models ...")
+ model_manager, pipe = self.load_pipeline(**config["models"])
+ if self.in_streamlit: st.markdown("Loading models ... done!")
+ if "smoother_configs" in config:
+ if self.in_streamlit: st.markdown("Loading smoother ...")
+ smoother = self.load_smoother(model_manager, config["smoother_configs"])
+ if self.in_streamlit: st.markdown("Loading smoother ... done!")
+ else:
+ smoother = None
+ if self.in_streamlit: st.markdown("Synthesizing videos ...")
+ output_video = self.synthesize_video(model_manager, pipe, config["pipeline"]["seed"], smoother, **config["pipeline"]["pipeline_inputs"])
+ if self.in_streamlit: st.markdown("Synthesizing videos ... done!")
+ if self.in_streamlit: st.markdown("Saving videos ...")
+ self.save_output(output_video, config["data"]["output_folder"], config["data"]["fps"], config)
+ if self.in_streamlit: st.markdown("Saving videos ... done!")
+ if self.in_streamlit: st.markdown("Finished!")
+ video_file = open(os.path.join(os.path.join(config["data"]["output_folder"], "video.mp4")), 'rb')
+ if self.in_streamlit: st.video(video_file.read())
diff --git a/PusaV1/diffsynth/pipelines/sd3_image.py b/PusaV1/diffsynth/pipelines/sd3_image.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6098739b2701d59958ef3fa85b0dc96b5ffe86a
--- /dev/null
+++ b/PusaV1/diffsynth/pipelines/sd3_image.py
@@ -0,0 +1,147 @@
+from ..models import ModelManager, SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3, SD3DiT, SD3VAEDecoder, SD3VAEEncoder
+from ..prompters import SD3Prompter
+from ..schedulers import FlowMatchScheduler
+from .base import BasePipeline
+import torch
+from tqdm import tqdm
+
+
+
+class SD3ImagePipeline(BasePipeline):
+
+ def __init__(self, device="cuda", torch_dtype=torch.float16):
+ super().__init__(device=device, torch_dtype=torch_dtype, height_division_factor=16, width_division_factor=16)
+ self.scheduler = FlowMatchScheduler()
+ self.prompter = SD3Prompter()
+ # models
+ self.text_encoder_1: SD3TextEncoder1 = None
+ self.text_encoder_2: SD3TextEncoder2 = None
+ self.text_encoder_3: SD3TextEncoder3 = None
+ self.dit: SD3DiT = None
+ self.vae_decoder: SD3VAEDecoder = None
+ self.vae_encoder: SD3VAEEncoder = None
+ self.model_names = ['text_encoder_1', 'text_encoder_2', 'text_encoder_3', 'dit', 'vae_decoder', 'vae_encoder']
+
+
+ def denoising_model(self):
+ return self.dit
+
+
+ def fetch_models(self, model_manager: ModelManager, prompt_refiner_classes=[]):
+ self.text_encoder_1 = model_manager.fetch_model("sd3_text_encoder_1")
+ self.text_encoder_2 = model_manager.fetch_model("sd3_text_encoder_2")
+ self.text_encoder_3 = model_manager.fetch_model("sd3_text_encoder_3")
+ self.dit = model_manager.fetch_model("sd3_dit")
+ self.vae_decoder = model_manager.fetch_model("sd3_vae_decoder")
+ self.vae_encoder = model_manager.fetch_model("sd3_vae_encoder")
+ self.prompter.fetch_models(self.text_encoder_1, self.text_encoder_2, self.text_encoder_3)
+ self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
+
+
+ @staticmethod
+ def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[], device=None):
+ pipe = SD3ImagePipeline(
+ device=model_manager.device if device is None else device,
+ torch_dtype=model_manager.torch_dtype,
+ )
+ pipe.fetch_models(model_manager, prompt_refiner_classes)
+ return pipe
+
+
+ def encode_image(self, image, tiled=False, tile_size=64, tile_stride=32):
+ latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+ return latents
+
+
+ def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
+ image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+ image = self.vae_output_to_image(image)
+ return image
+
+
+ def encode_prompt(self, prompt, positive=True, t5_sequence_length=77):
+ prompt_emb, pooled_prompt_emb = self.prompter.encode_prompt(
+ prompt, device=self.device, positive=positive, t5_sequence_length=t5_sequence_length
+ )
+ return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb}
+
+
+ def prepare_extra_input(self, latents=None):
+ return {}
+
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt,
+ local_prompts=[],
+ masks=[],
+ mask_scales=[],
+ negative_prompt="",
+ cfg_scale=7.5,
+ input_image=None,
+ denoising_strength=1.0,
+ height=1024,
+ width=1024,
+ num_inference_steps=20,
+ t5_sequence_length=77,
+ tiled=False,
+ tile_size=128,
+ tile_stride=64,
+ seed=None,
+ progress_bar_cmd=tqdm,
+ progress_bar_st=None,
+ ):
+ height, width = self.check_resize_height_width(height, width)
+
+ # Tiler parameters
+ tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
+
+ # Prepare scheduler
+ self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
+
+ # Prepare latent tensors
+ if input_image is not None:
+ self.load_models_to_device(['vae_encoder'])
+ image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
+ latents = self.encode_image(image, **tiler_kwargs)
+ noise = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
+ latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
+ else:
+ latents = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
+
+ # Encode prompts
+ self.load_models_to_device(['text_encoder_1', 'text_encoder_2', 'text_encoder_3'])
+ prompt_emb_posi = self.encode_prompt(prompt, positive=True, t5_sequence_length=t5_sequence_length)
+ prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False, t5_sequence_length=t5_sequence_length)
+ prompt_emb_locals = [self.encode_prompt(prompt_local, t5_sequence_length=t5_sequence_length) for prompt_local in local_prompts]
+
+ # Denoise
+ self.load_models_to_device(['dit'])
+ for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
+ timestep = timestep.unsqueeze(0).to(self.device)
+
+ # Classifier-free guidance
+ inference_callback = lambda prompt_emb_posi: self.dit(
+ latents, timestep=timestep, **prompt_emb_posi, **tiler_kwargs,
+ )
+ noise_pred_posi = self.control_noise_via_local_prompts(prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback)
+ noise_pred_nega = self.dit(
+ latents, timestep=timestep, **prompt_emb_nega, **tiler_kwargs,
+ )
+ noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
+
+ # DDIM
+ latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
+
+ # UI
+ if progress_bar_st is not None:
+ progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
+
+ # Decode image
+ self.load_models_to_device(['vae_decoder'])
+ image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+
+ # offload all models
+ self.load_models_to_device([])
+ return image
diff --git a/PusaV1/diffsynth/pipelines/sd_image.py b/PusaV1/diffsynth/pipelines/sd_image.py
new file mode 100644
index 0000000000000000000000000000000000000000..c22c3fe69578f28925be900036bf21afeb750f17
--- /dev/null
+++ b/PusaV1/diffsynth/pipelines/sd_image.py
@@ -0,0 +1,191 @@
+from ..models import SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder, SDIpAdapter, IpAdapterCLIPImageEmbedder
+from ..models.model_manager import ModelManager
+from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
+from ..prompters import SDPrompter
+from ..schedulers import EnhancedDDIMScheduler
+from .base import BasePipeline
+from .dancer import lets_dance
+from typing import List
+import torch
+from tqdm import tqdm
+
+
+
+class SDImagePipeline(BasePipeline):
+
+ def __init__(self, device="cuda", torch_dtype=torch.float16):
+ super().__init__(device=device, torch_dtype=torch_dtype)
+ self.scheduler = EnhancedDDIMScheduler()
+ self.prompter = SDPrompter()
+ # models
+ self.text_encoder: SDTextEncoder = None
+ self.unet: SDUNet = None
+ self.vae_decoder: SDVAEDecoder = None
+ self.vae_encoder: SDVAEEncoder = None
+ self.controlnet: MultiControlNetManager = None
+ self.ipadapter_image_encoder: IpAdapterCLIPImageEmbedder = None
+ self.ipadapter: SDIpAdapter = None
+ self.model_names = ['text_encoder', 'unet', 'vae_decoder', 'vae_encoder', 'controlnet', 'ipadapter_image_encoder', 'ipadapter']
+
+
+ def denoising_model(self):
+ return self.unet
+
+
+ def fetch_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
+ # Main models
+ self.text_encoder = model_manager.fetch_model("sd_text_encoder")
+ self.unet = model_manager.fetch_model("sd_unet")
+ self.vae_decoder = model_manager.fetch_model("sd_vae_decoder")
+ self.vae_encoder = model_manager.fetch_model("sd_vae_encoder")
+ self.prompter.fetch_models(self.text_encoder)
+ self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
+
+ # ControlNets
+ controlnet_units = []
+ for config in controlnet_config_units:
+ controlnet_unit = ControlNetUnit(
+ Annotator(config.processor_id, device=self.device),
+ model_manager.fetch_model("sd_controlnet", config.model_path),
+ config.scale
+ )
+ controlnet_units.append(controlnet_unit)
+ self.controlnet = MultiControlNetManager(controlnet_units)
+
+ # IP-Adapters
+ self.ipadapter = model_manager.fetch_model("sd_ipadapter")
+ self.ipadapter_image_encoder = model_manager.fetch_model("sd_ipadapter_clip_image_encoder")
+
+
+ @staticmethod
+ def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], device=None):
+ pipe = SDImagePipeline(
+ device=model_manager.device if device is None else device,
+ torch_dtype=model_manager.torch_dtype,
+ )
+ pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes=[])
+ return pipe
+
+
+ def encode_image(self, image, tiled=False, tile_size=64, tile_stride=32):
+ latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+ return latents
+
+
+ def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
+ image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+ image = self.vae_output_to_image(image)
+ return image
+
+
+ def encode_prompt(self, prompt, clip_skip=1, positive=True):
+ prompt_emb = self.prompter.encode_prompt(prompt, clip_skip=clip_skip, device=self.device, positive=positive)
+ return {"encoder_hidden_states": prompt_emb}
+
+
+ def prepare_extra_input(self, latents=None):
+ return {}
+
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt,
+ local_prompts=[],
+ masks=[],
+ mask_scales=[],
+ negative_prompt="",
+ cfg_scale=7.5,
+ clip_skip=1,
+ input_image=None,
+ ipadapter_images=None,
+ ipadapter_scale=1.0,
+ controlnet_image=None,
+ denoising_strength=1.0,
+ height=512,
+ width=512,
+ num_inference_steps=20,
+ tiled=False,
+ tile_size=64,
+ tile_stride=32,
+ seed=None,
+ progress_bar_cmd=tqdm,
+ progress_bar_st=None,
+ ):
+ height, width = self.check_resize_height_width(height, width)
+
+ # Tiler parameters
+ tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
+
+ # Prepare scheduler
+ self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
+
+ # Prepare latent tensors
+ if input_image is not None:
+ self.load_models_to_device(['vae_encoder'])
+ image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
+ latents = self.encode_image(image, **tiler_kwargs)
+ noise = self.generate_noise((1, 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
+ latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
+ else:
+ latents = self.generate_noise((1, 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
+
+ # Encode prompts
+ self.load_models_to_device(['text_encoder'])
+ prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, positive=True)
+ prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, positive=False)
+ prompt_emb_locals = [self.encode_prompt(prompt_local, clip_skip=clip_skip, positive=True) for prompt_local in local_prompts]
+
+ # IP-Adapter
+ if ipadapter_images is not None:
+ self.load_models_to_device(['ipadapter_image_encoder'])
+ ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images)
+ self.load_models_to_device(['ipadapter'])
+ ipadapter_kwargs_list_posi = {"ipadapter_kwargs_list": self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)}
+ ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": self.ipadapter(torch.zeros_like(ipadapter_image_encoding))}
+ else:
+ ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": {}}, {"ipadapter_kwargs_list": {}}
+
+ # Prepare ControlNets
+ if controlnet_image is not None:
+ self.load_models_to_device(['controlnet'])
+ controlnet_image = self.controlnet.process_image(controlnet_image).to(device=self.device, dtype=self.torch_dtype)
+ controlnet_image = controlnet_image.unsqueeze(1)
+ controlnet_kwargs = {"controlnet_frames": controlnet_image}
+ else:
+ controlnet_kwargs = {"controlnet_frames": None}
+
+ # Denoise
+ self.load_models_to_device(['controlnet', 'unet'])
+ for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
+ timestep = timestep.unsqueeze(0).to(self.device)
+
+ # Classifier-free guidance
+ inference_callback = lambda prompt_emb_posi: lets_dance(
+ self.unet, motion_modules=None, controlnet=self.controlnet,
+ sample=latents, timestep=timestep,
+ **prompt_emb_posi, **controlnet_kwargs, **tiler_kwargs, **ipadapter_kwargs_list_posi,
+ device=self.device,
+ )
+ noise_pred_posi = self.control_noise_via_local_prompts(prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback)
+ noise_pred_nega = lets_dance(
+ self.unet, motion_modules=None, controlnet=self.controlnet,
+ sample=latents, timestep=timestep, **prompt_emb_nega, **controlnet_kwargs, **tiler_kwargs, **ipadapter_kwargs_list_nega,
+ device=self.device,
+ )
+ noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
+
+ # DDIM
+ latents = self.scheduler.step(noise_pred, timestep, latents)
+
+ # UI
+ if progress_bar_st is not None:
+ progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
+
+ # Decode image
+ self.load_models_to_device(['vae_decoder'])
+ image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+
+ # offload all models
+ self.load_models_to_device([])
+ return image
diff --git a/PusaV1/diffsynth/pipelines/sd_video.py b/PusaV1/diffsynth/pipelines/sd_video.py
new file mode 100644
index 0000000000000000000000000000000000000000..4337beb4f7a2d4a08c5955fdbd5f528ea328b39e
--- /dev/null
+++ b/PusaV1/diffsynth/pipelines/sd_video.py
@@ -0,0 +1,269 @@
+from ..models import SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder, SDIpAdapter, IpAdapterCLIPImageEmbedder, SDMotionModel
+from ..models.model_manager import ModelManager
+from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
+from ..prompters import SDPrompter
+from ..schedulers import EnhancedDDIMScheduler
+from .sd_image import SDImagePipeline
+from .dancer import lets_dance
+from typing import List
+import torch
+from tqdm import tqdm
+
+
+
+def lets_dance_with_long_video(
+ unet: SDUNet,
+ motion_modules: SDMotionModel = None,
+ controlnet: MultiControlNetManager = None,
+ sample = None,
+ timestep = None,
+ encoder_hidden_states = None,
+ ipadapter_kwargs_list = {},
+ controlnet_frames = None,
+ unet_batch_size = 1,
+ controlnet_batch_size = 1,
+ cross_frame_attention = False,
+ tiled=False,
+ tile_size=64,
+ tile_stride=32,
+ device="cuda",
+ animatediff_batch_size=16,
+ animatediff_stride=8,
+):
+ num_frames = sample.shape[0]
+ hidden_states_output = [(torch.zeros(sample[0].shape, dtype=sample[0].dtype), 0) for i in range(num_frames)]
+
+ for batch_id in range(0, num_frames, animatediff_stride):
+ batch_id_ = min(batch_id + animatediff_batch_size, num_frames)
+
+ # process this batch
+ hidden_states_batch = lets_dance(
+ unet, motion_modules, controlnet,
+ sample[batch_id: batch_id_].to(device),
+ timestep,
+ encoder_hidden_states,
+ ipadapter_kwargs_list=ipadapter_kwargs_list,
+ controlnet_frames=controlnet_frames[:, batch_id: batch_id_].to(device) if controlnet_frames is not None else None,
+ unet_batch_size=unet_batch_size, controlnet_batch_size=controlnet_batch_size,
+ cross_frame_attention=cross_frame_attention,
+ tiled=tiled, tile_size=tile_size, tile_stride=tile_stride, device=device
+ ).cpu()
+
+ # update hidden_states
+ for i, hidden_states_updated in zip(range(batch_id, batch_id_), hidden_states_batch):
+ bias = max(1 - abs(i - (batch_id + batch_id_ - 1) / 2) / ((batch_id_ - batch_id - 1 + 1e-2) / 2), 1e-2)
+ hidden_states, num = hidden_states_output[i]
+ hidden_states = hidden_states * (num / (num + bias)) + hidden_states_updated * (bias / (num + bias))
+ hidden_states_output[i] = (hidden_states, num + bias)
+
+ if batch_id_ == num_frames:
+ break
+
+ # output
+ hidden_states = torch.stack([h for h, _ in hidden_states_output])
+ return hidden_states
+
+
+
+class SDVideoPipeline(SDImagePipeline):
+
+ def __init__(self, device="cuda", torch_dtype=torch.float16, use_original_animatediff=True):
+ super().__init__(device=device, torch_dtype=torch_dtype)
+ self.scheduler = EnhancedDDIMScheduler(beta_schedule="linear" if use_original_animatediff else "scaled_linear")
+ self.prompter = SDPrompter()
+ # models
+ self.text_encoder: SDTextEncoder = None
+ self.unet: SDUNet = None
+ self.vae_decoder: SDVAEDecoder = None
+ self.vae_encoder: SDVAEEncoder = None
+ self.controlnet: MultiControlNetManager = None
+ self.ipadapter_image_encoder: IpAdapterCLIPImageEmbedder = None
+ self.ipadapter: SDIpAdapter = None
+ self.motion_modules: SDMotionModel = None
+
+
+ def fetch_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
+ # Main models
+ self.text_encoder = model_manager.fetch_model("sd_text_encoder")
+ self.unet = model_manager.fetch_model("sd_unet")
+ self.vae_decoder = model_manager.fetch_model("sd_vae_decoder")
+ self.vae_encoder = model_manager.fetch_model("sd_vae_encoder")
+ self.prompter.fetch_models(self.text_encoder)
+ self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
+
+ # ControlNets
+ controlnet_units = []
+ for config in controlnet_config_units:
+ controlnet_unit = ControlNetUnit(
+ Annotator(config.processor_id, device=self.device),
+ model_manager.fetch_model("sd_controlnet", config.model_path),
+ config.scale
+ )
+ controlnet_units.append(controlnet_unit)
+ self.controlnet = MultiControlNetManager(controlnet_units)
+
+ # IP-Adapters
+ self.ipadapter = model_manager.fetch_model("sd_ipadapter")
+ self.ipadapter_image_encoder = model_manager.fetch_model("sd_ipadapter_clip_image_encoder")
+
+ # Motion Modules
+ self.motion_modules = model_manager.fetch_model("sd_motion_modules")
+ if self.motion_modules is None:
+ self.scheduler = EnhancedDDIMScheduler(beta_schedule="scaled_linear")
+
+
+ @staticmethod
+ def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
+ pipe = SDVideoPipeline(
+ device=model_manager.device,
+ torch_dtype=model_manager.torch_dtype,
+ )
+ pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes)
+ return pipe
+
+
+ def decode_video(self, latents, tiled=False, tile_size=64, tile_stride=32):
+ images = [
+ self.decode_image(latents[frame_id: frame_id+1], tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+ for frame_id in range(latents.shape[0])
+ ]
+ return images
+
+
+ def encode_video(self, processed_images, tiled=False, tile_size=64, tile_stride=32):
+ latents = []
+ for image in processed_images:
+ image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
+ latent = self.encode_image(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+ latents.append(latent.cpu())
+ latents = torch.concat(latents, dim=0)
+ return latents
+
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt,
+ negative_prompt="",
+ cfg_scale=7.5,
+ clip_skip=1,
+ num_frames=None,
+ input_frames=None,
+ ipadapter_images=None,
+ ipadapter_scale=1.0,
+ controlnet_frames=None,
+ denoising_strength=1.0,
+ height=512,
+ width=512,
+ num_inference_steps=20,
+ animatediff_batch_size = 16,
+ animatediff_stride = 8,
+ unet_batch_size = 1,
+ controlnet_batch_size = 1,
+ cross_frame_attention = False,
+ smoother=None,
+ smoother_progress_ids=[],
+ tiled=False,
+ tile_size=64,
+ tile_stride=32,
+ seed=None,
+ progress_bar_cmd=tqdm,
+ progress_bar_st=None,
+ ):
+ height, width = self.check_resize_height_width(height, width)
+
+ # Tiler parameters, batch size ...
+ tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
+ other_kwargs = {
+ "animatediff_batch_size": animatediff_batch_size, "animatediff_stride": animatediff_stride,
+ "unet_batch_size": unet_batch_size, "controlnet_batch_size": controlnet_batch_size,
+ "cross_frame_attention": cross_frame_attention,
+ }
+
+ # Prepare scheduler
+ self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
+
+ # Prepare latent tensors
+ if self.motion_modules is None:
+ noise = self.generate_noise((1, 4, height//8, width//8), seed=seed, device="cpu", dtype=self.torch_dtype).repeat(num_frames, 1, 1, 1)
+ else:
+ noise = self.generate_noise((num_frames, 4, height//8, width//8), seed=seed, device="cpu", dtype=self.torch_dtype)
+ if input_frames is None or denoising_strength == 1.0:
+ latents = noise
+ else:
+ latents = self.encode_video(input_frames, **tiler_kwargs)
+ latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
+
+ # Encode prompts
+ prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, positive=True)
+ prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, positive=False)
+
+ # IP-Adapter
+ if ipadapter_images is not None:
+ ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images)
+ ipadapter_kwargs_list_posi = {"ipadapter_kwargs_list": self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)}
+ ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": self.ipadapter(torch.zeros_like(ipadapter_image_encoding))}
+ else:
+ ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": {}}, {"ipadapter_kwargs_list": {}}
+
+ # Prepare ControlNets
+ if controlnet_frames is not None:
+ if isinstance(controlnet_frames[0], list):
+ controlnet_frames_ = []
+ for processor_id in range(len(controlnet_frames)):
+ controlnet_frames_.append(
+ torch.stack([
+ self.controlnet.process_image(controlnet_frame, processor_id=processor_id).to(self.torch_dtype)
+ for controlnet_frame in progress_bar_cmd(controlnet_frames[processor_id])
+ ], dim=1)
+ )
+ controlnet_frames = torch.concat(controlnet_frames_, dim=0)
+ else:
+ controlnet_frames = torch.stack([
+ self.controlnet.process_image(controlnet_frame).to(self.torch_dtype)
+ for controlnet_frame in progress_bar_cmd(controlnet_frames)
+ ], dim=1)
+ controlnet_kwargs = {"controlnet_frames": controlnet_frames}
+ else:
+ controlnet_kwargs = {"controlnet_frames": None}
+
+ # Denoise
+ for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
+ timestep = timestep.unsqueeze(0).to(self.device)
+
+ # Classifier-free guidance
+ noise_pred_posi = lets_dance_with_long_video(
+ self.unet, motion_modules=self.motion_modules, controlnet=self.controlnet,
+ sample=latents, timestep=timestep,
+ **prompt_emb_posi, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **other_kwargs, **tiler_kwargs,
+ device=self.device,
+ )
+ noise_pred_nega = lets_dance_with_long_video(
+ self.unet, motion_modules=self.motion_modules, controlnet=self.controlnet,
+ sample=latents, timestep=timestep,
+ **prompt_emb_nega, **controlnet_kwargs, **ipadapter_kwargs_list_nega, **other_kwargs, **tiler_kwargs,
+ device=self.device,
+ )
+ noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
+
+ # DDIM and smoother
+ if smoother is not None and progress_id in smoother_progress_ids:
+ rendered_frames = self.scheduler.step(noise_pred, timestep, latents, to_final=True)
+ rendered_frames = self.decode_video(rendered_frames)
+ rendered_frames = smoother(rendered_frames, original_frames=input_frames)
+ target_latents = self.encode_video(rendered_frames)
+ noise_pred = self.scheduler.return_to_timestep(timestep, latents, target_latents)
+ latents = self.scheduler.step(noise_pred, timestep, latents)
+
+ # UI
+ if progress_bar_st is not None:
+ progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
+
+ # Decode image
+ output_frames = self.decode_video(latents, **tiler_kwargs)
+
+ # Post-process
+ if smoother is not None and (num_inference_steps in smoother_progress_ids or -1 in smoother_progress_ids):
+ output_frames = smoother(output_frames, original_frames=input_frames)
+
+ return output_frames
diff --git a/PusaV1/diffsynth/pipelines/sdxl_image.py b/PusaV1/diffsynth/pipelines/sdxl_image.py
new file mode 100644
index 0000000000000000000000000000000000000000..499c4bbce707fa7cfd026c66af8c8dca3e554127
--- /dev/null
+++ b/PusaV1/diffsynth/pipelines/sdxl_image.py
@@ -0,0 +1,226 @@
+from ..models import SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder, SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
+from ..models.kolors_text_encoder import ChatGLMModel
+from ..models.model_manager import ModelManager
+from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
+from ..prompters import SDXLPrompter, KolorsPrompter
+from ..schedulers import EnhancedDDIMScheduler
+from .base import BasePipeline
+from .dancer import lets_dance_xl
+from typing import List
+import torch
+from tqdm import tqdm
+from einops import repeat
+
+
+
+class SDXLImagePipeline(BasePipeline):
+
+ def __init__(self, device="cuda", torch_dtype=torch.float16):
+ super().__init__(device=device, torch_dtype=torch_dtype)
+ self.scheduler = EnhancedDDIMScheduler()
+ self.prompter = SDXLPrompter()
+ # models
+ self.text_encoder: SDXLTextEncoder = None
+ self.text_encoder_2: SDXLTextEncoder2 = None
+ self.text_encoder_kolors: ChatGLMModel = None
+ self.unet: SDXLUNet = None
+ self.vae_decoder: SDXLVAEDecoder = None
+ self.vae_encoder: SDXLVAEEncoder = None
+ self.controlnet: MultiControlNetManager = None
+ self.ipadapter_image_encoder: IpAdapterXLCLIPImageEmbedder = None
+ self.ipadapter: SDXLIpAdapter = None
+ self.model_names = ['text_encoder', 'text_encoder_2', 'text_encoder_kolors', 'unet', 'vae_decoder', 'vae_encoder', 'controlnet', 'ipadapter_image_encoder', 'ipadapter']
+
+
+ def denoising_model(self):
+ return self.unet
+
+
+ def fetch_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
+ # Main models
+ self.text_encoder = model_manager.fetch_model("sdxl_text_encoder")
+ self.text_encoder_2 = model_manager.fetch_model("sdxl_text_encoder_2")
+ self.text_encoder_kolors = model_manager.fetch_model("kolors_text_encoder")
+ self.unet = model_manager.fetch_model("sdxl_unet")
+ self.vae_decoder = model_manager.fetch_model("sdxl_vae_decoder")
+ self.vae_encoder = model_manager.fetch_model("sdxl_vae_encoder")
+
+ # ControlNets
+ controlnet_units = []
+ for config in controlnet_config_units:
+ controlnet_unit = ControlNetUnit(
+ Annotator(config.processor_id, device=self.device),
+ model_manager.fetch_model("sdxl_controlnet", config.model_path),
+ config.scale
+ )
+ controlnet_units.append(controlnet_unit)
+ self.controlnet = MultiControlNetManager(controlnet_units)
+
+ # IP-Adapters
+ self.ipadapter = model_manager.fetch_model("sdxl_ipadapter")
+ self.ipadapter_image_encoder = model_manager.fetch_model("sdxl_ipadapter_clip_image_encoder")
+
+ # Kolors
+ if self.text_encoder_kolors is not None:
+ print("Switch to Kolors. The prompter and scheduler will be replaced.")
+ self.prompter = KolorsPrompter()
+ self.prompter.fetch_models(self.text_encoder_kolors)
+ self.scheduler = EnhancedDDIMScheduler(beta_end=0.014, num_train_timesteps=1100)
+ else:
+ self.prompter.fetch_models(self.text_encoder, self.text_encoder_2)
+ self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
+
+
+ @staticmethod
+ def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], device=None):
+ pipe = SDXLImagePipeline(
+ device=model_manager.device if device is None else device,
+ torch_dtype=model_manager.torch_dtype,
+ )
+ pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes)
+ return pipe
+
+
+ def encode_image(self, image, tiled=False, tile_size=64, tile_stride=32):
+ latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+ return latents
+
+
+ def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
+ image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+ image = self.vae_output_to_image(image)
+ return image
+
+
+ def encode_prompt(self, prompt, clip_skip=1, clip_skip_2=2, positive=True):
+ add_prompt_emb, prompt_emb = self.prompter.encode_prompt(
+ prompt,
+ clip_skip=clip_skip, clip_skip_2=clip_skip_2,
+ device=self.device,
+ positive=positive,
+ )
+ return {"encoder_hidden_states": prompt_emb, "add_text_embeds": add_prompt_emb}
+
+
+ def prepare_extra_input(self, latents=None):
+ height, width = latents.shape[2] * 8, latents.shape[3] * 8
+ add_time_id = torch.tensor([height, width, 0, 0, height, width], device=self.device).repeat(latents.shape[0])
+ return {"add_time_id": add_time_id}
+
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt,
+ local_prompts=[],
+ masks=[],
+ mask_scales=[],
+ negative_prompt="",
+ cfg_scale=7.5,
+ clip_skip=1,
+ clip_skip_2=2,
+ input_image=None,
+ ipadapter_images=None,
+ ipadapter_scale=1.0,
+ ipadapter_use_instant_style=False,
+ controlnet_image=None,
+ denoising_strength=1.0,
+ height=1024,
+ width=1024,
+ num_inference_steps=20,
+ tiled=False,
+ tile_size=64,
+ tile_stride=32,
+ seed=None,
+ progress_bar_cmd=tqdm,
+ progress_bar_st=None,
+ ):
+ height, width = self.check_resize_height_width(height, width)
+
+ # Tiler parameters
+ tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
+
+ # Prepare scheduler
+ self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
+
+ # Prepare latent tensors
+ if input_image is not None:
+ self.load_models_to_device(['vae_encoder'])
+ image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
+ latents = self.encode_image(image, **tiler_kwargs)
+ noise = self.generate_noise((1, 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
+ latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
+ else:
+ latents = self.generate_noise((1, 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
+
+ # Encode prompts
+ self.load_models_to_device(['text_encoder', 'text_encoder_2', 'text_encoder_kolors'])
+ prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True)
+ prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=False)
+ prompt_emb_locals = [self.encode_prompt(prompt_local, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True) for prompt_local in local_prompts]
+
+ # IP-Adapter
+ if ipadapter_images is not None:
+ if ipadapter_use_instant_style:
+ self.ipadapter.set_less_adapter()
+ else:
+ self.ipadapter.set_full_adapter()
+ self.load_models_to_device(['ipadapter_image_encoder'])
+ ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images)
+ self.load_models_to_device(['ipadapter'])
+ ipadapter_kwargs_list_posi = {"ipadapter_kwargs_list": self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)}
+ ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": self.ipadapter(torch.zeros_like(ipadapter_image_encoding))}
+ else:
+ ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": {}}, {"ipadapter_kwargs_list": {}}
+
+ # Prepare ControlNets
+ if controlnet_image is not None:
+ self.load_models_to_device(['controlnet'])
+ controlnet_image = self.controlnet.process_image(controlnet_image).to(device=self.device, dtype=self.torch_dtype)
+ controlnet_image = controlnet_image.unsqueeze(1)
+ controlnet_kwargs = {"controlnet_frames": controlnet_image}
+ else:
+ controlnet_kwargs = {"controlnet_frames": None}
+
+ # Prepare extra input
+ extra_input = self.prepare_extra_input(latents)
+
+ # Denoise
+ self.load_models_to_device(['controlnet', 'unet'])
+ for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
+ timestep = timestep.unsqueeze(0).to(self.device)
+
+ # Classifier-free guidance
+ inference_callback = lambda prompt_emb_posi: lets_dance_xl(
+ self.unet, motion_modules=None, controlnet=self.controlnet,
+ sample=latents, timestep=timestep, **extra_input,
+ **prompt_emb_posi, **controlnet_kwargs, **tiler_kwargs, **ipadapter_kwargs_list_posi,
+ device=self.device,
+ )
+ noise_pred_posi = self.control_noise_via_local_prompts(prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback)
+
+ if cfg_scale != 1.0:
+ noise_pred_nega = lets_dance_xl(
+ self.unet, motion_modules=None, controlnet=self.controlnet,
+ sample=latents, timestep=timestep, **extra_input,
+ **prompt_emb_nega, **controlnet_kwargs, **tiler_kwargs, **ipadapter_kwargs_list_nega,
+ device=self.device,
+ )
+ noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
+ else:
+ noise_pred = noise_pred_posi
+
+ # DDIM
+ latents = self.scheduler.step(noise_pred, timestep, latents)
+
+ # UI
+ if progress_bar_st is not None:
+ progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
+
+ # Decode image
+ self.load_models_to_device(['vae_decoder'])
+ image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+
+ # offload all models
+ self.load_models_to_device([])
+ return image
diff --git a/PusaV1/diffsynth/pipelines/sdxl_video.py b/PusaV1/diffsynth/pipelines/sdxl_video.py
new file mode 100644
index 0000000000000000000000000000000000000000..308590ca6a874c5803da95db1d90fced26126893
--- /dev/null
+++ b/PusaV1/diffsynth/pipelines/sdxl_video.py
@@ -0,0 +1,226 @@
+from ..models import SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder, SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder, SDXLMotionModel
+from ..models.kolors_text_encoder import ChatGLMModel
+from ..models.model_manager import ModelManager
+from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
+from ..prompters import SDXLPrompter, KolorsPrompter
+from ..schedulers import EnhancedDDIMScheduler
+from .sdxl_image import SDXLImagePipeline
+from .dancer import lets_dance_xl
+from typing import List
+import torch
+from tqdm import tqdm
+
+
+
+class SDXLVideoPipeline(SDXLImagePipeline):
+
+ def __init__(self, device="cuda", torch_dtype=torch.float16, use_original_animatediff=True):
+ super().__init__(device=device, torch_dtype=torch_dtype)
+ self.scheduler = EnhancedDDIMScheduler(beta_schedule="linear" if use_original_animatediff else "scaled_linear")
+ self.prompter = SDXLPrompter()
+ # models
+ self.text_encoder: SDXLTextEncoder = None
+ self.text_encoder_2: SDXLTextEncoder2 = None
+ self.text_encoder_kolors: ChatGLMModel = None
+ self.unet: SDXLUNet = None
+ self.vae_decoder: SDXLVAEDecoder = None
+ self.vae_encoder: SDXLVAEEncoder = None
+ # self.controlnet: MultiControlNetManager = None (TODO)
+ self.ipadapter_image_encoder: IpAdapterXLCLIPImageEmbedder = None
+ self.ipadapter: SDXLIpAdapter = None
+ self.motion_modules: SDXLMotionModel = None
+
+
+ def fetch_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
+ # Main models
+ self.text_encoder = model_manager.fetch_model("sdxl_text_encoder")
+ self.text_encoder_2 = model_manager.fetch_model("sdxl_text_encoder_2")
+ self.text_encoder_kolors = model_manager.fetch_model("kolors_text_encoder")
+ self.unet = model_manager.fetch_model("sdxl_unet")
+ self.vae_decoder = model_manager.fetch_model("sdxl_vae_decoder")
+ self.vae_encoder = model_manager.fetch_model("sdxl_vae_encoder")
+ self.prompter.fetch_models(self.text_encoder)
+ self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
+
+ # ControlNets (TODO)
+
+ # IP-Adapters
+ self.ipadapter = model_manager.fetch_model("sdxl_ipadapter")
+ self.ipadapter_image_encoder = model_manager.fetch_model("sdxl_ipadapter_clip_image_encoder")
+
+ # Motion Modules
+ self.motion_modules = model_manager.fetch_model("sdxl_motion_modules")
+ if self.motion_modules is None:
+ self.scheduler = EnhancedDDIMScheduler(beta_schedule="scaled_linear")
+
+ # Kolors
+ if self.text_encoder_kolors is not None:
+ print("Switch to Kolors. The prompter will be replaced.")
+ self.prompter = KolorsPrompter()
+ self.prompter.fetch_models(self.text_encoder_kolors)
+ # The schedulers of AniamteDiff and Kolors are incompatible. We align it with AniamteDiff.
+ if self.motion_modules is None:
+ self.scheduler = EnhancedDDIMScheduler(beta_end=0.014, num_train_timesteps=1100)
+ else:
+ self.prompter.fetch_models(self.text_encoder, self.text_encoder_2)
+
+
+ @staticmethod
+ def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
+ pipe = SDXLVideoPipeline(
+ device=model_manager.device,
+ torch_dtype=model_manager.torch_dtype,
+ )
+ pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes)
+ return pipe
+
+
+ def decode_video(self, latents, tiled=False, tile_size=64, tile_stride=32):
+ images = [
+ self.decode_image(latents[frame_id: frame_id+1], tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+ for frame_id in range(latents.shape[0])
+ ]
+ return images
+
+
+ def encode_video(self, processed_images, tiled=False, tile_size=64, tile_stride=32):
+ latents = []
+ for image in processed_images:
+ image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
+ latent = self.encode_image(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+ latents.append(latent.cpu())
+ latents = torch.concat(latents, dim=0)
+ return latents
+
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt,
+ negative_prompt="",
+ cfg_scale=7.5,
+ clip_skip=1,
+ num_frames=None,
+ input_frames=None,
+ ipadapter_images=None,
+ ipadapter_scale=1.0,
+ ipadapter_use_instant_style=False,
+ controlnet_frames=None,
+ denoising_strength=1.0,
+ height=512,
+ width=512,
+ num_inference_steps=20,
+ animatediff_batch_size = 16,
+ animatediff_stride = 8,
+ unet_batch_size = 1,
+ controlnet_batch_size = 1,
+ cross_frame_attention = False,
+ smoother=None,
+ smoother_progress_ids=[],
+ tiled=False,
+ tile_size=64,
+ tile_stride=32,
+ seed=None,
+ progress_bar_cmd=tqdm,
+ progress_bar_st=None,
+ ):
+ height, width = self.check_resize_height_width(height, width)
+
+ # Tiler parameters, batch size ...
+ tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
+
+ # Prepare scheduler
+ self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
+
+ # Prepare latent tensors
+ if self.motion_modules is None:
+ noise = self.generate_noise((1, 4, height//8, width//8), seed=seed, device="cpu", dtype=self.torch_dtype).repeat(num_frames, 1, 1, 1)
+ else:
+ noise = self.generate_noise((num_frames, 4, height//8, width//8), seed=seed, device="cpu", dtype=self.torch_dtype)
+ if input_frames is None or denoising_strength == 1.0:
+ latents = noise
+ else:
+ latents = self.encode_video(input_frames, **tiler_kwargs)
+ latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
+ latents = latents.to(self.device) # will be deleted for supporting long videos
+
+ # Encode prompts
+ prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, positive=True)
+ prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, positive=False)
+
+ # IP-Adapter
+ if ipadapter_images is not None:
+ if ipadapter_use_instant_style:
+ self.ipadapter.set_less_adapter()
+ else:
+ self.ipadapter.set_full_adapter()
+ ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images)
+ ipadapter_kwargs_list_posi = {"ipadapter_kwargs_list": self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)}
+ ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": self.ipadapter(torch.zeros_like(ipadapter_image_encoding))}
+ else:
+ ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": {}}, {"ipadapter_kwargs_list": {}}
+
+ # Prepare ControlNets
+ if controlnet_frames is not None:
+ if isinstance(controlnet_frames[0], list):
+ controlnet_frames_ = []
+ for processor_id in range(len(controlnet_frames)):
+ controlnet_frames_.append(
+ torch.stack([
+ self.controlnet.process_image(controlnet_frame, processor_id=processor_id).to(self.torch_dtype)
+ for controlnet_frame in progress_bar_cmd(controlnet_frames[processor_id])
+ ], dim=1)
+ )
+ controlnet_frames = torch.concat(controlnet_frames_, dim=0)
+ else:
+ controlnet_frames = torch.stack([
+ self.controlnet.process_image(controlnet_frame).to(self.torch_dtype)
+ for controlnet_frame in progress_bar_cmd(controlnet_frames)
+ ], dim=1)
+ controlnet_kwargs = {"controlnet_frames": controlnet_frames}
+ else:
+ controlnet_kwargs = {"controlnet_frames": None}
+
+ # Prepare extra input
+ extra_input = self.prepare_extra_input(latents)
+
+ # Denoise
+ for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
+ timestep = timestep.unsqueeze(0).to(self.device)
+
+ # Classifier-free guidance
+ noise_pred_posi = lets_dance_xl(
+ self.unet, motion_modules=self.motion_modules, controlnet=None,
+ sample=latents, timestep=timestep,
+ **prompt_emb_posi, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **extra_input, **tiler_kwargs,
+ device=self.device,
+ )
+ noise_pred_nega = lets_dance_xl(
+ self.unet, motion_modules=self.motion_modules, controlnet=None,
+ sample=latents, timestep=timestep,
+ **prompt_emb_nega, **controlnet_kwargs, **ipadapter_kwargs_list_nega, **extra_input, **tiler_kwargs,
+ device=self.device,
+ )
+ noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
+
+ # DDIM and smoother
+ if smoother is not None and progress_id in smoother_progress_ids:
+ rendered_frames = self.scheduler.step(noise_pred, timestep, latents, to_final=True)
+ rendered_frames = self.decode_video(rendered_frames)
+ rendered_frames = smoother(rendered_frames, original_frames=input_frames)
+ target_latents = self.encode_video(rendered_frames)
+ noise_pred = self.scheduler.return_to_timestep(timestep, latents, target_latents)
+ latents = self.scheduler.step(noise_pred, timestep, latents)
+
+ # UI
+ if progress_bar_st is not None:
+ progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
+
+ # Decode image
+ output_frames = self.decode_video(latents, **tiler_kwargs)
+
+ # Post-process
+ if smoother is not None and (num_inference_steps in smoother_progress_ids or -1 in smoother_progress_ids):
+ output_frames = smoother(output_frames, original_frames=input_frames)
+
+ return output_frames
diff --git a/PusaV1/diffsynth/pipelines/step_video.py b/PusaV1/diffsynth/pipelines/step_video.py
new file mode 100644
index 0000000000000000000000000000000000000000..56140178e9d6cdaf5efeca77ea061f8232836f11
--- /dev/null
+++ b/PusaV1/diffsynth/pipelines/step_video.py
@@ -0,0 +1,209 @@
+from ..models import ModelManager
+from ..models.hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder
+from ..models.stepvideo_text_encoder import STEP1TextEncoder
+from ..models.stepvideo_dit import StepVideoModel
+from ..models.stepvideo_vae import StepVideoVAE
+from ..schedulers.flow_match import FlowMatchScheduler
+from .base import BasePipeline
+from ..prompters import StepVideoPrompter
+import torch
+from einops import rearrange
+import numpy as np
+from PIL import Image
+from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
+from transformers.models.bert.modeling_bert import BertEmbeddings
+from ..models.stepvideo_dit import RMSNorm
+from ..models.stepvideo_vae import CausalConv, CausalConvAfterNorm, Upsample2D, BaseGroupNorm
+
+
+
+class StepVideoPipeline(BasePipeline):
+
+ def __init__(self, device="cuda", torch_dtype=torch.float16):
+ super().__init__(device=device, torch_dtype=torch_dtype)
+ self.scheduler = FlowMatchScheduler(sigma_min=0.0, extra_one_step=True, shift=13.0, reverse_sigmas=True, num_train_timesteps=1)
+ self.prompter = StepVideoPrompter()
+ self.text_encoder_1: HunyuanDiTCLIPTextEncoder = None
+ self.text_encoder_2: STEP1TextEncoder = None
+ self.dit: StepVideoModel = None
+ self.vae: StepVideoVAE = None
+ self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae']
+
+
+ def enable_vram_management(self, num_persistent_param_in_dit=None):
+ dtype = next(iter(self.text_encoder_1.parameters())).dtype
+ enable_vram_management(
+ self.text_encoder_1,
+ module_map = {
+ torch.nn.Linear: AutoWrappedLinear,
+ BertEmbeddings: AutoWrappedModule,
+ torch.nn.LayerNorm: AutoWrappedModule,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device="cpu",
+ computation_dtype=torch.float32,
+ computation_device=self.device,
+ ),
+ )
+ dtype = next(iter(self.text_encoder_2.parameters())).dtype
+ enable_vram_management(
+ self.text_encoder_2,
+ module_map = {
+ torch.nn.Linear: AutoWrappedLinear,
+ RMSNorm: AutoWrappedModule,
+ torch.nn.Embedding: AutoWrappedModule,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device="cpu",
+ computation_dtype=self.torch_dtype,
+ computation_device=self.device,
+ ),
+ )
+ dtype = next(iter(self.dit.parameters())).dtype
+ enable_vram_management(
+ self.dit,
+ module_map = {
+ torch.nn.Linear: AutoWrappedLinear,
+ torch.nn.Conv2d: AutoWrappedModule,
+ torch.nn.LayerNorm: AutoWrappedModule,
+ RMSNorm: AutoWrappedModule,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device=self.device,
+ computation_dtype=self.torch_dtype,
+ computation_device=self.device,
+ ),
+ max_num_param=num_persistent_param_in_dit,
+ overflow_module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device="cpu",
+ computation_dtype=self.torch_dtype,
+ computation_device=self.device,
+ ),
+ )
+ dtype = next(iter(self.vae.parameters())).dtype
+ enable_vram_management(
+ self.vae,
+ module_map = {
+ torch.nn.Linear: AutoWrappedLinear,
+ torch.nn.Conv3d: AutoWrappedModule,
+ CausalConv: AutoWrappedModule,
+ CausalConvAfterNorm: AutoWrappedModule,
+ Upsample2D: AutoWrappedModule,
+ BaseGroupNorm: AutoWrappedModule,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device="cpu",
+ computation_dtype=self.torch_dtype,
+ computation_device=self.device,
+ ),
+ )
+ self.enable_cpu_offload()
+
+
+ def fetch_models(self, model_manager: ModelManager):
+ self.text_encoder_1 = model_manager.fetch_model("hunyuan_dit_clip_text_encoder")
+ self.text_encoder_2 = model_manager.fetch_model("stepvideo_text_encoder_2")
+ self.dit = model_manager.fetch_model("stepvideo_dit")
+ self.vae = model_manager.fetch_model("stepvideo_vae")
+ self.prompter.fetch_models(self.text_encoder_1, self.text_encoder_2)
+
+
+ @staticmethod
+ def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None):
+ if device is None: device = model_manager.device
+ if torch_dtype is None: torch_dtype = model_manager.torch_dtype
+ pipe = StepVideoPipeline(device=device, torch_dtype=torch_dtype)
+ pipe.fetch_models(model_manager)
+ return pipe
+
+
+ def encode_prompt(self, prompt, positive=True):
+ clip_embeds, llm_embeds, llm_mask = self.prompter.encode_prompt(prompt, device=self.device, positive=positive)
+ clip_embeds = clip_embeds.to(dtype=self.torch_dtype, device=self.device)
+ llm_embeds = llm_embeds.to(dtype=self.torch_dtype, device=self.device)
+ llm_mask = llm_mask.to(dtype=self.torch_dtype, device=self.device)
+ return {"encoder_hidden_states_2": clip_embeds, "encoder_hidden_states": llm_embeds, "encoder_attention_mask": llm_mask}
+
+
+ def tensor2video(self, frames):
+ frames = rearrange(frames, "C T H W -> T H W C")
+ frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
+ frames = [Image.fromarray(frame) for frame in frames]
+ return frames
+
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt,
+ negative_prompt="",
+ input_video=None,
+ denoising_strength=1.0,
+ seed=None,
+ rand_device="cpu",
+ height=544,
+ width=992,
+ num_frames=204,
+ cfg_scale=9.0,
+ num_inference_steps=30,
+ tiled=True,
+ tile_size=(34, 34),
+ tile_stride=(16, 16),
+ smooth_scale=0.6,
+ progress_bar_cmd=lambda x: x,
+ progress_bar_st=None,
+ ):
+ # Tiler parameters
+ tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
+
+ # Scheduler
+ self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
+
+ # Initialize noise
+ latents = self.generate_noise((1, max(num_frames//17*3, 1), 64, height//16, width//16), seed=seed, device=rand_device, dtype=self.torch_dtype).to(self.device)
+
+ # Encode prompts
+ self.load_models_to_device(["text_encoder_1", "text_encoder_2"])
+ prompt_emb_posi = self.encode_prompt(prompt, positive=True)
+ if cfg_scale != 1.0:
+ prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
+
+ # Denoise
+ self.load_models_to_device(["dit"])
+ for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
+ timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
+ print(f"Step {progress_id + 1} / {len(self.scheduler.timesteps)}")
+
+ # Inference
+ noise_pred_posi = self.dit(latents, timestep=timestep, **prompt_emb_posi)
+ if cfg_scale != 1.0:
+ noise_pred_nega = self.dit(latents, timestep=timestep, **prompt_emb_nega)
+ noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
+ else:
+ noise_pred = noise_pred_posi
+
+ # Scheduler
+ latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
+
+ # Decode
+ self.load_models_to_device(['vae'])
+ frames = self.vae.decode(latents, device=self.device, smooth_scale=smooth_scale, **tiler_kwargs)
+ self.load_models_to_device([])
+ frames = self.tensor2video(frames[0])
+
+ return frames
diff --git a/PusaV1/diffsynth/pipelines/svd_video.py b/PusaV1/diffsynth/pipelines/svd_video.py
new file mode 100644
index 0000000000000000000000000000000000000000..b71597efa73783f7e3746a2bcf6b7be5c70c360e
--- /dev/null
+++ b/PusaV1/diffsynth/pipelines/svd_video.py
@@ -0,0 +1,300 @@
+from ..models import ModelManager, SVDImageEncoder, SVDUNet, SVDVAEEncoder, SVDVAEDecoder
+from ..schedulers import ContinuousODEScheduler
+from .base import BasePipeline
+import torch
+from tqdm import tqdm
+from PIL import Image
+import numpy as np
+from einops import rearrange, repeat
+
+
+
+class SVDVideoPipeline(BasePipeline):
+
+ def __init__(self, device="cuda", torch_dtype=torch.float16):
+ super().__init__(device=device, torch_dtype=torch_dtype)
+ self.scheduler = ContinuousODEScheduler()
+ # models
+ self.image_encoder: SVDImageEncoder = None
+ self.unet: SVDUNet = None
+ self.vae_encoder: SVDVAEEncoder = None
+ self.vae_decoder: SVDVAEDecoder = None
+
+
+ def fetch_models(self, model_manager: ModelManager):
+ self.image_encoder = model_manager.fetch_model("svd_image_encoder")
+ self.unet = model_manager.fetch_model("svd_unet")
+ self.vae_encoder = model_manager.fetch_model("svd_vae_encoder")
+ self.vae_decoder = model_manager.fetch_model("svd_vae_decoder")
+
+
+ @staticmethod
+ def from_model_manager(model_manager: ModelManager, **kwargs):
+ pipe = SVDVideoPipeline(
+ device=model_manager.device,
+ torch_dtype=model_manager.torch_dtype
+ )
+ pipe.fetch_models(model_manager)
+ return pipe
+
+
+ def encode_image_with_clip(self, image):
+ image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
+ image = SVDCLIPImageProcessor().resize_with_antialiasing(image, (224, 224))
+ image = (image + 1.0) / 2.0
+ mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).reshape(1, 3, 1, 1).to(device=self.device, dtype=self.torch_dtype)
+ std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).reshape(1, 3, 1, 1).to(device=self.device, dtype=self.torch_dtype)
+ image = (image - mean) / std
+ image_emb = self.image_encoder(image)
+ return image_emb
+
+
+ def encode_image_with_vae(self, image, noise_aug_strength, seed=None):
+ image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
+ noise = self.generate_noise(image.shape, seed=seed, device=self.device, dtype=self.torch_dtype)
+ image = image + noise_aug_strength * noise
+ image_emb = self.vae_encoder(image) / self.vae_encoder.scaling_factor
+ return image_emb
+
+
+ def encode_video_with_vae(self, video):
+ video = torch.concat([self.preprocess_image(frame) for frame in video], dim=0)
+ video = rearrange(video, "T C H W -> 1 C T H W")
+ video = video.to(device=self.device, dtype=self.torch_dtype)
+ latents = self.vae_encoder.encode_video(video)
+ latents = rearrange(latents[0], "C T H W -> T C H W")
+ return latents
+
+
+ def tensor2video(self, frames):
+ frames = rearrange(frames, "C T H W -> T H W C")
+ frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
+ frames = [Image.fromarray(frame) for frame in frames]
+ return frames
+
+
+ def calculate_noise_pred(
+ self,
+ latents,
+ timestep,
+ add_time_id,
+ cfg_scales,
+ image_emb_vae_posi, image_emb_clip_posi,
+ image_emb_vae_nega, image_emb_clip_nega
+ ):
+ # Positive side
+ noise_pred_posi = self.unet(
+ torch.cat([latents, image_emb_vae_posi], dim=1),
+ timestep, image_emb_clip_posi, add_time_id
+ )
+ # Negative side
+ noise_pred_nega = self.unet(
+ torch.cat([latents, image_emb_vae_nega], dim=1),
+ timestep, image_emb_clip_nega, add_time_id
+ )
+
+ # Classifier-free guidance
+ noise_pred = noise_pred_nega + cfg_scales * (noise_pred_posi - noise_pred_nega)
+
+ return noise_pred
+
+
+ def post_process_latents(self, latents, post_normalize=True, contrast_enhance_scale=1.0):
+ if post_normalize:
+ mean, std = latents.mean(), latents.std()
+ latents = (latents - latents.mean(dim=[1, 2, 3], keepdim=True)) / latents.std(dim=[1, 2, 3], keepdim=True) * std + mean
+ latents = latents * contrast_enhance_scale
+ return latents
+
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ input_image=None,
+ input_video=None,
+ mask_frames=[],
+ mask_frame_ids=[],
+ min_cfg_scale=1.0,
+ max_cfg_scale=3.0,
+ denoising_strength=1.0,
+ num_frames=25,
+ height=576,
+ width=1024,
+ fps=7,
+ motion_bucket_id=127,
+ noise_aug_strength=0.02,
+ num_inference_steps=20,
+ post_normalize=True,
+ contrast_enhance_scale=1.2,
+ seed=None,
+ progress_bar_cmd=tqdm,
+ progress_bar_st=None,
+ ):
+ height, width = self.check_resize_height_width(height, width)
+
+ # Prepare scheduler
+ self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength)
+
+ # Prepare latent tensors
+ noise = self.generate_noise((num_frames, 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
+ if denoising_strength == 1.0:
+ latents = noise.clone()
+ else:
+ latents = self.encode_video_with_vae(input_video)
+ latents = self.scheduler.add_noise(latents, noise, self.scheduler.timesteps[0])
+
+ # Prepare mask frames
+ if len(mask_frames) > 0:
+ mask_latents = self.encode_video_with_vae(mask_frames)
+
+ # Encode image
+ image_emb_clip_posi = self.encode_image_with_clip(input_image)
+ image_emb_clip_nega = torch.zeros_like(image_emb_clip_posi)
+ image_emb_vae_posi = repeat(self.encode_image_with_vae(input_image, noise_aug_strength, seed=seed), "B C H W -> (B T) C H W", T=num_frames)
+ image_emb_vae_nega = torch.zeros_like(image_emb_vae_posi)
+
+ # Prepare classifier-free guidance
+ cfg_scales = torch.linspace(min_cfg_scale, max_cfg_scale, num_frames)
+ cfg_scales = cfg_scales.reshape(num_frames, 1, 1, 1).to(device=self.device, dtype=self.torch_dtype)
+
+ # Prepare positional id
+ add_time_id = torch.tensor([[fps-1, motion_bucket_id, noise_aug_strength]], device=self.device)
+
+ # Denoise
+ for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
+
+ # Mask frames
+ for frame_id, mask_frame_id in enumerate(mask_frame_ids):
+ latents[mask_frame_id] = self.scheduler.add_noise(mask_latents[frame_id], noise[mask_frame_id], timestep)
+
+ # Fetch model output
+ noise_pred = self.calculate_noise_pred(
+ latents, timestep, add_time_id, cfg_scales,
+ image_emb_vae_posi, image_emb_clip_posi, image_emb_vae_nega, image_emb_clip_nega
+ )
+
+ # Forward Euler
+ latents = self.scheduler.step(noise_pred, timestep, latents)
+
+ # Update progress bar
+ if progress_bar_st is not None:
+ progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
+
+ # Decode image
+ latents = self.post_process_latents(latents, post_normalize=post_normalize, contrast_enhance_scale=contrast_enhance_scale)
+ video = self.vae_decoder.decode_video(latents, progress_bar=progress_bar_cmd)
+ video = self.tensor2video(video)
+
+ return video
+
+
+
+class SVDCLIPImageProcessor:
+ def __init__(self):
+ pass
+
+ def resize_with_antialiasing(self, input, size, interpolation="bicubic", align_corners=True):
+ h, w = input.shape[-2:]
+ factors = (h / size[0], w / size[1])
+
+ # First, we have to determine sigma
+ # Taken from skimage: https://github.com/scikit-image/scikit-image/blob/v0.19.2/skimage/transform/_warps.py#L171
+ sigmas = (
+ max((factors[0] - 1.0) / 2.0, 0.001),
+ max((factors[1] - 1.0) / 2.0, 0.001),
+ )
+
+ # Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma
+ # https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206
+ # But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now
+ ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3))
+
+ # Make sure it is odd
+ if (ks[0] % 2) == 0:
+ ks = ks[0] + 1, ks[1]
+
+ if (ks[1] % 2) == 0:
+ ks = ks[0], ks[1] + 1
+
+ input = self._gaussian_blur2d(input, ks, sigmas)
+
+ output = torch.nn.functional.interpolate(input, size=size, mode=interpolation, align_corners=align_corners)
+ return output
+
+
+ def _compute_padding(self, kernel_size):
+ """Compute padding tuple."""
+ # 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom)
+ # https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad
+ if len(kernel_size) < 2:
+ raise AssertionError(kernel_size)
+ computed = [k - 1 for k in kernel_size]
+
+ # for even kernels we need to do asymmetric padding :(
+ out_padding = 2 * len(kernel_size) * [0]
+
+ for i in range(len(kernel_size)):
+ computed_tmp = computed[-(i + 1)]
+
+ pad_front = computed_tmp // 2
+ pad_rear = computed_tmp - pad_front
+
+ out_padding[2 * i + 0] = pad_front
+ out_padding[2 * i + 1] = pad_rear
+
+ return out_padding
+
+
+ def _filter2d(self, input, kernel):
+ # prepare kernel
+ b, c, h, w = input.shape
+ tmp_kernel = kernel[:, None, ...].to(device=input.device, dtype=input.dtype)
+
+ tmp_kernel = tmp_kernel.expand(-1, c, -1, -1)
+
+ height, width = tmp_kernel.shape[-2:]
+
+ padding_shape: list[int] = self._compute_padding([height, width])
+ input = torch.nn.functional.pad(input, padding_shape, mode="reflect")
+
+ # kernel and input tensor reshape to align element-wise or batch-wise params
+ tmp_kernel = tmp_kernel.reshape(-1, 1, height, width)
+ input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1))
+
+ # convolve the tensor with the kernel.
+ output = torch.nn.functional.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1)
+
+ out = output.view(b, c, h, w)
+ return out
+
+
+ def _gaussian(self, window_size: int, sigma):
+ if isinstance(sigma, float):
+ sigma = torch.tensor([[sigma]])
+
+ batch_size = sigma.shape[0]
+
+ x = (torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1)
+
+ if window_size % 2 == 0:
+ x = x + 0.5
+
+ gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0)))
+
+ return gauss / gauss.sum(-1, keepdim=True)
+
+
+ def _gaussian_blur2d(self, input, kernel_size, sigma):
+ if isinstance(sigma, tuple):
+ sigma = torch.tensor([sigma], dtype=input.dtype)
+ else:
+ sigma = sigma.to(dtype=input.dtype)
+
+ ky, kx = int(kernel_size[0]), int(kernel_size[1])
+ bs = sigma.shape[0]
+ kernel_x = self._gaussian(kx, sigma[:, 1].view(bs, 1))
+ kernel_y = self._gaussian(ky, sigma[:, 0].view(bs, 1))
+ out_x = self._filter2d(input, kernel_x[..., None, :])
+ out = self._filter2d(out_x, kernel_y[..., None])
+
+ return out
diff --git a/PusaV1/diffsynth/pipelines/wan_video.py b/PusaV1/diffsynth/pipelines/wan_video.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b36f9c248c4c7feba32c68f94e3a4a1e68dfa12
--- /dev/null
+++ b/PusaV1/diffsynth/pipelines/wan_video.py
@@ -0,0 +1,637 @@
+import types
+from ..models import ModelManager
+from ..models.wan_video_dit import WanModel
+from ..models.wan_video_text_encoder import WanTextEncoder
+from ..models.wan_video_vae import WanVideoVAE
+from ..models.wan_video_image_encoder import WanImageEncoder
+from ..models.wan_video_vace import VaceWanModel
+from ..schedulers.flow_match import FlowMatchScheduler
+from .base import BasePipeline
+from ..prompters import WanPrompter
+import torch, os
+from einops import rearrange
+import numpy as np
+from PIL import Image
+from tqdm import tqdm
+from typing import Optional
+
+from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
+from ..models.wan_video_text_encoder import T5RelativeEmbedding, T5LayerNorm
+from ..models.wan_video_dit import RMSNorm, sinusoidal_embedding_1d
+from ..models.wan_video_vae import RMS_norm, CausalConv3d, Upsample
+from ..models.wan_video_motion_controller import WanMotionControllerModel
+
+
+
+class WanVideoPipeline(BasePipeline):
+
+ def __init__(self, device="cuda", torch_dtype=torch.float16, tokenizer_path=None):
+ super().__init__(device=device, torch_dtype=torch_dtype)
+ self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True)
+ self.prompter = WanPrompter(tokenizer_path=tokenizer_path)
+ self.text_encoder: WanTextEncoder = None
+ self.image_encoder: WanImageEncoder = None
+ self.dit: WanModel = None
+ self.vae: WanVideoVAE = None
+ self.motion_controller: WanMotionControllerModel = None
+ self.vace: VaceWanModel = None
+ self.model_names = ['text_encoder', 'dit', 'vae', 'image_encoder', 'motion_controller', 'vace']
+ self.height_division_factor = 16
+ self.width_division_factor = 16
+ self.use_unified_sequence_parallel = False
+
+
+ def enable_vram_management(self, num_persistent_param_in_dit=None):
+ dtype = next(iter(self.text_encoder.parameters())).dtype
+ enable_vram_management(
+ self.text_encoder,
+ module_map = {
+ torch.nn.Linear: AutoWrappedLinear,
+ torch.nn.Embedding: AutoWrappedModule,
+ T5RelativeEmbedding: AutoWrappedModule,
+ T5LayerNorm: AutoWrappedModule,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device="cpu",
+ computation_dtype=self.torch_dtype,
+ computation_device=self.device,
+ ),
+ )
+ dtype = next(iter(self.dit.parameters())).dtype
+ enable_vram_management(
+ self.dit,
+ module_map = {
+ torch.nn.Linear: AutoWrappedLinear,
+ torch.nn.Conv3d: AutoWrappedModule,
+ torch.nn.LayerNorm: AutoWrappedModule,
+ RMSNorm: AutoWrappedModule,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device=self.device,
+ computation_dtype=self.torch_dtype,
+ computation_device=self.device,
+ ),
+ max_num_param=num_persistent_param_in_dit,
+ overflow_module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device="cpu",
+ computation_dtype=self.torch_dtype,
+ computation_device=self.device,
+ ),
+ )
+ dtype = next(iter(self.vae.parameters())).dtype
+ enable_vram_management(
+ self.vae,
+ module_map = {
+ torch.nn.Linear: AutoWrappedLinear,
+ torch.nn.Conv2d: AutoWrappedModule,
+ RMS_norm: AutoWrappedModule,
+ CausalConv3d: AutoWrappedModule,
+ Upsample: AutoWrappedModule,
+ torch.nn.SiLU: AutoWrappedModule,
+ torch.nn.Dropout: AutoWrappedModule,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device=self.device,
+ computation_dtype=self.torch_dtype,
+ computation_device=self.device,
+ ),
+ )
+ if self.image_encoder is not None:
+ dtype = next(iter(self.image_encoder.parameters())).dtype
+ enable_vram_management(
+ self.image_encoder,
+ module_map = {
+ torch.nn.Linear: AutoWrappedLinear,
+ torch.nn.Conv2d: AutoWrappedModule,
+ torch.nn.LayerNorm: AutoWrappedModule,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device="cpu",
+ computation_dtype=dtype,
+ computation_device=self.device,
+ ),
+ )
+ if self.motion_controller is not None:
+ dtype = next(iter(self.motion_controller.parameters())).dtype
+ enable_vram_management(
+ self.motion_controller,
+ module_map = {
+ torch.nn.Linear: AutoWrappedLinear,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device="cpu",
+ computation_dtype=dtype,
+ computation_device=self.device,
+ ),
+ )
+ if self.vace is not None:
+ enable_vram_management(
+ self.vace,
+ module_map = {
+ torch.nn.Linear: AutoWrappedLinear,
+ torch.nn.Conv3d: AutoWrappedModule,
+ torch.nn.LayerNorm: AutoWrappedModule,
+ RMSNorm: AutoWrappedModule,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device=self.device,
+ computation_dtype=self.torch_dtype,
+ computation_device=self.device,
+ ),
+ )
+ self.enable_cpu_offload()
+
+
+ def fetch_models(self, model_manager: ModelManager):
+ text_encoder_model_and_path = model_manager.fetch_model("wan_video_text_encoder", require_model_path=True)
+ if text_encoder_model_and_path is not None:
+ self.text_encoder, tokenizer_path = text_encoder_model_and_path
+ self.prompter.fetch_models(self.text_encoder)
+ self.prompter.fetch_tokenizer(os.path.join(os.path.dirname(tokenizer_path), "google/umt5-xxl"))
+ self.dit = model_manager.fetch_model("wan_video_dit")
+ self.vae = model_manager.fetch_model("wan_video_vae")
+ self.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
+ self.motion_controller = model_manager.fetch_model("wan_video_motion_controller")
+ self.vace = model_manager.fetch_model("wan_video_vace")
+
+
+ @staticmethod
+ def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None, use_usp=False):
+ if device is None: device = model_manager.device
+ if torch_dtype is None: torch_dtype = model_manager.torch_dtype
+ pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
+ pipe.fetch_models(model_manager)
+ if use_usp:
+ from xfuser.core.distributed import get_sequence_parallel_world_size
+ from ..distributed.xdit_context_parallel import usp_attn_forward, usp_dit_forward
+
+ for block in pipe.dit.blocks:
+ block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
+ pipe.dit.forward = types.MethodType(usp_dit_forward, pipe.dit)
+ pipe.sp_size = get_sequence_parallel_world_size()
+ pipe.use_unified_sequence_parallel = True
+ return pipe
+
+
+ def denoising_model(self):
+ return self.dit
+
+
+ def encode_prompt(self, prompt, positive=True):
+ prompt_emb = self.prompter.encode_prompt(prompt, positive=positive, device=self.device)
+ return {"context": prompt_emb}
+
+
+ def encode_image(self, image, end_image, num_frames, height, width):
+ image = self.preprocess_image(image.resize((width, height))).to(self.device)
+ clip_context = self.image_encoder.encode_image([image])
+ msk = torch.ones(1, num_frames, height//8, width//8, device=self.device)
+ msk[:, 1:] = 0
+ if end_image is not None:
+ end_image = self.preprocess_image(end_image.resize((width, height))).to(self.device)
+ vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1)
+ msk[:, -1:] = 1
+ else:
+ vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)
+
+ msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
+ msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
+ msk = msk.transpose(1, 2)[0]
+
+ y = self.vae.encode([vae_input.to(dtype=self.torch_dtype, device=self.device)], device=self.device)[0]
+ y = torch.concat([msk, y])
+ y = y.unsqueeze(0)
+ clip_context = clip_context.to(dtype=self.torch_dtype, device=self.device)
+ y = y.to(dtype=self.torch_dtype, device=self.device)
+ return {"clip_feature": clip_context, "y": y}
+
+
+ def encode_control_video(self, control_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
+ control_video = self.preprocess_images(control_video)
+ control_video = torch.stack(control_video, dim=2).to(dtype=self.torch_dtype, device=self.device)
+ latents = self.encode_video(control_video, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=self.torch_dtype, device=self.device)
+ return latents
+
+
+ def prepare_controlnet_kwargs(self, control_video, num_frames, height, width, clip_feature=None, y=None, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
+ if control_video is not None:
+ control_latents = self.encode_control_video(control_video, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+ if clip_feature is None or y is None:
+ clip_feature = torch.zeros((1, 257, 1280), dtype=self.torch_dtype, device=self.device)
+ y = torch.zeros((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), dtype=self.torch_dtype, device=self.device)
+ else:
+ y = y[:, -16:]
+ y = torch.concat([control_latents, y], dim=1)
+ return {"clip_feature": clip_feature, "y": y}
+
+
+ def tensor2video(self, frames):
+ frames = rearrange(frames, "C T H W -> T H W C")
+ frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
+ frames = [Image.fromarray(frame) for frame in frames]
+ return frames
+
+
+ def prepare_extra_input(self, latents=None):
+ return {}
+
+
+ def encode_video(self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
+ latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+ return latents
+
+
+ def decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
+ model_dtype = next(iter(self.vae.parameters())).dtype
+ model_device = next(iter(self.vae.parameters())).device
+
+ # Convert latents to the correct dtype and device
+ latents = latents.to(dtype=model_dtype, device=model_device)
+
+ frames = self.vae.decode(latents, device=self.device, tiled=tiled,
+ tile_size=tile_size, tile_stride=tile_stride)
+ return frames
+
+
+ def prepare_unified_sequence_parallel(self):
+ return {"use_unified_sequence_parallel": self.use_unified_sequence_parallel}
+
+
+ def prepare_motion_bucket_id(self, motion_bucket_id):
+ motion_bucket_id = torch.Tensor((motion_bucket_id,)).to(dtype=self.torch_dtype, device=self.device)
+ return {"motion_bucket_id": motion_bucket_id}
+
+
+ def prepare_vace_kwargs(
+ self,
+ latents,
+ vace_video=None, vace_mask=None, vace_reference_image=None, vace_scale=1.0,
+ height=480, width=832, num_frames=81,
+ seed=None, rand_device="cpu",
+ tiled=True, tile_size=(34, 34), tile_stride=(18, 16)
+ ):
+ if vace_video is not None or vace_mask is not None or vace_reference_image is not None:
+ self.load_models_to_device(["vae"])
+ if vace_video is None:
+ vace_video = torch.zeros((1, 3, num_frames, height, width), dtype=self.torch_dtype, device=self.device)
+ else:
+ vace_video = self.preprocess_images(vace_video)
+ vace_video = torch.stack(vace_video, dim=2).to(dtype=self.torch_dtype, device=self.device)
+
+ if vace_mask is None:
+ vace_mask = torch.ones_like(vace_video)
+ else:
+ vace_mask = self.preprocess_images(vace_mask)
+ vace_mask = torch.stack(vace_mask, dim=2).to(dtype=self.torch_dtype, device=self.device)
+
+ inactive = vace_video * (1 - vace_mask) + 0 * vace_mask
+ reactive = vace_video * vace_mask + 0 * (1 - vace_mask)
+ inactive = self.encode_video(inactive, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=self.torch_dtype, device=self.device)
+ reactive = self.encode_video(reactive, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=self.torch_dtype, device=self.device)
+ vace_video_latents = torch.concat((inactive, reactive), dim=1)
+
+ vace_mask_latents = rearrange(vace_mask[0,0], "T (H P) (W Q) -> 1 (P Q) T H W", P=8, Q=8)
+ vace_mask_latents = torch.nn.functional.interpolate(vace_mask_latents, size=((vace_mask_latents.shape[2] + 3) // 4, vace_mask_latents.shape[3], vace_mask_latents.shape[4]), mode='nearest-exact')
+
+ if vace_reference_image is None:
+ pass
+ else:
+ vace_reference_image = self.preprocess_images([vace_reference_image])
+ vace_reference_image = torch.stack(vace_reference_image, dim=2).to(dtype=self.torch_dtype, device=self.device)
+ vace_reference_latents = self.encode_video(vace_reference_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=self.torch_dtype, device=self.device)
+ vace_reference_latents = torch.concat((vace_reference_latents, torch.zeros_like(vace_reference_latents)), dim=1)
+ vace_video_latents = torch.concat((vace_reference_latents, vace_video_latents), dim=2)
+ vace_mask_latents = torch.concat((torch.zeros_like(vace_mask_latents[:, :, :1]), vace_mask_latents), dim=2)
+
+ noise = self.generate_noise((1, 16, 1, latents.shape[3], latents.shape[4]), seed=seed, device=rand_device, dtype=torch.float32)
+ noise = noise.to(dtype=self.torch_dtype, device=self.device)
+ latents = torch.concat((noise, latents), dim=2)
+
+ vace_context = torch.concat((vace_video_latents, vace_mask_latents), dim=1)
+ return latents, {"vace_context": vace_context, "vace_scale": vace_scale}
+ else:
+ return latents, {"vace_context": None, "vace_scale": vace_scale}
+
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt,
+ negative_prompt="",
+ input_image=None,
+ end_image=None,
+ input_video=None,
+ control_video=None,
+ vace_video=None,
+ vace_video_mask=None,
+ vace_reference_image=None,
+ vace_scale=1.0,
+ denoising_strength=1.0,
+ seed=None,
+ rand_device="cpu",
+ height=480,
+ width=832,
+ num_frames=81,
+ cfg_scale=5.0,
+ num_inference_steps=50,
+ sigma_shift=5.0,
+ motion_bucket_id=None,
+ tiled=True,
+ tile_size=(30, 52),
+ tile_stride=(15, 26),
+ tea_cache_l1_thresh=None,
+ tea_cache_model_id="",
+ progress_bar_cmd=tqdm,
+ progress_bar_st=None,
+ visualize_attention=False,
+ output_dir=None,
+ ):
+ # Parameter check
+ height, width = self.check_resize_height_width(height, width)
+ if num_frames % 4 != 1:
+ num_frames = (num_frames + 2) // 4 * 4 + 1
+ print(f"Only `num_frames % 4 != 1` is acceptable. We round it up to {num_frames}.")
+
+ if visualize_attention:
+ import datetime
+ import os
+ from ..models.wan_video_dit import _VISUALIZE_ATTENTION_CONFIG
+
+ if output_dir:
+ vis_path = os.path.join(output_dir,"attention_maps")
+ else:
+ timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
+ vis_path = os.path.join("attention_maps", timestamp)
+ os.makedirs(vis_path, exist_ok=True)
+
+ _VISUALIZE_ATTENTION_CONFIG["enabled"] = True
+ _VISUALIZE_ATTENTION_CONFIG["path"] = vis_path
+ print(f"Attention visualization enabled. Maps will be saved to {vis_path}")
+
+ # Tiler parameters
+ tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
+
+ # Scheduler
+ self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
+
+ # Initialize noise
+ noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=torch.float32)
+ noise = noise.to(dtype=self.torch_dtype, device=self.device)
+ if input_video is not None:
+ self.load_models_to_device(['vae'])
+ input_video = self.preprocess_images(input_video)
+ input_video = torch.stack(input_video, dim=2).to(dtype=self.torch_dtype, device=self.device)
+ latents = self.encode_video(input_video, **tiler_kwargs).to(dtype=self.torch_dtype, device=self.device)
+ latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
+ else:
+ latents = noise
+
+ # Encode prompts
+ self.load_models_to_device(["text_encoder"])
+ prompt_emb_posi = self.encode_prompt(prompt, positive=True)
+ if cfg_scale != 1.0:
+ prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
+
+ # Encode image
+ if input_image is not None and self.image_encoder is not None:
+ self.load_models_to_device(["image_encoder", "vae"])
+ image_emb = self.encode_image(input_image, end_image, num_frames, height, width)
+ else:
+ image_emb = {}
+
+ # ControlNet
+ if control_video is not None:
+ self.load_models_to_device(["image_encoder", "vae"])
+ image_emb = self.prepare_controlnet_kwargs(control_video, num_frames, height, width, **image_emb, **tiler_kwargs)
+
+ # Motion Controller
+ if self.motion_controller is not None and motion_bucket_id is not None:
+ motion_kwargs = self.prepare_motion_bucket_id(motion_bucket_id)
+ else:
+ motion_kwargs = {}
+
+ # Extra input
+ extra_input = self.prepare_extra_input(latents)
+
+ # VACE
+ latents, vace_kwargs = self.prepare_vace_kwargs(
+ latents, vace_video, vace_video_mask, vace_reference_image, vace_scale,
+ height=height, width=width, num_frames=num_frames, seed=seed, rand_device=rand_device, **tiler_kwargs
+ )
+
+ # TeaCache
+ tea_cache_posi = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None}
+ tea_cache_nega = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None}
+
+ # Unified Sequence Parallel
+ usp_kwargs = self.prepare_unified_sequence_parallel()
+
+ # Denoise
+ self.load_models_to_device(["dit", "motion_controller", "vace"])
+ for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
+ if visualize_attention:
+ from ..models.wan_video_dit import _VISUALIZE_ATTENTION_CONFIG
+ _VISUALIZE_ATTENTION_CONFIG["step"] = progress_id
+
+ timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
+
+ # Inference
+ noise_pred_posi = model_fn_wan_video(
+ self.dit, motion_controller=self.motion_controller, vace=self.vace,
+ x=latents, timestep=timestep,
+ **prompt_emb_posi, **image_emb, **extra_input,
+ **tea_cache_posi, **usp_kwargs, **motion_kwargs, **vace_kwargs,
+ )
+ if cfg_scale != 1.0:
+ noise_pred_nega = model_fn_wan_video(
+ self.dit, motion_controller=self.motion_controller, vace=self.vace,
+ x=latents, timestep=timestep,
+ **prompt_emb_nega, **image_emb, **extra_input,
+ **tea_cache_nega, **usp_kwargs, **motion_kwargs, **vace_kwargs,
+ )
+ noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
+ else:
+ noise_pred = noise_pred_posi
+
+ # Scheduler
+ latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
+
+ if visualize_attention:
+ from ..models.wan_video_dit import _VISUALIZE_ATTENTION_CONFIG
+ _VISUALIZE_ATTENTION_CONFIG["enabled"] = False
+ _VISUALIZE_ATTENTION_CONFIG["path"] = None
+ print("Attention visualization finished.")
+
+ if vace_reference_image is not None:
+ latents = latents[:, :, 1:]
+
+ # Decode
+ self.load_models_to_device(['vae'])
+ frames = self.decode_video(latents, **tiler_kwargs)
+ self.load_models_to_device([])
+ frames = self.tensor2video(frames[0])
+
+ return frames
+
+
+
+class TeaCache:
+ def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
+ self.num_inference_steps = num_inference_steps
+ self.step = 0
+ self.accumulated_rel_l1_distance = 0
+ self.previous_modulated_input = None
+ self.rel_l1_thresh = rel_l1_thresh
+ self.previous_residual = None
+ self.previous_hidden_states = None
+
+ self.coefficients_dict = {
+ "Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02],
+ "Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01],
+ "Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01],
+ "Wan2.1-I2V-14B-720P": [ 8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02],
+ }
+ if model_id not in self.coefficients_dict:
+ supported_model_ids = ", ".join([i for i in self.coefficients_dict])
+ raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).")
+ self.coefficients = self.coefficients_dict[model_id]
+
+ def check(self, dit: WanModel, x, t_mod):
+ modulated_inp = t_mod.clone()
+ if self.step == 0 or self.step == self.num_inference_steps - 1:
+ should_calc = True
+ self.accumulated_rel_l1_distance = 0
+ else:
+ coefficients = self.coefficients
+ rescale_func = np.poly1d(coefficients)
+ self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
+ if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
+ should_calc = False
+ else:
+ should_calc = True
+ self.accumulated_rel_l1_distance = 0
+ self.previous_modulated_input = modulated_inp
+ self.step += 1
+ if self.step == self.num_inference_steps:
+ self.step = 0
+ if should_calc:
+ self.previous_hidden_states = x.clone()
+ return not should_calc
+
+ def store(self, hidden_states):
+ self.previous_residual = hidden_states - self.previous_hidden_states
+ self.previous_hidden_states = None
+
+ def update(self, hidden_states):
+ hidden_states = hidden_states + self.previous_residual
+ return hidden_states
+
+
+
+def model_fn_wan_video(
+ dit: WanModel,
+ motion_controller: WanMotionControllerModel = None,
+ vace: VaceWanModel = None,
+ x: torch.Tensor = None,
+ timestep: torch.Tensor = None,
+ context: torch.Tensor = None,
+ clip_feature: Optional[torch.Tensor] = None,
+ y: Optional[torch.Tensor] = None,
+ vace_context = None,
+ vace_scale = 1.0,
+ tea_cache: TeaCache = None,
+ use_unified_sequence_parallel: bool = False,
+ motion_bucket_id: Optional[torch.Tensor] = None,
+ **kwargs,
+):
+ if use_unified_sequence_parallel:
+ import torch.distributed as dist
+ from xfuser.core.distributed import (get_sequence_parallel_rank,
+ get_sequence_parallel_world_size,
+ get_sp_group)
+
+ t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))
+ t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim))
+ if motion_bucket_id is not None and motion_controller is not None:
+ t_mod = t_mod + motion_controller(motion_bucket_id).unflatten(1, (6, dit.dim))
+ context = dit.text_embedding(context)
+
+ model_dtype = next(iter(dit.parameters())).dtype
+ model_device = next(iter(dit.parameters())).device
+
+ # Convert inputs to the correct dtype and device
+ x = x.to(dtype=model_dtype, device=model_device)
+ t_mod = t_mod.to(dtype=model_dtype, device=model_device)
+ context = context.to(dtype=model_dtype, device=model_device)
+ if y is not None:
+ y = y.to(dtype=model_dtype, device=model_device)
+ if clip_feature is not None:
+ clip_feature = clip_feature.to(dtype=model_dtype, device=model_device)
+
+ if dit.has_image_input:
+ x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
+ clip_embdding = dit.img_emb(clip_feature)
+ context = torch.cat([clip_embdding, context], dim=1)
+
+ x, (f, h, w) = dit.patchify(x)
+
+ from ..models.wan_video_dit import _VISUALIZE_ATTENTION_CONFIG
+ if _VISUALIZE_ATTENTION_CONFIG["enabled"]:
+ _VISUALIZE_ATTENTION_CONFIG["grid_size"] = (f, h, w)
+
+ freqs = torch.cat([
+ dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
+ dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
+ dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
+ ], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
+
+ # TeaCache
+ if tea_cache is not None:
+ tea_cache_update = tea_cache.check(dit, x, t_mod)
+ else:
+ tea_cache_update = False
+
+ if vace_context is not None:
+ vace_hints = vace(x, vace_context, context, t_mod, freqs)
+
+ # blocks
+ if use_unified_sequence_parallel:
+ if dist.is_initialized() and dist.get_world_size() > 1:
+ x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
+ if tea_cache_update:
+ x = tea_cache.update(x)
+ else:
+ for block_id, block in enumerate(dit.blocks):
+ x = block(x, context, t_mod, freqs)
+ if vace_context is not None and block_id in vace.vace_layers_mapping:
+ x = x + vace_hints[vace.vace_layers_mapping[block_id]] * vace_scale
+ if tea_cache is not None:
+ tea_cache.store(x)
+
+ x = dit.head(x, t)
+ if use_unified_sequence_parallel:
+ if dist.is_initialized() and dist.get_world_size() > 1:
+ x = get_sp_group().all_gather(x, dim=1)
+ x = dit.unpatchify(x, (f, h, w))
+ return x
diff --git a/PusaV1/diffsynth/pipelines/wan_video_pusa.py b/PusaV1/diffsynth/pipelines/wan_video_pusa.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef7f337a68bd1a5e90bb8ceaa2b80fba8750f75c
--- /dev/null
+++ b/PusaV1/diffsynth/pipelines/wan_video_pusa.py
@@ -0,0 +1,656 @@
+import types
+from ..models import ModelManager
+from ..models.wan_video_pusa import WanModelPusa
+from ..models.wan_video_text_encoder import WanTextEncoder
+from ..models.wan_video_vae import WanVideoVAE
+from ..models.wan_video_image_encoder import WanImageEncoder
+from ..models.wan_video_vace import VaceWanModel
+from ..schedulers.flow_match_pusa import FlowMatchSchedulerPusa
+from .base import BasePipeline
+from ..prompters import WanPrompter
+import torch, os
+from einops import rearrange
+import numpy as np
+from PIL import Image
+from tqdm import tqdm
+from typing import Optional
+
+from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
+from ..models.wan_video_text_encoder import T5RelativeEmbedding, T5LayerNorm
+from ..models.wan_video_pusa import RMSNorm, sinusoidal_embedding_1d
+from ..models.wan_video_vae import RMS_norm, CausalConv3d, Upsample
+from ..models.wan_video_motion_controller import WanMotionControllerModel
+
+
+
+class WanVideoPusaPipeline(BasePipeline):
+
+ def __init__(self, device="cuda", torch_dtype=torch.float16, tokenizer_path=None):
+ super().__init__(device=device, torch_dtype=torch_dtype)
+ self.scheduler = FlowMatchSchedulerPusa(shift=5, sigma_min=0.0, extra_one_step=True)
+ self.prompter = WanPrompter(tokenizer_path=tokenizer_path)
+ self.text_encoder: WanTextEncoder = None
+ self.image_encoder: WanImageEncoder = None
+ self.dit: WanModelPusa = None
+ self.vae: WanVideoVAE = None
+ self.motion_controller: WanMotionControllerModel = None
+ self.vace: VaceWanModel = None
+ self.model_names = ['text_encoder', 'dit', 'vae', 'image_encoder', 'motion_controller', 'vace']
+ self.height_division_factor = 16
+ self.width_division_factor = 16
+ self.use_unified_sequence_parallel = False
+
+
+ def enable_vram_management(self, num_persistent_param_in_dit=None):
+ dtype = next(iter(self.text_encoder.parameters())).dtype
+ enable_vram_management(
+ self.text_encoder,
+ module_map = {
+ torch.nn.Linear: AutoWrappedLinear,
+ torch.nn.Embedding: AutoWrappedModule,
+ T5RelativeEmbedding: AutoWrappedModule,
+ T5LayerNorm: AutoWrappedModule,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device="cpu",
+ computation_dtype=self.torch_dtype,
+ computation_device=self.device,
+ ),
+ )
+ dtype = next(iter(self.dit.parameters())).dtype
+ enable_vram_management(
+ self.dit,
+ module_map = {
+ torch.nn.Linear: AutoWrappedLinear,
+ torch.nn.Conv3d: AutoWrappedModule,
+ torch.nn.LayerNorm: AutoWrappedModule,
+ RMSNorm: AutoWrappedModule,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device=self.device,
+ computation_dtype=self.torch_dtype,
+ computation_device=self.device,
+ ),
+ max_num_param=num_persistent_param_in_dit,
+ overflow_module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device="cpu",
+ computation_dtype=self.torch_dtype,
+ computation_device=self.device,
+ ),
+ )
+ dtype = next(iter(self.vae.parameters())).dtype
+ enable_vram_management(
+ self.vae,
+ module_map = {
+ torch.nn.Linear: AutoWrappedLinear,
+ torch.nn.Conv2d: AutoWrappedModule,
+ RMS_norm: AutoWrappedModule,
+ CausalConv3d: AutoWrappedModule,
+ Upsample: AutoWrappedModule,
+ torch.nn.SiLU: AutoWrappedModule,
+ torch.nn.Dropout: AutoWrappedModule,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device=self.device,
+ computation_dtype=self.torch_dtype,
+ computation_device=self.device,
+ ),
+ )
+ if self.image_encoder is not None:
+ dtype = next(iter(self.image_encoder.parameters())).dtype
+ enable_vram_management(
+ self.image_encoder,
+ module_map = {
+ torch.nn.Linear: AutoWrappedLinear,
+ torch.nn.Conv2d: AutoWrappedModule,
+ torch.nn.LayerNorm: AutoWrappedModule,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device="cpu",
+ computation_dtype=dtype,
+ computation_device=self.device,
+ ),
+ )
+ if self.motion_controller is not None:
+ dtype = next(iter(self.motion_controller.parameters())).dtype
+ enable_vram_management(
+ self.motion_controller,
+ module_map = {
+ torch.nn.Linear: AutoWrappedLinear,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device="cpu",
+ computation_dtype=dtype,
+ computation_device=self.device,
+ ),
+ )
+ if self.vace is not None:
+ enable_vram_management(
+ self.vace,
+ module_map = {
+ torch.nn.Linear: AutoWrappedLinear,
+ torch.nn.Conv3d: AutoWrappedModule,
+ torch.nn.LayerNorm: AutoWrappedModule,
+ RMSNorm: AutoWrappedModule,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device=self.device,
+ computation_dtype=self.torch_dtype,
+ computation_device=self.device,
+ ),
+ )
+ self.enable_cpu_offload()
+
+
+ def fetch_models(self, model_manager: ModelManager):
+ text_encoder_model_and_path = model_manager.fetch_model("wan_video_text_encoder", require_model_path=True)
+ if text_encoder_model_and_path is not None:
+ self.text_encoder, tokenizer_path = text_encoder_model_and_path
+ self.prompter.fetch_models(self.text_encoder)
+ self.prompter.fetch_tokenizer(os.path.join(os.path.dirname(tokenizer_path), "google/umt5-xxl"))
+ self.dit = model_manager.fetch_model("wan_video_pusa")
+ self.vae = model_manager.fetch_model("wan_video_vae")
+ self.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
+ self.motion_controller = model_manager.fetch_model("wan_video_motion_controller")
+ self.vace = model_manager.fetch_model("wan_video_vace")
+
+
+ @staticmethod
+ def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None, use_usp=False):
+ if device is None: device = model_manager.device
+ if torch_dtype is None: torch_dtype = model_manager.torch_dtype
+ pipe = WanVideoPusaPipeline(device=device, torch_dtype=torch_dtype)
+ pipe.fetch_models(model_manager)
+ if use_usp:
+ from xfuser.core.distributed import get_sequence_parallel_world_size
+ from ..distributed.xdit_context_parallel import usp_attn_forward, usp_dit_forward
+
+ for block in pipe.dit.blocks:
+ block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
+ pipe.dit.forward = types.MethodType(usp_dit_forward, pipe.dit)
+ pipe.sp_size = get_sequence_parallel_world_size()
+ pipe.use_unified_sequence_parallel = True
+ return pipe
+
+
+ def denoising_model(self):
+ return self.dit
+
+
+ def encode_prompt(self, prompt, positive=True):
+ prompt_emb = self.prompter.encode_prompt(prompt, positive=positive, device=self.device)
+ return {"context": prompt_emb}
+
+
+ def encode_image(self, image, end_image, num_frames, height, width):
+ image = self.preprocess_image(image.resize((width, height))).to(self.device)
+ clip_context = self.image_encoder.encode_image([image])
+ msk = torch.ones(1, num_frames, height//8, width//8, device=self.device)
+ msk[:, 1:] = 0
+ if end_image is not None:
+ end_image = self.preprocess_image(end_image.resize((width, height))).to(self.device)
+ vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1)
+ msk[:, -1:] = 1
+ else:
+ vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)
+
+ msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
+ msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
+ msk = msk.transpose(1, 2)[0]
+
+ y = self.vae.encode([vae_input.to(dtype=self.torch_dtype, device=self.device)], device=self.device)[0]
+ # 720P
+ # (Pdb) y.shape
+ # torch.Size([16, 21, 90, 160])
+ # (Pdb) msk.shape
+ # torch.Size([4, 21, 90, 160])
+ # import pdb; pdb.set_trace()
+
+ y = torch.concat([msk, y])
+ y = y.unsqueeze(0)
+ clip_context = clip_context.to(dtype=self.torch_dtype, device=self.device)
+ y = y.to(dtype=self.torch_dtype, device=self.device)
+ return {"clip_feature": clip_context, "y": y}
+
+ def encode_control_video(self, control_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
+ control_video = self.preprocess_images(control_video)
+ control_video = torch.stack(control_video, dim=2).to(dtype=self.torch_dtype, device=self.device)
+ latents = self.encode_video(control_video, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=self.torch_dtype, device=self.device)
+ return latents
+
+
+ def prepare_controlnet_kwargs(self, control_video, num_frames, height, width, clip_feature=None, y=None, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
+ if control_video is not None:
+ control_latents = self.encode_control_video(control_video, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+ if clip_feature is None or y is None:
+ clip_feature = torch.zeros((1, 257, 1280), dtype=self.torch_dtype, device=self.device)
+ y = torch.zeros((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), dtype=self.torch_dtype, device=self.device)
+ else:
+ y = y[:, -16:]
+ y = torch.concat([control_latents, y], dim=1)
+ return {"clip_feature": clip_feature, "y": y}
+
+
+ def tensor2video(self, frames):
+ frames = rearrange(frames, "C T H W -> T H W C")
+ frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
+ frames = [Image.fromarray(frame) for frame in frames]
+ return frames
+
+
+ def prepare_extra_input(self, latents=None):
+ return {}
+
+
+ def encode_video(self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
+ latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+ return latents
+
+
+ def decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
+ model_dtype = next(iter(self.vae.parameters())).dtype
+ model_device = next(iter(self.vae.parameters())).device
+
+ # Convert latents to the correct dtype and device
+ latents = latents.to(dtype=model_dtype, device=model_device)
+
+ frames = self.vae.decode(latents, device=self.device, tiled=tiled,
+ tile_size=tile_size, tile_stride=tile_stride)
+ return frames
+
+
+ def prepare_unified_sequence_parallel(self):
+ return {"use_unified_sequence_parallel": self.use_unified_sequence_parallel}
+
+
+ def prepare_motion_bucket_id(self, motion_bucket_id):
+ motion_bucket_id = torch.Tensor((motion_bucket_id,)).to(dtype=self.torch_dtype, device=self.device)
+ return {"motion_bucket_id": motion_bucket_id}
+
+
+ def prepare_vace_kwargs(
+ self,
+ latents,
+ vace_video=None, vace_mask=None, vace_reference_image=None, vace_scale=1.0,
+ height=480, width=832, num_frames=81,
+ seed=None, rand_device="cpu",
+ tiled=True, tile_size=(34, 34), tile_stride=(18, 16)
+ ):
+ if vace_video is not None or vace_mask is not None or vace_reference_image is not None:
+ self.load_models_to_device(["vae"])
+ if vace_video is None:
+ vace_video = torch.zeros((1, 3, num_frames, height, width), dtype=self.torch_dtype, device=self.device)
+ else:
+ vace_video = self.preprocess_images(vace_video)
+ vace_video = torch.stack(vace_video, dim=2).to(dtype=self.torch_dtype, device=self.device)
+
+ if vace_mask is None:
+ vace_mask = torch.ones_like(vace_video)
+ else:
+ vace_mask = self.preprocess_images(vace_mask)
+ vace_mask = torch.stack(vace_mask, dim=2).to(dtype=self.torch_dtype, device=self.device)
+
+ inactive = vace_video * (1 - vace_mask) + 0 * vace_mask
+ reactive = vace_video * vace_mask + 0 * (1 - vace_mask)
+ inactive = self.encode_video(inactive, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=self.torch_dtype, device=self.device)
+ reactive = self.encode_video(reactive, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=self.torch_dtype, device=self.device)
+ vace_video_latents = torch.concat((inactive, reactive), dim=1)
+
+ vace_mask_latents = rearrange(vace_mask[0,0], "T (H P) (W Q) -> 1 (P Q) T H W", P=8, Q=8)
+ vace_mask_latents = torch.nn.functional.interpolate(vace_mask_latents, size=((vace_mask_latents.shape[2] + 3) // 4, vace_mask_latents.shape[3], vace_mask_latents.shape[4]), mode='nearest-exact')
+
+ if vace_reference_image is None:
+ pass
+ else:
+ vace_reference_image = self.preprocess_images([vace_reference_image])
+ vace_reference_image = torch.stack(vace_reference_image, dim=2).to(dtype=self.torch_dtype, device=self.device)
+ vace_reference_latents = self.encode_video(vace_reference_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=self.torch_dtype, device=self.device)
+ vace_reference_latents = torch.concat((vace_reference_latents, torch.zeros_like(vace_reference_latents)), dim=1)
+ vace_video_latents = torch.concat((vace_reference_latents, vace_video_latents), dim=2)
+ vace_mask_latents = torch.concat((torch.zeros_like(vace_mask_latents[:, :, :1]), vace_mask_latents), dim=2)
+
+ noise = self.generate_noise((1, 16, 1, latents.shape[3], latents.shape[4]), seed=seed, device=rand_device, dtype=torch.float32)
+ noise = noise.to(dtype=self.torch_dtype, device=self.device)
+ latents = torch.concat((noise, latents), dim=2)
+
+ vace_context = torch.concat((vace_video_latents, vace_mask_latents), dim=1)
+ return latents, {"vace_context": vace_context, "vace_scale": vace_scale}
+ else:
+ return latents, {"vace_context": None, "vace_scale": vace_scale}
+
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt,
+ negative_prompt="",
+ input_image=None,
+ end_image=None,
+ input_video=None,
+ control_video=None,
+ vace_video=None,
+ vace_video_mask=None,
+ vace_reference_image=None,
+ vace_scale=1.0,
+ denoising_strength=1.0,
+ seed=None,
+ rand_device="cpu",
+ height=480,
+ width=832,
+ num_frames=81,
+ cfg_scale=5.0,
+ num_inference_steps=50,
+ sigma_shift=5.0,
+ motion_bucket_id=None,
+ tiled=True,
+ tile_size=(30, 52),
+ tile_stride=(15, 26),
+ tea_cache_l1_thresh=None,
+ tea_cache_model_id="",
+ progress_bar_cmd=tqdm,
+ progress_bar_st=None,
+ visualize_attention=False,
+ output_dir=None,
+ ):
+ # Parameter check
+ height, width = self.check_resize_height_width(height, width)
+ if num_frames % 4 != 1:
+ num_frames = (num_frames + 2) // 4 * 4 + 1
+ print(f"Only `num_frames % 4 != 1` is acceptable. We round it up to {num_frames}.")
+
+ # import ipdb; ipdb.set_trace()
+ if visualize_attention:
+ import datetime
+ import os
+ from ..models.wan_video_pusa import _VISUALIZE_ATTENTION_CONFIG
+
+ if output_dir:
+ vis_path = os.path.join(output_dir,"attention_maps")
+ else:
+ timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
+ vis_path = os.path.join("attention_maps", timestamp)
+ os.makedirs(vis_path, exist_ok=True)
+
+ _VISUALIZE_ATTENTION_CONFIG["enabled"] = True
+ _VISUALIZE_ATTENTION_CONFIG["path"] = vis_path
+ print(f"Attention visualization enabled. Maps will be saved to {vis_path}")
+
+ # Tiler parameters
+ tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
+
+ # Scheduler
+ self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
+
+ # Initialize noise
+ noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=torch.float32)
+ noise = noise.to(dtype=self.torch_dtype, device=self.device)
+ if input_video is not None:
+ self.load_models_to_device(['vae'])
+ input_video = self.preprocess_images(input_video)
+ input_video = torch.stack(input_video, dim=2).to(dtype=self.torch_dtype, device=self.device)
+ latents = self.encode_video(input_video, **tiler_kwargs).to(dtype=self.torch_dtype, device=self.device)
+ latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
+ else:
+ latents = noise
+
+ # Encode prompts
+ self.load_models_to_device(["text_encoder"])
+ prompt_emb_posi = self.encode_prompt(prompt, positive=True)
+ if cfg_scale != 1.0:
+ prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
+
+ # Encode image
+ if input_image is not None and self.image_encoder is not None:
+ self.load_models_to_device(["image_encoder", "vae"])
+ image_emb = self.encode_image(input_image, end_image, num_frames, height, width)
+ else:
+ image_emb = {}
+
+ # ControlNet
+ if control_video is not None:
+ self.load_models_to_device(["image_encoder", "vae"])
+ image_emb = self.prepare_controlnet_kwargs(control_video, num_frames, height, width, **image_emb, **tiler_kwargs)
+
+ # Motion Controller
+ if self.motion_controller is not None and motion_bucket_id is not None:
+ motion_kwargs = self.prepare_motion_bucket_id(motion_bucket_id)
+ else:
+ motion_kwargs = {}
+
+ # Extra input
+ extra_input = self.prepare_extra_input(latents)
+
+ # VACE
+ latents, vace_kwargs = self.prepare_vace_kwargs(
+ latents, vace_video, vace_video_mask, vace_reference_image, vace_scale,
+ height=height, width=width, num_frames=num_frames, seed=seed, rand_device=rand_device, **tiler_kwargs
+ )
+
+ # TeaCache
+ tea_cache_posi = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None}
+ tea_cache_nega = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None}
+
+ # Unified Sequence Parallel
+ usp_kwargs = self.prepare_unified_sequence_parallel()
+
+ if input_image is not None:
+ latents[:,:,0:1] = image_emb["y"][:,4:,0:1]
+
+ # Denoise
+ self.load_models_to_device(["dit", "motion_controller", "vace"])
+ for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
+ if visualize_attention:
+ from ..models.wan_video_pusa import _VISUALIZE_ATTENTION_CONFIG
+ _VISUALIZE_ATTENTION_CONFIG["step"] = progress_id
+
+ timestep = timestep.unsqueeze(0).unsqueeze(1).repeat(1, latents.shape[2]).to(dtype=self.torch_dtype, device=self.device)
+
+ timestep[:,0] = 0
+
+ print("timestep", timestep[0])
+
+ # Inference
+ noise_pred_posi = model_fn_wan_video(
+ self.dit, motion_controller=self.motion_controller, vace=self.vace,
+ x=latents, timestep=timestep,
+ **prompt_emb_posi, **image_emb, **extra_input,
+ **tea_cache_posi, **usp_kwargs, **motion_kwargs, **vace_kwargs,
+ )
+ if cfg_scale != 1.0:
+ noise_pred_nega = model_fn_wan_video(
+ self.dit, motion_controller=self.motion_controller, vace=self.vace,
+ x=latents, timestep=timestep,
+ **prompt_emb_nega, **image_emb, **extra_input,
+ **tea_cache_nega, **usp_kwargs, **motion_kwargs, **vace_kwargs,
+ )
+ noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
+ else:
+ noise_pred = noise_pred_posi
+
+ latents = self.scheduler.step(noise_pred, timestep, latents)
+
+
+ if visualize_attention:
+ from ..models.wan_video_pusa import _VISUALIZE_ATTENTION_CONFIG
+ _VISUALIZE_ATTENTION_CONFIG["enabled"] = False
+ _VISUALIZE_ATTENTION_CONFIG["path"] = None
+ print("Attention visualization finished.")
+
+ if vace_reference_image is not None:
+ latents = latents[:, :, 1:]
+
+ # Decode
+ self.load_models_to_device(['vae'])
+ frames = self.decode_video(latents, **tiler_kwargs)
+ self.load_models_to_device([])
+ frames = self.tensor2video(frames[0])
+
+ return frames
+
+
+
+class TeaCache:
+ def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
+ self.num_inference_steps = num_inference_steps
+ self.step = 0
+ self.accumulated_rel_l1_distance = 0
+ self.previous_modulated_input = None
+ self.rel_l1_thresh = rel_l1_thresh
+ self.previous_residual = None
+ self.previous_hidden_states = None
+
+ self.coefficients_dict = {
+ "Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02],
+ "Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01],
+ "Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01],
+ "Wan2.1-I2V-14B-720P": [ 8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02],
+ }
+ if model_id not in self.coefficients_dict:
+ supported_model_ids = ", ".join([i for i in self.coefficients_dict])
+ raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).")
+ self.coefficients = self.coefficients_dict[model_id]
+
+ def check(self, dit: WanModelPusa, x, t_mod):
+ modulated_inp = t_mod.clone()
+ if self.step == 0 or self.step == self.num_inference_steps - 1:
+ should_calc = True
+ self.accumulated_rel_l1_distance = 0
+ else:
+ coefficients = self.coefficients
+ rescale_func = np.poly1d(coefficients)
+ self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
+ if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
+ should_calc = False
+ else:
+ should_calc = True
+ self.accumulated_rel_l1_distance = 0
+ self.previous_modulated_input = modulated_inp
+ self.step += 1
+ if self.step == self.num_inference_steps:
+ self.step = 0
+ if should_calc:
+ self.previous_hidden_states = x.clone()
+ return not should_calc
+
+ def store(self, hidden_states):
+ self.previous_residual = hidden_states - self.previous_hidden_states
+ self.previous_hidden_states = None
+
+ def update(self, hidden_states):
+ hidden_states = hidden_states + self.previous_residual
+ return hidden_states
+
+
+
+def model_fn_wan_video(
+ dit: WanModelPusa,
+ motion_controller: WanMotionControllerModel = None,
+ vace: VaceWanModel = None,
+ x: torch.Tensor = None,
+ timestep: torch.Tensor = None,
+ context: torch.Tensor = None,
+ clip_feature: Optional[torch.Tensor] = None,
+ y: Optional[torch.Tensor] = None,
+ vace_context = None,
+ vace_scale = 1.0,
+ tea_cache: TeaCache = None,
+ use_unified_sequence_parallel: bool = False,
+ motion_bucket_id: Optional[torch.Tensor] = None,
+ **kwargs,
+):
+ if use_unified_sequence_parallel:
+ import torch.distributed as dist
+ from xfuser.core.distributed import (get_sequence_parallel_rank,
+ get_sequence_parallel_world_size,
+ get_sp_group)
+
+ t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))
+ t_mod = dit.time_projection(t).unflatten(2, (6, dit.dim))
+ if motion_bucket_id is not None and motion_controller is not None:
+ t_mod = t_mod + motion_controller(motion_bucket_id).unflatten(1, (6, dit.dim))
+ context = dit.text_embedding(context)
+
+
+ B, C, T, H, W = x.shape
+ pH, pW = H // dit.patch_size[1], W // dit.patch_size[2]
+
+ t = t.unsqueeze(2).unsqueeze(3).repeat(1, 1, pH, pW, 1)
+ t = rearrange(t, 'b f h w d -> b (f h w) d').contiguous()
+ t_mod = t_mod.unsqueeze(3).unsqueeze(4).repeat(1, 1, 1, pH, pW, 1)
+ t_mod = rearrange(t_mod, 'b f m h w d -> b m (f h w) d').contiguous()
+
+
+ model_dtype = next(iter(dit.parameters())).dtype
+ model_device = next(iter(dit.parameters())).device
+
+ # Convert inputs to the correct dtype and device
+ x = x.to(dtype=model_dtype, device=model_device)
+ t_mod = t_mod.to(dtype=model_dtype, device=model_device)
+ context = context.to(dtype=model_dtype, device=model_device)
+ if y is not None:
+ y = y.to(dtype=model_dtype, device=model_device)
+ if clip_feature is not None:
+ clip_feature = clip_feature.to(dtype=model_dtype, device=model_device)
+
+ x, (f, h, w) = dit.patchify(x)
+
+ from ..models.wan_video_pusa import _VISUALIZE_ATTENTION_CONFIG
+ if _VISUALIZE_ATTENTION_CONFIG["enabled"]:
+ _VISUALIZE_ATTENTION_CONFIG["grid_size"] = (f, h, w)
+
+ freqs = torch.cat([
+ dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
+ dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
+ dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
+ ], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
+
+ # TeaCache
+ if tea_cache is not None:
+ tea_cache_update = tea_cache.check(dit, x, t_mod)
+ else:
+ tea_cache_update = False
+
+ if vace_context is not None:
+ vace_hints = vace(x, vace_context, context, t_mod, freqs)
+
+ # blocks
+ if use_unified_sequence_parallel:
+ if dist.is_initialized() and dist.get_world_size() > 1:
+ x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
+ if tea_cache_update:
+ x = tea_cache.update(x)
+ else:
+ for block_id, block in enumerate(dit.blocks):
+ x = block(x, context, t_mod, freqs)
+ if vace_context is not None and block_id in vace.vace_layers_mapping:
+ x = x + vace_hints[vace.vace_layers_mapping[block_id]] * vace_scale
+ if tea_cache is not None:
+ tea_cache.store(x)
+
+ x = dit.head(x, t)
+ if use_unified_sequence_parallel:
+ if dist.is_initialized() and dist.get_world_size() > 1:
+ x = get_sp_group().all_gather(x, dim=1)
+ x = dit.unpatchify(x, (f, h, w))
+ return x
diff --git a/PusaV1/diffsynth/pipelines/wan_video_pusa_multi_frames.py b/PusaV1/diffsynth/pipelines/wan_video_pusa_multi_frames.py
new file mode 100644
index 0000000000000000000000000000000000000000..8bfc08da5bfeb0f24acd30073a8a8616d628661f
--- /dev/null
+++ b/PusaV1/diffsynth/pipelines/wan_video_pusa_multi_frames.py
@@ -0,0 +1,690 @@
+import types
+from ..models import ModelManager
+from ..models.wan_video_pusa import WanModelPusa
+from ..models.wan_video_text_encoder import WanTextEncoder
+from ..models.wan_video_vae import WanVideoVAE
+from ..models.wan_video_image_encoder import WanImageEncoder
+from ..models.wan_video_vace import VaceWanModel
+from ..schedulers.flow_match_pusa_multi_frames import FlowMatchSchedulerPusaMultiFrames
+from .base import BasePipeline
+from ..prompters import WanPrompter
+import torch, os
+from einops import rearrange
+import numpy as np
+from PIL import Image
+from tqdm import tqdm
+from typing import Optional
+
+from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
+from ..models.wan_video_text_encoder import T5RelativeEmbedding, T5LayerNorm
+from ..models.wan_video_pusa import RMSNorm, sinusoidal_embedding_1d
+from ..models.wan_video_vae import RMS_norm, CausalConv3d, Upsample
+from ..models.wan_video_motion_controller import WanMotionControllerModel
+
+
+
+class PusaMultiFramesPipeline(BasePipeline):
+
+ def __init__(self, device="cuda", torch_dtype=torch.float16, tokenizer_path=None):
+ super().__init__(device=device, torch_dtype=torch_dtype)
+ self.scheduler = FlowMatchSchedulerPusaMultiFrames(shift=5, sigma_min=0.0, extra_one_step=True)
+ self.prompter = WanPrompter(tokenizer_path=tokenizer_path)
+ self.text_encoder: WanTextEncoder = None
+ self.image_encoder: WanImageEncoder = None
+ self.dit: WanModelPusa = None
+ self.vae: WanVideoVAE = None
+ self.motion_controller: WanMotionControllerModel = None
+ self.vace: VaceWanModel = None
+ self.model_names = ['text_encoder', 'dit', 'vae', 'image_encoder', 'motion_controller', 'vace']
+ self.height_division_factor = 16
+ self.width_division_factor = 16
+ self.use_unified_sequence_parallel = False
+
+
+ def enable_vram_management(self, num_persistent_param_in_dit=None):
+ dtype = next(iter(self.text_encoder.parameters())).dtype
+ enable_vram_management(
+ self.text_encoder,
+ module_map = {
+ torch.nn.Linear: AutoWrappedLinear,
+ torch.nn.Embedding: AutoWrappedModule,
+ T5RelativeEmbedding: AutoWrappedModule,
+ T5LayerNorm: AutoWrappedModule,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device="cpu",
+ computation_dtype=self.torch_dtype,
+ computation_device=self.device,
+ ),
+ )
+ dtype = next(iter(self.dit.parameters())).dtype
+ enable_vram_management(
+ self.dit,
+ module_map = {
+ torch.nn.Linear: AutoWrappedLinear,
+ torch.nn.Conv3d: AutoWrappedModule,
+ torch.nn.LayerNorm: AutoWrappedModule,
+ RMSNorm: AutoWrappedModule,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device=self.device,
+ computation_dtype=self.torch_dtype,
+ computation_device=self.device,
+ ),
+ max_num_param=num_persistent_param_in_dit,
+ overflow_module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device="cpu",
+ computation_dtype=self.torch_dtype,
+ computation_device=self.device,
+ ),
+ )
+ dtype = next(iter(self.vae.parameters())).dtype
+ enable_vram_management(
+ self.vae,
+ module_map = {
+ torch.nn.Linear: AutoWrappedLinear,
+ torch.nn.Conv2d: AutoWrappedModule,
+ RMS_norm: AutoWrappedModule,
+ CausalConv3d: AutoWrappedModule,
+ Upsample: AutoWrappedModule,
+ torch.nn.SiLU: AutoWrappedModule,
+ torch.nn.Dropout: AutoWrappedModule,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device=self.device,
+ computation_dtype=self.torch_dtype,
+ computation_device=self.device,
+ ),
+ )
+ if self.image_encoder is not None:
+ dtype = next(iter(self.image_encoder.parameters())).dtype
+ enable_vram_management(
+ self.image_encoder,
+ module_map = {
+ torch.nn.Linear: AutoWrappedLinear,
+ torch.nn.Conv2d: AutoWrappedModule,
+ torch.nn.LayerNorm: AutoWrappedModule,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device="cpu",
+ computation_dtype=dtype,
+ computation_device=self.device,
+ ),
+ )
+ if self.motion_controller is not None:
+ dtype = next(iter(self.motion_controller.parameters())).dtype
+ enable_vram_management(
+ self.motion_controller,
+ module_map = {
+ torch.nn.Linear: AutoWrappedLinear,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device="cpu",
+ computation_dtype=dtype,
+ computation_device=self.device,
+ ),
+ )
+ if self.vace is not None:
+ enable_vram_management(
+ self.vace,
+ module_map = {
+ torch.nn.Linear: AutoWrappedLinear,
+ torch.nn.Conv3d: AutoWrappedModule,
+ torch.nn.LayerNorm: AutoWrappedModule,
+ RMSNorm: AutoWrappedModule,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device=self.device,
+ computation_dtype=self.torch_dtype,
+ computation_device=self.device,
+ ),
+ )
+ self.enable_cpu_offload()
+
+
+ def fetch_models(self, model_manager: ModelManager):
+ text_encoder_model_and_path = model_manager.fetch_model("wan_video_text_encoder", require_model_path=True)
+ if text_encoder_model_and_path is not None:
+ self.text_encoder, tokenizer_path = text_encoder_model_and_path
+ self.prompter.fetch_models(self.text_encoder)
+ self.prompter.fetch_tokenizer(os.path.join(os.path.dirname(tokenizer_path), "google/umt5-xxl"))
+ self.dit = model_manager.fetch_model("wan_video_pusa")
+ self.vae = model_manager.fetch_model("wan_video_vae")
+ self.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
+ self.motion_controller = model_manager.fetch_model("wan_video_motion_controller")
+ self.vace = model_manager.fetch_model("wan_video_vace")
+
+
+ @staticmethod
+ def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None, use_usp=False):
+ if device is None: device = model_manager.device
+ if torch_dtype is None: torch_dtype = model_manager.torch_dtype
+ pipe = PusaMultiFramesPipeline(device=device, torch_dtype=torch_dtype)
+ pipe.fetch_models(model_manager)
+ if use_usp:
+ from xfuser.core.distributed import get_sequence_parallel_world_size
+ from ..distributed.xdit_context_parallel import usp_attn_forward, usp_dit_forward
+
+ for block in pipe.dit.blocks:
+ block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
+ pipe.dit.forward = types.MethodType(usp_dit_forward, pipe.dit)
+ pipe.sp_size = get_sequence_parallel_world_size()
+ pipe.use_unified_sequence_parallel = True
+ return pipe
+
+
+ def denoising_model(self):
+ return self.dit
+
+
+ def encode_prompt(self, prompt, positive=True):
+ prompt_emb = self.prompter.encode_prompt(prompt, positive=positive, device=self.device)
+ return {"context": prompt_emb}
+
+
+ def encode_single_image(self, image: Image.Image, height: int, width: int, tiled: bool, tile_size: tuple, tile_stride: tuple):
+ self.load_models_to_device(["vae"])
+ image = self.preprocess_image(image.resize((width, height), resample=Image.LANCZOS)).to(self.device)
+ image_tensor = image.unsqueeze(2)
+ image_tensor = image_tensor.to(dtype=self.torch_dtype, device=self.device)
+ latents = self.vae.encode(image_tensor, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+ return latents
+
+
+ def encode_image(self, image, end_image, num_frames, height, width):
+ image = self.preprocess_image(image.resize((width, height))).to(self.device)
+ clip_context = self.image_encoder.encode_image([image])
+ msk = torch.ones(1, num_frames, height//8, width//8, device=self.device)
+ msk[:, 1:] = 0
+ if end_image is not None:
+ end_image = self.preprocess_image(end_image.resize((width, height))).to(self.device)
+ vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1)
+ msk[:, -1:] = 1
+ else:
+ vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)
+
+ msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
+ msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
+ msk = msk.transpose(1, 2)[0]
+
+ y = self.vae.encode([vae_input.to(dtype=self.torch_dtype, device=self.device)], device=self.device)[0]
+
+ y = torch.concat([msk, y])
+ y = y.unsqueeze(0)
+ clip_context = clip_context.to(dtype=self.torch_dtype, device=self.device)
+ y = y.to(dtype=self.torch_dtype, device=self.device)
+ return {"clip_feature": clip_context, "y": y}
+
+ def encode_control_video(self, control_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
+ control_video = self.preprocess_images(control_video)
+ control_video = torch.stack(control_video, dim=2).to(dtype=self.torch_dtype, device=self.device)
+ latents = self.encode_video(control_video, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=self.torch_dtype, device=self.device)
+ return latents
+
+
+ def prepare_controlnet_kwargs(self, control_video, num_frames, height, width, clip_feature=None, y=None, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
+ if control_video is not None:
+ control_latents = self.encode_control_video(control_video, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+ if clip_feature is None or y is None:
+ clip_feature = torch.zeros((1, 257, 1280), dtype=self.torch_dtype, device=self.device)
+ y = torch.zeros((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), dtype=self.torch_dtype, device=self.device)
+ else:
+ y = y[:, -16:]
+ y = torch.concat([control_latents, y], dim=1)
+ return {"clip_feature": clip_feature, "y": y}
+
+
+ def tensor2video(self, frames):
+ frames = rearrange(frames, "C T H W -> T H W C")
+ frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
+ frames = [Image.fromarray(frame) for frame in frames]
+ return frames
+
+
+ def prepare_extra_input(self, latents=None):
+ return {}
+
+
+ def encode_video(self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
+ latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+ return latents
+
+
+ def decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
+ model_dtype = next(iter(self.vae.parameters())).dtype
+ model_device = next(iter(self.vae.parameters())).device
+
+ # Convert latents to the correct dtype and device
+ latents = latents.to(dtype=model_dtype, device=model_device)
+
+ frames = self.vae.decode(latents, device=self.device, tiled=tiled,
+ tile_size=tile_size, tile_stride=tile_stride)
+ return frames
+
+
+ def prepare_unified_sequence_parallel(self):
+ return {"use_unified_sequence_parallel": self.use_unified_sequence_parallel}
+
+
+ def prepare_motion_bucket_id(self, motion_bucket_id):
+ motion_bucket_id = torch.Tensor((motion_bucket_id,)).to(dtype=self.torch_dtype, device=self.device)
+ return {"motion_bucket_id": motion_bucket_id}
+
+
+ def prepare_vace_kwargs(
+ self,
+ latents,
+ vace_video=None, vace_mask=None, vace_reference_image=None, vace_scale=1.0,
+ height=480, width=832, num_frames=81,
+ seed=None, rand_device="cpu",
+ tiled=True, tile_size=(34, 34), tile_stride=(18, 16)
+ ):
+ if vace_video is not None or vace_mask is not None or vace_reference_image is not None:
+ self.load_models_to_device(["vae"])
+ if vace_video is None:
+ vace_video = torch.zeros((1, 3, num_frames, height, width), dtype=self.torch_dtype, device=self.device)
+ else:
+ vace_video = self.preprocess_images(vace_video)
+ vace_video = torch.stack(vace_video, dim=2).to(dtype=self.torch_dtype, device=self.device)
+
+ if vace_mask is None:
+ vace_mask = torch.ones_like(vace_video)
+ else:
+ vace_mask = self.preprocess_images(vace_mask)
+ vace_mask = torch.stack(vace_mask, dim=2).to(dtype=self.torch_dtype, device=self.device)
+
+ inactive = vace_video * (1 - vace_mask) + 0 * vace_mask
+ reactive = vace_video * vace_mask + 0 * (1 - vace_mask)
+ inactive = self.encode_video(inactive, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=self.torch_dtype, device=self.device)
+ reactive = self.encode_video(reactive, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=self.torch_dtype, device=self.device)
+ vace_video_latents = torch.concat((inactive, reactive), dim=1)
+
+ vace_mask_latents = rearrange(vace_mask[0,0], "T (H P) (W Q) -> 1 (P Q) T H W", P=8, Q=8)
+ vace_mask_latents = torch.nn.functional.interpolate(vace_mask_latents, size=((vace_mask_latents.shape[2] + 3) // 4, vace_mask_latents.shape[3], vace_mask_latents.shape[4]), mode='nearest-exact')
+
+ if vace_reference_image is None:
+ pass
+ else:
+ vace_reference_image = self.preprocess_images([vace_reference_image])
+ vace_reference_image = torch.stack(vace_reference_image, dim=2).to(dtype=self.torch_dtype, device=self.device)
+ vace_reference_latents = self.encode_video(vace_reference_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=self.torch_dtype, device=self.device)
+ vace_reference_latents = torch.concat((vace_reference_latents, torch.zeros_like(vace_reference_latents)), dim=1)
+ vace_video_latents = torch.concat((vace_reference_latents, vace_video_latents), dim=2)
+ vace_mask_latents = torch.concat((torch.zeros_like(vace_mask_latents[:, :, :1]), vace_mask_latents), dim=2)
+
+ noise = self.generate_noise((1, 16, 1, latents.shape[3], latents.shape[4]), seed=seed, device=rand_device, dtype=torch.float32)
+ noise = noise.to(dtype=self.torch_dtype, device=self.device)
+ latents = torch.concat((noise, latents), dim=2)
+
+ vace_context = torch.concat((vace_video_latents, vace_mask_latents), dim=1)
+ return latents, {"vace_context": vace_context, "vace_scale": vace_scale}
+ else:
+ return latents, {"vace_context": None, "vace_scale": vace_scale}
+
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt,
+ negative_prompt="",
+ input_image=None,
+ end_image=None,
+ multi_frame_images: Optional[dict] = None,
+ input_video=None,
+ control_video=None,
+ vace_video=None,
+ vace_video_mask=None,
+ vace_reference_image=None,
+ vace_scale=1.0,
+ denoising_strength=1.0,
+ seed=None,
+ rand_device="cpu",
+ height=480,
+ width=832,
+ num_frames=81,
+ cfg_scale=5.0,
+ num_inference_steps=50,
+ sigma_shift=5.0,
+ motion_bucket_id=None,
+ tiled=True,
+ tile_size=(30, 52),
+ tile_stride=(15, 26),
+ tea_cache_l1_thresh=None,
+ tea_cache_model_id="",
+ progress_bar_cmd=tqdm,
+ progress_bar_st=None,
+ visualize_attention=False,
+ output_dir=None,
+ ):
+ # Parameter check
+ height, width = self.check_resize_height_width(height, width)
+ if num_frames % 4 != 1:
+ num_frames = (num_frames + 2) // 4 * 4 + 1
+ print(f"Only `num_frames % 4 != 1` is acceptable. We round it up to {num_frames}.")
+
+ # import ipdb; ipdb.set_trace()
+ if visualize_attention:
+ import datetime
+ import os
+ from ..models.wan_video_pusa import _VISUALIZE_ATTENTION_CONFIG
+
+ if output_dir:
+ vis_path = os.path.join(output_dir,"attention_maps")
+ else:
+ timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
+ vis_path = os.path.join("attention_maps", timestamp)
+ os.makedirs(vis_path, exist_ok=True)
+
+ _VISUALIZE_ATTENTION_CONFIG["enabled"] = True
+ _VISUALIZE_ATTENTION_CONFIG["path"] = vis_path
+ print(f"Attention visualization enabled. Maps will be saved to {vis_path}")
+
+ # Tiler parameters
+ tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
+
+ # Scheduler
+ self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
+ # self.scheduler.set_timesteps(1000, num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
+
+ # Initialize noise
+ latent_size = (1, 16, (num_frames - 1) // 4 + 1, height//8, width//8)
+ noise = self.generate_noise(latent_size, seed=seed, device=rand_device, dtype=torch.float32)
+ noise = noise.to(dtype=self.torch_dtype, device=self.device)
+
+ if input_video is not None:
+ self.load_models_to_device(['vae'])
+ input_video = self.preprocess_images(input_video)
+ input_video = torch.stack(input_video, dim=2).to(dtype=self.torch_dtype, device=self.device)
+ latents = self.encode_video(input_video, **tiler_kwargs).to(dtype=self.torch_dtype, device=self.device)
+ latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
+ else:
+ latents = noise
+
+ cond_frame_latent_indices = []
+ noise_multipliers = {}
+ if multi_frame_images is not None:
+ self.load_models_to_device(['vae'])
+ for frame_idx, image_info in multi_frame_images.items():
+ image, noise_mult = image_info
+ latent_idx = frame_idx
+ if 0 <= latent_idx < latents.shape[2]:
+ cond_frame_latent_indices.append(latent_idx)
+ noise_multipliers[latent_idx] = noise_mult
+ cond_latent = self.encode_single_image(image, height, width, **tiler_kwargs)
+ latents[:, :, latent_idx:latent_idx+1] = cond_latent.to(latents.device)
+
+
+ # Encode prompts
+ self.load_models_to_device(["text_encoder"])
+ prompt_emb_posi = self.encode_prompt(prompt, positive=True)
+ if cfg_scale != 1.0:
+ prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
+
+ # Encode image
+ if input_image is not None and self.image_encoder is not None:
+ self.load_models_to_device(["image_encoder", "vae"])
+ image_emb = self.encode_image(input_image, end_image, num_frames, height, width)
+ else:
+ image_emb = {}
+
+ # ControlNet
+ if control_video is not None:
+ self.load_models_to_device(["image_encoder", "vae"])
+ image_emb = self.prepare_controlnet_kwargs(control_video, num_frames, height, width, **image_emb, **tiler_kwargs)
+
+ # Motion Controller
+ if self.motion_controller is not None and motion_bucket_id is not None:
+ motion_kwargs = self.prepare_motion_bucket_id(motion_bucket_id)
+ else:
+ motion_kwargs = {}
+
+ # Extra input
+ extra_input = self.prepare_extra_input(latents)
+
+ # VACE
+ latents, vace_kwargs = self.prepare_vace_kwargs(
+ latents, vace_video, vace_video_mask, vace_reference_image, vace_scale,
+ height=height, width=width, num_frames=num_frames, seed=seed, rand_device=rand_device, **tiler_kwargs
+ )
+
+ # TeaCache
+ tea_cache_posi = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None}
+ tea_cache_nega = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None}
+
+ # Unified Sequence Parallel
+ usp_kwargs = self.prepare_unified_sequence_parallel()
+
+ if input_image is not None:
+ latents[:,:,0:1] = image_emb["y"][:,4:,0:1]
+
+ # Denoise
+ self.load_models_to_device(["dit", "motion_controller", "vace"])
+ for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
+ if visualize_attention:
+ from ..models.wan_video_pusa import _VISUALIZE_ATTENTION_CONFIG
+ _VISUALIZE_ATTENTION_CONFIG["step"] = progress_id
+ timestep = timestep.unsqueeze(0).unsqueeze(1).repeat(1, latents.shape[2]).to(dtype=self.torch_dtype, device=self.device)
+
+
+ if input_image is not None:
+ timestep[:,0] = 0
+ for latent_idx in cond_frame_latent_indices:
+ multiplier = noise_multipliers.get(latent_idx, 1.0)
+ timestep[:,latent_idx] = timestep[:,latent_idx] * multiplier # timestep = sigma * 1000, equivalent
+ timestep = timestep.to(torch.long).to(dtype=self.torch_dtype, device=self.device)
+
+ print("timestep", timestep[0])
+
+ # Inference
+ noise_pred_posi = model_fn_wan_video(
+ self.dit, motion_controller=self.motion_controller, vace=self.vace,
+ x=latents, timestep=timestep,
+ **prompt_emb_posi, **image_emb, **extra_input,
+ **tea_cache_posi, **usp_kwargs, **motion_kwargs, **vace_kwargs,
+ )
+ if cfg_scale != 1.0:
+ noise_pred_nega = model_fn_wan_video(
+ self.dit, motion_controller=self.motion_controller, vace=self.vace,
+ x=latents, timestep=timestep,
+ **prompt_emb_nega, **image_emb, **extra_input,
+ **tea_cache_nega, **usp_kwargs, **motion_kwargs, **vace_kwargs,
+ )
+ noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
+ else:
+ noise_pred = noise_pred_posi
+
+ # Scheduler
+ latents = self.scheduler.step(
+ noise_pred,
+ timestep,
+ latents,
+ cond_frame_latent_indices=cond_frame_latent_indices,
+ noise_multipliers=noise_multipliers
+ )
+
+
+ if visualize_attention:
+ from ..models.wan_video_pusa import _VISUALIZE_ATTENTION_CONFIG
+ _VISUALIZE_ATTENTION_CONFIG["enabled"] = False
+ _VISUALIZE_ATTENTION_CONFIG["path"] = None
+ print("Attention visualization finished.")
+
+ if vace_reference_image is not None:
+ latents = latents[:, :, 1:]
+
+ # Decode
+ self.load_models_to_device(['vae'])
+ frames = self.decode_video(latents, **tiler_kwargs)
+ self.load_models_to_device([])
+ frames = self.tensor2video(frames[0])
+
+ return frames
+
+
+
+class TeaCache:
+ def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
+ self.num_inference_steps = num_inference_steps
+ self.step = 0
+ self.accumulated_rel_l1_distance = 0
+ self.previous_modulated_input = None
+ self.rel_l1_thresh = rel_l1_thresh
+ self.previous_residual = None
+ self.previous_hidden_states = None
+
+ self.coefficients_dict = {
+ "Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02],
+ "Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01],
+ "Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01],
+ "Wan2.1-I2V-14B-720P": [ 8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02],
+ }
+ if model_id not in self.coefficients_dict:
+ supported_model_ids = ", ".join([i for i in self.coefficients_dict])
+ raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).")
+ self.coefficients = self.coefficients_dict[model_id]
+
+ def check(self, dit: WanModelPusa, x, t_mod):
+ modulated_inp = t_mod.clone()
+ if self.step == 0 or self.step == self.num_inference_steps - 1:
+ should_calc = True
+ self.accumulated_rel_l1_distance = 0
+ else:
+ coefficients = self.coefficients
+ rescale_func = np.poly1d(coefficients)
+ self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
+ if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
+ should_calc = False
+ else:
+ should_calc = True
+ self.accumulated_rel_l1_distance = 0
+ self.previous_modulated_input = modulated_inp
+ self.step += 1
+ if self.step == self.num_inference_steps:
+ self.step = 0
+ if should_calc:
+ self.previous_hidden_states = x.clone()
+ return not should_calc
+
+ def store(self, hidden_states):
+ self.previous_residual = hidden_states - self.previous_hidden_states
+ self.previous_hidden_states = None
+
+ def update(self, hidden_states):
+ hidden_states = hidden_states + self.previous_residual
+ return hidden_states
+
+
+
+def model_fn_wan_video(
+ dit: WanModelPusa,
+ motion_controller: WanMotionControllerModel = None,
+ vace: VaceWanModel = None,
+ x: torch.Tensor = None,
+ timestep: torch.Tensor = None,
+ context: torch.Tensor = None,
+ clip_feature: Optional[torch.Tensor] = None,
+ y: Optional[torch.Tensor] = None,
+ vace_context = None,
+ vace_scale = 1.0,
+ tea_cache: TeaCache = None,
+ use_unified_sequence_parallel: bool = False,
+ motion_bucket_id: Optional[torch.Tensor] = None,
+ **kwargs,
+):
+ if use_unified_sequence_parallel:
+ import torch.distributed as dist
+ from xfuser.core.distributed import (get_sequence_parallel_rank,
+ get_sequence_parallel_world_size,
+ get_sp_group)
+
+ t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))
+ t_mod = dit.time_projection(t).unflatten(2, (6, dit.dim)) #TODO
+ if motion_bucket_id is not None and motion_controller is not None:
+ t_mod = t_mod + motion_controller(motion_bucket_id).unflatten(1, (6, dit.dim))
+ context = dit.text_embedding(context)
+
+
+ B, C, T, H, W = x.shape
+ pH, pW = H // dit.patch_size[1], W // dit.patch_size[2]
+
+
+ t = t.unsqueeze(2).unsqueeze(3).repeat(1, 1, pH, pW, 1)
+ t = rearrange(t, 'b f h w d -> b (f h w) d').contiguous()
+ t_mod = t_mod.unsqueeze(3).unsqueeze(4).repeat(1, 1, 1, pH, pW, 1)
+ t_mod = rearrange(t_mod, 'b f m h w d -> b m (f h w) d').contiguous()
+
+
+ model_dtype = next(iter(dit.parameters())).dtype
+ model_device = next(iter(dit.parameters())).device
+
+ # Convert inputs to the correct dtype and device
+ x = x.to(dtype=model_dtype, device=model_device)
+ t_mod = t_mod.to(dtype=model_dtype, device=model_device)
+ context = context.to(dtype=model_dtype, device=model_device)
+ if y is not None:
+ y = y.to(dtype=model_dtype, device=model_device)
+ if clip_feature is not None:
+ clip_feature = clip_feature.to(dtype=model_dtype, device=model_device)
+
+ x, (f, h, w) = dit.patchify(x)
+
+ from ..models.wan_video_pusa import _VISUALIZE_ATTENTION_CONFIG
+ if _VISUALIZE_ATTENTION_CONFIG["enabled"]:
+ _VISUALIZE_ATTENTION_CONFIG["grid_size"] = (f, h, w)
+
+ freqs = torch.cat([
+ dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
+ dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
+ dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
+ ], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
+
+ # TeaCache
+ if tea_cache is not None:
+ tea_cache_update = tea_cache.check(dit, x, t_mod)
+ else:
+ tea_cache_update = False
+
+ if vace_context is not None:
+ vace_hints = vace(x, vace_context, context, t_mod, freqs)
+
+ # blocks
+ if use_unified_sequence_parallel:
+ if dist.is_initialized() and dist.get_world_size() > 1:
+ x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
+ if tea_cache_update:
+ x = tea_cache.update(x)
+ else:
+ for block_id, block in enumerate(dit.blocks):
+ x = block(x, context, t_mod, freqs)
+ if vace_context is not None and block_id in vace.vace_layers_mapping:
+ x = x + vace_hints[vace.vace_layers_mapping[block_id]] * vace_scale
+ if tea_cache is not None:
+ tea_cache.store(x)
+
+ x = dit.head(x, t)
+ if use_unified_sequence_parallel:
+ if dist.is_initialized() and dist.get_world_size() > 1:
+ x = get_sp_group().all_gather(x, dim=1)
+ x = dit.unpatchify(x, (f, h, w))
+ return x
diff --git a/PusaV1/diffsynth/pipelines/wan_video_pusa_v2v.py b/PusaV1/diffsynth/pipelines/wan_video_pusa_v2v.py
new file mode 100644
index 0000000000000000000000000000000000000000..64f97bc043bd15e885ab507a4fc6a1b9bf033063
--- /dev/null
+++ b/PusaV1/diffsynth/pipelines/wan_video_pusa_v2v.py
@@ -0,0 +1,690 @@
+import types
+from ..models import ModelManager
+from ..models.wan_video_pusa import WanModelPusa
+from ..models.wan_video_text_encoder import WanTextEncoder
+from ..models.wan_video_vae import WanVideoVAE
+from ..models.wan_video_image_encoder import WanImageEncoder
+from ..models.wan_video_vace import VaceWanModel
+from ..schedulers.flow_match_pusa_v2v import FlowMatchSchedulerPusaV2V
+from .base import BasePipeline
+from ..prompters import WanPrompter
+import torch, os
+from einops import rearrange
+import numpy as np
+from PIL import Image
+from tqdm import tqdm
+from typing import Optional
+
+from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
+from ..models.wan_video_text_encoder import T5RelativeEmbedding, T5LayerNorm
+from ..models.wan_video_pusa import RMSNorm, sinusoidal_embedding_1d
+from ..models.wan_video_vae import RMS_norm, CausalConv3d, Upsample
+from ..models.wan_video_motion_controller import WanMotionControllerModel
+
+
+
+class PusaV2VPipeline(BasePipeline):
+
+ def __init__(self, device="cuda", torch_dtype=torch.float16, tokenizer_path=None):
+ super().__init__(device=device, torch_dtype=torch_dtype)
+ self.scheduler = FlowMatchSchedulerPusaV2V(shift=5, sigma_min=0.0, extra_one_step=True)
+ self.prompter = WanPrompter(tokenizer_path=tokenizer_path)
+ self.text_encoder: WanTextEncoder = None
+ self.image_encoder: WanImageEncoder = None
+ self.dit: WanModelPusa = None
+ self.vae: WanVideoVAE = None
+ self.motion_controller: WanMotionControllerModel = None
+ self.vace: VaceWanModel = None
+ self.model_names = ['text_encoder', 'dit', 'vae', 'image_encoder', 'motion_controller', 'vace']
+ self.height_division_factor = 16
+ self.width_division_factor = 16
+ self.use_unified_sequence_parallel = False
+
+
+ def enable_vram_management(self, num_persistent_param_in_dit=None):
+ dtype = next(iter(self.text_encoder.parameters())).dtype
+ enable_vram_management(
+ self.text_encoder,
+ module_map = {
+ torch.nn.Linear: AutoWrappedLinear,
+ torch.nn.Embedding: AutoWrappedModule,
+ T5RelativeEmbedding: AutoWrappedModule,
+ T5LayerNorm: AutoWrappedModule,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device="cpu",
+ computation_dtype=self.torch_dtype,
+ computation_device=self.device,
+ ),
+ )
+ dtype = next(iter(self.dit.parameters())).dtype
+ enable_vram_management(
+ self.dit,
+ module_map = {
+ torch.nn.Linear: AutoWrappedLinear,
+ torch.nn.Conv3d: AutoWrappedModule,
+ torch.nn.LayerNorm: AutoWrappedModule,
+ RMSNorm: AutoWrappedModule,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device=self.device,
+ computation_dtype=self.torch_dtype,
+ computation_device=self.device,
+ ),
+ max_num_param=num_persistent_param_in_dit,
+ overflow_module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device="cpu",
+ computation_dtype=self.torch_dtype,
+ computation_device=self.device,
+ ),
+ )
+ dtype = next(iter(self.vae.parameters())).dtype
+ enable_vram_management(
+ self.vae,
+ module_map = {
+ torch.nn.Linear: AutoWrappedLinear,
+ torch.nn.Conv2d: AutoWrappedModule,
+ RMS_norm: AutoWrappedModule,
+ CausalConv3d: AutoWrappedModule,
+ Upsample: AutoWrappedModule,
+ torch.nn.SiLU: AutoWrappedModule,
+ torch.nn.Dropout: AutoWrappedModule,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device=self.device,
+ computation_dtype=self.torch_dtype,
+ computation_device=self.device,
+ ),
+ )
+ if self.image_encoder is not None:
+ dtype = next(iter(self.image_encoder.parameters())).dtype
+ enable_vram_management(
+ self.image_encoder,
+ module_map = {
+ torch.nn.Linear: AutoWrappedLinear,
+ torch.nn.Conv2d: AutoWrappedModule,
+ torch.nn.LayerNorm: AutoWrappedModule,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device="cpu",
+ computation_dtype=dtype,
+ computation_device=self.device,
+ ),
+ )
+ if self.motion_controller is not None:
+ dtype = next(iter(self.motion_controller.parameters())).dtype
+ enable_vram_management(
+ self.motion_controller,
+ module_map = {
+ torch.nn.Linear: AutoWrappedLinear,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device="cpu",
+ computation_dtype=dtype,
+ computation_device=self.device,
+ ),
+ )
+ if self.vace is not None:
+ enable_vram_management(
+ self.vace,
+ module_map = {
+ torch.nn.Linear: AutoWrappedLinear,
+ torch.nn.Conv3d: AutoWrappedModule,
+ torch.nn.LayerNorm: AutoWrappedModule,
+ RMSNorm: AutoWrappedModule,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device=self.device,
+ computation_dtype=self.torch_dtype,
+ computation_device=self.device,
+ ),
+ )
+ self.enable_cpu_offload()
+
+
+ def fetch_models(self, model_manager: ModelManager):
+ text_encoder_model_and_path = model_manager.fetch_model("wan_video_text_encoder", require_model_path=True)
+ if text_encoder_model_and_path is not None:
+ self.text_encoder, tokenizer_path = text_encoder_model_and_path
+ self.prompter.fetch_models(self.text_encoder)
+ self.prompter.fetch_tokenizer(os.path.join(os.path.dirname(tokenizer_path), "google/umt5-xxl"))
+ self.dit = model_manager.fetch_model("wan_video_pusa")
+ self.vae = model_manager.fetch_model("wan_video_vae")
+ self.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
+ self.motion_controller = model_manager.fetch_model("wan_video_motion_controller")
+ self.vace = model_manager.fetch_model("wan_video_vace")
+
+
+ @staticmethod
+ def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None, use_usp=False):
+ if device is None: device = model_manager.device
+ if torch_dtype is None: torch_dtype = model_manager.torch_dtype
+ pipe = PusaV2VPipeline(device=device, torch_dtype=torch_dtype)
+ pipe.fetch_models(model_manager)
+ if use_usp:
+ from xfuser.core.distributed import get_sequence_parallel_world_size
+ from ..distributed.xdit_context_parallel import usp_attn_forward, usp_dit_forward
+
+ for block in pipe.dit.blocks:
+ block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
+ pipe.dit.forward = types.MethodType(usp_dit_forward, pipe.dit)
+ pipe.sp_size = get_sequence_parallel_world_size()
+ pipe.use_unified_sequence_parallel = True
+ return pipe
+
+
+ def denoising_model(self):
+ return self.dit
+
+
+ def encode_prompt(self, prompt, positive=True):
+ prompt_emb = self.prompter.encode_prompt(prompt, positive=positive, device=self.device)
+ return {"context": prompt_emb}
+
+
+ def encode_single_image(self, image: Image.Image, height: int, width: int, tiled: bool, tile_size: tuple, tile_stride: tuple):
+ self.load_models_to_device(["vae"])
+ image = self.preprocess_image(image.resize((width, height), resample=Image.LANCZOS)).to(self.device)
+ image_tensor = image.unsqueeze(2)
+ image_tensor = image_tensor.to(dtype=self.torch_dtype, device=self.device)
+ latents = self.vae.encode(image_tensor, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+ return latents
+
+
+ def encode_image(self, image, end_image, num_frames, height, width):
+ image = self.preprocess_image(image.resize((width, height))).to(self.device)
+ clip_context = self.image_encoder.encode_image([image])
+ msk = torch.ones(1, num_frames, height//8, width//8, device=self.device)
+ msk[:, 1:] = 0
+ if end_image is not None:
+ end_image = self.preprocess_image(end_image.resize((width, height))).to(self.device)
+ vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1)
+ msk[:, -1:] = 1
+ else:
+ vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)
+
+ msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
+ msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
+ msk = msk.transpose(1, 2)[0]
+
+ y = self.vae.encode([vae_input.to(dtype=self.torch_dtype, device=self.device)], device=self.device)[0]
+
+ y = torch.concat([msk, y])
+ y = y.unsqueeze(0)
+ clip_context = clip_context.to(dtype=self.torch_dtype, device=self.device)
+ y = y.to(dtype=self.torch_dtype, device=self.device)
+ return {"clip_feature": clip_context, "y": y}
+
+ def encode_control_video(self, control_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
+ control_video = self.preprocess_images(control_video)
+ control_video = torch.stack(control_video, dim=2).to(dtype=self.torch_dtype, device=self.device)
+ latents = self.encode_video(control_video, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=self.torch_dtype, device=self.device)
+ return latents
+
+
+ def prepare_controlnet_kwargs(self, control_video, num_frames, height, width, clip_feature=None, y=None, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
+ if control_video is not None:
+ control_latents = self.encode_control_video(control_video, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+ if clip_feature is None or y is None:
+ clip_feature = torch.zeros((1, 257, 1280), dtype=self.torch_dtype, device=self.device)
+ y = torch.zeros((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), dtype=self.torch_dtype, device=self.device)
+ else:
+ y = y[:, -16:]
+ y = torch.concat([control_latents, y], dim=1)
+ return {"clip_feature": clip_feature, "y": y}
+
+
+ def tensor2video(self, frames):
+ frames = rearrange(frames, "C T H W -> T H W C")
+ frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
+ frames = [Image.fromarray(frame) for frame in frames]
+ return frames
+
+
+ def prepare_extra_input(self, latents=None):
+ return {}
+
+
+ def encode_video(self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
+ latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
+ return latents
+
+
+ def decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
+ model_dtype = next(iter(self.vae.parameters())).dtype
+ model_device = next(iter(self.vae.parameters())).device
+
+ # Convert latents to the correct dtype and device
+ latents = latents.to(dtype=model_dtype, device=model_device)
+
+ frames = self.vae.decode(latents, device=self.device, tiled=tiled,
+ tile_size=tile_size, tile_stride=tile_stride)
+ return frames
+
+
+ def prepare_unified_sequence_parallel(self):
+ return {"use_unified_sequence_parallel": self.use_unified_sequence_parallel}
+
+
+ def prepare_motion_bucket_id(self, motion_bucket_id):
+ motion_bucket_id = torch.Tensor((motion_bucket_id,)).to(dtype=self.torch_dtype, device=self.device)
+ return {"motion_bucket_id": motion_bucket_id}
+
+
+ def prepare_vace_kwargs(
+ self,
+ latents,
+ vace_video=None, vace_mask=None, vace_reference_image=None, vace_scale=1.0,
+ height=480, width=832, num_frames=81,
+ seed=None, rand_device="cpu",
+ tiled=True, tile_size=(34, 34), tile_stride=(18, 16)
+ ):
+ if vace_video is not None or vace_mask is not None or vace_reference_image is not None:
+ self.load_models_to_device(["vae"])
+ if vace_video is None:
+ vace_video = torch.zeros((1, 3, num_frames, height, width), dtype=self.torch_dtype, device=self.device)
+ else:
+ vace_video = self.preprocess_images(vace_video)
+ vace_video = torch.stack(vace_video, dim=2).to(dtype=self.torch_dtype, device=self.device)
+
+ if vace_mask is None:
+ vace_mask = torch.ones_like(vace_video)
+ else:
+ vace_mask = self.preprocess_images(vace_mask)
+ vace_mask = torch.stack(vace_mask, dim=2).to(dtype=self.torch_dtype, device=self.device)
+
+ inactive = vace_video * (1 - vace_mask) + 0 * vace_mask
+ reactive = vace_video * vace_mask + 0 * (1 - vace_mask)
+ inactive = self.encode_video(inactive, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=self.torch_dtype, device=self.device)
+ reactive = self.encode_video(reactive, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=self.torch_dtype, device=self.device)
+ vace_video_latents = torch.concat((inactive, reactive), dim=1)
+
+ vace_mask_latents = rearrange(vace_mask[0,0], "T (H P) (W Q) -> 1 (P Q) T H W", P=8, Q=8)
+ vace_mask_latents = torch.nn.functional.interpolate(vace_mask_latents, size=((vace_mask_latents.shape[2] + 3) // 4, vace_mask_latents.shape[3], vace_mask_latents.shape[4]), mode='nearest-exact')
+
+ if vace_reference_image is None:
+ pass
+ else:
+ vace_reference_image = self.preprocess_images([vace_reference_image])
+ vace_reference_image = torch.stack(vace_reference_image, dim=2).to(dtype=self.torch_dtype, device=self.device)
+ vace_reference_latents = self.encode_video(vace_reference_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=self.torch_dtype, device=self.device)
+ vace_reference_latents = torch.concat((vace_reference_latents, torch.zeros_like(vace_reference_latents)), dim=1)
+ vace_video_latents = torch.concat((vace_reference_latents, vace_video_latents), dim=2)
+ vace_mask_latents = torch.concat((torch.zeros_like(vace_mask_latents[:, :, :1]), vace_mask_latents), dim=2)
+
+ noise = self.generate_noise((1, 16, 1, latents.shape[3], latents.shape[4]), seed=seed, device=rand_device, dtype=torch.float32)
+ noise = noise.to(dtype=self.torch_dtype, device=self.device)
+ latents = torch.concat((noise, latents), dim=2)
+
+ vace_context = torch.concat((vace_video_latents, vace_mask_latents), dim=1)
+ return latents, {"vace_context": vace_context, "vace_scale": vace_scale}
+ else:
+ return latents, {"vace_context": None, "vace_scale": vace_scale}
+
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt,
+ negative_prompt="",
+ input_image=None,
+ conditioning_video: Optional[list] = None,
+ conditioning_indices: Optional[list] = None,
+ conditioning_noise_multipliers: Optional[list] = None,
+ input_video=None,
+ control_video=None,
+ vace_video=None,
+ vace_video_mask=None,
+ vace_reference_image=None,
+ vace_scale=1.0,
+ denoising_strength=1.0,
+ seed=None,
+ rand_device="cpu",
+ height=480,
+ width=832,
+ num_frames=81,
+ cfg_scale=5.0,
+ num_inference_steps=50,
+ sigma_shift=5.0,
+ motion_bucket_id=None,
+ tiled=True,
+ tile_size=(30, 52),
+ tile_stride=(15, 26),
+ tea_cache_l1_thresh=None,
+ tea_cache_model_id="",
+ progress_bar_cmd=tqdm,
+ progress_bar_st=None,
+ visualize_attention=False,
+ output_dir=None,
+ ):
+ # Parameter check
+ height, width = self.check_resize_height_width(height, width)
+ if num_frames % 4 != 1:
+ num_frames = (num_frames + 2) // 4 * 4 + 1
+ print(f"Only `num_frames % 4 != 1` is acceptable. We round it up to {num_frames}.")
+
+ # import ipdb; ipdb.set_trace()
+ if visualize_attention:
+ import datetime
+ import os
+ from ..models.wan_video_pusa import _VISUALIZE_ATTENTION_CONFIG
+
+ if output_dir:
+ vis_path = os.path.join(output_dir,"attention_maps")
+ else:
+ timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
+ vis_path = os.path.join("attention_maps", timestamp)
+ os.makedirs(vis_path, exist_ok=True)
+
+ _VISUALIZE_ATTENTION_CONFIG["enabled"] = True
+ _VISUALIZE_ATTENTION_CONFIG["path"] = vis_path
+ print(f"Attention visualization enabled. Maps will be saved to {vis_path}")
+
+ # Tiler parameters
+ tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
+
+ # Scheduler
+ self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
+
+
+ # Initialize noise
+ latent_size = (1, 16, (num_frames - 1) // 4 + 1, height//8, width//8)
+ noise = self.generate_noise(latent_size, seed=seed, device=rand_device, dtype=torch.float32)
+ noise = noise.to(dtype=self.torch_dtype, device=self.device)
+
+ if input_video is not None:
+ self.load_models_to_device(['vae'])
+ input_video = self.preprocess_images(input_video)
+ input_video = torch.stack(input_video, dim=2).to(dtype=self.torch_dtype, device=self.device)
+ latents = self.encode_video(input_video, **tiler_kwargs).to(dtype=self.torch_dtype, device=self.device)
+ latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
+ else:
+ latents = noise
+
+ cond_frame_latent_indices = []
+ noise_multipliers = {}
+ if conditioning_video is not None:
+ self.load_models_to_device(['vae'])
+ video_frames = self.preprocess_images(conditioning_video)
+ video_tensor = torch.stack(video_frames, dim=2).to(dtype=self.torch_dtype, device=self.device)
+ cond_latents = self.encode_video(video_tensor, **tiler_kwargs)
+
+ for i, frame_idx in enumerate(conditioning_indices):
+ latent_idx = frame_idx
+ cond_frame_latent_indices.append(latent_idx)
+ noise_multipliers[latent_idx] = conditioning_noise_multipliers[i]
+ latents[:, :, latent_idx:latent_idx+1] = cond_latents[:, :, latent_idx:latent_idx+1].to(latents.device)
+ # Encode prompts
+ self.load_models_to_device(["text_encoder"])
+ prompt_emb_posi = self.encode_prompt(prompt, positive=True)
+ if cfg_scale != 1.0:
+ prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
+
+ # Encode image
+ if input_image is not None and self.image_encoder is not None:
+ self.load_models_to_device(["image_encoder", "vae"])
+ image_emb = self.encode_image(input_image, None, num_frames, height, width)
+ else:
+ image_emb = {}
+
+ # ControlNet
+ if control_video is not None:
+ self.load_models_to_device(["image_encoder", "vae"])
+ image_emb = self.prepare_controlnet_kwargs(control_video, num_frames, height, width, **image_emb, **tiler_kwargs)
+
+ # # Motion Controller
+ if self.motion_controller is not None and motion_bucket_id is not None:
+ motion_kwargs = self.prepare_motion_bucket_id(motion_bucket_id)
+ else:
+ motion_kwargs = {}
+
+ # Extra input
+ extra_input = self.prepare_extra_input(latents)
+
+ # VACE
+ latents, vace_kwargs = self.prepare_vace_kwargs(
+ latents, vace_video, vace_video_mask, vace_reference_image, vace_scale,
+ height=height, width=width, num_frames=num_frames, seed=seed, rand_device=rand_device, **tiler_kwargs
+ )
+
+ # TeaCache
+ tea_cache_posi = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None}
+ tea_cache_nega = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None}
+
+ # Unified Sequence Parallel
+ usp_kwargs = self.prepare_unified_sequence_parallel()
+
+ if input_image is not None:
+ latents[:,:,0:1] = image_emb["y"][:,4:,0:1]
+
+ # Denoise
+ self.load_models_to_device(["dit", "motion_controller", "vace"])
+ for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
+ if visualize_attention:
+ from ..models.wan_video_pusa import _VISUALIZE_ATTENTION_CONFIG
+ _VISUALIZE_ATTENTION_CONFIG["step"] = progress_id
+
+ timestep = timestep.unsqueeze(0).unsqueeze(1).repeat(1, latents.shape[2]).to(dtype=self.torch_dtype, device=self.device)
+
+ if input_image is not None:
+ timestep[:,0] = 0
+ for latent_idx in cond_frame_latent_indices:
+ multiplier = noise_multipliers.get(latent_idx, 1.0)
+ timestep[:,latent_idx] = timestep[:,latent_idx] * multiplier # timestep = sigma * 1000, equivalent
+ timestep = timestep.to(torch.long).to(dtype=self.torch_dtype, device=self.device)
+
+
+ print("timestep", timestep[0])
+
+ # Inference
+ noise_pred_posi = model_fn_wan_video(
+ self.dit, motion_controller=self.motion_controller, vace=self.vace,
+ x=latents, timestep=timestep,
+ **prompt_emb_posi, **image_emb, **extra_input,
+ **tea_cache_posi, **usp_kwargs, **motion_kwargs, **vace_kwargs,
+ )
+ if cfg_scale != 1.0:
+ noise_pred_nega = model_fn_wan_video(
+ self.dit, motion_controller=self.motion_controller, vace=self.vace,
+ x=latents, timestep=timestep,
+ **prompt_emb_nega, **image_emb, **extra_input,
+ **tea_cache_nega, **usp_kwargs, **motion_kwargs, **vace_kwargs,
+ )
+ noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
+ else:
+ noise_pred = noise_pred_posi
+
+ # Scheduler
+ latents = self.scheduler.step(
+ noise_pred,
+ timestep,
+ latents,
+ cond_frame_latent_indices=cond_frame_latent_indices,
+ noise_multipliers=noise_multipliers
+ )
+
+
+ if visualize_attention:
+ from ..models.wan_video_pusa import _VISUALIZE_ATTENTION_CONFIG
+ _VISUALIZE_ATTENTION_CONFIG["enabled"] = False
+ _VISUALIZE_ATTENTION_CONFIG["path"] = None
+ print("Attention visualization finished.")
+
+ if vace_reference_image is not None:
+ latents = latents[:, :, 1:]
+
+ # Decode
+ self.load_models_to_device(['vae'])
+ frames = self.decode_video(latents, **tiler_kwargs)
+ self.load_models_to_device([])
+ frames = self.tensor2video(frames[0])
+
+ return frames
+
+
+
+class TeaCache:
+ def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
+ self.num_inference_steps = num_inference_steps
+ self.step = 0
+ self.accumulated_rel_l1_distance = 0
+ self.previous_modulated_input = None
+ self.rel_l1_thresh = rel_l1_thresh
+ self.previous_residual = None
+ self.previous_hidden_states = None
+
+ self.coefficients_dict = {
+ "Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02],
+ "Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01],
+ "Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01],
+ "Wan2.1-I2V-14B-720P": [ 8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02],
+ }
+ if model_id not in self.coefficients_dict:
+ supported_model_ids = ", ".join([i for i in self.coefficients_dict])
+ raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).")
+ self.coefficients = self.coefficients_dict[model_id]
+
+ def check(self, dit: WanModelPusa, x, t_mod):
+ modulated_inp = t_mod.clone()
+ if self.step == 0 or self.step == self.num_inference_steps - 1:
+ should_calc = True
+ self.accumulated_rel_l1_distance = 0
+ else:
+ coefficients = self.coefficients
+ rescale_func = np.poly1d(coefficients)
+ self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
+ if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
+ should_calc = False
+ else:
+ should_calc = True
+ self.accumulated_rel_l1_distance = 0
+ self.previous_modulated_input = modulated_inp
+ self.step += 1
+ if self.step == self.num_inference_steps:
+ self.step = 0
+ if should_calc:
+ self.previous_hidden_states = x.clone()
+ return not should_calc
+
+ def store(self, hidden_states):
+ self.previous_residual = hidden_states - self.previous_hidden_states
+ self.previous_hidden_states = None
+
+ def update(self, hidden_states):
+ hidden_states = hidden_states + self.previous_residual
+ return hidden_states
+
+
+
+def model_fn_wan_video(
+ dit: WanModelPusa,
+ motion_controller: WanMotionControllerModel = None,
+ vace: VaceWanModel = None,
+ x: torch.Tensor = None,
+ timestep: torch.Tensor = None,
+ context: torch.Tensor = None,
+ clip_feature: Optional[torch.Tensor] = None,
+ y: Optional[torch.Tensor] = None,
+ vace_context = None,
+ vace_scale = 1.0,
+ tea_cache: TeaCache = None,
+ use_unified_sequence_parallel: bool = False,
+ motion_bucket_id: Optional[torch.Tensor] = None,
+ **kwargs,
+):
+ if use_unified_sequence_parallel:
+ import torch.distributed as dist
+ from xfuser.core.distributed import (get_sequence_parallel_rank,
+ get_sequence_parallel_world_size,
+ get_sp_group)
+
+ t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))
+ t_mod = dit.time_projection(t).unflatten(2, (6, dit.dim))
+ if motion_bucket_id is not None and motion_controller is not None:
+ t_mod = t_mod + motion_controller(motion_bucket_id).unflatten(1, (6, dit.dim))
+ context = dit.text_embedding(context)
+
+
+ B, C, T, H, W = x.shape
+ pH, pW = H // dit.patch_size[1], W // dit.patch_size[2]
+
+ t = t.unsqueeze(2).unsqueeze(3).repeat(1, 1, pH, pW, 1)
+ t = rearrange(t, 'b f h w d -> b (f h w) d').contiguous()
+ t_mod = t_mod.unsqueeze(3).unsqueeze(4).repeat(1, 1, 1, pH, pW, 1)
+ t_mod = rearrange(t_mod, 'b f m h w d -> b m (f h w) d').contiguous()
+
+
+ model_dtype = next(iter(dit.parameters())).dtype
+ model_device = next(iter(dit.parameters())).device
+
+ # Convert inputs to the correct dtype and device
+ x = x.to(dtype=model_dtype, device=model_device)
+ t_mod = t_mod.to(dtype=model_dtype, device=model_device)
+ context = context.to(dtype=model_dtype, device=model_device)
+ if y is not None:
+ y = y.to(dtype=model_dtype, device=model_device)
+ if clip_feature is not None:
+ clip_feature = clip_feature.to(dtype=model_dtype, device=model_device)
+
+ x, (f, h, w) = dit.patchify(x)
+
+ from ..models.wan_video_pusa import _VISUALIZE_ATTENTION_CONFIG
+ if _VISUALIZE_ATTENTION_CONFIG["enabled"]:
+ _VISUALIZE_ATTENTION_CONFIG["grid_size"] = (f, h, w)
+
+ freqs = torch.cat([
+ dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
+ dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
+ dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
+ ], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
+
+ # TeaCache
+ if tea_cache is not None:
+ tea_cache_update = tea_cache.check(dit, x, t_mod)
+ else:
+ tea_cache_update = False
+
+ if vace_context is not None:
+ vace_hints = vace(x, vace_context, context, t_mod, freqs)
+
+ # blocks
+ if use_unified_sequence_parallel:
+ if dist.is_initialized() and dist.get_world_size() > 1:
+ x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
+ if tea_cache_update:
+ x = tea_cache.update(x)
+ else:
+ for block_id, block in enumerate(dit.blocks):
+ x = block(x, context, t_mod, freqs)
+ if vace_context is not None and block_id in vace.vace_layers_mapping:
+ x = x + vace_hints[vace.vace_layers_mapping[block_id]] * vace_scale
+ if tea_cache is not None:
+ tea_cache.store(x)
+
+ x = dit.head(x, t)
+ if use_unified_sequence_parallel:
+ if dist.is_initialized() and dist.get_world_size() > 1:
+ x = get_sp_group().all_gather(x, dim=1)
+ x = dit.unpatchify(x, (f, h, w))
+ return x
diff --git a/PusaV1/diffsynth/processors/FastBlend.py b/PusaV1/diffsynth/processors/FastBlend.py
new file mode 100644
index 0000000000000000000000000000000000000000..fed33f4fdd215c8c9dc46f3b07d9453a12cc6b98
--- /dev/null
+++ b/PusaV1/diffsynth/processors/FastBlend.py
@@ -0,0 +1,142 @@
+from PIL import Image
+import cupy as cp
+import numpy as np
+from tqdm import tqdm
+from ..extensions.FastBlend.patch_match import PyramidPatchMatcher
+from ..extensions.FastBlend.runners.fast import TableManager
+from .base import VideoProcessor
+
+
+class FastBlendSmoother(VideoProcessor):
+ def __init__(
+ self,
+ inference_mode="fast", batch_size=8, window_size=60,
+ minimum_patch_size=5, threads_per_block=8, num_iter=5, gpu_id=0, guide_weight=10.0, initialize="identity", tracking_window_size=0
+ ):
+ self.inference_mode = inference_mode
+ self.batch_size = batch_size
+ self.window_size = window_size
+ self.ebsynth_config = {
+ "minimum_patch_size": minimum_patch_size,
+ "threads_per_block": threads_per_block,
+ "num_iter": num_iter,
+ "gpu_id": gpu_id,
+ "guide_weight": guide_weight,
+ "initialize": initialize,
+ "tracking_window_size": tracking_window_size
+ }
+
+ @staticmethod
+ def from_model_manager(model_manager, **kwargs):
+ # TODO: fetch GPU ID from model_manager
+ return FastBlendSmoother(**kwargs)
+
+ def inference_fast(self, frames_guide, frames_style):
+ table_manager = TableManager()
+ patch_match_engine = PyramidPatchMatcher(
+ image_height=frames_style[0].shape[0],
+ image_width=frames_style[0].shape[1],
+ channel=3,
+ **self.ebsynth_config
+ )
+ # left part
+ table_l = table_manager.build_remapping_table(frames_guide, frames_style, patch_match_engine, self.batch_size, desc="Fast Mode Step 1/4")
+ table_l = table_manager.remapping_table_to_blending_table(table_l)
+ table_l = table_manager.process_window_sum(frames_guide, table_l, patch_match_engine, self.window_size, self.batch_size, desc="Fast Mode Step 2/4")
+ # right part
+ table_r = table_manager.build_remapping_table(frames_guide[::-1], frames_style[::-1], patch_match_engine, self.batch_size, desc="Fast Mode Step 3/4")
+ table_r = table_manager.remapping_table_to_blending_table(table_r)
+ table_r = table_manager.process_window_sum(frames_guide[::-1], table_r, patch_match_engine, self.window_size, self.batch_size, desc="Fast Mode Step 4/4")[::-1]
+ # merge
+ frames = []
+ for (frame_l, weight_l), frame_m, (frame_r, weight_r) in zip(table_l, frames_style, table_r):
+ weight_m = -1
+ weight = weight_l + weight_m + weight_r
+ frame = frame_l * (weight_l / weight) + frame_m * (weight_m / weight) + frame_r * (weight_r / weight)
+ frames.append(frame)
+ frames = [frame.clip(0, 255).astype("uint8") for frame in frames]
+ frames = [Image.fromarray(frame) for frame in frames]
+ return frames
+
+ def inference_balanced(self, frames_guide, frames_style):
+ patch_match_engine = PyramidPatchMatcher(
+ image_height=frames_style[0].shape[0],
+ image_width=frames_style[0].shape[1],
+ channel=3,
+ **self.ebsynth_config
+ )
+ output_frames = []
+ # tasks
+ n = len(frames_style)
+ tasks = []
+ for target in range(n):
+ for source in range(target - self.window_size, target + self.window_size + 1):
+ if source >= 0 and source < n and source != target:
+ tasks.append((source, target))
+ # run
+ frames = [(None, 1) for i in range(n)]
+ for batch_id in tqdm(range(0, len(tasks), self.batch_size), desc="Balanced Mode"):
+ tasks_batch = tasks[batch_id: min(batch_id+self.batch_size, len(tasks))]
+ source_guide = np.stack([frames_guide[source] for source, target in tasks_batch])
+ target_guide = np.stack([frames_guide[target] for source, target in tasks_batch])
+ source_style = np.stack([frames_style[source] for source, target in tasks_batch])
+ _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
+ for (source, target), result in zip(tasks_batch, target_style):
+ frame, weight = frames[target]
+ if frame is None:
+ frame = frames_style[target]
+ frames[target] = (
+ frame * (weight / (weight + 1)) + result / (weight + 1),
+ weight + 1
+ )
+ if weight + 1 == min(n, target + self.window_size + 1) - max(0, target - self.window_size):
+ frame = frame.clip(0, 255).astype("uint8")
+ output_frames.append(Image.fromarray(frame))
+ frames[target] = (None, 1)
+ return output_frames
+
+ def inference_accurate(self, frames_guide, frames_style):
+ patch_match_engine = PyramidPatchMatcher(
+ image_height=frames_style[0].shape[0],
+ image_width=frames_style[0].shape[1],
+ channel=3,
+ use_mean_target_style=True,
+ **self.ebsynth_config
+ )
+ output_frames = []
+ # run
+ n = len(frames_style)
+ for target in tqdm(range(n), desc="Accurate Mode"):
+ l, r = max(target - self.window_size, 0), min(target + self.window_size + 1, n)
+ remapped_frames = []
+ for i in range(l, r, self.batch_size):
+ j = min(i + self.batch_size, r)
+ source_guide = np.stack([frames_guide[source] for source in range(i, j)])
+ target_guide = np.stack([frames_guide[target]] * (j - i))
+ source_style = np.stack([frames_style[source] for source in range(i, j)])
+ _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
+ remapped_frames.append(target_style)
+ frame = np.concatenate(remapped_frames, axis=0).mean(axis=0)
+ frame = frame.clip(0, 255).astype("uint8")
+ output_frames.append(Image.fromarray(frame))
+ return output_frames
+
+ def release_vram(self):
+ mempool = cp.get_default_memory_pool()
+ pinned_mempool = cp.get_default_pinned_memory_pool()
+ mempool.free_all_blocks()
+ pinned_mempool.free_all_blocks()
+
+ def __call__(self, rendered_frames, original_frames=None, **kwargs):
+ rendered_frames = [np.array(frame) for frame in rendered_frames]
+ original_frames = [np.array(frame) for frame in original_frames]
+ if self.inference_mode == "fast":
+ output_frames = self.inference_fast(original_frames, rendered_frames)
+ elif self.inference_mode == "balanced":
+ output_frames = self.inference_balanced(original_frames, rendered_frames)
+ elif self.inference_mode == "accurate":
+ output_frames = self.inference_accurate(original_frames, rendered_frames)
+ else:
+ raise ValueError("inference_mode must be fast, balanced or accurate")
+ self.release_vram()
+ return output_frames
diff --git a/PusaV1/diffsynth/processors/PILEditor.py b/PusaV1/diffsynth/processors/PILEditor.py
new file mode 100644
index 0000000000000000000000000000000000000000..01011d8724f61283550d503c5c20ae6fd0375ec7
--- /dev/null
+++ b/PusaV1/diffsynth/processors/PILEditor.py
@@ -0,0 +1,28 @@
+from PIL import ImageEnhance
+from .base import VideoProcessor
+
+
+class ContrastEditor(VideoProcessor):
+ def __init__(self, rate=1.5):
+ self.rate = rate
+
+ @staticmethod
+ def from_model_manager(model_manager, **kwargs):
+ return ContrastEditor(**kwargs)
+
+ def __call__(self, rendered_frames, **kwargs):
+ rendered_frames = [ImageEnhance.Contrast(i).enhance(self.rate) for i in rendered_frames]
+ return rendered_frames
+
+
+class SharpnessEditor(VideoProcessor):
+ def __init__(self, rate=1.5):
+ self.rate = rate
+
+ @staticmethod
+ def from_model_manager(model_manager, **kwargs):
+ return SharpnessEditor(**kwargs)
+
+ def __call__(self, rendered_frames, **kwargs):
+ rendered_frames = [ImageEnhance.Sharpness(i).enhance(self.rate) for i in rendered_frames]
+ return rendered_frames
diff --git a/PusaV1/diffsynth/processors/RIFE.py b/PusaV1/diffsynth/processors/RIFE.py
new file mode 100644
index 0000000000000000000000000000000000000000..4186eb31496e9a1bf38df06eb64921226f07ee09
--- /dev/null
+++ b/PusaV1/diffsynth/processors/RIFE.py
@@ -0,0 +1,77 @@
+import torch
+import numpy as np
+from PIL import Image
+from .base import VideoProcessor
+
+
+class RIFESmoother(VideoProcessor):
+ def __init__(self, model, device="cuda", scale=1.0, batch_size=4, interpolate=True):
+ self.model = model
+ self.device = device
+
+ # IFNet only does not support float16
+ self.torch_dtype = torch.float32
+
+ # Other parameters
+ self.scale = scale
+ self.batch_size = batch_size
+ self.interpolate = interpolate
+
+ @staticmethod
+ def from_model_manager(model_manager, **kwargs):
+ return RIFESmoother(model_manager.RIFE, device=model_manager.device, **kwargs)
+
+ def process_image(self, image):
+ width, height = image.size
+ if width % 32 != 0 or height % 32 != 0:
+ width = (width + 31) // 32
+ height = (height + 31) // 32
+ image = image.resize((width, height))
+ image = torch.Tensor(np.array(image, dtype=np.float32)[:, :, [2,1,0]] / 255).permute(2, 0, 1)
+ return image
+
+ def process_images(self, images):
+ images = [self.process_image(image) for image in images]
+ images = torch.stack(images)
+ return images
+
+ def decode_images(self, images):
+ images = (images[:, [2,1,0]].permute(0, 2, 3, 1) * 255).clip(0, 255).numpy().astype(np.uint8)
+ images = [Image.fromarray(image) for image in images]
+ return images
+
+ def process_tensors(self, input_tensor, scale=1.0, batch_size=4):
+ output_tensor = []
+ for batch_id in range(0, input_tensor.shape[0], batch_size):
+ batch_id_ = min(batch_id + batch_size, input_tensor.shape[0])
+ batch_input_tensor = input_tensor[batch_id: batch_id_]
+ batch_input_tensor = batch_input_tensor.to(device=self.device, dtype=self.torch_dtype)
+ flow, mask, merged = self.model(batch_input_tensor, [4/scale, 2/scale, 1/scale])
+ output_tensor.append(merged[2].cpu())
+ output_tensor = torch.concat(output_tensor, dim=0)
+ return output_tensor
+
+ @torch.no_grad()
+ def __call__(self, rendered_frames, **kwargs):
+ # Preprocess
+ processed_images = self.process_images(rendered_frames)
+
+ # Input
+ input_tensor = torch.cat((processed_images[:-2], processed_images[2:]), dim=1)
+
+ # Interpolate
+ output_tensor = self.process_tensors(input_tensor, scale=self.scale, batch_size=self.batch_size)
+
+ if self.interpolate:
+ # Blend
+ input_tensor = torch.cat((processed_images[1:-1], output_tensor), dim=1)
+ output_tensor = self.process_tensors(input_tensor, scale=self.scale, batch_size=self.batch_size)
+ processed_images[1:-1] = output_tensor
+ else:
+ processed_images[1:-1] = (processed_images[1:-1] + output_tensor) / 2
+
+ # To images
+ output_images = self.decode_images(processed_images)
+ if output_images[0].size != rendered_frames[0].size:
+ output_images = [image.resize(rendered_frames[0].size) for image in output_images]
+ return output_images
diff --git a/PusaV1/diffsynth/processors/__init__.py b/PusaV1/diffsynth/processors/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/PusaV1/diffsynth/processors/__pycache__/__init__.cpython-310.pyc b/PusaV1/diffsynth/processors/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4ac52a63d1533c4bb46a7b27fd356a7f3f8d6879
Binary files /dev/null and b/PusaV1/diffsynth/processors/__pycache__/__init__.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/processors/__pycache__/__init__.cpython-312.pyc b/PusaV1/diffsynth/processors/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0910deb9c948f78470746e07891d838fd0f347f3
Binary files /dev/null and b/PusaV1/diffsynth/processors/__pycache__/__init__.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/processors/__pycache__/base.cpython-310.pyc b/PusaV1/diffsynth/processors/__pycache__/base.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2f38bece072d97a4ec613f4fb0dd78cc1781b17b
Binary files /dev/null and b/PusaV1/diffsynth/processors/__pycache__/base.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/processors/__pycache__/base.cpython-312.pyc b/PusaV1/diffsynth/processors/__pycache__/base.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ce4d922d11f9e5a29c3097657c81d2acb0b8b5df
Binary files /dev/null and b/PusaV1/diffsynth/processors/__pycache__/base.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/processors/__pycache__/sequencial_processor.cpython-310.pyc b/PusaV1/diffsynth/processors/__pycache__/sequencial_processor.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7af9d5db78e93437398bb6fb75eb200a8b48203f
Binary files /dev/null and b/PusaV1/diffsynth/processors/__pycache__/sequencial_processor.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/processors/__pycache__/sequencial_processor.cpython-312.pyc b/PusaV1/diffsynth/processors/__pycache__/sequencial_processor.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8fa2af4d9a5615e0b0d9c721855b5745bf9ab0d4
Binary files /dev/null and b/PusaV1/diffsynth/processors/__pycache__/sequencial_processor.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/processors/base.py b/PusaV1/diffsynth/processors/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..278a9c1b74044987cc116de35292a96de8b13737
--- /dev/null
+++ b/PusaV1/diffsynth/processors/base.py
@@ -0,0 +1,6 @@
+class VideoProcessor:
+ def __init__(self):
+ pass
+
+ def __call__(self):
+ raise NotImplementedError
diff --git a/PusaV1/diffsynth/processors/sequencial_processor.py b/PusaV1/diffsynth/processors/sequencial_processor.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b5bc9454f0b9d74f10bb4a6bff92db77f26325c
--- /dev/null
+++ b/PusaV1/diffsynth/processors/sequencial_processor.py
@@ -0,0 +1,41 @@
+from .base import VideoProcessor
+
+
+class AutoVideoProcessor(VideoProcessor):
+ def __init__(self):
+ pass
+
+ @staticmethod
+ def from_model_manager(model_manager, processor_type, **kwargs):
+ if processor_type == "FastBlend":
+ from .FastBlend import FastBlendSmoother
+ return FastBlendSmoother.from_model_manager(model_manager, **kwargs)
+ elif processor_type == "Contrast":
+ from .PILEditor import ContrastEditor
+ return ContrastEditor.from_model_manager(model_manager, **kwargs)
+ elif processor_type == "Sharpness":
+ from .PILEditor import SharpnessEditor
+ return SharpnessEditor.from_model_manager(model_manager, **kwargs)
+ elif processor_type == "RIFE":
+ from .RIFE import RIFESmoother
+ return RIFESmoother.from_model_manager(model_manager, **kwargs)
+ else:
+ raise ValueError(f"invalid processor_type: {processor_type}")
+
+
+class SequencialProcessor(VideoProcessor):
+ def __init__(self, processors=[]):
+ self.processors = processors
+
+ @staticmethod
+ def from_model_manager(model_manager, configs):
+ processors = [
+ AutoVideoProcessor.from_model_manager(model_manager, config["processor_type"], **config["config"])
+ for config in configs
+ ]
+ return SequencialProcessor(processors)
+
+ def __call__(self, rendered_frames, **kwargs):
+ for processor in self.processors:
+ rendered_frames = processor(rendered_frames, **kwargs)
+ return rendered_frames
diff --git a/PusaV1/diffsynth/prompters/__init__.py b/PusaV1/diffsynth/prompters/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f27c6f153b076de484c2b650e8bf16d7142d1099
--- /dev/null
+++ b/PusaV1/diffsynth/prompters/__init__.py
@@ -0,0 +1,12 @@
+from .prompt_refiners import Translator, BeautifulPrompt, QwenPrompt
+from .sd_prompter import SDPrompter
+from .sdxl_prompter import SDXLPrompter
+from .sd3_prompter import SD3Prompter
+from .hunyuan_dit_prompter import HunyuanDiTPrompter
+from .kolors_prompter import KolorsPrompter
+from .flux_prompter import FluxPrompter
+from .omost import OmostPromter
+from .cog_prompter import CogPrompter
+from .hunyuan_video_prompter import HunyuanVideoPrompter
+from .stepvideo_prompter import StepVideoPrompter
+from .wan_prompter import WanPrompter
diff --git a/PusaV1/diffsynth/prompters/__pycache__/__init__.cpython-310.pyc b/PusaV1/diffsynth/prompters/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4812d155fdc1486c424c01da5e9465f50422b0af
Binary files /dev/null and b/PusaV1/diffsynth/prompters/__pycache__/__init__.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/prompters/__pycache__/__init__.cpython-312.pyc b/PusaV1/diffsynth/prompters/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a8469a2dc2df325715f5cfa6e8e1a6c8d5e296f7
Binary files /dev/null and b/PusaV1/diffsynth/prompters/__pycache__/__init__.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/prompters/__pycache__/base_prompter.cpython-310.pyc b/PusaV1/diffsynth/prompters/__pycache__/base_prompter.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..08273f4039034f5a834daf895cd1565cd80253fc
Binary files /dev/null and b/PusaV1/diffsynth/prompters/__pycache__/base_prompter.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/prompters/__pycache__/base_prompter.cpython-312.pyc b/PusaV1/diffsynth/prompters/__pycache__/base_prompter.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ecc43124553ef6372ac342673ccedbda6444fde6
Binary files /dev/null and b/PusaV1/diffsynth/prompters/__pycache__/base_prompter.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/prompters/__pycache__/cog_prompter.cpython-310.pyc b/PusaV1/diffsynth/prompters/__pycache__/cog_prompter.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..eb949d13796ae5d48f3c9cab60238e284d04d6c4
Binary files /dev/null and b/PusaV1/diffsynth/prompters/__pycache__/cog_prompter.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/prompters/__pycache__/cog_prompter.cpython-312.pyc b/PusaV1/diffsynth/prompters/__pycache__/cog_prompter.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3ce94f0abbd4599ffd20084bd859774d88e693ee
Binary files /dev/null and b/PusaV1/diffsynth/prompters/__pycache__/cog_prompter.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/prompters/__pycache__/flux_prompter.cpython-310.pyc b/PusaV1/diffsynth/prompters/__pycache__/flux_prompter.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6f6bbff9d5838c9cd115d185074274a94b21c9a5
Binary files /dev/null and b/PusaV1/diffsynth/prompters/__pycache__/flux_prompter.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/prompters/__pycache__/flux_prompter.cpython-312.pyc b/PusaV1/diffsynth/prompters/__pycache__/flux_prompter.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0b4b32140b6276ee290d048983c893f92c2a423c
Binary files /dev/null and b/PusaV1/diffsynth/prompters/__pycache__/flux_prompter.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/prompters/__pycache__/hunyuan_dit_prompter.cpython-310.pyc b/PusaV1/diffsynth/prompters/__pycache__/hunyuan_dit_prompter.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..903dedef3c7f595585a40f10189e4dab8b3e77d9
Binary files /dev/null and b/PusaV1/diffsynth/prompters/__pycache__/hunyuan_dit_prompter.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/prompters/__pycache__/hunyuan_dit_prompter.cpython-312.pyc b/PusaV1/diffsynth/prompters/__pycache__/hunyuan_dit_prompter.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a4fd59a3a43d080e942f51017a68a4ca6c0e7ea7
Binary files /dev/null and b/PusaV1/diffsynth/prompters/__pycache__/hunyuan_dit_prompter.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/prompters/__pycache__/hunyuan_video_prompter.cpython-310.pyc b/PusaV1/diffsynth/prompters/__pycache__/hunyuan_video_prompter.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cfc02664f1101606d6ff16c530b2a588a832fbbf
Binary files /dev/null and b/PusaV1/diffsynth/prompters/__pycache__/hunyuan_video_prompter.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/prompters/__pycache__/hunyuan_video_prompter.cpython-312.pyc b/PusaV1/diffsynth/prompters/__pycache__/hunyuan_video_prompter.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4903ef06301bdce4e13f1ff00bf4a74b6bb90e78
Binary files /dev/null and b/PusaV1/diffsynth/prompters/__pycache__/hunyuan_video_prompter.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/prompters/__pycache__/kolors_prompter.cpython-310.pyc b/PusaV1/diffsynth/prompters/__pycache__/kolors_prompter.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..36b58714d417d43f6976cac1e347b1ccc0666c3b
Binary files /dev/null and b/PusaV1/diffsynth/prompters/__pycache__/kolors_prompter.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/prompters/__pycache__/kolors_prompter.cpython-312.pyc b/PusaV1/diffsynth/prompters/__pycache__/kolors_prompter.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..eb17d20c8bdd9e1d3ffe4527f97c5e96462c9eda
Binary files /dev/null and b/PusaV1/diffsynth/prompters/__pycache__/kolors_prompter.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/prompters/__pycache__/omnigen_prompter.cpython-310.pyc b/PusaV1/diffsynth/prompters/__pycache__/omnigen_prompter.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e733d3f822303724b0df2fbf5917677d43fcc8b0
Binary files /dev/null and b/PusaV1/diffsynth/prompters/__pycache__/omnigen_prompter.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/prompters/__pycache__/omnigen_prompter.cpython-312.pyc b/PusaV1/diffsynth/prompters/__pycache__/omnigen_prompter.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..187c8f8f1f6dd580183be6347a0cef2c0c6d0472
Binary files /dev/null and b/PusaV1/diffsynth/prompters/__pycache__/omnigen_prompter.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/prompters/__pycache__/omost.cpython-310.pyc b/PusaV1/diffsynth/prompters/__pycache__/omost.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7cb8f8b208ae7999e4fea1e98ced0b7167460ed8
Binary files /dev/null and b/PusaV1/diffsynth/prompters/__pycache__/omost.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/prompters/__pycache__/omost.cpython-312.pyc b/PusaV1/diffsynth/prompters/__pycache__/omost.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..38b880f939a319eb3ecf266fbc319dd8777dc24c
Binary files /dev/null and b/PusaV1/diffsynth/prompters/__pycache__/omost.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/prompters/__pycache__/prompt_refiners.cpython-310.pyc b/PusaV1/diffsynth/prompters/__pycache__/prompt_refiners.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8609b053ab5bc6d41489c49eb739a1d25841275e
Binary files /dev/null and b/PusaV1/diffsynth/prompters/__pycache__/prompt_refiners.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/prompters/__pycache__/prompt_refiners.cpython-312.pyc b/PusaV1/diffsynth/prompters/__pycache__/prompt_refiners.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..58e14bb9226464163b9f93cf807cadfc3e40a4af
Binary files /dev/null and b/PusaV1/diffsynth/prompters/__pycache__/prompt_refiners.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/prompters/__pycache__/sd3_prompter.cpython-310.pyc b/PusaV1/diffsynth/prompters/__pycache__/sd3_prompter.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e97a02046ccac18ced93331587a94a318daf555c
Binary files /dev/null and b/PusaV1/diffsynth/prompters/__pycache__/sd3_prompter.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/prompters/__pycache__/sd3_prompter.cpython-312.pyc b/PusaV1/diffsynth/prompters/__pycache__/sd3_prompter.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ca8d3a9c0e21e857de32d0d84a38033f65419496
Binary files /dev/null and b/PusaV1/diffsynth/prompters/__pycache__/sd3_prompter.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/prompters/__pycache__/sd_prompter.cpython-310.pyc b/PusaV1/diffsynth/prompters/__pycache__/sd_prompter.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5d8d190b9b2d2d59149184aaaf196406609508d7
Binary files /dev/null and b/PusaV1/diffsynth/prompters/__pycache__/sd_prompter.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/prompters/__pycache__/sd_prompter.cpython-312.pyc b/PusaV1/diffsynth/prompters/__pycache__/sd_prompter.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..acb5f5b278d02850600fb69edac87d727c26b644
Binary files /dev/null and b/PusaV1/diffsynth/prompters/__pycache__/sd_prompter.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/prompters/__pycache__/sdxl_prompter.cpython-310.pyc b/PusaV1/diffsynth/prompters/__pycache__/sdxl_prompter.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..78f646c80491c5ccc6793e58dd8359d3d1614ffd
Binary files /dev/null and b/PusaV1/diffsynth/prompters/__pycache__/sdxl_prompter.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/prompters/__pycache__/sdxl_prompter.cpython-312.pyc b/PusaV1/diffsynth/prompters/__pycache__/sdxl_prompter.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a5c013520a746fc44f4037ff5dca2f90411781bf
Binary files /dev/null and b/PusaV1/diffsynth/prompters/__pycache__/sdxl_prompter.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/prompters/__pycache__/stepvideo_prompter.cpython-310.pyc b/PusaV1/diffsynth/prompters/__pycache__/stepvideo_prompter.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..20a8f23cc944f18edc8ebb4f1c5befc682441955
Binary files /dev/null and b/PusaV1/diffsynth/prompters/__pycache__/stepvideo_prompter.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/prompters/__pycache__/stepvideo_prompter.cpython-312.pyc b/PusaV1/diffsynth/prompters/__pycache__/stepvideo_prompter.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c5c0c0c5db91726d20243deabb571f4bc21b320c
Binary files /dev/null and b/PusaV1/diffsynth/prompters/__pycache__/stepvideo_prompter.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/prompters/__pycache__/wan_prompter.cpython-310.pyc b/PusaV1/diffsynth/prompters/__pycache__/wan_prompter.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7e089efc588b67f436aacfaeb48cdde47e1a2f8b
Binary files /dev/null and b/PusaV1/diffsynth/prompters/__pycache__/wan_prompter.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/prompters/__pycache__/wan_prompter.cpython-312.pyc b/PusaV1/diffsynth/prompters/__pycache__/wan_prompter.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6f74808369d54a5bc18162500668f4fd5d4f325d
Binary files /dev/null and b/PusaV1/diffsynth/prompters/__pycache__/wan_prompter.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/prompters/base_prompter.py b/PusaV1/diffsynth/prompters/base_prompter.py
new file mode 100644
index 0000000000000000000000000000000000000000..136abd18fabdb04e618f59801420c9ce5fb94634
--- /dev/null
+++ b/PusaV1/diffsynth/prompters/base_prompter.py
@@ -0,0 +1,70 @@
+from ..models.model_manager import ModelManager
+import torch
+
+
+
+def tokenize_long_prompt(tokenizer, prompt, max_length=None):
+ # Get model_max_length from self.tokenizer
+ length = tokenizer.model_max_length if max_length is None else max_length
+
+ # To avoid the warning. set self.tokenizer.model_max_length to +oo.
+ tokenizer.model_max_length = 99999999
+
+ # Tokenize it!
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids
+
+ # Determine the real length.
+ max_length = (input_ids.shape[1] + length - 1) // length * length
+
+ # Restore tokenizer.model_max_length
+ tokenizer.model_max_length = length
+
+ # Tokenize it again with fixed length.
+ input_ids = tokenizer(
+ prompt,
+ return_tensors="pt",
+ padding="max_length",
+ max_length=max_length,
+ truncation=True
+ ).input_ids
+
+ # Reshape input_ids to fit the text encoder.
+ num_sentence = input_ids.shape[1] // length
+ input_ids = input_ids.reshape((num_sentence, length))
+
+ return input_ids
+
+
+
+class BasePrompter:
+ def __init__(self):
+ self.refiners = []
+ self.extenders = []
+
+
+ def load_prompt_refiners(self, model_manager: ModelManager, refiner_classes=[]):
+ for refiner_class in refiner_classes:
+ refiner = refiner_class.from_model_manager(model_manager)
+ self.refiners.append(refiner)
+
+ def load_prompt_extenders(self,model_manager:ModelManager,extender_classes=[]):
+ for extender_class in extender_classes:
+ extender = extender_class.from_model_manager(model_manager)
+ self.extenders.append(extender)
+
+
+ @torch.no_grad()
+ def process_prompt(self, prompt, positive=True):
+ if isinstance(prompt, list):
+ prompt = [self.process_prompt(prompt_, positive=positive) for prompt_ in prompt]
+ else:
+ for refiner in self.refiners:
+ prompt = refiner(prompt, positive=positive)
+ return prompt
+
+ @torch.no_grad()
+ def extend_prompt(self, prompt:str, positive=True):
+ extended_prompt = dict(prompt=prompt)
+ for extender in self.extenders:
+ extended_prompt = extender(extended_prompt)
+ return extended_prompt
\ No newline at end of file
diff --git a/PusaV1/diffsynth/prompters/cog_prompter.py b/PusaV1/diffsynth/prompters/cog_prompter.py
new file mode 100644
index 0000000000000000000000000000000000000000..a1ab84a69c32e681e087ba7ed0642a6177fe1f7a
--- /dev/null
+++ b/PusaV1/diffsynth/prompters/cog_prompter.py
@@ -0,0 +1,46 @@
+from .base_prompter import BasePrompter
+from ..models.flux_text_encoder import FluxTextEncoder2
+from transformers import T5TokenizerFast
+import os
+
+
+class CogPrompter(BasePrompter):
+ def __init__(
+ self,
+ tokenizer_path=None
+ ):
+ if tokenizer_path is None:
+ base_path = os.path.dirname(os.path.dirname(__file__))
+ tokenizer_path = os.path.join(base_path, "tokenizer_configs/cog/tokenizer")
+ super().__init__()
+ self.tokenizer = T5TokenizerFast.from_pretrained(tokenizer_path)
+ self.text_encoder: FluxTextEncoder2 = None
+
+
+ def fetch_models(self, text_encoder: FluxTextEncoder2 = None):
+ self.text_encoder = text_encoder
+
+
+ def encode_prompt_using_t5(self, prompt, text_encoder, tokenizer, max_length, device):
+ input_ids = tokenizer(
+ prompt,
+ return_tensors="pt",
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ ).input_ids.to(device)
+ prompt_emb = text_encoder(input_ids)
+ prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
+
+ return prompt_emb
+
+
+ def encode_prompt(
+ self,
+ prompt,
+ positive=True,
+ device="cuda"
+ ):
+ prompt = self.process_prompt(prompt, positive=positive)
+ prompt_emb = self.encode_prompt_using_t5(prompt, self.text_encoder, self.tokenizer, 226, device)
+ return prompt_emb
diff --git a/PusaV1/diffsynth/prompters/flux_prompter.py b/PusaV1/diffsynth/prompters/flux_prompter.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3a06ff8df29345f505873cf1b79c963229f3efb
--- /dev/null
+++ b/PusaV1/diffsynth/prompters/flux_prompter.py
@@ -0,0 +1,74 @@
+from .base_prompter import BasePrompter
+from ..models.flux_text_encoder import FluxTextEncoder2
+from ..models.sd3_text_encoder import SD3TextEncoder1
+from transformers import CLIPTokenizer, T5TokenizerFast
+import os, torch
+
+
+class FluxPrompter(BasePrompter):
+ def __init__(
+ self,
+ tokenizer_1_path=None,
+ tokenizer_2_path=None
+ ):
+ if tokenizer_1_path is None:
+ base_path = os.path.dirname(os.path.dirname(__file__))
+ tokenizer_1_path = os.path.join(base_path, "tokenizer_configs/flux/tokenizer_1")
+ if tokenizer_2_path is None:
+ base_path = os.path.dirname(os.path.dirname(__file__))
+ tokenizer_2_path = os.path.join(base_path, "tokenizer_configs/flux/tokenizer_2")
+ super().__init__()
+ self.tokenizer_1 = CLIPTokenizer.from_pretrained(tokenizer_1_path)
+ self.tokenizer_2 = T5TokenizerFast.from_pretrained(tokenizer_2_path)
+ self.text_encoder_1: SD3TextEncoder1 = None
+ self.text_encoder_2: FluxTextEncoder2 = None
+
+
+ def fetch_models(self, text_encoder_1: SD3TextEncoder1 = None, text_encoder_2: FluxTextEncoder2 = None):
+ self.text_encoder_1 = text_encoder_1
+ self.text_encoder_2 = text_encoder_2
+
+
+ def encode_prompt_using_clip(self, prompt, text_encoder, tokenizer, max_length, device):
+ input_ids = tokenizer(
+ prompt,
+ return_tensors="pt",
+ padding="max_length",
+ max_length=max_length,
+ truncation=True
+ ).input_ids.to(device)
+ pooled_prompt_emb, _ = text_encoder(input_ids)
+ return pooled_prompt_emb
+
+
+ def encode_prompt_using_t5(self, prompt, text_encoder, tokenizer, max_length, device):
+ input_ids = tokenizer(
+ prompt,
+ return_tensors="pt",
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ ).input_ids.to(device)
+ prompt_emb = text_encoder(input_ids)
+ return prompt_emb
+
+
+ def encode_prompt(
+ self,
+ prompt,
+ positive=True,
+ device="cuda",
+ t5_sequence_length=512,
+ ):
+ prompt = self.process_prompt(prompt, positive=positive)
+
+ # CLIP
+ pooled_prompt_emb = self.encode_prompt_using_clip(prompt, self.text_encoder_1, self.tokenizer_1, 77, device)
+
+ # T5
+ prompt_emb = self.encode_prompt_using_t5(prompt, self.text_encoder_2, self.tokenizer_2, t5_sequence_length, device)
+
+ # text_ids
+ text_ids = torch.zeros(prompt_emb.shape[0], prompt_emb.shape[1], 3).to(device=device, dtype=prompt_emb.dtype)
+
+ return prompt_emb, pooled_prompt_emb, text_ids
diff --git a/PusaV1/diffsynth/prompters/hunyuan_dit_prompter.py b/PusaV1/diffsynth/prompters/hunyuan_dit_prompter.py
new file mode 100644
index 0000000000000000000000000000000000000000..52a22ed72ab77ef668183119fff67db3141ee561
--- /dev/null
+++ b/PusaV1/diffsynth/prompters/hunyuan_dit_prompter.py
@@ -0,0 +1,69 @@
+from .base_prompter import BasePrompter
+from ..models.model_manager import ModelManager
+from ..models import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
+from transformers import BertTokenizer, AutoTokenizer
+import warnings, os
+
+
+class HunyuanDiTPrompter(BasePrompter):
+ def __init__(
+ self,
+ tokenizer_path=None,
+ tokenizer_t5_path=None
+ ):
+ if tokenizer_path is None:
+ base_path = os.path.dirname(os.path.dirname(__file__))
+ tokenizer_path = os.path.join(base_path, "tokenizer_configs/hunyuan_dit/tokenizer")
+ if tokenizer_t5_path is None:
+ base_path = os.path.dirname(os.path.dirname(__file__))
+ tokenizer_t5_path = os.path.join(base_path, "tokenizer_configs/hunyuan_dit/tokenizer_t5")
+ super().__init__()
+ self.tokenizer = BertTokenizer.from_pretrained(tokenizer_path)
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ self.tokenizer_t5 = AutoTokenizer.from_pretrained(tokenizer_t5_path)
+ self.text_encoder: HunyuanDiTCLIPTextEncoder = None
+ self.text_encoder_t5: HunyuanDiTT5TextEncoder = None
+
+
+ def fetch_models(self, text_encoder: HunyuanDiTCLIPTextEncoder = None, text_encoder_t5: HunyuanDiTT5TextEncoder = None):
+ self.text_encoder = text_encoder
+ self.text_encoder_t5 = text_encoder_t5
+
+
+ def encode_prompt_using_signle_model(self, prompt, text_encoder, tokenizer, max_length, clip_skip, device):
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_attention_mask=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ attention_mask = text_inputs.attention_mask.to(device)
+ prompt_embeds = text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ clip_skip=clip_skip
+ )
+ return prompt_embeds, attention_mask
+
+
+ def encode_prompt(
+ self,
+ prompt,
+ clip_skip=1,
+ clip_skip_2=1,
+ positive=True,
+ device="cuda"
+ ):
+ prompt = self.process_prompt(prompt, positive=positive)
+
+ # CLIP
+ prompt_emb, attention_mask = self.encode_prompt_using_signle_model(prompt, self.text_encoder, self.tokenizer, self.tokenizer.model_max_length, clip_skip, device)
+
+ # T5
+ prompt_emb_t5, attention_mask_t5 = self.encode_prompt_using_signle_model(prompt, self.text_encoder_t5, self.tokenizer_t5, self.tokenizer_t5.model_max_length, clip_skip_2, device)
+
+ return prompt_emb, attention_mask, prompt_emb_t5, attention_mask_t5
diff --git a/PusaV1/diffsynth/prompters/hunyuan_video_prompter.py b/PusaV1/diffsynth/prompters/hunyuan_video_prompter.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b97356cacd4b9ccd9d0912b5694e1c1b4868ae9
--- /dev/null
+++ b/PusaV1/diffsynth/prompters/hunyuan_video_prompter.py
@@ -0,0 +1,275 @@
+from .base_prompter import BasePrompter
+from ..models.sd3_text_encoder import SD3TextEncoder1
+from ..models.hunyuan_video_text_encoder import HunyuanVideoLLMEncoder, HunyuanVideoMLLMEncoder
+from transformers import CLIPTokenizer, LlamaTokenizerFast, CLIPImageProcessor
+import os, torch
+from typing import Union
+
+PROMPT_TEMPLATE_ENCODE = (
+ "<|start_header_id|>system<|end_header_id|>\n\nDescribe the image by detailing the color, shape, size, texture, "
+ "quantity, text, spatial relationships of the objects and background:<|eot_id|>"
+ "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>")
+
+PROMPT_TEMPLATE_ENCODE_VIDEO = (
+ "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: "
+ "1. The main content and theme of the video."
+ "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
+ "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
+ "4. background environment, light, style and atmosphere."
+ "5. camera angles, movements, and transitions used in the video:<|eot_id|>"
+ "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>")
+
+PROMPT_TEMPLATE_ENCODE_I2V = (
+ "<|start_header_id|>system<|end_header_id|>\n\n\nDescribe the image by detailing the color, shape, size, texture, "
+ "quantity, text, spatial relationships of the objects and background:<|eot_id|>"
+ "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"
+)
+
+PROMPT_TEMPLATE_ENCODE_VIDEO_I2V = (
+ "<|start_header_id|>system<|end_header_id|>\n\n\nDescribe the video by detailing the following aspects according to the reference image: "
+ "1. The main content and theme of the video."
+ "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
+ "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
+ "4. background environment, light, style and atmosphere."
+ "5. camera angles, movements, and transitions used in the video:<|eot_id|>\n\n"
+ "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"
+)
+
+PROMPT_TEMPLATE = {
+ "dit-llm-encode": {
+ "template": PROMPT_TEMPLATE_ENCODE,
+ "crop_start": 36,
+ },
+ "dit-llm-encode-video": {
+ "template": PROMPT_TEMPLATE_ENCODE_VIDEO,
+ "crop_start": 95,
+ },
+ "dit-llm-encode-i2v": {
+ "template": PROMPT_TEMPLATE_ENCODE_I2V,
+ "crop_start": 36,
+ "image_emb_start": 5,
+ "image_emb_end": 581,
+ "image_emb_len": 576,
+ "double_return_token_id": 271
+ },
+ "dit-llm-encode-video-i2v": {
+ "template": PROMPT_TEMPLATE_ENCODE_VIDEO_I2V,
+ "crop_start": 103,
+ "image_emb_start": 5,
+ "image_emb_end": 581,
+ "image_emb_len": 576,
+ "double_return_token_id": 271
+ },
+}
+
+NEGATIVE_PROMPT = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion"
+
+
+class HunyuanVideoPrompter(BasePrompter):
+
+ def __init__(
+ self,
+ tokenizer_1_path=None,
+ tokenizer_2_path=None,
+ ):
+ if tokenizer_1_path is None:
+ base_path = os.path.dirname(os.path.dirname(__file__))
+ tokenizer_1_path = os.path.join(
+ base_path, "tokenizer_configs/hunyuan_video/tokenizer_1")
+ if tokenizer_2_path is None:
+ base_path = os.path.dirname(os.path.dirname(__file__))
+ tokenizer_2_path = os.path.join(
+ base_path, "tokenizer_configs/hunyuan_video/tokenizer_2")
+ super().__init__()
+ self.tokenizer_1 = CLIPTokenizer.from_pretrained(tokenizer_1_path)
+ self.tokenizer_2 = LlamaTokenizerFast.from_pretrained(tokenizer_2_path, padding_side='right')
+ self.text_encoder_1: SD3TextEncoder1 = None
+ self.text_encoder_2: HunyuanVideoLLMEncoder = None
+
+ self.prompt_template = PROMPT_TEMPLATE['dit-llm-encode']
+ self.prompt_template_video = PROMPT_TEMPLATE['dit-llm-encode-video']
+
+ def fetch_models(self,
+ text_encoder_1: SD3TextEncoder1 = None,
+ text_encoder_2: Union[HunyuanVideoLLMEncoder, HunyuanVideoMLLMEncoder] = None):
+ self.text_encoder_1 = text_encoder_1
+ self.text_encoder_2 = text_encoder_2
+ if isinstance(text_encoder_2, HunyuanVideoMLLMEncoder):
+ # processor
+ # TODO: may need to replace processor with local implementation
+ base_path = os.path.dirname(os.path.dirname(__file__))
+ tokenizer_2_path = os.path.join(base_path, "tokenizer_configs/hunyuan_video/tokenizer_2")
+ self.processor = CLIPImageProcessor.from_pretrained(tokenizer_2_path)
+ # template
+ self.prompt_template = PROMPT_TEMPLATE['dit-llm-encode-i2v']
+ self.prompt_template_video = PROMPT_TEMPLATE['dit-llm-encode-video-i2v']
+
+ def apply_text_to_template(self, text, template):
+ assert isinstance(template, str)
+ if isinstance(text, list):
+ return [self.apply_text_to_template(text_) for text_ in text]
+ elif isinstance(text, str):
+ # Will send string to tokenizer. Used for llm
+ return template.format(text)
+ else:
+ raise TypeError(f"Unsupported prompt type: {type(text)}")
+
+ def encode_prompt_using_clip(self, prompt, max_length, device):
+ tokenized_result = self.tokenizer_1(
+ prompt,
+ return_tensors="pt",
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_attention_mask=True
+ )
+ input_ids = tokenized_result.input_ids.to(device)
+ attention_mask = tokenized_result.attention_mask.to(device)
+ return self.text_encoder_1(input_ids=input_ids, extra_mask=attention_mask)[0]
+
+ def encode_prompt_using_llm(self,
+ prompt,
+ max_length,
+ device,
+ crop_start,
+ hidden_state_skip_layer=2,
+ use_attention_mask=True):
+ max_length += crop_start
+ inputs = self.tokenizer_2(prompt,
+ return_tensors="pt",
+ padding="max_length",
+ max_length=max_length,
+ truncation=True)
+ input_ids = inputs.input_ids.to(device)
+ attention_mask = inputs.attention_mask.to(device)
+ last_hidden_state = self.text_encoder_2(input_ids, attention_mask, hidden_state_skip_layer)
+
+ # crop out
+ if crop_start > 0:
+ last_hidden_state = last_hidden_state[:, crop_start:]
+ attention_mask = (attention_mask[:, crop_start:] if use_attention_mask else None)
+
+ return last_hidden_state, attention_mask
+
+ def encode_prompt_using_mllm(self,
+ prompt,
+ images,
+ max_length,
+ device,
+ crop_start,
+ hidden_state_skip_layer=2,
+ use_attention_mask=True,
+ image_embed_interleave=4):
+ image_outputs = self.processor(images, return_tensors="pt")["pixel_values"].to(device)
+ max_length += crop_start
+ inputs = self.tokenizer_2(prompt,
+ return_tensors="pt",
+ padding="max_length",
+ max_length=max_length,
+ truncation=True)
+ input_ids = inputs.input_ids.to(device)
+ attention_mask = inputs.attention_mask.to(device)
+ last_hidden_state = self.text_encoder_2(input_ids=input_ids,
+ attention_mask=attention_mask,
+ hidden_state_skip_layer=hidden_state_skip_layer,
+ pixel_values=image_outputs)
+
+ text_crop_start = (crop_start - 1 + self.prompt_template_video.get("image_emb_len", 576))
+ image_crop_start = self.prompt_template_video.get("image_emb_start", 5)
+ image_crop_end = self.prompt_template_video.get("image_emb_end", 581)
+ batch_indices, last_double_return_token_indices = torch.where(
+ input_ids == self.prompt_template_video.get("double_return_token_id", 271))
+ if last_double_return_token_indices.shape[0] == 3:
+ # in case the prompt is too long
+ last_double_return_token_indices = torch.cat((
+ last_double_return_token_indices,
+ torch.tensor([input_ids.shape[-1]]),
+ ))
+ batch_indices = torch.cat((batch_indices, torch.tensor([0])))
+ last_double_return_token_indices = (last_double_return_token_indices.reshape(input_ids.shape[0], -1)[:, -1])
+ batch_indices = batch_indices.reshape(input_ids.shape[0], -1)[:, -1]
+ assistant_crop_start = (last_double_return_token_indices - 1 + self.prompt_template_video.get("image_emb_len", 576) - 4)
+ assistant_crop_end = (last_double_return_token_indices - 1 + self.prompt_template_video.get("image_emb_len", 576))
+ attention_mask_assistant_crop_start = (last_double_return_token_indices - 4)
+ attention_mask_assistant_crop_end = last_double_return_token_indices
+ text_last_hidden_state = []
+ text_attention_mask = []
+ image_last_hidden_state = []
+ image_attention_mask = []
+ for i in range(input_ids.shape[0]):
+ text_last_hidden_state.append(
+ torch.cat([
+ last_hidden_state[i, text_crop_start:assistant_crop_start[i].item()],
+ last_hidden_state[i, assistant_crop_end[i].item():],
+ ]))
+ text_attention_mask.append(
+ torch.cat([
+ attention_mask[
+ i,
+ crop_start:attention_mask_assistant_crop_start[i].item(),
+ ],
+ attention_mask[i, attention_mask_assistant_crop_end[i].item():],
+ ]) if use_attention_mask else None)
+ image_last_hidden_state.append(last_hidden_state[i, image_crop_start:image_crop_end])
+ image_attention_mask.append(
+ torch.ones(image_last_hidden_state[-1].shape[0]).to(last_hidden_state.device).
+ to(attention_mask.dtype) if use_attention_mask else None)
+
+ text_last_hidden_state = torch.stack(text_last_hidden_state)
+ text_attention_mask = torch.stack(text_attention_mask)
+ image_last_hidden_state = torch.stack(image_last_hidden_state)
+ image_attention_mask = torch.stack(image_attention_mask)
+
+ image_last_hidden_state = image_last_hidden_state[:, ::image_embed_interleave, :]
+ image_attention_mask = image_attention_mask[:, ::image_embed_interleave]
+
+ assert (text_last_hidden_state.shape[0] == text_attention_mask.shape[0] and
+ image_last_hidden_state.shape[0] == image_attention_mask.shape[0])
+
+ last_hidden_state = torch.cat([image_last_hidden_state, text_last_hidden_state], dim=1)
+ attention_mask = torch.cat([image_attention_mask, text_attention_mask], dim=1)
+
+ return last_hidden_state, attention_mask
+
+ def encode_prompt(self,
+ prompt,
+ images=None,
+ positive=True,
+ device="cuda",
+ clip_sequence_length=77,
+ llm_sequence_length=256,
+ data_type='video',
+ use_template=True,
+ hidden_state_skip_layer=2,
+ use_attention_mask=True,
+ image_embed_interleave=4):
+
+ prompt = self.process_prompt(prompt, positive=positive)
+
+ # apply template
+ if use_template:
+ template = self.prompt_template_video if data_type == 'video' else self.prompt_template
+ prompt_formated = self.apply_text_to_template(prompt, template['template'])
+ else:
+ prompt_formated = prompt
+ # Text encoder
+ if data_type == 'video':
+ crop_start = self.prompt_template_video.get("crop_start", 0)
+ else:
+ crop_start = self.prompt_template.get("crop_start", 0)
+
+ # CLIP
+ pooled_prompt_emb = self.encode_prompt_using_clip(prompt, clip_sequence_length, device)
+
+ # LLM
+ if images is None:
+ prompt_emb, attention_mask = self.encode_prompt_using_llm(prompt_formated, llm_sequence_length, device, crop_start,
+ hidden_state_skip_layer, use_attention_mask)
+ else:
+ prompt_emb, attention_mask = self.encode_prompt_using_mllm(prompt_formated, images, llm_sequence_length, device,
+ crop_start, hidden_state_skip_layer, use_attention_mask,
+ image_embed_interleave)
+
+ return prompt_emb, pooled_prompt_emb, attention_mask
diff --git a/PusaV1/diffsynth/prompters/kolors_prompter.py b/PusaV1/diffsynth/prompters/kolors_prompter.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3d5d58a9dbb816ea8c8e0e3b4f0433bd11d3306
--- /dev/null
+++ b/PusaV1/diffsynth/prompters/kolors_prompter.py
@@ -0,0 +1,354 @@
+from .base_prompter import BasePrompter
+from ..models.model_manager import ModelManager
+import json, os, re
+from typing import List, Optional, Union, Dict
+from sentencepiece import SentencePieceProcessor
+from transformers import PreTrainedTokenizer
+from transformers.utils import PaddingStrategy
+from transformers.tokenization_utils_base import EncodedInput, BatchEncoding
+from ..models.kolors_text_encoder import ChatGLMModel
+
+
+class SPTokenizer:
+ def __init__(self, model_path: str):
+ # reload tokenizer
+ assert os.path.isfile(model_path), model_path
+ self.sp_model = SentencePieceProcessor(model_file=model_path)
+
+ # BOS / EOS token IDs
+ self.n_words: int = self.sp_model.vocab_size()
+ self.bos_id: int = self.sp_model.bos_id()
+ self.eos_id: int = self.sp_model.eos_id()
+ self.pad_id: int = self.sp_model.unk_id()
+ assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
+
+ role_special_tokens = ["<|system|>", "<|user|>", "<|assistant|>", "<|observation|>"]
+ special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "sop", "eop"] + role_special_tokens
+ self.special_tokens = {}
+ self.index_special_tokens = {}
+ for token in special_tokens:
+ self.special_tokens[token] = self.n_words
+ self.index_special_tokens[self.n_words] = token
+ self.n_words += 1
+ self.role_special_token_expression = "|".join([re.escape(token) for token in role_special_tokens])
+
+ def tokenize(self, s: str, encode_special_tokens=False):
+ if encode_special_tokens:
+ last_index = 0
+ t = []
+ for match in re.finditer(self.role_special_token_expression, s):
+ if last_index < match.start():
+ t.extend(self.sp_model.EncodeAsPieces(s[last_index:match.start()]))
+ t.append(s[match.start():match.end()])
+ last_index = match.end()
+ if last_index < len(s):
+ t.extend(self.sp_model.EncodeAsPieces(s[last_index:]))
+ return t
+ else:
+ return self.sp_model.EncodeAsPieces(s)
+
+ def encode(self, s: str, bos: bool = False, eos: bool = False) -> List[int]:
+ assert type(s) is str
+ t = self.sp_model.encode(s)
+ if bos:
+ t = [self.bos_id] + t
+ if eos:
+ t = t + [self.eos_id]
+ return t
+
+ def decode(self, t: List[int]) -> str:
+ text, buffer = "", []
+ for token in t:
+ if token in self.index_special_tokens:
+ if buffer:
+ text += self.sp_model.decode(buffer)
+ buffer = []
+ text += self.index_special_tokens[token]
+ else:
+ buffer.append(token)
+ if buffer:
+ text += self.sp_model.decode(buffer)
+ return text
+
+ def decode_tokens(self, tokens: List[str]) -> str:
+ text = self.sp_model.DecodePieces(tokens)
+ return text
+
+ def convert_token_to_id(self, token):
+ """ Converts a token (str) in an id using the vocab. """
+ if token in self.special_tokens:
+ return self.special_tokens[token]
+ return self.sp_model.PieceToId(token)
+
+ def convert_id_to_token(self, index):
+ """Converts an index (integer) in a token (str) using the vocab."""
+ if index in self.index_special_tokens:
+ return self.index_special_tokens[index]
+ if index in [self.eos_id, self.bos_id, self.pad_id] or index < 0:
+ return ""
+ return self.sp_model.IdToPiece(index)
+
+
+
+class ChatGLMTokenizer(PreTrainedTokenizer):
+ vocab_files_names = {"vocab_file": "tokenizer.model"}
+
+ model_input_names = ["input_ids", "attention_mask", "position_ids"]
+
+ def __init__(self, vocab_file, padding_side="left", clean_up_tokenization_spaces=False, encode_special_tokens=False,
+ **kwargs):
+ self.name = "GLMTokenizer"
+
+ self.vocab_file = vocab_file
+ self.tokenizer = SPTokenizer(vocab_file)
+ self.special_tokens = {
+ "": self.tokenizer.bos_id,
+ "": self.tokenizer.eos_id,
+ "": self.tokenizer.pad_id
+ }
+ self.encode_special_tokens = encode_special_tokens
+ super().__init__(padding_side=padding_side, clean_up_tokenization_spaces=clean_up_tokenization_spaces,
+ encode_special_tokens=encode_special_tokens,
+ **kwargs)
+
+ def get_command(self, token):
+ if token in self.special_tokens:
+ return self.special_tokens[token]
+ assert token in self.tokenizer.special_tokens, f"{token} is not a special token for {self.name}"
+ return self.tokenizer.special_tokens[token]
+
+ @property
+ def unk_token(self) -> str:
+ return ""
+
+ @property
+ def pad_token(self) -> str:
+ return ""
+
+ @property
+ def pad_token_id(self):
+ return self.get_command("")
+
+ @property
+ def eos_token(self) -> str:
+ return ""
+
+ @property
+ def eos_token_id(self):
+ return self.get_command("")
+
+ @property
+ def vocab_size(self):
+ return self.tokenizer.n_words
+
+ def get_vocab(self):
+ """ Returns vocab as a dict """
+ vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)}
+ vocab.update(self.added_tokens_encoder)
+ return vocab
+
+ def _tokenize(self, text, **kwargs):
+ return self.tokenizer.tokenize(text, encode_special_tokens=self.encode_special_tokens)
+
+ def _convert_token_to_id(self, token):
+ """ Converts a token (str) in an id using the vocab. """
+ return self.tokenizer.convert_token_to_id(token)
+
+ def _convert_id_to_token(self, index):
+ """Converts an index (integer) in a token (str) using the vocab."""
+ return self.tokenizer.convert_id_to_token(index)
+
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
+ return self.tokenizer.decode_tokens(tokens)
+
+ def save_vocabulary(self, save_directory, filename_prefix=None):
+ """
+ Save the vocabulary and special tokens file to a directory.
+
+ Args:
+ save_directory (`str`):
+ The directory in which to save the vocabulary.
+ filename_prefix (`str`, *optional*):
+ An optional prefix to add to the named of the saved files.
+
+ Returns:
+ `Tuple(str)`: Paths to the files saved.
+ """
+ if os.path.isdir(save_directory):
+ vocab_file = os.path.join(
+ save_directory, self.vocab_files_names["vocab_file"]
+ )
+ else:
+ vocab_file = save_directory
+
+ with open(self.vocab_file, 'rb') as fin:
+ proto_str = fin.read()
+
+ with open(vocab_file, "wb") as writer:
+ writer.write(proto_str)
+
+ return (vocab_file,)
+
+ def get_prefix_tokens(self):
+ prefix_tokens = [self.get_command("[gMASK]"), self.get_command("sop")]
+ return prefix_tokens
+
+ def build_single_message(self, role, metadata, message):
+ assert role in ["system", "user", "assistant", "observation"], role
+ role_tokens = [self.get_command(f"<|{role}|>")] + self.tokenizer.encode(f"{metadata}\n")
+ message_tokens = self.tokenizer.encode(message)
+ tokens = role_tokens + message_tokens
+ return tokens
+
+ def build_chat_input(self, query, history=None, role="user"):
+ if history is None:
+ history = []
+ input_ids = []
+ for item in history:
+ content = item["content"]
+ if item["role"] == "system" and "tools" in item:
+ content = content + "\n" + json.dumps(item["tools"], indent=4, ensure_ascii=False)
+ input_ids.extend(self.build_single_message(item["role"], item.get("metadata", ""), content))
+ input_ids.extend(self.build_single_message(role, "", query))
+ input_ids.extend([self.get_command("<|assistant|>")])
+ return self.batch_encode_plus([input_ids], return_tensors="pt", is_split_into_words=True)
+
+ def build_inputs_with_special_tokens(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ ) -> List[int]:
+ """
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
+ adding special tokens. A BERT sequence has the following format:
+
+ - single sequence: `[CLS] X [SEP]`
+ - pair of sequences: `[CLS] A [SEP] B [SEP]`
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs to which the special tokens will be added.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
+ """
+ prefix_tokens = self.get_prefix_tokens()
+ token_ids_0 = prefix_tokens + token_ids_0
+ if token_ids_1 is not None:
+ token_ids_0 = token_ids_0 + token_ids_1 + [self.get_command("")]
+ return token_ids_0
+
+ def _pad(
+ self,
+ encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
+ max_length: Optional[int] = None,
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
+ pad_to_multiple_of: Optional[int] = None,
+ return_attention_mask: Optional[bool] = None,
+ padding_side: Optional[str] = None,
+ ) -> dict:
+ """
+ Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
+
+ Args:
+ encoded_inputs:
+ Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
+ max_length: maximum length of the returned list and optionally padding length (see below).
+ Will truncate by taking into account the special tokens.
+ padding_strategy: PaddingStrategy to use for padding.
+
+ - PaddingStrategy.LONGEST Pad to the longest sequence in the batch
+ - PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
+ - PaddingStrategy.DO_NOT_PAD: Do not pad
+ The tokenizer padding sides are defined in self.padding_side:
+
+ - 'left': pads on the left of the sequences
+ - 'right': pads on the right of the sequences
+ pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
+ This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
+ `>= 7.5` (Volta).
+ return_attention_mask:
+ (optional) Set to False to avoid returning attention mask (default: set to model specifics)
+ """
+ # Load from model defaults
+ assert self.padding_side == "left"
+
+ required_input = encoded_inputs[self.model_input_names[0]]
+ seq_length = len(required_input)
+
+ if padding_strategy == PaddingStrategy.LONGEST:
+ max_length = len(required_input)
+
+ if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
+ max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
+
+ needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
+
+ # Initialize attention mask if not present.
+ if "attention_mask" not in encoded_inputs:
+ encoded_inputs["attention_mask"] = [1] * seq_length
+
+ if "position_ids" not in encoded_inputs:
+ encoded_inputs["position_ids"] = list(range(seq_length))
+
+ if needs_to_be_padded:
+ difference = max_length - len(required_input)
+
+ if "attention_mask" in encoded_inputs:
+ encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
+ if "position_ids" in encoded_inputs:
+ encoded_inputs["position_ids"] = [0] * difference + encoded_inputs["position_ids"]
+ encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
+
+ return encoded_inputs
+
+
+
+class KolorsPrompter(BasePrompter):
+ def __init__(
+ self,
+ tokenizer_path=None
+ ):
+ if tokenizer_path is None:
+ base_path = os.path.dirname(os.path.dirname(__file__))
+ tokenizer_path = os.path.join(base_path, "tokenizer_configs/kolors/tokenizer")
+ super().__init__()
+ self.tokenizer = ChatGLMTokenizer.from_pretrained(tokenizer_path)
+ self.text_encoder: ChatGLMModel = None
+
+
+ def fetch_models(self, text_encoder: ChatGLMModel = None):
+ self.text_encoder = text_encoder
+
+
+ def encode_prompt_using_ChatGLM(self, prompt, text_encoder, tokenizer, max_length, clip_skip, device):
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ ).to(device)
+ output = text_encoder(
+ input_ids=text_inputs['input_ids'] ,
+ attention_mask=text_inputs['attention_mask'],
+ position_ids=text_inputs['position_ids'],
+ output_hidden_states=True
+ )
+ prompt_emb = output.hidden_states[-clip_skip].permute(1, 0, 2).clone()
+ pooled_prompt_emb = output.hidden_states[-1][-1, :, :].clone()
+ return prompt_emb, pooled_prompt_emb
+
+
+ def encode_prompt(
+ self,
+ prompt,
+ clip_skip=1,
+ clip_skip_2=2,
+ positive=True,
+ device="cuda"
+ ):
+ prompt = self.process_prompt(prompt, positive=positive)
+ prompt_emb, pooled_prompt_emb = self.encode_prompt_using_ChatGLM(prompt, self.text_encoder, self.tokenizer, 256, clip_skip_2, device)
+
+ return pooled_prompt_emb, prompt_emb
diff --git a/PusaV1/diffsynth/prompters/omnigen_prompter.py b/PusaV1/diffsynth/prompters/omnigen_prompter.py
new file mode 100644
index 0000000000000000000000000000000000000000..616efabebb7d327ecf968165dd12341ab8f83894
--- /dev/null
+++ b/PusaV1/diffsynth/prompters/omnigen_prompter.py
@@ -0,0 +1,356 @@
+import os
+import re
+from typing import Dict, List
+
+import torch
+from PIL import Image
+from torchvision import transforms
+from transformers import AutoTokenizer
+from huggingface_hub import snapshot_download
+import numpy as np
+
+
+
+def crop_arr(pil_image, max_image_size):
+ while min(*pil_image.size) >= 2 * max_image_size:
+ pil_image = pil_image.resize(
+ tuple(x // 2 for x in pil_image.size), resample=Image.BOX
+ )
+
+ if max(*pil_image.size) > max_image_size:
+ scale = max_image_size / max(*pil_image.size)
+ pil_image = pil_image.resize(
+ tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
+ )
+
+ if min(*pil_image.size) < 16:
+ scale = 16 / min(*pil_image.size)
+ pil_image = pil_image.resize(
+ tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
+ )
+
+ arr = np.array(pil_image)
+ crop_y1 = (arr.shape[0] % 16) // 2
+ crop_y2 = arr.shape[0] % 16 - crop_y1
+
+ crop_x1 = (arr.shape[1] % 16) // 2
+ crop_x2 = arr.shape[1] % 16 - crop_x1
+
+ arr = arr[crop_y1:arr.shape[0]-crop_y2, crop_x1:arr.shape[1]-crop_x2]
+ return Image.fromarray(arr)
+
+
+
+class OmniGenPrompter:
+ def __init__(self,
+ text_tokenizer,
+ max_image_size: int=1024):
+ self.text_tokenizer = text_tokenizer
+ self.max_image_size = max_image_size
+
+ self.image_transform = transforms.Compose([
+ transforms.Lambda(lambda pil_image: crop_arr(pil_image, max_image_size)),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
+ ])
+
+ self.collator = OmniGenCollator()
+ self.separate_collator = OmniGenSeparateCollator()
+
+ @classmethod
+ def from_pretrained(cls, model_name):
+ if not os.path.exists(model_name):
+ cache_folder = os.getenv('HF_HUB_CACHE')
+ model_name = snapshot_download(repo_id=model_name,
+ cache_dir=cache_folder,
+ allow_patterns="*.json")
+ text_tokenizer = AutoTokenizer.from_pretrained(model_name)
+
+ return cls(text_tokenizer)
+
+
+ def process_image(self, image):
+ return self.image_transform(image)
+
+ def process_multi_modal_prompt(self, text, input_images):
+ text = self.add_prefix_instruction(text)
+ if input_images is None or len(input_images) == 0:
+ model_inputs = self.text_tokenizer(text)
+ return {"input_ids": model_inputs.input_ids, "pixel_values": None, "image_sizes": None}
+
+ pattern = r"<\|image_\d+\|>"
+ prompt_chunks = [self.text_tokenizer(chunk).input_ids for chunk in re.split(pattern, text)]
+
+ for i in range(1, len(prompt_chunks)):
+ if prompt_chunks[i][0] == 1:
+ prompt_chunks[i] = prompt_chunks[i][1:]
+
+ image_tags = re.findall(pattern, text)
+ image_ids = [int(s.split("|")[1].split("_")[-1]) for s in image_tags]
+
+ unique_image_ids = sorted(list(set(image_ids)))
+ assert unique_image_ids == list(range(1, len(unique_image_ids)+1)), f"image_ids must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be {unique_image_ids}"
+ # total images must be the same as the number of image tags
+ assert len(unique_image_ids) == len(input_images), f"total images must be the same as the number of image tags, got {len(unique_image_ids)} image tags and {len(input_images)} images"
+
+ input_images = [input_images[x-1] for x in image_ids]
+
+ all_input_ids = []
+ img_inx = []
+ idx = 0
+ for i in range(len(prompt_chunks)):
+ all_input_ids.extend(prompt_chunks[i])
+ if i != len(prompt_chunks) -1:
+ start_inx = len(all_input_ids)
+ size = input_images[i].size(-2) * input_images[i].size(-1) // 16 // 16
+ img_inx.append([start_inx, start_inx+size])
+ all_input_ids.extend([0]*size)
+
+ return {"input_ids": all_input_ids, "pixel_values": input_images, "image_sizes": img_inx}
+
+
+ def add_prefix_instruction(self, prompt):
+ user_prompt = '<|user|>\n'
+ generation_prompt = 'Generate an image according to the following instructions\n'
+ assistant_prompt = '<|assistant|>\n<|diffusion|>'
+ prompt_suffix = "<|end|>\n"
+ prompt = f"{user_prompt}{generation_prompt}{prompt}{prompt_suffix}{assistant_prompt}"
+ return prompt
+
+
+ def __call__(self,
+ instructions: List[str],
+ input_images: List[List[str]] = None,
+ height: int = 1024,
+ width: int = 1024,
+ negative_prompt: str = "low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers.",
+ use_img_cfg: bool = True,
+ separate_cfg_input: bool = False,
+ use_input_image_size_as_output: bool=False,
+ ) -> Dict:
+
+ if input_images is None:
+ use_img_cfg = False
+ if isinstance(instructions, str):
+ instructions = [instructions]
+ input_images = [input_images]
+
+ input_data = []
+ for i in range(len(instructions)):
+ cur_instruction = instructions[i]
+ cur_input_images = None if input_images is None else input_images[i]
+ if cur_input_images is not None and len(cur_input_images) > 0:
+ cur_input_images = [self.process_image(x) for x in cur_input_images]
+ else:
+ cur_input_images = None
+ assert "
<|image_1|>" not in cur_instruction
+
+ mllm_input = self.process_multi_modal_prompt(cur_instruction, cur_input_images)
+
+
+ neg_mllm_input, img_cfg_mllm_input = None, None
+ neg_mllm_input = self.process_multi_modal_prompt(negative_prompt, None)
+ if use_img_cfg:
+ if cur_input_images is not None and len(cur_input_images) >= 1:
+ img_cfg_prompt = [f"
<|image_{i+1}|>" for i in range(len(cur_input_images))]
+ img_cfg_mllm_input = self.process_multi_modal_prompt(" ".join(img_cfg_prompt), cur_input_images)
+ else:
+ img_cfg_mllm_input = neg_mllm_input
+
+ if use_input_image_size_as_output:
+ input_data.append((mllm_input, neg_mllm_input, img_cfg_mllm_input, [mllm_input['pixel_values'][0].size(-2), mllm_input['pixel_values'][0].size(-1)]))
+ else:
+ input_data.append((mllm_input, neg_mllm_input, img_cfg_mllm_input, [height, width]))
+
+ if separate_cfg_input:
+ return self.separate_collator(input_data)
+ return self.collator(input_data)
+
+
+
+
+class OmniGenCollator:
+ def __init__(self, pad_token_id=2, hidden_size=3072):
+ self.pad_token_id = pad_token_id
+ self.hidden_size = hidden_size
+
+ def create_position(self, attention_mask, num_tokens_for_output_images):
+ position_ids = []
+ text_length = attention_mask.size(-1)
+ img_length = max(num_tokens_for_output_images)
+ for mask in attention_mask:
+ temp_l = torch.sum(mask)
+ temp_position = [0]*(text_length-temp_l) + [i for i in range(temp_l+img_length+1)] # we add a time embedding into the sequence, so add one more token
+ position_ids.append(temp_position)
+ return torch.LongTensor(position_ids)
+
+ def create_mask(self, attention_mask, num_tokens_for_output_images):
+ extended_mask = []
+ padding_images = []
+ text_length = attention_mask.size(-1)
+ img_length = max(num_tokens_for_output_images)
+ seq_len = text_length + img_length + 1 # we add a time embedding into the sequence, so add one more token
+ inx = 0
+ for mask in attention_mask:
+ temp_l = torch.sum(mask)
+ pad_l = text_length - temp_l
+
+ temp_mask = torch.tril(torch.ones(size=(temp_l+1, temp_l+1)))
+
+ image_mask = torch.zeros(size=(temp_l+1, img_length))
+ temp_mask = torch.cat([temp_mask, image_mask], dim=-1)
+
+ image_mask = torch.ones(size=(img_length, temp_l+img_length+1))
+ temp_mask = torch.cat([temp_mask, image_mask], dim=0)
+
+ if pad_l > 0:
+ pad_mask = torch.zeros(size=(temp_l+1+img_length, pad_l))
+ temp_mask = torch.cat([pad_mask, temp_mask], dim=-1)
+
+ pad_mask = torch.ones(size=(pad_l, seq_len))
+ temp_mask = torch.cat([pad_mask, temp_mask], dim=0)
+
+ true_img_length = num_tokens_for_output_images[inx]
+ pad_img_length = img_length - true_img_length
+ if pad_img_length > 0:
+ temp_mask[:, -pad_img_length:] = 0
+ temp_padding_imgs = torch.zeros(size=(1, pad_img_length, self.hidden_size))
+ else:
+ temp_padding_imgs = None
+
+ extended_mask.append(temp_mask.unsqueeze(0))
+ padding_images.append(temp_padding_imgs)
+ inx += 1
+ return torch.cat(extended_mask, dim=0), padding_images
+
+ def adjust_attention_for_input_images(self, attention_mask, image_sizes):
+ for b_inx in image_sizes.keys():
+ for start_inx, end_inx in image_sizes[b_inx]:
+ attention_mask[b_inx][start_inx:end_inx, start_inx:end_inx] = 1
+
+ return attention_mask
+
+ def pad_input_ids(self, input_ids, image_sizes):
+ max_l = max([len(x) for x in input_ids])
+ padded_ids = []
+ attention_mask = []
+ new_image_sizes = []
+
+ for i in range(len(input_ids)):
+ temp_ids = input_ids[i]
+ temp_l = len(temp_ids)
+ pad_l = max_l - temp_l
+ if pad_l == 0:
+ attention_mask.append([1]*max_l)
+ padded_ids.append(temp_ids)
+ else:
+ attention_mask.append([0]*pad_l+[1]*temp_l)
+ padded_ids.append([self.pad_token_id]*pad_l+temp_ids)
+
+ if i in image_sizes:
+ new_inx = []
+ for old_inx in image_sizes[i]:
+ new_inx.append([x+pad_l for x in old_inx])
+ image_sizes[i] = new_inx
+
+ return torch.LongTensor(padded_ids), torch.LongTensor(attention_mask), image_sizes
+
+
+ def process_mllm_input(self, mllm_inputs, target_img_size):
+ num_tokens_for_output_images = []
+ for img_size in target_img_size:
+ num_tokens_for_output_images.append(img_size[0]*img_size[1]//16//16)
+
+ pixel_values, image_sizes = [], {}
+ b_inx = 0
+ for x in mllm_inputs:
+ if x['pixel_values'] is not None:
+ pixel_values.extend(x['pixel_values'])
+ for size in x['image_sizes']:
+ if b_inx not in image_sizes:
+ image_sizes[b_inx] = [size]
+ else:
+ image_sizes[b_inx].append(size)
+ b_inx += 1
+ pixel_values = [x.unsqueeze(0) for x in pixel_values]
+
+
+ input_ids = [x['input_ids'] for x in mllm_inputs]
+ padded_input_ids, attention_mask, image_sizes = self.pad_input_ids(input_ids, image_sizes)
+ position_ids = self.create_position(attention_mask, num_tokens_for_output_images)
+ attention_mask, padding_images = self.create_mask(attention_mask, num_tokens_for_output_images)
+ attention_mask = self.adjust_attention_for_input_images(attention_mask, image_sizes)
+
+ return padded_input_ids, position_ids, attention_mask, padding_images, pixel_values, image_sizes
+
+
+ def __call__(self, features):
+ mllm_inputs = [f[0] for f in features]
+ cfg_mllm_inputs = [f[1] for f in features]
+ img_cfg_mllm_input = [f[2] for f in features]
+ target_img_size = [f[3] for f in features]
+
+
+ if img_cfg_mllm_input[0] is not None:
+ mllm_inputs = mllm_inputs + cfg_mllm_inputs + img_cfg_mllm_input
+ target_img_size = target_img_size + target_img_size + target_img_size
+ else:
+ mllm_inputs = mllm_inputs + cfg_mllm_inputs
+ target_img_size = target_img_size + target_img_size
+
+
+ all_padded_input_ids, all_position_ids, all_attention_mask, all_padding_images, all_pixel_values, all_image_sizes = self.process_mllm_input(mllm_inputs, target_img_size)
+
+ data = {"input_ids": all_padded_input_ids,
+ "attention_mask": all_attention_mask,
+ "position_ids": all_position_ids,
+ "input_pixel_values": all_pixel_values,
+ "input_image_sizes": all_image_sizes,
+ "padding_images": all_padding_images,
+ }
+ return data
+
+
+class OmniGenSeparateCollator(OmniGenCollator):
+ def __call__(self, features):
+ mllm_inputs = [f[0] for f in features]
+ cfg_mllm_inputs = [f[1] for f in features]
+ img_cfg_mllm_input = [f[2] for f in features]
+ target_img_size = [f[3] for f in features]
+
+ all_padded_input_ids, all_attention_mask, all_position_ids, all_pixel_values, all_image_sizes, all_padding_images = [], [], [], [], [], []
+
+
+ padded_input_ids, position_ids, attention_mask, padding_images, pixel_values, image_sizes = self.process_mllm_input(mllm_inputs, target_img_size)
+ all_padded_input_ids.append(padded_input_ids)
+ all_attention_mask.append(attention_mask)
+ all_position_ids.append(position_ids)
+ all_pixel_values.append(pixel_values)
+ all_image_sizes.append(image_sizes)
+ all_padding_images.append(padding_images)
+
+ if cfg_mllm_inputs[0] is not None:
+ padded_input_ids, position_ids, attention_mask, padding_images, pixel_values, image_sizes = self.process_mllm_input(cfg_mllm_inputs, target_img_size)
+ all_padded_input_ids.append(padded_input_ids)
+ all_attention_mask.append(attention_mask)
+ all_position_ids.append(position_ids)
+ all_pixel_values.append(pixel_values)
+ all_image_sizes.append(image_sizes)
+ all_padding_images.append(padding_images)
+ if img_cfg_mllm_input[0] is not None:
+ padded_input_ids, position_ids, attention_mask, padding_images, pixel_values, image_sizes = self.process_mllm_input(img_cfg_mllm_input, target_img_size)
+ all_padded_input_ids.append(padded_input_ids)
+ all_attention_mask.append(attention_mask)
+ all_position_ids.append(position_ids)
+ all_pixel_values.append(pixel_values)
+ all_image_sizes.append(image_sizes)
+ all_padding_images.append(padding_images)
+
+ data = {"input_ids": all_padded_input_ids,
+ "attention_mask": all_attention_mask,
+ "position_ids": all_position_ids,
+ "input_pixel_values": all_pixel_values,
+ "input_image_sizes": all_image_sizes,
+ "padding_images": all_padding_images,
+ }
+ return data
diff --git a/PusaV1/diffsynth/prompters/omost.py b/PusaV1/diffsynth/prompters/omost.py
new file mode 100644
index 0000000000000000000000000000000000000000..81828ad79978103eea42389d439847c0877cbd85
--- /dev/null
+++ b/PusaV1/diffsynth/prompters/omost.py
@@ -0,0 +1,323 @@
+from transformers import AutoTokenizer, TextIteratorStreamer
+import difflib
+import torch
+import numpy as np
+import re
+from ..models.model_manager import ModelManager
+from PIL import Image
+
+valid_colors = { # r, g, b
+ 'aliceblue': (240, 248, 255), 'antiquewhite': (250, 235, 215), 'aqua': (0, 255, 255),
+ 'aquamarine': (127, 255, 212), 'azure': (240, 255, 255), 'beige': (245, 245, 220),
+ 'bisque': (255, 228, 196), 'black': (0, 0, 0), 'blanchedalmond': (255, 235, 205), 'blue': (0, 0, 255),
+ 'blueviolet': (138, 43, 226), 'brown': (165, 42, 42), 'burlywood': (222, 184, 135),
+ 'cadetblue': (95, 158, 160), 'chartreuse': (127, 255, 0), 'chocolate': (210, 105, 30),
+ 'coral': (255, 127, 80), 'cornflowerblue': (100, 149, 237), 'cornsilk': (255, 248, 220),
+ 'crimson': (220, 20, 60), 'cyan': (0, 255, 255), 'darkblue': (0, 0, 139), 'darkcyan': (0, 139, 139),
+ 'darkgoldenrod': (184, 134, 11), 'darkgray': (169, 169, 169), 'darkgrey': (169, 169, 169),
+ 'darkgreen': (0, 100, 0), 'darkkhaki': (189, 183, 107), 'darkmagenta': (139, 0, 139),
+ 'darkolivegreen': (85, 107, 47), 'darkorange': (255, 140, 0), 'darkorchid': (153, 50, 204),
+ 'darkred': (139, 0, 0), 'darksalmon': (233, 150, 122), 'darkseagreen': (143, 188, 143),
+ 'darkslateblue': (72, 61, 139), 'darkslategray': (47, 79, 79), 'darkslategrey': (47, 79, 79),
+ 'darkturquoise': (0, 206, 209), 'darkviolet': (148, 0, 211), 'deeppink': (255, 20, 147),
+ 'deepskyblue': (0, 191, 255), 'dimgray': (105, 105, 105), 'dimgrey': (105, 105, 105),
+ 'dodgerblue': (30, 144, 255), 'firebrick': (178, 34, 34), 'floralwhite': (255, 250, 240),
+ 'forestgreen': (34, 139, 34), 'fuchsia': (255, 0, 255), 'gainsboro': (220, 220, 220),
+ 'ghostwhite': (248, 248, 255), 'gold': (255, 215, 0), 'goldenrod': (218, 165, 32),
+ 'gray': (128, 128, 128), 'grey': (128, 128, 128), 'green': (0, 128, 0), 'greenyellow': (173, 255, 47),
+ 'honeydew': (240, 255, 240), 'hotpink': (255, 105, 180), 'indianred': (205, 92, 92),
+ 'indigo': (75, 0, 130), 'ivory': (255, 255, 240), 'khaki': (240, 230, 140), 'lavender': (230, 230, 250),
+ 'lavenderblush': (255, 240, 245), 'lawngreen': (124, 252, 0), 'lemonchiffon': (255, 250, 205),
+ 'lightblue': (173, 216, 230), 'lightcoral': (240, 128, 128), 'lightcyan': (224, 255, 255),
+ 'lightgoldenrodyellow': (250, 250, 210), 'lightgray': (211, 211, 211), 'lightgrey': (211, 211, 211),
+ 'lightgreen': (144, 238, 144), 'lightpink': (255, 182, 193), 'lightsalmon': (255, 160, 122),
+ 'lightseagreen': (32, 178, 170), 'lightskyblue': (135, 206, 250), 'lightslategray': (119, 136, 153),
+ 'lightslategrey': (119, 136, 153), 'lightsteelblue': (176, 196, 222), 'lightyellow': (255, 255, 224),
+ 'lime': (0, 255, 0), 'limegreen': (50, 205, 50), 'linen': (250, 240, 230), 'magenta': (255, 0, 255),
+ 'maroon': (128, 0, 0), 'mediumaquamarine': (102, 205, 170), 'mediumblue': (0, 0, 205),
+ 'mediumorchid': (186, 85, 211), 'mediumpurple': (147, 112, 219), 'mediumseagreen': (60, 179, 113),
+ 'mediumslateblue': (123, 104, 238), 'mediumspringgreen': (0, 250, 154),
+ 'mediumturquoise': (72, 209, 204), 'mediumvioletred': (199, 21, 133), 'midnightblue': (25, 25, 112),
+ 'mintcream': (245, 255, 250), 'mistyrose': (255, 228, 225), 'moccasin': (255, 228, 181),
+ 'navajowhite': (255, 222, 173), 'navy': (0, 0, 128), 'navyblue': (0, 0, 128),
+ 'oldlace': (253, 245, 230), 'olive': (128, 128, 0), 'olivedrab': (107, 142, 35),
+ 'orange': (255, 165, 0), 'orangered': (255, 69, 0), 'orchid': (218, 112, 214),
+ 'palegoldenrod': (238, 232, 170), 'palegreen': (152, 251, 152), 'paleturquoise': (175, 238, 238),
+ 'palevioletred': (219, 112, 147), 'papayawhip': (255, 239, 213), 'peachpuff': (255, 218, 185),
+ 'peru': (205, 133, 63), 'pink': (255, 192, 203), 'plum': (221, 160, 221), 'powderblue': (176, 224, 230),
+ 'purple': (128, 0, 128), 'rebeccapurple': (102, 51, 153), 'red': (255, 0, 0),
+ 'rosybrown': (188, 143, 143), 'royalblue': (65, 105, 225), 'saddlebrown': (139, 69, 19),
+ 'salmon': (250, 128, 114), 'sandybrown': (244, 164, 96), 'seagreen': (46, 139, 87),
+ 'seashell': (255, 245, 238), 'sienna': (160, 82, 45), 'silver': (192, 192, 192),
+ 'skyblue': (135, 206, 235), 'slateblue': (106, 90, 205), 'slategray': (112, 128, 144),
+ 'slategrey': (112, 128, 144), 'snow': (255, 250, 250), 'springgreen': (0, 255, 127),
+ 'steelblue': (70, 130, 180), 'tan': (210, 180, 140), 'teal': (0, 128, 128), 'thistle': (216, 191, 216),
+ 'tomato': (255, 99, 71), 'turquoise': (64, 224, 208), 'violet': (238, 130, 238),
+ 'wheat': (245, 222, 179), 'white': (255, 255, 255), 'whitesmoke': (245, 245, 245),
+ 'yellow': (255, 255, 0), 'yellowgreen': (154, 205, 50)
+}
+
+valid_locations = { # x, y in 90*90
+ 'in the center': (45, 45),
+ 'on the left': (15, 45),
+ 'on the right': (75, 45),
+ 'on the top': (45, 15),
+ 'on the bottom': (45, 75),
+ 'on the top-left': (15, 15),
+ 'on the top-right': (75, 15),
+ 'on the bottom-left': (15, 75),
+ 'on the bottom-right': (75, 75)
+}
+
+valid_offsets = { # x, y in 90*90
+ 'no offset': (0, 0),
+ 'slightly to the left': (-10, 0),
+ 'slightly to the right': (10, 0),
+ 'slightly to the upper': (0, -10),
+ 'slightly to the lower': (0, 10),
+ 'slightly to the upper-left': (-10, -10),
+ 'slightly to the upper-right': (10, -10),
+ 'slightly to the lower-left': (-10, 10),
+ 'slightly to the lower-right': (10, 10)}
+
+valid_areas = { # w, h in 90*90
+ "a small square area": (50, 50),
+ "a small vertical area": (40, 60),
+ "a small horizontal area": (60, 40),
+ "a medium-sized square area": (60, 60),
+ "a medium-sized vertical area": (50, 80),
+ "a medium-sized horizontal area": (80, 50),
+ "a large square area": (70, 70),
+ "a large vertical area": (60, 90),
+ "a large horizontal area": (90, 60)
+}
+
+def safe_str(x):
+ return x.strip(',. ') + '.'
+
+def closest_name(input_str, options):
+ input_str = input_str.lower()
+
+ closest_match = difflib.get_close_matches(input_str, list(options.keys()), n=1, cutoff=0.5)
+ assert isinstance(closest_match, list) and len(closest_match) > 0, f'The value [{input_str}] is not valid!'
+ result = closest_match[0]
+
+ if result != input_str:
+ print(f'Automatically corrected [{input_str}] -> [{result}].')
+
+ return result
+
+class Canvas:
+ @staticmethod
+ def from_bot_response(response: str):
+
+ matched = re.search(r'```python\n(.*?)\n```', response, re.DOTALL)
+ assert matched, 'Response does not contain codes!'
+ code_content = matched.group(1)
+ assert 'canvas = Canvas()' in code_content, 'Code block must include valid canvas var!'
+ local_vars = {'Canvas': Canvas}
+ exec(code_content, {}, local_vars)
+ canvas = local_vars.get('canvas', None)
+ assert isinstance(canvas, Canvas), 'Code block must produce valid canvas var!'
+ return canvas
+
+ def __init__(self):
+ self.components = []
+ self.color = None
+ self.record_tags = True
+ self.prefixes = []
+ self.suffixes = []
+ return
+
+ def set_global_description(self, description: str, detailed_descriptions: list, tags: str,
+ HTML_web_color_name: str):
+ assert isinstance(description, str), 'Global description is not valid!'
+ assert isinstance(detailed_descriptions, list) and all(isinstance(item, str) for item in detailed_descriptions), \
+ 'Global detailed_descriptions is not valid!'
+ assert isinstance(tags, str), 'Global tags is not valid!'
+
+ HTML_web_color_name = closest_name(HTML_web_color_name, valid_colors)
+ self.color = np.array([[valid_colors[HTML_web_color_name]]], dtype=np.uint8)
+
+ self.prefixes = [description]
+ self.suffixes = detailed_descriptions
+
+ if self.record_tags:
+ self.suffixes = self.suffixes + [tags]
+
+ self.prefixes = [safe_str(x) for x in self.prefixes]
+ self.suffixes = [safe_str(x) for x in self.suffixes]
+
+ return
+
+ def add_local_description(self, location: str, offset: str, area: str, distance_to_viewer: float, description: str,
+ detailed_descriptions: list, tags: str, atmosphere: str, style: str,
+ quality_meta: str, HTML_web_color_name: str):
+ assert isinstance(description, str), 'Local description is wrong!'
+ assert isinstance(distance_to_viewer, (int, float)) and distance_to_viewer > 0, \
+ f'The distance_to_viewer for [{description}] is not positive float number!'
+ assert isinstance(detailed_descriptions, list) and all(isinstance(item, str) for item in detailed_descriptions), \
+ f'The detailed_descriptions for [{description}] is not valid!'
+ assert isinstance(tags, str), f'The tags for [{description}] is not valid!'
+ assert isinstance(atmosphere, str), f'The atmosphere for [{description}] is not valid!'
+ assert isinstance(style, str), f'The style for [{description}] is not valid!'
+ assert isinstance(quality_meta, str), f'The quality_meta for [{description}] is not valid!'
+
+ location = closest_name(location, valid_locations)
+ offset = closest_name(offset, valid_offsets)
+ area = closest_name(area, valid_areas)
+ HTML_web_color_name = closest_name(HTML_web_color_name, valid_colors)
+
+ xb, yb = valid_locations[location]
+ xo, yo = valid_offsets[offset]
+ w, h = valid_areas[area]
+ rect = (yb + yo - h // 2, yb + yo + h // 2, xb + xo - w // 2, xb + xo + w // 2)
+ rect = [max(0, min(90, i)) for i in rect]
+ color = np.array([[valid_colors[HTML_web_color_name]]], dtype=np.uint8)
+
+ prefixes = self.prefixes + [description]
+ suffixes = detailed_descriptions
+
+ if self.record_tags:
+ suffixes = suffixes + [tags, atmosphere, style, quality_meta]
+
+ prefixes = [safe_str(x) for x in prefixes]
+ suffixes = [safe_str(x) for x in suffixes]
+
+ self.components.append(dict(
+ rect=rect,
+ distance_to_viewer=distance_to_viewer,
+ color=color,
+ prefixes=prefixes,
+ suffixes=suffixes,
+ location=location,
+ ))
+
+ return
+
+ def process(self):
+ # sort components
+ self.components = sorted(self.components, key=lambda x: x['distance_to_viewer'], reverse=True)
+
+ # compute initial latent
+ # print(self.color)
+ initial_latent = np.zeros(shape=(90, 90, 3), dtype=np.float32) + self.color
+
+ for component in self.components:
+ a, b, c, d = component['rect']
+ initial_latent[a:b, c:d] = 0.7 * component['color'] + 0.3 * initial_latent[a:b, c:d]
+
+ initial_latent = initial_latent.clip(0, 255).astype(np.uint8)
+
+ # compute conditions
+
+ bag_of_conditions = [
+ dict(mask=np.ones(shape=(90, 90), dtype=np.float32), prefixes=self.prefixes, suffixes=self.suffixes,location= "full")
+ ]
+
+ for i, component in enumerate(self.components):
+ a, b, c, d = component['rect']
+ m = np.zeros(shape=(90, 90), dtype=np.float32)
+ m[a:b, c:d] = 1.0
+ bag_of_conditions.append(dict(
+ mask = m,
+ prefixes = component['prefixes'],
+ suffixes = component['suffixes'],
+ location = component['location'],
+ ))
+
+ return dict(
+ initial_latent = initial_latent,
+ bag_of_conditions = bag_of_conditions,
+ )
+
+
+class OmostPromter(torch.nn.Module):
+
+ def __init__(self,model = None,tokenizer = None, template = "",device="cpu"):
+ super().__init__()
+ self.model=model
+ self.tokenizer = tokenizer
+ self.device = device
+ if template == "":
+ template = r'''You are a helpful AI assistant to compose images using the below python class `Canvas`:
+ ```python
+ class Canvas:
+ def set_global_description(self, description: str, detailed_descriptions: list[str], tags: str, HTML_web_color_name: str):
+ pass
+
+ def add_local_description(self, location: str, offset: str, area: str, distance_to_viewer: float, description: str, detailed_descriptions: list[str], tags: str, atmosphere: str, style: str, quality_meta: str, HTML_web_color_name: str):
+ assert location in ["in the center", "on the left", "on the right", "on the top", "on the bottom", "on the top-left", "on the top-right", "on the bottom-left", "on the bottom-right"]
+ assert offset in ["no offset", "slightly to the left", "slightly to the right", "slightly to the upper", "slightly to the lower", "slightly to the upper-left", "slightly to the upper-right", "slightly to the lower-left", "slightly to the lower-right"]
+ assert area in ["a small square area", "a small vertical area", "a small horizontal area", "a medium-sized square area", "a medium-sized vertical area", "a medium-sized horizontal area", "a large square area", "a large vertical area", "a large horizontal area"]
+ assert distance_to_viewer > 0
+ pass
+ ```'''
+ self.template = template
+
+ @staticmethod
+ def from_model_manager(model_manager: ModelManager):
+ model, model_path = model_manager.fetch_model("omost_prompt", require_model_path=True)
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
+ omost = OmostPromter(
+ model= model,
+ tokenizer = tokenizer,
+ device = model_manager.device
+ )
+ return omost
+
+
+ def __call__(self,prompt_dict:dict):
+ raw_prompt=prompt_dict["prompt"]
+ conversation = [{"role": "system", "content": self.template}]
+ conversation.append({"role": "user", "content": raw_prompt})
+
+ input_ids = self.tokenizer.apply_chat_template(conversation, return_tensors="pt", add_generation_prompt=True).to(self.device)
+ streamer = TextIteratorStreamer(self.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
+ attention_mask = torch.ones(input_ids.shape, dtype=torch.bfloat16, device=self.device)
+
+ generate_kwargs = dict(
+ input_ids = input_ids,
+ streamer = streamer,
+ # stopping_criteria=stopping_criteria,
+ # max_new_tokens=max_new_tokens,
+ do_sample = True,
+ attention_mask = attention_mask,
+ pad_token_id = self.tokenizer.eos_token_id,
+ # temperature=temperature,
+ # top_p=top_p,
+ )
+ self.model.generate(**generate_kwargs)
+ outputs = []
+ for text in streamer:
+ outputs.append(text)
+ llm_outputs = "".join(outputs)
+
+ canvas = Canvas.from_bot_response(llm_outputs)
+ canvas_output = canvas.process()
+
+ prompts = [" ".join(_["prefixes"]+_["suffixes"][:2]) for _ in canvas_output["bag_of_conditions"]]
+ canvas_output["prompt"] = prompts[0]
+ canvas_output["prompts"] = prompts[1:]
+
+ raw_masks = [_["mask"] for _ in canvas_output["bag_of_conditions"]]
+ masks=[]
+ for mask in raw_masks:
+ mask[mask>0.5]=255
+ mask = np.stack([mask] * 3, axis=-1).astype("uint8")
+ masks.append(Image.fromarray(mask))
+
+ canvas_output["masks"] = masks
+ prompt_dict.update(canvas_output)
+ print(f"Your prompt is extended by Omost:\n")
+ cnt = 0
+ for component,pmt in zip(canvas_output["bag_of_conditions"],prompts):
+ loc = component["location"]
+ cnt += 1
+ print(f"Component {cnt} - Location : {loc}\nPrompt:{pmt}\n")
+
+ return prompt_dict
+
+
+
+
\ No newline at end of file
diff --git a/PusaV1/diffsynth/prompters/prompt_refiners.py b/PusaV1/diffsynth/prompters/prompt_refiners.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ac19f565b076cccb21d9e05149b604e4bb55854
--- /dev/null
+++ b/PusaV1/diffsynth/prompters/prompt_refiners.py
@@ -0,0 +1,130 @@
+from transformers import AutoTokenizer
+from ..models.model_manager import ModelManager
+import torch
+from .omost import OmostPromter
+
+class BeautifulPrompt(torch.nn.Module):
+ def __init__(self, tokenizer_path=None, model=None, template=""):
+ super().__init__()
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
+ self.model = model
+ self.template = template
+
+
+ @staticmethod
+ def from_model_manager(model_manager: ModelManager):
+ model, model_path = model_manager.fetch_model("beautiful_prompt", require_model_path=True)
+ template = 'Instruction: Give a simple description of the image to generate a drawing prompt.\nInput: {raw_prompt}\nOutput:'
+ if model_path.endswith("v2"):
+ template = """Converts a simple image description into a prompt. \
+Prompts are formatted as multiple related tags separated by commas, plus you can use () to increase the weight, [] to decrease the weight, \
+or use a number to specify the weight. You should add appropriate words to make the images described in the prompt more aesthetically pleasing, \
+but make sure there is a correlation between the input and output.\n\
+### Input: {raw_prompt}\n### Output:"""
+ beautiful_prompt = BeautifulPrompt(
+ tokenizer_path=model_path,
+ model=model,
+ template=template
+ )
+ return beautiful_prompt
+
+
+ def __call__(self, raw_prompt, positive=True, **kwargs):
+ if positive:
+ model_input = self.template.format(raw_prompt=raw_prompt)
+ input_ids = self.tokenizer.encode(model_input, return_tensors='pt').to(self.model.device)
+ outputs = self.model.generate(
+ input_ids,
+ max_new_tokens=384,
+ do_sample=True,
+ temperature=0.9,
+ top_k=50,
+ top_p=0.95,
+ repetition_penalty=1.1,
+ num_return_sequences=1
+ )
+ prompt = raw_prompt + ", " + self.tokenizer.batch_decode(
+ outputs[:, input_ids.size(1):],
+ skip_special_tokens=True
+ )[0].strip()
+ print(f"Your prompt is refined by BeautifulPrompt: {prompt}")
+ return prompt
+ else:
+ return raw_prompt
+
+
+
+class QwenPrompt(torch.nn.Module):
+ # This class leverages the open-source Qwen model to translate Chinese prompts into English,
+ # with an integrated optimization mechanism for enhanced translation quality.
+ def __init__(self, tokenizer_path=None, model=None, system_prompt=""):
+ super().__init__()
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
+ self.model = model
+ self.system_prompt = system_prompt
+
+
+ @staticmethod
+ def from_model_manager(model_nameger: ModelManager):
+ model, model_path = model_nameger.fetch_model("qwen_prompt", require_model_path=True)
+ system_prompt = """You are an English image describer. Here are some example image styles:\n\n1. Extreme close-up: Clear focus on a single object with a blurred background, highlighted under natural sunlight.\n2. Vintage: A photograph of a historical scene, using techniques such as Daguerreotype or cyanotype.\n3. Anime: A stylized cartoon image, emphasizing hyper-realistic portraits and luminous brushwork.\n4. Candid: A natural, unposed shot capturing spontaneous moments, often with cinematic qualities.\n5. Landscape: A photorealistic image of natural scenery, such as a sunrise over the sea.\n6. Design: Colorful and detailed illustrations, often in the style of 2D game art or botanical illustrations.\n7. Urban: An ultrarealistic scene in a modern setting, possibly a cityscape viewed from indoors.\n\nYour task is to translate a given Chinese image description into a concise and precise English description. Ensure that the imagery is vivid and descriptive, and include stylistic elements to enrich the description.\nPlease note the following points:\n\n1. Capture the essence and mood of the Chinese description without including direct phrases or words from the examples provided.\n2. You should add appropriate words to make the images described in the prompt more aesthetically pleasing. If the Chinese description does not specify a style, you need to add some stylistic descriptions based on the essence of the Chinese text.\n3. The generated English description should not exceed 200 words.\n\n"""
+ qwen_prompt = QwenPrompt(
+ tokenizer_path=model_path,
+ model=model,
+ system_prompt=system_prompt
+ )
+ return qwen_prompt
+
+
+ def __call__(self, raw_prompt, positive=True, **kwargs):
+ if positive:
+ messages = [{
+ 'role': 'system',
+ 'content': self.system_prompt
+ }, {
+ 'role': 'user',
+ 'content': raw_prompt
+ }]
+ text = self.tokenizer.apply_chat_template(
+ messages,
+ tokenize=False,
+ add_generation_prompt=True
+ )
+ model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device)
+
+ generated_ids = self.model.generate(
+ model_inputs.input_ids,
+ max_new_tokens=512
+ )
+ generated_ids = [
+ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
+ ]
+
+ prompt = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
+ print(f"Your prompt is refined by Qwen: {prompt}")
+ return prompt
+ else:
+ return raw_prompt
+
+
+
+class Translator(torch.nn.Module):
+ def __init__(self, tokenizer_path=None, model=None):
+ super().__init__()
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
+ self.model = model
+
+
+ @staticmethod
+ def from_model_manager(model_manager: ModelManager):
+ model, model_path = model_manager.fetch_model("translator", require_model_path=True)
+ translator = Translator(tokenizer_path=model_path, model=model)
+ return translator
+
+
+ def __call__(self, prompt, **kwargs):
+ input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.model.device)
+ output_ids = self.model.generate(input_ids)
+ prompt = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
+ print(f"Your prompt is translated: {prompt}")
+ return prompt
diff --git a/PusaV1/diffsynth/prompters/sd3_prompter.py b/PusaV1/diffsynth/prompters/sd3_prompter.py
new file mode 100644
index 0000000000000000000000000000000000000000..ecf9bca30ae53e78822d06d769a65a6c79e8b5d8
--- /dev/null
+++ b/PusaV1/diffsynth/prompters/sd3_prompter.py
@@ -0,0 +1,93 @@
+from .base_prompter import BasePrompter
+from ..models.model_manager import ModelManager
+from ..models import SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3
+from transformers import CLIPTokenizer, T5TokenizerFast
+import os, torch
+
+
+class SD3Prompter(BasePrompter):
+ def __init__(
+ self,
+ tokenizer_1_path=None,
+ tokenizer_2_path=None,
+ tokenizer_3_path=None
+ ):
+ if tokenizer_1_path is None:
+ base_path = os.path.dirname(os.path.dirname(__file__))
+ tokenizer_1_path = os.path.join(base_path, "tokenizer_configs/stable_diffusion_3/tokenizer_1")
+ if tokenizer_2_path is None:
+ base_path = os.path.dirname(os.path.dirname(__file__))
+ tokenizer_2_path = os.path.join(base_path, "tokenizer_configs/stable_diffusion_3/tokenizer_2")
+ if tokenizer_3_path is None:
+ base_path = os.path.dirname(os.path.dirname(__file__))
+ tokenizer_3_path = os.path.join(base_path, "tokenizer_configs/stable_diffusion_3/tokenizer_3")
+ super().__init__()
+ self.tokenizer_1 = CLIPTokenizer.from_pretrained(tokenizer_1_path)
+ self.tokenizer_2 = CLIPTokenizer.from_pretrained(tokenizer_2_path)
+ self.tokenizer_3 = T5TokenizerFast.from_pretrained(tokenizer_3_path)
+ self.text_encoder_1: SD3TextEncoder1 = None
+ self.text_encoder_2: SD3TextEncoder2 = None
+ self.text_encoder_3: SD3TextEncoder3 = None
+
+
+ def fetch_models(self, text_encoder_1: SD3TextEncoder1 = None, text_encoder_2: SD3TextEncoder2 = None, text_encoder_3: SD3TextEncoder3 = None):
+ self.text_encoder_1 = text_encoder_1
+ self.text_encoder_2 = text_encoder_2
+ self.text_encoder_3 = text_encoder_3
+
+
+ def encode_prompt_using_clip(self, prompt, text_encoder, tokenizer, max_length, device):
+ input_ids = tokenizer(
+ prompt,
+ return_tensors="pt",
+ padding="max_length",
+ max_length=max_length,
+ truncation=True
+ ).input_ids.to(device)
+ pooled_prompt_emb, prompt_emb = text_encoder(input_ids)
+ return pooled_prompt_emb, prompt_emb
+
+
+ def encode_prompt_using_t5(self, prompt, text_encoder, tokenizer, max_length, device):
+ input_ids = tokenizer(
+ prompt,
+ return_tensors="pt",
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ add_special_tokens=True,
+ ).input_ids.to(device)
+ prompt_emb = text_encoder(input_ids)
+ prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
+
+ return prompt_emb
+
+
+ def encode_prompt(
+ self,
+ prompt,
+ positive=True,
+ device="cuda",
+ t5_sequence_length=77,
+ ):
+ prompt = self.process_prompt(prompt, positive=positive)
+
+ # CLIP
+ pooled_prompt_emb_1, prompt_emb_1 = self.encode_prompt_using_clip(prompt, self.text_encoder_1, self.tokenizer_1, 77, device)
+ pooled_prompt_emb_2, prompt_emb_2 = self.encode_prompt_using_clip(prompt, self.text_encoder_2, self.tokenizer_2, 77, device)
+
+ # T5
+ if self.text_encoder_3 is None:
+ prompt_emb_3 = torch.zeros((prompt_emb_1.shape[0], t5_sequence_length, 4096), dtype=prompt_emb_1.dtype, device=device)
+ else:
+ prompt_emb_3 = self.encode_prompt_using_t5(prompt, self.text_encoder_3, self.tokenizer_3, t5_sequence_length, device)
+ prompt_emb_3 = prompt_emb_3.to(prompt_emb_1.dtype) # float32 -> float16
+
+ # Merge
+ prompt_emb = torch.cat([
+ torch.nn.functional.pad(torch.cat([prompt_emb_1, prompt_emb_2], dim=-1), (0, 4096 - 768 - 1280)),
+ prompt_emb_3
+ ], dim=-2)
+ pooled_prompt_emb = torch.cat([pooled_prompt_emb_1, pooled_prompt_emb_2], dim=-1)
+
+ return prompt_emb, pooled_prompt_emb
diff --git a/PusaV1/diffsynth/prompters/sd_prompter.py b/PusaV1/diffsynth/prompters/sd_prompter.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3b31ea2836b3b02edab37d7f610c13f2cf6cead
--- /dev/null
+++ b/PusaV1/diffsynth/prompters/sd_prompter.py
@@ -0,0 +1,73 @@
+from .base_prompter import BasePrompter, tokenize_long_prompt
+from ..models.utils import load_state_dict, search_for_embeddings
+from ..models import SDTextEncoder
+from transformers import CLIPTokenizer
+import torch, os
+
+
+
+class SDPrompter(BasePrompter):
+ def __init__(self, tokenizer_path=None):
+ if tokenizer_path is None:
+ base_path = os.path.dirname(os.path.dirname(__file__))
+ tokenizer_path = os.path.join(base_path, "tokenizer_configs/stable_diffusion/tokenizer")
+ super().__init__()
+ self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
+ self.text_encoder: SDTextEncoder = None
+ self.textual_inversion_dict = {}
+ self.keyword_dict = {}
+
+
+ def fetch_models(self, text_encoder: SDTextEncoder = None):
+ self.text_encoder = text_encoder
+
+
+ def add_textual_inversions_to_model(self, textual_inversion_dict, text_encoder):
+ dtype = next(iter(text_encoder.parameters())).dtype
+ state_dict = text_encoder.token_embedding.state_dict()
+ token_embeddings = [state_dict["weight"]]
+ for keyword in textual_inversion_dict:
+ _, embeddings = textual_inversion_dict[keyword]
+ token_embeddings.append(embeddings.to(dtype=dtype, device=token_embeddings[0].device))
+ token_embeddings = torch.concat(token_embeddings, dim=0)
+ state_dict["weight"] = token_embeddings
+ text_encoder.token_embedding = torch.nn.Embedding(token_embeddings.shape[0], token_embeddings.shape[1])
+ text_encoder.token_embedding = text_encoder.token_embedding.to(dtype=dtype, device=token_embeddings[0].device)
+ text_encoder.token_embedding.load_state_dict(state_dict)
+
+
+ def add_textual_inversions_to_tokenizer(self, textual_inversion_dict, tokenizer):
+ additional_tokens = []
+ for keyword in textual_inversion_dict:
+ tokens, _ = textual_inversion_dict[keyword]
+ additional_tokens += tokens
+ self.keyword_dict[keyword] = " " + " ".join(tokens) + " "
+ tokenizer.add_tokens(additional_tokens)
+
+
+ def load_textual_inversions(self, model_paths):
+ for model_path in model_paths:
+ keyword = os.path.splitext(os.path.split(model_path)[-1])[0]
+ state_dict = load_state_dict(model_path)
+
+ # Search for embeddings
+ for embeddings in search_for_embeddings(state_dict):
+ if len(embeddings.shape) == 2 and embeddings.shape[1] == 768:
+ tokens = [f"{keyword}_{i}" for i in range(embeddings.shape[0])]
+ self.textual_inversion_dict[keyword] = (tokens, embeddings)
+
+ self.add_textual_inversions_to_model(self.textual_inversion_dict, self.text_encoder)
+ self.add_textual_inversions_to_tokenizer(self.textual_inversion_dict, self.tokenizer)
+
+
+ def encode_prompt(self, prompt, clip_skip=1, device="cuda", positive=True):
+ prompt = self.process_prompt(prompt, positive=positive)
+ for keyword in self.keyword_dict:
+ if keyword in prompt:
+ print(f"Textual inversion {keyword} is enabled.")
+ prompt = prompt.replace(keyword, self.keyword_dict[keyword])
+ input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device)
+ prompt_emb = self.text_encoder(input_ids, clip_skip=clip_skip)
+ prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
+
+ return prompt_emb
\ No newline at end of file
diff --git a/PusaV1/diffsynth/prompters/sdxl_prompter.py b/PusaV1/diffsynth/prompters/sdxl_prompter.py
new file mode 100644
index 0000000000000000000000000000000000000000..d84145402538b89b23d39a98271cbad64c2d9fc3
--- /dev/null
+++ b/PusaV1/diffsynth/prompters/sdxl_prompter.py
@@ -0,0 +1,61 @@
+from .base_prompter import BasePrompter, tokenize_long_prompt
+from ..models.model_manager import ModelManager
+from ..models import SDXLTextEncoder, SDXLTextEncoder2
+from transformers import CLIPTokenizer
+import torch, os
+
+
+
+class SDXLPrompter(BasePrompter):
+ def __init__(
+ self,
+ tokenizer_path=None,
+ tokenizer_2_path=None
+ ):
+ if tokenizer_path is None:
+ base_path = os.path.dirname(os.path.dirname(__file__))
+ tokenizer_path = os.path.join(base_path, "tokenizer_configs/stable_diffusion/tokenizer")
+ if tokenizer_2_path is None:
+ base_path = os.path.dirname(os.path.dirname(__file__))
+ tokenizer_2_path = os.path.join(base_path, "tokenizer_configs/stable_diffusion_xl/tokenizer_2")
+ super().__init__()
+ self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
+ self.tokenizer_2 = CLIPTokenizer.from_pretrained(tokenizer_2_path)
+ self.text_encoder: SDXLTextEncoder = None
+ self.text_encoder_2: SDXLTextEncoder2 = None
+
+
+ def fetch_models(self, text_encoder: SDXLTextEncoder = None, text_encoder_2: SDXLTextEncoder2 = None):
+ self.text_encoder = text_encoder
+ self.text_encoder_2 = text_encoder_2
+
+
+ def encode_prompt(
+ self,
+ prompt,
+ clip_skip=1,
+ clip_skip_2=2,
+ positive=True,
+ device="cuda"
+ ):
+ prompt = self.process_prompt(prompt, positive=positive)
+
+ # 1
+ input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device)
+ prompt_emb_1 = self.text_encoder(input_ids, clip_skip=clip_skip)
+
+ # 2
+ input_ids_2 = tokenize_long_prompt(self.tokenizer_2, prompt).to(device)
+ add_text_embeds, prompt_emb_2 = self.text_encoder_2(input_ids_2, clip_skip=clip_skip_2)
+
+ # Merge
+ if prompt_emb_1.shape[0] != prompt_emb_2.shape[0]:
+ max_batch_size = min(prompt_emb_1.shape[0], prompt_emb_2.shape[0])
+ prompt_emb_1 = prompt_emb_1[: max_batch_size]
+ prompt_emb_2 = prompt_emb_2[: max_batch_size]
+ prompt_emb = torch.concatenate([prompt_emb_1, prompt_emb_2], dim=-1)
+
+ # For very long prompt, we only use the first 77 tokens to compute `add_text_embeds`.
+ add_text_embeds = add_text_embeds[0:1]
+ prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
+ return add_text_embeds, prompt_emb
diff --git a/PusaV1/diffsynth/prompters/stepvideo_prompter.py b/PusaV1/diffsynth/prompters/stepvideo_prompter.py
new file mode 100644
index 0000000000000000000000000000000000000000..79d374b1f8a4be2a2298520fcbf87800e0ca91d9
--- /dev/null
+++ b/PusaV1/diffsynth/prompters/stepvideo_prompter.py
@@ -0,0 +1,56 @@
+from .base_prompter import BasePrompter
+from ..models.hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder
+from ..models.stepvideo_text_encoder import STEP1TextEncoder
+from transformers import BertTokenizer
+import os, torch
+
+
+class StepVideoPrompter(BasePrompter):
+
+ def __init__(
+ self,
+ tokenizer_1_path=None,
+ ):
+ if tokenizer_1_path is None:
+ base_path = os.path.dirname(os.path.dirname(__file__))
+ tokenizer_1_path = os.path.join(
+ base_path, "tokenizer_configs/hunyuan_dit/tokenizer")
+ super().__init__()
+ self.tokenizer_1 = BertTokenizer.from_pretrained(tokenizer_1_path)
+
+ def fetch_models(self, text_encoder_1: HunyuanDiTCLIPTextEncoder = None, text_encoder_2: STEP1TextEncoder = None):
+ self.text_encoder_1 = text_encoder_1
+ self.text_encoder_2 = text_encoder_2
+
+ def encode_prompt_using_clip(self, prompt, max_length, device):
+ text_inputs = self.tokenizer_1(
+ prompt,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_attention_mask=True,
+ return_tensors="pt",
+ )
+ prompt_embeds = self.text_encoder_1(
+ text_inputs.input_ids.to(device),
+ attention_mask=text_inputs.attention_mask.to(device),
+ )
+ return prompt_embeds
+
+ def encode_prompt_using_llm(self, prompt, max_length, device):
+ y, y_mask = self.text_encoder_2(prompt, max_length=max_length, device=device)
+ return y, y_mask
+
+ def encode_prompt(self,
+ prompt,
+ positive=True,
+ device="cuda"):
+
+ prompt = self.process_prompt(prompt, positive=positive)
+
+ clip_embeds = self.encode_prompt_using_clip(prompt, max_length=77, device=device)
+ llm_embeds, llm_mask = self.encode_prompt_using_llm(prompt, max_length=320, device=device)
+
+ llm_mask = torch.nn.functional.pad(llm_mask, (clip_embeds.shape[1], 0), value=1)
+
+ return clip_embeds, llm_embeds, llm_mask
diff --git a/PusaV1/diffsynth/prompters/wan_prompter.py b/PusaV1/diffsynth/prompters/wan_prompter.py
new file mode 100644
index 0000000000000000000000000000000000000000..01a765d3cb3bf2ee4d06553fd061ed7dd75443b2
--- /dev/null
+++ b/PusaV1/diffsynth/prompters/wan_prompter.py
@@ -0,0 +1,109 @@
+from .base_prompter import BasePrompter
+from ..models.wan_video_text_encoder import WanTextEncoder
+from transformers import AutoTokenizer
+import os, torch
+import ftfy
+import html
+import string
+import regex as re
+
+
+def basic_clean(text):
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+def whitespace_clean(text):
+ text = re.sub(r'\s+', ' ', text)
+ text = text.strip()
+ return text
+
+
+def canonicalize(text, keep_punctuation_exact_string=None):
+ text = text.replace('_', ' ')
+ if keep_punctuation_exact_string:
+ text = keep_punctuation_exact_string.join(
+ part.translate(str.maketrans('', '', string.punctuation))
+ for part in text.split(keep_punctuation_exact_string))
+ else:
+ text = text.translate(str.maketrans('', '', string.punctuation))
+ text = text.lower()
+ text = re.sub(r'\s+', ' ', text)
+ return text.strip()
+
+
+class HuggingfaceTokenizer:
+
+ def __init__(self, name, seq_len=None, clean=None, **kwargs):
+ assert clean in (None, 'whitespace', 'lower', 'canonicalize')
+ self.name = name
+ self.seq_len = seq_len
+ self.clean = clean
+
+ # init tokenizer
+ self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs)
+ self.vocab_size = self.tokenizer.vocab_size
+
+ def __call__(self, sequence, **kwargs):
+ return_mask = kwargs.pop('return_mask', False)
+
+ # arguments
+ _kwargs = {'return_tensors': 'pt'}
+ if self.seq_len is not None:
+ _kwargs.update({
+ 'padding': 'max_length',
+ 'truncation': True,
+ 'max_length': self.seq_len
+ })
+ _kwargs.update(**kwargs)
+
+ # tokenization
+ if isinstance(sequence, str):
+ sequence = [sequence]
+ if self.clean:
+ sequence = [self._clean(u) for u in sequence]
+ ids = self.tokenizer(sequence, **_kwargs)
+
+ # output
+ if return_mask:
+ return ids.input_ids, ids.attention_mask
+ else:
+ return ids.input_ids
+
+ def _clean(self, text):
+ if self.clean == 'whitespace':
+ text = whitespace_clean(basic_clean(text))
+ elif self.clean == 'lower':
+ text = whitespace_clean(basic_clean(text)).lower()
+ elif self.clean == 'canonicalize':
+ text = canonicalize(basic_clean(text))
+ return text
+
+
+class WanPrompter(BasePrompter):
+
+ def __init__(self, tokenizer_path=None, text_len=512):
+ super().__init__()
+ self.text_len = text_len
+ self.text_encoder = None
+ self.fetch_tokenizer(tokenizer_path)
+
+ def fetch_tokenizer(self, tokenizer_path=None):
+ if tokenizer_path is not None:
+ self.tokenizer = HuggingfaceTokenizer(name=tokenizer_path, seq_len=self.text_len, clean='whitespace')
+
+ def fetch_models(self, text_encoder: WanTextEncoder = None):
+ self.text_encoder = text_encoder
+
+ def encode_prompt(self, prompt, positive=True, device="cuda"):
+ prompt = self.process_prompt(prompt, positive=positive)
+
+ ids, mask = self.tokenizer(prompt, return_mask=True, add_special_tokens=True)
+ ids = ids.to(device)
+ mask = mask.to(device)
+ seq_lens = mask.gt(0).sum(dim=1).long()
+ prompt_emb = self.text_encoder(ids, mask)
+ for i, v in enumerate(seq_lens):
+ prompt_emb[:, v:] = 0
+ return prompt_emb
diff --git a/PusaV1/diffsynth/schedulers/__init__.py b/PusaV1/diffsynth/schedulers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..851f0f35d593bc43e9ea59688b5b0eadf11c27cf
--- /dev/null
+++ b/PusaV1/diffsynth/schedulers/__init__.py
@@ -0,0 +1,6 @@
+from .ddim import EnhancedDDIMScheduler
+from .continuous_ode import ContinuousODEScheduler
+from .flow_match import FlowMatchScheduler
+from .flow_match_pusa import FlowMatchSchedulerPusa
+from .flow_match_pusa_multi_frames import FlowMatchSchedulerPusaMultiFrames
+from .flow_match_pusa_v2v import FlowMatchSchedulerPusaV2V
diff --git a/PusaV1/diffsynth/schedulers/__pycache__/__init__.cpython-310.pyc b/PusaV1/diffsynth/schedulers/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4fd1e5aedb4e3d8c39b2676dea2d663224fefcb3
Binary files /dev/null and b/PusaV1/diffsynth/schedulers/__pycache__/__init__.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/schedulers/__pycache__/__init__.cpython-312.pyc b/PusaV1/diffsynth/schedulers/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0b7405e0b06281cf2b3f5a107740fa6e3daf4768
Binary files /dev/null and b/PusaV1/diffsynth/schedulers/__pycache__/__init__.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/schedulers/__pycache__/continuous_ode.cpython-310.pyc b/PusaV1/diffsynth/schedulers/__pycache__/continuous_ode.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..72a621593ec4de8872057c765365a86e5b4d827b
Binary files /dev/null and b/PusaV1/diffsynth/schedulers/__pycache__/continuous_ode.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/schedulers/__pycache__/continuous_ode.cpython-312.pyc b/PusaV1/diffsynth/schedulers/__pycache__/continuous_ode.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0f4f041d7bb2bf843a23e47731d85fa2e4e697c2
Binary files /dev/null and b/PusaV1/diffsynth/schedulers/__pycache__/continuous_ode.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/schedulers/__pycache__/ddim.cpython-310.pyc b/PusaV1/diffsynth/schedulers/__pycache__/ddim.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f38c543872f07534fc456e2f9703a1426053c2d1
Binary files /dev/null and b/PusaV1/diffsynth/schedulers/__pycache__/ddim.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/schedulers/__pycache__/ddim.cpython-312.pyc b/PusaV1/diffsynth/schedulers/__pycache__/ddim.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4d328f336b96fdb52d118f7a180c68a78447e828
Binary files /dev/null and b/PusaV1/diffsynth/schedulers/__pycache__/ddim.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/schedulers/__pycache__/flow_match.cpython-310.pyc b/PusaV1/diffsynth/schedulers/__pycache__/flow_match.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..689e3debcc9197dfe56159ce9b25c8dfea1743fc
Binary files /dev/null and b/PusaV1/diffsynth/schedulers/__pycache__/flow_match.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/schedulers/__pycache__/flow_match.cpython-312.pyc b/PusaV1/diffsynth/schedulers/__pycache__/flow_match.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..382224074e86cac14b8ae1ce8d700ea8a9520f1a
Binary files /dev/null and b/PusaV1/diffsynth/schedulers/__pycache__/flow_match.cpython-312.pyc differ
diff --git a/PusaV1/diffsynth/schedulers/__pycache__/flow_match_pusa.cpython-310.pyc b/PusaV1/diffsynth/schedulers/__pycache__/flow_match_pusa.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a0da804b11d7f0aa731d88a361667b586ce5fb70
Binary files /dev/null and b/PusaV1/diffsynth/schedulers/__pycache__/flow_match_pusa.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/schedulers/__pycache__/flow_match_pusa_multi_frames.cpython-310.pyc b/PusaV1/diffsynth/schedulers/__pycache__/flow_match_pusa_multi_frames.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8902c40dc5bf9e4f2e7ea106d969230e4ad70cb2
Binary files /dev/null and b/PusaV1/diffsynth/schedulers/__pycache__/flow_match_pusa_multi_frames.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/schedulers/__pycache__/flow_match_pusa_v2v.cpython-310.pyc b/PusaV1/diffsynth/schedulers/__pycache__/flow_match_pusa_v2v.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..69972944961eb0ee1891604a92d28d30fdc65426
Binary files /dev/null and b/PusaV1/diffsynth/schedulers/__pycache__/flow_match_pusa_v2v.cpython-310.pyc differ
diff --git a/PusaV1/diffsynth/schedulers/continuous_ode.py b/PusaV1/diffsynth/schedulers/continuous_ode.py
new file mode 100644
index 0000000000000000000000000000000000000000..c73b9e221aa54a8385322b42012c30c598550fcd
--- /dev/null
+++ b/PusaV1/diffsynth/schedulers/continuous_ode.py
@@ -0,0 +1,59 @@
+import torch
+
+
+class ContinuousODEScheduler():
+
+ def __init__(self, num_inference_steps=100, sigma_max=700.0, sigma_min=0.002, rho=7.0):
+ self.sigma_max = sigma_max
+ self.sigma_min = sigma_min
+ self.rho = rho
+ self.set_timesteps(num_inference_steps)
+
+
+ def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, **kwargs):
+ ramp = torch.linspace(1-denoising_strength, 1, num_inference_steps)
+ min_inv_rho = torch.pow(torch.tensor((self.sigma_min,)), (1 / self.rho))
+ max_inv_rho = torch.pow(torch.tensor((self.sigma_max,)), (1 / self.rho))
+ self.sigmas = torch.pow(max_inv_rho + ramp * (min_inv_rho - max_inv_rho), self.rho)
+ self.timesteps = torch.log(self.sigmas) * 0.25
+
+
+ def step(self, model_output, timestep, sample, to_final=False):
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
+ sigma = self.sigmas[timestep_id]
+ sample *= (sigma*sigma + 1).sqrt()
+ estimated_sample = -sigma / (sigma*sigma + 1).sqrt() * model_output + 1 / (sigma*sigma + 1) * sample
+ if to_final or timestep_id + 1 >= len(self.timesteps):
+ prev_sample = estimated_sample
+ else:
+ sigma_ = self.sigmas[timestep_id + 1]
+ derivative = 1 / sigma * (sample - estimated_sample)
+ prev_sample = sample + derivative * (sigma_ - sigma)
+ prev_sample /= (sigma_*sigma_ + 1).sqrt()
+ return prev_sample
+
+
+ def return_to_timestep(self, timestep, sample, sample_stablized):
+ # This scheduler doesn't support this function.
+ pass
+
+
+ def add_noise(self, original_samples, noise, timestep):
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
+ sigma = self.sigmas[timestep_id]
+ sample = (original_samples + noise * sigma) / (sigma*sigma + 1).sqrt()
+ return sample
+
+
+ def training_target(self, sample, noise, timestep):
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
+ sigma = self.sigmas[timestep_id]
+ target = (-(sigma*sigma + 1).sqrt() / sigma + 1 / (sigma*sigma + 1).sqrt() / sigma) * sample + 1 / (sigma*sigma + 1).sqrt() * noise
+ return target
+
+
+ def training_weight(self, timestep):
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
+ sigma = self.sigmas[timestep_id]
+ weight = (1 + sigma*sigma).sqrt() / sigma
+ return weight
diff --git a/PusaV1/diffsynth/schedulers/ddim.py b/PusaV1/diffsynth/schedulers/ddim.py
new file mode 100644
index 0000000000000000000000000000000000000000..da524963c62f662016b1429d5047ebe7b5922604
--- /dev/null
+++ b/PusaV1/diffsynth/schedulers/ddim.py
@@ -0,0 +1,105 @@
+import torch, math
+
+
+class EnhancedDDIMScheduler():
+
+ def __init__(self, num_train_timesteps=1000, beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", prediction_type="epsilon", rescale_zero_terminal_snr=False):
+ self.num_train_timesteps = num_train_timesteps
+ if beta_schedule == "scaled_linear":
+ betas = torch.square(torch.linspace(math.sqrt(beta_start), math.sqrt(beta_end), num_train_timesteps, dtype=torch.float32))
+ elif beta_schedule == "linear":
+ betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
+ else:
+ raise NotImplementedError(f"{beta_schedule} is not implemented")
+ self.alphas_cumprod = torch.cumprod(1.0 - betas, dim=0)
+ if rescale_zero_terminal_snr:
+ self.alphas_cumprod = self.rescale_zero_terminal_snr(self.alphas_cumprod)
+ self.alphas_cumprod = self.alphas_cumprod.tolist()
+ self.set_timesteps(10)
+ self.prediction_type = prediction_type
+
+
+ def rescale_zero_terminal_snr(self, alphas_cumprod):
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
+
+ # Store old values.
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
+
+ # Shift so the last timestep is zero.
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
+
+ # Scale so the first timestep is back to the old value.
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
+
+ # Convert alphas_bar_sqrt to betas
+ alphas_bar = alphas_bar_sqrt.square() # Revert sqrt
+
+ return alphas_bar
+
+
+ def set_timesteps(self, num_inference_steps, denoising_strength=1.0, **kwargs):
+ # The timesteps are aligned to 999...0, which is different from other implementations,
+ # but I think this implementation is more reasonable in theory.
+ max_timestep = max(round(self.num_train_timesteps * denoising_strength) - 1, 0)
+ num_inference_steps = min(num_inference_steps, max_timestep + 1)
+ if num_inference_steps == 1:
+ self.timesteps = torch.Tensor([max_timestep])
+ else:
+ step_length = max_timestep / (num_inference_steps - 1)
+ self.timesteps = torch.Tensor([round(max_timestep - i*step_length) for i in range(num_inference_steps)])
+
+
+ def denoise(self, model_output, sample, alpha_prod_t, alpha_prod_t_prev):
+ if self.prediction_type == "epsilon":
+ weight_e = math.sqrt(1 - alpha_prod_t_prev) - math.sqrt(alpha_prod_t_prev * (1 - alpha_prod_t) / alpha_prod_t)
+ weight_x = math.sqrt(alpha_prod_t_prev / alpha_prod_t)
+ prev_sample = sample * weight_x + model_output * weight_e
+ elif self.prediction_type == "v_prediction":
+ weight_e = -math.sqrt(alpha_prod_t_prev * (1 - alpha_prod_t)) + math.sqrt(alpha_prod_t * (1 - alpha_prod_t_prev))
+ weight_x = math.sqrt(alpha_prod_t * alpha_prod_t_prev) + math.sqrt((1 - alpha_prod_t) * (1 - alpha_prod_t_prev))
+ prev_sample = sample * weight_x + model_output * weight_e
+ else:
+ raise NotImplementedError(f"{self.prediction_type} is not implemented")
+ return prev_sample
+
+
+ def step(self, model_output, timestep, sample, to_final=False):
+ alpha_prod_t = self.alphas_cumprod[int(timestep.flatten().tolist()[0])]
+ if isinstance(timestep, torch.Tensor):
+ timestep = timestep.cpu()
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
+ if to_final or timestep_id + 1 >= len(self.timesteps):
+ alpha_prod_t_prev = 1.0
+ else:
+ timestep_prev = int(self.timesteps[timestep_id + 1])
+ alpha_prod_t_prev = self.alphas_cumprod[timestep_prev]
+
+ return self.denoise(model_output, sample, alpha_prod_t, alpha_prod_t_prev)
+
+
+ def return_to_timestep(self, timestep, sample, sample_stablized):
+ alpha_prod_t = self.alphas_cumprod[int(timestep.flatten().tolist()[0])]
+ noise_pred = (sample - math.sqrt(alpha_prod_t) * sample_stablized) / math.sqrt(1 - alpha_prod_t)
+ return noise_pred
+
+
+ def add_noise(self, original_samples, noise, timestep):
+ sqrt_alpha_prod = math.sqrt(self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
+ sqrt_one_minus_alpha_prod = math.sqrt(1 - self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
+ return noisy_samples
+
+
+ def training_target(self, sample, noise, timestep):
+ if self.prediction_type == "epsilon":
+ return noise
+ else:
+ sqrt_alpha_prod = math.sqrt(self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
+ sqrt_one_minus_alpha_prod = math.sqrt(1 - self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
+ target = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
+ return target
+
+
+ def training_weight(self, timestep):
+ return 1.0
diff --git a/PusaV1/diffsynth/schedulers/flow_match.py b/PusaV1/diffsynth/schedulers/flow_match.py
new file mode 100644
index 0000000000000000000000000000000000000000..e0a9e961a3a4903b0f502e687ce4afced51bbfcf
--- /dev/null
+++ b/PusaV1/diffsynth/schedulers/flow_match.py
@@ -0,0 +1,79 @@
+import torch
+
+
+
+class FlowMatchScheduler():
+
+ def __init__(self, num_inference_steps=100, num_train_timesteps=1000, shift=3.0, sigma_max=1.0, sigma_min=0.003/1.002, inverse_timesteps=False, extra_one_step=False, reverse_sigmas=False):
+ self.num_train_timesteps = num_train_timesteps
+ self.shift = shift
+ self.sigma_max = sigma_max
+ self.sigma_min = sigma_min
+ self.inverse_timesteps = inverse_timesteps
+ self.extra_one_step = extra_one_step
+ self.reverse_sigmas = reverse_sigmas
+ self.set_timesteps(num_inference_steps)
+
+
+ def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, shift=None):
+ if shift is not None:
+ self.shift = shift
+ sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength
+ if self.extra_one_step:
+ self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps + 1)[:-1]
+ else:
+ self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps)
+ if self.inverse_timesteps:
+ self.sigmas = torch.flip(self.sigmas, dims=[0])
+ self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas)
+ if self.reverse_sigmas:
+ self.sigmas = 1 - self.sigmas
+ self.timesteps = self.sigmas * self.num_train_timesteps
+ if training:
+ x = self.timesteps
+ y = torch.exp(-2 * ((x - num_inference_steps / 2) / num_inference_steps) ** 2)
+ y_shifted = y - y.min()
+ bsmntw_weighing = y_shifted * (num_inference_steps / y_shifted.sum())
+ self.linear_timesteps_weights = bsmntw_weighing
+
+
+ def step(self, model_output, timestep, sample, to_final=False, **kwargs):
+ if isinstance(timestep, torch.Tensor):
+ timestep = timestep.cpu()
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
+ sigma = self.sigmas[timestep_id]
+ if to_final or timestep_id + 1 >= len(self.timesteps):
+ sigma_ = 1 if (self.inverse_timesteps or self.reverse_sigmas) else 0
+ else:
+ sigma_ = self.sigmas[timestep_id + 1]
+ prev_sample = sample + model_output * (sigma_ - sigma)
+ return prev_sample
+
+
+ def return_to_timestep(self, timestep, sample, sample_stablized):
+ if isinstance(timestep, torch.Tensor):
+ timestep = timestep.cpu()
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
+ sigma = self.sigmas[timestep_id]
+ model_output = (sample - sample_stablized) / sigma
+ return model_output
+
+
+ def add_noise(self, original_samples, noise, timestep):
+ if isinstance(timestep, torch.Tensor):
+ timestep = timestep.cpu()
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
+ sigma = self.sigmas[timestep_id]
+ sample = (1 - sigma) * original_samples + sigma * noise
+ return sample
+
+
+ def training_target(self, sample, noise, timestep):
+ target = noise - sample
+ return target
+
+
+ def training_weight(self, timestep):
+ timestep_id = torch.argmin((self.timesteps - timestep.to(self.timesteps.device)).abs())
+ weights = self.linear_timesteps_weights[timestep_id]
+ return weights
\ No newline at end of file
diff --git a/PusaV1/diffsynth/schedulers/flow_match_pusa.py b/PusaV1/diffsynth/schedulers/flow_match_pusa.py
new file mode 100644
index 0000000000000000000000000000000000000000..a5a76b7a39608342ab86ffdd8270e8c2d6a909c7
--- /dev/null
+++ b/PusaV1/diffsynth/schedulers/flow_match_pusa.py
@@ -0,0 +1,128 @@
+import torch
+
+
+
+class FlowMatchSchedulerPusa():
+
+ def __init__(self, num_inference_steps=100, num_train_timesteps=1000, shift=3.0, sigma_max=1.0, sigma_min=0.003/1.002, inverse_timesteps=False, extra_one_step=False, reverse_sigmas=False):
+ self.num_train_timesteps = num_train_timesteps
+ self.shift = shift
+ self.sigma_max = sigma_max
+ self.sigma_min = sigma_min
+ self.inverse_timesteps = inverse_timesteps
+ self.extra_one_step = extra_one_step
+ self.reverse_sigmas = reverse_sigmas
+ self.set_timesteps(num_inference_steps)
+
+
+ def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, shift=None):
+ if shift is not None:
+ self.shift = shift
+ sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength
+ if self.extra_one_step:
+ self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps + 1)[:-1]
+ else:
+ self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps)
+ if self.inverse_timesteps:
+ self.sigmas = torch.flip(self.sigmas, dims=[0])
+ self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas)
+ if self.reverse_sigmas:
+ self.sigmas = 1 - self.sigmas
+ self.timesteps = self.sigmas * self.num_train_timesteps
+ if training:
+ x = self.timesteps
+ y = torch.exp(-2 * ((x - num_inference_steps / 2) / num_inference_steps) ** 2)
+ y_shifted = y - y.min()
+ bsmntw_weighing = y_shifted * (num_inference_steps / y_shifted.sum())
+ self.linear_timesteps_weights = bsmntw_weighing
+
+
+ def step(self, model_output, timestep, sample, to_final=False, **kwargs):
+ if isinstance(timestep, torch.Tensor):
+ # timestep = timestep.cpu()
+ self.timesteps = self.timesteps.to(timestep.device)
+ self.sigmas = self.sigmas.to(timestep.device)
+ model_output = model_output.to(timestep.device)
+ sample = sample.to(timestep.device)
+ if len(timestep.shape) == 1:
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
+ sigma = self.sigmas[timestep_id]
+ if to_final or timestep_id + 1 >= len(self.timesteps):
+ sigma_ = 1 if (self.inverse_timesteps or self.reverse_sigmas) else 0
+ else:
+ sigma_ = self.sigmas[timestep_id + 1]
+ prev_sample = sample + model_output * (sigma_ - sigma)
+ else:
+ timestep_id = torch.argmin((self.timesteps.unsqueeze(1) - timestep).abs(), dim=0)
+ sigma = self.sigmas[timestep_id].unsqueeze(0).unsqueeze(1).unsqueeze(3).unsqueeze(4).to(sample.device)
+ # Handle sigma_ calculation for each timestep_id element
+ if to_final or torch.any(timestep_id + 1 >= len(self.timesteps)):
+ default_value = 1.0 if (self.inverse_timesteps or self.reverse_sigmas) else 0.0
+ # Create sigma_ with the same dtype as self.sigmas
+ sigma_ = torch.ones_like(timestep_id, dtype=self.sigmas.dtype, device=sample.device) * default_value
+ valid_indices = timestep_id + 1 < len(self.timesteps)
+ if torch.any(valid_indices):
+ # Convert indices to the appropriate type for indexing
+ valid_timestep_ids = timestep_id[valid_indices]
+ sigma_[valid_indices] = self.sigmas[(valid_timestep_ids + 1).to(torch.long)]
+ else:
+ sigma_ = self.sigmas[(timestep_id + 1).to(torch.long)]
+
+
+ # Reshape sigma_ to match sigma's dimensions for the operation
+ sigma_ = sigma_.unsqueeze(0).unsqueeze(1).unsqueeze(3).unsqueeze(4).to(sample.device)
+ if torch.any(timestep == 0):
+ zero_indices = torch.where(timestep == 0)[1].to(torch.long)
+ sigma[:,:,zero_indices] = 0
+ print("sigma", sigma[0,0,:,0,0], '\n', "sigma_", sigma_[0,0,:,0,0])
+
+ prev_sample = sample + model_output * (sigma_ - sigma)
+ return prev_sample
+
+
+ def return_to_timestep(self, timestep, sample, sample_stablized):
+ if isinstance(timestep, torch.Tensor):
+ # timestep = timestep.cpu()
+ self.timesteps = self.timesteps.to(timestep.device)
+ self.sigmas = self.sigmas.to(timestep.device)
+ if len(timestep.shape) == 1:
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
+ sigma = self.sigmas[timestep_id]
+ else:
+ timestep_id = torch.argmin((self.timesteps.unsqueeze(1) - timestep).abs(), dim=0)
+ sigma = self.sigmas[timestep_id].unsqueeze(0).unsqueeze(1).unsqueeze(3).unsqueeze(4).to(sample.device)
+ model_output = (sample - sample_stablized) / sigma
+ return model_output
+
+
+ def add_noise(self, original_samples, noise, timestep):
+ if isinstance(timestep, torch.Tensor):
+ # timestep = timestep.cpu()
+ self.timesteps = self.timesteps.to(timestep.device)
+ self.sigmas = self.sigmas.to(timestep.device)
+ if len(timestep.shape) == 1:
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
+ sigma = self.sigmas[timestep_id]
+ else:
+ timestep_id = torch.argmin((self.timesteps.unsqueeze(1) - timestep).abs(), dim=0)
+ sigma = self.sigmas[timestep_id].unsqueeze(0).unsqueeze(1).unsqueeze(3).unsqueeze(4).to(original_samples.device)
+ sample = (1 - sigma) * original_samples + sigma * noise
+
+ return sample
+
+
+ def training_target(self, sample, noise, timestep):
+ target = noise - sample
+ return target
+
+
+ def training_weight(self, timestep):
+ if isinstance(timestep, torch.Tensor):
+ self.timesteps = self.timesteps.to(timestep.device)
+ self.linear_timesteps_weights = self.linear_timesteps_weights.to(timestep.device)
+ if len(timestep.shape) == 1:
+ timestep_id = torch.argmin((self.timesteps - timestep.to(self.timesteps.device)).abs())
+ else:
+ timestep_id = torch.argmin((self.timesteps.unsqueeze(1) - timestep.to(self.timesteps.device)).abs(), dim=0)
+ weights = self.linear_timesteps_weights[timestep_id].to(self.timesteps.device)
+ return weights
diff --git a/PusaV1/diffsynth/schedulers/flow_match_pusa_multi_frames.py b/PusaV1/diffsynth/schedulers/flow_match_pusa_multi_frames.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9df0b447a31da24667cbd45f0afc341562ac601
--- /dev/null
+++ b/PusaV1/diffsynth/schedulers/flow_match_pusa_multi_frames.py
@@ -0,0 +1,130 @@
+import torch
+
+class FlowMatchSchedulerPusaMultiFrames():
+
+ def __init__(self, num_inference_steps=100, num_train_timesteps=1000, shift=3.0, sigma_max=1.0, sigma_min=0.003/1.002, inverse_timesteps=False, extra_one_step=False, reverse_sigmas=False):
+ self.num_train_timesteps = num_train_timesteps
+ self.shift = shift
+ self.sigma_max = sigma_max
+ self.sigma_min = sigma_min
+ self.inverse_timesteps = inverse_timesteps
+ self.extra_one_step = extra_one_step
+ self.reverse_sigmas = reverse_sigmas
+ self.set_timesteps(num_inference_steps)
+
+
+ def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, shift=None):
+ if shift is not None:
+ self.shift = shift
+ sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength
+ if self.extra_one_step:
+ self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps + 1)[:-1]
+ else:
+ self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps)
+ if self.inverse_timesteps:
+ self.sigmas = torch.flip(self.sigmas, dims=[0])
+ self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas)
+ if self.reverse_sigmas:
+ self.sigmas = 1 - self.sigmas
+ self.timesteps = self.sigmas * self.num_train_timesteps
+ if training:
+ x = self.timesteps
+ y = torch.exp(-2 * ((x - num_inference_steps / 2) / num_inference_steps) ** 2)
+ y_shifted = y - y.min()
+ bsmntw_weighing = y_shifted * (num_inference_steps / y_shifted.sum())
+ self.linear_timesteps_weights = bsmntw_weighing
+
+
+ def step(self, model_output, timestep, sample, to_final=False, cond_frame_latent_indices=None, noise_multipliers=None, **kwargs):
+ if isinstance(timestep, torch.Tensor):
+ # timestep = timestep.cpu()
+ self.timesteps = self.timesteps.to(timestep.device)
+ self.sigmas = self.sigmas.to(timestep.device)
+ model_output = model_output.to(timestep.device)
+ sample = sample.to(timestep.device)
+ if len(timestep.shape) == 1:
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
+ sigma = self.sigmas[timestep_id]
+ if to_final or timestep_id + 1 >= len(self.timesteps):
+ sigma_ = 1 if (self.inverse_timesteps or self.reverse_sigmas) else 0
+ else:
+ sigma_ = self.sigmas[timestep_id + 1]
+ prev_sample = sample + model_output * (sigma_ - sigma)
+ else:
+ timestep_id = torch.argmin((self.timesteps.unsqueeze(1) - timestep).abs(), dim=0)
+ sigma = self.sigmas[timestep_id].unsqueeze(0).unsqueeze(1).unsqueeze(3).unsqueeze(4).to(sample.device)
+
+ # Handle sigma_ calculation for each timestep_id element
+ next_timestep_id = (timestep_id + 1).to(torch.long)
+ # Create sigma_ with the same dtype as self.sigmas
+ sigma_ = torch.zeros_like(timestep_id, dtype=self.sigmas.dtype, device=sample.device)
+
+ valid_indices = next_timestep_id < len(self.timesteps)
+ sigma_[valid_indices] = self.sigmas[next_timestep_id[valid_indices]]
+
+ invalid_indices = ~valid_indices
+ default_value = 1.0 if (self.inverse_timesteps or self.reverse_sigmas) else 0.0
+ sigma_[invalid_indices] = default_value
+
+ # Reshape sigma_ to match sigma's dimensions for the operation
+ sigma_ = sigma_.unsqueeze(0).unsqueeze(1).unsqueeze(3).unsqueeze(4).to(sample.device)
+
+
+ if torch.any(timestep == 0):
+ zero_indices = torch.where(timestep == 0)[1].to(torch.long)
+ sigma[:,:,zero_indices] = 0
+ print("sigma", sigma[0,0,:,0,0], '\n', "sigma_", sigma_[0,0,:,0,0])
+
+ prev_sample = sample + model_output * (sigma_ - sigma)
+
+ return prev_sample
+
+
+ def return_to_timestep(self, timestep, sample, sample_stablized):
+ if isinstance(timestep, torch.Tensor):
+ # timestep = timestep.cpu()
+ self.timesteps = self.timesteps.to(timestep.device)
+ self.sigmas = self.sigmas.to(timestep.device)
+ if len(timestep.shape) == 1:
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
+ sigma = self.sigmas[timestep_id]
+ else:
+ timestep_id = torch.argmin((self.timesteps.unsqueeze(1) - timestep).abs(), dim=0)
+ sigma = self.sigmas[timestep_id].unsqueeze(0).unsqueeze(1).unsqueeze(3).unsqueeze(4).to(sample.device)
+ model_output = (sample - sample_stablized) / sigma
+ return model_output
+
+
+ def add_noise(self, original_samples, noise, timestep):
+ if isinstance(timestep, torch.Tensor):
+ # timestep = timestep.cpu()
+ self.timesteps = self.timesteps.to(timestep.device)
+ self.sigmas = self.sigmas.to(timestep.device)
+ if len(timestep.shape) == 1:
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
+ sigma = self.sigmas[timestep_id]
+ else:
+ timestep_id = torch.argmin((self.timesteps.unsqueeze(1) - timestep).abs(), dim=0)
+ sigma = self.sigmas[timestep_id].unsqueeze(0).unsqueeze(1).unsqueeze(3).unsqueeze(4).to(original_samples.device)
+
+ sample = (1 - sigma) * original_samples + sigma * noise
+
+ return sample
+
+
+ def training_target(self, sample, noise, timestep):
+ target = noise - sample
+ return target
+
+
+ def training_weight(self, timestep):
+ if isinstance(timestep, torch.Tensor):
+ self.timesteps = self.timesteps.to(timestep.device)
+ self.linear_timesteps_weights = self.linear_timesteps_weights.to(timestep.device)
+ if len(timestep.shape) == 1:
+ timestep_id = torch.argmin((self.timesteps - timestep.to(self.timesteps.device)).abs())
+ else:
+ timestep_id = torch.argmin((self.timesteps.unsqueeze(1) - timestep.to(self.timesteps.device)).abs(), dim=0)
+ weights = self.linear_timesteps_weights[timestep_id].to(self.timesteps.device)
+
+ return weights
diff --git a/PusaV1/diffsynth/schedulers/flow_match_pusa_v2v.py b/PusaV1/diffsynth/schedulers/flow_match_pusa_v2v.py
new file mode 100644
index 0000000000000000000000000000000000000000..190499ec4c4cdc807f32b4a336cb9280ed41660e
--- /dev/null
+++ b/PusaV1/diffsynth/schedulers/flow_match_pusa_v2v.py
@@ -0,0 +1,136 @@
+import torch
+
+class FlowMatchSchedulerPusaV2V():
+
+ def __init__(self, num_inference_steps=100, num_train_timesteps=1000, shift=3.0, sigma_max=1.0, sigma_min=0.003/1.002, inverse_timesteps=False, extra_one_step=False, reverse_sigmas=False):
+ self.num_train_timesteps = num_train_timesteps
+ self.shift = shift
+ self.sigma_max = sigma_max
+ self.sigma_min = sigma_min
+ self.inverse_timesteps = inverse_timesteps
+ self.extra_one_step = extra_one_step
+ self.reverse_sigmas = reverse_sigmas
+ self.set_timesteps(num_inference_steps)
+
+
+ def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, shift=None):
+ if shift is not None:
+ self.shift = shift
+ sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength
+ if self.extra_one_step:
+ self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps + 1)[:-1]
+ else:
+ self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps)
+ if self.inverse_timesteps:
+ self.sigmas = torch.flip(self.sigmas, dims=[0])
+ self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas)
+ if self.reverse_sigmas:
+ self.sigmas = 1 - self.sigmas
+ self.timesteps = self.sigmas * self.num_train_timesteps
+ if training:
+ x = self.timesteps
+ y = torch.exp(-2 * ((x - num_inference_steps / 2) / num_inference_steps) ** 2)
+ y_shifted = y - y.min()
+ bsmntw_weighing = y_shifted * (num_inference_steps / y_shifted.sum())
+ self.linear_timesteps_weights = bsmntw_weighing
+
+
+ def step(self, model_output, timestep, sample, to_final=False, cond_frame_latent_indices=None, noise_multipliers=None, **kwargs):
+ if isinstance(timestep, torch.Tensor):
+ # timestep = timestep.cpu()
+ self.timesteps = self.timesteps.to(timestep.device)
+ self.sigmas = self.sigmas.to(timestep.device)
+ model_output = model_output.to(timestep.device)
+ sample = sample.to(timestep.device)
+
+ if len(timestep.shape) == 1:
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
+ sigma = self.sigmas[timestep_id]
+ if to_final or timestep_id + 1 >= len(self.timesteps):
+ sigma_ = 1 if (self.inverse_timesteps or self.reverse_sigmas) else 0
+ else:
+ sigma_ = self.sigmas[timestep_id + 1]
+ prev_sample = sample + model_output * (sigma_ - sigma)
+ else:
+ timestep = torch.ones_like(timestep) * timestep.max()
+ timestep_id = torch.argmin((self.timesteps.unsqueeze(1) - timestep).abs(), dim=0)
+ sigma = self.sigmas[timestep_id].unsqueeze(0).unsqueeze(1).unsqueeze(3).unsqueeze(4).to(sample.device)
+ # Handle sigma_ calculation for each timestep_id element
+ if to_final or torch.any(timestep_id + 1 >= len(self.timesteps)):
+ default_value = 1.0 if (self.inverse_timesteps or self.reverse_sigmas) else 0.0
+ # Create sigma_ with the same dtype as self.sigmas
+ sigma_ = torch.ones_like(timestep_id, dtype=self.sigmas.dtype, device=sample.device) * default_value
+ valid_indices = timestep_id + 1 < len(self.timesteps)
+ if torch.any(valid_indices):
+ # Convert indices to the appropriate type for indexing
+ valid_timestep_ids = timestep_id[valid_indices]
+ sigma_[valid_indices] = self.sigmas[(valid_timestep_ids + 1).to(torch.long)]
+ else:
+ sigma_ = self.sigmas[(timestep_id + 1).to(torch.long)]
+
+ if cond_frame_latent_indices is not None and noise_multipliers is not None:
+ for latent_idx in cond_frame_latent_indices:
+ multiplier = noise_multipliers.get(latent_idx, 1.0)
+ sigma[:,:,latent_idx] = sigma[:,:,latent_idx] * multiplier # timestep = sigma * 1000, equivalent, so directly use multiplier here
+ sigma_[latent_idx] = sigma_[latent_idx] * multiplier
+
+ sigma_ = sigma_.unsqueeze(0).unsqueeze(1).unsqueeze(3).unsqueeze(4).to(sample.device)
+
+ if torch.any(timestep == 0):
+ zero_indices = torch.where(timestep == 0)[1].to(torch.long)
+ sigma[:,:,zero_indices] = 0
+ print("sigma", sigma[0,0,:,0,0], '\n', "sigma_", sigma_[0,0,:,0,0])
+
+ prev_sample = sample + model_output * (sigma_ - sigma)
+
+
+
+ return prev_sample
+
+
+ def return_to_timestep(self, timestep, sample, sample_stablized):
+ if isinstance(timestep, torch.Tensor):
+ # timestep = timestep.cpu()
+ self.timesteps = self.timesteps.to(timestep.device)
+ self.sigmas = self.sigmas.to(timestep.device)
+ if len(timestep.shape) == 1:
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
+ sigma = self.sigmas[timestep_id]
+ else:
+ timestep_id = torch.argmin((self.timesteps.unsqueeze(1) - timestep).abs(), dim=0)
+ sigma = self.sigmas[timestep_id].unsqueeze(0).unsqueeze(1).unsqueeze(3).unsqueeze(4).to(sample.device)
+ model_output = (sample - sample_stablized) / sigma
+ return model_output
+
+
+ def add_noise(self, original_samples, noise, timestep):
+ if isinstance(timestep, torch.Tensor):
+ # timestep = timestep.cpu()
+ self.timesteps = self.timesteps.to(timestep.device)
+ self.sigmas = self.sigmas.to(timestep.device)
+ if len(timestep.shape) == 1:
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
+ sigma = self.sigmas[timestep_id]
+ else:
+ timestep_id = torch.argmin((self.timesteps.unsqueeze(1) - timestep).abs(), dim=0)
+ sigma = self.sigmas[timestep_id].unsqueeze(0).unsqueeze(1).unsqueeze(3).unsqueeze(4).to(original_samples.device)
+ sample = (1 - sigma) * original_samples + sigma * noise
+
+ return sample
+
+
+ def training_target(self, sample, noise, timestep):
+ target = noise - sample
+ return target
+
+
+ def training_weight(self, timestep):
+ if isinstance(timestep, torch.Tensor):
+ self.timesteps = self.timesteps.to(timestep.device)
+ self.linear_timesteps_weights = self.linear_timesteps_weights.to(timestep.device)
+ if len(timestep.shape) == 1:
+ timestep_id = torch.argmin((self.timesteps - timestep.to(self.timesteps.device)).abs())
+ else:
+ timestep_id = torch.argmin((self.timesteps.unsqueeze(1) - timestep.to(self.timesteps.device)).abs(), dim=0)
+ weights = self.linear_timesteps_weights[timestep_id].to(self.timesteps.device)
+ return weights
diff --git a/PusaV1/diffsynth/tokenizer_configs/__init__.py b/PusaV1/diffsynth/tokenizer_configs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/PusaV1/diffsynth/tokenizer_configs/cog/tokenizer/added_tokens.json b/PusaV1/diffsynth/tokenizer_configs/cog/tokenizer/added_tokens.json
new file mode 100644
index 0000000000000000000000000000000000000000..3f5132007c4fcf42b75b65c8b6aa49c7098bcdf4
--- /dev/null
+++ b/PusaV1/diffsynth/tokenizer_configs/cog/tokenizer/added_tokens.json
@@ -0,0 +1,102 @@
+{
+ "": 32099,
+ "": 32089,
+ "": 32088,
+ "": 32087,
+ "": 32086,
+ "": 32085,
+ "": 32084,
+ "": 32083,
+ "": 32082,
+ "": 32081,
+ "": 32080,
+ "": 32098,
+ "": 32079,
+ "": 32078,
+ "": 32077,
+ "": 32076,
+ "": 32075,
+ "": 32074,
+ "": 32073,
+ "": 32072,
+ "": 32071,
+ "": 32070,
+ "": 32097,
+ "": 32069,
+ "": 32068,
+ "": 32067,
+ "": 32066,
+ "": 32065,
+ "": 32064,
+ "": 32063,
+ "": 32062,
+ "": 32061,
+ "": 32060,
+ "": 32096,
+ "": 32059,
+ "": 32058,
+ "": 32057,
+ "": 32056,
+ "": 32055,
+ "": 32054,
+ "": 32053,
+ "": 32052,
+ "": 32051,
+ "": 32050,
+ "": 32095,
+ "": 32049,
+ "": 32048,
+ "": 32047,
+ "": 32046,
+ "": 32045,
+ "": 32044,
+ "": 32043,
+ "": 32042,
+ "": 32041,
+ "": 32040,
+ "": 32094,
+ "": 32039,
+ "": 32038,
+ "": 32037,
+ "": 32036,
+ "": 32035,
+ "": 32034,
+ "": 32033,
+ "": 32032,
+ "": 32031,
+ "": 32030,
+ "": 32093,
+ "": 32029,
+ "": 32028,
+ "": 32027,
+ "": 32026,
+ "": 32025,
+ "": 32024,
+ "": 32023,
+ "": 32022,
+ "": 32021,
+ "": 32020,
+ "": 32092,
+ "": 32019,
+ "": 32018,
+ "": 32017,
+ "": 32016,
+ "": 32015,
+ "": 32014,
+ "": 32013,
+ "": 32012,
+ "": 32011,
+ "": 32010,
+ "": 32091,
+ "": 32009,
+ "": 32008,
+ "": 32007,
+ "": 32006,
+ "": 32005,
+ "": 32004,
+ "": 32003,
+ "": 32002,
+ "": 32001,
+ "": 32000,
+ "": 32090
+}
diff --git a/PusaV1/diffsynth/tokenizer_configs/cog/tokenizer/special_tokens_map.json b/PusaV1/diffsynth/tokenizer_configs/cog/tokenizer/special_tokens_map.json
new file mode 100644
index 0000000000000000000000000000000000000000..17ade346a1042cbe0c1436f5bedcbd85c099d582
--- /dev/null
+++ b/PusaV1/diffsynth/tokenizer_configs/cog/tokenizer/special_tokens_map.json
@@ -0,0 +1,125 @@
+{
+ "additional_special_tokens": [
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ ""
+ ],
+ "eos_token": {
+ "content": "",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false
+ },
+ "pad_token": {
+ "content": "",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false
+ },
+ "unk_token": {
+ "content": "",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false
+ }
+}
diff --git a/PusaV1/diffsynth/tokenizer_configs/cog/tokenizer/spiece.model b/PusaV1/diffsynth/tokenizer_configs/cog/tokenizer/spiece.model
new file mode 100644
index 0000000000000000000000000000000000000000..317a5ccbde45300f5d1d970d4d449af2108b147e
--- /dev/null
+++ b/PusaV1/diffsynth/tokenizer_configs/cog/tokenizer/spiece.model
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d60acb128cf7b7f2536e8f38a5b18a05535c9e14c7a355904270e15b0945ea86
+size 791656
diff --git a/PusaV1/diffsynth/tokenizer_configs/cog/tokenizer/tokenizer_config.json b/PusaV1/diffsynth/tokenizer_configs/cog/tokenizer/tokenizer_config.json
new file mode 100644
index 0000000000000000000000000000000000000000..161715af5ee99558c9fcce7b31d3d547a72c349b
--- /dev/null
+++ b/PusaV1/diffsynth/tokenizer_configs/cog/tokenizer/tokenizer_config.json
@@ -0,0 +1,940 @@
+{
+ "add_prefix_space": true,
+ "added_tokens_decoder": {
+ "0": {
+ "content": "",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": true
+ },
+ "1": {
+ "content": "",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": true
+ },
+ "2": {
+ "content": "",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": true
+ },
+ "32000": {
+ "content": "",
+ "lstrip": true,
+ "normalized": false,
+ "rstrip": true,
+ "single_word": false,
+ "special": true
+ },
+ "32001": {
+ "content": "",
+ "lstrip": true,
+ "normalized": false,
+ "rstrip": true,
+ "single_word": false,
+ "special": true
+ },
+ "32002": {
+ "content": "",
+ "lstrip": true,
+ "normalized": false,
+ "rstrip": true,
+ "single_word": false,
+ "special": true
+ },
+ "32003": {
+ "content": "",
+ "lstrip": true,
+ "normalized": false,
+ "rstrip": true,
+ "single_word": false,
+ "special": true
+ },
+ "32004": {
+ "content": "",
+ "lstrip": true,
+ "normalized": false,
+ "rstrip": true,
+ "single_word": false,
+ "special": true
+ },
+ "32005": {
+ "content": "",
+ "lstrip": true,
+ "normalized": false,
+ "rstrip": true,
+ "single_word": false,
+ "special": true
+ },
+ "32006": {
+ "content": "",
+ "lstrip": true,
+ "normalized": false,
+ "rstrip": true,
+ "single_word": false,
+ "special": true
+ },
+ "32007": {
+ "content": "",
+ "lstrip": true,
+ "normalized": false,
+ "rstrip": true,
+ "single_word": false,
+ "special": true
+ },
+ "32008": {
+ "content": "",
+ "lstrip": true,
+ "normalized": false,
+ "rstrip": true,
+ "single_word": false,
+ "special": true
+ },
+ "32009": {
+ "content": "",
+ "lstrip": true,
+ "normalized": false,
+ "rstrip": true,
+ "single_word": false,
+ "special": true
+ },
+ "32010": {
+ "content": "",
+ "lstrip": true,
+ "normalized": false,
+ "rstrip": true,
+ "single_word": false,
+ "special": true
+ },
+ "32011": {
+ "content": "",
+ "lstrip": true,
+ "normalized": false,
+ "rstrip": true,
+ "single_word": false,
+ "special": true
+ },
+ "32012": {
+ "content": "",
+ "lstrip": true,
+ "normalized": false,
+ "rstrip": true,
+ "single_word": false,
+ "special": true
+ },
+ "32013": {
+ "content": "",
+ "lstrip": true,
+ "normalized": false,
+ "rstrip": true,
+ "single_word": false,
+ "special": true
+ },
+ "32014": {
+ "content": "",
+ "lstrip": true,
+ "normalized": false,
+ "rstrip": true,
+ "single_word": false,
+ "special": true
+ },
+ "32015": {
+ "content": "",
+ "lstrip": true,
+ "normalized": false,
+ "rstrip": true,
+ "single_word": false,
+ "special": true
+ },
+ "32016": {
+ "content": "",
+ "lstrip": true,
+ "normalized": false,
+ "rstrip": true,
+ "single_word": false,
+ "special": true
+ },
+ "32017": {
+ "content": "",
+ "lstrip": true,
+ "normalized": false,
+ "rstrip": true,
+ "single_word": false,
+ "special": true
+ },
+ "32018": {
+ "content": "",
+ "lstrip": true,
+ "normalized": false,
+ "rstrip": true,
+ "single_word": false,
+ "special": true
+ },
+ "32019": {
+ "content": "",
+ "lstrip": true,
+ "normalized": false,
+ "rstrip": true,
+ "single_word": false,
+ "special": true
+ },
+ "32020": {
+ "content": "",
+ "lstrip": true,
+ "normalized": false,
+ "rstrip": true,
+ "single_word": false,
+ "special": true
+ },
+ "32021": {
+ "content": "",
+ "lstrip": true,
+ "normalized": false,
+ "rstrip": true,
+ "single_word": false,
+ "special": true
+ },
+ "32022": {
+ "content": "",
+ "lstrip": true,
+ "normalized": false,
+ "rstrip": true,
+ "single_word": false,
+ "special": true
+ },
+ "32023": {
+ "content": "",
+ "lstrip": true,
+ "normalized": false,
+ "rstrip": true,
+ "single_word": false,
+ "special": true
+ },
+ "32024": {
+ "content": "",
+ "lstrip": true,
+ "normalized": false,
+ "rstrip": true,
+ "single_word": false,
+ "special": true
+ },
+ "32025": {
+ "content": "",
+ "lstrip": true,
+ "normalized": false,
+ "rstrip": true,
+ "single_word": false,
+ "special": true
+ },
+ "32026": {
+ "content": "",
+ "lstrip": true,
+ "normalized": false,
+ "rstrip": true,
+ "single_word": false,
+ "special": true
+ },
+ "32027": {
+ "content": "",
+ "lstrip": true,
+ "normalized": false,
+ "rstrip": true,
+ "single_word": false,
+ "special": true
+ },
+ "32028": {
+ "content": "",
+ "lstrip": true,
+ "normalized": false,
+ "rstrip": true,
+ "single_word": false,
+ "special": true
+ },
+ "32029": {
+ "content": "