diff --git a/PyTorch/contrib/cv/video/CogVideoX/LICENSE b/PyTorch/contrib/cv/video/CogVideoX/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..f98e413cf55b56dbabdf693559055d79800f63c9 --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2024 CogVideo Model Team @ Zhipu AI + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/PyTorch/contrib/cv/video/CogVideoX/MODEL_LICENSE b/PyTorch/contrib/cv/video/CogVideoX/MODEL_LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..3ca0c74848a189e77f466c542122bc09aab94381 --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/MODEL_LICENSE @@ -0,0 +1,71 @@ +The CogVideoX License + +1. Definitions + +“Licensor” means the CogVideoX Model Team that distributes its Software. + +“Software” means the CogVideoX model parameters made available under this license. + +2. License Grant + +Under the terms and conditions of this license, the licensor hereby grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty-free copyright license. The intellectual property rights of the generated content belong to the user to the extent permitted by applicable local laws. +This license allows you to freely use all open-source models in this repository for academic research. Users who wish to use the models for commercial purposes must register and obtain a basic commercial license in https://open.bigmodel.cn/mla/form . +Users who have registered and obtained the basic commercial license can use the models for commercial activities for free, but must comply with all terms and conditions of this license. Additionally, the number of service users (visits) for your commercial activities must not exceed 1 million visits per month. +If the number of service users (visits) for your commercial activities exceeds 1 million visits per month, you need to contact our business team to obtain more commercial licenses. +The above copyright statement and this license statement should be included in all copies or significant portions of this software. + +3. Restriction + +You will not use, copy, modify, merge, publish, distribute, reproduce, or create derivative works of the Software, in whole or in part, for any military, or illegal purposes. + +You will not use the Software for any act that may undermine China's national security and national unity, harm the public interest of society, or infringe upon the rights and interests of human beings. + +4. Disclaimer + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +5. Limitation of Liability + +EXCEPT TO THE EXTENT PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER BASED IN TORT, NEGLIGENCE, CONTRACT, LIABILITY, OR OTHERWISE WILL ANY LICENSOR BE LIABLE TO YOU FOR ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES, OR ANY OTHER COMMERCIAL LOSSES, EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. + +6. Dispute Resolution + +This license shall be governed and construed in accordance with the laws of People’s Republic of China. Any dispute arising from or in connection with this License shall be submitted to Haidian District People's Court in Beijing. + +Note that the license is subject to update to a more comprehensive version. For any questions related to the license and copyright, please contact us at license@zhipuai.cn. + +1. 定义 + +“许可方”是指分发其软件的 CogVideoX 模型团队。 + +“软件”是指根据本许可提供的 CogVideoX 模型参数。 + +2. 许可授予 + +根据本许可的条款和条件,许可方特此授予您非排他性、全球性、不可转让、不可再许可、可撤销、免版税的版权许可。生成内容的知识产权所属,可根据适用当地法律的规定,在法律允许的范围内由用户享有生成内容的知识产权或其他权利。 +本许可允许您免费使用本仓库中的所有开源模型进行学术研究。对于希望将模型用于商业目的的用户,需在 https://open.bigmodel.cn/mla/form 完成登记并获得基础商用授权。 + +经过登记并获得基础商用授权的用户可以免费使用本模型进行商业活动,但必须遵守本许可的所有条款和条件。 +在本许可证下,您的商业活动的服务用户数量(访问量)不得超过100万人次访问 / 每月。如果超过,您需要与我们的商业团队联系以获得更多的商业许可。 +上述版权声明和本许可声明应包含在本软件的所有副本或重要部分中。 + +3.限制 + +您不得出于任何军事或非法目的使用、复制、修改、合并、发布、分发、复制或创建本软件的全部或部分衍生作品。 + +您不得利用本软件从事任何危害国家安全和国家统一、危害社会公共利益、侵犯人身权益的行为。 + +4.免责声明 + +本软件“按原样”提供,不提供任何明示或暗示的保证,包括但不限于对适销性、特定用途的适用性和非侵权性的保证。 +在任何情况下,作者或版权持有人均不对任何索赔、损害或其他责任负责,无论是在合同诉讼、侵权行为还是其他方面,由软件或软件的使用或其他交易引起、由软件引起或与之相关 软件。 + +5. 责任限制 + +除适用法律禁止的范围外,在任何情况下且根据任何法律理论,无论是基于侵权行为、疏忽、合同、责任或其他原因,任何许可方均不对您承担任何直接、间接、特殊、偶然、示范性、 或间接损害,或任何其他商业损失,即使许可人已被告知此类损害的可能性。 + +6.争议解决 + +本许可受中华人民共和国法律管辖并按其解释。 因本许可引起的或与本许可有关的任何争议应提交北京市海淀区人民法院。 + +请注意,许可证可能会更新到更全面的版本。 有关许可和版权的任何问题,请通过 license@zhipuai.cn 与我们联系。 \ No newline at end of file diff --git a/PyTorch/contrib/cv/video/CogVideoX/README.md b/PyTorch/contrib/cv/video/CogVideoX/README.md new file mode 100644 index 0000000000000000000000000000000000000000..33a1ea3f59387944566c1113a2393dd142bcf699 --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/README.md @@ -0,0 +1,168 @@ +# CogVideoX for PyTorch + +- [概述](概述.md) +- [准备训练环境](准备训练环境.md) +- [开始训练](开始训练.md) +- [训练结果展示](训练结果展示.md) +- [版本说明](版本说明.md) + +# 概述 + +## 简述 + +CogVideoX-5B是一款最先进的文本到视频模型,可以从文本提示生成高质量的视频。利用3D因果VAE和专家Transformer架构,该模型确保时间上连贯和流畅的视频序列,使其非常适合复杂运动和详细语义生成。 + +- 参考实现: + + ``` + url=https://github.com/THUDM/CogVideo + commit_id=6a162073217bd41aa5436f11fc29ff69ef0b8623 + ``` +- 适配昇腾 AI 处理器的实现: + + ``` + url=https://gitee.com/ascend/ModelZoo-PyTorch.git + code_path=PyTorch/contrib/cv/video + ``` + +## 准备训练环境 + +## 准备环境 + +- 当前模型支持的 PyTorch 版本和已知三方依赖库如下表所示。 + + **表 1** 版本支持表 + + | Torch_Version | 三方库依赖版本 | + | :-----------: | :-------------------------: | + | Pytorch 2.4 | SwissArmyTransformer 0.4.12 | + | Pytorch 2.1 | SwissArmyTransformer 0.4.12 | +- 环境准备指导。 + + 请参考《[Pytorch框架训练环境准备](https://www.hiascend.com/document/detail/zh/ModelZoo/pytorchframework/ptes)》。 +- 安装依赖。 + + ``` + pip install -r requirements.txt + ``` + +> 注意arm机器安装triton和decord需要按照源码安装方式安装,以下为对应的源码仓库请根据对应readme进行安装。 +> +> *另外* ``CogVideoX/sat/arguments.py`` 及 ``CogVideoX/sat/vae_modules/regularizers.py`` 请到参考实现源码路径获取。 +> + + ``` + https://github.com/triton-lang/triton + https://github.com/dmlc/decord + ``` + +## 准备数据集 + +获取数据集。 + + 用户自行准备480 720分辨率、fps为8的视频,以及其对应的文本描述。在模型根目录下创建 `dataset` 目录,并放入数据集。 + + 数据集目录结构参考如下所示。 + + ``` + |-- dataset + |-- labels + |-- 1.txt + |-- 2.txt + |-- 3.txt + | ... + |-- videos + |-- 1.mp4 + |-- 2.mp4 + |-- 3.mp4 + | ... + ``` + + > **说明:** + > 该数据集的训练过程脚本只作为一种参考示例 + > + +## 下载预训练模型 + +下载 `CogVideoX-5B` 和 `3d-vae` 模型权重,在模型根目录下创建 `CogVideoX-5b-sat` 目录,并将模型权重放置在该目录下。 + + ``` + wget https://cloud.tsinghua.edu.cn/f/fdba7608a49c463ba754/?dl=1 + mv 'index.html?dl=1' vae.zip + unzip vae.zip + wget https://cloud.tsinghua.edu.cn/f/556a3e1329e74f1bac45/?dl=1 + mv 'index.html?dl=1' transformer.zip + unzip transformer.zip + ``` + +模型结构如下: + + ``` + |-- CogVideoX-5B-sat + |-- transformer + |-- 1 + |-- mp_rank_00_model_states.pt + |-- latest + |-- vae + |-- 3d-vae.pt + ``` + +下载 `t5` 模型用于编码文本,在模型根目录下创建 `t5-v1_1-xxl` 目录,并将预训练模型放置在该目录下。 + + ``` + git clone https://huggingface.co/THUDM/CogVideoX-2b.git + mkdir t5-v1_1-xxl + mv CogVideoX-2b/text_encoder/* CogVideoX-2b/tokenizer/* t5-v1_1-xxl + ``` + +模型结构如下: + + ``` + |-- t5-v1_1-xxl + |-- added_tokens.json + |-- config.json + |-- model-00001-of-00002.safetensors + |-- model-00002-of-00002.safetensors + |-- model.safetensors.index.json + |-- special_tokens_map.json + |-- spiece.model + |-- tokenizer_config.json + ``` + +# 开始训练 + +## 训练模型 + +1. 进入解压后的源码包根目录下的sat目录。 + ``` + cd /${模型文件夹名称}/sat + ``` +2. 运行训练脚本。 + + 该模型支持单机8卡训练。 + + - 单机8卡训练 + + ``` + bash ./finetune_multi_gpus.sh + ``` + + 训练完成后,权重文件保存在当前路径下。 + +# 训练结果展示 + +**表2** 训练结果展示表 + +| NAME | elapsed time per iteration (ms) | Torch_Version | +| :----: | :-----------------------------: | :-----------: | +| 8p-NPU | 25059.4 | 2.1 | + +# 版本说明 + +## 变更 + +2024.10.16:首次发布。 + +## FAQ + +无。 \ No newline at end of file diff --git a/PyTorch/contrib/cv/video/CogVideoX/pyproject.toml b/PyTorch/contrib/cv/video/CogVideoX/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..09bc849926ce8a8868cf7853a4f58e9229176bfa --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/pyproject.toml @@ -0,0 +1,27 @@ +[tool.ruff] +line-length = 119 + +[tool.ruff.lint] +# Never enforce `E501` (line length violations). +ignore = ["C901", "E501", "E741", "F402", "F823"] +select = ["C", "E", "F", "I", "W"] + +# Ignore import violations in all `__init__.py` files. +[tool.ruff.lint.per-file-ignores] +"__init__.py" = ["E402", "F401", "F403", "F811"] + +[tool.ruff.lint.isort] +lines-after-imports = 2 + +[tool.ruff.format] +# Like Black, use double quotes for strings. +quote-style = "double" + +# Like Black, indent with spaces, rather than tabs. +indent-style = "space" + +# Like Black, respect magic trailing commas. +skip-magic-trailing-comma = false + +# Like Black, automatically detect the appropriate line ending. +line-ending = "auto" diff --git a/PyTorch/contrib/cv/video/CogVideoX/requirements.txt b/PyTorch/contrib/cv/video/CogVideoX/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..582888d2eb4d37626a449774387f2f799a942c62 --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/requirements.txt @@ -0,0 +1,15 @@ +diffusers>=0.30.3 +accelerate>=0.34.2 +transformers>=4.44.2 +numpy==1.26.0 +torch>=2.1.0 +torchvision>=0.16.0 +sentencepiece>=0.2.0 +SwissArmyTransformer>=0.4.12 +gradio>=4.44.0 +imageio>=2.35.1 +imageio-ffmpeg>=0.5.1 +openai>=1.45.0 +moviepy>=1.0.3 +pillow==9.5.0 +scikit-video diff --git a/PyTorch/contrib/cv/video/CogVideoX/resources/WECHAT.md b/PyTorch/contrib/cv/video/CogVideoX/resources/WECHAT.md new file mode 100644 index 0000000000000000000000000000000000000000..7f9620d24141549da2a9a8c0edcb7fa3ada96b73 --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/resources/WECHAT.md @@ -0,0 +1,7 @@ +
+ + +

扫码关注公众号,加入「 CogVideoX 交流群」

+

Scan the QR code to follow the official account and join the "CogVLM Discussion Group"

+
+ diff --git a/PyTorch/contrib/cv/video/CogVideoX/resources/contribute.md b/PyTorch/contrib/cv/video/CogVideoX/resources/contribute.md new file mode 100644 index 0000000000000000000000000000000000000000..0d3640f5981004a6af19385325fa9edb8fe40590 --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/resources/contribute.md @@ -0,0 +1,49 @@ +# Contribution Guide + +There may still be many incomplete aspects in this project. + +We look forward to your contributions to the repository in the following areas. If you complete the work mentioned above +and are willing to submit a PR and share it with the community, upon review, we +will acknowledge your contribution on the project homepage. + +## Model Algorithms + +- Support for model quantization inference (Int4 quantization project) +- Optimization of model fine-tuning data loading (replacing the existing decord tool) + +## Model Engineering + +- Model fine-tuning examples / Best prompt practices +- Inference adaptation on different devices (e.g., MLX framework) +- Any tools related to the model +- Any minimal fully open-source project using the CogVideoX open-source model + +## Code Standards + +Good code style is an art. We have prepared a `pyproject.toml` configuration file for the project to standardize code +style. You can organize the code according to the following specifications: + +1. Install the `ruff` tool + +```shell +pip install ruff +``` + +Then, run the `ruff` tool + +```shell +ruff check tools sat inference +``` + +Check the code style. If there are issues, you can automatically fix them using the `ruff format` command. + +```shell +ruff format tools sat inference +``` + +Once your code meets the standard, there should be no errors. + +## Naming Conventions +1. Please use English names, do not use Pinyin or other language names. All comments should be in English. +2. Please strictly follow the PEP8 specification and use underscores to separate words. Do not use names like a, b, c. + diff --git a/PyTorch/contrib/cv/video/CogVideoX/resources/contribute_ja.md b/PyTorch/contrib/cv/video/CogVideoX/resources/contribute_ja.md new file mode 100644 index 0000000000000000000000000000000000000000..80ddc275f3f40fde043ceacd474b12c805cabed4 --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/resources/contribute_ja.md @@ -0,0 +1,47 @@ +# コントリビューションガイド + +本プロジェクトにはまだ多くの未完成の部分があります。 + +以下の分野でリポジトリへの貢献をお待ちしています。上記の作業を完了し、PRを提出してコミュニティと共有する意志がある場合、レビュー後、プロジェクトのホームページで貢献を認識します。 + +## モデルアルゴリズム + +- モデル量子化推論のサポート (Int4量子化プロジェクト) +- モデルのファインチューニングデータロードの最適化(既存のdecordツールの置き換え) + +## モデルエンジニアリング + +- モデルのファインチューニング例 / 最適なプロンプトの実践 +- 異なるデバイスでの推論適応(例: MLXフレームワーク) +- モデルに関連するツール +- CogVideoXオープンソースモデルを使用した、完全にオープンソースの最小プロジェクト + +## コード標準 + +良いコードスタイルは一種の芸術です。本プロジェクトにはコードスタイルを標準化するための `pyproject.toml` +設定ファイルを用意しています。以下の仕様に従ってコードを整理してください。 + +1. `ruff` ツールをインストールする + +```shell +pip install ruff +``` + +次に、`ruff` ツールを実行します + +```shell +ruff check tools sat inference +``` + +コードスタイルを確認します。問題がある場合は、`ruff format` コマンドを使用して自動修正できます。 + +```shell +ruff format tools sat inference +``` + +コードが標準に準拠したら、エラーはなくなるはずです。 + +## 命名規則 + +1. 英語名を使用してください。ピンインや他の言語の名前を使用しないでください。すべてのコメントは英語で記載してください。 +2. PEP8仕様に厳密に従い、単語をアンダースコアで区切ってください。a、b、cのような名前は使用しないでください。 diff --git a/PyTorch/contrib/cv/video/CogVideoX/resources/contribute_zh.md b/PyTorch/contrib/cv/video/CogVideoX/resources/contribute_zh.md new file mode 100644 index 0000000000000000000000000000000000000000..4b95254cd8b00b350a52322211bfe606f471f8ba --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/resources/contribute_zh.md @@ -0,0 +1,44 @@ +# 贡献指南 + +本项目可能还存在很多不完善的内容。 我们期待您在以下方面与我们共建仓库, 如果您完成了上述工作并愿意PR和分享到社区,在通过审核后,我们将在项目首页感谢您的贡献。 + +## 模型算法 + +- 模型量化推理支持 (Int4量化工程) +- 模型微调数据载入优化支持(替换现有的decord工具) + +## 模型工程 + +- 模型微调示例 / 最佳提示词实践 +- 不同设备上的推理适配(MLX等框架) +- 任何模型周边工具 +- 任何使用CogVideoX开源模型制作的最小完整开源项目 + +## 代码规范 + +良好的代码风格是一种艺术,我们已经为项目准备好了`pyproject.toml`配置文件,用于规范代码风格。您可以按照以下规范梳理代码: + +1. 安装`ruff`工具 + +```shell +pip install ruff +``` + +接着,运行`ruff`工具 + +```shell +ruff check tools sat inference +``` + +检查代码风格,如果有问题,您可以通过`ruff format .`命令自动修复。 + +```shell +ruff format tools sat inference +``` + +如果您的代码符合规范,应该不会出现任何的错误。 + +## 命名规范 + +- 请使用英文命名,不要使用拼音或者其他语言命名。所有的注释均使用英文。 +- 请严格遵循 PEP8 规范,使用下划线分割单词。请勿使用 a,b,c 这样的命名。 \ No newline at end of file diff --git a/PyTorch/contrib/cv/video/CogVideoX/resources/galary_prompt.md b/PyTorch/contrib/cv/video/CogVideoX/resources/galary_prompt.md new file mode 100644 index 0000000000000000000000000000000000000000..a738bb2a4c9f0dbfd2265ad03ea42120279d36d0 --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/resources/galary_prompt.md @@ -0,0 +1,31 @@ +## CogVideoX-5B + +Videos 1-8: + +1. A garden comes to life as a kaleidoscope of butterflies flutters amidst the blossoms, their delicate wings casting shadows on the petals below. In the background, a grand fountain cascades water with a gentle splendor, its rhythmic sound providing a soothing backdrop. Beneath the cool shade of a mature tree, a solitary wooden chair invites solitude and reflection, its smooth surface worn by the touch of countless visitors seeking a moment of tranquility in nature's embrace. + +2. A small boy, head bowed and determination etched on his face, sprints through the torrential downpour as闪电 crackles and 雷鸣 rumbles in the distance. The relentless rain pounds the ground, creating a chaotic dance of water droplets that mirror the Dramatic sky's anger. In the far background, the silhouette of a cozy home beckons, a faint beacon of safety and warmth amidst the fierce weather. The scene is one of perseverance and the unyielding spirit of a child braving the elements. + +3. A suited astronaut, with the red dust of Mars clinging to their boots, reaches out to shake hands with an alien being, their skin a shimmering blue, under the pink-tinged sky of the fourth planet. In the background, a sleek silver rocket, a beacon of human ingenuity, stands tall, its engines powered down, as the two representatives of different worlds exchange a historic greeting amidst the desolate beauty of the Martian landscape. + +4. An elderly gentleman, with a serene expression, sits at the water's edge, a steaming cup of tea by his side. He is engrossed in his artwork, brush in hand, as he renders an oil painting on a canvas that's propped up against a small, weathered table. The sea breeze whispers through his silver hair, gently billowing his loose-fitting white shirt, while the salty air adds an intangible element to his masterpiece in progress. The scene is one of tranquility and inspiration, with the artist's canvas capturing the vibrant hues of the setting sun reflecting off the tranquil sea. + +5. In a dimly lit bar, purplish light bathes the face of a mature man, his eyes blinking thoughtfully as he ponders in close-up, the background artfully blurred to focus on his introspective expression, the ambiance of the bar a mere suggestion of shadows and soft lighting. + +6. A golden retriever, sporting sleek black sunglasses, with its lengthy fur flowing in the breeze, sprints playfully across a rooftop terrace, recently refreshed by a light rain. The scene unfolds from a distance, the dog's energetic bounds growing larger as it approaches the camera, its tail wagging with unrestrained joy, while droplets of water glisten on the concrete behind it. The overcast sky provides a dramatic backdrop, emphasizing the vibrant golden coat of the canine as it dashes towards the viewer. + +7. On a brilliant sunny day, the lakeshore is lined with an array of willow trees, their slender branches swaying gently in the soft breeze. The tranquil surface of the lake reflects the clear blue sky, while several elegant swans glide gracefully through the still water, leaving behind delicate ripples that disturb the mirror-like quality of the lake. The scene is one of serene beauty, with the willows' greenery providing a picturesque frame for the peaceful avian visitors. + +8. A Chinese mother, draped in a soft, pastel-colored robe, gently rocks back and forth in a cozy rocking chair positioned in the tranquil setting of a nursery. The dimly lit bedroom is adorned with whimsical mobiles dangling from the ceiling, casting shadows that dance on the walls. Her baby, swaddled in a delicate, patterned blanket, rests against her chest, the child's earlier cries now replaced by contented coos as the mother's soothing voice lulls the little one to sleep. The scent of lavender fills the air, adding to the serene atmosphere, while a warm, orange glow from a nearby nightlight illuminates the scene with a gentle hue, capturing a moment of tender love and comfort. + +## CogVideoX-2B + +Videos 1-4: + +1. A detailed wooden toy ship with intricately carved masts and sails is seen gliding smoothly over a plush, blue carpet that mimics the waves of the sea. The ship's hull is painted a rich brown, with tiny windows. The carpet, soft and textured, provides a perfect backdrop, resembling an oceanic expanse. Surrounding the ship are various other toys and children's items, hinting at a playful environment. The scene captures the innocence and imagination of childhood, with the toy ship's journey symbolizing endless adventures in a whimsical, indoor setting. + +2. The camera follows behind a white vintage SUV with a black roof rack as it speeds up a steep dirt road surrounded by pine trees on a steep mountain slope, dust kicks up from its tires, the sunlight shines on the SUV as it speeds along the dirt road, casting a warm glow over the scene. The dirt road curves gently into the distance, with no other cars or vehicles in sight. The trees on either side of the road are redwoods, with patches of greenery scattered throughout. The car is seen from the rear following the curve with ease, making it seem as if it is on a rugged drive through the rugged terrain. The dirt road itself is surrounded by steep hills and mountains, with a clear blue sky above with wispy clouds. + +3. A street artist, clad in a worn-out denim jacket and a colorful bandana, stands before a vast concrete wall in the heart, holding a can of spray paint, spray-painting a colorful bird on a mottled wall. + +4. In the haunting backdrop of a war-torn city, where ruins and crumbled walls tell a story of devastation, a poignant close-up frames a young girl. Her face is smudged with ash, a silent testament to the chaos around her. Her eyes glistening with a mix of sorrow and resilience, capturing the raw emotion of a world that has lost its innocence to the ravages of conflict. diff --git a/PyTorch/contrib/cv/video/CogVideoX/resources/logo.svg b/PyTorch/contrib/cv/video/CogVideoX/resources/logo.svg new file mode 100644 index 0000000000000000000000000000000000000000..d0b8d44233658ea1293185b92a5d77e02fbe5484 --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/resources/logo.svg @@ -0,0 +1,142 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/README.md b/PyTorch/contrib/cv/video/CogVideoX/sat/README.md new file mode 100644 index 0000000000000000000000000000000000000000..48c45521d1eb7bdc9f9efc97ba54579569668937 --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/README.md @@ -0,0 +1,440 @@ +# SAT CogVideoX-2B + +[中文阅读](./README_zh.md) + +[日本語で読む](./README_ja.md) + +This folder contains the inference code using [SAT](https://github.com/THUDM/SwissArmyTransformer) weights and the +fine-tuning code for SAT weights. + +This code is the framework used by the team to train the model. It has few comments and requires careful study. + +## Inference Model + +### 1. Ensure that you have correctly installed the dependencies required by this folder. + +```shell +pip install -r requirements.txt +``` + +### 2. Download the model weights + +### 2. Download model weights + +First, go to the SAT mirror to download the model weights. For the CogVideoX-2B model, please download as follows: + +```shell +mkdir CogVideoX-2b-sat +cd CogVideoX-2b-sat +wget https://cloud.tsinghua.edu.cn/f/fdba7608a49c463ba754/?dl=1 +mv 'index.html?dl=1' vae.zip +unzip vae.zip +wget https://cloud.tsinghua.edu.cn/f/556a3e1329e74f1bac45/?dl=1 +mv 'index.html?dl=1' transformer.zip +unzip transformer.zip +``` + +For the CogVideoX-5B model, please download the `transformers` file as follows link: +(VAE files are the same as 2B) + ++ [CogVideoX-5B](https://cloud.tsinghua.edu.cn/d/fcef5b3904294a6885e5/?p=%2F&mode=list) ++ [CogVideoX-5B-I2V](https://cloud.tsinghua.edu.cn/d/5cc62a2d6e7d45c0a2f6/?p=%2F1&mode=list) + +Next, you need to format the model files as follows: + +``` +. +├── transformer +│ ├── 1000 (or 1) +│ │ └── mp_rank_00_model_states.pt +│ └── latest +└── vae + └── 3d-vae.pt +``` + +Due to large size of model weight file, using `git lfs` is recommended. Installation of `git lfs` can be +found [here](https://github.com/git-lfs/git-lfs?tab=readme-ov-file#installing) + +Next, clone the T5 model, which is not used for training and fine-tuning, but must be used. +> T5 model is available on [Modelscope](https://modelscope.cn/models/ZhipuAI/CogVideoX-2b) as well. + +```shell +git clone https://huggingface.co/THUDM/CogVideoX-2b.git +mkdir t5-v1_1-xxl +mv CogVideoX-2b/text_encoder/* CogVideoX-2b/tokenizer/* t5-v1_1-xxl +``` + +By following the above approach, you will obtain a safetensor format T5 file. Ensure that there are no errors when +loading it into Deepspeed in Finetune. + +``` +├── added_tokens.json +├── config.json +├── model-00001-of-00002.safetensors +├── model-00002-of-00002.safetensors +├── model.safetensors.index.json +├── special_tokens_map.json +├── spiece.model +└── tokenizer_config.json + +0 directories, 8 files +``` + +### 3. Modify the file in `configs/cogvideox_2b.yaml`. + +```yaml +model: + scale_factor: 1.15258426 + disable_first_stage_autocast: true + log_keys: + - txt + + denoiser_config: + target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser + params: + num_idx: 1000 + quantize_c_noise: False + + weighting_config: + target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting + scaling_config: + target: sgm.modules.diffusionmodules.denoiser_scaling.VideoScaling + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization + params: + shift_scale: 3.0 + + network_config: + target: dit_video_concat.DiffusionTransformer + params: + time_embed_dim: 512 + elementwise_affine: True + num_frames: 49 + time_compressed_rate: 4 + latent_width: 90 + latent_height: 60 + num_layers: 30 + patch_size: 2 + in_channels: 16 + out_channels: 16 + hidden_size: 1920 + adm_in_channels: 256 + num_attention_heads: 30 + + transformer_args: + checkpoint_activations: True ## using gradient checkpointing + vocab_size: 1 + max_sequence_length: 64 + layernorm_order: pre + skip_init: false + model_parallel_size: 1 + is_decoder: false + + modules: + pos_embed_config: + target: dit_video_concat.Basic3DPositionEmbeddingMixin + params: + text_length: 226 + height_interpolation: 1.875 + width_interpolation: 1.875 + + patch_embed_config: + target: dit_video_concat.ImagePatchEmbeddingMixin + params: + text_hidden_size: 4096 + + adaln_layer_config: + target: dit_video_concat.AdaLNMixin + params: + qk_ln: True + + final_layer_config: + target: dit_video_concat.FinalLayerMixin + + conditioner_config: + target: sgm.modules.GeneralConditioner + params: + emb_models: + - is_trainable: false + input_key: txt + ucg_rate: 0.1 + target: sgm.modules.encoders.modules.FrozenT5Embedder + params: + model_dir: "t5-v1_1-xxl" # Absolute path to the CogVideoX-2b/t5-v1_1-xxl weights folder + max_length: 226 + + first_stage_config: + target: vae_modules.autoencoder.VideoAutoencoderInferenceWrapper + params: + cp_size: 1 + ckpt_path: "CogVideoX-2b-sat/vae/3d-vae.pt" # Absolute path to the CogVideoX-2b-sat/vae/3d-vae.pt folder + ignore_keys: [ 'loss' ] + + loss_config: + target: torch.nn.Identity + + regularizer_config: + target: vae_modules.regularizers.DiagonalGaussianRegularizer + + encoder_config: + target: vae_modules.cp_enc_dec.ContextParallelEncoder3D + params: + double_z: true + z_channels: 16 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [ 1, 2, 2, 4 ] + attn_resolutions: [ ] + num_res_blocks: 3 + dropout: 0.0 + gather_norm: True + + decoder_config: + target: vae_modules.cp_enc_dec.ContextParallelDecoder3D + params: + double_z: True + z_channels: 16 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [ 1, 2, 2, 4 ] + attn_resolutions: [ ] + num_res_blocks: 3 + dropout: 0.0 + gather_norm: False + + loss_fn_config: + target: sgm.modules.diffusionmodules.loss.VideoDiffusionLoss + params: + offset_noise_level: 0 + sigma_sampler_config: + target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling + params: + uniform_sampling: True + num_idx: 1000 + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization + params: + shift_scale: 3.0 + + sampler_config: + target: sgm.modules.diffusionmodules.sampling.VPSDEDPMPP2MSampler + params: + num_steps: 50 + verbose: True + + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization + params: + shift_scale: 3.0 + + guider_config: + target: sgm.modules.diffusionmodules.guiders.DynamicCFG + params: + scale: 6 + exp: 5 + num_steps: 50 +``` + +### 4. Modify the file in `configs/inference.yaml`. + +```yaml +args: + latent_channels: 16 + mode: inference + load: "{absolute_path/to/your}/transformer" # Absolute path to the CogVideoX-2b-sat/transformer folder + # load: "{your lora folder} such as zRzRzRzRzRzRzR/lora-disney-08-20-13-28" # This is for Full model without lora adapter + + batch_size: 1 + input_type: txt # You can choose txt for pure text input, or change to cli for command line input + input_file: configs/test.txt # Pure text file, which can be edited + sampling_num_frames: 13 # Must be 13, 11 or 9 + sampling_fps: 8 + fp16: True # For CogVideoX-2B + # bf16: True # For CogVideoX-5B + output_dir: outputs/ + force_inference: True +``` + ++ Modify `configs/test.txt` if multiple prompts is required, in which each line makes a prompt. ++ For better prompt formatting, refer to [convert_demo.py](../inference/convert_demo.py), for which you should set the + OPENAI_API_KEY as your environmental variable. ++ Modify `input_type` in `configs/inference.yaml` if use command line as prompt iuput. + +```yaml +input_type: cli +``` + +This allows input from the command line as prompts. + +Change `output_dir` if you wish to modify the address of the output video + +```yaml +output_dir: outputs/ +``` + +It is saved by default in the `.outputs/` folder. + +### 5. Run the inference code to perform inference. + +```shell +bash inference.sh +``` + +## Fine-tuning the Model + +### Preparing the Dataset + +The dataset format should be as follows: + +``` +. +├── labels +│   ├── 1.txt +│   ├── 2.txt +│   ├── ... +└── videos + ├── 1.mp4 + ├── 2.mp4 + ├── ... +``` + +Each text file shares the same name as its corresponding video, serving as the label for that video. Videos and labels +should be matched one-to-one. Generally, a single video should not be associated with multiple labels. + +For style fine-tuning, please prepare at least 50 videos and labels with similar styles to ensure proper fitting. + +### Modifying Configuration Files + +We support two fine-tuning methods: `Lora` and full-parameter fine-tuning. Please note that both methods only fine-tune +the `transformer` part and do not modify the `VAE` section. `T5` is used solely as an Encoder. Please modify +the `configs/sft.yaml` (for full-parameter fine-tuning) file as follows: + +``` + # checkpoint_activations: True ## Using gradient checkpointing (Both checkpoint_activations in the config file need to be set to True) + model_parallel_size: 1 # Model parallel size + experiment_name: lora-disney # Experiment name (do not modify) + mode: finetune # Mode (do not modify) + load: "{your_CogVideoX-2b-sat_path}/transformer" ## Transformer model path + no_load_rng: True # Whether to load random seed + train_iters: 1000 # Training iterations + eval_iters: 1 # Evaluation iterations + eval_interval: 100 # Evaluation interval + eval_batch_size: 1 # Evaluation batch size + save: ckpts # Model save path + save_interval: 100 # Model save interval + log_interval: 20 # Log output interval + train_data: [ "your train data path" ] + valid_data: [ "your val data path" ] # Training and validation datasets can be the same + split: 1,0,0 # Training, validation, and test set ratio + num_workers: 8 # Number of worker threads for data loader + force_train: True # Allow missing keys when loading checkpoint (T5 and VAE are loaded separately) + only_log_video_latents: True # Avoid memory overhead caused by VAE decode + deepspeed: + bf16: + enabled: False # For CogVideoX-2B set to False and for CogVideoX-5B set to True + fp16: + enabled: True # For CogVideoX-2B set to True and for CogVideoX-5B set to False +``` + +If you wish to use Lora fine-tuning, you also need to modify the `cogvideox__lora` file: + +Here, take `CogVideoX-2B` as a reference: + +``` +model: + scale_factor: 1.15258426 + disable_first_stage_autocast: true + not_trainable_prefixes: [ 'all' ] ## Uncomment + log_keys: + - txt' + + lora_config: ## Uncomment + target: sat.model.finetune.lora2.LoraMixin + params: + r: 256 +``` + +### Modifying Run Scripts + +Edit `finetune_single_gpu.sh` or `finetune_multi_gpus.sh` to select the configuration file. Below are two examples: + +1. If you want to use the `CogVideoX-2B` model and the `Lora` method, you need to modify `finetune_single_gpu.sh` + or `finetune_multi_gpus.sh`: + +``` +run_cmd="torchrun --standalone --nproc_per_node=8 train_video.py --base configs/cogvideox_2b_lora.yaml configs/sft.yaml --seed $RANDOM" +``` + +2. If you want to use the `CogVideoX-2B` model and the `full-parameter fine-tuning` method, you need to + modify `finetune_single_gpu.sh` or `finetune_multi_gpus.sh`: + +``` +run_cmd="torchrun --standalone --nproc_per_node=8 train_video.py --base configs/cogvideox_2b.yaml configs/sft.yaml --seed $RANDOM" +``` + +### Fine-Tuning and Evaluation + +Run the inference code to start fine-tuning. + +``` +bash finetune_single_gpu.sh # Single GPU +bash finetune_multi_gpus.sh # Multi GPUs +``` + +### Using the Fine-Tuned Model + +The fine-tuned model cannot be merged; here is how to modify the inference configuration file `inference.sh`: + +``` +run_cmd="$environs python sample_video.py --base configs/cogvideox__lora.yaml configs/inference.yaml --seed 42" +``` + +Then, execute the code: + +``` +bash inference.sh +``` + +### Converting to Huggingface Diffusers Supported Weights + +The SAT weight format is different from Huggingface's weight format and needs to be converted. Please run: + +```shell +python ../tools/convert_weight_sat2hf.py +``` + +### Exporting Huggingface Diffusers lora LoRA Weights from SAT Checkpoints + +After completing the training using the above steps, we get a SAT checkpoint with LoRA weights. You can find the file +at `{args.save}/1000/1000/mp_rank_00_model_states.pt`. + +The script for exporting LoRA weights can be found in the CogVideoX repository at `tools/export_sat_lora_weight.py`. +After exporting, you can use `load_cogvideox_lora.py` for inference. + +Export command: + +```bash +python tools/export_sat_lora_weight.py --sat_pt_path {args.save}/{experiment_name}-09-09-21-10/1000/mp_rank_00_model_states.pt --lora_save_directory {args.save}/export_hf_lora_weights_1/ +``` + +This training mainly modified the following model structures. The table below lists the corresponding structure mappings +for converting to the HF (Hugging Face) format LoRA structure. As you can see, LoRA adds a low-rank weight to the +model's attention structure. + +``` +'attention.query_key_value.matrix_A.0': 'attn1.to_q.lora_A.weight', +'attention.query_key_value.matrix_A.1': 'attn1.to_k.lora_A.weight', +'attention.query_key_value.matrix_A.2': 'attn1.to_v.lora_A.weight', +'attention.query_key_value.matrix_B.0': 'attn1.to_q.lora_B.weight', +'attention.query_key_value.matrix_B.1': 'attn1.to_k.lora_B.weight', +'attention.query_key_value.matrix_B.2': 'attn1.to_v.lora_B.weight', +'attention.dense.matrix_A.0': 'attn1.to_out.0.lora_A.weight', +'attention.dense.matrix_B.0': 'attn1.to_out.0.lora_B.weight' +``` + +Using export_sat_lora_weight.py, you can convert the SAT checkpoint into the HF LoRA format. +![alt text](../resources/hf_lora_weights.png) diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/README_ja.md b/PyTorch/contrib/cv/video/CogVideoX/sat/README_ja.md new file mode 100644 index 0000000000000000000000000000000000000000..ee1abcdb7d16c565f3c61bda26284eb3485b4b5e --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/README_ja.md @@ -0,0 +1,438 @@ +# SAT CogVideoX-2B + +[Read this in English.](./README_zh) + +[中文阅读](./README_zh.md) + +このフォルダには、[SAT](https://github.com/THUDM/SwissArmyTransformer) ウェイトを使用した推論コードと、SAT +ウェイトのファインチューニングコードが含まれています。 + +このコードは、チームがモデルをトレーニングするために使用したフレームワークです。コメントが少なく、注意深く研究する必要があります。 + +## 推論モデル + +### 1. このフォルダに必要な依存関係が正しくインストールされていることを確認してください。 + +```shell +pip install -r requirements.txt +``` + +### 2. モデルウェイトをダウンロードします + +まず、SAT ミラーに移動してモデルの重みをダウンロードします。 CogVideoX-2B モデルの場合は、次のようにダウンロードしてください。 + +```shell +mkdir CogVideoX-2b-sat +cd CogVideoX-2b-sat +wget https://cloud.tsinghua.edu.cn/f/fdba7608a49c463ba754/?dl=1 +mv 'index.html?dl=1' vae.zip +unzip vae.zip +wget https://cloud.tsinghua.edu.cn/f/556a3e1329e74f1bac45/?dl=1 +mv 'index.html?dl=1' transformer.zip +unzip transformer.zip +``` + +CogVideoX-5B モデルの `transformers` ファイルを以下のリンクからダウンロードしてください (VAE ファイルは 2B と同じです): + ++ [CogVideoX-5B](https://cloud.tsinghua.edu.cn/d/fcef5b3904294a6885e5/?p=%2F&mode=list) ++ [CogVideoX-5B-I2V](https://cloud.tsinghua.edu.cn/d/5cc62a2d6e7d45c0a2f6/?p=%2F1&mode=list) + +次に、モデルファイルを以下の形式にフォーマットする必要があります: + +``` +. +├── transformer +│ ├── 1000 (or 1) +│ │ └── mp_rank_00_model_states.pt +│ └── latest +└── vae + └── 3d-vae.pt +``` + +モデルの重みファイルが大きいため、`git lfs`を使用することをお勧めいたします。`git lfs` +のインストールについては、[こちら](https://github.com/git-lfs/git-lfs?tab=readme-ov-file#installing)をご参照ください。 + +```shell +git lfs install +``` + +次に、T5 モデルをクローンします。これはトレーニングやファインチューニングには使用されませんが、使用する必要があります。 +> モデルを複製する際には、[Modelscope](https://modelscope.cn/models/ZhipuAI/CogVideoX-2b)のモデルファイルの場所もご使用いただけます。 + +```shell +git clone https://huggingface.co/THUDM/CogVideoX-2b.git #ハギングフェイス(huggingface.org)からモデルをダウンロードいただきます +# git clone https://www.modelscope.cn/ZhipuAI/CogVideoX-2b.git #Modelscopeからモデルをダウンロードいただきます +mkdir t5-v1_1-xxl +mv CogVideoX-2b/text_encoder/* CogVideoX-2b/tokenizer/* t5-v1_1-xxl +``` + +上記の方法に従うことで、safetensor 形式の T5 ファイルを取得できます。これにより、Deepspeed でのファインチューニング中にエラーが発生しないようにします。 + +``` +├── added_tokens.json +├── config.json +├── model-00001-of-00002.safetensors +├── model-00002-of-00002.safetensors +├── model.safetensors.index.json +├── special_tokens_map.json +├── spiece.model +└── tokenizer_config.json + +0 directories, 8 files +``` + +### 3. `configs/cogvideox_2b.yaml` ファイルを変更します。 + +```yaml +model: + scale_factor: 1.15258426 + disable_first_stage_autocast: true + log_keys: + - txt + + denoiser_config: + target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser + params: + num_idx: 1000 + quantize_c_noise: False + + weighting_config: + target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting + scaling_config: + target: sgm.modules.diffusionmodules.denoiser_scaling.VideoScaling + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization + params: + shift_scale: 3.0 + + network_config: + target: dit_video_concat.DiffusionTransformer + params: + time_embed_dim: 512 + elementwise_affine: True + num_frames: 49 + time_compressed_rate: 4 + latent_width: 90 + latent_height: 60 + num_layers: 30 + patch_size: 2 + in_channels: 16 + out_channels: 16 + hidden_size: 1920 + adm_in_channels: 256 + num_attention_heads: 30 + + transformer_args: + checkpoint_activations: True ## グラデーション チェックポイントを使用する + vocab_size: 1 + max_sequence_length: 64 + layernorm_order: pre + skip_init: false + model_parallel_size: 1 + is_decoder: false + + modules: + pos_embed_config: + target: dit_video_concat.Basic3DPositionEmbeddingMixin + params: + text_length: 226 + height_interpolation: 1.875 + width_interpolation: 1.875 + + patch_embed_config: + target: dit_video_concat.ImagePatchEmbeddingMixin + params: + text_hidden_size: 4096 + + adaln_layer_config: + target: dit_video_concat.AdaLNMixin + params: + qk_ln: True + + final_layer_config: + target: dit_video_concat.FinalLayerMixin + + conditioner_config: + target: sgm.modules.GeneralConditioner + params: + emb_models: + - is_trainable: false + input_key: txt + ucg_rate: 0.1 + target: sgm.modules.encoders.modules.FrozenT5Embedder + params: + model_dir: "t5-v1_1-xxl" # CogVideoX-2b/t5-v1_1-xxlフォルダの絶対パス + max_length: 226 + + first_stage_config: + target: vae_modules.autoencoder.VideoAutoencoderInferenceWrapper + params: + cp_size: 1 + ckpt_path: "CogVideoX-2b-sat/vae/3d-vae.pt" # CogVideoX-2b-sat/vae/3d-vae.ptフォルダの絶対パス + ignore_keys: [ 'loss' ] + + loss_config: + target: torch.nn.Identity + + regularizer_config: + target: vae_modules.regularizers.DiagonalGaussianRegularizer + + encoder_config: + target: vae_modules.cp_enc_dec.ContextParallelEncoder3D + params: + double_z: true + z_channels: 16 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [ 1, 2, 2, 4 ] + attn_resolutions: [ ] + num_res_blocks: 3 + dropout: 0.0 + gather_norm: True + + decoder_config: + target: vae_modules.cp_enc_dec.ContextParallelDecoder3D + params: + double_z: True + z_channels: 16 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [ 1, 2, 2, 4 ] + attn_resolutions: [ ] + num_res_blocks: 3 + dropout: 0.0 + gather_norm: False + + loss_fn_config: + target: sgm.modules.diffusionmodules.loss.VideoDiffusionLoss + params: + offset_noise_level: 0 + sigma_sampler_config: + target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling + params: + uniform_sampling: True + num_idx: 1000 + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization + params: + shift_scale: 3.0 + + sampler_config: + target: sgm.modules.diffusionmodules.sampling.VPSDEDPMPP2MSampler + params: + num_steps: 50 + verbose: True + + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization + params: + shift_scale: 3.0 + + guider_config: + target: sgm.modules.diffusionmodules.guiders.DynamicCFG + params: + scale: 6 + exp: 5 + num_steps: 50 +``` + +### 4. `configs/inference.yaml` ファイルを変更します。 + +```yaml +args: + latent_channels: 16 + mode: inference + load: "{absolute_path/to/your}/transformer" # CogVideoX-2b-sat/transformerフォルダの絶対パス + # load: "{your lora folder} such as zRzRzRzRzRzRzR/lora-disney-08-20-13-28" # This is for Full model without lora adapter + + batch_size: 1 + input_type: txt #TXTのテキストファイルを入力として選択されたり、CLIコマンドラインを入力として変更されたりいただけます + input_file: configs/test.txt #テキストファイルのパスで、これに対して編集がさせていただけます + sampling_num_frames: 13 # Must be 13, 11 or 9 + sampling_fps: 8 + fp16: True # For CogVideoX-2B + # bf16: True # For CogVideoX-5B + output_dir: outputs/ + force_inference: True +``` + ++ 複数のプロンプトを保存するために txt を使用する場合は、`configs/test.txt` + を参照して変更してください。1行に1つのプロンプトを記述します。プロンプトの書き方がわからない場合は、最初に [このコード](../inference/convert_demo.py) + を使用して LLM によるリファインメントを呼び出すことができます。 ++ コマンドラインを入力として使用する場合は、次のように変更します。 + +```yaml +input_type: cli +``` + +これにより、コマンドラインからプロンプトを入力できます。 + +出力ビデオのディレクトリを変更したい場合は、次のように変更できます: + +```yaml +output_dir: outputs/ +``` + +デフォルトでは `.outputs/` フォルダに保存されます。 + +### 5. 推論コードを実行して推論を開始します。 + +```shell +bash inference.sh +``` + +## モデルのファインチューニング + +### データセットの準備 + +データセットの形式は次のようになります: + +``` +. +├── labels +│ ├── 1.txt +│ ├── 2.txt +│ ├── ... +└── videos + ├── 1.mp4 + ├── 2.mp4 + ├── ... +``` + +各 txt ファイルは対応するビデオファイルと同じ名前であり、そのビデオのラベルを含んでいます。各ビデオはラベルと一対一で対応する必要があります。通常、1つのビデオに複数のラベルを持たせることはありません。 + +スタイルファインチューニングの場合、少なくとも50本のスタイルが似たビデオとラベルを準備し、フィッティングを容易にします。 + +### 設定ファイルの変更 + +`Lora` とフルパラメータ微調整の2つの方法をサポートしています。両方の微調整方法は、`transformer` 部分のみを微調整し、`VAE` +部分には変更を加えないことに注意してください。`T5` はエンコーダーとしてのみ使用されます。以下のように `configs/sft.yaml` ( +フルパラメータ微調整用) ファイルを変更してください。 + +``` + # checkpoint_activations: True ## 勾配チェックポイントを使用する場合 (設定ファイル内の2つの checkpoint_activations を True に設定する必要があります) + model_parallel_size: 1 # モデル並列サイズ + experiment_name: lora-disney # 実験名 (変更しないでください) + mode: finetune # モード (変更しないでください) + load: "{your_CogVideoX-2b-sat_path}/transformer" ## Transformer モデルのパス + no_load_rng: True # 乱数シードを読み込むかどうか + train_iters: 1000 # トレーニングイテレーション数 + eval_iters: 1 # 評価イテレーション数 + eval_interval: 100 # 評価間隔 + eval_batch_size: 1 # 評価バッチサイズ + save: ckpts # モデル保存パス + save_interval: 100 # モデル保存間隔 + log_interval: 20 # ログ出力間隔 + train_data: [ "your train data path" ] + valid_data: [ "your val data path" ] # トレーニングデータと評価データは同じでも構いません + split: 1,0,0 # トレーニングセット、評価セット、テストセットの割合 + num_workers: 8 # データローダーのワーカースレッド数 + force_train: True # チェックポイントをロードするときに欠落したキーを許可 (T5 と VAE は別々にロードされます) + only_log_video_latents: True # VAE のデコードによるメモリオーバーヘッドを回避 + deepspeed: + bf16: + enabled: False # CogVideoX-2B の場合は False に設定し、CogVideoX-5B の場合は True に設定 + fp16: + enabled: True # CogVideoX-2B の場合は True に設定し、CogVideoX-5B の場合は False に設定 +``` + +Lora 微調整を使用したい場合は、`cogvideox__lora` ファイルも変更する必要があります。 + +ここでは、`CogVideoX-2B` を参考にします。 + +``` +model: + scale_factor: 1.15258426 + disable_first_stage_autocast: true + not_trainable_prefixes: [ 'all' ] ## コメントを解除 + log_keys: + - txt' + + lora_config: ## コメントを解除 + target: sat.model.finetune.lora2.LoraMixin + params: + r: 256 +``` + +### 実行スクリプトの変更 + +設定ファイルを選択するために `finetune_single_gpu.sh` または `finetune_multi_gpus.sh` を編集します。以下に2つの例を示します。 + +1. `CogVideoX-2B` モデルを使用し、`Lora` 手法を利用する場合は、`finetune_single_gpu.sh` または `finetune_multi_gpus.sh` + を変更する必要があります。 + +``` +run_cmd="torchrun --standalone --nproc_per_node=8 train_video.py --base configs/cogvideox_2b_lora.yaml configs/sft.yaml --seed $RANDOM" +``` + +2. `CogVideoX-2B` モデルを使用し、`フルパラメータ微調整` 手法を利用する場合は、`finetune_single_gpu.sh` + または `finetune_multi_gpus.sh` を変更する必要があります。 + +``` +run_cmd="torchrun --standalone --nproc_per_node=8 train_video.py --base configs/cogvideox_2b.yaml configs/sft.yaml --seed $RANDOM" +``` + +### 微調整と評価 + +推論コードを実行して微調整を開始します。 + +``` +bash finetune_single_gpu.sh # シングルGPU +bash finetune_multi_gpus.sh # マルチGPU +``` + +### 微調整後のモデルの使用 + +微調整されたモデルは統合できません。ここでは、推論設定ファイル `inference.sh` を変更する方法を示します。 + +``` +run_cmd="$environs python sample_video.py --base configs/cogvideox__lora.yaml configs/inference.yaml --seed 42" +``` + +その後、次のコードを実行します。 + +``` +bash inference.sh +``` + +### Huggingface Diffusers サポートのウェイトに変換 + +SAT ウェイト形式は Huggingface のウェイト形式と異なり、変換が必要です。次のコマンドを実行してください: + +```shell +python ../tools/convert_weight_sat2hf.py +``` + +### SATチェックポイントからHuggingface Diffusers lora LoRAウェイトをエクスポート + +上記のステップを完了すると、LoRAウェイト付きのSATチェックポイントが得られます。ファイルは `{args.save}/1000/1000/mp_rank_00_model_states.pt` にあります。 + +LoRAウェイトをエクスポートするためのスクリプトは、CogVideoXリポジトリの `tools/export_sat_lora_weight.py` にあります。エクスポート後、`load_cogvideox_lora.py` を使用して推論を行うことができます。 + +エクスポートコマンド: + +```bash +python tools/export_sat_lora_weight.py --sat_pt_path {args.save}/{experiment_name}-09-09-21-10/1000/mp_rank_00_model_states.pt --lora_save_directory {args.save}/export_hf_lora_weights_1/ +``` + +このトレーニングでは主に以下のモデル構造が変更されました。以下の表は、HF (Hugging Face) 形式のLoRA構造に変換する際の対応関係を示しています。ご覧の通り、LoRAはモデルの注意メカニズムに低ランクの重みを追加しています。 + +``` +'attention.query_key_value.matrix_A.0': 'attn1.to_q.lora_A.weight', +'attention.query_key_value.matrix_A.1': 'attn1.to_k.lora_A.weight', +'attention.query_key_value.matrix_A.2': 'attn1.to_v.lora_A.weight', +'attention.query_key_value.matrix_B.0': 'attn1.to_q.lora_B.weight', +'attention.query_key_value.matrix_B.1': 'attn1.to_k.lora_B.weight', +'attention.query_key_value.matrix_B.2': 'attn1.to_v.lora_B.weight', +'attention.dense.matrix_A.0': 'attn1.to_out.0.lora_A.weight', +'attention.dense.matrix_B.0': 'attn1.to_out.0.lora_B.weight' +``` + +export_sat_lora_weight.py を使用して、SATチェックポイントをHF LoRA形式に変換できます。 + + +![alt text](../resources/hf_lora_weights.png) diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/README_zh.md b/PyTorch/contrib/cv/video/CogVideoX/sat/README_zh.md new file mode 100644 index 0000000000000000000000000000000000000000..c605da80d05f5828fb08836c2385c7d1f9af1a93 --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/README_zh.md @@ -0,0 +1,436 @@ +# SAT CogVideoX-2B + +[Read this in English.](./README_zh) + +[日本語で読む](./README_ja.md) + +本文件夹包含了使用 [SAT](https://github.com/THUDM/SwissArmyTransformer) 权重的推理代码,以及 SAT 权重的微调代码。 + +该代码是团队训练模型时使用的框架。注释较少,需要认真研究。 + +## 推理模型 + +### 1. 确保你已经正确安装本文件夹中的要求的依赖 + +```shell +pip install -r requirements.txt +``` + +### 2. 下载模型权重 + +首先,前往 SAT 镜像下载模型权重。 + +对于 CogVideoX-2B 模型,请按照如下方式下载: + +```shell +mkdir CogVideoX-2b-sat +cd CogVideoX-2b-sat +wget https://cloud.tsinghua.edu.cn/f/fdba7608a49c463ba754/?dl=1 +mv 'index.html?dl=1' vae.zip +unzip vae.zip +wget https://cloud.tsinghua.edu.cn/f/556a3e1329e74f1bac45/?dl=1 +mv 'index.html?dl=1' transformer.zip +unzip transformer.zip +``` + +请按如下链接方式下载 CogVideoX-5B 模型的 `transformers` 文件(VAE 文件与 2B 相同): + ++ [CogVideoX-5B](https://cloud.tsinghua.edu.cn/d/fcef5b3904294a6885e5/?p=%2F&mode=list) ++ [CogVideoX-5B-I2V](https://cloud.tsinghua.edu.cn/d/5cc62a2d6e7d45c0a2f6/?p=%2F1&mode=list) + +接着,你需要将模型文件排版成如下格式: + +``` +. +├── transformer +│ ├── 1000 (or 1) +│ │ └── mp_rank_00_model_states.pt +│ └── latest +└── vae + └── 3d-vae.pt +``` + +由于模型的权重档案较大,建议使用`git lfs`。`git lfs` +安装参见[这里](https://github.com/git-lfs/git-lfs?tab=readme-ov-file#installing) + +```shell +git lfs install +``` + +接着,克隆 T5 模型,该模型不用做训练和微调,但是必须使用。 +> 克隆模型的时候也可以使用[Modelscope](https://modelscope.cn/models/ZhipuAI/CogVideoX-2b)上的模型文件位置。 + +```shell +git clone https://huggingface.co/THUDM/CogVideoX-2b.git #从huggingface下载模型 +# git clone https://www.modelscope.cn/ZhipuAI/CogVideoX-2b.git #从modelscope下载模型 +mkdir t5-v1_1-xxl +mv CogVideoX-2b/text_encoder/* CogVideoX-2b/tokenizer/* t5-v1_1-xxl +``` + +通过上述方案,你将会得到一个 safetensor 格式的T5文件,确保在 Deepspeed微调过程中读入的时候不会报错。 + +``` +├── added_tokens.json +├── config.json +├── model-00001-of-00002.safetensors +├── model-00002-of-00002.safetensors +├── model.safetensors.index.json +├── special_tokens_map.json +├── spiece.model +└── tokenizer_config.json + +0 directories, 8 files +``` + +### 3. 修改`configs/cogvideox_2b.yaml`中的文件。 + +```yaml +model: + scale_factor: 1.15258426 + disable_first_stage_autocast: true + log_keys: + - txt + + denoiser_config: + target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser + params: + num_idx: 1000 + quantize_c_noise: False + + weighting_config: + target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting + scaling_config: + target: sgm.modules.diffusionmodules.denoiser_scaling.VideoScaling + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization + params: + shift_scale: 3.0 + + network_config: + target: dit_video_concat.DiffusionTransformer + params: + time_embed_dim: 512 + elementwise_affine: True + num_frames: 49 + time_compressed_rate: 4 + latent_width: 90 + latent_height: 60 + num_layers: 30 + patch_size: 2 + in_channels: 16 + out_channels: 16 + hidden_size: 1920 + adm_in_channels: 256 + num_attention_heads: 30 + + transformer_args: + checkpoint_activations: True ## using gradient checkpointing + vocab_size: 1 + max_sequence_length: 64 + layernorm_order: pre + skip_init: false + model_parallel_size: 1 + is_decoder: false + + modules: + pos_embed_config: + target: dit_video_concat.Basic3DPositionEmbeddingMixin + params: + text_length: 226 + height_interpolation: 1.875 + width_interpolation: 1.875 + + patch_embed_config: + target: dit_video_concat.ImagePatchEmbeddingMixin + params: + text_hidden_size: 4096 + + adaln_layer_config: + target: dit_video_concat.AdaLNMixin + params: + qk_ln: True + + final_layer_config: + target: dit_video_concat.FinalLayerMixin + + conditioner_config: + target: sgm.modules.GeneralConditioner + params: + emb_models: + - is_trainable: false + input_key: txt + ucg_rate: 0.1 + target: sgm.modules.encoders.modules.FrozenT5Embedder + params: + model_dir: "t5-v1_1-xxl" # CogVideoX-2b/t5-v1_1-xxl 权重文件夹的绝对路径 + max_length: 226 + + first_stage_config: + target: vae_modules.autoencoder.VideoAutoencoderInferenceWrapper + params: + cp_size: 1 + ckpt_path: "CogVideoX-2b-sat/vae/3d-vae.pt" # CogVideoX-2b-sat/vae/3d-vae.pt文件夹的绝对路径 + ignore_keys: [ 'loss' ] + + loss_config: + target: torch.nn.Identity + + regularizer_config: + target: vae_modules.regularizers.DiagonalGaussianRegularizer + + encoder_config: + target: vae_modules.cp_enc_dec.ContextParallelEncoder3D + params: + double_z: true + z_channels: 16 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [ 1, 2, 2, 4 ] + attn_resolutions: [ ] + num_res_blocks: 3 + dropout: 0.0 + gather_norm: True + + decoder_config: + target: vae_modules.cp_enc_dec.ContextParallelDecoder3D + params: + double_z: True + z_channels: 16 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [ 1, 2, 2, 4 ] + attn_resolutions: [ ] + num_res_blocks: 3 + dropout: 0.0 + gather_norm: False + + loss_fn_config: + target: sgm.modules.diffusionmodules.loss.VideoDiffusionLoss + params: + offset_noise_level: 0 + sigma_sampler_config: + target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling + params: + uniform_sampling: True + num_idx: 1000 + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization + params: + shift_scale: 3.0 + + sampler_config: + target: sgm.modules.diffusionmodules.sampling.VPSDEDPMPP2MSampler + params: + num_steps: 50 + verbose: True + + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization + params: + shift_scale: 3.0 + + guider_config: + target: sgm.modules.diffusionmodules.guiders.DynamicCFG + params: + scale: 6 + exp: 5 + num_steps: 50 +``` + +### 4. 修改`configs/inference.yaml`中的文件。 + +```yaml +args: + latent_channels: 16 + mode: inference + load: "{absolute_path/to/your}/transformer" # CogVideoX-2b-sat/transformer文件夹的绝对路径 + # load: "{your lora folder} such as zRzRzRzRzRzRzR/lora-disney-08-20-13-28" # This is for Full model without lora adapter + + batch_size: 1 + input_type: txt #可以选择txt纯文字档作为输入,或者改成cli命令行作为输入 + input_file: configs/test.txt #纯文字档,可以对此做编辑 + sampling_num_frames: 13 # Must be 13, 11 or 9 + sampling_fps: 8 + fp16: True # For CogVideoX-2B + # bf16: True # For CogVideoX-5B + output_dir: outputs/ + force_inference: True +``` + ++ 如果使用 txt 保存多个提示词,请参考`configs/test.txt` + 进行修改。每一行一个提示词。如果您不知道如何书写提示词,可以先使用[此代码](../inference/convert_demo.py)调用 LLM进行润色。 ++ 如果使用命令行作为输入,请修改 + +```yaml +input_type: cli +``` + +这样就可以从命令行输入提示词。 + +如果你希望修改输出视频的地址,你可以修改: + +```yaml +output_dir: outputs/ +``` + +默认保存在`.outputs/`文件夹下。 + +### 5. 运行推理代码, 即可推理 + +```shell +bash inference.sh +``` + +## 微调模型 + +### 准备数据集 + +数据集格式应该如下: + +``` +. +├── labels +│ ├── 1.txt +│ ├── 2.txt +│ ├── ... +└── videos + ├── 1.mp4 + ├── 2.mp4 + ├── ... +``` + +每个 txt 与视频同名,为视频的标签。视频与标签应该一一对应。通常情况下,不使用一个视频对应多个标签。 + +如果为风格微调,清准备至少50条风格相似的视频和标签,以利于拟合。 + +### 修改配置文件 + +我们支持 `Lora` 和 全参数微调两种方式。请注意,两种微调方式都仅仅对 `transformer` 部分进行微调。不改动 `VAE` 部分。`T5`仅作为 +Encoder 使用。 +部分。 请按照以下方式修改`configs/sft.yaml`(全量微调) 中的文件。 + +```yaml + # checkpoint_activations: True ## using gradient checkpointing (配置文件中的两个checkpoint_activations都需要设置为True) + model_parallel_size: 1 # 模型并行大小 + experiment_name: lora-disney # 实验名称(不要改动) + mode: finetune # 模式(不要改动) + load: "{your_CogVideoX-2b-sat_path}/transformer" ## Transformer 模型路径 + no_load_rng: True # 是否加载随机数种子 + train_iters: 1000 # 训练迭代次数 + eval_iters: 1 # 验证迭代次数 + eval_interval: 100 # 验证间隔 + eval_batch_size: 1 # 验证集 batch size + save: ckpts # 模型保存路径 + save_interval: 100 # 模型保存间隔 + log_interval: 20 # 日志输出间隔 + train_data: [ "your train data path" ] + valid_data: [ "your val data path" ] # 训练集和验证集可以相同 + split: 1,0,0 # 训练集,验证集,测试集比例 + num_workers: 8 # 数据加载器的工作线程数 + force_train: True # 在加载checkpoint时允许missing keys (T5 和 VAE 单独加载) + only_log_video_latents: True # 避免VAE decode带来的显存开销 + deepspeed: + bf16: + enabled: False # For CogVideoX-2B Turn to False and For CogVideoX-5B Turn to True + fp16: + enabled: True # For CogVideoX-2B Turn to True and For CogVideoX-5B Turn to False +``` + +如果你希望使用 Lora 微调,你还需要修改`cogvideox_<模型参数>_lora` 文件: + +这里以 `CogVideoX-2B` 为参考: + +```yaml +model: + scale_factor: 1.15258426 + disable_first_stage_autocast: true + not_trainable_prefixes: [ 'all' ] ## 解除注释 + log_keys: + - txt' + + lora_config: ## 解除注释 + target: sat.model.finetune.lora2.LoraMixin + params: + r: 256 +``` + +### 修改运行脚本 + +编辑`finetune_single_gpu.sh` 或者 `finetune_multi_gpus.sh`,选择配置文件。下面是两个例子: + +1. 如果您想使用 `CogVideoX-2B` 模型并使用`Lora`方案,您需要修改`finetune_single_gpu.sh` 或者 `finetune_multi_gpus.sh`: + +``` +run_cmd="torchrun --standalone --nproc_per_node=8 train_video.py --base configs/cogvideox_2b_lora.yaml configs/sft.yaml --seed $RANDOM" +``` + +2. 如果您想使用 `CogVideoX-2B` 模型并使用`全量微调`方案,您需要修改`finetune_single_gpu.sh` + 或者 `finetune_multi_gpus.sh`: + +``` +run_cmd="torchrun --standalone --nproc_per_node=8 train_video.py --base configs/cogvideox_2b.yaml configs/sft.yaml --seed $RANDOM" +``` + +### 微调和验证 + +运行推理代码,即可开始微调。 + +```shell +bash finetune_single_gpu.sh # Single GPU +bash finetune_multi_gpus.sh # Multi GPUs +``` + +### 使用微调后的模型 + +微调后的模型无法合并,这里展现了如何修改推理配置文件 `inference.sh` + +``` +run_cmd="$environs python sample_video.py --base configs/cogvideox_<模型参数>_lora.yaml configs/inference.yaml --seed 42" +``` + +然后,执行代码: + +``` +bash inference.sh +``` + +### 转换到 Huggingface Diffusers 库支持的权重 + +SAT 权重格式与 Huggingface 的权重格式不同,需要转换。请运行 + +```shell +python ../tools/convert_weight_sat2hf.py +``` + +### 从SAT权重文件 导出Huggingface Diffusers lora权重 + +支持了从SAT权重文件 +在经过上面这些步骤训练之后,我们得到了一个sat带lora的权重,在{args.save}/1000/1000/mp_rank_00_model_states.pt你可以看到这个文件 + +导出的lora权重脚本在CogVideoX仓库 tools/export_sat_lora_weight.py ,导出后使用 load_cogvideox_lora.py 推理 + +导出命令: + +``` +python tools/export_sat_lora_weight.py --sat_pt_path {args.save}/{experiment_name}-09-09-21-10/1000/mp_rank_00_model_states.pt --lora_save_directory {args.save}/export_hf_lora_weights_1/ +``` + +这次训练主要修改了下面几个模型结构,下面列出了 转换为HF格式的lora结构对应关系,可以看到lora将模型注意力结构上增加一个低秩权重, + +``` +'attention.query_key_value.matrix_A.0': 'attn1.to_q.lora_A.weight', +'attention.query_key_value.matrix_A.1': 'attn1.to_k.lora_A.weight', +'attention.query_key_value.matrix_A.2': 'attn1.to_v.lora_A.weight', +'attention.query_key_value.matrix_B.0': 'attn1.to_q.lora_B.weight', +'attention.query_key_value.matrix_B.1': 'attn1.to_k.lora_B.weight', +'attention.query_key_value.matrix_B.2': 'attn1.to_v.lora_B.weight', +'attention.dense.matrix_A.0': 'attn1.to_out.0.lora_A.weight', +'attention.dense.matrix_B.0': 'attn1.to_out.0.lora_B.weight' +``` + +通过export_sat_lora_weight.py将它转换为HF格式的lora结构 +![alt text](../resources/hf_lora_weights.png) diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/configs/cogvideox_2b.yaml b/PyTorch/contrib/cv/video/CogVideoX/sat/configs/cogvideox_2b.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f142b6272171711630c9f297486db9a793919e27 --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/configs/cogvideox_2b.yaml @@ -0,0 +1,154 @@ +model: + scale_factor: 1.15258426 + disable_first_stage_autocast: true + log_keys: + - txt + + denoiser_config: + target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser + params: + num_idx: 1000 + quantize_c_noise: False + + weighting_config: + target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting + scaling_config: + target: sgm.modules.diffusionmodules.denoiser_scaling.VideoScaling + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization + params: + shift_scale: 3.0 + + network_config: + target: dit_video_concat.DiffusionTransformer + params: + time_embed_dim: 512 + elementwise_affine: True + num_frames: 49 + time_compressed_rate: 4 + latent_width: 90 + latent_height: 60 + num_layers: 30 + patch_size: 2 + in_channels: 16 + out_channels: 16 + hidden_size: 1920 + adm_in_channels: 256 + num_attention_heads: 30 + + transformer_args: + checkpoint_activations: True ## using gradient checkpointing + vocab_size: 1 + max_sequence_length: 64 + layernorm_order: pre + skip_init: false + model_parallel_size: 1 + is_decoder: false + + modules: + pos_embed_config: + target: dit_video_concat.Basic3DPositionEmbeddingMixin + params: + text_length: 226 + height_interpolation: 1.875 + width_interpolation: 1.875 + + patch_embed_config: + target: dit_video_concat.ImagePatchEmbeddingMixin + params: + text_hidden_size: 4096 + + adaln_layer_config: + target: dit_video_concat.AdaLNMixin + params: + qk_ln: True + + final_layer_config: + target: dit_video_concat.FinalLayerMixin + + conditioner_config: + target: sgm.modules.GeneralConditioner + params: + emb_models: + - is_trainable: false + input_key: txt + ucg_rate: 0.1 + target: sgm.modules.encoders.modules.FrozenT5Embedder + params: + model_dir: "t5-v1_1-xxl" + max_length: 226 + + first_stage_config: + target: vae_modules.autoencoder.VideoAutoencoderInferenceWrapper + params: + cp_size: 1 + ckpt_path: "cogvideox-2b-sat/vae/3d-vae.pt" + ignore_keys: [ 'loss' ] + + loss_config: + target: torch.nn.Identity + + regularizer_config: + target: vae_modules.regularizers.DiagonalGaussianRegularizer + + encoder_config: + target: vae_modules.cp_enc_dec.ContextParallelEncoder3D + params: + double_z: true + z_channels: 16 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [ 1, 2, 2, 4 ] + attn_resolutions: [ ] + num_res_blocks: 3 + dropout: 0.0 + gather_norm: True + + decoder_config: + target: vae_modules.cp_enc_dec.ContextParallelDecoder3D + params: + double_z: True + z_channels: 16 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [ 1, 2, 2, 4 ] + attn_resolutions: [ ] + num_res_blocks: 3 + dropout: 0.0 + gather_norm: False + + loss_fn_config: + target: sgm.modules.diffusionmodules.loss.VideoDiffusionLoss + params: + offset_noise_level: 0 + sigma_sampler_config: + target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling + params: + uniform_sampling: True + num_idx: 1000 + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization + params: + shift_scale: 3.0 + + sampler_config: + target: sgm.modules.diffusionmodules.sampling.VPSDEDPMPP2MSampler + params: + num_steps: 50 + verbose: True + + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization + params: + shift_scale: 3.0 + + guider_config: + target: sgm.modules.diffusionmodules.guiders.DynamicCFG + params: + scale: 6 + exp: 5 + num_steps: 50 \ No newline at end of file diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/configs/cogvideox_2b_lora.yaml b/PyTorch/contrib/cv/video/CogVideoX/sat/configs/cogvideox_2b_lora.yaml new file mode 100644 index 0000000000000000000000000000000000000000..af04479aef685df6730c75ee0fa34af96b392971 --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/configs/cogvideox_2b_lora.yaml @@ -0,0 +1,160 @@ +model: + scale_factor: 1.15258426 + disable_first_stage_autocast: true + not_trainable_prefixes: ['all'] ## Using Lora + log_keys: + - txt + + denoiser_config: + target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser + params: + num_idx: 1000 + quantize_c_noise: False + + weighting_config: + target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting + scaling_config: + target: sgm.modules.diffusionmodules.denoiser_scaling.VideoScaling + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization + params: + shift_scale: 3.0 + + network_config: + target: dit_video_concat.DiffusionTransformer + params: + time_embed_dim: 512 + elementwise_affine: True + num_frames: 49 + time_compressed_rate: 4 + latent_width: 90 + latent_height: 60 + num_layers: 30 + patch_size: 2 + in_channels: 16 + out_channels: 16 + hidden_size: 1920 + adm_in_channels: 256 + num_attention_heads: 30 + + transformer_args: + checkpoint_activations: True ## using gradient checkpointing + vocab_size: 1 + max_sequence_length: 64 + layernorm_order: pre + skip_init: false + model_parallel_size: 1 + is_decoder: false + + modules: + pos_embed_config: + target: dit_video_concat.Basic3DPositionEmbeddingMixin + params: + text_length: 226 + height_interpolation: 1.875 + width_interpolation: 1.875 + + lora_config: + target: sat.model.finetune.lora2.LoraMixin + params: + r: 128 + + patch_embed_config: + target: dit_video_concat.ImagePatchEmbeddingMixin + params: + text_hidden_size: 4096 + + adaln_layer_config: + target: dit_video_concat.AdaLNMixin + params: + qk_ln: True + + final_layer_config: + target: dit_video_concat.FinalLayerMixin + + conditioner_config: + target: sgm.modules.GeneralConditioner + params: + emb_models: + - is_trainable: false + input_key: txt + ucg_rate: 0.1 + target: sgm.modules.encoders.modules.FrozenT5Embedder + params: + model_dir: "t5-v1_1-xxl" + max_length: 226 + + first_stage_config: + target: vae_modules.autoencoder.VideoAutoencoderInferenceWrapper + params: + cp_size: 1 + ckpt_path: "cogvideox-2b-sat/vae/3d-vae.pt" + ignore_keys: [ 'loss' ] + + loss_config: + target: torch.nn.Identity + + regularizer_config: + target: vae_modules.regularizers.DiagonalGaussianRegularizer + + encoder_config: + target: vae_modules.cp_enc_dec.ContextParallelEncoder3D + params: + double_z: true + z_channels: 16 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [ 1, 2, 2, 4 ] + attn_resolutions: [ ] + num_res_blocks: 3 + dropout: 0.0 + gather_norm: True + + decoder_config: + target: vae_modules.cp_enc_dec.ContextParallelDecoder3D + params: + double_z: True + z_channels: 16 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [ 1, 2, 2, 4 ] + attn_resolutions: [ ] + num_res_blocks: 3 + dropout: 0.0 + gather_norm: False + + loss_fn_config: + target: sgm.modules.diffusionmodules.loss.VideoDiffusionLoss + params: + offset_noise_level: 0 + sigma_sampler_config: + target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling + params: + uniform_sampling: True + num_idx: 1000 + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization + params: + shift_scale: 3.0 + + sampler_config: + target: sgm.modules.diffusionmodules.sampling.VPSDEDPMPP2MSampler + params: + num_steps: 50 + verbose: True + + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization + params: + shift_scale: 3.0 + + guider_config: + target: sgm.modules.diffusionmodules.guiders.DynamicCFG + params: + scale: 6 + exp: 5 + num_steps: 50 \ No newline at end of file diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/configs/cogvideox_5b.yaml b/PyTorch/contrib/cv/video/CogVideoX/sat/configs/cogvideox_5b.yaml new file mode 100644 index 0000000000000000000000000000000000000000..22ba6949433aeb6e4a9169a30a44449130c4e4d2 --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/configs/cogvideox_5b.yaml @@ -0,0 +1,153 @@ +model: + scale_factor: 0.7 + disable_first_stage_autocast: true + log_keys: + - txt + + denoiser_config: + target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser + params: + num_idx: 1000 + quantize_c_noise: False + + weighting_config: + target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting + scaling_config: + target: sgm.modules.diffusionmodules.denoiser_scaling.VideoScaling + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization + params: + shift_scale: 1.0 # different from cogvideox_2b_infer.yaml + + network_config: + target: dit_video_concat.DiffusionTransformer + params: + time_embed_dim: 512 + elementwise_affine: True + num_frames: 49 + time_compressed_rate: 4 + latent_width: 90 + latent_height: 60 + num_layers: 42 # different from cogvideox_2b_infer.yaml + patch_size: 2 + in_channels: 16 + out_channels: 16 + hidden_size: 3072 # different from cogvideox_2b_infer.yaml + adm_in_channels: 256 + num_attention_heads: 48 # different from cogvideox_2b_infer.yaml + + transformer_args: + checkpoint_activations: True + vocab_size: 1 + max_sequence_length: 64 + layernorm_order: pre + skip_init: false + model_parallel_size: 1 + is_decoder: false + + modules: + pos_embed_config: + target: dit_video_concat.Rotary3DPositionEmbeddingMixin # different from cogvideox_2b_infer.yaml + params: + hidden_size_head: 64 + text_length: 226 + + patch_embed_config: + target: dit_video_concat.ImagePatchEmbeddingMixin + params: + text_hidden_size: 4096 + + adaln_layer_config: + target: dit_video_concat.AdaLNMixin + params: + qk_ln: True + + final_layer_config: + target: dit_video_concat.FinalLayerMixin + + conditioner_config: + target: sgm.modules.GeneralConditioner + params: + emb_models: + - is_trainable: false + input_key: txt + ucg_rate: 0.1 + target: sgm.modules.encoders.modules.FrozenT5Embedder + params: + model_dir: "t5-v1_1-xxl" + max_length: 226 + + first_stage_config: + target: vae_modules.autoencoder.VideoAutoencoderInferenceWrapper + params: + cp_size: 1 + ckpt_path: "cogvideox-5b-sat/vae/3d-vae.pt" + ignore_keys: [ 'loss' ] + + loss_config: + target: torch.nn.Identity + + regularizer_config: + target: vae_modules.regularizers.DiagonalGaussianRegularizer + + encoder_config: + target: vae_modules.cp_enc_dec.ContextParallelEncoder3D + params: + double_z: true + z_channels: 16 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [ 1, 2, 2, 4 ] + attn_resolutions: [ ] + num_res_blocks: 3 + dropout: 0.0 + gather_norm: True + + decoder_config: + target: vae_modules.cp_enc_dec.ContextParallelDecoder3D + params: + double_z: True + z_channels: 16 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [ 1, 2, 2, 4 ] + attn_resolutions: [ ] + num_res_blocks: 3 + dropout: 0.0 + gather_norm: False + + loss_fn_config: + target: sgm.modules.diffusionmodules.loss.VideoDiffusionLoss + params: + offset_noise_level: 0 + sigma_sampler_config: + target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling + params: + uniform_sampling: True + num_idx: 1000 + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization + params: + shift_scale: 1.0 # different from cogvideox_2b_infer.yaml + + sampler_config: + target: sgm.modules.diffusionmodules.sampling.VPSDEDPMPP2MSampler + params: + num_steps: 50 + verbose: True + + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization + params: + shift_scale: 1.0 # different from cogvideox_2b_infer.yaml + + guider_config: + target: sgm.modules.diffusionmodules.guiders.DynamicCFG + params: + scale: 6 + exp: 5 + num_steps: 50 \ No newline at end of file diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/configs/cogvideox_5b_i2v.yaml b/PyTorch/contrib/cv/video/CogVideoX/sat/configs/cogvideox_5b_i2v.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4baf9637d782584e880e2fd559c62e50688e3192 --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/configs/cogvideox_5b_i2v.yaml @@ -0,0 +1,159 @@ +model: + scale_factor: 0.7 + disable_first_stage_autocast: true + latent_input: false + noised_image_input: true + noised_image_dropout: 0.05 + log_keys: + - txt + + denoiser_config: + target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser + params: + num_idx: 1000 + quantize_c_noise: False + + weighting_config: + target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting + scaling_config: + target: sgm.modules.diffusionmodules.denoiser_scaling.VideoScaling + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization + params: + shift_scale: 1.0 # different from cogvideox_2b_infer.yaml + + network_config: + target: dit_video_concat.DiffusionTransformer + params: + time_embed_dim: 512 + elementwise_affine: True + num_frames: 49 + time_compressed_rate: 4 + latent_width: 90 + latent_height: 60 + num_layers: 42 + patch_size: 2 + in_channels: 32 #different from cogvideox_5b_infer.yaml + out_channels: 16 + hidden_size: 3072 + adm_in_channels: 256 + num_attention_heads: 48 + + transformer_args: + checkpoint_activations: True + vocab_size: 1 + max_sequence_length: 64 + layernorm_order: pre + skip_init: false + model_parallel_size: 1 + is_decoder: false + + modules: + pos_embed_config: + target: dit_video_concat.Rotary3DPositionEmbeddingMixin + params: + learnable_pos_embed: True + hidden_size_head: 64 + text_length: 226 + + patch_embed_config: + target: dit_video_concat.ImagePatchEmbeddingMixin + params: + text_hidden_size: 4096 + + adaln_layer_config: + target: dit_video_concat.AdaLNMixin + params: + qk_ln: True + + final_layer_config: + target: dit_video_concat.FinalLayerMixin + + conditioner_config: + target: sgm.modules.GeneralConditioner + params: + emb_models: + - is_trainable: false + input_key: txt + ucg_rate: 0.1 + target: sgm.modules.encoders.modules.FrozenT5Embedder + params: + model_dir: "t5-v1_1-xxl" + max_length: 226 + + first_stage_config: + target: vae_modules.autoencoder.VideoAutoencoderInferenceWrapper + params: + cp_size: 1 + ckpt_path: "cogvideox-5b-i2v-sat/vae/3d-vae.pt" + ignore_keys: ['loss'] + + loss_config: + target: torch.nn.Identity + + regularizer_config: + target: vae_modules.regularizers.DiagonalGaussianRegularizer + + encoder_config: + target: vae_modules.cp_enc_dec.ContextParallelEncoder3D + params: + double_z: true + z_channels: 16 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [1, 2, 2, 4] + attn_resolutions: [] + num_res_blocks: 3 + dropout: 0.0 + gather_norm: True + + decoder_config: + target: vae_modules.cp_enc_dec.ContextParallelDecoder3D + params: + double_z: True + z_channels: 16 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [1, 2, 2, 4] + attn_resolutions: [] + num_res_blocks: 3 + dropout: 0.0 + gather_norm: True + + loss_fn_config: + target: sgm.modules.diffusionmodules.loss.VideoDiffusionLoss + params: + fixed_frames: 0 + offset_noise_level: 0 + sigma_sampler_config: + target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling + params: + uniform_sampling: True + num_idx: 1000 + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization + params: + shift_scale: 1.0 + + sampler_config: + target: sgm.modules.diffusionmodules.sampling.VPSDEDPMPP2MSampler + params: + fixed_frames: 0 + num_steps: 50 + verbose: True + + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization + params: + shift_scale: 1.0 + + guider_config: + target: sgm.modules.diffusionmodules.guiders.DynamicCFG + params: + scale: 6 + exp: 5 + num_steps: 50 \ No newline at end of file diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/configs/cogvideox_5b_i2v_lora.yaml b/PyTorch/contrib/cv/video/CogVideoX/sat/configs/cogvideox_5b_i2v_lora.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e36aee7b463aa98ce95f3485f490bb12369f80b6 --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/configs/cogvideox_5b_i2v_lora.yaml @@ -0,0 +1,165 @@ +model: + scale_factor: 0.7 + disable_first_stage_autocast: true + latent_input: false + noised_image_input: true + noised_image_dropout: 0.05 + not_trainable_prefixes: ['all'] ## Using Lora + log_keys: + - txt + + denoiser_config: + target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser + params: + num_idx: 1000 + quantize_c_noise: False + + weighting_config: + target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting + scaling_config: + target: sgm.modules.diffusionmodules.denoiser_scaling.VideoScaling + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization + params: + shift_scale: 1.0 # different from cogvideox_2b_infer.yaml + + network_config: + target: dit_video_concat.DiffusionTransformer + params: + time_embed_dim: 512 + elementwise_affine: True + num_frames: 49 + time_compressed_rate: 4 + latent_width: 90 + latent_height: 60 + num_layers: 42 + patch_size: 2 + in_channels: 32 + out_channels: 16 + hidden_size: 3072 + adm_in_channels: 256 + num_attention_heads: 48 + + transformer_args: + checkpoint_activations: True + vocab_size: 1 + max_sequence_length: 64 + layernorm_order: pre + skip_init: false + model_parallel_size: 1 + is_decoder: false + + modules: + pos_embed_config: + target: dit_video_concat.Rotary3DPositionEmbeddingMixin + params: + learnable_pos_embed: True + hidden_size_head: 64 + text_length: 226 + + lora_config: + target: sat.model.finetune.lora2.LoraMixin + params: + r: 256 + + patch_embed_config: + target: dit_video_concat.ImagePatchEmbeddingMixin + params: + text_hidden_size: 4096 + + adaln_layer_config: + target: dit_video_concat.AdaLNMixin + params: + qk_ln: True + + final_layer_config: + target: dit_video_concat.FinalLayerMixin + + conditioner_config: + target: sgm.modules.GeneralConditioner + params: + emb_models: + - is_trainable: false + input_key: txt + ucg_rate: 0.1 + target: sgm.modules.encoders.modules.FrozenT5Embedder + params: + model_dir: "t5-v1_1-xxl" + max_length: 226 + + first_stage_config: + target: vae_modules.autoencoder.VideoAutoencoderInferenceWrapper + params: + cp_size: 1 + ckpt_path: "cogvideox-5b-i2v-sat/vae/3d-vae.pt" + ignore_keys: [ 'loss' ] + + loss_config: + target: torch.nn.Identity + + regularizer_config: + target: vae_modules.regularizers.DiagonalGaussianRegularizer + + encoder_config: + target: vae_modules.cp_enc_dec.ContextParallelEncoder3D + params: + double_z: true + z_channels: 16 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [ 1, 2, 2, 4 ] + attn_resolutions: [ ] + num_res_blocks: 3 + dropout: 0.0 + gather_norm: True + + decoder_config: + target: vae_modules.cp_enc_dec.ContextParallelDecoder3D + params: + double_z: True + z_channels: 16 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [ 1, 2, 2, 4 ] + attn_resolutions: [ ] + num_res_blocks: 3 + dropout: 0.0 + gather_norm: True + + loss_fn_config: + target: sgm.modules.diffusionmodules.loss.VideoDiffusionLoss + params: + fixed_frames: 0 + offset_noise_level: 0 + sigma_sampler_config: + target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling + params: + uniform_sampling: True + num_idx: 1000 + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization + params: + shift_scale: 1.0 + + sampler_config: + target: sgm.modules.diffusionmodules.sampling.VPSDEDPMPP2MSampler + params: + fixed_frames: 0 + num_steps: 50 + verbose: True + + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization + params: + shift_scale: 1.0 + + guider_config: + target: sgm.modules.diffusionmodules.guiders.DynamicCFG + params: + scale: 6 + exp: 5 + num_steps: 50 \ No newline at end of file diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/configs/cogvideox_5b_lora.yaml b/PyTorch/contrib/cv/video/CogVideoX/sat/configs/cogvideox_5b_lora.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f3ac66c35023e1c9377b501fcebe51a172a2ba58 --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/configs/cogvideox_5b_lora.yaml @@ -0,0 +1,159 @@ +model: + scale_factor: 0.7 # different from cogvideox_2b_infer.yaml + disable_first_stage_autocast: true + not_trainable_prefixes: ['all'] # Using Lora + log_keys: + - txt + + denoiser_config: + target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser + params: + num_idx: 1000 + quantize_c_noise: False + + weighting_config: + target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting + scaling_config: + target: sgm.modules.diffusionmodules.denoiser_scaling.VideoScaling + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization + params: + shift_scale: 1.0 # different from cogvideox_2b_infer.yaml + + network_config: + target: dit_video_concat.DiffusionTransformer + params: + time_embed_dim: 512 + elementwise_affine: True + num_frames: 49 + time_compressed_rate: 4 + latent_width: 90 + latent_height: 60 + num_layers: 42 # different from cogvideox_2b_infer.yaml + patch_size: 2 + in_channels: 16 + out_channels: 16 + hidden_size: 3072 # different from cogvideox_2b_infer.yaml + adm_in_channels: 256 + num_attention_heads: 48 # different from cogvideox_2b_infer.yaml + + transformer_args: + checkpoint_activations: True + vocab_size: 1 + max_sequence_length: 64 + layernorm_order: pre + skip_init: false + model_parallel_size: 1 + is_decoder: false + + modules: + pos_embed_config: + target: dit_video_concat.Rotary3DPositionEmbeddingMixin # different from cogvideox_2b_infer.yaml + params: + hidden_size_head: 64 + text_length: 226 + + lora_config: # Using Lora + target: sat.model.finetune.lora2.LoraMixin + params: + r: 128 + + patch_embed_config: + target: dit_video_concat.ImagePatchEmbeddingMixin + params: + text_hidden_size: 4096 + + adaln_layer_config: + target: dit_video_concat.AdaLNMixin + params: + qk_ln: True + + final_layer_config: + target: dit_video_concat.FinalLayerMixin + + conditioner_config: + target: sgm.modules.GeneralConditioner + params: + emb_models: + - is_trainable: false + input_key: txt + ucg_rate: 0.1 + target: sgm.modules.encoders.modules.FrozenT5Embedder + params: + model_dir: "t5-v1_1-xxl" + max_length: 226 + + first_stage_config: + target: vae_modules.autoencoder.VideoAutoencoderInferenceWrapper + params: + cp_size: 1 + ckpt_path: "cogvideox-5b-sat/vae/3d-vae.pt" + ignore_keys: [ 'loss' ] + + loss_config: + target: torch.nn.Identity + + regularizer_config: + target: vae_modules.regularizers.DiagonalGaussianRegularizer + + encoder_config: + target: vae_modules.cp_enc_dec.ContextParallelEncoder3D + params: + double_z: true + z_channels: 16 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [ 1, 2, 2, 4 ] + attn_resolutions: [ ] + num_res_blocks: 3 + dropout: 0.0 + gather_norm: True + + decoder_config: + target: vae_modules.cp_enc_dec.ContextParallelDecoder3D + params: + double_z: True + z_channels: 16 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [ 1, 2, 2, 4 ] + attn_resolutions: [ ] + num_res_blocks: 3 + dropout: 0.0 + gather_norm: False + + loss_fn_config: + target: sgm.modules.diffusionmodules.loss.VideoDiffusionLoss + params: + offset_noise_level: 0 + sigma_sampler_config: + target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling + params: + uniform_sampling: True + num_idx: 1000 + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization + params: + shift_scale: 1.0 # different from cogvideox_2b_infer.yaml + + sampler_config: + target: sgm.modules.diffusionmodules.sampling.VPSDEDPMPP2MSampler + params: + num_steps: 50 + verbose: True + + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization + params: + shift_scale: 1.0 + + guider_config: + target: sgm.modules.diffusionmodules.guiders.DynamicCFG + params: + scale: 6 + exp: 5 + num_steps: 50 \ No newline at end of file diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/configs/cogvideox_5b_sft.yaml b/PyTorch/contrib/cv/video/CogVideoX/sat/configs/cogvideox_5b_sft.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9994f3165ef1437fec277e2c625c85a952f983d3 --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/configs/cogvideox_5b_sft.yaml @@ -0,0 +1,214 @@ +args: + checkpoint_activations: True # using gradient checkpointing + model_parallel_size: 1 + experiment_name: finetune-5b-npu + mode: finetune + load: "../CogVideoX-5b-sat/transformer" + no_load_rng: True + train_iters: 1000 # Suggest more than 1000 For Lora and SFT For 500 is enough + eval_iters: 1 + eval_interval: 100 + eval_batch_size: 1 + save: ckpts_5b_sft + save_interval: 500 + log_interval: 1 + train_data: [ "../dataset" ] # Train data path + valid_data: [ "../dataset" ] # Validation data path, can be the same as train_data(not recommended) + split: 1,0,0 + num_workers: 8 + force_train: True + only_log_video_latents: True + +data: + target: data_video.SFTDataset + params: + video_size: [ 480, 720 ] + fps: 8 + max_num_frames: 49 + skip_frms_num: 3. + +deepspeed: + # Minimum for 16 videos per batch for ALL GPUs, This setting is for 8 x A100 GPUs + train_micro_batch_size_per_gpu: 1 + gradient_accumulation_steps: 1 + steps_per_print: 1 + gradient_clipping: 0.1 + zero_optimization: + stage: 2 + cpu_offload: false + contiguous_gradients: false + overlap_comm: true + reduce_scatter: true + reduce_bucket_size: 1000000000 + allgather_bucket_size: 1000000000 + load_from_fp32_weights: false + zero_allow_untested_optimizer: true + bf16: + enabled: True # For CogVideoX-2B Turn to False and For CogVideoX-5B Turn to True + fp16: + enabled: False # For CogVideoX-2B Turn to True and For CogVideoX-5B Turn to False + loss_scale: 0 + loss_scale_window: 400 + hysteresis: 2 + min_loss_scale: 1 + + optimizer: + type: AdamW #sat.ops.FusedEmaAdam + params: + lr: 0.00002 # Between 1E-3 and 5E-4 For Lora and 1E-5 For SFT + betas: [ 0.9, 0.95 ] + eps: 1e-8 + weight_decay: 1e-4 + activation_checkpointing: + partition_activations: false + contiguous_memory_optimization: false + wall_clock_breakdown: false + +model: + scale_factor: 0.7 + disable_first_stage_autocast: true + log_keys: + - txt + + denoiser_config: + target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser + params: + num_idx: 1000 + quantize_c_noise: False + + weighting_config: + target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting + scaling_config: + target: sgm.modules.diffusionmodules.denoiser_scaling.VideoScaling + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization + + network_config: + target: dit_video_concat.DiffusionTransformer + params: + time_embed_dim: 512 + elementwise_affine: True + num_frames: 49 + time_compressed_rate: 4 + latent_width: 90 + latent_height: 60 + num_layers: 42 # different from cogvideox_2b_infer.yaml + patch_size: 2 + in_channels: 16 + out_channels: 16 + hidden_size: 3072 # different from cogvideox_2b_infer.yaml + adm_in_channels: 256 + num_attention_heads: 48 # different from cogvideox_2b_infer.yaml + + transformer_args: + use_gpu_initialization: True + checkpoint_activations: True + vocab_size: 1 + max_sequence_length: 64 + layernorm_order: pre + skip_init: false + model_parallel_size: 1 + is_decoder: false + + modules: + pos_embed_config: + target: dit_video_concat.Rotary3DPositionEmbeddingMixin # different from cogvideox_2b_infer.yaml + params: + hidden_size_head: 64 + text_length: 226 + + patch_embed_config: + target: dit_video_concat.ImagePatchEmbeddingMixin + params: + text_hidden_size: 4096 + + adaln_layer_config: + target: dit_video_concat.AdaLNMixin + params: + qk_ln: True + + final_layer_config: + target: dit_video_concat.FinalLayerMixin + + conditioner_config: + target: sgm.modules.GeneralConditioner + params: + emb_models: + - is_trainable: false + input_key: txt + ucg_rate: 0.1 + target: sgm.modules.encoders.modules.FrozenT5Embedder + params: + model_dir: "../t5-v1_1-xxl" + max_length: 226 + + first_stage_config: + target: vae_modules.autoencoder.VideoAutoencoderInferenceWrapper + params: + cp_size: 1 + ckpt_path: "../CogVideox-5b-sat/vae/3d-vae.pt" + ignore_keys: [ 'loss' ] + + loss_config: + target: torch.nn.Identity + + regularizer_config: + target: vae_modules.regularizers.DiagonalGaussianRegularizer + + encoder_config: + target: vae_modules.cp_enc_dec.ContextParallelEncoder3D + params: + double_z: true + z_channels: 16 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [ 1, 2, 2, 4 ] + attn_resolutions: [ ] + num_res_blocks: 3 + dropout: 0.0 + gather_norm: True + + decoder_config: + target: vae_modules.cp_enc_dec.ContextParallelDecoder3D + params: + double_z: True + z_channels: 16 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [ 1, 2, 2, 4 ] + attn_resolutions: [ ] + num_res_blocks: 3 + dropout: 0.0 + gather_norm: False + + loss_fn_config: + target: sgm.modules.diffusionmodules.loss.VideoDiffusionLoss + params: + offset_noise_level: 0 + sigma_sampler_config: + target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling + params: + uniform_sampling: True + num_idx: 1000 + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization + + sampler_config: + target: sgm.modules.diffusionmodules.sampling.VPSDEDPMPP2MSampler + params: + num_steps: 50 + verbose: True + + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization + + guider_config: + target: sgm.modules.diffusionmodules.guiders.DynamicCFG + params: + scale: 6 + exp: 5 + num_steps: 50 \ No newline at end of file diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/configs/inference.yaml b/PyTorch/contrib/cv/video/CogVideoX/sat/configs/inference.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a93bb997862a1844456df3a0b6a67c39143c6f22 --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/configs/inference.yaml @@ -0,0 +1,16 @@ +args: + image2video: False # True for image2video, False for text2video + latent_channels: 16 + mode: inference + load: "{your CogVideoX SAT folder}/transformer" # This is for Full model without lora adapter + # load: "{your lora folder} such as zRzRzRzRzRzRzR/lora-disney-08-20-13-28" # This is for Full model without lora adapter + batch_size: 1 + input_type: txt + input_file: configs/test.txt + sampling_image_size: [480, 720] + sampling_num_frames: 13 # Must be 13, 11 or 9 + sampling_fps: 8 +# fp16: True # For CogVideoX-2B + bf16: True # For CogVideoX-5B and CoGVideoX-5B-I2V + output_dir: outputs/ + force_inference: True \ No newline at end of file diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/configs/sft.yaml b/PyTorch/contrib/cv/video/CogVideoX/sat/configs/sft.yaml new file mode 100644 index 0000000000000000000000000000000000000000..971c52142f533b3b06dc28cae4e7057ea517072b --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/configs/sft.yaml @@ -0,0 +1,65 @@ +args: + checkpoint_activations: True # using gradient checkpointing + model_parallel_size: 1 + experiment_name: lora-disney + mode: finetune + load: "{your CogVideoX SAT folder}/transformer" + no_load_rng: True + train_iters: 1000 # Suggest more than 1000 For Lora and SFT For 500 is enough + eval_iters: 1 + eval_interval: 100 + eval_batch_size: 1 + save: ckpts_5b_lora + save_interval: 500 + log_interval: 20 + train_data: [ "disney" ] # Train data path + valid_data: [ "disney" ] # Validation data path, can be the same as train_data(not recommended) + split: 1,0,0 + num_workers: 8 + force_train: True + only_log_video_latents: True + +data: + target: data_video.SFTDataset + params: + video_size: [ 480, 720 ] + fps: 8 + max_num_frames: 49 + skip_frms_num: 3. + +deepspeed: + # Minimum for 16 videos per batch for ALL GPUs, This setting is for 8 x A100 GPUs + train_micro_batch_size_per_gpu: 2 + gradient_accumulation_steps: 1 + steps_per_print: 50 + gradient_clipping: 0.1 + zero_optimization: + stage: 2 + cpu_offload: false + contiguous_gradients: false + overlap_comm: true + reduce_scatter: true + reduce_bucket_size: 1000000000 + allgather_bucket_size: 1000000000 + load_from_fp32_weights: false + zero_allow_untested_optimizer: true + bf16: + enabled: True # For CogVideoX-2B Turn to False and For CogVideoX-5B Turn to True + fp16: + enabled: False # For CogVideoX-2B Turn to True and For CogVideoX-5B Turn to False + loss_scale: 0 + loss_scale_window: 400 + hysteresis: 2 + min_loss_scale: 1 + + optimizer: + type: sat.ops.FusedEmaAdam + params: + lr: 0.00001 # Between 1E-3 and 5E-4 For Lora and 1E-5 For SFT + betas: [ 0.9, 0.95 ] + eps: 1e-8 + weight_decay: 1e-4 + activation_checkpointing: + partition_activations: false + contiguous_memory_optimization: false + wall_clock_breakdown: false \ No newline at end of file diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/configs/test.txt b/PyTorch/contrib/cv/video/CogVideoX/sat/configs/test.txt new file mode 100644 index 0000000000000000000000000000000000000000..8d035c06d5129db4fef4fb491f3454525a4e9964 --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/configs/test.txt @@ -0,0 +1,4 @@ +In the haunting backdrop of a warIn the haunting backdrop of a war-torn city, where ruins and crumbled walls tell a story of devastation, a poignant close-up frames a young girl. Her face is smudged with ash, a silent testament to the chaos around her. Her eyes glistening with a mix of sorrow and resilience, capturing the raw emotion of a world that has lost its innocence to the ravages of conflict. +The camera follows behind a white vintage SUV with a black roof rack as it speeds up a steep dirt road surrounded by pine trees on a steep mountain slope, dust kicks up from its tires, the sunlight shines on the SUV as it speeds along the dirt road, casting a warm glow over the scene. The dirt road curves gently into the distance, with no other cars or vehicles in sight. The trees on either side of the road are redwoods, with patches of greenery scattered throughout. The car is seen from the rear following the curve with ease, making it seem as if it is on a rugged drive through the rugged terrain. The dirt road itself is surrounded by steep hills and mountains, with a clear blue sky above with wispy clouds. +A detailed wooden toy ship with intricately carved masts and sails is seen gliding smoothly over a plush, blue carpet that mimics the waves of the sea. The ship's hull is painted a rich brown, with tiny windows. The carpet, soft and textured, provides a perfect backdrop, resembling an oceanic expanse. Surrounding the ship are various other toys and children's items, hinting at a playful environment. The scene captures the innocence and imagination of childhood, with the toy ship's journey symbolizing endless adventures in a whimsical, indoor setting. +A street artist, clad in a worn-out denim jacket and a colorful bandana, stands before a vast concrete wall in the heart, holding a can of spray paint, spray-painting a colorful bird on a mottled wall.-torn city, where ruins and crumbled walls tell a story of devastation, a poignant close-up frames a young girl. Her face is smudged with ash, a silent testament to the chaos around her. Her eyes glistening with a mix of sorrow and resilience, capturing the raw emotion of a world that has lost its innocence to the ravages of conflict. \ No newline at end of file diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/data_video.py b/PyTorch/contrib/cv/video/CogVideoX/sat/data_video.py new file mode 100644 index 0000000000000000000000000000000000000000..b572d8349ad3d35ef838043a18b89e802c6818e6 --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/data_video.py @@ -0,0 +1,452 @@ +import io +import os +import sys +from functools import partial +import math +import torchvision.transforms as TT +from sgm.webds import MetaDistributedWebDataset +import random +from fractions import Fraction +from typing import Union, Optional, Dict, Any, Tuple +from torchvision.io.video import av +import numpy as np +import torch +from torchvision.io import _video_opt +from torchvision.io.video import _check_av_available, _read_from_stream, _align_audio_frames +from torchvision.transforms.functional import center_crop, resize +from torchvision.transforms import InterpolationMode +import decord +from decord import VideoReader +from torch.utils.data import Dataset + + +def read_video( + filename: str, + start_pts: Union[float, Fraction] = 0, + end_pts: Optional[Union[float, Fraction]] = None, + pts_unit: str = "pts", + output_format: str = "THWC", +) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]: + """ + Reads a video from a file, returning both the video frames and the audio frames + + Args: + filename (str): path to the video file + start_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional): + The start presentation time of the video + end_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional): + The end presentation time + pts_unit (str, optional): unit in which start_pts and end_pts values will be interpreted, + either 'pts' or 'sec'. Defaults to 'pts'. + output_format (str, optional): The format of the output video tensors. Can be either "THWC" (default) or "TCHW". + + Returns: + vframes (Tensor[T, H, W, C] or Tensor[T, C, H, W]): the `T` video frames + aframes (Tensor[K, L]): the audio frames, where `K` is the number of channels and `L` is the number of points + info (Dict): metadata for the video and audio. Can contain the fields video_fps (float) and audio_fps (int) + """ + + output_format = output_format.upper() + if output_format not in ("THWC", "TCHW"): + raise ValueError(f"output_format should be either 'THWC' or 'TCHW', got {output_format}.") + + _check_av_available() + + if end_pts is None: + end_pts = float("inf") + + if end_pts < start_pts: + raise ValueError(f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}") + + info = {} + audio_frames = [] + audio_timebase = _video_opt.default_timebase + + with av.open(filename, metadata_errors="ignore") as container: + if container.streams.audio: + audio_timebase = container.streams.audio[0].time_base + if container.streams.video: + video_frames = _read_from_stream( + container, + start_pts, + end_pts, + pts_unit, + container.streams.video[0], + {"video": 0}, + ) + video_fps = container.streams.video[0].average_rate + # guard against potentially corrupted files + if video_fps is not None: + info["video_fps"] = float(video_fps) + + if container.streams.audio: + audio_frames = _read_from_stream( + container, + start_pts, + end_pts, + pts_unit, + container.streams.audio[0], + {"audio": 0}, + ) + info["audio_fps"] = container.streams.audio[0].rate + + aframes_list = [frame.to_ndarray() for frame in audio_frames] + + vframes = torch.empty((0, 1, 1, 3), dtype=torch.uint8) + + if aframes_list: + aframes = np.concatenate(aframes_list, 1) + aframes = torch.as_tensor(aframes) + if pts_unit == "sec": + start_pts = int(math.floor(start_pts * (1 / audio_timebase))) + if end_pts != float("inf"): + end_pts = int(math.ceil(end_pts * (1 / audio_timebase))) + aframes = _align_audio_frames(aframes, audio_frames, start_pts, end_pts) + else: + aframes = torch.empty((1, 0), dtype=torch.float32) + + if output_format == "TCHW": + # [T,H,W,C] --> [T,C,H,W] + vframes = vframes.permute(0, 3, 1, 2) + + return vframes, aframes, info + + +def resize_for_rectangle_crop(arr, image_size, reshape_mode="random"): + if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]: + arr = resize( + arr, + size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])], + interpolation=InterpolationMode.BICUBIC, + ) + else: + arr = resize( + arr, + size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]], + interpolation=InterpolationMode.BICUBIC, + ) + + h, w = arr.shape[2], arr.shape[3] + arr = arr.squeeze(0) + + delta_h = h - image_size[0] + delta_w = w - image_size[1] + + if reshape_mode == "random" or reshape_mode == "none": + top = np.random.randint(0, delta_h + 1) + left = np.random.randint(0, delta_w + 1) + elif reshape_mode == "center": + top, left = delta_h // 2, delta_w // 2 + else: + raise NotImplementedError + arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1]) + return arr + + +def pad_last_frame(tensor, num_frames): + # T, H, W, C + if len(tensor) < num_frames: + pad_length = num_frames - len(tensor) + # Use the last frame to pad instead of zero + last_frame = tensor[-1] + pad_tensor = last_frame.unsqueeze(0).expand(pad_length, *tensor.shape[1:]) + padded_tensor = torch.cat([tensor, pad_tensor], dim=0) + return padded_tensor + else: + return tensor[:num_frames] + + +def load_video( + video_data, + sampling="uniform", + duration=None, + num_frames=4, + wanted_fps=None, + actual_fps=None, + skip_frms_num=0.0, + nb_read_frames=None, +): + decord.bridge.set_bridge("torch") + vr = VideoReader(uri=video_data, height=-1, width=-1) + if nb_read_frames is not None: + ori_vlen = nb_read_frames + else: + ori_vlen = min(int(duration * actual_fps) - 1, len(vr)) + + max_seek = int(ori_vlen - skip_frms_num - num_frames / wanted_fps * actual_fps) + start = random.randint(skip_frms_num, max_seek + 1) + end = int(start + num_frames / wanted_fps * actual_fps) + n_frms = num_frames + + if sampling == "uniform": + indices = np.arange(start, end, (end - start) / n_frms).astype(int) + else: + raise NotImplementedError + + # get_batch -> T, H, W, C + temp_frms = vr.get_batch(np.arange(start, end)) + assert temp_frms is not None + tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms + tensor_frms = tensor_frms[torch.tensor((indices - start).tolist())] + + return pad_last_frame(tensor_frms, num_frames) + + +import threading + + +def load_video_with_timeout(*args, **kwargs): + video_container = {} + + def target_function(): + video = load_video(*args, **kwargs) + video_container["video"] = video + + thread = threading.Thread(target=target_function) + thread.start() + timeout = 20 + thread.join(timeout) + + if thread.is_alive(): + print("Loading video timed out") + raise TimeoutError + return video_container.get("video", None).contiguous() + + +def process_video( + video_path, + image_size=None, + duration=None, + num_frames=4, + wanted_fps=None, + actual_fps=None, + skip_frms_num=0.0, + nb_read_frames=None, +): + """ + video_path: str or io.BytesIO + image_size: . + duration: preknow the duration to speed up by seeking to sampled start. TODO by_pass if unknown. + num_frames: wanted num_frames. + wanted_fps: . + skip_frms_num: ignore the first and the last xx frames, avoiding transitions. + """ + + video = load_video_with_timeout( + video_path, + duration=duration, + num_frames=num_frames, + wanted_fps=wanted_fps, + actual_fps=actual_fps, + skip_frms_num=skip_frms_num, + nb_read_frames=nb_read_frames, + ) + + # --- copy and modify the image process --- + video = video.permute(0, 3, 1, 2) # [T, C, H, W] + + # resize + if image_size is not None: + video = resize_for_rectangle_crop(video, image_size, reshape_mode="center") + + return video + + +def process_fn_video(src, image_size, fps, num_frames, skip_frms_num=0.0, txt_key="caption"): + while True: + r = next(src) + if "mp4" in r: + video_data = r["mp4"] + elif "avi" in r: + video_data = r["avi"] + else: + print("No video data found") + continue + + if txt_key not in r: + txt = "" + else: + txt = r[txt_key] + + if isinstance(txt, bytes): + txt = txt.decode("utf-8") + else: + txt = str(txt) + + duration = r.get("duration", None) + if duration is not None: + duration = float(duration) + else: + continue + + actual_fps = r.get("fps", None) + if actual_fps is not None: + actual_fps = float(actual_fps) + else: + continue + + required_frames = num_frames / fps * actual_fps + 2 * skip_frms_num + required_duration = num_frames / fps + 2 * skip_frms_num / actual_fps + + if duration is not None and duration < required_duration: + continue + + try: + frames = process_video( + io.BytesIO(video_data), + num_frames=num_frames, + wanted_fps=fps, + image_size=image_size, + duration=duration, + actual_fps=actual_fps, + skip_frms_num=skip_frms_num, + ) + frames = (frames - 127.5) / 127.5 + except Exception as e: + print(e) + continue + + item = { + "mp4": frames, + "txt": txt, + "num_frames": num_frames, + "fps": fps, + } + + yield item + + +class VideoDataset(MetaDistributedWebDataset): + def __init__( + self, + path, + image_size, + num_frames, + fps, + skip_frms_num=0.0, + nshards=sys.maxsize, + seed=1, + meta_names=None, + shuffle_buffer=1000, + include_dirs=None, + txt_key="caption", + **kwargs, + ): + if seed == -1: + seed = random.randint(0, 1000000) + if meta_names is None: + meta_names = [] + + if path.startswith(";"): + path, include_dirs = path.split(";", 1) + super().__init__( + path, + partial( + process_fn_video, num_frames=num_frames, image_size=image_size, fps=fps, skip_frms_num=skip_frms_num + ), + seed, + meta_names=meta_names, + shuffle_buffer=shuffle_buffer, + nshards=nshards, + include_dirs=include_dirs, + ) + + @classmethod + def create_dataset_function(cls, path, args, **kwargs): + return cls(path, **kwargs) + + +class SFTDataset(Dataset): + def __init__(self, data_dir, video_size, fps, max_num_frames, skip_frms_num=3): + """ + skip_frms_num: ignore the first and the last xx frames, avoiding transitions. + """ + super(SFTDataset, self).__init__() + + self.video_size = video_size + self.fps = fps + self.max_num_frames = max_num_frames + self.skip_frms_num = skip_frms_num + + self.video_paths = [] + self.captions = [] + + for root, dirnames, filenames in os.walk(data_dir): + for filename in filenames: + if filename.endswith(".mp4"): + video_path = os.path.join(root, filename) + self.video_paths.append(video_path) + + caption_path = video_path.replace(".mp4", ".txt").replace("videos", "labels") + if os.path.exists(caption_path): + caption = open(caption_path, "r").read().splitlines()[0] + else: + caption = "" + self.captions.append(caption) + + def __getitem__(self, index): + decord.bridge.set_bridge("torch") + + video_path = self.video_paths[index] + vr = VideoReader(uri=video_path, height=-1, width=-1) + actual_fps = vr.get_avg_fps() + ori_vlen = len(vr) + + if ori_vlen / actual_fps * self.fps > self.max_num_frames: + num_frames = self.max_num_frames + start = int(self.skip_frms_num) + end = int(start + num_frames / self.fps * actual_fps) + end_safty = min(int(start + num_frames / self.fps * actual_fps), int(ori_vlen)) + indices = np.arange(start, end, (end - start) // num_frames).astype(int) + temp_frms = vr.get_batch(np.arange(start, end_safty)) + assert temp_frms is not None + tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms + tensor_frms = tensor_frms[torch.tensor((indices - start).tolist())] + else: + if ori_vlen > self.max_num_frames: + num_frames = self.max_num_frames + start = int(self.skip_frms_num) + end = int(ori_vlen - self.skip_frms_num) + indices = np.arange(start, end, max((end - start) // num_frames, 1)).astype(int) + temp_frms = vr.get_batch(np.arange(start, end)) + assert temp_frms is not None + tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms + tensor_frms = tensor_frms[torch.tensor((indices - start).tolist())] + else: + + def nearest_smaller_4k_plus_1(n): + remainder = n % 4 + if remainder == 0: + return n - 3 + else: + return n - remainder + 1 + + start = int(self.skip_frms_num) + end = int(ori_vlen - self.skip_frms_num) + num_frames = nearest_smaller_4k_plus_1(end - start) # 3D VAE requires the number of frames to be 4k+1 + end = int(start + num_frames) + temp_frms = vr.get_batch(np.arange(start, end)) + assert temp_frms is not None + tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms + + tensor_frms = pad_last_frame( + tensor_frms, self.max_num_frames + ) # the len of indices may be less than num_frames, due to round error + tensor_frms = tensor_frms.permute(0, 3, 1, 2) # [T, H, W, C] -> [T, C, H, W] + tensor_frms = resize_for_rectangle_crop(tensor_frms, self.video_size, reshape_mode="center") + tensor_frms = (tensor_frms - 127.5) / 127.5 + + item = { + "mp4": tensor_frms, + "txt": self.captions[index], + "num_frames": num_frames, + "fps": self.fps, + } + return item + + def __len__(self): + return len(self.video_paths) + + @classmethod + def create_dataset_function(cls, path, args, **kwargs): + return cls(data_dir=path, **kwargs) diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/diffusion_video.py b/PyTorch/contrib/cv/video/CogVideoX/sat/diffusion_video.py new file mode 100644 index 0000000000000000000000000000000000000000..888c54ca7f9aa2f80bc03de6015530d24489b1be --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/diffusion_video.py @@ -0,0 +1,353 @@ +import random + +import math +from typing import Any, Dict, List, Tuple, Union +from omegaconf import ListConfig +import torch.nn.functional as F + +from sat.helpers import print_rank0 +import torch +from torch import nn + +from sgm.modules import UNCONDITIONAL_CONFIG +from sgm.modules.autoencoding.temporal_ae import VideoDecoder +from sgm.modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER +from sgm.util import ( + default, + disabled_train, + get_obj_from_str, + instantiate_from_config, + log_txt_as_img, +) +# import gc +from sat import mpu + + +class SATVideoDiffusionEngine(nn.Module): + def __init__(self, args, **kwargs): + super().__init__() + + model_config = args.model_config + # model args preprocess + log_keys = model_config.get("log_keys", None) + input_key = model_config.get("input_key", "mp4") + network_config = model_config.get("network_config", None) + network_wrapper = model_config.get("network_wrapper", None) + denoiser_config = model_config.get("denoiser_config", None) + sampler_config = model_config.get("sampler_config", None) + conditioner_config = model_config.get("conditioner_config", None) + first_stage_config = model_config.get("first_stage_config", None) + loss_fn_config = model_config.get("loss_fn_config", None) + scale_factor = model_config.get("scale_factor", 1.0) + latent_input = model_config.get("latent_input", False) + disable_first_stage_autocast = model_config.get("disable_first_stage_autocast", False) + no_cond_log = model_config.get("disable_first_stage_autocast", False) + not_trainable_prefixes = model_config.get("not_trainable_prefixes", ["first_stage_model", "conditioner"]) + compile_model = model_config.get("compile_model", False) + en_and_decode_n_samples_a_time = model_config.get("en_and_decode_n_samples_a_time", None) + lr_scale = model_config.get("lr_scale", None) + lora_train = model_config.get("lora_train", False) + self.use_pd = model_config.get("use_pd", False) # progressive distillation + + self.log_keys = log_keys + self.input_key = input_key + self.not_trainable_prefixes = not_trainable_prefixes + self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time + self.lr_scale = lr_scale + self.lora_train = lora_train + self.noised_image_input = model_config.get("noised_image_input", False) + self.noised_image_all_concat = model_config.get("noised_image_all_concat", False) + self.noised_image_dropout = model_config.get("noised_image_dropout", 0.0) + if args.fp16: + dtype = torch.float16 + dtype_str = "fp16" + elif args.bf16: + dtype = torch.bfloat16 + dtype_str = "bf16" + else: + dtype = torch.float32 + dtype_str = "fp32" + self.dtype = dtype + self.dtype_str = dtype_str + + network_config["params"]["dtype"] = dtype_str + model = instantiate_from_config(network_config) + self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))( + model, compile_model=compile_model, dtype=dtype + ) + + self.denoiser = instantiate_from_config(denoiser_config) + self.sampler = instantiate_from_config(sampler_config) if sampler_config is not None else None + self.conditioner = instantiate_from_config(default(conditioner_config, UNCONDITIONAL_CONFIG)) + + self._init_first_stage(first_stage_config) + + self.loss_fn = instantiate_from_config(loss_fn_config) if loss_fn_config is not None else None + + self.latent_input = latent_input + self.scale_factor = scale_factor + self.disable_first_stage_autocast = disable_first_stage_autocast + self.no_cond_log = no_cond_log + self.device = args.device + + def disable_untrainable_params(self): + total_trainable = 0 + for n, p in self.named_parameters(): + if p.requires_grad == False: + continue + flag = False + for prefix in self.not_trainable_prefixes: + if n.startswith(prefix) or prefix == "all": + flag = True + break + + lora_prefix = ["matrix_A", "matrix_B"] + for prefix in lora_prefix: + if prefix in n: + flag = False + break + + if flag: + p.requires_grad_(False) + else: + total_trainable += p.numel() + + print_rank0("***** Total trainable parameters: " + str(total_trainable) + " *****") + + def reinit(self, parent_model=None): + # reload the initial params from previous trained modules + # you can also get access to other mixins through parent_model.get_mixin(). + pass + + def _init_first_stage(self, config): + model = instantiate_from_config(config).eval() + model.train = disabled_train + for param in model.parameters(): + param.requires_grad = False + self.first_stage_model = model + + def forward(self, x, batch): + loss = self.loss_fn(self.model, self.denoiser, self.conditioner, x, batch) + loss_mean = loss.mean() + loss_dict = {"loss": loss_mean} + return loss_mean, loss_dict + + def add_noise_to_first_frame(self, image): + sigma = torch.normal(mean=-3.0, std=0.5, size=(image.shape[0],)).to(self.device) + sigma = torch.exp(sigma).to(image.dtype) + image_noise = torch.randn_like(image) * sigma[:, None, None, None, None] + image = image + image_noise + return image + + def shared_step(self, batch: Dict) -> Any: + x = self.get_input(batch) + if self.lr_scale is not None: + lr_x = F.interpolate(x, scale_factor=1 / self.lr_scale, mode="bilinear", align_corners=False) + lr_x = F.interpolate(lr_x, scale_factor=self.lr_scale, mode="bilinear", align_corners=False) + lr_z = self.encode_first_stage(lr_x, batch) + batch["lr_input"] = lr_z + + x = x.permute(0, 2, 1, 3, 4).contiguous() + if self.noised_image_input: + image = x[:, :, 0:1] + image = self.add_noise_to_first_frame(image) + image = self.encode_first_stage(image, batch) + + x = self.encode_first_stage(x, batch) + x = x.permute(0, 2, 1, 3, 4).contiguous() + if self.noised_image_input: + image = image.permute(0, 2, 1, 3, 4).contiguous() + if self.noised_image_all_concat: + image = image.repeat(1, x.shape[1], 1, 1, 1) + else: + image = torch.concat([image, torch.zeros_like(x[:, 1:])], dim=1) + if random.random() < self.noised_image_dropout: + image = torch.zeros_like(image) + batch["concat_images"] = image + + # gc.collect() + # torch.cuda.empty_cache() + loss, loss_dict = self(x, batch) + return loss, loss_dict + + def get_input(self, batch): + return batch[self.input_key].to(self.dtype) + + @torch.no_grad() + def decode_first_stage(self, z): + z = 1.0 / self.scale_factor * z + n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0]) + n_rounds = math.ceil(z.shape[0] / n_samples) + all_out = [] + with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast): + for n in range(n_rounds): + if isinstance(self.first_stage_model.decoder, VideoDecoder): + kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])} + else: + kwargs = {} + out = self.first_stage_model.decode(z[n * n_samples : (n + 1) * n_samples], **kwargs) + all_out.append(out) + out = torch.cat(all_out, dim=0) + return out + + @torch.no_grad() + def encode_first_stage(self, x, batch): + frame = x.shape[2] + + if frame > 1 and self.latent_input: + x = x.permute(0, 2, 1, 3, 4).contiguous() + return x * self.scale_factor # already encoded + + n_samples = default(self.en_and_decode_n_samples_a_time, x.shape[0]) + n_rounds = math.ceil(x.shape[0] / n_samples) + all_out = [] + with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast): + for n in range(n_rounds): + out = self.first_stage_model.encode(x[n * n_samples : (n + 1) * n_samples]) + all_out.append(out) + z = torch.cat(all_out, dim=0) + z = self.scale_factor * z + return z + + @torch.no_grad() + def sample( + self, + cond: Dict, + uc: Union[Dict, None] = None, + batch_size: int = 16, + shape: Union[None, Tuple, List] = None, + prefix=None, + concat_images=None, + **kwargs, + ): + randn = torch.randn(batch_size, *shape).to(torch.float32).to(self.device) + if hasattr(self, "seeded_noise"): + randn = self.seeded_noise(randn) + + if prefix is not None: + randn = torch.cat([prefix, randn[:, prefix.shape[1] :]], dim=1) + + # broadcast noise + mp_size = mpu.get_model_parallel_world_size() + if mp_size > 1: + global_rank = torch.distributed.get_rank() // mp_size + src = global_rank * mp_size + torch.distributed.broadcast(randn, src=src, group=mpu.get_model_parallel_group()) + + scale = None + scale_emb = None + + denoiser = lambda input, sigma, c, **addtional_model_inputs: self.denoiser( + self.model, input, sigma, c, concat_images=concat_images, **addtional_model_inputs + ) + + samples = self.sampler(denoiser, randn, cond, uc=uc, scale=scale, scale_emb=scale_emb) + samples = samples.to(self.dtype) + return samples + + @torch.no_grad() + def log_conditionings(self, batch: Dict, n: int) -> Dict: + """ + Defines heuristics to log different conditionings. + These can be lists of strings (text-to-image), tensors, ints, ... + """ + image_h, image_w = batch[self.input_key].shape[3:] + log = dict() + + for embedder in self.conditioner.embedders: + if ((self.log_keys is None) or (embedder.input_key in self.log_keys)) and not self.no_cond_log: + x = batch[embedder.input_key][:n] + if isinstance(x, torch.Tensor): + if x.dim() == 1: + # class-conditional, convert integer to string + x = [str(x[i].item()) for i in range(x.shape[0])] + xc = log_txt_as_img((image_h, image_w), x, size=image_h // 4) + elif x.dim() == 2: + # size and crop cond and the like + x = ["x".join([str(xx) for xx in x[i].tolist()]) for i in range(x.shape[0])] + xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20) + else: + raise NotImplementedError() + elif isinstance(x, (List, ListConfig)): + if isinstance(x[0], str): + xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20) + else: + raise NotImplementedError() + else: + raise NotImplementedError() + log[embedder.input_key] = xc + return log + + @torch.no_grad() + def log_video( + self, + batch: Dict, + N: int = 8, + ucg_keys: List[str] = None, + only_log_video_latents=False, + **kwargs, + ) -> Dict: + conditioner_input_keys = [e.input_key for e in self.conditioner.embedders] + if ucg_keys: + assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), ( + "Each defined ucg key for sampling must be in the provided conditioner input keys," + f"but we have {ucg_keys} vs. {conditioner_input_keys}" + ) + else: + ucg_keys = conditioner_input_keys + log = dict() + + x = self.get_input(batch) + + c, uc = self.conditioner.get_unconditional_conditioning( + batch, + force_uc_zero_embeddings=ucg_keys if len(self.conditioner.embedders) > 0 else [], + ) + + sampling_kwargs = {} + + N = min(x.shape[0], N) + x = x.to(self.device)[:N] + if not self.latent_input: + log["inputs"] = x.to(torch.float32) + x = x.permute(0, 2, 1, 3, 4).contiguous() + z = self.encode_first_stage(x, batch) + if not only_log_video_latents: + log["reconstructions"] = self.decode_first_stage(z).to(torch.float32) + log["reconstructions"] = log["reconstructions"].permute(0, 2, 1, 3, 4).contiguous() + z = z.permute(0, 2, 1, 3, 4).contiguous() + + log.update(self.log_conditionings(batch, N)) + + for k in c: + if isinstance(c[k], torch.Tensor): + c[k], uc[k] = map(lambda y: y[k][:N].to(self.device), (c, uc)) + + if self.noised_image_input: + image = x[:, :, 0:1] + image = self.add_noise_to_first_frame(image) + image = self.encode_first_stage(image, batch) + image = image.permute(0, 2, 1, 3, 4).contiguous() + image = torch.concat([image, torch.zeros_like(z[:, 1:])], dim=1) + c["concat"] = image + uc["concat"] = image + samples = self.sample(c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs) # b t c h w + samples = samples.permute(0, 2, 1, 3, 4).contiguous() + if only_log_video_latents: + latents = 1.0 / self.scale_factor * samples + log["latents"] = latents + else: + samples = self.decode_first_stage(samples).to(torch.float32) + samples = samples.permute(0, 2, 1, 3, 4).contiguous() + log["samples"] = samples + else: + samples = self.sample(c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs) # b t c h w + samples = samples.permute(0, 2, 1, 3, 4).contiguous() + if only_log_video_latents: + latents = 1.0 / self.scale_factor * samples + log["latents"] = latents + else: + samples = self.decode_first_stage(samples).to(torch.float32) + samples = samples.permute(0, 2, 1, 3, 4).contiguous() + log["samples"] = samples + return log diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/dit_video_concat.py b/PyTorch/contrib/cv/video/CogVideoX/sat/dit_video_concat.py new file mode 100644 index 0000000000000000000000000000000000000000..d36f51267c04b246d66684602bfce1bbacf0862d --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/dit_video_concat.py @@ -0,0 +1,808 @@ +from functools import partial +from einops import rearrange, repeat +import numpy as np + +import torch +from torch import nn +import torch.nn.functional as F + +from sat.model.base_model import BaseModel, non_conflict +from sat.model.mixins import BaseMixin +from sat.transformer_defaults import HOOKS_DEFAULT, attention_fn_default +from sat.mpu.layers import ColumnParallelLinear +from sgm.util import instantiate_from_config + +from sgm.modules.diffusionmodules.openaimodel import Timestep +from sgm.modules.diffusionmodules.util import ( + linear, + timestep_embedding, +) +from sat.ops.layernorm import LayerNorm, RMSNorm + + +class ImagePatchEmbeddingMixin(BaseMixin): + def __init__( + self, + in_channels, + hidden_size, + patch_size, + bias=True, + text_hidden_size=None, + ): + super().__init__() + self.proj = nn.Conv2d(in_channels, hidden_size, kernel_size=patch_size, stride=patch_size, bias=bias) + if text_hidden_size is not None: + self.text_proj = nn.Linear(text_hidden_size, hidden_size) + else: + self.text_proj = None + + def word_embedding_forward(self, input_ids, **kwargs): + # now is 3d patch + images = kwargs["images"] # (b,t,c,h,w) + B, T = images.shape[:2] + emb = images.view(-1, *images.shape[2:]) + emb = self.proj(emb) # ((b t),d,h/2,w/2) + emb = emb.view(B, T, *emb.shape[1:]) + emb = emb.flatten(3).transpose(2, 3) # (b,t,n,d) + emb = rearrange(emb, "b t n d -> b (t n) d") + + if self.text_proj is not None: + text_emb = self.text_proj(kwargs["encoder_outputs"]) + emb = torch.cat((text_emb, emb), dim=1) # (b,n_t+t*n_i,d) + + emb = emb.contiguous() + return emb # (b,n_t+t*n_i,d) + + def reinit(self, parent_model=None): + w = self.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.proj.bias, 0) + del self.transformer.word_embeddings + + +def get_3d_sincos_pos_embed( + embed_dim, + grid_height, + grid_width, + t_size, + cls_token=False, + height_interpolation=1.0, + width_interpolation=1.0, + time_interpolation=1.0, +): + """ + grid_size: int of the grid height and width + t_size: int of the temporal size + return: + pos_embed: [t_size*grid_size*grid_size, embed_dim] or [1+t_size*grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + assert embed_dim % 4 == 0 + embed_dim_spatial = embed_dim // 4 * 3 + embed_dim_temporal = embed_dim // 4 + + # spatial + grid_h = np.arange(grid_height, dtype=np.float32) / height_interpolation + grid_w = np.arange(grid_width, dtype=np.float32) / width_interpolation + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_height, grid_width]) + pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid) + + # temporal + grid_t = np.arange(t_size, dtype=np.float32) / time_interpolation + pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(embed_dim_temporal, grid_t) + + # concate: [T, H, W] order + pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :] + pos_embed_temporal = np.repeat(pos_embed_temporal, grid_height * grid_width, axis=1) # [T, H*W, D // 4] + pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :] + pos_embed_spatial = np.repeat(pos_embed_spatial, t_size, axis=0) # [T, H*W, D // 4 * 3] + + pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1) + # pos_embed = pos_embed.reshape([-1, embed_dim]) # [T*H*W, D] + + return pos_embed # [T, H*W, D] + + +def get_2d_sincos_pos_embed(embed_dim, grid_height, grid_width, cls_token=False, extra_tokens=0): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_height, dtype=np.float32) + grid_w = np.arange(grid_width, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_height, grid_width]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +class Basic2DPositionEmbeddingMixin(BaseMixin): + def __init__(self, height, width, compressed_num_frames, hidden_size, text_length=0): + super().__init__() + self.height = height + self.width = width + self.spatial_length = height * width + self.pos_embedding = nn.Parameter( + torch.zeros(1, int(text_length + self.spatial_length), int(hidden_size)), requires_grad=False + ) + + def position_embedding_forward(self, position_ids, **kwargs): + return self.pos_embedding + + def reinit(self, parent_model=None): + del self.transformer.position_embeddings + pos_embed = get_2d_sincos_pos_embed(self.pos_embedding.shape[-1], self.height, self.width) + self.pos_embedding.data[:, -self.spatial_length :].copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + + +class Basic3DPositionEmbeddingMixin(BaseMixin): + def __init__( + self, + height, + width, + compressed_num_frames, + hidden_size, + text_length=0, + height_interpolation=1.0, + width_interpolation=1.0, + time_interpolation=1.0, + ): + super().__init__() + self.height = height + self.width = width + self.text_length = text_length + self.compressed_num_frames = compressed_num_frames + self.spatial_length = height * width + self.num_patches = height * width * compressed_num_frames + self.pos_embedding = nn.Parameter( + torch.zeros(1, int(text_length + self.num_patches), int(hidden_size)), requires_grad=False + ) + self.height_interpolation = height_interpolation + self.width_interpolation = width_interpolation + self.time_interpolation = time_interpolation + + def position_embedding_forward(self, position_ids, **kwargs): + if kwargs["images"].shape[1] == 1: + return self.pos_embedding[:, : self.text_length + self.spatial_length] + + return self.pos_embedding[:, : self.text_length + kwargs["seq_length"]] + + def reinit(self, parent_model=None): + del self.transformer.position_embeddings + pos_embed = get_3d_sincos_pos_embed( + self.pos_embedding.shape[-1], + self.height, + self.width, + self.compressed_num_frames, + height_interpolation=self.height_interpolation, + width_interpolation=self.width_interpolation, + time_interpolation=self.time_interpolation, + ) + pos_embed = torch.from_numpy(pos_embed).float() + pos_embed = rearrange(pos_embed, "t n d -> (t n) d") + self.pos_embedding.data[:, -self.num_patches :].copy_(pos_embed) + + +def broadcat(tensors, dim=-1): + num_tensors = len(tensors) + shape_lens = set(list(map(lambda t: len(t.shape), tensors))) + assert len(shape_lens) == 1, "tensors must all have the same number of dimensions" + shape_len = list(shape_lens)[0] + dim = (dim + shape_len) if dim < 0 else dim + dims = list(zip(*map(lambda t: list(t.shape), tensors))) + expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] + assert all( + [*map(lambda t: len(set(t[1])) <= 2, expandable_dims)] + ), "invalid dimensions for broadcastable concatentation" + max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims)) + expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims)) + expanded_dims.insert(dim, (dim, dims[dim])) + expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims))) + tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes))) + return torch.cat(tensors, dim=dim) + + +def rotate_half(x): + x = rearrange(x, "... (d r) -> ... d r", r=2).contiguous() + x1, x2 = x.unbind(dim=-1) + x = torch.stack((-x2, x1), dim=-1) + return rearrange(x, "... d r -> ... (d r)") + + +class Rotary3DPositionEmbeddingMixin(BaseMixin): + def __init__( + self, + height, + width, + compressed_num_frames, + hidden_size, + hidden_size_head, + text_length, + theta=10000, + rot_v=False, + learnable_pos_embed=False, + ): + super().__init__() + self.rot_v = rot_v + + dim_t = hidden_size_head // 4 + dim_h = hidden_size_head // 8 * 3 + dim_w = hidden_size_head // 8 * 3 + + freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2)[: (dim_t // 2)].float() / dim_t)) + freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2)[: (dim_h // 2)].float() / dim_h)) + freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2)[: (dim_w // 2)].float() / dim_w)) + + grid_t = torch.arange(compressed_num_frames, dtype=torch.float32) + grid_h = torch.arange(height, dtype=torch.float32) + grid_w = torch.arange(width, dtype=torch.float32) + + freqs_t = torch.einsum("..., f -> ... f", grid_t, freqs_t) + freqs_h = torch.einsum("..., f -> ... f", grid_h, freqs_h) + freqs_w = torch.einsum("..., f -> ... f", grid_w, freqs_w) + + freqs_t = repeat(freqs_t, "... n -> ... (n r)", r=2) + freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2) + freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2) + + freqs = broadcat((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1) + freqs = rearrange(freqs, "t h w d -> (t h w) d") + + freqs = freqs.contiguous() + freqs_sin = freqs.sin() + freqs_cos = freqs.cos() + self.register_buffer("freqs_sin", freqs_sin) + self.register_buffer("freqs_cos", freqs_cos) + + self.text_length = text_length + if learnable_pos_embed: + num_patches = height * width * compressed_num_frames + text_length + self.pos_embedding = nn.Parameter(torch.zeros(1, num_patches, int(hidden_size)), requires_grad=True) + else: + self.pos_embedding = None + + def rotary(self, t, **kwargs): + seq_len = t.shape[2] + freqs_cos = self.freqs_cos[:seq_len].unsqueeze(0).unsqueeze(0) + freqs_sin = self.freqs_sin[:seq_len].unsqueeze(0).unsqueeze(0) + + return t * freqs_cos + rotate_half(t) * freqs_sin + + def position_embedding_forward(self, position_ids, **kwargs): + if self.pos_embedding is not None: + return self.pos_embedding[:, :self.text_length + kwargs["seq_length"]] + else: + return None + + def attention_fn( + self, + query_layer, + key_layer, + value_layer, + attention_mask, + attention_dropout=None, + log_attention_weights=None, + scaling_attention_score=True, + **kwargs, + ): + attention_fn_default = HOOKS_DEFAULT["attention_fn"] + + query_layer[:, :, self.text_length :] = self.rotary(query_layer[:, :, self.text_length :]) + key_layer[:, :, self.text_length :] = self.rotary(key_layer[:, :, self.text_length :]) + if self.rot_v: + value_layer[:, :, self.text_length :] = self.rotary(value_layer[:, :, self.text_length :]) + + return attention_fn_default( + query_layer, + key_layer, + value_layer, + attention_mask, + attention_dropout=attention_dropout, + log_attention_weights=log_attention_weights, + scaling_attention_score=scaling_attention_score, + **kwargs, + ) + + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +def unpatchify(x, c, p, w, h, rope_position_ids=None, **kwargs): + """ + x: (N, T/2 * S, patch_size**3 * C) + imgs: (N, T, H, W, C) + """ + if rope_position_ids is not None: + assert NotImplementedError + # do pix2struct unpatchify + L = x.shape[1] + x = x.reshape(shape=(x.shape[0], L, p, p, c)) + x = torch.einsum("nlpqc->ncplq", x) + imgs = x.reshape(shape=(x.shape[0], c, p, L * p)) + else: + b = x.shape[0] + imgs = rearrange(x, "b (t h w) (c p q) -> b t c (h p) (w q)", b=b, h=h, w=w, c=c, p=p, q=p) + + return imgs + + +class FinalLayerMixin(BaseMixin): + def __init__( + self, + hidden_size, + time_embed_dim, + patch_size, + out_channels, + latent_width, + latent_height, + elementwise_affine, + ): + super().__init__() + self.hidden_size = hidden_size + self.patch_size = patch_size + self.out_channels = out_channels + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=elementwise_affine, eps=1e-6) + self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(time_embed_dim, 2 * hidden_size, bias=True)) + + self.spatial_length = latent_width * latent_height // patch_size**2 + self.latent_width = latent_width + self.latent_height = latent_height + + def final_forward(self, logits, **kwargs): + x, emb = logits[:, kwargs["text_length"] :, :], kwargs["emb"] # x:(b,(t n),d) + shift, scale = self.adaLN_modulation(emb).chunk(2, dim=1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + + return unpatchify( + x, + c=self.out_channels, + p=self.patch_size, + w=self.latent_width // self.patch_size, + h=self.latent_height // self.patch_size, + rope_position_ids=kwargs.get("rope_position_ids", None), + **kwargs, + ) + + def reinit(self, parent_model=None): + nn.init.xavier_uniform_(self.linear.weight) + nn.init.constant_(self.linear.bias, 0) + + +class SwiGLUMixin(BaseMixin): + def __init__(self, num_layers, in_features, hidden_features, bias=False): + super().__init__() + self.w2 = nn.ModuleList( + [ + ColumnParallelLinear( + in_features, + hidden_features, + gather_output=False, + bias=bias, + module=self, + name="dense_h_to_4h_gate", + ) + for i in range(num_layers) + ] + ) + + def mlp_forward(self, hidden_states, **kw_args): + x = hidden_states + origin = self.transformer.layers[kw_args["layer_id"]].mlp + x1 = origin.dense_h_to_4h(x) + x2 = self.w2[kw_args["layer_id"]](x) + hidden = origin.activation_func(x2) * x1 + x = origin.dense_4h_to_h(hidden) + return x + + +class AdaLNMixin(BaseMixin): + def __init__( + self, + width, + height, + hidden_size, + num_layers, + time_embed_dim, + compressed_num_frames, + qk_ln=True, + hidden_size_head=None, + elementwise_affine=True, + ): + super().__init__() + self.num_layers = num_layers + self.width = width + self.height = height + self.compressed_num_frames = compressed_num_frames + + self.adaLN_modulations = nn.ModuleList( + [nn.Sequential(nn.SiLU(), nn.Linear(time_embed_dim, 12 * hidden_size)) for _ in range(num_layers)] + ) + + self.qk_ln = qk_ln + if qk_ln: + self.query_layernorm_list = nn.ModuleList( + [ + LayerNorm(hidden_size_head, eps=1e-6, elementwise_affine=elementwise_affine) + for _ in range(num_layers) + ] + ) + self.key_layernorm_list = nn.ModuleList( + [ + LayerNorm(hidden_size_head, eps=1e-6, elementwise_affine=elementwise_affine) + for _ in range(num_layers) + ] + ) + + def layer_forward( + self, + hidden_states, + mask, + *args, + **kwargs, + ): + text_length = kwargs["text_length"] + # hidden_states (b,(n_t+t*n_i),d) + text_hidden_states = hidden_states[:, :text_length] # (b,n,d) + img_hidden_states = hidden_states[:, text_length:] # (b,(t n),d) + + layer = self.transformer.layers[kwargs["layer_id"]] + adaLN_modulation = self.adaLN_modulations[kwargs["layer_id"]] + + ( + shift_msa, + scale_msa, + gate_msa, + shift_mlp, + scale_mlp, + gate_mlp, + text_shift_msa, + text_scale_msa, + text_gate_msa, + text_shift_mlp, + text_scale_mlp, + text_gate_mlp, + ) = adaLN_modulation(kwargs["emb"]).chunk(12, dim=1) + gate_msa, gate_mlp, text_gate_msa, text_gate_mlp = ( + gate_msa.unsqueeze(1), + gate_mlp.unsqueeze(1), + text_gate_msa.unsqueeze(1), + text_gate_mlp.unsqueeze(1), + ) + + # self full attention (b,(t n),d) + img_attention_input = layer.input_layernorm(img_hidden_states) + text_attention_input = layer.input_layernorm(text_hidden_states) + img_attention_input = modulate(img_attention_input, shift_msa, scale_msa) + text_attention_input = modulate(text_attention_input, text_shift_msa, text_scale_msa) + + attention_input = torch.cat((text_attention_input, img_attention_input), dim=1) # (b,n_t+t*n_i,d) + attention_output = layer.attention(attention_input, mask, **kwargs) + text_attention_output = attention_output[:, :text_length] # (b,n,d) + img_attention_output = attention_output[:, text_length:] # (b,(t n),d) + if self.transformer.layernorm_order == "sandwich": + text_attention_output = layer.third_layernorm(text_attention_output) + img_attention_output = layer.third_layernorm(img_attention_output) + img_hidden_states = img_hidden_states + gate_msa * img_attention_output # (b,(t n),d) + text_hidden_states = text_hidden_states + text_gate_msa * text_attention_output # (b,n,d) + + # mlp (b,(t n),d) + img_mlp_input = layer.post_attention_layernorm(img_hidden_states) # vision (b,(t n),d) + text_mlp_input = layer.post_attention_layernorm(text_hidden_states) # language (b,n,d) + img_mlp_input = modulate(img_mlp_input, shift_mlp, scale_mlp) + text_mlp_input = modulate(text_mlp_input, text_shift_mlp, text_scale_mlp) + mlp_input = torch.cat((text_mlp_input, img_mlp_input), dim=1) # (b,(n_t+t*n_i),d + mlp_output = layer.mlp(mlp_input, **kwargs) + img_mlp_output = mlp_output[:, text_length:] # vision (b,(t n),d) + text_mlp_output = mlp_output[:, :text_length] # language (b,n,d) + if self.transformer.layernorm_order == "sandwich": + text_mlp_output = layer.fourth_layernorm(text_mlp_output) + img_mlp_output = layer.fourth_layernorm(img_mlp_output) + + img_hidden_states = img_hidden_states + gate_mlp * img_mlp_output # vision (b,(t n),d) + text_hidden_states = text_hidden_states + text_gate_mlp * text_mlp_output # language (b,n,d) + + hidden_states = torch.cat((text_hidden_states, img_hidden_states), dim=1) # (b,(n_t+t*n_i),d) + return hidden_states + + def reinit(self, parent_model=None): + for layer in self.adaLN_modulations: + nn.init.constant_(layer[-1].weight, 0) + nn.init.constant_(layer[-1].bias, 0) + + @non_conflict + def attention_fn( + self, + query_layer, + key_layer, + value_layer, + attention_mask, + attention_dropout=None, + log_attention_weights=None, + scaling_attention_score=True, + old_impl=attention_fn_default, + **kwargs, + ): + if self.qk_ln: + query_layernorm = self.query_layernorm_list[kwargs["layer_id"]] + key_layernorm = self.key_layernorm_list[kwargs["layer_id"]] + query_layer = query_layernorm(query_layer) + key_layer = key_layernorm(key_layer) + + return old_impl( + query_layer, + key_layer, + value_layer, + attention_mask, + attention_dropout=attention_dropout, + log_attention_weights=log_attention_weights, + scaling_attention_score=scaling_attention_score, + **kwargs, + ) + + +str_to_dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16} + + +class DiffusionTransformer(BaseModel): + def __init__( + self, + transformer_args, + num_frames, + time_compressed_rate, + latent_width, + latent_height, + patch_size, + in_channels, + out_channels, + hidden_size, + num_layers, + num_attention_heads, + elementwise_affine, + time_embed_dim=None, + num_classes=None, + modules={}, + input_time="adaln", + adm_in_channels=None, + parallel_output=True, + height_interpolation=1.0, + width_interpolation=1.0, + time_interpolation=1.0, + use_SwiGLU=False, + use_RMSNorm=False, + zero_init_y_embed=False, + **kwargs, + ): + self.latent_width = latent_width + self.latent_height = latent_height + self.patch_size = patch_size + self.num_frames = num_frames + self.time_compressed_rate = time_compressed_rate + self.spatial_length = latent_width * latent_height // patch_size**2 + self.in_channels = in_channels + self.out_channels = out_channels + self.hidden_size = hidden_size + self.model_channels = hidden_size + self.time_embed_dim = time_embed_dim if time_embed_dim is not None else hidden_size + self.num_classes = num_classes + self.adm_in_channels = adm_in_channels + self.input_time = input_time + self.num_layers = num_layers + self.num_attention_heads = num_attention_heads + self.is_decoder = transformer_args.is_decoder + self.elementwise_affine = elementwise_affine + self.height_interpolation = height_interpolation + self.width_interpolation = width_interpolation + self.time_interpolation = time_interpolation + self.inner_hidden_size = hidden_size * 4 + self.zero_init_y_embed = zero_init_y_embed + try: + self.dtype = str_to_dtype[kwargs.pop("dtype")] + except: + self.dtype = torch.float32 + + if use_SwiGLU: + kwargs["activation_func"] = F.silu + elif "activation_func" not in kwargs: + approx_gelu = nn.GELU(approximate="tanh") + kwargs["activation_func"] = approx_gelu + + if use_RMSNorm: + kwargs["layernorm"] = RMSNorm + else: + kwargs["layernorm"] = partial(LayerNorm, elementwise_affine=elementwise_affine, eps=1e-6) + + transformer_args.num_layers = num_layers + transformer_args.hidden_size = hidden_size + transformer_args.num_attention_heads = num_attention_heads + transformer_args.parallel_output = parallel_output + super().__init__(args=transformer_args, transformer=None, **kwargs) + + module_configs = modules + self._build_modules(module_configs) + + if use_SwiGLU: + self.add_mixin( + "swiglu", SwiGLUMixin(num_layers, hidden_size, self.inner_hidden_size, bias=False), reinit=True + ) + + def _build_modules(self, module_configs): + model_channels = self.hidden_size + # time_embed_dim = model_channels * 4 + time_embed_dim = self.time_embed_dim + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + if self.num_classes is not None: + if isinstance(self.num_classes, int): + self.label_emb = nn.Embedding(self.num_classes, time_embed_dim) + elif self.num_classes == "continuous": + print("setting up linear c_adm embedding layer") + self.label_emb = nn.Linear(1, time_embed_dim) + elif self.num_classes == "timestep": + self.label_emb = nn.Sequential( + Timestep(model_channels), + nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ), + ) + elif self.num_classes == "sequential": + assert self.adm_in_channels is not None + self.label_emb = nn.Sequential( + nn.Sequential( + linear(self.adm_in_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + ) + if self.zero_init_y_embed: + nn.init.constant_(self.label_emb[0][2].weight, 0) + nn.init.constant_(self.label_emb[0][2].bias, 0) + else: + raise ValueError() + + pos_embed_config = module_configs["pos_embed_config"] + self.add_mixin( + "pos_embed", + instantiate_from_config( + pos_embed_config, + height=self.latent_height // self.patch_size, + width=self.latent_width // self.patch_size, + compressed_num_frames=(self.num_frames - 1) // self.time_compressed_rate + 1, + hidden_size=self.hidden_size, + ), + reinit=True, + ) + + patch_embed_config = module_configs["patch_embed_config"] + self.add_mixin( + "patch_embed", + instantiate_from_config( + patch_embed_config, + patch_size=self.patch_size, + hidden_size=self.hidden_size, + in_channels=self.in_channels, + ), + reinit=True, + ) + if self.input_time == "adaln": + adaln_layer_config = module_configs["adaln_layer_config"] + self.add_mixin( + "adaln_layer", + instantiate_from_config( + adaln_layer_config, + height=self.latent_height // self.patch_size, + width=self.latent_width // self.patch_size, + hidden_size=self.hidden_size, + num_layers=self.num_layers, + compressed_num_frames=(self.num_frames - 1) // self.time_compressed_rate + 1, + hidden_size_head=self.hidden_size // self.num_attention_heads, + time_embed_dim=self.time_embed_dim, + elementwise_affine=self.elementwise_affine, + ), + ) + else: + raise NotImplementedError + + final_layer_config = module_configs["final_layer_config"] + self.add_mixin( + "final_layer", + instantiate_from_config( + final_layer_config, + hidden_size=self.hidden_size, + patch_size=self.patch_size, + out_channels=self.out_channels, + time_embed_dim=self.time_embed_dim, + latent_width=self.latent_width, + latent_height=self.latent_height, + elementwise_affine=self.elementwise_affine, + ), + reinit=True, + ) + + if "lora_config" in module_configs: + lora_config = module_configs["lora_config"] + self.add_mixin("lora", instantiate_from_config(lora_config, layer_num=self.num_layers), reinit=True) + + return + + def forward(self, x, timesteps=None, context=None, y=None, **kwargs): + b, t, d, h, w = x.shape + if x.dtype != self.dtype: + x = x.to(self.dtype) + + # This is not use in inference + if "concat_images" in kwargs and kwargs["concat_images"] is not None: + if kwargs["concat_images"].shape[0] != x.shape[0]: + concat_images = kwargs["concat_images"].repeat(2, 1, 1, 1, 1) + else: + concat_images = kwargs["concat_images"] + x = torch.cat([x, concat_images], dim=2) + + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False, dtype=self.dtype) + emb = self.time_embed(t_emb) + + if self.num_classes is not None: + # assert y.shape[0] == x.shape[0] + assert x.shape[0] % y.shape[0] == 0 + y = y.repeat_interleave(x.shape[0] // y.shape[0], dim=0) + emb = emb + self.label_emb(y) + + kwargs["seq_length"] = t * h * w // (self.patch_size**2) + kwargs["images"] = x + kwargs["emb"] = emb + kwargs["encoder_outputs"] = context + kwargs["text_length"] = context.shape[1] + + kwargs["input_ids"] = kwargs["position_ids"] = kwargs["attention_mask"] = torch.ones((1, 1)).to(x.dtype) + output = super().forward(**kwargs)[0] + return output diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/finetune_multi_gpus.sh b/PyTorch/contrib/cv/video/CogVideoX/sat/finetune_multi_gpus.sh new file mode 100644 index 0000000000000000000000000000000000000000..9aab4c5782acbefb159b9f416988c244b11c218f --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/finetune_multi_gpus.sh @@ -0,0 +1,8 @@ +#! /bin/bash + +run_cmd="PYTORCH_NPU_ALLOC_CONF=expandable_segments:True torchrun --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr=127.0.0.1 --master_port=25000 train_video.py --base configs/cogvideox_5b_sft.yaml --seed 1234" + +echo ${run_cmd} +eval ${run_cmd} + +echo "DONE on `hostname`" \ No newline at end of file diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/requirements.txt b/PyTorch/contrib/cv/video/CogVideoX/sat/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..3467e3ee3dc8a7cafe332996d98d9f00464dc453 --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/requirements.txt @@ -0,0 +1,16 @@ +SwissArmyTransformer==0.4.12 +omegaconf==2.3.0 +torch==2.1.0 +torchvision==0.16.0 +pytorch_lightning==2.3.3 +kornia==0.7.3 +beartype==0.18.5 +numpy==2.0.1 +fsspec==2024.5.0 +safetensors==0.4.3 +imageio-ffmpeg==0.5.1 +imageio==2.34.2 +scipy==1.14.0 +decord==0.6.0 +wandb==0.17.5 +deepspeed==0.14.4 \ No newline at end of file diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/sample_video.py b/PyTorch/contrib/cv/video/CogVideoX/sat/sample_video.py new file mode 100644 index 0000000000000000000000000000000000000000..49cfcac7361a8cf3536c09f98301b4a29cf886cf --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/sample_video.py @@ -0,0 +1,259 @@ +import os +import math +import argparse +from typing import List, Union +from tqdm import tqdm +from omegaconf import ListConfig +import imageio + +import torch +import numpy as np +from einops import rearrange +import torchvision.transforms as TT + + +from sat.model.base_model import get_model +from sat.training.model_io import load_checkpoint +from sat import mpu + +from diffusion_video import SATVideoDiffusionEngine +from arguments import get_args +from torchvision.transforms.functional import center_crop, resize +from torchvision.transforms import InterpolationMode +from PIL import Image + + +def read_from_cli(): + cnt = 0 + try: + while True: + x = input("Please input English text (Ctrl-D quit): ") + yield x.strip(), cnt + cnt += 1 + except EOFError as e: + pass + + +def read_from_file(p, rank=0, world_size=1): + with open(p, "r") as fin: + cnt = -1 + for l in fin: + cnt += 1 + if cnt % world_size != rank: + continue + yield l.strip(), cnt + + +def get_unique_embedder_keys_from_conditioner(conditioner): + return list(set([x.input_key for x in conditioner.embedders])) + + +def get_batch(keys, value_dict, N: Union[List, ListConfig], T=None, device="cuda"): + batch = {} + batch_uc = {} + + for key in keys: + if key == "txt": + batch["txt"] = np.repeat([value_dict["prompt"]], repeats=math.prod(N)).reshape(N).tolist() + batch_uc["txt"] = np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)).reshape(N).tolist() + else: + batch[key] = value_dict[key] + + if T is not None: + batch["num_video_frames"] = T + + for key in batch.keys(): + if key not in batch_uc and isinstance(batch[key], torch.Tensor): + batch_uc[key] = torch.clone(batch[key]) + return batch, batch_uc + + +def save_video_as_grid_and_mp4(video_batch: torch.Tensor, save_path: str, fps: int = 5, args=None, key=None): + os.makedirs(save_path, exist_ok=True) + + for i, vid in enumerate(video_batch): + gif_frames = [] + for frame in vid: + frame = rearrange(frame, "c h w -> h w c") + frame = (255.0 * frame).cpu().numpy().astype(np.uint8) + gif_frames.append(frame) + now_save_path = os.path.join(save_path, f"{i:06d}.mp4") + with imageio.get_writer(now_save_path, fps=fps) as writer: + for frame in gif_frames: + writer.append_data(frame) + + +def resize_for_rectangle_crop(arr, image_size, reshape_mode="random"): + if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]: + arr = resize( + arr, + size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])], + interpolation=InterpolationMode.BICUBIC, + ) + else: + arr = resize( + arr, + size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]], + interpolation=InterpolationMode.BICUBIC, + ) + + h, w = arr.shape[2], arr.shape[3] + arr = arr.squeeze(0) + + delta_h = h - image_size[0] + delta_w = w - image_size[1] + + if reshape_mode == "random" or reshape_mode == "none": + top = np.random.randint(0, delta_h + 1) + left = np.random.randint(0, delta_w + 1) + elif reshape_mode == "center": + top, left = delta_h // 2, delta_w // 2 + else: + raise NotImplementedError + arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1]) + return arr + + +def sampling_main(args, model_cls): + if isinstance(model_cls, type): + model = get_model(args, model_cls) + else: + model = model_cls + + load_checkpoint(model, args) + model.eval() + + if args.input_type == "cli": + data_iter = read_from_cli() + elif args.input_type == "txt": + rank, world_size = mpu.get_data_parallel_rank(), mpu.get_data_parallel_world_size() + print("rank and world_size", rank, world_size) + data_iter = read_from_file(args.input_file, rank=rank, world_size=world_size) + else: + raise NotImplementedError + + image_size = [480, 720] + + if args.image2video: + chained_trainsforms = [] + chained_trainsforms.append(TT.ToTensor()) + transform = TT.Compose(chained_trainsforms) + + sample_func = model.sample + T, H, W, C, F = args.sampling_num_frames, image_size[0], image_size[1], args.latent_channels, 8 + num_samples = [1] + force_uc_zero_embeddings = ["txt"] + device = model.device + with torch.no_grad(): + for text, cnt in tqdm(data_iter): + if args.image2video: + text, image_path = text.split("@@") + assert os.path.exists(image_path), image_path + image = Image.open(image_path).convert("RGB") + image = transform(image).unsqueeze(0).to("cuda") + image = resize_for_rectangle_crop(image, image_size, reshape_mode="center").unsqueeze(0) + image = image * 2.0 - 1.0 + image = image.unsqueeze(2).to(torch.bfloat16) + image = model.encode_first_stage(image, None) + image = image.permute(0, 2, 1, 3, 4).contiguous() + pad_shape = (image.shape[0], T - 1, C, H // F, W // F) + image = torch.concat([image, torch.zeros(pad_shape).to(image.device).to(image.dtype)], dim=1) + else: + image = None + + value_dict = { + "prompt": text, + "negative_prompt": "", + "num_frames": torch.tensor(T).unsqueeze(0), + } + + batch, batch_uc = get_batch( + get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, num_samples + ) + for key in batch: + if isinstance(batch[key], torch.Tensor): + print(key, batch[key].shape) + elif isinstance(batch[key], list): + print(key, [len(l) for l in batch[key]]) + else: + print(key, batch[key]) + c, uc = model.conditioner.get_unconditional_conditioning( + batch, + batch_uc=batch_uc, + force_uc_zero_embeddings=force_uc_zero_embeddings, + ) + + for k in c: + if not k == "crossattn": + c[k], uc[k] = map(lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc)) + + if args.image2video and image is not None: + c["concat"] = image + uc["concat"] = image + + for index in range(args.batch_size): + # reload model on GPU + model.to(device) + samples_z = sample_func( + c, + uc=uc, + batch_size=1, + shape=(T, C, H // F, W // F), + ) + samples_z = samples_z.permute(0, 2, 1, 3, 4).contiguous() + + # Unload the model from GPU to save GPU memory + model.to("cpu") + torch.cuda.empty_cache() + first_stage_model = model.first_stage_model + first_stage_model = first_stage_model.to(device) + + latent = 1.0 / model.scale_factor * samples_z + + # Decode latent serial to save GPU memory + recons = [] + loop_num = (T - 1) // 2 + for i in range(loop_num): + if i == 0: + start_frame, end_frame = 0, 3 + else: + start_frame, end_frame = i * 2 + 1, i * 2 + 3 + if i == loop_num - 1: + clear_fake_cp_cache = True + else: + clear_fake_cp_cache = False + with torch.no_grad(): + recon = first_stage_model.decode( + latent[:, :, start_frame:end_frame].contiguous(), clear_fake_cp_cache=clear_fake_cp_cache + ) + + recons.append(recon) + + recon = torch.cat(recons, dim=2).to(torch.float32) + samples_x = recon.permute(0, 2, 1, 3, 4).contiguous() + samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0).cpu() + + save_path = os.path.join( + args.output_dir, str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:120], str(index) + ) + if mpu.get_model_parallel_rank() == 0: + save_video_as_grid_and_mp4(samples, save_path, fps=args.sampling_fps) + + +if __name__ == "__main__": + if "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ: + os.environ["LOCAL_RANK"] = os.environ["OMPI_COMM_WORLD_LOCAL_RANK"] + os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"] + os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"] + py_parser = argparse.ArgumentParser(add_help=False) + known, args_list = py_parser.parse_known_args() + + args = get_args(args_list) + args = argparse.Namespace(**vars(args), **vars(known)) + del args.deepspeed_config + args.model_config.first_stage_config.params.cp_size = 1 + args.model_config.network_config.params.transformer_args.model_parallel_size = 1 + args.model_config.network_config.params.transformer_args.checkpoint_activations = False + args.model_config.loss_fn_config.params.sigma_sampler_config.params.uniform_sampling = False + + sampling_main(args, model_cls=SATVideoDiffusionEngine) diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/__init__.py b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1c4482364f4f054d67e5ec9ef57862976a2c6aa7 --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/__init__.py @@ -0,0 +1,4 @@ +from .models import AutoencodingEngine +from .util import get_configs_path, instantiate_from_config + +__version__ = "0.1.0" diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/lr_scheduler.py b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..b45db6983b731819de0eea23723bf83ea141f685 --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/lr_scheduler.py @@ -0,0 +1,110 @@ +import numpy as np + + +class LambdaWarmUpCosineScheduler: + """ + note: use with a base_lr of 1.0 + """ + + def __init__( + self, + warm_up_steps, + lr_min, + lr_max, + lr_start, + max_decay_steps, + verbosity_interval=0, + ): + self.lr_warm_up_steps = warm_up_steps + self.lr_start = lr_start + self.lr_min = lr_min + self.lr_max = lr_max + self.lr_max_decay_steps = max_decay_steps + self.last_lr = 0.0 + self.verbosity_interval = verbosity_interval + + def schedule(self, n, **kwargs): + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: + print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") + if n < self.lr_warm_up_steps: + lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start + self.last_lr = lr + return lr + else: + t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) + t = min(t, 1.0) + lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (1 + np.cos(t * np.pi)) + self.last_lr = lr + return lr + + def __call__(self, n, **kwargs): + return self.schedule(n, **kwargs) + + +class LambdaWarmUpCosineScheduler2: + """ + supports repeated iterations, configurable via lists + note: use with a base_lr of 1.0. + """ + + def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): + assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) + self.lr_warm_up_steps = warm_up_steps + self.f_start = f_start + self.f_min = f_min + self.f_max = f_max + self.cycle_lengths = cycle_lengths + self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) + self.last_f = 0.0 + self.verbosity_interval = verbosity_interval + + def find_in_interval(self, n): + interval = 0 + for cl in self.cum_cycles[1:]: + if n <= cl: + return interval + interval += 1 + + def schedule(self, n, **kwargs): + cycle = self.find_in_interval(n) + n = n - self.cum_cycles[cycle] + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: + print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " f"current cycle {cycle}") + if n < self.lr_warm_up_steps[cycle]: + f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] + self.last_f = f + return f + else: + t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) + t = min(t, 1.0) + f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (1 + np.cos(t * np.pi)) + self.last_f = f + return f + + def __call__(self, n, **kwargs): + return self.schedule(n, **kwargs) + + +class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): + def schedule(self, n, **kwargs): + cycle = self.find_in_interval(n) + n = n - self.cum_cycles[cycle] + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: + print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " f"current cycle {cycle}") + + if n < self.lr_warm_up_steps[cycle]: + f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] + self.last_f = f + return f + else: + f = ( + self.f_min[cycle] + + (self.f_max[cycle] - self.f_min[cycle]) + * (self.cycle_lengths[cycle] - n) + / (self.cycle_lengths[cycle]) + ) + self.last_f = f + return f diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/models/__init__.py b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e72b865963a7d7dfbe526f6b7aba63c5aa00a1e4 --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/models/__init__.py @@ -0,0 +1 @@ +from .autoencoder import AutoencodingEngine diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/models/autoencoder.py b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/models/autoencoder.py new file mode 100644 index 0000000000000000000000000000000000000000..0b21318f9473e37a7d5a39e63a966887e93f1646 --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/models/autoencoder.py @@ -0,0 +1,549 @@ +import logging +import math +import re +import random +from abc import abstractmethod +from contextlib import contextmanager +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import pytorch_lightning as pl +import torch +import torch.distributed +import torch.nn as nn +from einops import rearrange +from packaging import version + +from ..modules.autoencoding.regularizers import AbstractRegularizer +from ..modules.ema import LitEma +from ..util import ( + default, + get_nested_attribute, + get_obj_from_str, + instantiate_from_config, + initialize_context_parallel, + get_context_parallel_group, + get_context_parallel_group_rank, + is_context_parallel_initialized, +) +from ..modules.cp_enc_dec import _conv_split, _conv_gather + +logpy = logging.getLogger(__name__) + + +class AbstractAutoencoder(pl.LightningModule): + """ + This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators, + unCLIP models, etc. Hence, it is fairly general, and specific features + (e.g. discriminator training, encoding, decoding) must be implemented in subclasses. + """ + + def __init__( + self, + ema_decay: Union[None, float] = None, + monitor: Union[None, str] = None, + input_key: str = "jpg", + ): + super().__init__() + + self.input_key = input_key + self.use_ema = ema_decay is not None + if monitor is not None: + self.monitor = monitor + + if self.use_ema: + self.model_ema = LitEma(self, decay=ema_decay) + logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + if version.parse(torch.__version__) >= version.parse("2.0.0"): + self.automatic_optimization = False + + def apply_ckpt(self, ckpt: Union[None, str, dict]): + if ckpt is None: + return + if isinstance(ckpt, str): + ckpt = { + "target": "sgm.modules.checkpoint.CheckpointEngine", + "params": {"ckpt_path": ckpt}, + } + engine = instantiate_from_config(ckpt) + engine(self) + + @abstractmethod + def get_input(self, batch) -> Any: + raise NotImplementedError() + + def on_train_batch_end(self, *args, **kwargs): + # for EMA computation + if self.use_ema: + self.model_ema(self) + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.parameters()) + self.model_ema.copy_to(self) + if context is not None: + logpy.info(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.parameters()) + if context is not None: + logpy.info(f"{context}: Restored training weights") + + @abstractmethod + def encode(self, *args, **kwargs) -> torch.Tensor: + raise NotImplementedError("encode()-method of abstract base class called") + + @abstractmethod + def decode(self, *args, **kwargs) -> torch.Tensor: + raise NotImplementedError("decode()-method of abstract base class called") + + def instantiate_optimizer_from_config(self, params, lr, cfg): + logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config") + return get_obj_from_str(cfg["target"])(params, lr=lr, **cfg.get("params", dict())) + + def configure_optimizers(self) -> Any: + raise NotImplementedError() + + +class AutoencodingEngine(AbstractAutoencoder): + """ + Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL + (we also restore them explicitly as special cases for legacy reasons). + Regularizations such as KL or VQ are moved to the regularizer class. + """ + + def __init__( + self, + *args, + encoder_config: Dict, + decoder_config: Dict, + loss_config: Dict, + regularizer_config: Dict, + optimizer_config: Union[Dict, None] = None, + lr_g_factor: float = 1.0, + trainable_ae_params: Optional[List[List[str]]] = None, + ae_optimizer_args: Optional[List[dict]] = None, + trainable_disc_params: Optional[List[List[str]]] = None, + disc_optimizer_args: Optional[List[dict]] = None, + disc_start_iter: int = 0, + diff_boost_factor: float = 3.0, + ckpt_engine: Union[None, str, dict] = None, + ckpt_path: Optional[str] = None, + additional_decode_keys: Optional[List[str]] = None, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.automatic_optimization = False # pytorch lightning + + self.encoder: torch.nn.Module = instantiate_from_config(encoder_config) + self.decoder: torch.nn.Module = instantiate_from_config(decoder_config) + self.loss: torch.nn.Module = instantiate_from_config(loss_config) + self.regularization: AbstractRegularizer = instantiate_from_config(regularizer_config) + self.optimizer_config = default(optimizer_config, {"target": "torch.optim.Adam"}) + self.diff_boost_factor = diff_boost_factor + self.disc_start_iter = disc_start_iter + self.lr_g_factor = lr_g_factor + self.trainable_ae_params = trainable_ae_params + if self.trainable_ae_params is not None: + self.ae_optimizer_args = default( + ae_optimizer_args, + [{} for _ in range(len(self.trainable_ae_params))], + ) + assert len(self.ae_optimizer_args) == len(self.trainable_ae_params) + else: + self.ae_optimizer_args = [{}] # makes type consitent + + self.trainable_disc_params = trainable_disc_params + if self.trainable_disc_params is not None: + self.disc_optimizer_args = default( + disc_optimizer_args, + [{} for _ in range(len(self.trainable_disc_params))], + ) + assert len(self.disc_optimizer_args) == len(self.trainable_disc_params) + else: + self.disc_optimizer_args = [{}] # makes type consitent + + if ckpt_path is not None: + assert ckpt_engine is None, "Can't set ckpt_engine and ckpt_path" + logpy.warn("Checkpoint path is deprecated, use `checkpoint_egnine` instead") + self.apply_ckpt(default(ckpt_path, ckpt_engine)) + self.additional_decode_keys = set(default(additional_decode_keys, [])) + + def get_input(self, batch: Dict) -> torch.Tensor: + # assuming unified data format, dataloader returns a dict. + # image tensors should be scaled to -1 ... 1 and in channels-first + # format (e.g., bchw instead if bhwc) + return batch[self.input_key] + + def get_autoencoder_params(self) -> list: + params = [] + if hasattr(self.loss, "get_trainable_autoencoder_parameters"): + params += list(self.loss.get_trainable_autoencoder_parameters()) + if hasattr(self.regularization, "get_trainable_parameters"): + params += list(self.regularization.get_trainable_parameters()) + params = params + list(self.encoder.parameters()) + params = params + list(self.decoder.parameters()) + return params + + def get_discriminator_params(self) -> list: + if hasattr(self.loss, "get_trainable_parameters"): + params = list(self.loss.get_trainable_parameters()) # e.g., discriminator + else: + params = [] + return params + + def get_last_layer(self): + return self.decoder.get_last_layer() + + def encode( + self, + x: torch.Tensor, + return_reg_log: bool = False, + unregularized: bool = False, + **kwargs, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]: + z = self.encoder(x, **kwargs) + if unregularized: + return z, dict() + z, reg_log = self.regularization(z) + if return_reg_log: + return z, reg_log + return z + + def decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor: + x = self.decoder(z, **kwargs) + return x + + def forward(self, x: torch.Tensor, **additional_decode_kwargs) -> Tuple[torch.Tensor, torch.Tensor, dict]: + z, reg_log = self.encode(x, return_reg_log=True) + dec = self.decode(z, **additional_decode_kwargs) + return z, dec, reg_log + + def inner_training_step(self, batch: dict, batch_idx: int, optimizer_idx: int = 0) -> torch.Tensor: + x = self.get_input(batch) + additional_decode_kwargs = {key: batch[key] for key in self.additional_decode_keys.intersection(batch)} + z, xrec, regularization_log = self(x, **additional_decode_kwargs) + if hasattr(self.loss, "forward_keys"): + extra_info = { + "z": z, + "optimizer_idx": optimizer_idx, + "global_step": self.global_step, + "last_layer": self.get_last_layer(), + "split": "train", + "regularization_log": regularization_log, + "autoencoder": self, + } + extra_info = {k: extra_info[k] for k in self.loss.forward_keys} + else: + extra_info = dict() + + if optimizer_idx == 0: + # autoencode + out_loss = self.loss(x, xrec, **extra_info) + if isinstance(out_loss, tuple): + aeloss, log_dict_ae = out_loss + else: + # simple loss function + aeloss = out_loss + log_dict_ae = {"train/loss/rec": aeloss.detach()} + + self.log_dict( + log_dict_ae, + prog_bar=False, + logger=True, + on_step=True, + on_epoch=True, + sync_dist=False, + ) + self.log( + "loss", + aeloss.mean().detach(), + prog_bar=True, + logger=False, + on_epoch=False, + on_step=True, + ) + return aeloss + elif optimizer_idx == 1: + # discriminator + discloss, log_dict_disc = self.loss(x, xrec, **extra_info) + # -> discriminator always needs to return a tuple + self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return discloss + else: + raise NotImplementedError(f"Unknown optimizer {optimizer_idx}") + + def training_step(self, batch: dict, batch_idx: int): + opts = self.optimizers() + if not isinstance(opts, list): + # Non-adversarial case + opts = [opts] + optimizer_idx = batch_idx % len(opts) + if self.global_step < self.disc_start_iter: + optimizer_idx = 0 + opt = opts[optimizer_idx] + opt.zero_grad() + with opt.toggle_model(): + loss = self.inner_training_step(batch, batch_idx, optimizer_idx=optimizer_idx) + self.manual_backward(loss) + opt.step() + + def validation_step(self, batch: dict, batch_idx: int) -> Dict: + log_dict = self._validation_step(batch, batch_idx) + with self.ema_scope(): + log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema") + log_dict.update(log_dict_ema) + return log_dict + + def _validation_step(self, batch: dict, batch_idx: int, postfix: str = "") -> Dict: + x = self.get_input(batch) + + z, xrec, regularization_log = self(x) + if hasattr(self.loss, "forward_keys"): + extra_info = { + "z": z, + "optimizer_idx": 0, + "global_step": self.global_step, + "last_layer": self.get_last_layer(), + "split": "val" + postfix, + "regularization_log": regularization_log, + "autoencoder": self, + } + extra_info = {k: extra_info[k] for k in self.loss.forward_keys} + else: + extra_info = dict() + out_loss = self.loss(x, xrec, **extra_info) + if isinstance(out_loss, tuple): + aeloss, log_dict_ae = out_loss + else: + # simple loss function + aeloss = out_loss + log_dict_ae = {f"val{postfix}/loss/rec": aeloss.detach()} + full_log_dict = log_dict_ae + + if "optimizer_idx" in extra_info: + extra_info["optimizer_idx"] = 1 + discloss, log_dict_disc = self.loss(x, xrec, **extra_info) + full_log_dict.update(log_dict_disc) + self.log( + f"val{postfix}/loss/rec", + log_dict_ae[f"val{postfix}/loss/rec"], + sync_dist=True, + ) + self.log_dict(full_log_dict, sync_dist=True) + return full_log_dict + + def get_param_groups( + self, parameter_names: List[List[str]], optimizer_args: List[dict] + ) -> Tuple[List[Dict[str, Any]], int]: + groups = [] + num_params = 0 + for names, args in zip(parameter_names, optimizer_args): + params = [] + for pattern_ in names: + pattern_params = [] + pattern = re.compile(pattern_) + for p_name, param in self.named_parameters(): + if re.match(pattern, p_name): + pattern_params.append(param) + num_params += param.numel() + if len(pattern_params) == 0: + logpy.warn(f"Did not find parameters for pattern {pattern_}") + params.extend(pattern_params) + groups.append({"params": params, **args}) + return groups, num_params + + def configure_optimizers(self) -> List[torch.optim.Optimizer]: + if self.trainable_ae_params is None: + ae_params = self.get_autoencoder_params() + else: + ae_params, num_ae_params = self.get_param_groups(self.trainable_ae_params, self.ae_optimizer_args) + logpy.info(f"Number of trainable autoencoder parameters: {num_ae_params:,}") + if self.trainable_disc_params is None: + disc_params = self.get_discriminator_params() + else: + disc_params, num_disc_params = self.get_param_groups(self.trainable_disc_params, self.disc_optimizer_args) + logpy.info(f"Number of trainable discriminator parameters: {num_disc_params:,}") + opt_ae = self.instantiate_optimizer_from_config( + ae_params, + default(self.lr_g_factor, 1.0) * self.learning_rate, + self.optimizer_config, + ) + opts = [opt_ae] + if len(disc_params) > 0: + opt_disc = self.instantiate_optimizer_from_config(disc_params, self.learning_rate, self.optimizer_config) + opts.append(opt_disc) + + return opts + + @torch.no_grad() + def log_images(self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs) -> dict: + log = dict() + additional_decode_kwargs = {} + x = self.get_input(batch) + additional_decode_kwargs.update({key: batch[key] for key in self.additional_decode_keys.intersection(batch)}) + + _, xrec, _ = self(x, **additional_decode_kwargs) + log["inputs"] = x + log["reconstructions"] = xrec + diff = 0.5 * torch.abs(torch.clamp(xrec, -1.0, 1.0) - x) + diff.clamp_(0, 1.0) + log["diff"] = 2.0 * diff - 1.0 + # diff_boost shows location of small errors, by boosting their + # brightness. + log["diff_boost"] = 2.0 * torch.clamp(self.diff_boost_factor * diff, 0.0, 1.0) - 1 + if hasattr(self.loss, "log_images"): + log.update(self.loss.log_images(x, xrec)) + with self.ema_scope(): + _, xrec_ema, _ = self(x, **additional_decode_kwargs) + log["reconstructions_ema"] = xrec_ema + diff_ema = 0.5 * torch.abs(torch.clamp(xrec_ema, -1.0, 1.0) - x) + diff_ema.clamp_(0, 1.0) + log["diff_ema"] = 2.0 * diff_ema - 1.0 + log["diff_boost_ema"] = 2.0 * torch.clamp(self.diff_boost_factor * diff_ema, 0.0, 1.0) - 1 + if additional_log_kwargs: + additional_decode_kwargs.update(additional_log_kwargs) + _, xrec_add, _ = self(x, **additional_decode_kwargs) + log_str = "reconstructions-" + "-".join( + [f"{key}={additional_log_kwargs[key]}" for key in additional_log_kwargs] + ) + log[log_str] = xrec_add + return log + + +class AutoencodingEngineLegacy(AutoencodingEngine): + def __init__(self, embed_dim: int, **kwargs): + self.max_batch_size = kwargs.pop("max_batch_size", None) + ddconfig = kwargs.pop("ddconfig") + ckpt_path = kwargs.pop("ckpt_path", None) + ckpt_engine = kwargs.pop("ckpt_engine", None) + super().__init__( + encoder_config={ + "target": "sgm.modules.diffusionmodules.model.Encoder", + "params": ddconfig, + }, + decoder_config={ + "target": "sgm.modules.diffusionmodules.model.Decoder", + "params": ddconfig, + }, + **kwargs, + ) + self.quant_conv = torch.nn.Conv2d( + (1 + ddconfig["double_z"]) * ddconfig["z_channels"], + (1 + ddconfig["double_z"]) * embed_dim, + 1, + ) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + self.embed_dim = embed_dim + + self.apply_ckpt(default(ckpt_path, ckpt_engine)) + + def get_autoencoder_params(self) -> list: + params = super().get_autoencoder_params() + return params + + def encode(self, x: torch.Tensor, return_reg_log: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]: + if self.max_batch_size is None: + z = self.encoder(x) + z = self.quant_conv(z) + else: + N = x.shape[0] + bs = self.max_batch_size + n_batches = int(math.ceil(N / bs)) + z = list() + for i_batch in range(n_batches): + z_batch = self.encoder(x[i_batch * bs : (i_batch + 1) * bs]) + z_batch = self.quant_conv(z_batch) + z.append(z_batch) + z = torch.cat(z, 0) + + z, reg_log = self.regularization(z) + if return_reg_log: + return z, reg_log + return z + + def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor: + if self.max_batch_size is None: + dec = self.post_quant_conv(z) + dec = self.decoder(dec, **decoder_kwargs) + else: + N = z.shape[0] + bs = self.max_batch_size + n_batches = int(math.ceil(N / bs)) + dec = list() + for i_batch in range(n_batches): + dec_batch = self.post_quant_conv(z[i_batch * bs : (i_batch + 1) * bs]) + dec_batch = self.decoder(dec_batch, **decoder_kwargs) + dec.append(dec_batch) + dec = torch.cat(dec, 0) + + return dec + + +class IdentityFirstStage(AbstractAutoencoder): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def get_input(self, x: Any) -> Any: + return x + + def encode(self, x: Any, *args, **kwargs) -> Any: + return x + + def decode(self, x: Any, *args, **kwargs) -> Any: + return + + +class VideoAutoencodingEngine(AutoencodingEngine): + def __init__( + self, + ckpt_path: Union[None, str] = None, + ignore_keys: Union[Tuple, list] = (), + image_video_weights=[1, 1], + only_train_decoder=False, + context_parallel_size=0, + **kwargs, + ): + super().__init__(**kwargs) + self.context_parallel_size = context_parallel_size + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + + def log_videos(self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs) -> dict: + return self.log_images(batch, additional_log_kwargs, **kwargs) + + def get_input(self, batch: dict) -> torch.Tensor: + if self.context_parallel_size > 0: + if not is_context_parallel_initialized(): + initialize_context_parallel(self.context_parallel_size) + + batch = batch[self.input_key] + + global_src_rank = get_context_parallel_group_rank() * self.context_parallel_size + torch.distributed.broadcast(batch, src=global_src_rank, group=get_context_parallel_group()) + + batch = _conv_split(batch, dim=2, kernel_size=1) + return batch + + return batch[self.input_key] + + def apply_ckpt(self, ckpt: Union[None, str, dict]): + if ckpt is None: + return + self.init_from_ckpt(ckpt) + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + del sd[k] + missing_keys, unexpected_keys = self.load_state_dict(sd, strict=False) + print("Missing keys: ", missing_keys) + print("Unexpected keys: ", unexpected_keys) + print(f"Restored from {path}") diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/__init__.py b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0db1d7716a6e48f77b86a4b59c9289d6fb76b50b --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/__init__.py @@ -0,0 +1,6 @@ +from .encoders.modules import GeneralConditioner + +UNCONDITIONAL_CONFIG = { + "target": "sgm.modules.GeneralConditioner", + "params": {"emb_models": []}, +} diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/attention.py b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..bb241571c6ebf511b57fa95f16bde4d0ee35f9b7 --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/attention.py @@ -0,0 +1,572 @@ +import math +from inspect import isfunction +from typing import Any, Optional + +import torch +import torch.nn.functional as F +from einops import rearrange, repeat +from packaging import version +from torch import nn + +if version.parse(torch.__version__) >= version.parse("2.0.0"): + SDP_IS_AVAILABLE = True + from torch.backends.cuda import SDPBackend, sdp_kernel + + BACKEND_MAP = { + SDPBackend.MATH: { + "enable_math": True, + "enable_flash": False, + "enable_mem_efficient": False, + }, + SDPBackend.FLASH_ATTENTION: { + "enable_math": False, + "enable_flash": True, + "enable_mem_efficient": False, + }, + SDPBackend.EFFICIENT_ATTENTION: { + "enable_math": False, + "enable_flash": False, + "enable_mem_efficient": True, + }, + None: {"enable_math": True, "enable_flash": True, "enable_mem_efficient": True}, + } +else: + from contextlib import nullcontext + + SDP_IS_AVAILABLE = False + sdp_kernel = nullcontext + BACKEND_MAP = {} + print( + f"No SDP backend available, likely because you are running in pytorch versions < 2.0. In fact, " + f"you are using PyTorch {torch.__version__}. You might want to consider upgrading." + ) + +try: + import xformers + import xformers.ops + + XFORMERS_IS_AVAILABLE = True +except: + XFORMERS_IS_AVAILABLE = False + print("no module 'xformers'. Processing without...") + +from .diffusionmodules.util import checkpoint + + +def exists(val): + return val is not None + + +def uniq(arr): + return {el: True for el in arr}.keys() + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def max_neg_value(t): + return -torch.finfo(t.dtype).max + + +def init_(tensor): + dim = tensor.shape[-1] + std = 1 / math.sqrt(dim) + tensor.uniform_(-std, std) + return tensor + + +# feedforward +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)) + + def forward(self, x): + return self.net(x) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def Normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +class LinearAttention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3) + k = k.softmax(dim=-1) + context = torch.einsum("bhdn,bhen->bhde", k, v) + out = torch.einsum("bhde,bhdn->bhen", context, q) + out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w) + return self.to_out(out) + + +class SpatialSelfAttention(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = rearrange(q, "b c h w -> b (h w) c") + k = rearrange(k, "b c h w -> b c (h w)") + w_ = torch.einsum("bij,bjk->bik", q, k) + + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = rearrange(v, "b c h w -> b c (h w)") + w_ = rearrange(w_, "b i j -> b j i") + h_ = torch.einsum("bij,bjk->bik", v, w_) + h_ = rearrange(h_, "b c (h w) -> b c h w", h=h) + h_ = self.proj_out(h_) + + return x + h_ + + +class CrossAttention(nn.Module): + def __init__( + self, + query_dim, + context_dim=None, + heads=8, + dim_head=64, + dropout=0.0, + backend=None, + ): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head**-0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + self.backend = backend + + def forward( + self, + x, + context=None, + mask=None, + additional_tokens=None, + n_times_crossframe_attn_in_self=0, + ): + h = self.heads + + if additional_tokens is not None: + # get the number of masked tokens at the beginning of the output sequence + n_tokens_to_mask = additional_tokens.shape[1] + # add additional token + x = torch.cat([additional_tokens, x], dim=1) + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + if n_times_crossframe_attn_in_self: + # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439 + assert x.shape[0] % n_times_crossframe_attn_in_self == 0 + n_cp = x.shape[0] // n_times_crossframe_attn_in_self + k = repeat(k[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp) + v = repeat(v[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp) + + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) + + # old + """ + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + del q, k + + if exists(mask): + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + sim = sim.softmax(dim=-1) + + out = einsum('b i j, b j d -> b i d', sim, v) + """ + # new + with sdp_kernel(**BACKEND_MAP[self.backend]): + # print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape) + out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask) # scale is dim_head ** -0.5 per default + + del q, k, v + out = rearrange(out, "b h n d -> b n (h d)", h=h) + + if additional_tokens is not None: + # remove additional token + out = out[:, n_tokens_to_mask:] + return self.to_out(out) + + +class MemoryEfficientCrossAttention(nn.Module): + # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs): + super().__init__() + print( + f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using " + f"{heads} heads with a dimension of {dim_head}." + ) + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.heads = heads + self.dim_head = dim_head + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + self.attention_op: Optional[Any] = None + + def forward( + self, + x, + context=None, + mask=None, + additional_tokens=None, + n_times_crossframe_attn_in_self=0, + ): + if additional_tokens is not None: + # get the number of masked tokens at the beginning of the output sequence + n_tokens_to_mask = additional_tokens.shape[1] + # add additional token + x = torch.cat([additional_tokens, x], dim=1) + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + if n_times_crossframe_attn_in_self: + # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439 + assert x.shape[0] % n_times_crossframe_attn_in_self == 0 + # n_cp = x.shape[0]//n_times_crossframe_attn_in_self + k = repeat( + k[::n_times_crossframe_attn_in_self], + "b ... -> (b n) ...", + n=n_times_crossframe_attn_in_self, + ) + v = repeat( + v[::n_times_crossframe_attn_in_self], + "b ... -> (b n) ...", + n=n_times_crossframe_attn_in_self, + ) + + b, _, _ = q.shape + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(b, t.shape[1], self.heads, self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b * self.heads, t.shape[1], self.dim_head) + .contiguous(), + (q, k, v), + ) + + # actually compute the attention, what we cannot get enough of + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) + + # TODO: Use this directly in the attention operation, as a bias + if exists(mask): + raise NotImplementedError + out = ( + out.unsqueeze(0) + .reshape(b, self.heads, out.shape[1], self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b, out.shape[1], self.heads * self.dim_head) + ) + if additional_tokens is not None: + # remove additional token + out = out[:, n_tokens_to_mask:] + return self.to_out(out) + + +class BasicTransformerBlock(nn.Module): + ATTENTION_MODES = { + "softmax": CrossAttention, # vanilla attention + "softmax-xformers": MemoryEfficientCrossAttention, # ampere + } + + def __init__( + self, + dim, + n_heads, + d_head, + dropout=0.0, + context_dim=None, + gated_ff=True, + checkpoint=True, + disable_self_attn=False, + attn_mode="softmax", + sdp_backend=None, + ): + super().__init__() + assert attn_mode in self.ATTENTION_MODES + if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE: + print( + f"Attention mode '{attn_mode}' is not available. Falling back to native attention. " + f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}" + ) + attn_mode = "softmax" + elif attn_mode == "softmax" and not SDP_IS_AVAILABLE: + print("We do not support vanilla attention anymore, as it is too expensive. Sorry.") + if not XFORMERS_IS_AVAILABLE: + assert False, "Please install xformers via e.g. 'pip install xformers==0.0.16'" + else: + print("Falling back to xformers efficient attention.") + attn_mode = "softmax-xformers" + attn_cls = self.ATTENTION_MODES[attn_mode] + if version.parse(torch.__version__) >= version.parse("2.0.0"): + assert sdp_backend is None or isinstance(sdp_backend, SDPBackend) + else: + assert sdp_backend is None + self.disable_self_attn = disable_self_attn + self.attn1 = attn_cls( + query_dim=dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + context_dim=context_dim if self.disable_self_attn else None, + backend=sdp_backend, + ) # is a self-attention if not self.disable_self_attn + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = attn_cls( + query_dim=dim, + context_dim=context_dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + backend=sdp_backend, + ) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + if self.checkpoint: + print(f"{self.__class__.__name__} is using checkpointing") + + def forward(self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): + kwargs = {"x": x} + + if context is not None: + kwargs.update({"context": context}) + + if additional_tokens is not None: + kwargs.update({"additional_tokens": additional_tokens}) + + if n_times_crossframe_attn_in_self: + kwargs.update({"n_times_crossframe_attn_in_self": n_times_crossframe_attn_in_self}) + + # return mixed_checkpoint(self._forward, kwargs, self.parameters(), self.checkpoint) + return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) + + def _forward(self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): + x = ( + self.attn1( + self.norm1(x), + context=context if self.disable_self_attn else None, + additional_tokens=additional_tokens, + n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self if not self.disable_self_attn else 0, + ) + + x + ) + x = self.attn2(self.norm2(x), context=context, additional_tokens=additional_tokens) + x + x = self.ff(self.norm3(x)) + x + return x + + +class BasicTransformerSingleLayerBlock(nn.Module): + ATTENTION_MODES = { + "softmax": CrossAttention, # vanilla attention + "softmax-xformers": MemoryEfficientCrossAttention, # on the A100s not quite as fast as the above version + # (todo might depend on head_dim, check, falls back to semi-optimized kernels for dim!=[16,32,64,128]) + } + + def __init__( + self, + dim, + n_heads, + d_head, + dropout=0.0, + context_dim=None, + gated_ff=True, + checkpoint=True, + attn_mode="softmax", + ): + super().__init__() + assert attn_mode in self.ATTENTION_MODES + attn_cls = self.ATTENTION_MODES[attn_mode] + self.attn1 = attn_cls( + query_dim=dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + context_dim=context_dim, + ) + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward(self, x, context=None): + return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) + + def _forward(self, x, context=None): + x = self.attn1(self.norm1(x), context=context) + x + x = self.ff(self.norm2(x)) + x + return x + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + NEW: use_linear for more efficiency instead of the 1x1 convs + """ + + def __init__( + self, + in_channels, + n_heads, + d_head, + depth=1, + dropout=0.0, + context_dim=None, + disable_self_attn=False, + use_linear=False, + attn_type="softmax", + use_checkpoint=True, + # sdp_backend=SDPBackend.FLASH_ATTENTION + sdp_backend=None, + ): + super().__init__() + print(f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads") + from omegaconf import ListConfig + + if exists(context_dim) and not isinstance(context_dim, (list, ListConfig)): + context_dim = [context_dim] + if exists(context_dim) and isinstance(context_dim, list): + if depth != len(context_dim): + print( + f"WARNING: {self.__class__.__name__}: Found context dims {context_dim} of depth {len(context_dim)}, " + f"which does not match the specified 'depth' of {depth}. Setting context_dim to {depth * [context_dim[0]]} now." + ) + # depth does not match context dims. + assert all( + map(lambda x: x == context_dim[0], context_dim) + ), "need homogenous context_dim to match depth automatically" + context_dim = depth * [context_dim[0]] + elif context_dim is None: + context_dim = [None] * depth + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = Normalize(in_channels) + if not use_linear: + self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + else: + self.proj_in = nn.Linear(in_channels, inner_dim) + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + n_heads, + d_head, + dropout=dropout, + context_dim=context_dim[d], + disable_self_attn=disable_self_attn, + attn_mode=attn_type, + checkpoint=use_checkpoint, + sdp_backend=sdp_backend, + ) + for d in range(depth) + ] + ) + if not use_linear: + self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)) + else: + # self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) + self.proj_out = zero_module(nn.Linear(inner_dim, in_channels)) + self.use_linear = use_linear + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + if not isinstance(context, list): + context = [context] + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + if not self.use_linear: + x = self.proj_in(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + if self.use_linear: + x = self.proj_in(x) + for i, block in enumerate(self.transformer_blocks): + if i > 0 and len(context) == 1: + i = 0 # use same context for each block + x = block(x, context=context[i]) + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous() + if not self.use_linear: + x = self.proj_out(x) + return x + x_in diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/__init__.py b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/losses/__init__.py b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b3bb81d91cd91637bef2e04f8b9dcda5af4c7c2a --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/losses/__init__.py @@ -0,0 +1,8 @@ +__all__ = [ + "GeneralLPIPSWithDiscriminator", + "LatentLPIPS", +] + +from .discriminator_loss import GeneralLPIPSWithDiscriminator +from .lpips import LatentLPIPS +from .video_loss import VideoAutoencoderLoss diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/losses/discriminator_loss.py b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/losses/discriminator_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..b5b144a043c7eccfaaef56adc5d2d7896a1849ae --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/losses/discriminator_loss.py @@ -0,0 +1,301 @@ +from typing import Dict, Iterator, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torchvision +from einops import rearrange +from matplotlib import colormaps +from matplotlib import pyplot as plt + +from ....util import default, instantiate_from_config +from ..lpips.loss.lpips import LPIPS +from ..lpips.model.model import weights_init +from ..lpips.vqperceptual import hinge_d_loss, vanilla_d_loss + + +class GeneralLPIPSWithDiscriminator(nn.Module): + def __init__( + self, + disc_start: int, + logvar_init: float = 0.0, + disc_num_layers: int = 3, + disc_in_channels: int = 3, + disc_factor: float = 1.0, + disc_weight: float = 1.0, + perceptual_weight: float = 1.0, + disc_loss: str = "hinge", + scale_input_to_tgt_size: bool = False, + dims: int = 2, + learn_logvar: bool = False, + regularization_weights: Union[None, Dict[str, float]] = None, + additional_log_keys: Optional[List[str]] = None, + discriminator_config: Optional[Dict] = None, + ): + super().__init__() + self.dims = dims + if self.dims > 2: + print( + f"running with dims={dims}. This means that for perceptual loss " + f"calculation, the LPIPS loss will be applied to each frame " + f"independently." + ) + self.scale_input_to_tgt_size = scale_input_to_tgt_size + assert disc_loss in ["hinge", "vanilla"] + self.perceptual_loss = LPIPS().eval() + self.perceptual_weight = perceptual_weight + # output log variance + self.logvar = nn.Parameter(torch.full((), logvar_init), requires_grad=learn_logvar) + self.learn_logvar = learn_logvar + + discriminator_config = default( + discriminator_config, + { + "target": "sgm.modules.autoencoding.lpips.model.model.NLayerDiscriminator", + "params": { + "input_nc": disc_in_channels, + "n_layers": disc_num_layers, + "use_actnorm": False, + }, + }, + ) + + self.discriminator = instantiate_from_config(discriminator_config).apply(weights_init) + self.discriminator_iter_start = disc_start + self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss + self.disc_factor = disc_factor + self.discriminator_weight = disc_weight + self.regularization_weights = default(regularization_weights, {}) + + self.forward_keys = [ + "optimizer_idx", + "global_step", + "last_layer", + "split", + "regularization_log", + ] + + self.additional_log_keys = set(default(additional_log_keys, [])) + self.additional_log_keys.update(set(self.regularization_weights.keys())) + + def get_trainable_parameters(self) -> Iterator[nn.Parameter]: + return self.discriminator.parameters() + + def get_trainable_autoencoder_parameters(self) -> Iterator[nn.Parameter]: + if self.learn_logvar: + yield self.logvar + yield from () + + @torch.no_grad() + def log_images(self, inputs: torch.Tensor, reconstructions: torch.Tensor) -> Dict[str, torch.Tensor]: + # calc logits of real/fake + logits_real = self.discriminator(inputs.contiguous().detach()) + if len(logits_real.shape) < 4: + # Non patch-discriminator + return dict() + logits_fake = self.discriminator(reconstructions.contiguous().detach()) + # -> (b, 1, h, w) + + # parameters for colormapping + high = max(logits_fake.abs().max(), logits_real.abs().max()).item() + cmap = colormaps["PiYG"] # diverging colormap + + def to_colormap(logits: torch.Tensor) -> torch.Tensor: + """(b, 1, ...) -> (b, 3, ...)""" + logits = (logits + high) / (2 * high) + logits_np = cmap(logits.cpu().numpy())[..., :3] # truncate alpha channel + # -> (b, 1, ..., 3) + logits = torch.from_numpy(logits_np).to(logits.device) + return rearrange(logits, "b 1 ... c -> b c ...") + + logits_real = torch.nn.functional.interpolate( + logits_real, + size=inputs.shape[-2:], + mode="nearest", + antialias=False, + ) + logits_fake = torch.nn.functional.interpolate( + logits_fake, + size=reconstructions.shape[-2:], + mode="nearest", + antialias=False, + ) + + # alpha value of logits for overlay + alpha_real = torch.abs(logits_real) / high + alpha_fake = torch.abs(logits_fake) / high + # -> (b, 1, h, w) in range [0, 0.5] + # alpha value of lines don't really matter, since the values are the same + # for both images and logits anyway + grid_alpha_real = torchvision.utils.make_grid(alpha_real, nrow=4) + grid_alpha_fake = torchvision.utils.make_grid(alpha_fake, nrow=4) + grid_alpha = 0.8 * torch.cat((grid_alpha_real, grid_alpha_fake), dim=1) + # -> (1, h, w) + # blend logits and images together + + # prepare logits for plotting + logits_real = to_colormap(logits_real) + logits_fake = to_colormap(logits_fake) + # resize logits + # -> (b, 3, h, w) + + # make some grids + # add all logits to one plot + logits_real = torchvision.utils.make_grid(logits_real, nrow=4) + logits_fake = torchvision.utils.make_grid(logits_fake, nrow=4) + # I just love how torchvision calls the number of columns `nrow` + grid_logits = torch.cat((logits_real, logits_fake), dim=1) + # -> (3, h, w) + + grid_images_real = torchvision.utils.make_grid(0.5 * inputs + 0.5, nrow=4) + grid_images_fake = torchvision.utils.make_grid(0.5 * reconstructions + 0.5, nrow=4) + grid_images = torch.cat((grid_images_real, grid_images_fake), dim=1) + # -> (3, h, w) in range [0, 1] + + grid_blend = grid_alpha * grid_logits + (1 - grid_alpha) * grid_images + + # Create labeled colorbar + dpi = 100 + height = 128 / dpi + width = grid_logits.shape[2] / dpi + fig, ax = plt.subplots(figsize=(width, height), dpi=dpi) + img = ax.imshow(np.array([[-high, high]]), cmap=cmap) + plt.colorbar( + img, + cax=ax, + orientation="horizontal", + fraction=0.9, + aspect=width / height, + pad=0.0, + ) + img.set_visible(False) + fig.tight_layout() + fig.canvas.draw() + # manually convert figure to numpy + cbar_np = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) + cbar_np = cbar_np.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + cbar = torch.from_numpy(cbar_np.copy()).to(grid_logits.dtype) / 255.0 + cbar = rearrange(cbar, "h w c -> c h w").to(grid_logits.device) + + # Add colorbar to plot + annotated_grid = torch.cat((grid_logits, cbar), dim=1) + blended_grid = torch.cat((grid_blend, cbar), dim=1) + return { + "vis_logits": 2 * annotated_grid[None, ...] - 1, + "vis_logits_blended": 2 * blended_grid[None, ...] - 1, + } + + def calculate_adaptive_weight( + self, nll_loss: torch.Tensor, g_loss: torch.Tensor, last_layer: torch.Tensor + ) -> torch.Tensor: + nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] + + d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) + d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() + d_weight = d_weight * self.discriminator_weight + return d_weight + + def forward( + self, + inputs: torch.Tensor, + reconstructions: torch.Tensor, + *, # added because I changed the order here + regularization_log: Dict[str, torch.Tensor], + optimizer_idx: int, + global_step: int, + last_layer: torch.Tensor, + split: str = "train", + weights: Union[None, float, torch.Tensor] = None, + ) -> Tuple[torch.Tensor, dict]: + if self.scale_input_to_tgt_size: + inputs = torch.nn.functional.interpolate(inputs, reconstructions.shape[2:], mode="bicubic", antialias=True) + + if self.dims > 2: + inputs, reconstructions = map( + lambda x: rearrange(x, "b c t h w -> (b t) c h w"), + (inputs, reconstructions), + ) + + rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) + if self.perceptual_weight > 0: + frame_indices = torch.randn((inputs.shape[0], inputs.shape[2])).topk(1, dim=-1).indices + + from sgm.modules.autoencoding.losses.video_loss import pick_video_frame + + input_frames = pick_video_frame(inputs, frame_indices) + recon_frames = pick_video_frame(reconstructions, frame_indices) + + p_loss = self.perceptual_loss(input_frames.contiguous(), recon_frames.contiguous()).mean() + rec_loss = rec_loss + self.perceptual_weight * p_loss + + nll_loss, weighted_nll_loss = self.get_nll_loss(rec_loss, weights) + + # now the GAN part + if optimizer_idx == 0: + # generator update + if global_step >= self.discriminator_iter_start or not self.training: + logits_fake = self.discriminator(reconstructions.contiguous()) + g_loss = -torch.mean(logits_fake) + if self.training: + d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) + else: + d_weight = torch.tensor(1.0) + else: + d_weight = torch.tensor(0.0) + g_loss = torch.tensor(0.0, requires_grad=True) + + loss = weighted_nll_loss + d_weight * self.disc_factor * g_loss + log = dict() + for k in regularization_log: + if k in self.regularization_weights: + loss = loss + self.regularization_weights[k] * regularization_log[k] + if k in self.additional_log_keys: + log[f"{split}/{k}"] = regularization_log[k].detach().float().mean() + + log.update( + { + f"{split}/loss/total": loss.clone().detach().mean(), + f"{split}/loss/nll": nll_loss.detach().mean(), + f"{split}/loss/rec": rec_loss.detach().mean(), + f"{split}/loss/percep": p_loss.detach().mean(), + f"{split}/loss/rec": rec_loss.detach().mean(), + f"{split}/loss/g": g_loss.detach().mean(), + f"{split}/scalars/logvar": self.logvar.detach(), + f"{split}/scalars/d_weight": d_weight.detach(), + } + ) + + return loss, log + elif optimizer_idx == 1: + # second pass for discriminator update + logits_real = self.discriminator(inputs.contiguous().detach()) + logits_fake = self.discriminator(reconstructions.contiguous().detach()) + + if global_step >= self.discriminator_iter_start or not self.training: + d_loss = self.disc_factor * self.disc_loss(logits_real, logits_fake) + else: + d_loss = torch.tensor(0.0, requires_grad=True) + + log = { + f"{split}/loss/disc": d_loss.clone().detach().mean(), + f"{split}/logits/real": logits_real.detach().mean(), + f"{split}/logits/fake": logits_fake.detach().mean(), + } + return d_loss, log + else: + raise NotImplementedError(f"Unknown optimizer_idx {optimizer_idx}") + + def get_nll_loss( + self, + rec_loss: torch.Tensor, + weights: Optional[Union[float, torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar + weighted_nll_loss = nll_loss + if weights is not None: + weighted_nll_loss = weights * nll_loss + weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] + nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] + + return nll_loss, weighted_nll_loss diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/losses/lpips.py b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/losses/lpips.py new file mode 100644 index 0000000000000000000000000000000000000000..fed01d64fbc696af237c267ed6b9cb4ed790ab70 --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/losses/lpips.py @@ -0,0 +1,64 @@ +import torch +import torch.nn as nn + +from ....util import default, instantiate_from_config +from ..lpips.loss.lpips import LPIPS + + +class LatentLPIPS(nn.Module): + def __init__( + self, + decoder_config, + perceptual_weight=1.0, + latent_weight=1.0, + scale_input_to_tgt_size=False, + scale_tgt_to_input_size=False, + perceptual_weight_on_inputs=0.0, + ): + super().__init__() + self.scale_input_to_tgt_size = scale_input_to_tgt_size + self.scale_tgt_to_input_size = scale_tgt_to_input_size + self.init_decoder(decoder_config) + self.perceptual_loss = LPIPS().eval() + self.perceptual_weight = perceptual_weight + self.latent_weight = latent_weight + self.perceptual_weight_on_inputs = perceptual_weight_on_inputs + + def init_decoder(self, config): + self.decoder = instantiate_from_config(config) + if hasattr(self.decoder, "encoder"): + del self.decoder.encoder + + def forward(self, latent_inputs, latent_predictions, image_inputs, split="train"): + log = dict() + loss = (latent_inputs - latent_predictions) ** 2 + log[f"{split}/latent_l2_loss"] = loss.mean().detach() + image_reconstructions = None + if self.perceptual_weight > 0.0: + image_reconstructions = self.decoder.decode(latent_predictions) + image_targets = self.decoder.decode(latent_inputs) + perceptual_loss = self.perceptual_loss(image_targets.contiguous(), image_reconstructions.contiguous()) + loss = self.latent_weight * loss.mean() + self.perceptual_weight * perceptual_loss.mean() + log[f"{split}/perceptual_loss"] = perceptual_loss.mean().detach() + + if self.perceptual_weight_on_inputs > 0.0: + image_reconstructions = default(image_reconstructions, self.decoder.decode(latent_predictions)) + if self.scale_input_to_tgt_size: + image_inputs = torch.nn.functional.interpolate( + image_inputs, + image_reconstructions.shape[2:], + mode="bicubic", + antialias=True, + ) + elif self.scale_tgt_to_input_size: + image_reconstructions = torch.nn.functional.interpolate( + image_reconstructions, + image_inputs.shape[2:], + mode="bicubic", + antialias=True, + ) + + perceptual_loss2 = self.perceptual_loss(image_inputs.contiguous(), image_reconstructions.contiguous()) + loss = loss + self.perceptual_weight_on_inputs * perceptual_loss2.mean() + log[f"{split}/perceptual_loss_on_inputs"] = perceptual_loss2.mean().detach() + return loss, log diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/losses/video_loss.py b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/losses/video_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..01c4c60b6275eef46042a0c471554bb32f7eff7e --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/losses/video_loss.py @@ -0,0 +1,712 @@ +from typing import Any, Union +from math import log2 +from beartype import beartype + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.autograd import grad as torch_grad +from torch.cuda.amp import autocast + +import torchvision +from torchvision.models import VGG16_Weights +from einops import rearrange, einsum, repeat +from einops.layers.torch import Rearrange +from kornia.filters import filter3d + +from ..magvit2_pytorch import Residual, FeedForward, LinearSpaceAttention +from .lpips import LPIPS + +from sgm.modules.autoencoding.vqvae.movq_enc_3d import CausalConv3d, DownSample3D +from sgm.util import instantiate_from_config + + +def exists(v): + return v is not None + + +def pair(t): + return t if isinstance(t, tuple) else (t, t) + + +def leaky_relu(p=0.1): + return nn.LeakyReLU(p) + + +def hinge_discr_loss(fake, real): + return (F.relu(1 + fake) + F.relu(1 - real)).mean() + + +def hinge_gen_loss(fake): + return -fake.mean() + + +@autocast(enabled=False) +@beartype +def grad_layer_wrt_loss(loss: Tensor, layer: nn.Parameter): + return torch_grad(outputs=loss, inputs=layer, grad_outputs=torch.ones_like(loss), retain_graph=True)[0].detach() + + +def pick_video_frame(video, frame_indices): + batch, device = video.shape[0], video.device + video = rearrange(video, "b c f ... -> b f c ...") + batch_indices = torch.arange(batch, device=device) + batch_indices = rearrange(batch_indices, "b -> b 1") + images = video[batch_indices, frame_indices] + images = rearrange(images, "b 1 c ... -> b c ...") + return images + + +def gradient_penalty(images, output): + batch_size = images.shape[0] + + gradients = torch_grad( + outputs=output, + inputs=images, + grad_outputs=torch.ones(output.size(), device=images.device), + create_graph=True, + retain_graph=True, + only_inputs=True, + )[0] + + gradients = rearrange(gradients, "b ... -> b (...)") + return ((gradients.norm(2, dim=1) - 1) ** 2).mean() + + +# discriminator with anti-aliased downsampling (blurpool Zhang et al.) + + +class Blur(nn.Module): + def __init__(self): + super().__init__() + f = torch.Tensor([1, 2, 1]) + self.register_buffer("f", f) + + def forward(self, x, space_only=False, time_only=False): + assert not (space_only and time_only) + + f = self.f + + if space_only: + f = einsum("i, j -> i j", f, f) + f = rearrange(f, "... -> 1 1 ...") + elif time_only: + f = rearrange(f, "f -> 1 f 1 1") + else: + f = einsum("i, j, k -> i j k", f, f, f) + f = rearrange(f, "... -> 1 ...") + + is_images = x.ndim == 4 + + if is_images: + x = rearrange(x, "b c h w -> b c 1 h w") + + out = filter3d(x, f, normalized=True) + + if is_images: + out = rearrange(out, "b c 1 h w -> b c h w") + + return out + + +class DiscriminatorBlock(nn.Module): + def __init__(self, input_channels, filters, downsample=True, antialiased_downsample=True): + super().__init__() + self.conv_res = nn.Conv2d(input_channels, filters, 1, stride=(2 if downsample else 1)) + + self.net = nn.Sequential( + nn.Conv2d(input_channels, filters, 3, padding=1), + leaky_relu(), + nn.Conv2d(filters, filters, 3, padding=1), + leaky_relu(), + ) + + self.maybe_blur = Blur() if antialiased_downsample else None + + self.downsample = ( + nn.Sequential( + Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2), nn.Conv2d(filters * 4, filters, 1) + ) + if downsample + else None + ) + + def forward(self, x): + res = self.conv_res(x) + + x = self.net(x) + + if exists(self.downsample): + if exists(self.maybe_blur): + x = self.maybe_blur(x, space_only=True) + + x = self.downsample(x) + + x = (x + res) * (2**-0.5) + return x + + +class Discriminator(nn.Module): + @beartype + def __init__( + self, + *, + dim, + image_size, + channels=3, + max_dim=512, + attn_heads=8, + attn_dim_head=32, + linear_attn_dim_head=8, + linear_attn_heads=16, + ff_mult=4, + antialiased_downsample=False, + ): + super().__init__() + image_size = pair(image_size) + min_image_resolution = min(image_size) + + num_layers = int(log2(min_image_resolution) - 2) + + blocks = [] + + layer_dims = [channels] + [(dim * 4) * (2**i) for i in range(num_layers + 1)] + layer_dims = [min(layer_dim, max_dim) for layer_dim in layer_dims] + layer_dims_in_out = tuple(zip(layer_dims[:-1], layer_dims[1:])) + + blocks = [] + attn_blocks = [] + + image_resolution = min_image_resolution + + for ind, (in_chan, out_chan) in enumerate(layer_dims_in_out): + num_layer = ind + 1 + is_not_last = ind != (len(layer_dims_in_out) - 1) + + block = DiscriminatorBlock( + in_chan, out_chan, downsample=is_not_last, antialiased_downsample=antialiased_downsample + ) + + attn_block = nn.Sequential( + Residual(LinearSpaceAttention(dim=out_chan, heads=linear_attn_heads, dim_head=linear_attn_dim_head)), + Residual(FeedForward(dim=out_chan, mult=ff_mult, images=True)), + ) + + blocks.append(nn.ModuleList([block, attn_block])) + + image_resolution //= 2 + + self.blocks = nn.ModuleList(blocks) + + dim_last = layer_dims[-1] + + downsample_factor = 2**num_layers + last_fmap_size = tuple(map(lambda n: n // downsample_factor, image_size)) + + latent_dim = last_fmap_size[0] * last_fmap_size[1] * dim_last + + self.to_logits = nn.Sequential( + nn.Conv2d(dim_last, dim_last, 3, padding=1), + leaky_relu(), + Rearrange("b ... -> b (...)"), + nn.Linear(latent_dim, 1), + Rearrange("b 1 -> b"), + ) + + def forward(self, x): + for block, attn_block in self.blocks: + x = block(x) + x = attn_block(x) + + return self.to_logits(x) + + +class DiscriminatorBlock3D(nn.Module): + def __init__( + self, + input_channels, + filters, + antialiased_downsample=True, + ): + super().__init__() + self.conv_res = nn.Conv3d(input_channels, filters, 1, stride=2) + + self.net = nn.Sequential( + nn.Conv3d(input_channels, filters, 3, padding=1), + leaky_relu(), + nn.Conv3d(filters, filters, 3, padding=1), + leaky_relu(), + ) + + self.maybe_blur = Blur() if antialiased_downsample else None + + self.downsample = nn.Sequential( + Rearrange("b c (f p1) (h p2) (w p3) -> b (c p1 p2 p3) f h w", p1=2, p2=2, p3=2), + nn.Conv3d(filters * 8, filters, 1), + ) + + def forward(self, x): + res = self.conv_res(x) + + x = self.net(x) + + if exists(self.downsample): + if exists(self.maybe_blur): + x = self.maybe_blur(x, space_only=True) + + x = self.downsample(x) + + x = (x + res) * (2**-0.5) + return x + + +class DiscriminatorBlock3DWithfirstframe(nn.Module): + def __init__( + self, + input_channels, + filters, + antialiased_downsample=True, + pad_mode="first", + ): + super().__init__() + self.downsample_res = DownSample3D( + in_channels=input_channels, + out_channels=filters, + with_conv=True, + compress_time=True, + ) + + self.net = nn.Sequential( + CausalConv3d(input_channels, filters, kernel_size=3, pad_mode=pad_mode), + leaky_relu(), + CausalConv3d(filters, filters, kernel_size=3, pad_mode=pad_mode), + leaky_relu(), + ) + + self.maybe_blur = Blur() if antialiased_downsample else None + + self.downsample = DownSample3D( + in_channels=filters, + out_channels=filters, + with_conv=True, + compress_time=True, + ) + + def forward(self, x): + res = self.downsample_res(x) + + x = self.net(x) + + if exists(self.downsample): + if exists(self.maybe_blur): + x = self.maybe_blur(x, space_only=True) + + x = self.downsample(x) + + x = (x + res) * (2**-0.5) + return x + + +class Discriminator3D(nn.Module): + @beartype + def __init__( + self, + *, + dim, + image_size, + frame_num, + channels=3, + max_dim=512, + linear_attn_dim_head=8, + linear_attn_heads=16, + ff_mult=4, + antialiased_downsample=False, + ): + super().__init__() + image_size = pair(image_size) + min_image_resolution = min(image_size) + + num_layers = int(log2(min_image_resolution) - 2) + temporal_num_layers = int(log2(frame_num)) + self.temporal_num_layers = temporal_num_layers + + layer_dims = [channels] + [(dim * 4) * (2**i) for i in range(num_layers + 1)] + layer_dims = [min(layer_dim, max_dim) for layer_dim in layer_dims] + layer_dims_in_out = tuple(zip(layer_dims[:-1], layer_dims[1:])) + + blocks = [] + + image_resolution = min_image_resolution + frame_resolution = frame_num + + for ind, (in_chan, out_chan) in enumerate(layer_dims_in_out): + num_layer = ind + 1 + is_not_last = ind != (len(layer_dims_in_out) - 1) + + if ind < temporal_num_layers: + block = DiscriminatorBlock3D( + in_chan, + out_chan, + antialiased_downsample=antialiased_downsample, + ) + + blocks.append(block) + + frame_resolution //= 2 + else: + block = DiscriminatorBlock( + in_chan, + out_chan, + downsample=is_not_last, + antialiased_downsample=antialiased_downsample, + ) + attn_block = nn.Sequential( + Residual( + LinearSpaceAttention(dim=out_chan, heads=linear_attn_heads, dim_head=linear_attn_dim_head) + ), + Residual(FeedForward(dim=out_chan, mult=ff_mult, images=True)), + ) + + blocks.append(nn.ModuleList([block, attn_block])) + + image_resolution //= 2 + + self.blocks = nn.ModuleList(blocks) + + dim_last = layer_dims[-1] + + downsample_factor = 2**num_layers + last_fmap_size = tuple(map(lambda n: n // downsample_factor, image_size)) + + latent_dim = last_fmap_size[0] * last_fmap_size[1] * dim_last + + self.to_logits = nn.Sequential( + nn.Conv2d(dim_last, dim_last, 3, padding=1), + leaky_relu(), + Rearrange("b ... -> b (...)"), + nn.Linear(latent_dim, 1), + Rearrange("b 1 -> b"), + ) + + def forward(self, x): + for i, layer in enumerate(self.blocks): + if i < self.temporal_num_layers: + x = layer(x) + if i == self.temporal_num_layers - 1: + x = rearrange(x, "b c f h w -> (b f) c h w") + else: + block, attn_block = layer + x = block(x) + x = attn_block(x) + + return self.to_logits(x) + + +class Discriminator3DWithfirstframe(nn.Module): + @beartype + def __init__( + self, + *, + dim, + image_size, + frame_num, + channels=3, + max_dim=512, + linear_attn_dim_head=8, + linear_attn_heads=16, + ff_mult=4, + antialiased_downsample=False, + ): + super().__init__() + image_size = pair(image_size) + min_image_resolution = min(image_size) + + num_layers = int(log2(min_image_resolution) - 2) + temporal_num_layers = int(log2(frame_num)) + self.temporal_num_layers = temporal_num_layers + + layer_dims = [channels] + [(dim * 4) * (2**i) for i in range(num_layers + 1)] + layer_dims = [min(layer_dim, max_dim) for layer_dim in layer_dims] + layer_dims_in_out = tuple(zip(layer_dims[:-1], layer_dims[1:])) + + blocks = [] + + image_resolution = min_image_resolution + frame_resolution = frame_num + + for ind, (in_chan, out_chan) in enumerate(layer_dims_in_out): + num_layer = ind + 1 + is_not_last = ind != (len(layer_dims_in_out) - 1) + + if ind < temporal_num_layers: + block = DiscriminatorBlock3DWithfirstframe( + in_chan, + out_chan, + antialiased_downsample=antialiased_downsample, + ) + + blocks.append(block) + + frame_resolution //= 2 + else: + block = DiscriminatorBlock( + in_chan, + out_chan, + downsample=is_not_last, + antialiased_downsample=antialiased_downsample, + ) + attn_block = nn.Sequential( + Residual( + LinearSpaceAttention(dim=out_chan, heads=linear_attn_heads, dim_head=linear_attn_dim_head) + ), + Residual(FeedForward(dim=out_chan, mult=ff_mult, images=True)), + ) + + blocks.append(nn.ModuleList([block, attn_block])) + + image_resolution //= 2 + + self.blocks = nn.ModuleList(blocks) + + dim_last = layer_dims[-1] + + downsample_factor = 2**num_layers + last_fmap_size = tuple(map(lambda n: n // downsample_factor, image_size)) + + latent_dim = last_fmap_size[0] * last_fmap_size[1] * dim_last + + self.to_logits = nn.Sequential( + nn.Conv2d(dim_last, dim_last, 3, padding=1), + leaky_relu(), + Rearrange("b ... -> b (...)"), + nn.Linear(latent_dim, 1), + Rearrange("b 1 -> b"), + ) + + def forward(self, x): + for i, layer in enumerate(self.blocks): + if i < self.temporal_num_layers: + x = layer(x) + if i == self.temporal_num_layers - 1: + x = x.mean(dim=2) + # x = rearrange(x, "b c f h w -> (b f) c h w") + else: + block, attn_block = layer + x = block(x) + x = attn_block(x) + + return self.to_logits(x) + + +class VideoAutoencoderLoss(nn.Module): + def __init__( + self, + disc_start, + perceptual_weight=1, + adversarial_loss_weight=0, + multiscale_adversarial_loss_weight=0, + grad_penalty_loss_weight=0, + quantizer_aux_loss_weight=0, + vgg_weights=VGG16_Weights.DEFAULT, + discr_kwargs=None, + discr_3d_kwargs=None, + ): + super().__init__() + + self.disc_start = disc_start + self.perceptual_weight = perceptual_weight + self.adversarial_loss_weight = adversarial_loss_weight + self.multiscale_adversarial_loss_weight = multiscale_adversarial_loss_weight + self.grad_penalty_loss_weight = grad_penalty_loss_weight + self.quantizer_aux_loss_weight = quantizer_aux_loss_weight + + if self.perceptual_weight > 0: + self.perceptual_model = LPIPS().eval() + # self.vgg = torchvision.models.vgg16(pretrained = True) + # self.vgg.requires_grad_(False) + # if self.adversarial_loss_weight > 0: + # self.discr = Discriminator(**discr_kwargs) + # else: + # self.discr = None + # if self.multiscale_adversarial_loss_weight > 0: + # self.multiscale_discrs = nn.ModuleList([*multiscale_discrs]) + # else: + # self.multiscale_discrs = None + if discr_kwargs is not None: + self.discr = Discriminator(**discr_kwargs) + else: + self.discr = None + if discr_3d_kwargs is not None: + # self.discr_3d = Discriminator3D(**discr_3d_kwargs) + self.discr_3d = instantiate_from_config(discr_3d_kwargs) + else: + self.discr_3d = None + # self.multiscale_discrs = nn.ModuleList([*multiscale_discrs]) + + self.register_buffer("zero", torch.tensor(0.0), persistent=False) + + def get_trainable_params(self) -> Any: + params = [] + if self.discr is not None: + params += list(self.discr.parameters()) + if self.discr_3d is not None: + params += list(self.discr_3d.parameters()) + # if self.multiscale_discrs is not None: + # for discr in self.multiscale_discrs: + # params += list(discr.parameters()) + return params + + def get_trainable_parameters(self) -> Any: + return self.get_trainable_params() + + def forward( + self, + inputs, + reconstructions, + optimizer_idx, + global_step, + aux_losses=None, + last_layer=None, + split="train", + ): + batch, channels, frames = inputs.shape[:3] + + if optimizer_idx == 0: + recon_loss = F.mse_loss(inputs, reconstructions) + + if self.perceptual_weight > 0: + frame_indices = torch.randn((batch, frames)).topk(1, dim=-1).indices + + input_frames = pick_video_frame(inputs, frame_indices) + recon_frames = pick_video_frame(reconstructions, frame_indices) + + perceptual_loss = self.perceptual_model(input_frames.contiguous(), recon_frames.contiguous()).mean() + else: + perceptual_loss = self.zero + + if global_step >= self.disc_start or not self.training or self.adversarial_loss_weight == 0: + gen_loss = self.zero + adaptive_weight = 0 + else: + # frame_indices = torch.randn((batch, frames)).topk(1, dim = -1).indices + # recon_video_frames = pick_video_frame(reconstructions, frame_indices) + + # fake_logits = self.discr(recon_video_frames) + fake_logits = self.discr_3d(reconstructions) + gen_loss = hinge_gen_loss(fake_logits) + + adaptive_weight = 1 + if self.perceptual_weight > 0 and last_layer is not None: + norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_layer).norm(p=2) + norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_layer).norm(p=2) + adaptive_weight = norm_grad_wrt_perceptual_loss / norm_grad_wrt_gen_loss.clamp(min=1e-3) + adaptive_weight.clamp_(max=1e3) + + if torch.isnan(adaptive_weight).any(): + adaptive_weight = 1 + + # multiscale discriminator losses + + # multiscale_gen_losses = [] + # multiscale_gen_adaptive_weights = [] + # if self.multiscale_adversarial_loss_weight > 0: + # if not exists(recon_video_frames): + # frame_indices = torch.randn((batch, frames)).topk(1, dim = -1).indices + # recon_video_frames = pick_video_frame(reconstructions, frame_indices) + # for discr in self.multiscale_discrs: + # fake_logits = recon_video_frames + + # multiscale_gen_loss = hinge_gen_loss(fake_logits) + # multiscale_gen_losses.append(multiscale_gen_loss) + + # multiscale_adaptive_weight = 1. + + # if exists(norm_grad_wrt_perceptual_loss): + # norm_grad_wrt_gen_loss = grad_layer_wrt_loss(multiscale_gen_loss, last_layer).norm(p = 2) + # multiscale_adaptive_weight = norm_grad_wrt_perceptual_loss / norm_grad_wrt_gen_loss.clamp(min = 1e-5) + # multiscale_adaptive_weight.clamp_(max = 1e3) + + # multiscale_gen_adaptive_weights.append(multiscale_adaptive_weight) + # weighted_multiscale_gen_losses = sum(loss * weight for loss, weight in zip(multiscale_gen_losses, multiscale_gen_adaptive_weights)) + # else: + # weighted_multiscale_gen_losses = self.zero + + if aux_losses is None: + aux_losses = self.zero + + total_loss = ( + recon_loss + + aux_losses * self.quantizer_aux_loss_weight + + perceptual_loss * self.perceptual_weight + + gen_loss * self.adversarial_loss_weight + ) + # gen_loss * adaptive_weight * self.adversarial_loss_weight + \ + # weighted_multiscale_gen_losses * self.multiscale_adversarial_loss_weight + + log = { + "{}/total_loss".format(split): total_loss.detach(), + "{}/recon_loss".format(split): recon_loss.detach(), + "{}/perceptual_loss".format(split): perceptual_loss.detach(), + "{}/gen_loss".format(split): gen_loss.detach(), + "{}/aux_losses".format(split): aux_losses.detach(), + # "{}/weighted_multiscale_gen_losses".format(split): weighted_multiscale_gen_losses.detach(), + "{}/adaptive_weight".format(split): adaptive_weight, + # "{}/multiscale_adaptive_weights".format(split): sum(multiscale_gen_adaptive_weights), + } + + return total_loss, log + + if optimizer_idx == 1: + # frame_indices = torch.randn((batch, frames)).topk(1, dim = -1).indices + + # real = pick_video_frame(inputs, frame_indices) + # fake = pick_video_frame(reconstructions, frame_indices) + + # apply_gradient_penalty = self.grad_penalty_loss_weight > 0 + # if apply_gradient_penalty: + # real = real.requires_grad_() + + # real_logits = self.discr(real) + # fake_logits = self.discr(fake.detach()) + + apply_gradient_penalty = self.grad_penalty_loss_weight > 0 + if apply_gradient_penalty: + inputs = inputs.requires_grad_() + real_logits = self.discr_3d(inputs) + fake_logits = self.discr_3d(reconstructions.detach()) + + discr_loss = hinge_discr_loss(fake_logits, real_logits) + + # # multiscale discriminators + # multiscale_discr_losses = [] + # if self.multiscale_adversarial_loss_weight > 0: + # for discr in self.multiscale_discrs: + # multiscale_real_logits = discr(inputs) + # multiscale_fake_logits = discr(reconstructions.detach()) + + # multiscale_discr_loss = hinge_discr_loss(multiscale_fake_logits, multiscale_real_logits) + # multiscale_discr_losses.append(multiscale_discr_loss) + # else: + # multiscale_discr_losses.append(self.zero) + + # gradient penalty + if apply_gradient_penalty: + # gradient_penalty_loss = gradient_penalty(real, real_logits) + gradient_penalty_loss = gradient_penalty(inputs, real_logits) + else: + gradient_penalty_loss = self.zero + + total_loss = discr_loss + self.grad_penalty_loss_weight * gradient_penalty_loss + # self.grad_penalty_loss_weight * gradient_penalty_loss + \ + # sum(multiscale_discr_losses) * self.multiscale_adversarial_loss_weight + + log = { + "{}/total_disc_loss".format(split): total_loss.detach(), + "{}/discr_loss".format(split): discr_loss.detach(), + "{}/grad_penalty_loss".format(split): gradient_penalty_loss.detach(), + # "{}/multiscale_discr_loss".format(split): sum(multiscale_discr_losses).detach(), + "{}/logits_real".format(split): real_logits.detach().mean(), + "{}/logits_fake".format(split): fake_logits.detach().mean(), + } + return total_loss, log diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/lpips/__init__.py b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/lpips/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/lpips/loss/.gitignore b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/lpips/loss/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..a92958a1cd4ffe005e1f5448ab3e6fd9c795a43a --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/lpips/loss/.gitignore @@ -0,0 +1 @@ +vgg.pth \ No newline at end of file diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/lpips/loss/LICENSE b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/lpips/loss/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..924cfc85b8d63ef538f5676f830a2a8497932108 --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/lpips/loss/LICENSE @@ -0,0 +1,23 @@ +Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/lpips/loss/__init__.py b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/lpips/loss/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/lpips/loss/lpips.py b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/lpips/loss/lpips.py new file mode 100644 index 0000000000000000000000000000000000000000..a0249cf74ca8b2c7fb51cb3b51ec61e02107a970 --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/lpips/loss/lpips.py @@ -0,0 +1,132 @@ +"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" + +from collections import namedtuple + +import torch +import torch.nn as nn +from torchvision import models + +from ..util import get_ckpt_path + + +class LPIPS(nn.Module): + # Learned perceptual metric + def __init__(self, use_dropout=True): + super().__init__() + self.scaling_layer = ScalingLayer() + self.chns = [64, 128, 256, 512, 512] # vg16 features + self.net = vgg16(pretrained=True, requires_grad=False) + self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) + self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) + self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) + self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) + self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) + self.load_from_pretrained() + for param in self.parameters(): + param.requires_grad = False + + def load_from_pretrained(self, name="vgg_lpips"): + ckpt = get_ckpt_path(name, "sgm/modules/autoencoding/lpips/loss") + self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) + print("loaded pretrained LPIPS loss from {}".format(ckpt)) + + @classmethod + def from_pretrained(cls, name="vgg_lpips"): + if name != "vgg_lpips": + raise NotImplementedError + model = cls() + ckpt = get_ckpt_path(name) + model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) + return model + + def forward(self, input, target): + in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) + outs0, outs1 = self.net(in0_input), self.net(in1_input) + feats0, feats1, diffs = {}, {}, {} + lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] + for kk in range(len(self.chns)): + feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) + diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 + + res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))] + val = res[0] + for l in range(1, len(self.chns)): + val += res[l] + return val + + +class ScalingLayer(nn.Module): + def __init__(self): + super(ScalingLayer, self).__init__() + self.register_buffer("shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None]) + self.register_buffer("scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None]) + + def forward(self, inp): + return (inp - self.shift) / self.scale + + +class NetLinLayer(nn.Module): + """A single linear layer which does a 1x1 conv""" + + def __init__(self, chn_in, chn_out=1, use_dropout=False): + super(NetLinLayer, self).__init__() + layers = ( + [ + nn.Dropout(), + ] + if (use_dropout) + else [] + ) + layers += [ + nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), + ] + self.model = nn.Sequential(*layers) + + +class vgg16(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True): + super(vgg16, self).__init__() + vgg_pretrained_features = models.vgg16(pretrained=pretrained).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.N_slices = 5 + for x in range(4): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(4, 9): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(9, 16): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(16, 23): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(23, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h = self.slice1(X) + h_relu1_2 = h + h = self.slice2(h) + h_relu2_2 = h + h = self.slice3(h) + h_relu3_3 = h + h = self.slice4(h) + h_relu4_3 = h + h = self.slice5(h) + h_relu5_3 = h + vgg_outputs = namedtuple("VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"]) + out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) + return out + + +def normalize_tensor(x, eps=1e-10): + norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True)) + return x / (norm_factor + eps) + + +def spatial_average(x, keepdim=True): + return x.mean([2, 3], keepdim=keepdim) diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/lpips/model/LICENSE b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/lpips/model/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..4b356e66b5aa689b339f1a80a9f1b5ba378003bb --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/lpips/model/LICENSE @@ -0,0 +1,58 @@ +Copyright (c) 2017, Jun-Yan Zhu and Taesung Park +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +--------------------------- LICENSE FOR pix2pix -------------------------------- +BSD License + +For pix2pix software +Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +----------------------------- LICENSE FOR DCGAN -------------------------------- +BSD License + +For dcgan.torch software + +Copyright (c) 2015, Facebook, Inc. All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/lpips/model/__init__.py b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/lpips/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/lpips/model/model.py b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/lpips/model/model.py new file mode 100644 index 0000000000000000000000000000000000000000..ee13babd77864bb81456a7c9634ba7e9e597983f --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/lpips/model/model.py @@ -0,0 +1,89 @@ +import functools + +import torch.nn as nn + +from ..util import ActNorm + + +def weights_init(m): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + try: + nn.init.normal_(m.weight.data, 0.0, 0.02) + except: + nn.init.normal_(m.conv.weight.data, 0.0, 0.02) + elif classname.find("BatchNorm") != -1: + nn.init.normal_(m.weight.data, 1.0, 0.02) + nn.init.constant_(m.bias.data, 0) + + +class NLayerDiscriminator(nn.Module): + """Defines a PatchGAN discriminator as in Pix2Pix + --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py + """ + + def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): + """Construct a PatchGAN discriminator + Parameters: + input_nc (int) -- the number of channels in input images + ndf (int) -- the number of filters in the last conv layer + n_layers (int) -- the number of conv layers in the discriminator + norm_layer -- normalization layer + """ + super(NLayerDiscriminator, self).__init__() + if not use_actnorm: + norm_layer = nn.BatchNorm2d + else: + norm_layer = ActNorm + if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters + use_bias = norm_layer.func != nn.BatchNorm2d + else: + use_bias = norm_layer != nn.BatchNorm2d + + kw = 4 + padw = 1 + sequence = [ + nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), + nn.LeakyReLU(0.2, True), + ] + nf_mult = 1 + nf_mult_prev = 1 + for n in range(1, n_layers): # gradually increase the number of filters + nf_mult_prev = nf_mult + nf_mult = min(2**n, 8) + sequence += [ + nn.Conv2d( + ndf * nf_mult_prev, + ndf * nf_mult, + kernel_size=kw, + stride=2, + padding=padw, + bias=use_bias, + ), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True), + ] + + nf_mult_prev = nf_mult + nf_mult = min(2**n_layers, 8) + sequence += [ + nn.Conv2d( + ndf * nf_mult_prev, + ndf * nf_mult, + kernel_size=kw, + stride=1, + padding=padw, + bias=use_bias, + ), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True), + ] + + sequence += [ + nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw) + ] # output 1 channel prediction map + self.main = nn.Sequential(*sequence) + + def forward(self, input): + """Standard forward.""" + return self.main(input) diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/lpips/util.py b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/lpips/util.py new file mode 100644 index 0000000000000000000000000000000000000000..cb4a03624437b1a2498026a2669e57cb66409e6d --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/lpips/util.py @@ -0,0 +1,114 @@ +import hashlib +import os + +import requests +import torch +import torch.nn as nn +from tqdm import tqdm + +URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"} + +CKPT_MAP = {"vgg_lpips": "vgg.pth"} + +MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"} + + +def download(url, local_path, chunk_size=1024): + os.makedirs(os.path.split(local_path)[0], exist_ok=True) + with requests.get(url, stream=True) as r: + total_size = int(r.headers.get("content-length", 0)) + with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: + with open(local_path, "wb") as f: + for data in r.iter_content(chunk_size=chunk_size): + if data: + f.write(data) + pbar.update(chunk_size) + + +def md5_hash(path): + with open(path, "rb") as f: + content = f.read() + return hashlib.md5(content).hexdigest() + + +def get_ckpt_path(name, root, check=False): + assert name in URL_MAP + path = os.path.join(root, CKPT_MAP[name]) + if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): + print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) + download(URL_MAP[name], path) + md5 = md5_hash(path) + assert md5 == MD5_MAP[name], md5 + return path + + +class ActNorm(nn.Module): + def __init__(self, num_features, logdet=False, affine=True, allow_reverse_init=False): + assert affine + super().__init__() + self.logdet = logdet + self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) + self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) + self.allow_reverse_init = allow_reverse_init + + self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8)) + + def initialize(self, input): + with torch.no_grad(): + flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) + mean = flatten.mean(1).unsqueeze(1).unsqueeze(2).unsqueeze(3).permute(1, 0, 2, 3) + std = flatten.std(1).unsqueeze(1).unsqueeze(2).unsqueeze(3).permute(1, 0, 2, 3) + + self.loc.data.copy_(-mean) + self.scale.data.copy_(1 / (std + 1e-6)) + + def forward(self, input, reverse=False): + if reverse: + return self.reverse(input) + if len(input.shape) == 2: + input = input[:, :, None, None] + squeeze = True + else: + squeeze = False + + _, _, height, width = input.shape + + if self.training and self.initialized.item() == 0: + self.initialize(input) + self.initialized.fill_(1) + + h = self.scale * (input + self.loc) + + if squeeze: + h = h.squeeze(-1).squeeze(-1) + + if self.logdet: + log_abs = torch.log(torch.abs(self.scale)) + logdet = height * width * torch.sum(log_abs) + logdet = logdet * torch.ones(input.shape[0]).to(input) + return h, logdet + + return h + + def reverse(self, output): + if self.training and self.initialized.item() == 0: + if not self.allow_reverse_init: + raise RuntimeError( + "Initializing ActNorm in reverse direction is " + "disabled by default. Use allow_reverse_init=True to enable." + ) + else: + self.initialize(output) + self.initialized.fill_(1) + + if len(output.shape) == 2: + output = output[:, :, None, None] + squeeze = True + else: + squeeze = False + + h = output / self.scale - self.loc + + if squeeze: + h = h.squeeze(-1).squeeze(-1) + return h diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/lpips/vqperceptual.py b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/lpips/vqperceptual.py new file mode 100644 index 0000000000000000000000000000000000000000..1e4944bd6c287fa0c74bf1c5f1cd8289a27c01b6 --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/lpips/vqperceptual.py @@ -0,0 +1,16 @@ +import torch +import torch.nn.functional as F + + +def hinge_d_loss(logits_real, logits_fake): + loss_real = torch.mean(F.relu(1.0 - logits_real)) + loss_fake = torch.mean(F.relu(1.0 + logits_fake)) + d_loss = 0.5 * (loss_real + loss_fake) + return d_loss + + +def vanilla_d_loss(logits_real, logits_fake): + d_loss = 0.5 * ( + torch.mean(torch.nn.functional.softplus(-logits_real)) + torch.mean(torch.nn.functional.softplus(logits_fake)) + ) + return d_loss diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/magvit2_pytorch.py b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/magvit2_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..58889526b30e76fd2715d153940f0059a64e3fb4 --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/magvit2_pytorch.py @@ -0,0 +1,1762 @@ +import copy +from pathlib import Path +from math import log2, ceil, sqrt +from functools import wraps, partial + +import torch +import torch.nn.functional as F +from torch.cuda.amp import autocast +from torch import nn, einsum, Tensor +from torch.nn import Module, ModuleList +from torch.autograd import grad as torch_grad + +import torchvision +from torchvision.models import VGG16_Weights + +from collections import namedtuple + +# from vector_quantize_pytorch import LFQ, FSQ +from .regularizers.finite_scalar_quantization import FSQ +from .regularizers.lookup_free_quantization import LFQ + +from einops import rearrange, repeat, reduce, pack, unpack +from einops.layers.torch import Rearrange + +from beartype import beartype +from beartype.typing import Union, Tuple, Optional, List + +from magvit2_pytorch.attend import Attend +from magvit2_pytorch.version import __version__ + +from gateloop_transformer import SimpleGateLoopLayer + +from taylor_series_linear_attention import TaylorSeriesLinearAttn + +from kornia.filters import filter3d + +import pickle + +# helper + + +def exists(v): + return v is not None + + +def default(v, d): + return v if exists(v) else d + + +def safe_get_index(it, ind, default=None): + if ind < len(it): + return it[ind] + return default + + +def pair(t): + return t if isinstance(t, tuple) else (t, t) + + +def identity(t, *args, **kwargs): + return t + + +def divisible_by(num, den): + return (num % den) == 0 + + +def pack_one(t, pattern): + return pack([t], pattern) + + +def unpack_one(t, ps, pattern): + return unpack(t, ps, pattern)[0] + + +def append_dims(t, ndims: int): + return t.reshape(*t.shape, *((1,) * ndims)) + + +def is_odd(n): + return not divisible_by(n, 2) + + +def maybe_del_attr_(o, attr): + if hasattr(o, attr): + delattr(o, attr) + + +def cast_tuple(t, length=1): + return t if isinstance(t, tuple) else ((t,) * length) + + +# tensor helpers + + +def l2norm(t): + return F.normalize(t, dim=-1, p=2) + + +def pad_at_dim(t, pad, dim=-1, value=0.0): + dims_from_right = (-dim - 1) if dim < 0 else (t.ndim - dim - 1) + zeros = (0, 0) * dims_from_right + return F.pad(t, (*zeros, *pad), value=value) + + +def pick_video_frame(video, frame_indices): + batch, device = video.shape[0], video.device + video = rearrange(video, "b c f ... -> b f c ...") + batch_indices = torch.arange(batch, device=device) + batch_indices = rearrange(batch_indices, "b -> b 1") + images = video[batch_indices, frame_indices] + images = rearrange(images, "b 1 c ... -> b c ...") + return images + + +# gan related + + +def gradient_penalty(images, output): + batch_size = images.shape[0] + + gradients = torch_grad( + outputs=output, + inputs=images, + grad_outputs=torch.ones(output.size(), device=images.device), + create_graph=True, + retain_graph=True, + only_inputs=True, + )[0] + + gradients = rearrange(gradients, "b ... -> b (...)") + return ((gradients.norm(2, dim=1) - 1) ** 2).mean() + + +def leaky_relu(p=0.1): + return nn.LeakyReLU(p) + + +def hinge_discr_loss(fake, real): + return (F.relu(1 + fake) + F.relu(1 - real)).mean() + + +def hinge_gen_loss(fake): + return -fake.mean() + + +@autocast(enabled=False) +@beartype +def grad_layer_wrt_loss(loss: Tensor, layer: nn.Parameter): + return torch_grad(outputs=loss, inputs=layer, grad_outputs=torch.ones_like(loss), retain_graph=True)[0].detach() + + +# helper decorators + + +def remove_vgg(fn): + @wraps(fn) + def inner(self, *args, **kwargs): + has_vgg = hasattr(self, "vgg") + if has_vgg: + vgg = self.vgg + delattr(self, "vgg") + + out = fn(self, *args, **kwargs) + + if has_vgg: + self.vgg = vgg + + return out + + return inner + + +# helper classes + + +def Sequential(*modules): + modules = [*filter(exists, modules)] + + if len(modules) == 0: + return nn.Identity() + + return nn.Sequential(*modules) + + +class Residual(Module): + @beartype + def __init__(self, fn: Module): + super().__init__() + self.fn = fn + + def forward(self, x, **kwargs): + return self.fn(x, **kwargs) + x + + +# for a bunch of tensor operations to change tensor to (batch, time, feature dimension) and back + + +class ToTimeSequence(Module): + @beartype + def __init__(self, fn: Module): + super().__init__() + self.fn = fn + + def forward(self, x, **kwargs): + x = rearrange(x, "b c f ... -> b ... f c") + x, ps = pack_one(x, "* n c") + + o = self.fn(x, **kwargs) + + o = unpack_one(o, ps, "* n c") + return rearrange(o, "b ... f c -> b c f ...") + + +class SqueezeExcite(Module): + # global context network - attention-esque squeeze-excite variant (https://arxiv.org/abs/2012.13375) + + def __init__(self, dim, *, dim_out=None, dim_hidden_min=16, init_bias=-10): + super().__init__() + dim_out = default(dim_out, dim) + + self.to_k = nn.Conv2d(dim, 1, 1) + dim_hidden = max(dim_hidden_min, dim_out // 2) + + self.net = nn.Sequential( + nn.Conv2d(dim, dim_hidden, 1), nn.LeakyReLU(0.1), nn.Conv2d(dim_hidden, dim_out, 1), nn.Sigmoid() + ) + + nn.init.zeros_(self.net[-2].weight) + nn.init.constant_(self.net[-2].bias, init_bias) + + def forward(self, x): + orig_input, batch = x, x.shape[0] + is_video = x.ndim == 5 + + if is_video: + x = rearrange(x, "b c f h w -> (b f) c h w") + + context = self.to_k(x) + + context = rearrange(context, "b c h w -> b c (h w)").softmax(dim=-1) + spatial_flattened_input = rearrange(x, "b c h w -> b c (h w)") + + out = einsum("b i n, b c n -> b c i", context, spatial_flattened_input) + out = rearrange(out, "... -> ... 1") + gates = self.net(out) + + if is_video: + gates = rearrange(gates, "(b f) c h w -> b c f h w", b=batch) + + return gates * orig_input + + +# token shifting + + +class TokenShift(Module): + @beartype + def __init__(self, fn: Module): + super().__init__() + self.fn = fn + + def forward(self, x, **kwargs): + x, x_shift = x.chunk(2, dim=1) + x_shift = pad_at_dim(x_shift, (1, -1), dim=2) # shift time dimension + x = torch.cat((x, x_shift), dim=1) + return self.fn(x, **kwargs) + + +# rmsnorm + + +class RMSNorm(Module): + def __init__(self, dim, channel_first=False, images=False, bias=False): + super().__init__() + broadcastable_dims = (1, 1, 1) if not images else (1, 1) + shape = (dim, *broadcastable_dims) if channel_first else (dim,) + + self.channel_first = channel_first + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(shape)) + self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0 + + def forward(self, x): + return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias + + +class AdaptiveRMSNorm(Module): + def __init__(self, dim, *, dim_cond, channel_first=False, images=False, bias=False): + super().__init__() + broadcastable_dims = (1, 1, 1) if not images else (1, 1) + shape = (dim, *broadcastable_dims) if channel_first else (dim,) + + self.dim_cond = dim_cond + self.channel_first = channel_first + self.scale = dim**0.5 + + self.to_gamma = nn.Linear(dim_cond, dim) + self.to_bias = nn.Linear(dim_cond, dim) if bias else None + + nn.init.zeros_(self.to_gamma.weight) + nn.init.ones_(self.to_gamma.bias) + + if bias: + nn.init.zeros_(self.to_bias.weight) + nn.init.zeros_(self.to_bias.bias) + + @beartype + def forward(self, x: Tensor, *, cond: Tensor): + batch = x.shape[0] + assert cond.shape == (batch, self.dim_cond) + + gamma = self.to_gamma(cond) + + bias = 0.0 + if exists(self.to_bias): + bias = self.to_bias(cond) + + if self.channel_first: + gamma = append_dims(gamma, x.ndim - 2) + + if exists(self.to_bias): + bias = append_dims(bias, x.ndim - 2) + + return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * gamma + bias + + +# attention + + +class Attention(Module): + @beartype + def __init__( + self, + *, + dim, + dim_cond: Optional[int] = None, + causal=False, + dim_head=32, + heads=8, + flash=False, + dropout=0.0, + num_memory_kv=4, + ): + super().__init__() + dim_inner = dim_head * heads + + self.need_cond = exists(dim_cond) + + if self.need_cond: + self.norm = AdaptiveRMSNorm(dim, dim_cond=dim_cond) + else: + self.norm = RMSNorm(dim) + + self.to_qkv = nn.Sequential( + nn.Linear(dim, dim_inner * 3, bias=False), Rearrange("b n (qkv h d) -> qkv b h n d", qkv=3, h=heads) + ) + + assert num_memory_kv > 0 + self.mem_kv = nn.Parameter(torch.randn(2, heads, num_memory_kv, dim_head)) + + self.attend = Attend(causal=causal, dropout=dropout, flash=flash) + + self.to_out = nn.Sequential(Rearrange("b h n d -> b n (h d)"), nn.Linear(dim_inner, dim, bias=False)) + + @beartype + def forward(self, x, mask: Optional[Tensor] = None, cond: Optional[Tensor] = None): + maybe_cond_kwargs = dict(cond=cond) if self.need_cond else dict() + + x = self.norm(x, **maybe_cond_kwargs) + + q, k, v = self.to_qkv(x) + + mk, mv = map(lambda t: repeat(t, "h n d -> b h n d", b=q.shape[0]), self.mem_kv) + k = torch.cat((mk, k), dim=-2) + v = torch.cat((mv, v), dim=-2) + + out = self.attend(q, k, v, mask=mask) + return self.to_out(out) + + +class LinearAttention(Module): + """ + using the specific linear attention proposed in https://arxiv.org/abs/2106.09681 + """ + + @beartype + def __init__(self, *, dim, dim_cond: Optional[int] = None, dim_head=8, heads=8, dropout=0.0): + super().__init__() + dim_inner = dim_head * heads + + self.need_cond = exists(dim_cond) + + if self.need_cond: + self.norm = AdaptiveRMSNorm(dim, dim_cond=dim_cond) + else: + self.norm = RMSNorm(dim) + + self.attn = TaylorSeriesLinearAttn(dim=dim, dim_head=dim_head, heads=heads) + + def forward(self, x, cond: Optional[Tensor] = None): + maybe_cond_kwargs = dict(cond=cond) if self.need_cond else dict() + + x = self.norm(x, **maybe_cond_kwargs) + + return self.attn(x) + + +class LinearSpaceAttention(LinearAttention): + def forward(self, x, *args, **kwargs): + x = rearrange(x, "b c ... h w -> b ... h w c") + x, batch_ps = pack_one(x, "* h w c") + x, seq_ps = pack_one(x, "b * c") + + x = super().forward(x, *args, **kwargs) + + x = unpack_one(x, seq_ps, "b * c") + x = unpack_one(x, batch_ps, "* h w c") + return rearrange(x, "b ... h w c -> b c ... h w") + + +class SpaceAttention(Attention): + def forward(self, x, *args, **kwargs): + x = rearrange(x, "b c t h w -> b t h w c") + x, batch_ps = pack_one(x, "* h w c") + x, seq_ps = pack_one(x, "b * c") + + x = super().forward(x, *args, **kwargs) + + x = unpack_one(x, seq_ps, "b * c") + x = unpack_one(x, batch_ps, "* h w c") + return rearrange(x, "b t h w c -> b c t h w") + + +class TimeAttention(Attention): + def forward(self, x, *args, **kwargs): + x = rearrange(x, "b c t h w -> b h w t c") + x, batch_ps = pack_one(x, "* t c") + + x = super().forward(x, *args, **kwargs) + + x = unpack_one(x, batch_ps, "* t c") + return rearrange(x, "b h w t c -> b c t h w") + + +class GEGLU(Module): + def forward(self, x): + x, gate = x.chunk(2, dim=1) + return F.gelu(gate) * x + + +class FeedForward(Module): + @beartype + def __init__(self, dim, *, dim_cond: Optional[int] = None, mult=4, images=False): + super().__init__() + conv_klass = nn.Conv2d if images else nn.Conv3d + + rmsnorm_klass = RMSNorm if not exists(dim_cond) else partial(AdaptiveRMSNorm, dim_cond=dim_cond) + + maybe_adaptive_norm_klass = partial(rmsnorm_klass, channel_first=True, images=images) + + dim_inner = int(dim * mult * 2 / 3) + + self.norm = maybe_adaptive_norm_klass(dim) + + self.net = Sequential(conv_klass(dim, dim_inner * 2, 1), GEGLU(), conv_klass(dim_inner, dim, 1)) + + @beartype + def forward(self, x: Tensor, *, cond: Optional[Tensor] = None): + maybe_cond_kwargs = dict(cond=cond) if exists(cond) else dict() + + x = self.norm(x, **maybe_cond_kwargs) + return self.net(x) + + +# discriminator with anti-aliased downsampling (blurpool Zhang et al.) + + +class Blur(Module): + def __init__(self): + super().__init__() + f = torch.Tensor([1, 2, 1]) + self.register_buffer("f", f) + + def forward(self, x, space_only=False, time_only=False): + assert not (space_only and time_only) + + f = self.f + + if space_only: + f = einsum("i, j -> i j", f, f) + f = rearrange(f, "... -> 1 1 ...") + elif time_only: + f = rearrange(f, "f -> 1 f 1 1") + else: + f = einsum("i, j, k -> i j k", f, f, f) + f = rearrange(f, "... -> 1 ...") + + is_images = x.ndim == 4 + + if is_images: + x = rearrange(x, "b c h w -> b c 1 h w") + + out = filter3d(x, f, normalized=True) + + if is_images: + out = rearrange(out, "b c 1 h w -> b c h w") + + return out + + +class DiscriminatorBlock(Module): + def __init__(self, input_channels, filters, downsample=True, antialiased_downsample=True): + super().__init__() + self.conv_res = nn.Conv2d(input_channels, filters, 1, stride=(2 if downsample else 1)) + + self.net = nn.Sequential( + nn.Conv2d(input_channels, filters, 3, padding=1), + leaky_relu(), + nn.Conv2d(filters, filters, 3, padding=1), + leaky_relu(), + ) + + self.maybe_blur = Blur() if antialiased_downsample else None + + self.downsample = ( + nn.Sequential( + Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2), nn.Conv2d(filters * 4, filters, 1) + ) + if downsample + else None + ) + + def forward(self, x): + res = self.conv_res(x) + + x = self.net(x) + + if exists(self.downsample): + if exists(self.maybe_blur): + x = self.maybe_blur(x, space_only=True) + + x = self.downsample(x) + + x = (x + res) * (2**-0.5) + return x + + +class Discriminator(Module): + @beartype + def __init__( + self, + *, + dim, + image_size, + channels=3, + max_dim=512, + attn_heads=8, + attn_dim_head=32, + linear_attn_dim_head=8, + linear_attn_heads=16, + ff_mult=4, + antialiased_downsample=False, + ): + super().__init__() + image_size = pair(image_size) + min_image_resolution = min(image_size) + + num_layers = int(log2(min_image_resolution) - 2) + + blocks = [] + + layer_dims = [channels] + [(dim * 4) * (2**i) for i in range(num_layers + 1)] + layer_dims = [min(layer_dim, max_dim) for layer_dim in layer_dims] + layer_dims_in_out = tuple(zip(layer_dims[:-1], layer_dims[1:])) + + blocks = [] + attn_blocks = [] + + image_resolution = min_image_resolution + + for ind, (in_chan, out_chan) in enumerate(layer_dims_in_out): + num_layer = ind + 1 + is_not_last = ind != (len(layer_dims_in_out) - 1) + + block = DiscriminatorBlock( + in_chan, out_chan, downsample=is_not_last, antialiased_downsample=antialiased_downsample + ) + + attn_block = Sequential( + Residual(LinearSpaceAttention(dim=out_chan, heads=linear_attn_heads, dim_head=linear_attn_dim_head)), + Residual(FeedForward(dim=out_chan, mult=ff_mult, images=True)), + ) + + blocks.append(ModuleList([block, attn_block])) + + image_resolution //= 2 + + self.blocks = ModuleList(blocks) + + dim_last = layer_dims[-1] + + downsample_factor = 2**num_layers + last_fmap_size = tuple(map(lambda n: n // downsample_factor, image_size)) + + latent_dim = last_fmap_size[0] * last_fmap_size[1] * dim_last + + self.to_logits = Sequential( + nn.Conv2d(dim_last, dim_last, 3, padding=1), + leaky_relu(), + Rearrange("b ... -> b (...)"), + nn.Linear(latent_dim, 1), + Rearrange("b 1 -> b"), + ) + + def forward(self, x): + for block, attn_block in self.blocks: + x = block(x) + x = attn_block(x) + + return self.to_logits(x) + + +# modulatable conv from Karras et al. Stylegan2 +# for conditioning on latents + + +class Conv3DMod(Module): + @beartype + def __init__( + self, dim, *, spatial_kernel, time_kernel, causal=True, dim_out=None, demod=True, eps=1e-8, pad_mode="zeros" + ): + super().__init__() + dim_out = default(dim_out, dim) + + self.eps = eps + + assert is_odd(spatial_kernel) and is_odd(time_kernel) + + self.spatial_kernel = spatial_kernel + self.time_kernel = time_kernel + + time_padding = (time_kernel - 1, 0) if causal else ((time_kernel // 2,) * 2) + + self.pad_mode = pad_mode + self.padding = (*((spatial_kernel // 2,) * 4), *time_padding) + self.weights = nn.Parameter(torch.randn((dim_out, dim, time_kernel, spatial_kernel, spatial_kernel))) + + self.demod = demod + + nn.init.kaiming_normal_(self.weights, a=0, mode="fan_in", nonlinearity="selu") + + @beartype + def forward(self, fmap, cond: Tensor): + """ + notation + + b - batch + n - convs + o - output + i - input + k - kernel + """ + + b = fmap.shape[0] + + # prepare weights for modulation + + weights = self.weights + + # do the modulation, demodulation, as done in stylegan2 + + cond = rearrange(cond, "b i -> b 1 i 1 1 1") + + weights = weights * (cond + 1) + + if self.demod: + inv_norm = reduce(weights**2, "b o i k0 k1 k2 -> b o 1 1 1 1", "sum").clamp(min=self.eps).rsqrt() + weights = weights * inv_norm + + fmap = rearrange(fmap, "b c t h w -> 1 (b c) t h w") + + weights = rearrange(weights, "b o ... -> (b o) ...") + + fmap = F.pad(fmap, self.padding, mode=self.pad_mode) + fmap = F.conv3d(fmap, weights, groups=b) + + return rearrange(fmap, "1 (b o) ... -> b o ...", b=b) + + +# strided conv downsamples + + +class SpatialDownsample2x(Module): + def __init__(self, dim, dim_out=None, kernel_size=3, antialias=False): + super().__init__() + dim_out = default(dim_out, dim) + self.maybe_blur = Blur() if antialias else identity + self.conv = nn.Conv2d(dim, dim_out, kernel_size, stride=2, padding=kernel_size // 2) + + def forward(self, x): + x = self.maybe_blur(x, space_only=True) + + x = rearrange(x, "b c t h w -> b t c h w") + x, ps = pack_one(x, "* c h w") + + out = self.conv(x) + + out = unpack_one(out, ps, "* c h w") + out = rearrange(out, "b t c h w -> b c t h w") + return out + + +class TimeDownsample2x(Module): + def __init__(self, dim, dim_out=None, kernel_size=3, antialias=False): + super().__init__() + dim_out = default(dim_out, dim) + self.maybe_blur = Blur() if antialias else identity + self.time_causal_padding = (kernel_size - 1, 0) + self.conv = nn.Conv1d(dim, dim_out, kernel_size, stride=2) + + def forward(self, x): + x = self.maybe_blur(x, time_only=True) + + x = rearrange(x, "b c t h w -> b h w c t") + x, ps = pack_one(x, "* c t") + + x = F.pad(x, self.time_causal_padding) + out = self.conv(x) + + out = unpack_one(out, ps, "* c t") + out = rearrange(out, "b h w c t -> b c t h w") + return out + + +# depth to space upsamples + + +class SpatialUpsample2x(Module): + def __init__(self, dim, dim_out=None): + super().__init__() + dim_out = default(dim_out, dim) + conv = nn.Conv2d(dim, dim_out * 4, 1) + + self.net = nn.Sequential(conv, nn.SiLU(), Rearrange("b (c p1 p2) h w -> b c (h p1) (w p2)", p1=2, p2=2)) + + self.init_conv_(conv) + + def init_conv_(self, conv): + o, i, h, w = conv.weight.shape + conv_weight = torch.empty(o // 4, i, h, w) + nn.init.kaiming_uniform_(conv_weight) + conv_weight = repeat(conv_weight, "o ... -> (o 4) ...") + + conv.weight.data.copy_(conv_weight) + nn.init.zeros_(conv.bias.data) + + def forward(self, x): + x = rearrange(x, "b c t h w -> b t c h w") + x, ps = pack_one(x, "* c h w") + + out = self.net(x) + + out = unpack_one(out, ps, "* c h w") + out = rearrange(out, "b t c h w -> b c t h w") + return out + + +class TimeUpsample2x(Module): + def __init__(self, dim, dim_out=None): + super().__init__() + dim_out = default(dim_out, dim) + conv = nn.Conv1d(dim, dim_out * 2, 1) + + self.net = nn.Sequential(conv, nn.SiLU(), Rearrange("b (c p) t -> b c (t p)", p=2)) + + self.init_conv_(conv) + + def init_conv_(self, conv): + o, i, t = conv.weight.shape + conv_weight = torch.empty(o // 2, i, t) + nn.init.kaiming_uniform_(conv_weight) + conv_weight = repeat(conv_weight, "o ... -> (o 2) ...") + + conv.weight.data.copy_(conv_weight) + nn.init.zeros_(conv.bias.data) + + def forward(self, x): + x = rearrange(x, "b c t h w -> b h w c t") + x, ps = pack_one(x, "* c t") + + out = self.net(x) + + out = unpack_one(out, ps, "* c t") + out = rearrange(out, "b h w c t -> b c t h w") + return out + + +# autoencoder - only best variant here offered, with causal conv 3d + + +def SameConv2d(dim_in, dim_out, kernel_size): + kernel_size = cast_tuple(kernel_size, 2) + padding = [k // 2 for k in kernel_size] + return nn.Conv2d(dim_in, dim_out, kernel_size=kernel_size, padding=padding) + + +class CausalConv3d(Module): + @beartype + def __init__( + self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], pad_mode="constant", **kwargs + ): + super().__init__() + kernel_size = cast_tuple(kernel_size, 3) + + time_kernel_size, height_kernel_size, width_kernel_size = kernel_size + + assert is_odd(height_kernel_size) and is_odd(width_kernel_size) + + dilation = kwargs.pop("dilation", 1) + stride = kwargs.pop("stride", 1) + + self.pad_mode = pad_mode + time_pad = dilation * (time_kernel_size - 1) + (1 - stride) + height_pad = height_kernel_size // 2 + width_pad = width_kernel_size // 2 + + self.time_pad = time_pad + self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0) + + stride = (stride, 1, 1) + dilation = (dilation, 1, 1) + self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs) + + def forward(self, x): + pad_mode = self.pad_mode if self.time_pad < x.shape[2] else "constant" + + x = F.pad(x, self.time_causal_padding, mode=pad_mode) + return self.conv(x) + + +@beartype +def ResidualUnit(dim, kernel_size: Union[int, Tuple[int, int, int]], pad_mode: str = "constant"): + net = Sequential( + CausalConv3d(dim, dim, kernel_size, pad_mode=pad_mode), + nn.ELU(), + nn.Conv3d(dim, dim, 1), + nn.ELU(), + SqueezeExcite(dim), + ) + + return Residual(net) + + +@beartype +class ResidualUnitMod(Module): + def __init__( + self, dim, kernel_size: Union[int, Tuple[int, int, int]], *, dim_cond, pad_mode: str = "constant", demod=True + ): + super().__init__() + kernel_size = cast_tuple(kernel_size, 3) + time_kernel_size, height_kernel_size, width_kernel_size = kernel_size + assert height_kernel_size == width_kernel_size + + self.to_cond = nn.Linear(dim_cond, dim) + + self.conv = Conv3DMod( + dim=dim, + spatial_kernel=height_kernel_size, + time_kernel=time_kernel_size, + causal=True, + demod=demod, + pad_mode=pad_mode, + ) + + self.conv_out = nn.Conv3d(dim, dim, 1) + + @beartype + def forward( + self, + x, + cond: Tensor, + ): + res = x + cond = self.to_cond(cond) + + x = self.conv(x, cond=cond) + x = F.elu(x) + x = self.conv_out(x) + x = F.elu(x) + return x + res + + +class CausalConvTranspose3d(Module): + def __init__(self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], *, time_stride, **kwargs): + super().__init__() + kernel_size = cast_tuple(kernel_size, 3) + + time_kernel_size, height_kernel_size, width_kernel_size = kernel_size + + assert is_odd(height_kernel_size) and is_odd(width_kernel_size) + + self.upsample_factor = time_stride + + height_pad = height_kernel_size // 2 + width_pad = width_kernel_size // 2 + + stride = (time_stride, 1, 1) + padding = (0, height_pad, width_pad) + + self.conv = nn.ConvTranspose3d(chan_in, chan_out, kernel_size, stride, padding=padding, **kwargs) + + def forward(self, x): + assert x.ndim == 5 + t = x.shape[2] + + out = self.conv(x) + + out = out[..., : (t * self.upsample_factor), :, :] + return out + + +# video tokenizer class + +LossBreakdown = namedtuple( + "LossBreakdown", + [ + "recon_loss", + "lfq_aux_loss", + "quantizer_loss_breakdown", + "perceptual_loss", + "adversarial_gen_loss", + "adaptive_adversarial_weight", + "multiscale_gen_losses", + "multiscale_gen_adaptive_weights", + ], +) + +DiscrLossBreakdown = namedtuple("DiscrLossBreakdown", ["discr_loss", "multiscale_discr_losses", "gradient_penalty"]) + + +class VideoTokenizer(Module): + @beartype + def __init__( + self, + *, + image_size, + layers: Tuple[Union[str, Tuple[str, int]], ...] = ("residual", "residual", "residual"), + residual_conv_kernel_size=3, + num_codebooks=1, + codebook_size: Optional[int] = None, + channels=3, + init_dim=64, + max_dim=float("inf"), + dim_cond=None, + dim_cond_expansion_factor=4.0, + input_conv_kernel_size: Tuple[int, int, int] = (7, 7, 7), + output_conv_kernel_size: Tuple[int, int, int] = (3, 3, 3), + pad_mode: str = "constant", + lfq_entropy_loss_weight=0.1, + lfq_commitment_loss_weight=1.0, + lfq_diversity_gamma=2.5, + quantizer_aux_loss_weight=1.0, + lfq_activation=nn.Identity(), + use_fsq=False, + fsq_levels: Optional[List[int]] = None, + attn_dim_head=32, + attn_heads=8, + attn_dropout=0.0, + linear_attn_dim_head=8, + linear_attn_heads=16, + vgg: Optional[Module] = None, + vgg_weights: VGG16_Weights = VGG16_Weights.DEFAULT, + perceptual_loss_weight=1e-1, + discr_kwargs: Optional[dict] = None, + multiscale_discrs: Tuple[Module, ...] = tuple(), + use_gan=True, + adversarial_loss_weight=1.0, + grad_penalty_loss_weight=10.0, + multiscale_adversarial_loss_weight=1.0, + flash_attn=True, + separate_first_frame_encoding=False, + ): + super().__init__() + + # for autosaving the config + + _locals = locals() + _locals.pop("self", None) + _locals.pop("__class__", None) + self._configs = pickle.dumps(_locals) + + # image size + + self.channels = channels + self.image_size = image_size + + # initial encoder + + self.conv_in = CausalConv3d(channels, init_dim, input_conv_kernel_size, pad_mode=pad_mode) + + # whether to encode the first frame separately or not + + self.conv_in_first_frame = nn.Identity() + self.conv_out_first_frame = nn.Identity() + + if separate_first_frame_encoding: + self.conv_in_first_frame = SameConv2d(channels, init_dim, input_conv_kernel_size[-2:]) + self.conv_out_first_frame = SameConv2d(init_dim, channels, output_conv_kernel_size[-2:]) + + self.separate_first_frame_encoding = separate_first_frame_encoding + + # encoder and decoder layers + + self.encoder_layers = ModuleList([]) + self.decoder_layers = ModuleList([]) + + self.conv_out = CausalConv3d(init_dim, channels, output_conv_kernel_size, pad_mode=pad_mode) + + dim = init_dim + dim_out = dim + + layer_fmap_size = image_size + time_downsample_factor = 1 + has_cond_across_layers = [] + + for layer_def in layers: + layer_type, *layer_params = cast_tuple(layer_def) + + has_cond = False + + if layer_type == "residual": + encoder_layer = ResidualUnit(dim, residual_conv_kernel_size) + decoder_layer = ResidualUnit(dim, residual_conv_kernel_size) + + elif layer_type == "consecutive_residual": + (num_consecutive,) = layer_params + encoder_layer = Sequential( + *[ResidualUnit(dim, residual_conv_kernel_size) for _ in range(num_consecutive)] + ) + decoder_layer = Sequential( + *[ResidualUnit(dim, residual_conv_kernel_size) for _ in range(num_consecutive)] + ) + + elif layer_type == "cond_residual": + assert exists( + dim_cond + ), "dim_cond must be passed into VideoTokenizer, if tokenizer is to be conditioned" + + has_cond = True + + encoder_layer = ResidualUnitMod( + dim, residual_conv_kernel_size, dim_cond=int(dim_cond * dim_cond_expansion_factor) + ) + decoder_layer = ResidualUnitMod( + dim, residual_conv_kernel_size, dim_cond=int(dim_cond * dim_cond_expansion_factor) + ) + dim_out = dim + + elif layer_type == "compress_space": + dim_out = safe_get_index(layer_params, 0) + dim_out = default(dim_out, dim * 2) + dim_out = min(dim_out, max_dim) + + encoder_layer = SpatialDownsample2x(dim, dim_out) + decoder_layer = SpatialUpsample2x(dim_out, dim) + + assert layer_fmap_size > 1 + layer_fmap_size //= 2 + + elif layer_type == "compress_time": + dim_out = safe_get_index(layer_params, 0) + dim_out = default(dim_out, dim * 2) + dim_out = min(dim_out, max_dim) + + encoder_layer = TimeDownsample2x(dim, dim_out) + decoder_layer = TimeUpsample2x(dim_out, dim) + + time_downsample_factor *= 2 + + elif layer_type == "attend_space": + attn_kwargs = dict( + dim=dim, dim_head=attn_dim_head, heads=attn_heads, dropout=attn_dropout, flash=flash_attn + ) + + encoder_layer = Sequential(Residual(SpaceAttention(**attn_kwargs)), Residual(FeedForward(dim))) + + decoder_layer = Sequential(Residual(SpaceAttention(**attn_kwargs)), Residual(FeedForward(dim))) + + elif layer_type == "linear_attend_space": + linear_attn_kwargs = dict(dim=dim, dim_head=linear_attn_dim_head, heads=linear_attn_heads) + + encoder_layer = Sequential( + Residual(LinearSpaceAttention(**linear_attn_kwargs)), Residual(FeedForward(dim)) + ) + + decoder_layer = Sequential( + Residual(LinearSpaceAttention(**linear_attn_kwargs)), Residual(FeedForward(dim)) + ) + + elif layer_type == "gateloop_time": + gateloop_kwargs = dict(use_heinsen=False) + + encoder_layer = ToTimeSequence(Residual(SimpleGateLoopLayer(dim=dim))) + decoder_layer = ToTimeSequence(Residual(SimpleGateLoopLayer(dim=dim))) + + elif layer_type == "attend_time": + attn_kwargs = dict( + dim=dim, + dim_head=attn_dim_head, + heads=attn_heads, + dropout=attn_dropout, + causal=True, + flash=flash_attn, + ) + + encoder_layer = Sequential( + Residual(TokenShift(TimeAttention(**attn_kwargs))), + Residual(TokenShift(FeedForward(dim, dim_cond=dim_cond))), + ) + + decoder_layer = Sequential( + Residual(TokenShift(TimeAttention(**attn_kwargs))), + Residual(TokenShift(FeedForward(dim, dim_cond=dim_cond))), + ) + + elif layer_type == "cond_attend_space": + has_cond = True + + attn_kwargs = dict( + dim=dim, + dim_cond=dim_cond, + dim_head=attn_dim_head, + heads=attn_heads, + dropout=attn_dropout, + flash=flash_attn, + ) + + encoder_layer = Sequential(Residual(SpaceAttention(**attn_kwargs)), Residual(FeedForward(dim))) + + decoder_layer = Sequential(Residual(SpaceAttention(**attn_kwargs)), Residual(FeedForward(dim))) + + elif layer_type == "cond_linear_attend_space": + has_cond = True + + attn_kwargs = dict( + dim=dim, + dim_cond=dim_cond, + dim_head=attn_dim_head, + heads=attn_heads, + dropout=attn_dropout, + flash=flash_attn, + ) + + encoder_layer = Sequential( + Residual(LinearSpaceAttention(**attn_kwargs)), Residual(FeedForward(dim, dim_cond=dim_cond)) + ) + + decoder_layer = Sequential( + Residual(LinearSpaceAttention(**attn_kwargs)), Residual(FeedForward(dim, dim_cond=dim_cond)) + ) + + elif layer_type == "cond_attend_time": + has_cond = True + + attn_kwargs = dict( + dim=dim, + dim_cond=dim_cond, + dim_head=attn_dim_head, + heads=attn_heads, + dropout=attn_dropout, + causal=True, + flash=flash_attn, + ) + + encoder_layer = Sequential( + Residual(TokenShift(TimeAttention(**attn_kwargs))), + Residual(TokenShift(FeedForward(dim, dim_cond=dim_cond))), + ) + + decoder_layer = Sequential( + Residual(TokenShift(TimeAttention(**attn_kwargs))), + Residual(TokenShift(FeedForward(dim, dim_cond=dim_cond))), + ) + + else: + raise ValueError(f"unknown layer type {layer_type}") + + self.encoder_layers.append(encoder_layer) + self.decoder_layers.insert(0, decoder_layer) + + dim = dim_out + has_cond_across_layers.append(has_cond) + + # add a final norm just before quantization layer + + self.encoder_layers.append( + Sequential( + Rearrange("b c ... -> b ... c"), + nn.LayerNorm(dim), + Rearrange("b ... c -> b c ..."), + ) + ) + + self.time_downsample_factor = time_downsample_factor + self.time_padding = time_downsample_factor - 1 + + self.fmap_size = layer_fmap_size + + # use a MLP stem for conditioning, if needed + + self.has_cond_across_layers = has_cond_across_layers + self.has_cond = any(has_cond_across_layers) + + self.encoder_cond_in = nn.Identity() + self.decoder_cond_in = nn.Identity() + + if has_cond: + self.dim_cond = dim_cond + + self.encoder_cond_in = Sequential( + nn.Linear(dim_cond, int(dim_cond * dim_cond_expansion_factor)), nn.SiLU() + ) + + self.decoder_cond_in = Sequential( + nn.Linear(dim_cond, int(dim_cond * dim_cond_expansion_factor)), nn.SiLU() + ) + + # quantizer related + + self.use_fsq = use_fsq + + if not use_fsq: + assert exists(codebook_size) and not exists( + fsq_levels + ), "if use_fsq is set to False, `codebook_size` must be set (and not `fsq_levels`)" + + # lookup free quantizer(s) - multiple codebooks is possible + # each codebook will get its own entropy regularization + + self.quantizers = LFQ( + dim=dim, + codebook_size=codebook_size, + num_codebooks=num_codebooks, + entropy_loss_weight=lfq_entropy_loss_weight, + commitment_loss_weight=lfq_commitment_loss_weight, + diversity_gamma=lfq_diversity_gamma, + ) + + else: + assert ( + not exists(codebook_size) and exists(fsq_levels) + ), "if use_fsq is set to True, `fsq_levels` must be set (and not `codebook_size`). the effective codebook size is the cumulative product of all the FSQ levels" + + self.quantizers = FSQ(fsq_levels, dim=dim, num_codebooks=num_codebooks) + + self.quantizer_aux_loss_weight = quantizer_aux_loss_weight + + # dummy loss + + self.register_buffer("zero", torch.tensor(0.0), persistent=False) + + # perceptual loss related + + use_vgg = channels in {1, 3, 4} and perceptual_loss_weight > 0.0 + + self.vgg = None + self.perceptual_loss_weight = perceptual_loss_weight + + if use_vgg: + if not exists(vgg): + vgg = torchvision.models.vgg16(weights=vgg_weights) + + vgg.classifier = Sequential(*vgg.classifier[:-2]) + + self.vgg = vgg + + self.use_vgg = use_vgg + + # main flag for whether to use GAN at all + + self.use_gan = use_gan + + # discriminator + + discr_kwargs = default(discr_kwargs, dict(dim=dim, image_size=image_size, channels=channels, max_dim=512)) + + self.discr = Discriminator(**discr_kwargs) + + self.adversarial_loss_weight = adversarial_loss_weight + self.grad_penalty_loss_weight = grad_penalty_loss_weight + + self.has_gan = use_gan and adversarial_loss_weight > 0.0 + + # multi-scale discriminators + + self.has_multiscale_gan = use_gan and multiscale_adversarial_loss_weight > 0.0 + + self.multiscale_discrs = ModuleList([*multiscale_discrs]) + + self.multiscale_adversarial_loss_weight = multiscale_adversarial_loss_weight + + self.has_multiscale_discrs = ( + use_gan and multiscale_adversarial_loss_weight > 0.0 and len(multiscale_discrs) > 0 + ) + + @property + def device(self): + return self.zero.device + + @classmethod + def init_and_load_from(cls, path, strict=True): + path = Path(path) + assert path.exists() + pkg = torch.load(str(path), map_location="cpu") + + assert "config" in pkg, "model configs were not found in this saved checkpoint" + + config = pickle.loads(pkg["config"]) + tokenizer = cls(**config) + tokenizer.load(path, strict=strict) + return tokenizer + + def parameters(self): + return [ + *self.conv_in.parameters(), + *self.conv_in_first_frame.parameters(), + *self.conv_out_first_frame.parameters(), + *self.conv_out.parameters(), + *self.encoder_layers.parameters(), + *self.decoder_layers.parameters(), + *self.encoder_cond_in.parameters(), + *self.decoder_cond_in.parameters(), + *self.quantizers.parameters(), + ] + + def discr_parameters(self): + return self.discr.parameters() + + def copy_for_eval(self): + device = self.device + vae_copy = copy.deepcopy(self.cpu()) + + maybe_del_attr_(vae_copy, "discr") + maybe_del_attr_(vae_copy, "vgg") + maybe_del_attr_(vae_copy, "multiscale_discrs") + + vae_copy.eval() + return vae_copy.to(device) + + @remove_vgg + def state_dict(self, *args, **kwargs): + return super().state_dict(*args, **kwargs) + + @remove_vgg + def load_state_dict(self, *args, **kwargs): + return super().load_state_dict(*args, **kwargs) + + def save(self, path, overwrite=True): + path = Path(path) + assert overwrite or not path.exists(), f"{str(path)} already exists" + + pkg = dict(model_state_dict=self.state_dict(), version=__version__, config=self._configs) + + torch.save(pkg, str(path)) + + def load(self, path, strict=True): + path = Path(path) + assert path.exists() + + pkg = torch.load(str(path)) + state_dict = pkg.get("model_state_dict") + version = pkg.get("version") + + assert exists(state_dict) + + if exists(version): + print(f"loading checkpointed tokenizer from version {version}") + + self.load_state_dict(state_dict, strict=strict) + + @beartype + def encode(self, video: Tensor, quantize=False, cond: Optional[Tensor] = None, video_contains_first_frame=True): + encode_first_frame_separately = self.separate_first_frame_encoding and video_contains_first_frame + + # whether to pad video or not + + if video_contains_first_frame: + video_len = video.shape[2] + + video = pad_at_dim(video, (self.time_padding, 0), value=0.0, dim=2) + video_packed_shape = [torch.Size([self.time_padding]), torch.Size([]), torch.Size([video_len - 1])] + + # conditioning, if needed + + assert (not self.has_cond) or exists( + cond + ), "`cond` must be passed into tokenizer forward method since conditionable layers were specified" + + if exists(cond): + assert cond.shape == (video.shape[0], self.dim_cond) + + cond = self.encoder_cond_in(cond) + cond_kwargs = dict(cond=cond) + + # initial conv + # taking into account whether to encode first frame separately + + if encode_first_frame_separately: + pad, first_frame, video = unpack(video, video_packed_shape, "b c * h w") + first_frame = self.conv_in_first_frame(first_frame) + + video = self.conv_in(video) + + if encode_first_frame_separately: + video, _ = pack([first_frame, video], "b c * h w") + video = pad_at_dim(video, (self.time_padding, 0), dim=2) + + # encoder layers + + for fn, has_cond in zip(self.encoder_layers, self.has_cond_across_layers): + layer_kwargs = dict() + + if has_cond: + layer_kwargs = cond_kwargs + + video = fn(video, **layer_kwargs) + + maybe_quantize = identity if not quantize else self.quantizers + + return maybe_quantize(video) + + @beartype + def decode_from_code_indices(self, codes: Tensor, cond: Optional[Tensor] = None, video_contains_first_frame=True): + assert codes.dtype in (torch.long, torch.int32) + + if codes.ndim == 2: + video_code_len = codes.shape[-1] + assert divisible_by( + video_code_len, self.fmap_size**2 + ), f"flattened video ids must have a length ({video_code_len}) that is divisible by the fmap size ({self.fmap_size}) squared ({self.fmap_size ** 2})" + + codes = rearrange(codes, "b (f h w) -> b f h w", h=self.fmap_size, w=self.fmap_size) + + quantized = self.quantizers.indices_to_codes(codes) + + return self.decode(quantized, cond=cond, video_contains_first_frame=video_contains_first_frame) + + @beartype + def decode(self, quantized: Tensor, cond: Optional[Tensor] = None, video_contains_first_frame=True): + decode_first_frame_separately = self.separate_first_frame_encoding and video_contains_first_frame + + batch = quantized.shape[0] + + # conditioning, if needed + + assert (not self.has_cond) or exists( + cond + ), "`cond` must be passed into tokenizer forward method since conditionable layers were specified" + + if exists(cond): + assert cond.shape == (batch, self.dim_cond) + + cond = self.decoder_cond_in(cond) + cond_kwargs = dict(cond=cond) + + # decoder layers + + x = quantized + + for fn, has_cond in zip(self.decoder_layers, reversed(self.has_cond_across_layers)): + layer_kwargs = dict() + + if has_cond: + layer_kwargs = cond_kwargs + + x = fn(x, **layer_kwargs) + + # to pixels + + if decode_first_frame_separately: + left_pad, xff, x = ( + x[:, :, : self.time_padding], + x[:, :, self.time_padding], + x[:, :, (self.time_padding + 1) :], + ) + + out = self.conv_out(x) + outff = self.conv_out_first_frame(xff) + + video, _ = pack([outff, out], "b c * h w") + + else: + video = self.conv_out(x) + + # if video were padded, remove padding + + if video_contains_first_frame: + video = video[:, :, self.time_padding :] + + return video + + @torch.no_grad() + def tokenize(self, video): + self.eval() + return self.forward(video, return_codes=True) + + @beartype + def forward( + self, + video_or_images: Tensor, + cond: Optional[Tensor] = None, + return_loss=False, + return_codes=False, + return_recon=False, + return_discr_loss=False, + return_recon_loss_only=False, + apply_gradient_penalty=True, + video_contains_first_frame=True, + adversarial_loss_weight=None, + multiscale_adversarial_loss_weight=None, + ): + adversarial_loss_weight = default(adversarial_loss_weight, self.adversarial_loss_weight) + multiscale_adversarial_loss_weight = default( + multiscale_adversarial_loss_weight, self.multiscale_adversarial_loss_weight + ) + + assert (return_loss + return_codes + return_discr_loss) <= 1 + assert video_or_images.ndim in {4, 5} + + assert video_or_images.shape[-2:] == (self.image_size, self.image_size) + + # accept images for image pretraining (curriculum learning from images to video) + + is_image = video_or_images.ndim == 4 + + if is_image: + video = rearrange(video_or_images, "b c ... -> b c 1 ...") + video_contains_first_frame = True + else: + video = video_or_images + + batch, channels, frames = video.shape[:3] + + assert divisible_by( + frames - int(video_contains_first_frame), self.time_downsample_factor + ), f"number of frames {frames} minus the first frame ({frames - int(video_contains_first_frame)}) must be divisible by the total downsample factor across time {self.time_downsample_factor}" + + # encoder + + x = self.encode(video, cond=cond, video_contains_first_frame=video_contains_first_frame) + + # lookup free quantization + + if self.use_fsq: + quantized, codes = self.quantizers(x) + + aux_losses = self.zero + quantizer_loss_breakdown = None + else: + (quantized, codes, aux_losses), quantizer_loss_breakdown = self.quantizers(x, return_loss_breakdown=True) + + if return_codes and not return_recon: + return codes + + # decoder + + recon_video = self.decode(quantized, cond=cond, video_contains_first_frame=video_contains_first_frame) + + if return_codes: + return codes, recon_video + + # reconstruction loss + + if not (return_loss or return_discr_loss or return_recon_loss_only): + return recon_video + + recon_loss = F.mse_loss(video, recon_video) + + # for validation, only return recon loss + + if return_recon_loss_only: + return recon_loss, recon_video + + # gan discriminator loss + + if return_discr_loss: + assert self.has_gan + assert exists(self.discr) + + # pick a random frame for image discriminator + + frame_indices = torch.randn((batch, frames)).topk(1, dim=-1).indices + + real = pick_video_frame(video, frame_indices) + + if apply_gradient_penalty: + real = real.requires_grad_() + + fake = pick_video_frame(recon_video, frame_indices) + + real_logits = self.discr(real) + fake_logits = self.discr(fake.detach()) + + discr_loss = hinge_discr_loss(fake_logits, real_logits) + + # multiscale discriminators + + multiscale_discr_losses = [] + + if self.has_multiscale_discrs: + for discr in self.multiscale_discrs: + multiscale_real_logits = discr(video) + multiscale_fake_logits = discr(recon_video.detach()) + + multiscale_discr_loss = hinge_discr_loss(multiscale_fake_logits, multiscale_real_logits) + + multiscale_discr_losses.append(multiscale_discr_loss) + else: + multiscale_discr_losses.append(self.zero) + + # gradient penalty + + if apply_gradient_penalty: + gradient_penalty_loss = gradient_penalty(real, real_logits) + else: + gradient_penalty_loss = self.zero + + # total loss + + total_loss = ( + discr_loss + + gradient_penalty_loss * self.grad_penalty_loss_weight + + sum(multiscale_discr_losses) * self.multiscale_adversarial_loss_weight + ) + + discr_loss_breakdown = DiscrLossBreakdown(discr_loss, multiscale_discr_losses, gradient_penalty_loss) + + return total_loss, discr_loss_breakdown + + # perceptual loss + + if self.use_vgg: + frame_indices = torch.randn((batch, frames)).topk(1, dim=-1).indices + + input_vgg_input = pick_video_frame(video, frame_indices) + recon_vgg_input = pick_video_frame(recon_video, frame_indices) + + if channels == 1: + input_vgg_input = repeat(input_vgg_input, "b 1 h w -> b c h w", c=3) + recon_vgg_input = repeat(recon_vgg_input, "b 1 h w -> b c h w", c=3) + + elif channels == 4: + input_vgg_input = input_vgg_input[:, :3] + recon_vgg_input = recon_vgg_input[:, :3] + + input_vgg_feats = self.vgg(input_vgg_input) + recon_vgg_feats = self.vgg(recon_vgg_input) + + perceptual_loss = F.mse_loss(input_vgg_feats, recon_vgg_feats) + else: + perceptual_loss = self.zero + + # get gradient with respect to perceptual loss for last decoder layer + # needed for adaptive weighting + + last_dec_layer = self.conv_out.conv.weight + + norm_grad_wrt_perceptual_loss = None + + if self.training and self.use_vgg and (self.has_gan or self.has_multiscale_discrs): + norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_dec_layer).norm(p=2) + + # per-frame image discriminator + + recon_video_frames = None + + if self.has_gan: + frame_indices = torch.randn((batch, frames)).topk(1, dim=-1).indices + recon_video_frames = pick_video_frame(recon_video, frame_indices) + + fake_logits = self.discr(recon_video_frames) + gen_loss = hinge_gen_loss(fake_logits) + + adaptive_weight = 1.0 + + if exists(norm_grad_wrt_perceptual_loss): + norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_dec_layer).norm(p=2) + adaptive_weight = norm_grad_wrt_perceptual_loss / norm_grad_wrt_gen_loss.clamp(min=1e-3) + adaptive_weight.clamp_(max=1e3) + + if torch.isnan(adaptive_weight).any(): + adaptive_weight = 1.0 + else: + gen_loss = self.zero + adaptive_weight = 0.0 + + # multiscale discriminator losses + + multiscale_gen_losses = [] + multiscale_gen_adaptive_weights = [] + + if self.has_multiscale_gan and self.has_multiscale_discrs: + if not exists(recon_video_frames): + recon_video_frames = pick_video_frame(recon_video, frame_indices) + + for discr in self.multiscale_discrs: + fake_logits = recon_video_frames + multiscale_gen_loss = hinge_gen_loss(fake_logits) + + multiscale_gen_losses.append(multiscale_gen_loss) + + multiscale_adaptive_weight = 1.0 + + if exists(norm_grad_wrt_perceptual_loss): + norm_grad_wrt_gen_loss = grad_layer_wrt_loss(multiscale_gen_loss, last_dec_layer).norm(p=2) + multiscale_adaptive_weight = norm_grad_wrt_perceptual_loss / norm_grad_wrt_gen_loss.clamp(min=1e-5) + multiscale_adaptive_weight.clamp_(max=1e3) + + multiscale_gen_adaptive_weights.append(multiscale_adaptive_weight) + + # calculate total loss + + total_loss = ( + recon_loss + + aux_losses * self.quantizer_aux_loss_weight + + perceptual_loss * self.perceptual_loss_weight + + gen_loss * adaptive_weight * adversarial_loss_weight + ) + + if self.has_multiscale_discrs: + weighted_multiscale_gen_losses = sum( + loss * weight for loss, weight in zip(multiscale_gen_losses, multiscale_gen_adaptive_weights) + ) + + total_loss = total_loss + weighted_multiscale_gen_losses * multiscale_adversarial_loss_weight + + # loss breakdown + + loss_breakdown = LossBreakdown( + recon_loss, + aux_losses, + quantizer_loss_breakdown, + perceptual_loss, + gen_loss, + adaptive_weight, + multiscale_gen_losses, + multiscale_gen_adaptive_weights, + ) + + return total_loss, loss_breakdown + + +# main class + + +class MagViT2(Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/regularizers/__init__.py b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/regularizers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6065fb209b6cb6fb4e0cb601c895c2a35e0044e9 --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/regularizers/__init__.py @@ -0,0 +1,30 @@ +from abc import abstractmethod +from typing import Any, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ....modules.distributions.distributions import DiagonalGaussianDistribution +from .base import AbstractRegularizer + + +class DiagonalGaussianRegularizer(AbstractRegularizer): + def __init__(self, sample: bool = True): + super().__init__() + self.sample = sample + + def get_trainable_parameters(self) -> Any: + yield from () + + def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: + log = dict() + posterior = DiagonalGaussianDistribution(z) + if self.sample: + z = posterior.sample() + else: + z = posterior.mode() + kl_loss = posterior.kl() + kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] + log["kl_loss"] = kl_loss + return z, log diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/regularizers/base.py b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/regularizers/base.py new file mode 100644 index 0000000000000000000000000000000000000000..7f455be98a6f1b5d8647b423de6c3aaeb24d3e23 --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/regularizers/base.py @@ -0,0 +1,36 @@ +from abc import abstractmethod +from typing import Any, Tuple + +import torch +import torch.nn.functional as F +from torch import nn + + +class AbstractRegularizer(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: + raise NotImplementedError() + + @abstractmethod + def get_trainable_parameters(self) -> Any: + raise NotImplementedError() + + +class IdentityRegularizer(AbstractRegularizer): + def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: + return z, dict() + + def get_trainable_parameters(self) -> Any: + yield from () + + +def measure_perplexity(predicted_indices: torch.Tensor, num_centroids: int) -> Tuple[torch.Tensor, torch.Tensor]: + # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py + # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally + encodings = F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids) + avg_probs = encodings.mean(0) + perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() + cluster_use = torch.sum(avg_probs > 0) + return perplexity, cluster_use diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/regularizers/finite_scalar_quantization.py b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/regularizers/finite_scalar_quantization.py new file mode 100644 index 0000000000000000000000000000000000000000..5a20dd63ef18259ac4242438f2c1a393e5ef938d --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/regularizers/finite_scalar_quantization.py @@ -0,0 +1,180 @@ +""" +Finite Scalar Quantization: VQ-VAE Made Simple - https://arxiv.org/abs/2309.15505 +Code adapted from Jax version in Appendix A.1 +""" + +from typing import List, Optional + +import torch +import torch.nn as nn +from torch.nn import Module +from torch import Tensor, int32 +from torch.cuda.amp import autocast + +from einops import rearrange, pack, unpack + +# helper functions + + +def exists(v): + return v is not None + + +def default(*args): + for arg in args: + if exists(arg): + return arg + return None + + +def pack_one(t, pattern): + return pack([t], pattern) + + +def unpack_one(t, ps, pattern): + return unpack(t, ps, pattern)[0] + + +# tensor helpers + + +def round_ste(z: Tensor) -> Tensor: + """Round with straight through gradients.""" + zhat = z.round() + return z + (zhat - z).detach() + + +# main class + + +class FSQ(Module): + def __init__( + self, + levels: List[int], + dim: Optional[int] = None, + num_codebooks=1, + keep_num_codebooks_dim: Optional[bool] = None, + scale: Optional[float] = None, + ): + super().__init__() + _levels = torch.tensor(levels, dtype=int32) + self.register_buffer("_levels", _levels, persistent=False) + + _basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=int32) + self.register_buffer("_basis", _basis, persistent=False) + + self.scale = scale + + codebook_dim = len(levels) + self.codebook_dim = codebook_dim + + effective_codebook_dim = codebook_dim * num_codebooks + self.num_codebooks = num_codebooks + self.effective_codebook_dim = effective_codebook_dim + + keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1) + assert not (num_codebooks > 1 and not keep_num_codebooks_dim) + self.keep_num_codebooks_dim = keep_num_codebooks_dim + + self.dim = default(dim, len(_levels) * num_codebooks) + + has_projections = self.dim != effective_codebook_dim + self.project_in = nn.Linear(self.dim, effective_codebook_dim) if has_projections else nn.Identity() + self.project_out = nn.Linear(effective_codebook_dim, self.dim) if has_projections else nn.Identity() + self.has_projections = has_projections + + self.codebook_size = self._levels.prod().item() + + implicit_codebook = self.indices_to_codes(torch.arange(self.codebook_size), project_out=False) + self.register_buffer("implicit_codebook", implicit_codebook, persistent=False) + + def bound(self, z: Tensor, eps: float = 1e-3) -> Tensor: + """Bound `z`, an array of shape (..., d).""" + half_l = (self._levels - 1) * (1 + eps) / 2 + offset = torch.where(self._levels % 2 == 0, 0.5, 0.0) + shift = (offset / half_l).atanh() + return (z + shift).tanh() * half_l - offset + + def quantize(self, z: Tensor) -> Tensor: + """Quantizes z, returns quantized zhat, same shape as z.""" + quantized = round_ste(self.bound(z)) + half_width = self._levels // 2 # Renormalize to [-1, 1]. + return quantized / half_width + + def _scale_and_shift(self, zhat_normalized: Tensor) -> Tensor: + half_width = self._levels // 2 + return (zhat_normalized * half_width) + half_width + + def _scale_and_shift_inverse(self, zhat: Tensor) -> Tensor: + half_width = self._levels // 2 + return (zhat - half_width) / half_width + + def codes_to_indices(self, zhat: Tensor) -> Tensor: + """Converts a `code` to an index in the codebook.""" + assert zhat.shape[-1] == self.codebook_dim + zhat = self._scale_and_shift(zhat) + return (zhat * self._basis).sum(dim=-1).to(int32) + + def indices_to_codes(self, indices: Tensor, project_out=True) -> Tensor: + """Inverse of `codes_to_indices`.""" + + is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim)) + + indices = rearrange(indices, "... -> ... 1") + codes_non_centered = (indices // self._basis) % self._levels + codes = self._scale_and_shift_inverse(codes_non_centered) + + if self.keep_num_codebooks_dim: + codes = rearrange(codes, "... c d -> ... (c d)") + + if project_out: + codes = self.project_out(codes) + + if is_img_or_video: + codes = rearrange(codes, "b ... d -> b d ...") + + return codes + + @autocast(enabled=False) + def forward(self, z: Tensor) -> Tensor: + """ + einstein notation + b - batch + n - sequence (or flattened spatial dimensions) + d - feature dimension + c - number of codebook dim + """ + + is_img_or_video = z.ndim >= 4 + + # standardize image or video into (batch, seq, dimension) + + if is_img_or_video: + z = rearrange(z, "b d ... -> b ... d") + z, ps = pack_one(z, "b * d") + + assert z.shape[-1] == self.dim, f"expected dimension of {self.dim} but found dimension of {z.shape[-1]}" + + z = self.project_in(z) + + z = rearrange(z, "b n (c d) -> b n c d", c=self.num_codebooks) + + codes = self.quantize(z) + indices = self.codes_to_indices(codes) + + codes = rearrange(codes, "b n c d -> b n (c d)") + + out = self.project_out(codes) + + # reconstitute image or video dimensions + + if is_img_or_video: + out = unpack_one(out, ps, "b * d") + out = rearrange(out, "b ... d -> b d ...") + + indices = unpack_one(indices, ps, "b * c") + + if not self.keep_num_codebooks_dim: + indices = rearrange(indices, "... 1 -> ...") + + return out, indices diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/regularizers/lookup_free_quantization.py b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/regularizers/lookup_free_quantization.py new file mode 100644 index 0000000000000000000000000000000000000000..beca8884d284bd62bca9e6b4bfd137b07674362e --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/regularizers/lookup_free_quantization.py @@ -0,0 +1,309 @@ +""" +Lookup Free Quantization +Proposed in https://arxiv.org/abs/2310.05737 + +In the simplest setup, each dimension is quantized into {-1, 1}. +An entropy penalty is used to encourage utilization. +""" + +from math import log2, ceil +from collections import namedtuple + +import torch +from torch import nn, einsum +import torch.nn.functional as F +from torch.nn import Module +from torch.cuda.amp import autocast + +from einops import rearrange, reduce, pack, unpack + +# constants + +Return = namedtuple("Return", ["quantized", "indices", "entropy_aux_loss"]) + +LossBreakdown = namedtuple("LossBreakdown", ["per_sample_entropy", "batch_entropy", "commitment"]) + +# helper functions + + +def exists(v): + return v is not None + + +def default(*args): + for arg in args: + if exists(arg): + return arg() if callable(arg) else arg + return None + + +def pack_one(t, pattern): + return pack([t], pattern) + + +def unpack_one(t, ps, pattern): + return unpack(t, ps, pattern)[0] + + +# entropy + + +def log(t, eps=1e-5): + return t.clamp(min=eps).log() + + +def entropy(prob): + return (-prob * log(prob)).sum(dim=-1) + + +# class + + +class LFQ(Module): + def __init__( + self, + *, + dim=None, + codebook_size=None, + entropy_loss_weight=0.1, + commitment_loss_weight=0.25, + diversity_gamma=1.0, + straight_through_activation=nn.Identity(), + num_codebooks=1, + keep_num_codebooks_dim=None, + codebook_scale=1.0, # for residual LFQ, codebook scaled down by 2x at each layer + frac_per_sample_entropy=1.0, # make less than 1. to only use a random fraction of the probs for per sample entropy + ): + super().__init__() + + # some assert validations + + assert exists(dim) or exists(codebook_size), "either dim or codebook_size must be specified for LFQ" + assert ( + not exists(codebook_size) or log2(codebook_size).is_integer() + ), f"your codebook size must be a power of 2 for lookup free quantization (suggested {2 ** ceil(log2(codebook_size))})" + + codebook_size = default(codebook_size, lambda: 2**dim) + codebook_dim = int(log2(codebook_size)) + + codebook_dims = codebook_dim * num_codebooks + dim = default(dim, codebook_dims) + + has_projections = dim != codebook_dims + self.project_in = nn.Linear(dim, codebook_dims) if has_projections else nn.Identity() + self.project_out = nn.Linear(codebook_dims, dim) if has_projections else nn.Identity() + self.has_projections = has_projections + + self.dim = dim + self.codebook_dim = codebook_dim + self.num_codebooks = num_codebooks + + keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1) + assert not (num_codebooks > 1 and not keep_num_codebooks_dim) + self.keep_num_codebooks_dim = keep_num_codebooks_dim + + # straight through activation + + self.activation = straight_through_activation + + # entropy aux loss related weights + + assert 0 < frac_per_sample_entropy <= 1.0 + self.frac_per_sample_entropy = frac_per_sample_entropy + + self.diversity_gamma = diversity_gamma + self.entropy_loss_weight = entropy_loss_weight + + # codebook scale + + self.codebook_scale = codebook_scale + + # commitment loss + + self.commitment_loss_weight = commitment_loss_weight + + # for no auxiliary loss, during inference + + self.register_buffer("mask", 2 ** torch.arange(codebook_dim - 1, -1, -1)) + self.register_buffer("zero", torch.tensor(0.0), persistent=False) + + # codes + + all_codes = torch.arange(codebook_size) + bits = ((all_codes[..., None].int() & self.mask) != 0).float() + codebook = self.bits_to_codes(bits) + + self.register_buffer("codebook", codebook, persistent=False) + + def bits_to_codes(self, bits): + return bits * self.codebook_scale * 2 - self.codebook_scale + + @property + def dtype(self): + return self.codebook.dtype + + def indices_to_codes(self, indices, project_out=True): + is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim)) + + if not self.keep_num_codebooks_dim: + indices = rearrange(indices, "... -> ... 1") + + # indices to codes, which are bits of either -1 or 1 + + bits = ((indices[..., None].int() & self.mask) != 0).to(self.dtype) + + codes = self.bits_to_codes(bits) + + codes = rearrange(codes, "... c d -> ... (c d)") + + # whether to project codes out to original dimensions + # if the input feature dimensions were not log2(codebook size) + + if project_out: + codes = self.project_out(codes) + + # rearrange codes back to original shape + + if is_img_or_video: + codes = rearrange(codes, "b ... d -> b d ...") + + return codes + + @autocast(enabled=False) + def forward( + self, + x, + inv_temperature=100.0, + return_loss_breakdown=False, + mask=None, + ): + """ + einstein notation + b - batch + n - sequence (or flattened spatial dimensions) + d - feature dimension, which is also log2(codebook size) + c - number of codebook dim + """ + + x = x.float() + + is_img_or_video = x.ndim >= 4 + + # standardize image or video into (batch, seq, dimension) + + if is_img_or_video: + x = rearrange(x, "b d ... -> b ... d") + x, ps = pack_one(x, "b * d") + + assert x.shape[-1] == self.dim, f"expected dimension of {self.dim} but received {x.shape[-1]}" + + x = self.project_in(x) + + # split out number of codebooks + + x = rearrange(x, "b n (c d) -> b n c d", c=self.num_codebooks) + + # quantize by eq 3. + + original_input = x + + codebook_value = torch.ones_like(x) * self.codebook_scale + quantized = torch.where(x > 0, codebook_value, -codebook_value) + + # use straight-through gradients (optionally with custom activation fn) if training + + if self.training: + x = self.activation(x) + x = x + (quantized - x).detach() + else: + x = quantized + + # calculate indices + + indices = reduce((x > 0).int() * self.mask.int(), "b n c d -> b n c", "sum") + + # entropy aux loss + + if self.training: + # the same as euclidean distance up to a constant + distance = -2 * einsum("... i d, j d -> ... i j", original_input, self.codebook) + + prob = (-distance * inv_temperature).softmax(dim=-1) + + # account for mask + + if exists(mask): + prob = prob[mask] + else: + prob = rearrange(prob, "b n ... -> (b n) ...") + + # whether to only use a fraction of probs, for reducing memory + + if self.frac_per_sample_entropy < 1.0: + num_tokens = prob.shape[0] + num_sampled_tokens = int(num_tokens * self.frac_per_sample_entropy) + rand_mask = torch.randn(num_tokens).argsort(dim=-1) < num_sampled_tokens + per_sample_probs = prob[rand_mask] + else: + per_sample_probs = prob + + # calculate per sample entropy + + per_sample_entropy = entropy(per_sample_probs).mean() + + # distribution over all available tokens in the batch + + avg_prob = reduce(per_sample_probs, "... c d -> c d", "mean") + codebook_entropy = entropy(avg_prob).mean() + + # 1. entropy will be nudged to be low for each code, to encourage the network to output confident predictions + # 2. codebook entropy will be nudged to be high, to encourage all codes to be uniformly used within the batch + + entropy_aux_loss = per_sample_entropy - self.diversity_gamma * codebook_entropy + else: + # if not training, just return dummy 0 + entropy_aux_loss = per_sample_entropy = codebook_entropy = self.zero + + # commit loss + + if self.training: + commit_loss = F.mse_loss(original_input, quantized.detach(), reduction="none") + + if exists(mask): + commit_loss = commit_loss[mask] + + commit_loss = commit_loss.mean() + else: + commit_loss = self.zero + + # merge back codebook dim + + x = rearrange(x, "b n c d -> b n (c d)") + + # project out to feature dimension if needed + + x = self.project_out(x) + + # reconstitute image or video dimensions + + if is_img_or_video: + x = unpack_one(x, ps, "b * d") + x = rearrange(x, "b ... d -> b d ...") + + indices = unpack_one(indices, ps, "b * c") + + # whether to remove single codebook dim + + if not self.keep_num_codebooks_dim: + indices = rearrange(indices, "... 1 -> ...") + + # complete aux loss + + aux_loss = entropy_aux_loss * self.entropy_loss_weight + commit_loss * self.commitment_loss_weight + + ret = Return(x, indices, aux_loss) + + if not return_loss_breakdown: + return ret + + return ret, LossBreakdown(per_sample_entropy, codebook_entropy, commit_loss) diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/regularizers/quantize.py b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/regularizers/quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..583f488c25e2283352176f7443a3233b3d4f926f --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/regularizers/quantize.py @@ -0,0 +1,453 @@ +import logging +from abc import abstractmethod +from typing import Dict, Iterator, Literal, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch import einsum + +from .base import AbstractRegularizer, measure_perplexity + +logpy = logging.getLogger(__name__) + + +class AbstractQuantizer(AbstractRegularizer): + def __init__(self): + super().__init__() + # Define these in your init + # shape (N,) + self.used: Optional[torch.Tensor] + self.re_embed: int + self.unknown_index: Union[Literal["random"], int] + + def remap_to_used(self, inds: torch.Tensor) -> torch.Tensor: + assert self.used is not None, "You need to define used indices for remap" + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + match = (inds[:, :, None] == used[None, None, ...]).long() + new = match.argmax(-1) + unknown = match.sum(2) < 1 + if self.unknown_index == "random": + new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device) + else: + new[unknown] = self.unknown_index + return new.reshape(ishape) + + def unmap_to_all(self, inds: torch.Tensor) -> torch.Tensor: + assert self.used is not None, "You need to define used indices for remap" + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + if self.re_embed > self.used.shape[0]: # extra token + inds[inds >= self.used.shape[0]] = 0 # simply set to zero + back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) + return back.reshape(ishape) + + @abstractmethod + def get_codebook_entry(self, indices: torch.Tensor, shape: Optional[Tuple[int, ...]] = None) -> torch.Tensor: + raise NotImplementedError() + + def get_trainable_parameters(self) -> Iterator[torch.nn.Parameter]: + yield from self.parameters() + + +class GumbelQuantizer(AbstractQuantizer): + """ + credit to @karpathy: + https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!) + Gumbel Softmax trick quantizer + Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016 + https://arxiv.org/abs/1611.01144 + """ + + def __init__( + self, + num_hiddens: int, + embedding_dim: int, + n_embed: int, + straight_through: bool = True, + kl_weight: float = 5e-4, + temp_init: float = 1.0, + remap: Optional[str] = None, + unknown_index: str = "random", + loss_key: str = "loss/vq", + ) -> None: + super().__init__() + + self.loss_key = loss_key + self.embedding_dim = embedding_dim + self.n_embed = n_embed + + self.straight_through = straight_through + self.temperature = temp_init + self.kl_weight = kl_weight + + self.proj = nn.Conv2d(num_hiddens, n_embed, 1) + self.embed = nn.Embedding(n_embed, embedding_dim) + + self.remap = remap + if self.remap is not None: + self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.re_embed = self.used.shape[0] + else: + self.used = None + self.re_embed = n_embed + if unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed + 1 + else: + assert unknown_index == "random" or isinstance( + unknown_index, int + ), "unknown index needs to be 'random', 'extra' or any integer" + self.unknown_index = unknown_index # "random" or "extra" or integer + if self.remap is not None: + logpy.info( + f"Remapping {self.n_embed} indices to {self.re_embed} indices. " + f"Using {self.unknown_index} for unknown indices." + ) + + def forward( + self, z: torch.Tensor, temp: Optional[float] = None, return_logits: bool = False + ) -> Tuple[torch.Tensor, Dict]: + # force hard = True when we are in eval mode, as we must quantize. + # actually, always true seems to work + hard = self.straight_through if self.training else True + temp = self.temperature if temp is None else temp + out_dict = {} + logits = self.proj(z) + if self.remap is not None: + # continue only with used logits + full_zeros = torch.zeros_like(logits) + logits = logits[:, self.used, ...] + + soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard) + if self.remap is not None: + # go back to all entries but unused set to zero + full_zeros[:, self.used, ...] = soft_one_hot + soft_one_hot = full_zeros + z_q = einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight) + + # + kl divergence to the prior loss + qy = F.softmax(logits, dim=1) + diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean() + out_dict[self.loss_key] = diff + + ind = soft_one_hot.argmax(dim=1) + out_dict["indices"] = ind + if self.remap is not None: + ind = self.remap_to_used(ind) + + if return_logits: + out_dict["logits"] = logits + + return z_q, out_dict + + def get_codebook_entry(self, indices, shape): + # TODO: shape not yet optional + b, h, w, c = shape + assert b * h * w == indices.shape[0] + indices = rearrange(indices, "(b h w) -> b h w", b=b, h=h, w=w) + if self.remap is not None: + indices = self.unmap_to_all(indices) + one_hot = F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float() + z_q = einsum("b n h w, n d -> b d h w", one_hot, self.embed.weight) + return z_q + + +class VectorQuantizer(AbstractQuantizer): + """ + ____________________________________________ + Discretization bottleneck part of the VQ-VAE. + Inputs: + - n_e : number of embeddings + - e_dim : dimension of embedding + - beta : commitment cost used in loss term, + beta * ||z_e(x)-sg[e]||^2 + _____________________________________________ + """ + + def __init__( + self, + n_e: int, + e_dim: int, + beta: float = 0.25, + remap: Optional[str] = None, + unknown_index: str = "random", + sane_index_shape: bool = False, + log_perplexity: bool = False, + embedding_weight_norm: bool = False, + loss_key: str = "loss/vq", + ): + super().__init__() + self.n_e = n_e + self.e_dim = e_dim + self.beta = beta + self.loss_key = loss_key + + if not embedding_weight_norm: + self.embedding = nn.Embedding(self.n_e, self.e_dim) + self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) + else: + self.embedding = torch.nn.utils.weight_norm(nn.Embedding(self.n_e, self.e_dim), dim=1) + + self.remap = remap + if self.remap is not None: + self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.re_embed = self.used.shape[0] + else: + self.used = None + self.re_embed = n_e + if unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed + 1 + else: + assert unknown_index == "random" or isinstance( + unknown_index, int + ), "unknown index needs to be 'random', 'extra' or any integer" + self.unknown_index = unknown_index # "random" or "extra" or integer + if self.remap is not None: + logpy.info( + f"Remapping {self.n_e} indices to {self.re_embed} indices. " + f"Using {self.unknown_index} for unknown indices." + ) + + self.sane_index_shape = sane_index_shape + self.log_perplexity = log_perplexity + + def forward( + self, + z: torch.Tensor, + ) -> Tuple[torch.Tensor, Dict]: + do_reshape = z.ndim == 4 + if do_reshape: + # # reshape z -> (batch, height, width, channel) and flatten + z = rearrange(z, "b c h w -> b h w c").contiguous() + + else: + assert z.ndim < 4, "No reshaping strategy for inputs > 4 dimensions defined" + z = z.contiguous() + + z_flattened = z.view(-1, self.e_dim) + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + + d = ( + torch.sum(z_flattened**2, dim=1, keepdim=True) + + torch.sum(self.embedding.weight**2, dim=1) + - 2 * torch.einsum("bd,dn->bn", z_flattened, rearrange(self.embedding.weight, "n d -> d n")) + ) + + min_encoding_indices = torch.argmin(d, dim=1) + z_q = self.embedding(min_encoding_indices).view(z.shape) + loss_dict = {} + if self.log_perplexity: + perplexity, cluster_usage = measure_perplexity(min_encoding_indices.detach(), self.n_e) + loss_dict.update({"perplexity": perplexity, "cluster_usage": cluster_usage}) + + # compute loss for embedding + loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2) + loss_dict[self.loss_key] = loss + + # preserve gradients + z_q = z + (z_q - z).detach() + + # reshape back to match original input shape + if do_reshape: + z_q = rearrange(z_q, "b h w c -> b c h w").contiguous() + + if self.remap is not None: + min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis + min_encoding_indices = self.remap_to_used(min_encoding_indices) + min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten + + if self.sane_index_shape: + if do_reshape: + min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3]) + else: + min_encoding_indices = rearrange(min_encoding_indices, "(b s) 1 -> b s", b=z_q.shape[0]) + + loss_dict["min_encoding_indices"] = min_encoding_indices + + return z_q, loss_dict + + def get_codebook_entry(self, indices: torch.Tensor, shape: Optional[Tuple[int, ...]] = None) -> torch.Tensor: + # shape specifying (batch, height, width, channel) + if self.remap is not None: + assert shape is not None, "Need to give shape for remap" + indices = indices.reshape(shape[0], -1) # add batch axis + indices = self.unmap_to_all(indices) + indices = indices.reshape(-1) # flatten again + + # get quantized latent vectors + z_q = self.embedding(indices) + + if shape is not None: + z_q = z_q.view(shape) + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q + + +class EmbeddingEMA(nn.Module): + def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5): + super().__init__() + self.decay = decay + self.eps = eps + weight = torch.randn(num_tokens, codebook_dim) + self.weight = nn.Parameter(weight, requires_grad=False) + self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad=False) + self.embed_avg = nn.Parameter(weight.clone(), requires_grad=False) + self.update = True + + def forward(self, embed_id): + return F.embedding(embed_id, self.weight) + + def cluster_size_ema_update(self, new_cluster_size): + self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay) + + def embed_avg_ema_update(self, new_embed_avg): + self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay) + + def weight_update(self, num_tokens): + n = self.cluster_size.sum() + smoothed_cluster_size = (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n + # normalize embedding average with smoothed cluster size + embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1) + self.weight.data.copy_(embed_normalized) + + +class EMAVectorQuantizer(AbstractQuantizer): + def __init__( + self, + n_embed: int, + embedding_dim: int, + beta: float, + decay: float = 0.99, + eps: float = 1e-5, + remap: Optional[str] = None, + unknown_index: str = "random", + loss_key: str = "loss/vq", + ): + super().__init__() + self.codebook_dim = embedding_dim + self.num_tokens = n_embed + self.beta = beta + self.loss_key = loss_key + + self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps) + + self.remap = remap + if self.remap is not None: + self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.re_embed = self.used.shape[0] + else: + self.used = None + self.re_embed = n_embed + if unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed + 1 + else: + assert unknown_index == "random" or isinstance( + unknown_index, int + ), "unknown index needs to be 'random', 'extra' or any integer" + self.unknown_index = unknown_index # "random" or "extra" or integer + if self.remap is not None: + logpy.info( + f"Remapping {self.n_embed} indices to {self.re_embed} indices. " + f"Using {self.unknown_index} for unknown indices." + ) + + def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Dict]: + # reshape z -> (batch, height, width, channel) and flatten + # z, 'b c h w -> b h w c' + z = rearrange(z, "b c h w -> b h w c") + z_flattened = z.reshape(-1, self.codebook_dim) + + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + d = ( + z_flattened.pow(2).sum(dim=1, keepdim=True) + + self.embedding.weight.pow(2).sum(dim=1) + - 2 * torch.einsum("bd,nd->bn", z_flattened, self.embedding.weight) + ) # 'n d -> d n' + + encoding_indices = torch.argmin(d, dim=1) + + z_q = self.embedding(encoding_indices).view(z.shape) + encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype) + avg_probs = torch.mean(encodings, dim=0) + perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) + + if self.training and self.embedding.update: + # EMA cluster size + encodings_sum = encodings.sum(0) + self.embedding.cluster_size_ema_update(encodings_sum) + # EMA embedding average + embed_sum = encodings.transpose(0, 1) @ z_flattened + self.embedding.embed_avg_ema_update(embed_sum) + # normalize embed_avg and update weight + self.embedding.weight_update(self.num_tokens) + + # compute loss for embedding + loss = self.beta * F.mse_loss(z_q.detach(), z) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # reshape back to match original input shape + # z_q, 'b h w c -> b c h w' + z_q = rearrange(z_q, "b h w c -> b c h w") + + out_dict = { + self.loss_key: loss, + "encodings": encodings, + "encoding_indices": encoding_indices, + "perplexity": perplexity, + } + + return z_q, out_dict + + +class VectorQuantizerWithInputProjection(VectorQuantizer): + def __init__( + self, + input_dim: int, + n_codes: int, + codebook_dim: int, + beta: float = 1.0, + output_dim: Optional[int] = None, + **kwargs, + ): + super().__init__(n_codes, codebook_dim, beta, **kwargs) + self.proj_in = nn.Linear(input_dim, codebook_dim) + self.output_dim = output_dim + if output_dim is not None: + self.proj_out = nn.Linear(codebook_dim, output_dim) + else: + self.proj_out = nn.Identity() + + def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Dict]: + rearr = False + in_shape = z.shape + + if z.ndim > 3: + rearr = self.output_dim is not None + z = rearrange(z, "b c ... -> b (...) c") + z = self.proj_in(z) + z_q, loss_dict = super().forward(z) + + z_q = self.proj_out(z_q) + if rearr: + if len(in_shape) == 4: + z_q = rearrange(z_q, "b (h w) c -> b c h w ", w=in_shape[-1]) + elif len(in_shape) == 5: + z_q = rearrange(z_q, "b (t h w) c -> b c t h w ", w=in_shape[-1], h=in_shape[-2]) + else: + raise NotImplementedError(f"rearranging not available for {len(in_shape)}-dimensional input.") + + return z_q, loss_dict diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/temporal_ae.py b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/temporal_ae.py new file mode 100644 index 0000000000000000000000000000000000000000..a45ef9d62efc30a0abd6d6e730254d6439ff419b --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/temporal_ae.py @@ -0,0 +1,331 @@ +from typing import Callable, Iterable, Union + +import torch +from einops import rearrange, repeat + +from sgm.modules.diffusionmodules.model import ( + XFORMERS_IS_AVAILABLE, + AttnBlock, + Decoder, + MemoryEfficientAttnBlock, + ResnetBlock, +) +from sgm.modules.diffusionmodules.openaimodel import ResBlock, timestep_embedding +from sgm.modules.video_attention import VideoTransformerBlock +from sgm.util import partialclass + + +class VideoResBlock(ResnetBlock): + def __init__( + self, + out_channels, + *args, + dropout=0.0, + video_kernel_size=3, + alpha=0.0, + merge_strategy="learned", + **kwargs, + ): + super().__init__(out_channels=out_channels, dropout=dropout, *args, **kwargs) + if video_kernel_size is None: + video_kernel_size = [3, 1, 1] + self.time_stack = ResBlock( + channels=out_channels, + emb_channels=0, + dropout=dropout, + dims=3, + use_scale_shift_norm=False, + use_conv=False, + up=False, + down=False, + kernel_size=video_kernel_size, + use_checkpoint=False, + skip_t_emb=True, + ) + + self.merge_strategy = merge_strategy + if self.merge_strategy == "fixed": + self.register_buffer("mix_factor", torch.Tensor([alpha])) + elif self.merge_strategy == "learned": + self.register_parameter("mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))) + else: + raise ValueError(f"unknown merge strategy {self.merge_strategy}") + + def get_alpha(self, bs): + if self.merge_strategy == "fixed": + return self.mix_factor + elif self.merge_strategy == "learned": + return torch.sigmoid(self.mix_factor) + else: + raise NotImplementedError() + + def forward(self, x, temb, skip_video=False, timesteps=None): + if timesteps is None: + timesteps = self.timesteps + + b, c, h, w = x.shape + + x = super().forward(x, temb) + + if not skip_video: + x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps) + + x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps) + + x = self.time_stack(x, temb) + + alpha = self.get_alpha(bs=b // timesteps) + x = alpha * x + (1.0 - alpha) * x_mix + + x = rearrange(x, "b c t h w -> (b t) c h w") + return x + + +class AE3DConv(torch.nn.Conv2d): + def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs): + super().__init__(in_channels, out_channels, *args, **kwargs) + if isinstance(video_kernel_size, Iterable): + padding = [int(k // 2) for k in video_kernel_size] + else: + padding = int(video_kernel_size // 2) + + self.time_mix_conv = torch.nn.Conv3d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=video_kernel_size, + padding=padding, + ) + + def forward(self, input, timesteps, skip_video=False): + x = super().forward(input) + if skip_video: + return x + x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps) + x = self.time_mix_conv(x) + return rearrange(x, "b c t h w -> (b t) c h w") + + +class VideoBlock(AttnBlock): + def __init__(self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned"): + super().__init__(in_channels) + # no context, single headed, as in base class + self.time_mix_block = VideoTransformerBlock( + dim=in_channels, + n_heads=1, + d_head=in_channels, + checkpoint=False, + ff_in=True, + attn_mode="softmax", + ) + + time_embed_dim = self.in_channels * 4 + self.video_time_embed = torch.nn.Sequential( + torch.nn.Linear(self.in_channels, time_embed_dim), + torch.nn.SiLU(), + torch.nn.Linear(time_embed_dim, self.in_channels), + ) + + self.merge_strategy = merge_strategy + if self.merge_strategy == "fixed": + self.register_buffer("mix_factor", torch.Tensor([alpha])) + elif self.merge_strategy == "learned": + self.register_parameter("mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))) + else: + raise ValueError(f"unknown merge strategy {self.merge_strategy}") + + def forward(self, x, timesteps, skip_video=False): + if skip_video: + return super().forward(x) + + x_in = x + x = self.attention(x) + h, w = x.shape[2:] + x = rearrange(x, "b c h w -> b (h w) c") + + x_mix = x + num_frames = torch.arange(timesteps, device=x.device) + num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps) + num_frames = rearrange(num_frames, "b t -> (b t)") + t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False) + emb = self.video_time_embed(t_emb) # b, n_channels + emb = emb[:, None, :] + x_mix = x_mix + emb + + alpha = self.get_alpha() + x_mix = self.time_mix_block(x_mix, timesteps=timesteps) + x = alpha * x + (1.0 - alpha) * x_mix # alpha merge + + x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) + x = self.proj_out(x) + + return x_in + x + + def get_alpha( + self, + ): + if self.merge_strategy == "fixed": + return self.mix_factor + elif self.merge_strategy == "learned": + return torch.sigmoid(self.mix_factor) + else: + raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}") + + +class MemoryEfficientVideoBlock(MemoryEfficientAttnBlock): + def __init__(self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned"): + super().__init__(in_channels) + # no context, single headed, as in base class + self.time_mix_block = VideoTransformerBlock( + dim=in_channels, + n_heads=1, + d_head=in_channels, + checkpoint=False, + ff_in=True, + attn_mode="softmax-xformers", + ) + + time_embed_dim = self.in_channels * 4 + self.video_time_embed = torch.nn.Sequential( + torch.nn.Linear(self.in_channels, time_embed_dim), + torch.nn.SiLU(), + torch.nn.Linear(time_embed_dim, self.in_channels), + ) + + self.merge_strategy = merge_strategy + if self.merge_strategy == "fixed": + self.register_buffer("mix_factor", torch.Tensor([alpha])) + elif self.merge_strategy == "learned": + self.register_parameter("mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))) + else: + raise ValueError(f"unknown merge strategy {self.merge_strategy}") + + def forward(self, x, timesteps, skip_time_block=False): + if skip_time_block: + return super().forward(x) + + x_in = x + x = self.attention(x) + h, w = x.shape[2:] + x = rearrange(x, "b c h w -> b (h w) c") + + x_mix = x + num_frames = torch.arange(timesteps, device=x.device) + num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps) + num_frames = rearrange(num_frames, "b t -> (b t)") + t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False) + emb = self.video_time_embed(t_emb) # b, n_channels + emb = emb[:, None, :] + x_mix = x_mix + emb + + alpha = self.get_alpha() + x_mix = self.time_mix_block(x_mix, timesteps=timesteps) + x = alpha * x + (1.0 - alpha) * x_mix # alpha merge + + x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) + x = self.proj_out(x) + + return x_in + x + + def get_alpha( + self, + ): + if self.merge_strategy == "fixed": + return self.mix_factor + elif self.merge_strategy == "learned": + return torch.sigmoid(self.mix_factor) + else: + raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}") + + +def make_time_attn( + in_channels, + attn_type="vanilla", + attn_kwargs=None, + alpha: float = 0, + merge_strategy: str = "learned", +): + assert attn_type in [ + "vanilla", + "vanilla-xformers", + ], f"attn_type {attn_type} not supported for spatio-temporal attention" + print(f"making spatial and temporal attention of type '{attn_type}' with {in_channels} in_channels") + if not XFORMERS_IS_AVAILABLE and attn_type == "vanilla-xformers": + print( + f"Attention mode '{attn_type}' is not available. Falling back to vanilla attention. " + f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}" + ) + attn_type = "vanilla" + + if attn_type == "vanilla": + assert attn_kwargs is None + return partialclass(VideoBlock, in_channels, alpha=alpha, merge_strategy=merge_strategy) + elif attn_type == "vanilla-xformers": + print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...") + return partialclass( + MemoryEfficientVideoBlock, + in_channels, + alpha=alpha, + merge_strategy=merge_strategy, + ) + else: + return NotImplementedError() + + +class Conv2DWrapper(torch.nn.Conv2d): + def forward(self, input: torch.Tensor, **kwargs) -> torch.Tensor: + return super().forward(input) + + +class VideoDecoder(Decoder): + available_time_modes = ["all", "conv-only", "attn-only"] + + def __init__( + self, + *args, + video_kernel_size: Union[int, list] = 3, + alpha: float = 0.0, + merge_strategy: str = "learned", + time_mode: str = "conv-only", + **kwargs, + ): + self.video_kernel_size = video_kernel_size + self.alpha = alpha + self.merge_strategy = merge_strategy + self.time_mode = time_mode + assert ( + self.time_mode in self.available_time_modes + ), f"time_mode parameter has to be in {self.available_time_modes}" + super().__init__(*args, **kwargs) + + def get_last_layer(self, skip_time_mix=False, **kwargs): + if self.time_mode == "attn-only": + raise NotImplementedError("TODO") + else: + return self.conv_out.time_mix_conv.weight if not skip_time_mix else self.conv_out.weight + + def _make_attn(self) -> Callable: + if self.time_mode not in ["conv-only", "only-last-conv"]: + return partialclass( + make_time_attn, + alpha=self.alpha, + merge_strategy=self.merge_strategy, + ) + else: + return super()._make_attn() + + def _make_conv(self) -> Callable: + if self.time_mode != "attn-only": + return partialclass(AE3DConv, video_kernel_size=self.video_kernel_size) + else: + return Conv2DWrapper + + def _make_resblock(self) -> Callable: + if self.time_mode not in ["attn-only", "only-last-conv"]: + return partialclass( + VideoResBlock, + video_kernel_size=self.video_kernel_size, + alpha=self.alpha, + merge_strategy=self.merge_strategy, + ) + else: + return super()._make_resblock() diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/vqvae/movq_dec_3d.py b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/vqvae/movq_dec_3d.py new file mode 100644 index 0000000000000000000000000000000000000000..0f9a46988703904f1e3a0b5f8f28f33cce4537bd --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/vqvae/movq_dec_3d.py @@ -0,0 +1,495 @@ +# pytorch_diffusion + derived encoder decoder +import math +import torch +import torch.nn as nn +import numpy as np +from einops import rearrange +from .movq_enc_3d import CausalConv3d, Upsample3D, DownSample3D + + +def cast_tuple(t, length=1): + return t if isinstance(t, tuple) else ((t,) * length) + + +def divisible_by(num, den): + return (num % den) == 0 + + +def is_odd(n): + return not divisible_by(n, 2) + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +def nonlinearity(x): + # swish + return x * torch.sigmoid(x) + + +class SpatialNorm3D(nn.Module): + def __init__( + self, + f_channels, + zq_channels, + norm_layer=nn.GroupNorm, + freeze_norm_layer=False, + add_conv=False, + pad_mode="constant", + **norm_layer_params, + ): + super().__init__() + self.norm_layer = norm_layer(num_channels=f_channels, **norm_layer_params) + if freeze_norm_layer: + for p in self.norm_layer.parameters: + p.requires_grad = False + self.add_conv = add_conv + if self.add_conv: + self.conv = CausalConv3d(zq_channels, zq_channels, kernel_size=3, pad_mode=pad_mode) + self.conv_y = CausalConv3d(zq_channels, f_channels, kernel_size=1, pad_mode=pad_mode) + self.conv_b = CausalConv3d(zq_channels, f_channels, kernel_size=1, pad_mode=pad_mode) + + def forward(self, f, zq): + if zq.shape[2] > 1: + f_first, f_rest = f[:, :, :1], f[:, :, 1:] + f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:] + zq_first, zq_rest = zq[:, :, :1], zq[:, :, 1:] + zq_first = torch.nn.functional.interpolate(zq_first, size=f_first_size, mode="nearest") + zq_rest = torch.nn.functional.interpolate(zq_rest, size=f_rest_size, mode="nearest") + zq = torch.cat([zq_first, zq_rest], dim=2) + else: + zq = torch.nn.functional.interpolate(zq, size=f.shape[-3:], mode="nearest") + if self.add_conv: + zq = self.conv(zq) + norm_f = self.norm_layer(f) + new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) + return new_f + + +def Normalize3D(in_channels, zq_ch, add_conv): + return SpatialNorm3D( + in_channels, + zq_ch, + norm_layer=nn.GroupNorm, + freeze_norm_layer=False, + add_conv=add_conv, + num_groups=32, + eps=1e-6, + affine=True, + ) + + +class ResnetBlock3D(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout, + temb_channels=512, + zq_ch=None, + add_conv=False, + pad_mode="constant", + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize3D(in_channels, zq_ch, add_conv=add_conv) + self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, pad_mode=pad_mode) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, out_channels) + self.norm2 = Normalize3D(out_channels, zq_ch, add_conv=add_conv) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = CausalConv3d(out_channels, out_channels, kernel_size=3, pad_mode=pad_mode) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = CausalConv3d(in_channels, out_channels, kernel_size=3, pad_mode=pad_mode) + else: + self.nin_shortcut = torch.nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x, temb, zq): + h = x + h = self.norm1(h, zq) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None, None] + + h = self.norm2(h, zq) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h + + +class AttnBlock2D(nn.Module): + def __init__(self, in_channels, zq_ch=None, add_conv=False): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize3D(in_channels, zq_ch, add_conv=add_conv) + self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x, zq): + h_ = x + h_ = self.norm(h_, zq) + + t = h_.shape[2] + h_ = rearrange(h_, "b c t h w -> (b t) c h w") + + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h * w) + q = q.permute(0, 2, 1) # b,hw,c + k = k.reshape(b, c, h * w) # b,c,hw + w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w) + w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, h, w) + + h_ = self.proj_out(h_) + + h_ = rearrange(h_, "(b t) c h w -> b c t h w", t=t) + + return x + h_ + + +class MOVQDecoder3D(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + give_pre_end=False, + zq_ch=None, + add_conv=False, + pad_mode="first", + temporal_compress_times=4, + **ignorekwargs, + ): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + + # log2 of temporal_compress_times + self.temporal_compress_level = int(np.log2(temporal_compress_times)) + + if zq_ch is None: + zq_ch = z_channels + + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + + self.conv_in = CausalConv3d(z_channels, block_in, kernel_size=3, pad_mode=pad_mode) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock3D( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + zq_ch=zq_ch, + add_conv=add_conv, + pad_mode=pad_mode, + ) + + self.mid.block_2 = ResnetBlock3D( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + zq_ch=zq_ch, + add_conv=add_conv, + pad_mode=pad_mode, + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + ResnetBlock3D( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + zq_ch=zq_ch, + add_conv=add_conv, + pad_mode=pad_mode, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock2D(block_in, zq_ch, add_conv=add_conv)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + if i_level < self.num_resolutions - self.temporal_compress_level: + up.upsample = Upsample3D(block_in, resamp_with_conv, compress_time=False) + else: + up.upsample = Upsample3D(block_in, resamp_with_conv, compress_time=True) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + self.norm_out = Normalize3D(block_in, zq_ch, add_conv=add_conv) + self.conv_out = CausalConv3d(block_in, out_ch, kernel_size=3, pad_mode=pad_mode) + + def forward(self, z, use_cp=False): + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + t = z.shape[2] + # z to block_in + + zq = z + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb, zq) + # h = self.mid.attn_1(h, zq) + h = self.mid.block_2(h, temb, zq) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h, temb, zq) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h, zq) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h, zq) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + def get_last_layer(self): + return self.conv_out.conv.weight + + +class NewDecoder3D(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + give_pre_end=False, + zq_ch=None, + add_conv=False, + pad_mode="first", + temporal_compress_times=4, + post_quant_conv=False, + **ignorekwargs, + ): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + + # log2 of temporal_compress_times + self.temporal_compress_level = int(np.log2(temporal_compress_times)) + + if zq_ch is None: + zq_ch = z_channels + if post_quant_conv: + self.post_quant_conv = CausalConv3d(zq_ch, z_channels, kernel_size=3, pad_mode=pad_mode) + else: + self.post_quant_conv = None + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,) + tuple(ch_mult) + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) + + # z to block_in + # self.conv_in = torch.nn.Conv3d(z_channels, + # block_in, + # kernel_size=3, + # stride=1, + # padding=1) + self.conv_in = CausalConv3d(z_channels, block_in, kernel_size=3, pad_mode=pad_mode) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock3D( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + zq_ch=zq_ch, + add_conv=add_conv, + pad_mode=pad_mode, + ) + # remove attention block + # self.mid.attn_1 = AttnBlock2D(block_in, zq_ch, add_conv=add_conv) + self.mid.block_2 = ResnetBlock3D( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + zq_ch=zq_ch, + add_conv=add_conv, + pad_mode=pad_mode, + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + ResnetBlock3D( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + zq_ch=zq_ch, + add_conv=add_conv, + pad_mode=pad_mode, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock2D(block_in, zq_ch, add_conv=add_conv)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + if i_level < self.num_resolutions - self.temporal_compress_level: + up.upsample = Upsample3D(block_in, resamp_with_conv, compress_time=False) + else: + up.upsample = Upsample3D(block_in, resamp_with_conv, compress_time=True) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + self.norm_out = Normalize3D(block_in, zq_ch, add_conv=add_conv) + # self.conv_out = torch.nn.Conv3d(block_in, + # out_ch, + # kernel_size=3, + # stride=1, + # padding=1) + self.conv_out = CausalConv3d(block_in, out_ch, kernel_size=3, pad_mode=pad_mode) + + def forward(self, z): + # assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + t = z.shape[2] + # z to block_in + + zq = z + if self.post_quant_conv is not None: + z = self.post_quant_conv(z) + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb, zq) + # h = self.mid.attn_1(h, zq) + h = self.mid.block_2(h, temb, zq) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h, temb, zq) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h, zq) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h, zq) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + def get_last_layer(self): + return self.conv_out.conv.weight diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/vqvae/movq_dec_3d_dev.py b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/vqvae/movq_dec_3d_dev.py new file mode 100644 index 0000000000000000000000000000000000000000..1b9a663b9c0bfdc364d839285c2cb314661a6c4c --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/vqvae/movq_dec_3d_dev.py @@ -0,0 +1,535 @@ +# pytorch_diffusion + derived encoder decoder +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +from beartype import beartype +from beartype.typing import Union, Tuple, Optional, List +from einops import rearrange + +from .movq_enc_3d import CausalConv3d, Upsample3D, DownSample3D + + +def cast_tuple(t, length=1): + return t if isinstance(t, tuple) else ((t,) * length) + + +def divisible_by(num, den): + return (num % den) == 0 + + +def is_odd(n): + return not divisible_by(n, 2) + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +def nonlinearity(x): + # swish + return x * torch.sigmoid(x) + + +class SpatialNorm3D(nn.Module): + def __init__( + self, + f_channels, + zq_channels, + norm_layer=nn.GroupNorm, + freeze_norm_layer=False, + add_conv=False, + pad_mode="constant", + **norm_layer_params, + ): + super().__init__() + self.norm_layer = norm_layer(num_channels=f_channels, **norm_layer_params) + if freeze_norm_layer: + for p in self.norm_layer.parameters: + p.requires_grad = False + self.add_conv = add_conv + if self.add_conv: + # self.conv = nn.Conv3d(zq_channels, zq_channels, kernel_size=3, stride=1, padding=1) + self.conv = CausalConv3d(zq_channels, zq_channels, kernel_size=3, pad_mode=pad_mode) + # self.conv_y = nn.Conv3d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) + # self.conv_b = nn.Conv3d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) + self.conv_y = CausalConv3d(zq_channels, f_channels, kernel_size=1, pad_mode=pad_mode) + self.conv_b = CausalConv3d(zq_channels, f_channels, kernel_size=1, pad_mode=pad_mode) + + def forward(self, f, zq): + if zq.shape[2] > 1: + f_first, f_rest = f[:, :, :1], f[:, :, 1:] + f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:] + zq_first, zq_rest = zq[:, :, :1], zq[:, :, 1:] + zq_first = torch.nn.functional.interpolate(zq_first, size=f_first_size, mode="nearest") + zq_rest = torch.nn.functional.interpolate(zq_rest, size=f_rest_size, mode="nearest") + zq = torch.cat([zq_first, zq_rest], dim=2) + else: + zq = torch.nn.functional.interpolate(zq, size=f.shape[-3:], mode="nearest") + if self.add_conv: + zq = self.conv(zq) + norm_f = self.norm_layer(f) + new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) + return new_f + + +def Normalize3D(in_channels, zq_ch, add_conv): + return SpatialNorm3D( + in_channels, + zq_ch, + norm_layer=nn.GroupNorm, + freeze_norm_layer=False, + add_conv=add_conv, + num_groups=32, + eps=1e-6, + affine=True, + ) + + +class ResnetBlock3D(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout, + temb_channels=512, + zq_ch=None, + add_conv=False, + pad_mode="constant", + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize3D(in_channels, zq_ch, add_conv=add_conv) + # self.conv1 = torch.nn.Conv3d(in_channels, + # out_channels, + # kernel_size=3, + # stride=1, + # padding=1) + self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, pad_mode=pad_mode) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, out_channels) + self.norm2 = Normalize3D(out_channels, zq_ch, add_conv=add_conv) + self.dropout = torch.nn.Dropout(dropout) + # self.conv2 = torch.nn.Conv3d(out_channels, + # out_channels, + # kernel_size=3, + # stride=1, + # padding=1) + self.conv2 = CausalConv3d(out_channels, out_channels, kernel_size=3, pad_mode=pad_mode) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + # self.conv_shortcut = torch.nn.Conv3d(in_channels, + # out_channels, + # kernel_size=3, + # stride=1, + # padding=1) + self.conv_shortcut = CausalConv3d(in_channels, out_channels, kernel_size=3, pad_mode=pad_mode) + else: + self.nin_shortcut = torch.nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + # self.nin_shortcut = CausalConv3d(in_channels, out_channels, kernel_size=1, pad_mode=pad_mode) + + def forward(self, x, temb, zq): + h = x + h = self.norm1(h, zq) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None, None] + + h = self.norm2(h, zq) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h + + +class AttnBlock2D(nn.Module): + def __init__(self, in_channels, zq_ch=None, add_conv=False): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize3D(in_channels, zq_ch, add_conv=add_conv) + self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x, zq): + h_ = x + h_ = self.norm(h_, zq) + + t = h_.shape[2] + h_ = rearrange(h_, "b c t h w -> (b t) c h w") + + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h * w) + q = q.permute(0, 2, 1) # b,hw,c + k = k.reshape(b, c, h * w) # b,c,hw + w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w) + w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, h, w) + + h_ = self.proj_out(h_) + + h_ = rearrange(h_, "(b t) c h w -> b c t h w", t=t) + + return x + h_ + + +class MOVQDecoder3D(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + give_pre_end=False, + zq_ch=None, + add_conv=False, + pad_mode="first", + temporal_compress_times=4, + **ignorekwargs, + ): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + + # log2 of temporal_compress_times + self.temporal_compress_level = int(np.log2(temporal_compress_times)) + + if zq_ch is None: + zq_ch = z_channels + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,) + tuple(ch_mult) + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) + + # z to block_in + # self.conv_in = torch.nn.Conv3d(z_channels, + # block_in, + # kernel_size=3, + # stride=1, + # padding=1) + self.conv_in = CausalConv3d(z_channels, block_in, kernel_size=3, pad_mode=pad_mode) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock3D( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + zq_ch=zq_ch, + add_conv=add_conv, + pad_mode=pad_mode, + ) + # remove attention block + # self.mid.attn_1 = AttnBlock2D(block_in, zq_ch, add_conv=add_conv) + self.mid.block_2 = ResnetBlock3D( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + zq_ch=zq_ch, + add_conv=add_conv, + pad_mode=pad_mode, + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + ResnetBlock3D( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + zq_ch=zq_ch, + add_conv=add_conv, + pad_mode=pad_mode, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock2D(block_in, zq_ch, add_conv=add_conv)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + if i_level < self.num_resolutions - self.temporal_compress_level: + up.upsample = Upsample3D(block_in, resamp_with_conv, compress_time=False) + else: + up.upsample = Upsample3D(block_in, resamp_with_conv, compress_time=True) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + self.norm_out = Normalize3D(block_in, zq_ch, add_conv=add_conv) + # self.conv_out = torch.nn.Conv3d(block_in, + # out_ch, + # kernel_size=3, + # stride=1, + # padding=1) + self.conv_out = CausalConv3d(block_in, out_ch, kernel_size=3, pad_mode=pad_mode) + + def forward(self, z, use_cp=False): + # assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + t = z.shape[2] + # z to block_in + + zq = z + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb, zq) + # h = self.mid.attn_1(h, zq) + h = self.mid.block_2(h, temb, zq) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h, temb, zq) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h, zq) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h, zq) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + def get_last_layer(self): + return self.conv_out.conv.weight + + +class NewDecoder3D(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + give_pre_end=False, + zq_ch=None, + add_conv=False, + pad_mode="first", + temporal_compress_times=4, + post_quant_conv=False, + **ignorekwargs, + ): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + + # log2 of temporal_compress_times + self.temporal_compress_level = int(np.log2(temporal_compress_times)) + + if zq_ch is None: + zq_ch = z_channels + if post_quant_conv: + self.post_quant_conv = CausalConv3d(zq_ch, z_channels, kernel_size=3, pad_mode=pad_mode) + else: + self.post_quant_conv = None + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,) + tuple(ch_mult) + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) + + # z to block_in + # self.conv_in = torch.nn.Conv3d(z_channels, + # block_in, + # kernel_size=3, + # stride=1, + # padding=1) + self.conv_in = CausalConv3d(z_channels, block_in, kernel_size=3, pad_mode=pad_mode) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock3D( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + zq_ch=zq_ch, + add_conv=add_conv, + pad_mode=pad_mode, + ) + # remove attention block + # self.mid.attn_1 = AttnBlock2D(block_in, zq_ch, add_conv=add_conv) + self.mid.block_2 = ResnetBlock3D( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + zq_ch=zq_ch, + add_conv=add_conv, + pad_mode=pad_mode, + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + ResnetBlock3D( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + zq_ch=zq_ch, + add_conv=add_conv, + pad_mode=pad_mode, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock2D(block_in, zq_ch, add_conv=add_conv)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + if i_level < self.num_resolutions - self.temporal_compress_level: + up.upsample = Upsample3D(block_in, resamp_with_conv, compress_time=False) + else: + up.upsample = Upsample3D(block_in, resamp_with_conv, compress_time=True) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + self.norm_out = Normalize3D(block_in, zq_ch, add_conv=add_conv) + # self.conv_out = torch.nn.Conv3d(block_in, + # out_ch, + # kernel_size=3, + # stride=1, + # padding=1) + self.conv_out = CausalConv3d(block_in, out_ch, kernel_size=3, pad_mode=pad_mode) + + def forward(self, z): + # assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + t = z.shape[2] + # z to block_in + + zq = z + if self.post_quant_conv is not None: + z = self.post_quant_conv(z) + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb, zq) + # h = self.mid.attn_1(h, zq) + h = self.mid.block_2(h, temb, zq) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h, temb, zq) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h, zq) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h, zq) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + def get_last_layer(self): + return self.conv_out.conv.weight diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/vqvae/movq_enc_3d.py b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/vqvae/movq_enc_3d.py new file mode 100644 index 0000000000000000000000000000000000000000..99b358df877c6f59f9ad62c1bb168339042f1900 --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/vqvae/movq_enc_3d.py @@ -0,0 +1,413 @@ +# pytorch_diffusion + derived encoder decoder +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +from beartype import beartype +from beartype.typing import Union, Tuple, Optional, List +from einops import rearrange + + +def cast_tuple(t, length=1): + return t if isinstance(t, tuple) else ((t,) * length) + + +def divisible_by(num, den): + return (num % den) == 0 + + +def is_odd(n): + return not divisible_by(n, 2) + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +def nonlinearity(x): + # swish + return x * torch.sigmoid(x) + + +class CausalConv3d(nn.Module): + @beartype + def __init__( + self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], pad_mode="constant", **kwargs + ): + super().__init__() + kernel_size = cast_tuple(kernel_size, 3) + + time_kernel_size, height_kernel_size, width_kernel_size = kernel_size + + assert is_odd(height_kernel_size) and is_odd(width_kernel_size) + + dilation = kwargs.pop("dilation", 1) + stride = kwargs.pop("stride", 1) + + self.pad_mode = pad_mode + time_pad = dilation * (time_kernel_size - 1) + (1 - stride) + height_pad = height_kernel_size // 2 + width_pad = width_kernel_size // 2 + + self.height_pad = height_pad + self.width_pad = width_pad + self.time_pad = time_pad + self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0) + + stride = (stride, 1, 1) + dilation = (dilation, 1, 1) + self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs) + + def forward(self, x): + if self.pad_mode == "constant": + causal_padding_3d = (self.time_pad, 0, self.width_pad, self.width_pad, self.height_pad, self.height_pad) + x = F.pad(x, causal_padding_3d, mode="constant", value=0) + elif self.pad_mode == "first": + pad_x = torch.cat([x[:, :, :1]] * self.time_pad, dim=2) + x = torch.cat([pad_x, x], dim=2) + causal_padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad) + x = F.pad(x, causal_padding_2d, mode="constant", value=0) + elif self.pad_mode == "reflect": + # reflect padding + reflect_x = x[:, :, 1 : self.time_pad + 1, :, :].flip(dims=[2]) + if reflect_x.shape[2] < self.time_pad: + reflect_x = torch.cat( + [torch.zeros_like(x[:, :, :1, :, :])] * (self.time_pad - reflect_x.shape[2]) + [reflect_x], dim=2 + ) + x = torch.cat([reflect_x, x], dim=2) + causal_padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad) + x = F.pad(x, causal_padding_2d, mode="constant", value=0) + else: + raise ValueError("Invalid pad mode") + return self.conv(x) + + +def Normalize3D(in_channels): # same for 3D and 2D + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +class Upsample3D(nn.Module): + def __init__(self, in_channels, with_conv, compress_time=False): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + self.compress_time = compress_time + + def forward(self, x): + if self.compress_time: + if x.shape[2] > 1: + # split first frame + x_first, x_rest = x[:, :, 0], x[:, :, 1:] + + x_first = torch.nn.functional.interpolate(x_first, scale_factor=2.0, mode="nearest") + x_rest = torch.nn.functional.interpolate(x_rest, scale_factor=2.0, mode="nearest") + x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2) + else: + x = x.squeeze(2) + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + x = x[:, :, None, :, :] + else: + # only interpolate 2D + t = x.shape[2] + x = rearrange(x, "b c t h w -> (b t) c h w") + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + x = rearrange(x, "(b t) c h w -> b c t h w", t=t) + + if self.with_conv: + t = x.shape[2] + x = rearrange(x, "b c t h w -> (b t) c h w") + x = self.conv(x) + x = rearrange(x, "(b t) c h w -> b c t h w", t=t) + return x + + +class DownSample3D(nn.Module): + def __init__(self, in_channels, with_conv, compress_time=False, out_channels=None): + super().__init__() + self.with_conv = with_conv + if out_channels is None: + out_channels = in_channels + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0) + self.compress_time = compress_time + + def forward(self, x): + if self.compress_time: + h, w = x.shape[-2:] + x = rearrange(x, "b c t h w -> (b h w) c t") + + # split first frame + x_first, x_rest = x[..., 0], x[..., 1:] + + if x_rest.shape[-1] > 0: + x_rest = torch.nn.functional.avg_pool1d(x_rest, kernel_size=2, stride=2) + x = torch.cat([x_first[..., None], x_rest], dim=-1) + x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w) + + if self.with_conv: + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + t = x.shape[2] + x = rearrange(x, "b c t h w -> (b t) c h w") + x = self.conv(x) + x = rearrange(x, "(b t) c h w -> b c t h w", t=t) + else: + t = x.shape[2] + x = rearrange(x, "b c t h w -> (b t) c h w") + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + x = rearrange(x, "(b t) c h w -> b c t h w", t=t) + return x + + +class ResnetBlock3D(nn.Module): + def __init__( + self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512, pad_mode="constant" + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize3D(in_channels) + # self.conv1 = torch.nn.Conv3d(in_channels, + # out_channels, + # kernel_size=3, + # stride=1, + # padding=1) + self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, pad_mode=pad_mode) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, out_channels) + self.norm2 = Normalize3D(out_channels) + self.dropout = torch.nn.Dropout(dropout) + # self.conv2 = torch.nn.Conv3d(out_channels, + # out_channels, + # kernel_size=3, + # stride=1, + # padding=1) + self.conv2 = CausalConv3d(out_channels, out_channels, kernel_size=3, pad_mode=pad_mode) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + # self.conv_shortcut = torch.nn.Conv3d(in_channels, + # out_channels, + # kernel_size=3, + # stride=1, + # padding=1) + self.conv_shortcut = CausalConv3d(in_channels, out_channels, kernel_size=3, pad_mode=pad_mode) + else: + self.nin_shortcut = torch.nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + # self.nin_shortcut = CausalConv3d(in_channels, out_channels, kernel_size=1, pad_mode=pad_mode) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None, None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h + + +class AttnBlock2D(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize3D(in_channels) + self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + + t = h_.shape[2] + h_ = rearrange(h_, "b c t h w -> (b t) c h w") + + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h * w) + q = q.permute(0, 2, 1) # b,hw,c + k = k.reshape(b, c, h * w) # b,c,hw + + # # original version, nan in fp16 + # w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + # w_ = w_ * (int(c)**(-0.5)) + # # implement c**-0.5 on q + q = q * (int(c) ** (-0.5)) + w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w) + w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, h, w) + + h_ = self.proj_out(h_) + + h_ = rearrange(h_, "(b t) c h w -> b c t h w", t=t) + + return x + h_ + + +class Encoder3D(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + double_z=True, + pad_mode="first", + temporal_compress_times=4, + **ignore_kwargs, + ): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # log2 of temporal_compress_times + self.temporal_compress_level = int(np.log2(temporal_compress_times)) + + # downsampling + # self.conv_in = torch.nn.Conv3d(in_channels, + # self.ch, + # kernel_size=3, + # stride=1, + # padding=1) + self.conv_in = CausalConv3d(in_channels, self.ch, kernel_size=3, pad_mode=pad_mode) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock3D( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + pad_mode=pad_mode, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock2D(block_in)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + if i_level < self.temporal_compress_level: + down.downsample = DownSample3D(block_in, resamp_with_conv, compress_time=True) + else: + down.downsample = DownSample3D(block_in, resamp_with_conv, compress_time=False) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock3D( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout, pad_mode=pad_mode + ) + # remove attention block + # self.mid.attn_1 = AttnBlock2D(block_in) + self.mid.block_2 = ResnetBlock3D( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout, pad_mode=pad_mode + ) + + # end + self.norm_out = Normalize3D(block_in) + # self.conv_out = torch.nn.Conv3d(block_in, + # 2*z_channels if double_z else z_channels, + # kernel_size=3, + # stride=1, + # padding=1) + self.conv_out = CausalConv3d( + block_in, 2 * z_channels if double_z else z_channels, kernel_size=3, pad_mode=pad_mode + ) + + def forward(self, x, use_cp=False): + # assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution) + # timestep embedding + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + # h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/vqvae/movq_modules.py b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/vqvae/movq_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..2773b0f2a67a6b4a68579c38962a6852e789e209 --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/vqvae/movq_modules.py @@ -0,0 +1,368 @@ +# pytorch_diffusion + derived encoder decoder +import math +import torch +import torch.nn as nn +import numpy as np + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +def nonlinearity(x): + # swish + return x * torch.sigmoid(x) + + +class SpatialNorm(nn.Module): + def __init__( + self, + f_channels, + zq_channels, + norm_layer=nn.GroupNorm, + freeze_norm_layer=False, + add_conv=False, + **norm_layer_params, + ): + super().__init__() + self.norm_layer = norm_layer(num_channels=f_channels, **norm_layer_params) + if freeze_norm_layer: + for p in self.norm_layer.parameters: + p.requires_grad = False + self.add_conv = add_conv + if self.add_conv: + self.conv = nn.Conv2d(zq_channels, zq_channels, kernel_size=3, stride=1, padding=1) + self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) + self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, f, zq): + f_size = f.shape[-2:] + zq = torch.nn.functional.interpolate(zq, size=f_size, mode="nearest") + if self.add_conv: + zq = self.conv(zq) + norm_f = self.norm_layer(f) + new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) + return new_f + + +def Normalize(in_channels, zq_ch, add_conv): + return SpatialNorm( + in_channels, + zq_ch, + norm_layer=nn.GroupNorm, + freeze_norm_layer=False, + add_conv=add_conv, + num_groups=32, + eps=1e-6, + affine=True, + ) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResnetBlock(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout, + temb_channels=512, + zq_ch=None, + add_conv=False, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels, zq_ch, add_conv=add_conv) + self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, out_channels) + self.norm2 = Normalize(out_channels, zq_ch, add_conv=add_conv) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x, temb, zq): + h = x + h = self.norm1(h, zq) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] + + h = self.norm2(h, zq) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h + + +class AttnBlock(nn.Module): + def __init__(self, in_channels, zq_ch=None, add_conv=False): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels, zq_ch, add_conv=add_conv) + self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x, zq): + h_ = x + h_ = self.norm(h_, zq) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h * w) + q = q.permute(0, 2, 1) # b,hw,c + k = k.reshape(b, c, h * w) # b,c,hw + w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w) + w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, h, w) + + h_ = self.proj_out(h_) + + return x + h_ + + +class MOVQDecoder(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + give_pre_end=False, + zq_ch=None, + add_conv=False, + **ignorekwargs, + ): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,) + tuple(ch_mult) + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + zq_ch=zq_ch, + add_conv=add_conv, + ) + self.mid.attn_1 = AttnBlock(block_in, zq_ch, add_conv=add_conv) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + zq_ch=zq_ch, + add_conv=add_conv, + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + zq_ch=zq_ch, + add_conv=add_conv, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in, zq_ch, add_conv=add_conv)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in, zq_ch, add_conv=add_conv) + self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) + + def forward(self, z, zq): + # assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb, zq) + h = self.mid.attn_1(h, zq) + h = self.mid.block_2(h, temb, zq) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h, temb, zq) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h, zq) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h, zq) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + def forward_with_features_output(self, z, zq): + # assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + output_features = {} + + # z to block_in + h = self.conv_in(z) + output_features["conv_in"] = h + + # middle + h = self.mid.block_1(h, temb, zq) + output_features["mid_block_1"] = h + h = self.mid.attn_1(h, zq) + output_features["mid_attn_1"] = h + h = self.mid.block_2(h, temb, zq) + output_features["mid_block_2"] = h + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h, temb, zq) + output_features[f"up_{i_level}_block_{i_block}"] = h + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h, zq) + output_features[f"up_{i_level}_attn_{i_block}"] = h + if i_level != 0: + h = self.up[i_level].upsample(h) + output_features[f"up_{i_level}_upsample"] = h + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h, zq) + output_features["norm_out"] = h + h = nonlinearity(h) + output_features["nonlinearity"] = h + h = self.conv_out(h) + output_features["conv_out"] = h + + return h, output_features diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/vqvae/quantize.py b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/vqvae/quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..54ea128fc3279c04048b2c89b7d33b551ea4db58 --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/vqvae/quantize.py @@ -0,0 +1,241 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from torch import einsum +from einops import rearrange + + +class VectorQuantizer2(nn.Module): + """ + Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly + avoids costly matrix multiplications and allows for post-hoc remapping of indices. + """ + + # NOTE: due to a bug the beta term was applied to the wrong term. for + # backwards compatibility we use the buggy version by default, but you can + # specify legacy=False to fix it. + def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True): + super().__init__() + self.n_e = n_e + self.e_dim = e_dim + self.beta = beta + self.legacy = legacy + + self.embedding = nn.Embedding(self.n_e, self.e_dim) + self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) + + self.remap = remap + if self.remap is not None: + self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.re_embed = self.used.shape[0] + self.unknown_index = unknown_index # "random" or "extra" or integer + if self.unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed + 1 + print( + f"Remapping {self.n_e} indices to {self.re_embed} indices. " + f"Using {self.unknown_index} for unknown indices." + ) + else: + self.re_embed = n_e + + self.sane_index_shape = sane_index_shape + + def remap_to_used(self, inds): + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + match = (inds[:, :, None] == used[None, None, ...]).long() + new = match.argmax(-1) + unknown = match.sum(2) < 1 + if self.unknown_index == "random": + new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device) + else: + new[unknown] = self.unknown_index + return new.reshape(ishape) + + def unmap_to_all(self, inds): + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + if self.re_embed > self.used.shape[0]: # extra token + inds[inds >= self.used.shape[0]] = 0 # simply set to zero + back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) + return back.reshape(ishape) + + def forward(self, z, temp=None, rescale_logits=False, return_logits=False): + assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel" + assert rescale_logits == False, "Only for interface compatible with Gumbel" + assert return_logits == False, "Only for interface compatible with Gumbel" + # reshape z -> (batch, height, width, channel) and flatten + z = rearrange(z, "b c h w -> b h w c").contiguous() + z_flattened = z.view(-1, self.e_dim) + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + + d = ( + torch.sum(z_flattened**2, dim=1, keepdim=True) + + torch.sum(self.embedding.weight**2, dim=1) + - 2 * torch.einsum("bd,dn->bn", z_flattened, rearrange(self.embedding.weight, "n d -> d n")) + ) + + min_encoding_indices = torch.argmin(d, dim=1) + z_q = self.embedding(min_encoding_indices).view(z.shape) + perplexity = None + min_encodings = None + + # compute loss for embedding + if not self.legacy: + loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2) + else: + loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # reshape back to match original input shape + z_q = rearrange(z_q, "b h w c -> b c h w").contiguous() + + if self.remap is not None: + min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis + min_encoding_indices = self.remap_to_used(min_encoding_indices) + min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten + + if self.sane_index_shape: + min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3]) + + return z_q, loss, (perplexity, min_encodings, min_encoding_indices) + + def get_codebook_entry(self, indices, shape): + # shape specifying (batch, height, width, channel) + if self.remap is not None: + indices = indices.reshape(shape[0], -1) # add batch axis + indices = self.unmap_to_all(indices) + indices = indices.reshape(-1) # flatten again + + # get quantized latent vectors + z_q = self.embedding(indices) + + if shape is not None: + z_q = z_q.view(shape) + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q + + +class GumbelQuantize(nn.Module): + """ + credit to @karpathy: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!) + Gumbel Softmax trick quantizer + Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016 + https://arxiv.org/abs/1611.01144 + """ + + def __init__( + self, + num_hiddens, + embedding_dim, + n_embed, + straight_through=True, + kl_weight=5e-4, + temp_init=1.0, + use_vqinterface=True, + remap=None, + unknown_index="random", + ): + super().__init__() + + self.embedding_dim = embedding_dim + self.n_embed = n_embed + + self.straight_through = straight_through + self.temperature = temp_init + self.kl_weight = kl_weight + + self.proj = nn.Conv2d(num_hiddens, n_embed, 1) + self.embed = nn.Embedding(n_embed, embedding_dim) + + self.use_vqinterface = use_vqinterface + + self.remap = remap + if self.remap is not None: + self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.re_embed = self.used.shape[0] + self.unknown_index = unknown_index # "random" or "extra" or integer + if self.unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed + 1 + print( + f"Remapping {self.n_embed} indices to {self.re_embed} indices. " + f"Using {self.unknown_index} for unknown indices." + ) + else: + self.re_embed = n_embed + + def remap_to_used(self, inds): + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + match = (inds[:, :, None] == used[None, None, ...]).long() + new = match.argmax(-1) + unknown = match.sum(2) < 1 + if self.unknown_index == "random": + new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device) + else: + new[unknown] = self.unknown_index + return new.reshape(ishape) + + def unmap_to_all(self, inds): + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + if self.re_embed > self.used.shape[0]: # extra token + inds[inds >= self.used.shape[0]] = 0 # simply set to zero + back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) + return back.reshape(ishape) + + def forward(self, z, temp=None, return_logits=False): + # force hard = True when we are in eval mode, as we must quantize. actually, always true seems to work + hard = self.straight_through if self.training else True + temp = self.temperature if temp is None else temp + + logits = self.proj(z) + if self.remap is not None: + # continue only with used logits + full_zeros = torch.zeros_like(logits) + logits = logits[:, self.used, ...] + + soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard) + if self.remap is not None: + # go back to all entries but unused set to zero + full_zeros[:, self.used, ...] = soft_one_hot + soft_one_hot = full_zeros + z_q = einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight) + + # + kl divergence to the prior loss + qy = F.softmax(logits, dim=1) + diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean() + + ind = soft_one_hot.argmax(dim=1) + if self.remap is not None: + ind = self.remap_to_used(ind) + if self.use_vqinterface: + if return_logits: + return z_q, diff, (None, None, ind), logits + return z_q, diff, (None, None, ind) + return z_q, diff, ind + + def get_codebook_entry(self, indices, shape): + b, h, w, c = shape + assert b * h * w == indices.shape[0] + indices = rearrange(indices, "(b h w) -> b h w", b=b, h=h, w=w) + if self.remap is not None: + indices = self.unmap_to_all(indices) + one_hot = F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float() + z_q = einsum("b n h w, n d -> b d h w", one_hot, self.embed.weight) + return z_q diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/vqvae/vqvae_blocks.py b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/vqvae/vqvae_blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..e42154f1c9bcacdd8b2cd80452ba9104b0352fcc --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/autoencoding/vqvae/vqvae_blocks.py @@ -0,0 +1,402 @@ +# pytorch_diffusion + derived encoder decoder +import math +import torch +import torch.nn as nn +import numpy as np + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +def nonlinearity(x): + # swish + return x * torch.sigmoid(x) + + +def Normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResnetBlock(nn.Module): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h * w) + q = q.permute(0, 2, 1) # b,hw,c + k = k.reshape(b, c, h * w) # b,c,hw + + # # original version, nan in fp16 + # w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + # w_ = w_ * (int(c)**(-0.5)) + # # implement c**-0.5 on q + q = q * (int(c) ** (-0.5)) + w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w) + w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, h, w) + + h_ = self.proj_out(h_) + + return x + h_ + + +class Encoder(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + double_z=True, + **ignore_kwargs, + ): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout + ) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout + ) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, 2 * z_channels if double_z else z_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + # assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution) + + # timestep embedding + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + def forward_with_features_output(self, x): + # assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution) + + # timestep embedding + temb = None + output_features = {} + + # downsampling + hs = [self.conv_in(x)] + output_features["conv_in"] = hs[-1] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + output_features["down{}_block{}".format(i_level, i_block)] = h + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + output_features["down{}_attn{}".format(i_level, i_block)] = h + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + output_features["down{}_downsample".format(i_level)] = hs[-1] + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + output_features["mid_block_1"] = h + h = self.mid.attn_1(h) + output_features["mid_attn_1"] = h + h = self.mid.block_2(h, temb) + output_features["mid_block_2"] = h + + # end + h = self.norm_out(h) + output_features["norm_out"] = h + h = nonlinearity(h) + output_features["nonlinearity"] = h + h = self.conv_out(h) + output_features["conv_out"] = h + + return h, output_features + + +class Decoder(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + give_pre_end=False, + **ignorekwargs, + ): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,) + tuple(ch_mult) + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout + ) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + ResnetBlock( + in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) + + def forward(self, z): + # assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/cp_enc_dec.py b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/cp_enc_dec.py new file mode 100644 index 0000000000000000000000000000000000000000..931e657bd34edd16a0c19ab26a2ac02627a622a4 --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/cp_enc_dec.py @@ -0,0 +1,181 @@ +import math +import torch +import torch.distributed +import torch.nn as nn +from ..util import ( + get_context_parallel_group, + get_context_parallel_rank, + get_context_parallel_world_size, +) + +_USE_CP = True + + +def cast_tuple(t, length=1): + return t if isinstance(t, tuple) else ((t,) * length) + + +def divisible_by(num, den): + return (num % den) == 0 + + +def is_odd(n): + return not divisible_by(n, 2) + + +def exists(v): + return v is not None + + +def pair(t): + return t if isinstance(t, tuple) else (t, t) + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +def nonlinearity(x): + # swish + return x * torch.sigmoid(x) + + +def leaky_relu(p=0.1): + return nn.LeakyReLU(p) + + +def _split(input_, dim): + cp_world_size = get_context_parallel_world_size() + + if cp_world_size == 1: + return input_ + + cp_rank = get_context_parallel_rank() + + # print('in _split, cp_rank:', cp_rank, 'input_size:', input_.shape) + + inpu_first_frame_ = input_.transpose(0, dim)[:1].transpose(0, dim).contiguous() + input_ = input_.transpose(0, dim)[1:].transpose(0, dim).contiguous() + dim_size = input_.size()[dim] // cp_world_size + + input_list = torch.split(input_, dim_size, dim=dim) + output = input_list[cp_rank] + + if cp_rank == 0: + output = torch.cat([inpu_first_frame_, output], dim=dim) + output = output.contiguous() + + # print('out _split, cp_rank:', cp_rank, 'output_size:', output.shape) + + return output + + +def _gather(input_, dim): + cp_world_size = get_context_parallel_world_size() + + # Bypass the function if context parallel is 1 + if cp_world_size == 1: + return input_ + + group = get_context_parallel_group() + cp_rank = get_context_parallel_rank() + + # print('in _gather, cp_rank:', cp_rank, 'input_size:', input_.shape) + + input_first_frame_ = input_.transpose(0, dim)[:1].transpose(0, dim).contiguous() + if cp_rank == 0: + input_ = input_.transpose(0, dim)[1:].transpose(0, dim).contiguous() + + tensor_list = [torch.empty_like(torch.cat([input_first_frame_, input_], dim=dim))] + [ + torch.empty_like(input_) for _ in range(cp_world_size - 1) + ] + + if cp_rank == 0: + input_ = torch.cat([input_first_frame_, input_], dim=dim) + + tensor_list[cp_rank] = input_ + torch.distributed.all_gather(tensor_list, input_, group=group) + + output = torch.cat(tensor_list, dim=dim).contiguous() + + # print('out _gather, cp_rank:', cp_rank, 'output_size:', output.shape) + + return output + + +def _conv_split(input_, dim, kernel_size): + cp_world_size = get_context_parallel_world_size() + + # Bypass the function if context parallel is 1 + if cp_world_size == 1: + return input_ + + # print('in _conv_split, cp_rank:', cp_rank, 'input_size:', input_.shape) + + cp_rank = get_context_parallel_rank() + + dim_size = (input_.size()[dim] - kernel_size) // cp_world_size + + if cp_rank == 0: + output = input_.transpose(dim, 0)[: dim_size + kernel_size].transpose(dim, 0) + else: + output = input_.transpose(dim, 0)[cp_rank * dim_size + 1 : (cp_rank + 1) * dim_size + kernel_size].transpose( + dim, 0 + ) + output = output.contiguous() + + # print('out _conv_split, cp_rank:', cp_rank, 'input_size:', output.shape) + + return output + + +def _conv_gather(input_, dim, kernel_size): + cp_world_size = get_context_parallel_world_size() + + # Bypass the function if context parallel is 1 + if cp_world_size == 1: + return input_ + + group = get_context_parallel_group() + cp_rank = get_context_parallel_rank() + + # print('in _conv_gather, cp_rank:', cp_rank, 'input_size:', input_.shape) + + input_first_kernel_ = input_.transpose(0, dim)[:kernel_size].transpose(0, dim).contiguous() + if cp_rank == 0: + input_ = input_.transpose(0, dim)[kernel_size:].transpose(0, dim).contiguous() + else: + input_ = input_.transpose(0, dim)[kernel_size - 1 :].transpose(0, dim).contiguous() + + tensor_list = [torch.empty_like(torch.cat([input_first_kernel_, input_], dim=dim))] + [ + torch.empty_like(input_) for _ in range(cp_world_size - 1) + ] + if cp_rank == 0: + input_ = torch.cat([input_first_kernel_, input_], dim=dim) + + tensor_list[cp_rank] = input_ + torch.distributed.all_gather(tensor_list, input_, group=group) + + # Note: torch.cat already creates a contiguous tensor. + output = torch.cat(tensor_list, dim=dim).contiguous() + + # print('out _conv_gather, cp_rank:', cp_rank, 'input_size:', output.shape) + + return output diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/diffusionmodules/__init__.py b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/diffusionmodules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fccebf954f5760fa559b17755e743c41daa1a824 --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/diffusionmodules/__init__.py @@ -0,0 +1,6 @@ +from .denoiser import Denoiser +from .discretizer import Discretization +from .model import Decoder, Encoder, Model +from .openaimodel import UNetModel +from .sampling import BaseDiffusionSampler +from .wrappers import OpenAIWrapper diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/diffusionmodules/denoiser.py b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/diffusionmodules/denoiser.py new file mode 100644 index 0000000000000000000000000000000000000000..3cc01e36f86117183ba8f6c5ee74f4c4cd579aed --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/diffusionmodules/denoiser.py @@ -0,0 +1,72 @@ +from typing import Dict, Union + +import torch +import torch.nn as nn + +from ...util import append_dims, instantiate_from_config + + +class Denoiser(nn.Module): + def __init__(self, weighting_config, scaling_config): + super().__init__() + + self.weighting = instantiate_from_config(weighting_config) + self.scaling = instantiate_from_config(scaling_config) + + def possibly_quantize_sigma(self, sigma): + return sigma + + def possibly_quantize_c_noise(self, c_noise): + return c_noise + + def w(self, sigma): + return self.weighting(sigma) + + def forward( + self, + network: nn.Module, + input: torch.Tensor, + sigma: torch.Tensor, + cond: Dict, + **additional_model_inputs, + ) -> torch.Tensor: + sigma = self.possibly_quantize_sigma(sigma) + sigma_shape = sigma.shape + sigma = append_dims(sigma, input.ndim) + c_skip, c_out, c_in, c_noise = self.scaling(sigma, **additional_model_inputs) + c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape)) + return network(input * c_in, c_noise, cond, **additional_model_inputs) * c_out + input * c_skip + + +class DiscreteDenoiser(Denoiser): + def __init__( + self, + weighting_config, + scaling_config, + num_idx, + discretization_config, + do_append_zero=False, + quantize_c_noise=True, + flip=True, + ): + super().__init__(weighting_config, scaling_config) + sigmas = instantiate_from_config(discretization_config)(num_idx, do_append_zero=do_append_zero, flip=flip) + self.sigmas = sigmas + # self.register_buffer("sigmas", sigmas) + self.quantize_c_noise = quantize_c_noise + + def sigma_to_idx(self, sigma): + dists = sigma - self.sigmas.to(sigma.device)[:, None] + return dists.abs().argmin(dim=0).view(sigma.shape) + + def idx_to_sigma(self, idx): + return self.sigmas.to(idx.device)[idx] + + def possibly_quantize_sigma(self, sigma): + return self.idx_to_sigma(self.sigma_to_idx(sigma)) + + def possibly_quantize_c_noise(self, c_noise): + if self.quantize_c_noise: + return self.sigma_to_idx(c_noise) + else: + return c_noise diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/diffusionmodules/denoiser_scaling.py b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/diffusionmodules/denoiser_scaling.py new file mode 100644 index 0000000000000000000000000000000000000000..2cb9643014435d908dbbd30c30c02c632373846b --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/diffusionmodules/denoiser_scaling.py @@ -0,0 +1,60 @@ +from abc import ABC, abstractmethod +from typing import Any, Tuple + +import torch + + +class DenoiserScaling(ABC): + @abstractmethod + def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + pass + + +class EDMScaling: + def __init__(self, sigma_data: float = 0.5): + self.sigma_data = sigma_data + + def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) + c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5 + c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5 + c_noise = 0.25 * sigma.log() + return c_skip, c_out, c_in, c_noise + + +class EpsScaling: + def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + c_skip = torch.ones_like(sigma, device=sigma.device) + c_out = -sigma + c_in = 1 / (sigma**2 + 1.0) ** 0.5 + c_noise = sigma.clone() + return c_skip, c_out, c_in, c_noise + + +class VScaling: + def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + c_skip = 1.0 / (sigma**2 + 1.0) + c_out = -sigma / (sigma**2 + 1.0) ** 0.5 + c_in = 1.0 / (sigma**2 + 1.0) ** 0.5 + c_noise = sigma.clone() + return c_skip, c_out, c_in, c_noise + + +class VScalingWithEDMcNoise(DenoiserScaling): + def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + c_skip = 1.0 / (sigma**2 + 1.0) + c_out = -sigma / (sigma**2 + 1.0) ** 0.5 + c_in = 1.0 / (sigma**2 + 1.0) ** 0.5 + c_noise = 0.25 * sigma.log() + return c_skip, c_out, c_in, c_noise + + +class VideoScaling: # similar to VScaling + def __call__( + self, alphas_cumprod_sqrt: torch.Tensor, **additional_model_inputs + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + c_skip = alphas_cumprod_sqrt + c_out = -((1 - alphas_cumprod_sqrt**2) ** 0.5) + c_in = torch.ones_like(alphas_cumprod_sqrt, device=alphas_cumprod_sqrt.device) + c_noise = additional_model_inputs["idx"].clone() + return c_skip, c_out, c_in, c_noise diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/diffusionmodules/denoiser_weighting.py b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/diffusionmodules/denoiser_weighting.py new file mode 100644 index 0000000000000000000000000000000000000000..b8b03ca58f17ea3d7374f4bbb7bf1d2994755e00 --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/diffusionmodules/denoiser_weighting.py @@ -0,0 +1,24 @@ +import torch + + +class UnitWeighting: + def __call__(self, sigma): + return torch.ones_like(sigma, device=sigma.device) + + +class EDMWeighting: + def __init__(self, sigma_data=0.5): + self.sigma_data = sigma_data + + def __call__(self, sigma): + return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 + + +class VWeighting(EDMWeighting): + def __init__(self): + super().__init__(sigma_data=1.0) + + +class EpsWeighting: + def __call__(self, sigma): + return sigma**-2.0 diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/diffusionmodules/discretizer.py b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/diffusionmodules/discretizer.py new file mode 100644 index 0000000000000000000000000000000000000000..a86b7d8dfb06aafef388ba28b369c611148ca300 --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/diffusionmodules/discretizer.py @@ -0,0 +1,126 @@ +from abc import abstractmethod +from functools import partial + +import numpy as np +import torch + +from ...modules.diffusionmodules.util import make_beta_schedule +from ...util import append_zero + + +def generate_roughly_equally_spaced_steps(num_substeps: int, max_step: int) -> np.ndarray: + return np.linspace(max_step - 1, 0, num_substeps, endpoint=False).astype(int)[::-1] + + +class Discretization: + def __call__(self, n, do_append_zero=True, device="cpu", flip=False, return_idx=False): + if return_idx: + sigmas, idx = self.get_sigmas(n, device=device, return_idx=return_idx) + else: + sigmas = self.get_sigmas(n, device=device, return_idx=return_idx) + sigmas = append_zero(sigmas) if do_append_zero else sigmas + if return_idx: + return sigmas if not flip else torch.flip(sigmas, (0,)), idx + else: + return sigmas if not flip else torch.flip(sigmas, (0,)) + + @abstractmethod + def get_sigmas(self, n, device): + pass + + +class EDMDiscretization(Discretization): + def __init__(self, sigma_min=0.002, sigma_max=80.0, rho=7.0): + self.sigma_min = sigma_min + self.sigma_max = sigma_max + self.rho = rho + + def get_sigmas(self, n, device="cpu"): + ramp = torch.linspace(0, 1, n, device=device) + min_inv_rho = self.sigma_min ** (1 / self.rho) + max_inv_rho = self.sigma_max ** (1 / self.rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** self.rho + return sigmas + + +class LegacyDDPMDiscretization(Discretization): + def __init__( + self, + linear_start=0.00085, + linear_end=0.0120, + num_timesteps=1000, + ): + super().__init__() + self.num_timesteps = num_timesteps + betas = make_beta_schedule("linear", num_timesteps, linear_start=linear_start, linear_end=linear_end) + alphas = 1.0 - betas + self.alphas_cumprod = np.cumprod(alphas, axis=0) + self.to_torch = partial(torch.tensor, dtype=torch.float32) + + def get_sigmas(self, n, device="cpu"): + if n < self.num_timesteps: + timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps) + alphas_cumprod = self.alphas_cumprod[timesteps] + elif n == self.num_timesteps: + alphas_cumprod = self.alphas_cumprod + else: + raise ValueError + + to_torch = partial(torch.tensor, dtype=torch.float32, device=device) + sigmas = to_torch((1 - alphas_cumprod) / alphas_cumprod) ** 0.5 + return torch.flip(sigmas, (0,)) # sigma_t: 14.4 -> 0.029 + + +class ZeroSNRDDPMDiscretization(Discretization): + def __init__( + self, + linear_start=0.00085, + linear_end=0.0120, + num_timesteps=1000, + shift_scale=1.0, # noise schedule t_n -> t_m: logSNR(t_m) = logSNR(t_n) - log(shift_scale) + keep_start=False, + post_shift=False, + ): + super().__init__() + if keep_start and not post_shift: + linear_start = linear_start / (shift_scale + (1 - shift_scale) * linear_start) + self.num_timesteps = num_timesteps + betas = make_beta_schedule("linear", num_timesteps, linear_start=linear_start, linear_end=linear_end) + alphas = 1.0 - betas + self.alphas_cumprod = np.cumprod(alphas, axis=0) + self.to_torch = partial(torch.tensor, dtype=torch.float32) + + # SNR shift + if not post_shift: + self.alphas_cumprod = self.alphas_cumprod / (shift_scale + (1 - shift_scale) * self.alphas_cumprod) + + self.post_shift = post_shift + self.shift_scale = shift_scale + + def get_sigmas(self, n, device="cpu", return_idx=False): + if n < self.num_timesteps: + timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps) + alphas_cumprod = self.alphas_cumprod[timesteps] + elif n == self.num_timesteps: + alphas_cumprod = self.alphas_cumprod + else: + raise ValueError + + to_torch = partial(torch.tensor, dtype=torch.float32, device=device) + alphas_cumprod = to_torch(alphas_cumprod) + alphas_cumprod_sqrt = alphas_cumprod.sqrt() + alphas_cumprod_sqrt_0 = alphas_cumprod_sqrt[0].clone() + alphas_cumprod_sqrt_T = alphas_cumprod_sqrt[-1].clone() + + alphas_cumprod_sqrt -= alphas_cumprod_sqrt_T + alphas_cumprod_sqrt *= alphas_cumprod_sqrt_0 / (alphas_cumprod_sqrt_0 - alphas_cumprod_sqrt_T) + + if self.post_shift: + alphas_cumprod_sqrt = ( + alphas_cumprod_sqrt**2 / (self.shift_scale + (1 - self.shift_scale) * alphas_cumprod_sqrt**2) + ) ** 0.5 + + if return_idx: + return torch.flip(alphas_cumprod_sqrt, (0,)), timesteps + else: + return torch.flip(alphas_cumprod_sqrt, (0,)) # sqrt(alpha_t): 0 -> 0.99 diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/diffusionmodules/guiders.py b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/diffusionmodules/guiders.py new file mode 100644 index 0000000000000000000000000000000000000000..7ce657c39258681fdadbec874ae2a7b6e26b8294 --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/diffusionmodules/guiders.py @@ -0,0 +1,87 @@ +import logging +from abc import ABC, abstractmethod +from typing import Dict, List, Optional, Tuple, Union +from functools import partial +import math + +import torch +from einops import rearrange, repeat + +from ...util import append_dims, default, instantiate_from_config + + +class Guider(ABC): + @abstractmethod + def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor: + pass + + def prepare_inputs(self, x: torch.Tensor, s: float, c: Dict, uc: Dict) -> Tuple[torch.Tensor, float, Dict]: + pass + + +class VanillaCFG: + """ + implements parallelized CFG + """ + + def __init__(self, scale, dyn_thresh_config=None): + self.scale = scale + scale_schedule = lambda scale, sigma: scale # independent of step + self.scale_schedule = partial(scale_schedule, scale) + self.dyn_thresh = instantiate_from_config( + default( + dyn_thresh_config, + {"target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"}, + ) + ) + + def __call__(self, x, sigma, scale=None): + x_u, x_c = x.chunk(2) + scale_value = default(scale, self.scale_schedule(sigma)) + x_pred = self.dyn_thresh(x_u, x_c, scale_value) + return x_pred + + def prepare_inputs(self, x, s, c, uc): + c_out = dict() + + for k in c: + if k in ["vector", "crossattn", "concat"]: + c_out[k] = torch.cat((uc[k], c[k]), 0) + else: + assert c[k] == uc[k] + c_out[k] = c[k] + return torch.cat([x] * 2), torch.cat([s] * 2), c_out + + +class DynamicCFG(VanillaCFG): + def __init__(self, scale, exp, num_steps, dyn_thresh_config=None): + super().__init__(scale, dyn_thresh_config) + scale_schedule = ( + lambda scale, sigma, step_index: 1 + scale * (1 - math.cos(math.pi * (step_index / num_steps) ** exp)) / 2 + ) + self.scale_schedule = partial(scale_schedule, scale) + self.dyn_thresh = instantiate_from_config( + default( + dyn_thresh_config, + {"target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"}, + ) + ) + + def __call__(self, x, sigma, step_index, scale=None): + x_u, x_c = x.chunk(2) + scale_value = self.scale_schedule(sigma, step_index.item()) + x_pred = self.dyn_thresh(x_u, x_c, scale_value) + return x_pred + + +class IdentityGuider: + def __call__(self, x, sigma): + return x + + def prepare_inputs(self, x, s, c, uc): + c_out = dict() + + for k in c: + c_out[k] = c[k] + + return x, s, c_out diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/diffusionmodules/lora.py b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/diffusionmodules/lora.py new file mode 100644 index 0000000000000000000000000000000000000000..7ccd72a19f615162832d7c5d1c215dd7921833c7 --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/diffusionmodules/lora.py @@ -0,0 +1,362 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union + +import torch +import torch.nn.functional as F +from torch import nn + + +class LoRALinearLayer(nn.Module): + def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None): + super().__init__() + + self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype) + self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype) + # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. + # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning + self.network_alpha = network_alpha + self.rank = rank + self.out_features = out_features + self.in_features = in_features + + nn.init.normal_(self.down.weight, std=1 / rank) + nn.init.zeros_(self.up.weight) + + def forward(self, hidden_states): + orig_dtype = hidden_states.dtype + dtype = self.down.weight.dtype + + down_hidden_states = self.down(hidden_states.to(dtype)) + up_hidden_states = self.up(down_hidden_states) + + if self.network_alpha is not None: + up_hidden_states *= self.network_alpha / self.rank + + return up_hidden_states.to(orig_dtype) + + +class LoRAConv2dLayer(nn.Module): + def __init__( + self, in_features, out_features, rank=4, kernel_size=(1, 1), stride=(1, 1), padding=0, network_alpha=None + ): + super().__init__() + + self.down = nn.Conv2d(in_features, rank, kernel_size=kernel_size, stride=stride, padding=padding, bias=False) + # according to the official kohya_ss trainer kernel_size are always fixed for the up layer + # # see: https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L129 + self.up = nn.Conv2d(rank, out_features, kernel_size=(1, 1), stride=(1, 1), bias=False) + + # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. + # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning + self.network_alpha = network_alpha + self.rank = rank + + nn.init.normal_(self.down.weight, std=1 / rank) + nn.init.zeros_(self.up.weight) + + def forward(self, hidden_states): + orig_dtype = hidden_states.dtype + dtype = self.down.weight.dtype + + down_hidden_states = self.down(hidden_states.to(dtype)) + up_hidden_states = self.up(down_hidden_states) + + if self.network_alpha is not None: + up_hidden_states *= self.network_alpha / self.rank + + return up_hidden_states.to(orig_dtype) + + +class LoRACompatibleConv(nn.Conv2d): + """ + A convolutional layer that can be used with LoRA. + """ + + def __init__(self, *args, lora_layer: Optional[LoRAConv2dLayer] = None, scale: float = 1.0, **kwargs): + super().__init__(*args, **kwargs) + self.lora_layer = lora_layer + self.scale = scale + + def set_lora_layer(self, lora_layer: Optional[LoRAConv2dLayer]): + self.lora_layer = lora_layer + + def _fuse_lora(self, lora_scale=1.0): + if self.lora_layer is None: + return + + dtype, device = self.weight.data.dtype, self.weight.data.device + + w_orig = self.weight.data.float() + w_up = self.lora_layer.up.weight.data.float() + w_down = self.lora_layer.down.weight.data.float() + + if self.lora_layer.network_alpha is not None: + w_up = w_up * self.lora_layer.network_alpha / self.lora_layer.rank + + fusion = torch.mm(w_up.flatten(start_dim=1), w_down.flatten(start_dim=1)) + fusion = fusion.reshape((w_orig.shape)) + fused_weight = w_orig + (lora_scale * fusion) + self.weight.data = fused_weight.to(device=device, dtype=dtype) + + # we can drop the lora layer now + self.lora_layer = None + + # offload the up and down matrices to CPU to not blow the memory + self.w_up = w_up.cpu() + self.w_down = w_down.cpu() + self._lora_scale = lora_scale + + def _unfuse_lora(self): + if not (hasattr(self, "w_up") and hasattr(self, "w_down")): + return + + fused_weight = self.weight.data + dtype, device = fused_weight.data.dtype, fused_weight.data.device + + self.w_up = self.w_up.to(device=device).float() + self.w_down = self.w_down.to(device).float() + + fusion = torch.mm(self.w_up.flatten(start_dim=1), self.w_down.flatten(start_dim=1)) + fusion = fusion.reshape((fused_weight.shape)) + unfused_weight = fused_weight.float() - (self._lora_scale * fusion) + self.weight.data = unfused_weight.to(device=device, dtype=dtype) + + self.w_up = None + self.w_down = None + + def forward(self, hidden_states, scale: float = None): + if scale is None: + scale = self.scale + if self.lora_layer is None: + # make sure to the functional Conv2D function as otherwise torch.compile's graph will break + # see: https://github.com/huggingface/diffusers/pull/4315 + return F.conv2d( + hidden_states, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups + ) + else: + return super().forward(hidden_states) + (scale * self.lora_layer(hidden_states)) + + +class LoRACompatibleLinear(nn.Linear): + """ + A Linear layer that can be used with LoRA. + """ + + def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, scale: float = 1.0, **kwargs): + super().__init__(*args, **kwargs) + self.lora_layer = lora_layer + self.scale = scale + + def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]): + self.lora_layer = lora_layer + + def _fuse_lora(self, lora_scale=1.0): + if self.lora_layer is None: + return + + dtype, device = self.weight.data.dtype, self.weight.data.device + + w_orig = self.weight.data.float() + w_up = self.lora_layer.up.weight.data.float() + w_down = self.lora_layer.down.weight.data.float() + + if self.lora_layer.network_alpha is not None: + w_up = w_up * self.lora_layer.network_alpha / self.lora_layer.rank + + fused_weight = w_orig + (lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0]) + self.weight.data = fused_weight.to(device=device, dtype=dtype) + + # we can drop the lora layer now + self.lora_layer = None + + # offload the up and down matrices to CPU to not blow the memory + self.w_up = w_up.cpu() + self.w_down = w_down.cpu() + self._lora_scale = lora_scale + + def _unfuse_lora(self): + if not (hasattr(self, "w_up") and hasattr(self, "w_down")): + return + + fused_weight = self.weight.data + dtype, device = fused_weight.dtype, fused_weight.device + + w_up = self.w_up.to(device=device).float() + w_down = self.w_down.to(device).float() + + unfused_weight = fused_weight.float() - (self._lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0]) + self.weight.data = unfused_weight.to(device=device, dtype=dtype) + + self.w_up = None + self.w_down = None + + def forward(self, hidden_states, scale: float = None): + if scale is None: + scale = self.scale + if self.lora_layer is None: + out = super().forward(hidden_states) + return out + else: + out = super().forward(hidden_states) + (scale * self.lora_layer(hidden_states)) + return out + + +def _find_children( + model, + search_class: List[Type[nn.Module]] = [nn.Linear], +): + """ + Find all modules of a certain class (or union of classes). + + Returns all matching modules, along with the parent of those moduless and the + names they are referenced by. + """ + # For each target find every linear_class module that isn't a child of a LoraInjectedLinear + for parent in model.modules(): + for name, module in parent.named_children(): + if any([isinstance(module, _class) for _class in search_class]): + yield parent, name, module + + +def _find_modules_v2( + model, + ancestor_class: Optional[Set[str]] = None, + search_class: List[Type[nn.Module]] = [nn.Linear], + exclude_children_of: Optional[List[Type[nn.Module]]] = [ + LoRACompatibleLinear, + LoRACompatibleConv, + LoRALinearLayer, + LoRAConv2dLayer, + ], +): + """ + Find all modules of a certain class (or union of classes) that are direct or + indirect descendants of other modules of a certain class (or union of classes). + + Returns all matching modules, along with the parent of those moduless and the + names they are referenced by. + """ + + # Get the targets we should replace all linears under + if ancestor_class is not None: + ancestors = (module for module in model.modules() if module.__class__.__name__ in ancestor_class) + else: + # this, incase you want to naively iterate over all modules. + ancestors = [module for module in model.modules()] + + # For each target find every linear_class module that isn't a child of a LoraInjectedLinear + for ancestor in ancestors: + for fullname, module in ancestor.named_modules(): + if any([isinstance(module, _class) for _class in search_class]): + # Find the direct parent if this is a descendant, not a child, of target + *path, name = fullname.split(".") + parent = ancestor + flag = False + while path: + try: + parent = parent.get_submodule(path.pop(0)) + except: + flag = True + break + if flag: + continue + # Skip this linear if it's a child of a LoraInjectedLinear + if exclude_children_of and any([isinstance(parent, _class) for _class in exclude_children_of]): + continue + # Otherwise, yield it + yield parent, name, module + + +_find_modules = _find_modules_v2 + + +def inject_trainable_lora_extended( + model: nn.Module, + target_replace_module: Set[str] = None, + rank: int = 4, + scale: float = 1.0, +): + for _module, name, _child_module in _find_modules( + model, target_replace_module, search_class=[nn.Linear, nn.Conv2d] + ): + if _child_module.__class__ == nn.Linear: + weight = _child_module.weight + bias = _child_module.bias + lora_layer = LoRALinearLayer( + in_features=_child_module.in_features, + out_features=_child_module.out_features, + rank=rank, + ) + _tmp = ( + LoRACompatibleLinear( + _child_module.in_features, + _child_module.out_features, + lora_layer=lora_layer, + scale=scale, + ) + .to(weight.dtype) + .to(weight.device) + ) + _tmp.weight = weight + if bias is not None: + _tmp.bias = bias + elif _child_module.__class__ == nn.Conv2d: + weight = _child_module.weight + bias = _child_module.bias + lora_layer = LoRAConv2dLayer( + in_features=_child_module.in_channels, + out_features=_child_module.out_channels, + rank=rank, + kernel_size=_child_module.kernel_size, + stride=_child_module.stride, + padding=_child_module.padding, + ) + _tmp = ( + LoRACompatibleConv( + _child_module.in_channels, + _child_module.out_channels, + kernel_size=_child_module.kernel_size, + stride=_child_module.stride, + padding=_child_module.padding, + lora_layer=lora_layer, + scale=scale, + ) + .to(weight.dtype) + .to(weight.device) + ) + _tmp.weight = weight + if bias is not None: + _tmp.bias = bias + else: + continue + + _module._modules[name] = _tmp + # print('injecting lora layer to', _module, name) + + return + + +def update_lora_scale( + model: nn.Module, + target_module: Set[str] = None, + scale: float = 1.0, +): + for _module, name, _child_module in _find_modules( + model, target_module, search_class=[LoRACompatibleLinear, LoRACompatibleConv] + ): + _child_module.scale = scale + + return diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/diffusionmodules/loss.py b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/diffusionmodules/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..66916c1b44a1152615d074aae83dc7623d8e725e --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/diffusionmodules/loss.py @@ -0,0 +1,120 @@ +from typing import List, Optional, Union + +import torch +import torch.nn as nn +from omegaconf import ListConfig +from ...util import append_dims, instantiate_from_config +from ...modules.autoencoding.lpips.loss.lpips import LPIPS +from sat import mpu + + +class StandardDiffusionLoss(nn.Module): + def __init__( + self, + sigma_sampler_config, + type="l2", + offset_noise_level=0.0, + batch2model_keys: Optional[Union[str, List[str], ListConfig]] = None, + ): + super().__init__() + + assert type in ["l2", "l1", "lpips"] + + self.sigma_sampler = instantiate_from_config(sigma_sampler_config) + + self.type = type + self.offset_noise_level = offset_noise_level + + if type == "lpips": + self.lpips = LPIPS().eval() + + if not batch2model_keys: + batch2model_keys = [] + + if isinstance(batch2model_keys, str): + batch2model_keys = [batch2model_keys] + + self.batch2model_keys = set(batch2model_keys) + + def __call__(self, network, denoiser, conditioner, input, batch): + cond = conditioner(batch) + additional_model_inputs = {key: batch[key] for key in self.batch2model_keys.intersection(batch)} + + sigmas = self.sigma_sampler(input.shape[0]).to(input.device) + noise = torch.randn_like(input) + if self.offset_noise_level > 0.0: + noise = ( + noise + append_dims(torch.randn(input.shape[0]).to(input.device), input.ndim) * self.offset_noise_level + ) + noise = noise.to(input.dtype) + noised_input = input.float() + noise * append_dims(sigmas, input.ndim) + model_output = denoiser(network, noised_input, sigmas, cond, **additional_model_inputs) + w = append_dims(denoiser.w(sigmas), input.ndim) + return self.get_loss(model_output, input, w) + + def get_loss(self, model_output, target, w): + if self.type == "l2": + return torch.mean((w * (model_output - target) ** 2).reshape(target.shape[0], -1), 1) + elif self.type == "l1": + return torch.mean((w * (model_output - target).abs()).reshape(target.shape[0], -1), 1) + elif self.type == "lpips": + loss = self.lpips(model_output, target).reshape(-1) + return loss + + +class VideoDiffusionLoss(StandardDiffusionLoss): + def __init__(self, block_scale=None, block_size=None, min_snr_value=None, fixed_frames=0, **kwargs): + self.fixed_frames = fixed_frames + self.block_scale = block_scale + self.block_size = block_size + self.min_snr_value = min_snr_value + super().__init__(**kwargs) + + def __call__(self, network, denoiser, conditioner, input, batch): + cond = conditioner(batch) + additional_model_inputs = {key: batch[key] for key in self.batch2model_keys.intersection(batch)} + + alphas_cumprod_sqrt, idx = self.sigma_sampler(input.shape[0], return_idx=True) + alphas_cumprod_sqrt = alphas_cumprod_sqrt.to(input.device) + idx = idx.to(input.device) + + noise = torch.randn_like(input) + + # broadcast noise + mp_size = mpu.get_model_parallel_world_size() + global_rank = torch.distributed.get_rank() // mp_size + src = global_rank * mp_size + torch.distributed.broadcast(idx, src=src, group=mpu.get_model_parallel_group()) + torch.distributed.broadcast(noise, src=src, group=mpu.get_model_parallel_group()) + torch.distributed.broadcast(alphas_cumprod_sqrt, src=src, group=mpu.get_model_parallel_group()) + + additional_model_inputs["idx"] = idx + + if self.offset_noise_level > 0.0: + noise = ( + noise + append_dims(torch.randn(input.shape[0]).to(input.device), input.ndim) * self.offset_noise_level + ) + + noised_input = input.float() * append_dims(alphas_cumprod_sqrt, input.ndim) + noise * append_dims( + (1 - alphas_cumprod_sqrt**2) ** 0.5, input.ndim + ) + + if "concat_images" in batch.keys(): + cond["concat"] = batch["concat_images"] + + # [2, 13, 16, 60, 90],[2] dict_keys(['crossattn', 'concat']) dict_keys(['idx']) + model_output = denoiser(network, noised_input, alphas_cumprod_sqrt, cond, **additional_model_inputs) + w = append_dims(1 / (1 - alphas_cumprod_sqrt**2), input.ndim) # v-pred + + if self.min_snr_value is not None: + w = min(w, self.min_snr_value) + return self.get_loss(model_output, input, w) + + def get_loss(self, model_output, target, w): + if self.type == "l2": + return torch.mean((w * (model_output - target) ** 2).reshape(target.shape[0], -1), 1) + elif self.type == "l1": + return torch.mean((w * (model_output - target).abs()).reshape(target.shape[0], -1), 1) + elif self.type == "lpips": + loss = self.lpips(model_output, target).reshape(-1) + return loss diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/diffusionmodules/model.py b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/diffusionmodules/model.py new file mode 100644 index 0000000000000000000000000000000000000000..466f01ac967bcc6d240d1eda06b308b46ac07bce --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/diffusionmodules/model.py @@ -0,0 +1,683 @@ +# pytorch_diffusion + derived encoder decoder +import math +from typing import Any, Callable, Optional + +import numpy as np +import torch +import torch.nn as nn +from einops import rearrange +from packaging import version + +try: + import xformers + import xformers.ops + + XFORMERS_IS_AVAILABLE = True +except: + XFORMERS_IS_AVAILABLE = False + print("no module 'xformers'. Processing without...") + +from ...modules.attention import LinearAttention, MemoryEfficientCrossAttention + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +def nonlinearity(x): + # swish + return x * torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResnetBlock(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout, + temb_channels=512, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h + + +class LinAttnBlock(LinearAttention): + """to match AttnBlock usage""" + + def __init__(self, in_channels): + super().__init__(dim=in_channels, heads=1, dim_head=in_channels) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + + def attention(self, h_: torch.Tensor) -> torch.Tensor: + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + b, c, h, w = q.shape + q, k, v = map(lambda x: rearrange(x, "b c h w -> b 1 (h w) c").contiguous(), (q, k, v)) + h_ = torch.nn.functional.scaled_dot_product_attention(q, k, v) # scale is dim ** -0.5 per default + # compute attention + + return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) + + def forward(self, x, **kwargs): + h_ = x + h_ = self.attention(h_) + h_ = self.proj_out(h_) + return x + h_ + + +class MemoryEfficientAttnBlock(nn.Module): + """ + Uses xformers efficient implementation, + see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 + Note: this is a single-head self-attention operation + """ + + # + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.attention_op: Optional[Any] = None + + def attention(self, h_: torch.Tensor) -> torch.Tensor: + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + B, C, H, W = q.shape + q, k, v = map(lambda x: rearrange(x, "b c h w -> b (h w) c"), (q, k, v)) + + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(B, t.shape[1], 1, C) + .permute(0, 2, 1, 3) + .reshape(B * 1, t.shape[1], C) + .contiguous(), + (q, k, v), + ) + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) + + out = out.unsqueeze(0).reshape(B, 1, out.shape[1], C).permute(0, 2, 1, 3).reshape(B, out.shape[1], C) + return rearrange(out, "b (h w) c -> b c h w", b=B, h=H, w=W, c=C) + + def forward(self, x, **kwargs): + h_ = x + h_ = self.attention(h_) + h_ = self.proj_out(h_) + return x + h_ + + +class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention): + def forward(self, x, context=None, mask=None, **unused_kwargs): + b, c, h, w = x.shape + x = rearrange(x, "b c h w -> b (h w) c") + out = super().forward(x, context=context, mask=mask) + out = rearrange(out, "b (h w) c -> b c h w", h=h, w=w, c=c) + return x + out + + +def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): + assert attn_type in [ + "vanilla", + "vanilla-xformers", + "memory-efficient-cross-attn", + "linear", + "none", + ], f"attn_type {attn_type} unknown" + if version.parse(torch.__version__) < version.parse("2.0.0") and attn_type != "none": + assert XFORMERS_IS_AVAILABLE, ( + f"We do not support vanilla attention in {torch.__version__} anymore, " + f"as it is too expensive. Please install xformers via e.g. 'pip install xformers==0.0.16'" + ) + attn_type = "vanilla-xformers" + print(f"making attention of type '{attn_type}' with {in_channels} in_channels") + if attn_type == "vanilla": + assert attn_kwargs is None + return AttnBlock(in_channels) + elif attn_type == "vanilla-xformers": + print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...") + return MemoryEfficientAttnBlock(in_channels) + elif type == "memory-efficient-cross-attn": + attn_kwargs["query_dim"] = in_channels + return MemoryEfficientCrossAttentionWrapper(**attn_kwargs) + elif attn_type == "none": + return nn.Identity(in_channels) + else: + return LinAttnBlock(in_channels) + + +class Model(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + use_timestep=True, + use_linear_attn=False, + attn_type="vanilla", + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = self.ch * 4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + self.use_timestep = use_timestep + if self.use_timestep: + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList( + [ + torch.nn.Linear(self.ch, self.temb_ch), + torch.nn.Linear(self.temb_ch, self.temb_ch), + ] + ) + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + skip_in = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + if i_block == self.num_res_blocks: + skip_in = ch * in_ch_mult[i_level] + block.append( + ResnetBlock( + in_channels=block_in + skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) + + def forward(self, x, t=None, context=None): + # assert x.shape[2] == x.shape[3] == self.resolution + if context is not None: + # assume aligned context, cat along channel axis + x = torch.cat((x, context), dim=1) + if self.use_timestep: + # timestep embedding + assert t is not None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + else: + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](torch.cat([h, hs.pop()], dim=1), temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + def get_last_layer(self): + return self.conv_out.weight + + +class Encoder(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + double_z=True, + use_linear_attn=False, + attn_type="vanilla", + **ignore_kwargs, + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, + 2 * z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + def forward(self, x): + # timestep embedding + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + give_pre_end=False, + tanh_out=False, + use_linear_attn=False, + attn_type="vanilla", + **ignorekwargs, + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,) + tuple(ch_mult) + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) + + make_attn_cls = self._make_attn() + make_resblock_cls = self._make_resblock() + make_conv_cls = self._make_conv() + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = make_resblock_cls( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = make_attn_cls(block_in, attn_type=attn_type) + self.mid.block_2 = make_resblock_cls( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + make_resblock_cls( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn_cls(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = make_conv_cls(block_in, out_ch, kernel_size=3, stride=1, padding=1) + + def _make_attn(self) -> Callable: + return make_attn + + def _make_resblock(self) -> Callable: + return ResnetBlock + + def _make_conv(self) -> Callable: + return torch.nn.Conv2d + + def get_last_layer(self, **kwargs): + return self.conv_out.weight + + def forward(self, z, **kwargs): + # assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb, **kwargs) + h = self.mid.attn_1(h, **kwargs) + h = self.mid.block_2(h, temb, **kwargs) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h, temb, **kwargs) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h, **kwargs) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h, **kwargs) + if self.tanh_out: + h = torch.tanh(h) + return h diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/diffusionmodules/openaimodel.py b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/diffusionmodules/openaimodel.py new file mode 100644 index 0000000000000000000000000000000000000000..3f0b83cb3fdee7e342e03f6e5ce36ba26ef1cbe4 --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/diffusionmodules/openaimodel.py @@ -0,0 +1,1248 @@ +import os +import math +from abc import abstractmethod +from functools import partial +from typing import Iterable, List, Optional, Tuple, Union + +import numpy as np +import torch as th +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +from ...modules.attention import SpatialTransformer +from ...modules.diffusionmodules.util import ( + avg_pool_nd, + checkpoint, + conv_nd, + linear, + normalization, + timestep_embedding, + zero_module, +) +from ...modules.diffusionmodules.lora import inject_trainable_lora_extended, update_lora_scale +from ...modules.video_attention import SpatialVideoTransformer +from ...util import default, exists + + +# dummy replace +def convert_module_to_f16(x): + pass + + +def convert_module_to_f32(x): + pass + + +class AttentionPool2d(nn.Module): + """ + Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def __init__( + self, + spacial_dim: int, + embed_dim: int, + num_heads_channels: int, + output_dim: int = None, + ): + super().__init__() + self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5) + self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) + self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) + self.num_heads = embed_dim // num_heads_channels + self.attention = QKVAttention(self.num_heads) + + def forward(self, x): + b, c, *_spatial = x.shape + x = x.reshape(b, c, -1) # NC(HW) + x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) + x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) + x = self.qkv_proj(x) + x = self.attention(x) + x = self.c_proj(x) + return x[:, :, 0] + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward( + self, + x: th.Tensor, + emb: th.Tensor, + context: Optional[th.Tensor] = None, + image_only_indicator: Optional[th.Tensor] = None, + time_context: Optional[int] = None, + num_video_frames: Optional[int] = None, + ): + from ...modules.diffusionmodules.video_model import VideoResBlock + + for layer in self: + module = layer + + if isinstance(module, TimestepBlock) and not isinstance(module, VideoResBlock): + x = layer(x, emb) + elif isinstance(module, VideoResBlock): + x = layer(x, emb, num_video_frames, image_only_indicator) + elif isinstance(module, SpatialVideoTransformer): + x = layer( + x, + context, + time_context, + num_video_frames, + image_only_indicator, + ) + elif isinstance(module, SpatialTransformer): + x = layer(x, context) + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, third_up=False): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + self.third_up = third_up + if use_conv: + self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + t_factor = 1 if not self.third_up else 2 + x = F.interpolate( + x, + (t_factor * x.shape[2], x.shape[3] * 2, x.shape[4] * 2), + mode="nearest", + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + + +class TransposedUpsample(nn.Module): + "Learned 2x upsampling without padding" + + def __init__(self, channels, out_channels=None, ks=5): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.up = nn.ConvTranspose2d(self.channels, self.out_channels, kernel_size=ks, stride=2) + + def forward(self, x): + return self.up(x) + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, third_down=False): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else ((1, 2, 2) if not third_down else (2, 2, 2)) + if use_conv: + print(f"Building a Downsample layer with {dims} dims.") + print( + f" --> settings are: \n in-chn: {self.channels}, out-chn: {self.out_channels}, " + f"kernel-size: 3, stride: {stride}, padding: {padding}" + ) + if dims == 3: + print(f" --> Downsampling third axis (time): {third_down}") + self.op = conv_nd( + dims, + self.channels, + self.out_channels, + 3, + stride=stride, + padding=padding, + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + kernel_size=3, + exchange_temb_dims=False, + skip_t_emb=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + self.exchange_temb_dims = exchange_temb_dims + + if isinstance(kernel_size, Iterable): + padding = [k // 2 for k in kernel_size] + else: + padding = kernel_size // 2 + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.skip_t_emb = skip_t_emb + self.emb_out_channels = 2 * self.out_channels if use_scale_shift_norm else self.out_channels + if self.skip_t_emb: + print(f"Skipping timestep embedding in {self.__class__.__name__}") + assert not self.use_scale_shift_norm + self.emb_layers = None + self.exchange_temb_dims = False + else: + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear( + emb_channels, + self.emb_out_channels, + ), + ) + + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + conv_nd( + dims, + self.out_channels, + self.out_channels, + kernel_size, + padding=padding, + ) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return checkpoint(self._forward, (x, emb), self.parameters(), self.use_checkpoint) + + def _forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + + if self.skip_t_emb: + emb_out = th.zeros_like(h) + else: + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + if self.exchange_temb_dims: + emb_out = rearrange(emb_out, "b t c ... -> b c t ...") + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, channels * 3, 1) + if use_new_attention_order: + # split qkv before split heads + self.attention = QKVAttention(self.num_heads) + else: + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x, **kwargs): + # TODO add crossframe attention and use mixed checkpoint + return checkpoint( + self._forward, (x,), self.parameters(), True + ) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! + # return pt_checkpoint(self._forward, x) # pytorch + + def _forward(self, x): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1) + qkv = self.qkv(self.norm(x)) + h = self.attention(qkv) + h = self.proj_out(h) + return (x + h).reshape(b, c, *spatial) + + +def count_flops_attn(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial**2) * c + model.total_ops += th.DoubleTensor([matmul_ops]) + + +class QKVAttentionLegacy(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class Timestep(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, t): + return timestep_embedding(t, self.dim) + + +str_to_dtype = {"fp32": th.float32, "fp16": th.float16, "bf16": th.bfloat16} + + +class UNetModel(nn.Module): + """ + The full UNet model with attention and timestep embedding. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + """ + + def __init__( + self, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + use_checkpoint=False, + use_fp16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + disable_self_attentions=None, + num_attention_blocks=None, + disable_middle_self_attn=False, + use_linear_in_transformer=False, + spatial_transformer_attn_type="softmax", + adm_in_channels=None, + use_fairscale_checkpoint=False, + offload_to_cpu=False, + transformer_depth_middle=None, + dtype="fp32", + lora_init=False, + lora_rank=4, + lora_scale=1.0, + lora_weight_path=None, + ): + super().__init__() + from omegaconf.listconfig import ListConfig + + self.dtype = str_to_dtype[dtype] + + if use_spatial_transformer: + assert ( + context_dim is not None + ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..." + + if context_dim is not None: + assert ( + use_spatial_transformer + ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..." + if type(context_dim) == ListConfig: + context_dim = list(context_dim) + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert num_head_channels != -1, "Either num_heads or num_head_channels has to be set" + + if num_head_channels == -1: + assert num_heads != -1, "Either num_heads or num_head_channels has to be set" + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + if isinstance(transformer_depth, int): + transformer_depth = len(channel_mult) * [transformer_depth] + elif isinstance(transformer_depth, ListConfig): + transformer_depth = list(transformer_depth) + transformer_depth_middle = default(transformer_depth_middle, transformer_depth[-1]) + + if isinstance(num_res_blocks, int): + self.num_res_blocks = len(channel_mult) * [num_res_blocks] + else: + if len(num_res_blocks) != len(channel_mult): + raise ValueError( + "provide num_res_blocks either as an int (globally constant) or " + "as a list/tuple (per-level) with the same length as channel_mult" + ) + self.num_res_blocks = num_res_blocks + # self.num_res_blocks = num_res_blocks + if disable_self_attentions is not None: + # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not + assert len(disable_self_attentions) == len(channel_mult) + if num_attention_blocks is not None: + assert len(num_attention_blocks) == len(self.num_res_blocks) + assert all( + map( + lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], + range(len(num_attention_blocks)), + ) + ) + print( + f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " + f"This option has LESS priority than attention_resolutions {attention_resolutions}, " + f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " + f"attention will still not be set." + ) # todo: convert to warning + + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + if use_fp16: + print("WARNING: use_fp16 was dropped and has no effect anymore.") + # self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + + assert use_fairscale_checkpoint != use_checkpoint or not (use_checkpoint or use_fairscale_checkpoint) + + self.use_fairscale_checkpoint = False + checkpoint_wrapper_fn = ( + partial(checkpoint_wrapper, offload_to_cpu=offload_to_cpu) + if self.use_fairscale_checkpoint + else lambda x: x + ) + + time_embed_dim = model_channels * 4 + self.time_embed = checkpoint_wrapper_fn( + nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + ) + + if self.num_classes is not None: + if isinstance(self.num_classes, int): + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + elif self.num_classes == "continuous": + print("setting up linear c_adm embedding layer") + self.label_emb = nn.Linear(1, time_embed_dim) + elif self.num_classes == "timestep": + self.label_emb = checkpoint_wrapper_fn( + nn.Sequential( + Timestep(model_channels), + nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ), + ) + ) + elif self.num_classes == "sequential": + assert adm_in_channels is not None + self.label_emb = nn.Sequential( + nn.Sequential( + linear(adm_in_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + ) + else: + raise ValueError() + + self.input_blocks = nn.ModuleList( + [TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for nr in range(self.num_res_blocks[level]): + layers = [ + checkpoint_wrapper_fn( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + if exists(disable_self_attentions): + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False + + if not exists(num_attention_blocks) or nr < num_attention_blocks[level]: + layers.append( + checkpoint_wrapper_fn( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + ) + if not use_spatial_transformer + else checkpoint_wrapper_fn( + SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth[level], + context_dim=context_dim, + disable_self_attn=disabled_sa, + use_linear=use_linear_in_transformer, + attn_type=spatial_transformer_attn_type, + use_checkpoint=use_checkpoint, + ) + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + checkpoint_wrapper_fn( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + ) + if resblock_updown + else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + self.middle_block = TimestepEmbedSequential( + checkpoint_wrapper_fn( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ), + checkpoint_wrapper_fn( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + ) + if not use_spatial_transformer + else checkpoint_wrapper_fn( + SpatialTransformer( # always uses a self-attn + ch, + num_heads, + dim_head, + depth=transformer_depth_middle, + context_dim=context_dim, + disable_self_attn=disable_middle_self_attn, + use_linear=use_linear_in_transformer, + attn_type=spatial_transformer_attn_type, + use_checkpoint=use_checkpoint, + ) + ), + checkpoint_wrapper_fn( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(self.num_res_blocks[level] + 1): + ich = input_block_chans.pop() + layers = [ + checkpoint_wrapper_fn( + ResBlock( + ch + ich, + time_embed_dim, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + if exists(disable_self_attentions): + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False + + if not exists(num_attention_blocks) or i < num_attention_blocks[level]: + layers.append( + checkpoint_wrapper_fn( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + ) + if not use_spatial_transformer + else checkpoint_wrapper_fn( + SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth[level], + context_dim=context_dim, + disable_self_attn=disabled_sa, + use_linear=use_linear_in_transformer, + attn_type=spatial_transformer_attn_type, + use_checkpoint=use_checkpoint, + ) + ) + ) + if level and i == self.num_res_blocks[level]: + out_ch = ch + layers.append( + checkpoint_wrapper_fn( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = checkpoint_wrapper_fn( + nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + ) + if self.predict_codebook_ids: + self.id_predictor = checkpoint_wrapper_fn( + nn.Sequential( + normalization(ch), + conv_nd(dims, model_channels, n_embed, 1), + # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits + ) + ) + + if lora_init: + self._init_lora(lora_rank, lora_scale, lora_weight_path) + + def _init_lora(self, rank, scale, ckpt_dir=None): + inject_trainable_lora_extended(self, target_replace_module=None, rank=rank, scale=scale) + + if ckpt_dir is not None: + with open(os.path.join(ckpt_dir, "latest")) as latest_file: + latest = latest_file.read().strip() + ckpt_path = os.path.join(ckpt_dir, latest, "mp_rank_00_model_states.pt") + print(f"loading lora from {ckpt_path}") + sd = th.load(ckpt_path)["module"] + sd = { + key[len("model.diffusion_model") :]: sd[key] for key in sd if key.startswith("model.diffusion_model") + } + self.load_state_dict(sd, strict=False) + + def _update_scale(self, scale): + update_lora_scale(self, scale) + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + self.output_blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + self.output_blocks.apply(convert_module_to_f32) + + def forward(self, x, timesteps=None, context=None, y=None, **kwargs): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param context: conditioning plugged in via crossattn + :param y: an [N] Tensor of labels, if class-conditional. + :return: an [N x C x ...] Tensor of outputs. + """ + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False, dtype=self.dtype) + emb = self.time_embed(t_emb) + + if self.num_classes is not None: + assert y.shape[0] == x.shape[0] + emb = emb + self.label_emb(y) + + # h = x.type(self.dtype) + h = x + for module in self.input_blocks: + h = module(h, emb, context) + hs.append(h) + h = self.middle_block(h, emb, context) + for module in self.output_blocks: + h = th.cat([h, hs.pop()], dim=1) + h = module(h, emb, context) + h = h.type(x.dtype) + if self.predict_codebook_ids: + assert False, "not supported anymore. what the f*** are you doing?" + else: + return self.out(h) + + +class NoTimeUNetModel(UNetModel): + def forward(self, x, timesteps=None, context=None, y=None, **kwargs): + timesteps = th.zeros_like(timesteps) + return super().forward(x, timesteps, context, y, **kwargs) + + +class EncoderUNetModel(nn.Module): + """ + The half UNet model with attention and timestep embedding. + For usage, see UNet. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + use_checkpoint=False, + use_fp16=False, + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + pool="adaptive", + *args, + **kwargs, + ): + super().__init__() + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + self.input_blocks = nn.ModuleList( + [TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + self.pool = pool + if pool == "adaptive": + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + nn.AdaptiveAvgPool2d((1, 1)), + zero_module(conv_nd(dims, ch, out_channels, 1)), + nn.Flatten(), + ) + elif pool == "attention": + assert num_head_channels != -1 + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + AttentionPool2d((image_size // ds), ch, num_head_channels, out_channels), + ) + elif pool == "spatial": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + nn.ReLU(), + nn.Linear(2048, self.out_channels), + ) + elif pool == "spatial_v2": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + normalization(2048), + nn.SiLU(), + nn.Linear(2048, self.out_channels), + ) + else: + raise NotImplementedError(f"Unexpected {pool} pooling") + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + + def forward(self, x, timesteps): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :return: an [N x K] Tensor of outputs. + """ + emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + + results = [] + # h = x.type(self.dtype) + h = x + for module in self.input_blocks: + h = module(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = self.middle_block(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = th.cat(results, axis=-1) + return self.out(h) + else: + h = h.type(x.dtype) + return self.out(h) + + +if __name__ == "__main__": + + class Dummy(nn.Module): + def __init__(self, in_channels=3, model_channels=64): + super().__init__() + self.input_blocks = nn.ModuleList( + [TimestepEmbedSequential(conv_nd(2, in_channels, model_channels, 3, padding=1))] + ) + + model = UNetModel( + use_checkpoint=True, + image_size=64, + in_channels=4, + out_channels=4, + model_channels=128, + attention_resolutions=[4, 2], + num_res_blocks=2, + channel_mult=[1, 2, 4], + num_head_channels=64, + use_spatial_transformer=False, + use_linear_in_transformer=True, + transformer_depth=1, + legacy=False, + ).cuda() + x = th.randn(11, 4, 64, 64).cuda() + t = th.randint(low=0, high=10, size=(11,), device="cuda") + o = model(x, t) + print("done.") diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/diffusionmodules/sampling.py b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/diffusionmodules/sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..f0f18302b2b36c607a15bdbf27b489ca5a447712 --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/diffusionmodules/sampling.py @@ -0,0 +1,763 @@ +""" +Partially ported from https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py +""" + +from typing import Dict, Union + +import torch +from omegaconf import ListConfig, OmegaConf +from tqdm import tqdm + +from ...modules.diffusionmodules.sampling_utils import ( + get_ancestral_step, + linear_multistep_coeff, + to_d, + to_neg_log_sigma, + to_sigma, +) +from ...util import append_dims, default, instantiate_from_config +from ...util import SeededNoise + +from .guiders import DynamicCFG + +DEFAULT_GUIDER = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"} + + +class BaseDiffusionSampler: + def __init__( + self, + discretization_config: Union[Dict, ListConfig, OmegaConf], + num_steps: Union[int, None] = None, + guider_config: Union[Dict, ListConfig, OmegaConf, None] = None, + verbose: bool = False, + device: str = "cuda", + ): + self.num_steps = num_steps + self.discretization = instantiate_from_config(discretization_config) + self.guider = instantiate_from_config( + default( + guider_config, + DEFAULT_GUIDER, + ) + ) + self.verbose = verbose + self.device = device + + def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None): + sigmas = self.discretization(self.num_steps if num_steps is None else num_steps, device=self.device) + uc = default(uc, cond) + + x *= torch.sqrt(1.0 + sigmas[0] ** 2.0) + num_sigmas = len(sigmas) + + s_in = x.new_ones([x.shape[0]]).float() + + return x, s_in, sigmas, num_sigmas, cond, uc + + def denoise(self, x, denoiser, sigma, cond, uc): + denoised = denoiser(*self.guider.prepare_inputs(x, sigma, cond, uc)) + denoised = self.guider(denoised, sigma) + return denoised + + def get_sigma_gen(self, num_sigmas): + sigma_generator = range(num_sigmas - 1) + if self.verbose: + print("#" * 30, " Sampling setting ", "#" * 30) + print(f"Sampler: {self.__class__.__name__}") + print(f"Discretization: {self.discretization.__class__.__name__}") + print(f"Guider: {self.guider.__class__.__name__}") + sigma_generator = tqdm( + sigma_generator, + total=num_sigmas, + desc=f"Sampling with {self.__class__.__name__} for {num_sigmas} steps", + ) + return sigma_generator + + +class SingleStepDiffusionSampler(BaseDiffusionSampler): + def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc, *args, **kwargs): + raise NotImplementedError + + def euler_step(self, x, d, dt): + return x + dt * d + + +class EDMSampler(SingleStepDiffusionSampler): + def __init__(self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.s_churn = s_churn + self.s_tmin = s_tmin + self.s_tmax = s_tmax + self.s_noise = s_noise + + def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, gamma=0.0): + sigma_hat = sigma * (gamma + 1.0) + if gamma > 0: + eps = torch.randn_like(x) * self.s_noise + x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5 + + denoised = self.denoise(x, denoiser, sigma_hat, cond, uc) + d = to_d(x, sigma_hat, denoised) + dt = append_dims(next_sigma - sigma_hat, x.ndim) + + euler_step = self.euler_step(x, d, dt) + x = self.possible_correction_step(euler_step, x, d, dt, next_sigma, denoiser, cond, uc) + return x + + def __call__(self, denoiser, x, cond, uc=None, num_steps=None): + x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps) + + for i in self.get_sigma_gen(num_sigmas): + gamma = ( + min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1) if self.s_tmin <= sigmas[i] <= self.s_tmax else 0.0 + ) + x = self.sampler_step( + s_in * sigmas[i], + s_in * sigmas[i + 1], + denoiser, + x, + cond, + uc, + gamma, + ) + + return x + + +class DDIMSampler(SingleStepDiffusionSampler): + def __init__(self, s_noise=0.1, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.s_noise = s_noise + + def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, s_noise=0.0): + denoised = self.denoise(x, denoiser, sigma, cond, uc) + d = to_d(x, sigma, denoised) + dt = append_dims(next_sigma * (1 - s_noise**2) ** 0.5 - sigma, x.ndim) + + euler_step = x + dt * d + s_noise * append_dims(next_sigma, x.ndim) * torch.randn_like(x) + + x = self.possible_correction_step(euler_step, x, d, dt, next_sigma, denoiser, cond, uc) + return x + + def __call__(self, denoiser, x, cond, uc=None, num_steps=None): + x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps) + + for i in self.get_sigma_gen(num_sigmas): + x = self.sampler_step( + s_in * sigmas[i], + s_in * sigmas[i + 1], + denoiser, + x, + cond, + uc, + self.s_noise, + ) + + return x + + +class AncestralSampler(SingleStepDiffusionSampler): + def __init__(self, eta=1.0, s_noise=1.0, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.eta = eta + self.s_noise = s_noise + self.noise_sampler = lambda x: torch.randn_like(x) + + def ancestral_euler_step(self, x, denoised, sigma, sigma_down): + d = to_d(x, sigma, denoised) + dt = append_dims(sigma_down - sigma, x.ndim) + + return self.euler_step(x, d, dt) + + def ancestral_step(self, x, sigma, next_sigma, sigma_up): + x = torch.where( + append_dims(next_sigma, x.ndim) > 0.0, + x + self.noise_sampler(x) * self.s_noise * append_dims(sigma_up, x.ndim), + x, + ) + return x + + def __call__(self, denoiser, x, cond, uc=None, num_steps=None): + x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps) + + for i in self.get_sigma_gen(num_sigmas): + x = self.sampler_step( + s_in * sigmas[i], + s_in * sigmas[i + 1], + denoiser, + x, + cond, + uc, + ) + + return x + + +class LinearMultistepSampler(BaseDiffusionSampler): + def __init__( + self, + order=4, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + + self.order = order + + def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs): + x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps) + + ds = [] + sigmas_cpu = sigmas.detach().cpu().numpy() + for i in self.get_sigma_gen(num_sigmas): + sigma = s_in * sigmas[i] + denoised = denoiser(*self.guider.prepare_inputs(x, sigma, cond, uc), **kwargs) + denoised = self.guider(denoised, sigma) + d = to_d(x, sigma, denoised) + ds.append(d) + if len(ds) > self.order: + ds.pop(0) + cur_order = min(i + 1, self.order) + coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)] + x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds))) + + return x + + +class EulerEDMSampler(EDMSampler): + def possible_correction_step(self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc): + return euler_step + + +class HeunEDMSampler(EDMSampler): + def possible_correction_step(self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc): + if torch.sum(next_sigma) < 1e-14: + # Save a network evaluation if all noise levels are 0 + return euler_step + else: + denoised = self.denoise(euler_step, denoiser, next_sigma, cond, uc) + d_new = to_d(euler_step, next_sigma, denoised) + d_prime = (d + d_new) / 2.0 + + # apply correction if noise level is not 0 + x = torch.where(append_dims(next_sigma, x.ndim) > 0.0, x + d_prime * dt, euler_step) + return x + + +class EulerAncestralSampler(AncestralSampler): + def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc): + sigma_down, sigma_up = get_ancestral_step(sigma, next_sigma, eta=self.eta) + denoised = self.denoise(x, denoiser, sigma, cond, uc) + x = self.ancestral_euler_step(x, denoised, sigma, sigma_down) + x = self.ancestral_step(x, sigma, next_sigma, sigma_up) + + return x + + +class DPMPP2SAncestralSampler(AncestralSampler): + def get_variables(self, sigma, sigma_down): + t, t_next = [to_neg_log_sigma(s) for s in (sigma, sigma_down)] + h = t_next - t + s = t + 0.5 * h + return h, s, t, t_next + + def get_mult(self, h, s, t, t_next): + mult1 = to_sigma(s) / to_sigma(t) + mult2 = (-0.5 * h).expm1() + mult3 = to_sigma(t_next) / to_sigma(t) + mult4 = (-h).expm1() + + return mult1, mult2, mult3, mult4 + + def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, **kwargs): + sigma_down, sigma_up = get_ancestral_step(sigma, next_sigma, eta=self.eta) + denoised = self.denoise(x, denoiser, sigma, cond, uc) + x_euler = self.ancestral_euler_step(x, denoised, sigma, sigma_down) + + if torch.sum(sigma_down) < 1e-14: + # Save a network evaluation if all noise levels are 0 + x = x_euler + else: + h, s, t, t_next = self.get_variables(sigma, sigma_down) + mult = [append_dims(mult, x.ndim) for mult in self.get_mult(h, s, t, t_next)] + + x2 = mult[0] * x - mult[1] * denoised + denoised2 = self.denoise(x2, denoiser, to_sigma(s), cond, uc) + x_dpmpp2s = mult[2] * x - mult[3] * denoised2 + + # apply correction if noise level is not 0 + x = torch.where(append_dims(sigma_down, x.ndim) > 0.0, x_dpmpp2s, x_euler) + + x = self.ancestral_step(x, sigma, next_sigma, sigma_up) + return x + + +class DPMPP2MSampler(BaseDiffusionSampler): + def get_variables(self, sigma, next_sigma, previous_sigma=None): + t, t_next = [to_neg_log_sigma(s) for s in (sigma, next_sigma)] + h = t_next - t + + if previous_sigma is not None: + h_last = t - to_neg_log_sigma(previous_sigma) + r = h_last / h + return h, r, t, t_next + else: + return h, None, t, t_next + + def get_mult(self, h, r, t, t_next, previous_sigma): + mult1 = to_sigma(t_next) / to_sigma(t) + mult2 = (-h).expm1() + + if previous_sigma is not None: + mult3 = 1 + 1 / (2 * r) + mult4 = 1 / (2 * r) + return mult1, mult2, mult3, mult4 + else: + return mult1, mult2 + + def sampler_step( + self, + old_denoised, + previous_sigma, + sigma, + next_sigma, + denoiser, + x, + cond, + uc=None, + ): + denoised = self.denoise(x, denoiser, sigma, cond, uc) + + h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma) + mult = [append_dims(mult, x.ndim) for mult in self.get_mult(h, r, t, t_next, previous_sigma)] + + x_standard = mult[0] * x - mult[1] * denoised + if old_denoised is None or torch.sum(next_sigma) < 1e-14: + # Save a network evaluation if all noise levels are 0 or on the first step + return x_standard, denoised + else: + denoised_d = mult[2] * denoised - mult[3] * old_denoised + x_advanced = mult[0] * x - mult[1] * denoised_d + + # apply correction if noise level is not 0 and not first step + x = torch.where(append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard) + + return x, denoised + + def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs): + x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps) + + old_denoised = None + for i in self.get_sigma_gen(num_sigmas): + x, old_denoised = self.sampler_step( + old_denoised, + None if i == 0 else s_in * sigmas[i - 1], + s_in * sigmas[i], + s_in * sigmas[i + 1], + denoiser, + x, + cond, + uc=uc, + ) + + return x + + +class SDEDPMPP2MSampler(BaseDiffusionSampler): + def get_variables(self, sigma, next_sigma, previous_sigma=None): + t, t_next = [to_neg_log_sigma(s) for s in (sigma, next_sigma)] + h = t_next - t + + if previous_sigma is not None: + h_last = t - to_neg_log_sigma(previous_sigma) + r = h_last / h + return h, r, t, t_next + else: + return h, None, t, t_next + + def get_mult(self, h, r, t, t_next, previous_sigma): + mult1 = to_sigma(t_next) / to_sigma(t) * (-h).exp() + mult2 = (-2 * h).expm1() + + if previous_sigma is not None: + mult3 = 1 + 1 / (2 * r) + mult4 = 1 / (2 * r) + return mult1, mult2, mult3, mult4 + else: + return mult1, mult2 + + def sampler_step( + self, + old_denoised, + previous_sigma, + sigma, + next_sigma, + denoiser, + x, + cond, + uc=None, + ): + denoised = self.denoise(x, denoiser, sigma, cond, uc) + + h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma) + mult = [append_dims(mult, x.ndim) for mult in self.get_mult(h, r, t, t_next, previous_sigma)] + mult_noise = append_dims(next_sigma * (1 - (-2 * h).exp()) ** 0.5, x.ndim) + + x_standard = mult[0] * x - mult[1] * denoised + mult_noise * torch.randn_like(x) + if old_denoised is None or torch.sum(next_sigma) < 1e-14: + # Save a network evaluation if all noise levels are 0 or on the first step + return x_standard, denoised + else: + denoised_d = mult[2] * denoised - mult[3] * old_denoised + x_advanced = mult[0] * x - mult[1] * denoised_d + mult_noise * torch.randn_like(x) + + # apply correction if noise level is not 0 and not first step + x = torch.where(append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard) + + return x, denoised + + def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, **kwargs): + x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps) + + old_denoised = None + for i in self.get_sigma_gen(num_sigmas): + x, old_denoised = self.sampler_step( + old_denoised, + None if i == 0 else s_in * sigmas[i - 1], + s_in * sigmas[i], + s_in * sigmas[i + 1], + denoiser, + x, + cond, + uc=uc, + ) + + return x + + +class SdeditEDMSampler(EulerEDMSampler): + def __init__(self, edit_ratio=0.5, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.edit_ratio = edit_ratio + + def __call__(self, denoiser, image, randn, cond, uc=None, num_steps=None, edit_ratio=None): + randn_unit = randn.clone() + randn, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(randn, cond, uc, num_steps) + + if num_steps is None: + num_steps = self.num_steps + if edit_ratio is None: + edit_ratio = self.edit_ratio + x = None + + for i in self.get_sigma_gen(num_sigmas): + if i / num_steps < edit_ratio: + continue + if x is None: + x = image + randn_unit * append_dims(s_in * sigmas[i], len(randn_unit.shape)) + + gamma = ( + min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1) if self.s_tmin <= sigmas[i] <= self.s_tmax else 0.0 + ) + x = self.sampler_step( + s_in * sigmas[i], + s_in * sigmas[i + 1], + denoiser, + x, + cond, + uc, + gamma, + ) + + return x + + +class VideoDDIMSampler(BaseDiffusionSampler): + def __init__(self, fixed_frames=0, sdedit=False, **kwargs): + super().__init__(**kwargs) + self.fixed_frames = fixed_frames + self.sdedit = sdedit + + def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None): + alpha_cumprod_sqrt, timesteps = self.discretization( + self.num_steps if num_steps is None else num_steps, + device=self.device, + return_idx=True, + do_append_zero=False, + ) + alpha_cumprod_sqrt = torch.cat([alpha_cumprod_sqrt, alpha_cumprod_sqrt.new_ones([1])]) + timesteps = torch.cat([torch.tensor(list(timesteps)).new_zeros([1]) - 1, torch.tensor(list(timesteps))]) + + uc = default(uc, cond) + + num_sigmas = len(alpha_cumprod_sqrt) + + s_in = x.new_ones([x.shape[0]]) + + return x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps + + def denoise(self, x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep=None, idx=None, scale=None, scale_emb=None): + additional_model_inputs = {} + + if isinstance(scale, torch.Tensor) == False and scale == 1: + additional_model_inputs["idx"] = x.new_ones([x.shape[0]]) * timestep + if scale_emb is not None: + additional_model_inputs["scale_emb"] = scale_emb + denoised = denoiser(x, alpha_cumprod_sqrt, cond, **additional_model_inputs).to(torch.float32) + else: + additional_model_inputs["idx"] = torch.cat([x.new_ones([x.shape[0]]) * timestep] * 2) + denoised = denoiser( + *self.guider.prepare_inputs(x, alpha_cumprod_sqrt, cond, uc), **additional_model_inputs + ).to(torch.float32) + if isinstance(self.guider, DynamicCFG): + denoised = self.guider( + denoised, (1 - alpha_cumprod_sqrt**2) ** 0.5, step_index=self.num_steps - timestep, scale=scale + ) + else: + denoised = self.guider(denoised, (1 - alpha_cumprod_sqrt**2) ** 0.5, scale=scale) + return denoised + + def sampler_step( + self, + alpha_cumprod_sqrt, + next_alpha_cumprod_sqrt, + denoiser, + x, + cond, + uc=None, + idx=None, + timestep=None, + scale=None, + scale_emb=None, + ): + denoised = self.denoise( + x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx, scale=scale, scale_emb=scale_emb + ).to(torch.float32) + + a_t = ((1 - next_alpha_cumprod_sqrt**2) / (1 - alpha_cumprod_sqrt**2)) ** 0.5 + b_t = next_alpha_cumprod_sqrt - alpha_cumprod_sqrt * a_t + + x = append_dims(a_t, x.ndim) * x + append_dims(b_t, x.ndim) * denoised + return x + + def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, scale_emb=None): + x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps = self.prepare_sampling_loop( + x, cond, uc, num_steps + ) + + for i in self.get_sigma_gen(num_sigmas): + x = self.sampler_step( + s_in * alpha_cumprod_sqrt[i], + s_in * alpha_cumprod_sqrt[i + 1], + denoiser, + x, + cond, + uc, + idx=self.num_steps - i, + timestep=timesteps[-(i + 1)], + scale=scale, + scale_emb=scale_emb, + ) + + return x + + +class VPSDEDPMPP2MSampler(VideoDDIMSampler): + def get_variables(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt=None): + alpha_cumprod = alpha_cumprod_sqrt**2 + lamb = ((alpha_cumprod / (1 - alpha_cumprod)) ** 0.5).log() + next_alpha_cumprod = next_alpha_cumprod_sqrt**2 + lamb_next = ((next_alpha_cumprod / (1 - next_alpha_cumprod)) ** 0.5).log() + h = lamb_next - lamb + + if previous_alpha_cumprod_sqrt is not None: + previous_alpha_cumprod = previous_alpha_cumprod_sqrt**2 + lamb_previous = ((previous_alpha_cumprod / (1 - previous_alpha_cumprod)) ** 0.5).log() + h_last = lamb - lamb_previous + r = h_last / h + return h, r, lamb, lamb_next + else: + return h, None, lamb, lamb_next + + def get_mult(self, h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt): + mult1 = ((1 - next_alpha_cumprod_sqrt**2) / (1 - alpha_cumprod_sqrt**2)) ** 0.5 * (-h).exp() + mult2 = (-2 * h).expm1() * next_alpha_cumprod_sqrt + + if previous_alpha_cumprod_sqrt is not None: + mult3 = 1 + 1 / (2 * r) + mult4 = 1 / (2 * r) + return mult1, mult2, mult3, mult4 + else: + return mult1, mult2 + + def sampler_step( + self, + old_denoised, + previous_alpha_cumprod_sqrt, + alpha_cumprod_sqrt, + next_alpha_cumprod_sqrt, + denoiser, + x, + cond, + uc=None, + idx=None, + timestep=None, + scale=None, + scale_emb=None, + ): + denoised = self.denoise( + x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx, scale=scale, scale_emb=scale_emb + ).to(torch.float32) + if idx == 1: + return denoised, denoised + + h, r, lamb, lamb_next = self.get_variables( + alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt + ) + mult = [ + append_dims(mult, x.ndim) + for mult in self.get_mult(h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt) + ] + mult_noise = append_dims((1 - next_alpha_cumprod_sqrt**2) ** 0.5 * (1 - (-2 * h).exp()) ** 0.5, x.ndim) + + x_standard = mult[0] * x - mult[1] * denoised + mult_noise * torch.randn_like(x) + if old_denoised is None or torch.sum(next_alpha_cumprod_sqrt) < 1e-14: + # Save a network evaluation if all noise levels are 0 or on the first step + return x_standard, denoised + else: + denoised_d = mult[2] * denoised - mult[3] * old_denoised + x_advanced = mult[0] * x - mult[1] * denoised_d + mult_noise * torch.randn_like(x) + + x = x_advanced + + return x, denoised + + def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, scale_emb=None): + x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps = self.prepare_sampling_loop( + x, cond, uc, num_steps + ) + + if self.fixed_frames > 0: + prefix_frames = x[:, : self.fixed_frames] + old_denoised = None + for i in self.get_sigma_gen(num_sigmas): + if self.fixed_frames > 0: + if self.sdedit: + rd = torch.randn_like(prefix_frames) + noised_prefix_frames = alpha_cumprod_sqrt[i] * prefix_frames + rd * append_dims( + s_in * (1 - alpha_cumprod_sqrt[i] ** 2) ** 0.5, len(prefix_frames.shape) + ) + x = torch.cat([noised_prefix_frames, x[:, self.fixed_frames :]], dim=1) + else: + x = torch.cat([prefix_frames, x[:, self.fixed_frames :]], dim=1) + x, old_denoised = self.sampler_step( + old_denoised, + None if i == 0 else s_in * alpha_cumprod_sqrt[i - 1], + s_in * alpha_cumprod_sqrt[i], + s_in * alpha_cumprod_sqrt[i + 1], + denoiser, + x, + cond, + uc=uc, + idx=self.num_steps - i, + timestep=timesteps[-(i + 1)], + scale=scale, + scale_emb=scale_emb, + ) + + if self.fixed_frames > 0: + x = torch.cat([prefix_frames, x[:, self.fixed_frames :]], dim=1) + + return x + + +class VPODEDPMPP2MSampler(VideoDDIMSampler): + def get_variables(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt=None): + alpha_cumprod = alpha_cumprod_sqrt**2 + lamb = ((alpha_cumprod / (1 - alpha_cumprod)) ** 0.5).log() + next_alpha_cumprod = next_alpha_cumprod_sqrt**2 + lamb_next = ((next_alpha_cumprod / (1 - next_alpha_cumprod)) ** 0.5).log() + h = lamb_next - lamb + + if previous_alpha_cumprod_sqrt is not None: + previous_alpha_cumprod = previous_alpha_cumprod_sqrt**2 + lamb_previous = ((previous_alpha_cumprod / (1 - previous_alpha_cumprod)) ** 0.5).log() + h_last = lamb - lamb_previous + r = h_last / h + return h, r, lamb, lamb_next + else: + return h, None, lamb, lamb_next + + def get_mult(self, h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt): + mult1 = ((1 - next_alpha_cumprod_sqrt**2) / (1 - alpha_cumprod_sqrt**2)) ** 0.5 + mult2 = (-h).expm1() * next_alpha_cumprod_sqrt + + if previous_alpha_cumprod_sqrt is not None: + mult3 = 1 + 1 / (2 * r) + mult4 = 1 / (2 * r) + return mult1, mult2, mult3, mult4 + else: + return mult1, mult2 + + def sampler_step( + self, + old_denoised, + previous_alpha_cumprod_sqrt, + alpha_cumprod_sqrt, + next_alpha_cumprod_sqrt, + denoiser, + x, + cond, + uc=None, + idx=None, + timestep=None, + ): + denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx).to(torch.float32) + if idx == 1: + return denoised, denoised + + h, r, lamb, lamb_next = self.get_variables( + alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt + ) + mult = [ + append_dims(mult, x.ndim) + for mult in self.get_mult(h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt) + ] + + x_standard = mult[0] * x - mult[1] * denoised + if old_denoised is None or torch.sum(next_alpha_cumprod_sqrt) < 1e-14: + # Save a network evaluation if all noise levels are 0 or on the first step + return x_standard, denoised + else: + denoised_d = mult[2] * denoised - mult[3] * old_denoised + x_advanced = mult[0] * x - mult[1] * denoised_d + + x = x_advanced + + return x, denoised + + def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, **kwargs): + x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps = self.prepare_sampling_loop( + x, cond, uc, num_steps + ) + + old_denoised = None + for i in self.get_sigma_gen(num_sigmas): + x, old_denoised = self.sampler_step( + old_denoised, + None if i == 0 else s_in * alpha_cumprod_sqrt[i - 1], + s_in * alpha_cumprod_sqrt[i], + s_in * alpha_cumprod_sqrt[i + 1], + denoiser, + x, + cond, + uc=uc, + idx=self.num_steps - i, + timestep=timesteps[-(i + 1)], + ) + + return x diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/diffusionmodules/sampling_utils.py b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/diffusionmodules/sampling_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9bb1fa829659394f673e19e6144d9ab3a1faf5a1 --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/diffusionmodules/sampling_utils.py @@ -0,0 +1,155 @@ +import torch +from scipy import integrate + +from ...util import append_dims +from einops import rearrange + + +class NoDynamicThresholding: + def __call__(self, uncond, cond, scale): + scale = append_dims(scale, cond.ndim) if isinstance(scale, torch.Tensor) else scale + return uncond + scale * (cond - uncond) + + +class StaticThresholding: + def __call__(self, uncond, cond, scale): + result = uncond + scale * (cond - uncond) + result = torch.clamp(result, min=-1.0, max=1.0) + return result + + +def dynamic_threshold(x, p=0.95): + N, T, C, H, W = x.shape + x = rearrange(x, "n t c h w -> n c (t h w)") + l, r = x.quantile(q=torch.tensor([1 - p, p], device=x.device), dim=-1, keepdim=True) + s = torch.maximum(-l, r) + threshold_mask = (s > 1).expand(-1, -1, H * W * T) + if threshold_mask.any(): + x = torch.where(threshold_mask, x.clamp(min=-1 * s, max=s), x) + x = rearrange(x, "n c (t h w) -> n t c h w", t=T, h=H, w=W) + return x + + +def dynamic_thresholding2(x0): + p = 0.995 # A hyperparameter in the paper of "Imagen" [1]. + origin_dtype = x0.dtype + x0 = x0.to(torch.float32) + s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) + s = append_dims(torch.maximum(s, torch.ones_like(s).to(s.device)), x0.dim()) + x0 = torch.clamp(x0, -s, s) # / s + return x0.to(origin_dtype) + + +def latent_dynamic_thresholding(x0): + p = 0.9995 + origin_dtype = x0.dtype + x0 = x0.to(torch.float32) + s = torch.quantile(torch.abs(x0), p, dim=2) + s = append_dims(s, x0.dim()) + x0 = torch.clamp(x0, -s, s) / s + return x0.to(origin_dtype) + + +def dynamic_thresholding3(x0): + p = 0.995 # A hyperparameter in the paper of "Imagen" [1]. + origin_dtype = x0.dtype + x0 = x0.to(torch.float32) + s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) + s = append_dims(torch.maximum(s, torch.ones_like(s).to(s.device)), x0.dim()) + x0 = torch.clamp(x0, -s, s) # / s + return x0.to(origin_dtype) + + +class DynamicThresholding: + def __call__(self, uncond, cond, scale): + mean = uncond.mean() + std = uncond.std() + result = uncond + scale * (cond - uncond) + result_mean, result_std = result.mean(), result.std() + result = (result - result_mean) / result_std * std + # result = dynamic_thresholding3(result) + return result + + +class DynamicThresholdingV1: + def __init__(self, scale_factor): + self.scale_factor = scale_factor + + def __call__(self, uncond, cond, scale): + result = uncond + scale * (cond - uncond) + unscaled_result = result / self.scale_factor + B, T, C, H, W = unscaled_result.shape + flattened = rearrange(unscaled_result, "b t c h w -> b c (t h w)") + means = flattened.mean(dim=2).unsqueeze(2) + recentered = flattened - means + magnitudes = recentered.abs().max() + normalized = recentered / magnitudes + thresholded = latent_dynamic_thresholding(normalized) + denormalized = thresholded * magnitudes + uncentered = denormalized + means + unflattened = rearrange(uncentered, "b c (t h w) -> b t c h w", t=T, h=H, w=W) + scaled_result = unflattened * self.scale_factor + return scaled_result + + +class DynamicThresholdingV2: + def __call__(self, uncond, cond, scale): + B, T, C, H, W = uncond.shape + diff = cond - uncond + mim_target = uncond + diff * 4.0 + cfg_target = uncond + diff * 8.0 + + mim_flattened = rearrange(mim_target, "b t c h w -> b c (t h w)") + cfg_flattened = rearrange(cfg_target, "b t c h w -> b c (t h w)") + mim_means = mim_flattened.mean(dim=2).unsqueeze(2) + cfg_means = cfg_flattened.mean(dim=2).unsqueeze(2) + mim_centered = mim_flattened - mim_means + cfg_centered = cfg_flattened - cfg_means + + mim_scaleref = mim_centered.std(dim=2).unsqueeze(2) + cfg_scaleref = cfg_centered.std(dim=2).unsqueeze(2) + + cfg_renormalized = cfg_centered / cfg_scaleref * mim_scaleref + + result = cfg_renormalized + cfg_means + unflattened = rearrange(result, "b c (t h w) -> b t c h w", t=T, h=H, w=W) + + return unflattened + + +def linear_multistep_coeff(order, t, i, j, epsrel=1e-4): + if order - 1 > i: + raise ValueError(f"Order {order} too high for step {i}") + + def fn(tau): + prod = 1.0 + for k in range(order): + if j == k: + continue + prod *= (tau - t[i - k]) / (t[i - j] - t[i - k]) + return prod + + return integrate.quad(fn, t[i], t[i + 1], epsrel=epsrel)[0] + + +def get_ancestral_step(sigma_from, sigma_to, eta=1.0): + if not eta: + return sigma_to, 0.0 + sigma_up = torch.minimum( + sigma_to, + eta * (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5, + ) + sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 + return sigma_down, sigma_up + + +def to_d(x, sigma, denoised): + return (x - denoised) / append_dims(sigma, x.ndim) + + +def to_neg_log_sigma(sigma): + return sigma.log().neg() + + +def to_sigma(neg_log_sigma): + return neg_log_sigma.neg().exp() diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/diffusionmodules/sigma_sampling.py b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/diffusionmodules/sigma_sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..770de4254e54d594e7a46663ea58d4f2f660187e --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/diffusionmodules/sigma_sampling.py @@ -0,0 +1,80 @@ +import torch +import torch.distributed + +from sat import mpu + +from ...util import default, instantiate_from_config + + +class EDMSampling: + def __init__(self, p_mean=-1.2, p_std=1.2): + self.p_mean = p_mean + self.p_std = p_std + + def __call__(self, n_samples, rand=None): + log_sigma = self.p_mean + self.p_std * default(rand, torch.randn((n_samples,))) + return log_sigma.exp() + + +class DiscreteSampling: + def __init__(self, discretization_config, num_idx, do_append_zero=False, flip=True, uniform_sampling=False): + self.num_idx = num_idx + self.sigmas = instantiate_from_config(discretization_config)(num_idx, do_append_zero=do_append_zero, flip=flip) + world_size = mpu.get_data_parallel_world_size() + self.uniform_sampling = uniform_sampling + if self.uniform_sampling: + i = 1 + while True: + if world_size % i != 0 or num_idx % (world_size // i) != 0: + i += 1 + else: + self.group_num = world_size // i + break + + assert self.group_num > 0 + assert world_size % self.group_num == 0 + self.group_width = world_size // self.group_num # the number of rank in one group + self.sigma_interval = self.num_idx // self.group_num + + def idx_to_sigma(self, idx): + return self.sigmas[idx] + + def __call__(self, n_samples, rand=None, return_idx=False): + if self.uniform_sampling: + rank = mpu.get_data_parallel_rank() + group_index = rank // self.group_width + idx = default( + rand, + torch.randint( + group_index * self.sigma_interval, (group_index + 1) * self.sigma_interval, (n_samples,) + ), + ) + else: + idx = default( + rand, + torch.randint(0, self.num_idx, (n_samples,)), + ) + if return_idx: + return self.idx_to_sigma(idx), idx + else: + return self.idx_to_sigma(idx) + + +class PartialDiscreteSampling: + def __init__(self, discretization_config, total_num_idx, partial_num_idx, do_append_zero=False, flip=True): + self.total_num_idx = total_num_idx + self.partial_num_idx = partial_num_idx + self.sigmas = instantiate_from_config(discretization_config)( + total_num_idx, do_append_zero=do_append_zero, flip=flip + ) + + def idx_to_sigma(self, idx): + return self.sigmas[idx] + + def __call__(self, n_samples, rand=None): + idx = default( + rand, + # torch.randint(self.total_num_idx-self.partial_num_idx, self.total_num_idx, (n_samples,)), + torch.randint(0, self.partial_num_idx, (n_samples,)), + ) + return self.idx_to_sigma(idx) diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/diffusionmodules/util.py b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/diffusionmodules/util.py new file mode 100644 index 0000000000000000000000000000000000000000..abf72a758fbf8a6f08145100223fea074fa64015 --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/diffusionmodules/util.py @@ -0,0 +1,328 @@ +""" +adopted from +https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +and +https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +and +https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py + +thanks! +""" + +import math +from typing import Optional + +import torch +import torch.nn as nn +from einops import rearrange, repeat + + +def make_beta_schedule( + schedule, + n_timestep, + linear_start=1e-4, + linear_end=2e-2, +): + if schedule == "linear": + betas = torch.linspace(linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64) ** 2 + return betas.numpy() + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def mixed_checkpoint(func, inputs: dict, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. This differs from the original checkpoint function + borrowed from https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py in that + it also works with non-tensor inputs + :param func: the function to evaluate. + :param inputs: the argument dictionary to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + tensor_keys = [key for key in inputs if isinstance(inputs[key], torch.Tensor)] + tensor_inputs = [inputs[key] for key in inputs if isinstance(inputs[key], torch.Tensor)] + non_tensor_keys = [key for key in inputs if not isinstance(inputs[key], torch.Tensor)] + non_tensor_inputs = [inputs[key] for key in inputs if not isinstance(inputs[key], torch.Tensor)] + args = tuple(tensor_inputs) + tuple(non_tensor_inputs) + tuple(params) + return MixedCheckpointFunction.apply( + func, + len(tensor_inputs), + len(non_tensor_inputs), + tensor_keys, + non_tensor_keys, + *args, + ) + else: + return func(**inputs) + + +class MixedCheckpointFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + run_function, + length_tensors, + length_non_tensors, + tensor_keys, + non_tensor_keys, + *args, + ): + ctx.end_tensors = length_tensors + ctx.end_non_tensors = length_tensors + length_non_tensors + ctx.gpu_autocast_kwargs = { + "enabled": torch.is_autocast_enabled(), + "dtype": torch.get_autocast_gpu_dtype(), + "cache_enabled": torch.is_autocast_cache_enabled(), + } + assert len(tensor_keys) == length_tensors and len(non_tensor_keys) == length_non_tensors + + ctx.input_tensors = {key: val for (key, val) in zip(tensor_keys, list(args[: ctx.end_tensors]))} + ctx.input_non_tensors = { + key: val for (key, val) in zip(non_tensor_keys, list(args[ctx.end_tensors : ctx.end_non_tensors])) + } + ctx.run_function = run_function + ctx.input_params = list(args[ctx.end_non_tensors :]) + + with torch.no_grad(): + output_tensors = ctx.run_function(**ctx.input_tensors, **ctx.input_non_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + # additional_args = {key: ctx.input_tensors[key] for key in ctx.input_tensors if not isinstance(ctx.input_tensors[key],torch.Tensor)} + ctx.input_tensors = {key: ctx.input_tensors[key].detach().requires_grad_(True) for key in ctx.input_tensors} + + with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = {key: ctx.input_tensors[key].view_as(ctx.input_tensors[key]) for key in ctx.input_tensors} + # shallow_copies.update(additional_args) + output_tensors = ctx.run_function(**shallow_copies, **ctx.input_non_tensors) + input_grads = torch.autograd.grad( + output_tensors, + list(ctx.input_tensors.values()) + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return ( + (None, None, None, None, None) + + input_grads[: ctx.end_tensors] + + (None,) * (ctx.end_non_tensors - ctx.end_tensors) + + input_grads[ctx.end_tensors :] + ) + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + ctx.gpu_autocast_kwargs = { + "enabled": torch.is_autocast_enabled(), + "dtype": torch.get_autocast_gpu_dtype(), + "cache_enabled": torch.is_autocast_cache_enabled(), + } + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False, dtype=torch.float32): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( + device=timesteps.device + ) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + else: + embedding = repeat(timesteps, "b -> b d", d=dim) + return embedding.to(dtype) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x).type(x.dtype) + + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +class AlphaBlender(nn.Module): + strategies = ["learned", "fixed", "learned_with_images"] + + def __init__( + self, + alpha: float, + merge_strategy: str = "learned_with_images", + rearrange_pattern: str = "b t -> (b t) 1 1", + ): + super().__init__() + self.merge_strategy = merge_strategy + self.rearrange_pattern = rearrange_pattern + + assert merge_strategy in self.strategies, f"merge_strategy needs to be in {self.strategies}" + + if self.merge_strategy == "fixed": + self.register_buffer("mix_factor", torch.Tensor([alpha])) + elif self.merge_strategy == "learned" or self.merge_strategy == "learned_with_images": + self.register_parameter("mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))) + else: + raise ValueError(f"unknown merge strategy {self.merge_strategy}") + + def get_alpha(self, image_only_indicator: torch.Tensor) -> torch.Tensor: + if self.merge_strategy == "fixed": + alpha = self.mix_factor + elif self.merge_strategy == "learned": + alpha = torch.sigmoid(self.mix_factor) + elif self.merge_strategy == "learned_with_images": + assert image_only_indicator is not None, "need image_only_indicator ..." + alpha = torch.where( + image_only_indicator.bool(), + torch.ones(1, 1, device=image_only_indicator.device), + rearrange(torch.sigmoid(self.mix_factor), "... -> ... 1"), + ) + alpha = rearrange(alpha, self.rearrange_pattern) + else: + raise NotImplementedError + return alpha + + def forward( + self, + x_spatial: torch.Tensor, + x_temporal: torch.Tensor, + image_only_indicator: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + alpha = self.get_alpha(image_only_indicator) + x = alpha.to(x_spatial.dtype) * x_spatial + (1.0 - alpha).to(x_spatial.dtype) * x_temporal + return x diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/diffusionmodules/wrappers.py b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/diffusionmodules/wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..d0b78ffd502fc67238752fbd5de7ee7c6661ce05 --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/diffusionmodules/wrappers.py @@ -0,0 +1,41 @@ +import torch +import torch.nn as nn +from packaging import version + +OPENAIUNETWRAPPER = "sgm.modules.diffusionmodules.wrappers.OpenAIWrapper" + + +class IdentityWrapper(nn.Module): + def __init__(self, diffusion_model, compile_model: bool = False, dtype: torch.dtype = torch.float32): + super().__init__() + compile = ( + torch.compile + if (version.parse(torch.__version__) >= version.parse("2.0.0")) and compile_model + else lambda x: x + ) + self.diffusion_model = compile(diffusion_model) + self.dtype = dtype + + def forward(self, *args, **kwargs): + return self.diffusion_model(*args, **kwargs) + + +class OpenAIWrapper(IdentityWrapper): + def forward(self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs) -> torch.Tensor: + for key in c: + c[key] = c[key].to(self.dtype) + + if x.dim() == 4: + x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1) + elif x.dim() == 5: + x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=2) + else: + raise ValueError("Input tensor must be 4D or 5D") + + return self.diffusion_model( + x, + timesteps=t, + context=c.get("crossattn", None), + y=c.get("vector", None), + **kwargs, + ) diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/distributions/__init__.py b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/distributions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/distributions/distributions.py b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/distributions/distributions.py new file mode 100644 index 0000000000000000000000000000000000000000..0338a861f90e4fbd1f8f4ba8712dde1316b4fa58 --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/distributions/distributions.py @@ -0,0 +1,94 @@ +import numpy as np +import torch + + +class AbstractDistribution: + def sample(self): + raise NotImplementedError() + + def mode(self): + raise NotImplementedError() + + +class DiracDistribution(AbstractDistribution): + def __init__(self, value): + self.value = value + + def sample(self): + return self.value + + def mode(self): + return self.value + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) + + def sample(self): + # x = self.mean + self.std * torch.randn(self.mean.shape).to( + # device=self.parameters.device + # ) + x = self.mean + self.std * torch.randn_like(self.mean).to(device=self.parameters.device) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.0]) + else: + if other is None: + return 0.5 * torch.sum( + torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, + dim=[1, 2, 3], + ) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var + - 1.0 + - self.logvar + + other.logvar, + dim=[1, 2, 3], + ) + + def nll(self, sample, dims=[1, 2, 3]): + if self.deterministic: + return torch.Tensor([0.0]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims, + ) + + def mode(self): + return self.mean + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, torch.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for torch.exp(). + logvar1, logvar2 = [x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) for x in (logvar1, logvar2)] + + return 0.5 * ( + -1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) + ) diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/ema.py b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/ema.py new file mode 100644 index 0000000000000000000000000000000000000000..9f1f7606c2c9b68ebd2302215a9e08f9f31ed8ab --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/ema.py @@ -0,0 +1,82 @@ +import torch +from torch import nn + + +class LitEma(nn.Module): + def __init__(self, model, decay=0.9999, use_num_upates=True): + super().__init__() + if decay < 0.0 or decay > 1.0: + raise ValueError("Decay must be between 0 and 1") + + self.m_name2s_name = {} + self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32)) + self.register_buffer( + "num_updates", + torch.tensor(0, dtype=torch.int) if use_num_upates else torch.tensor(-1, dtype=torch.int), + ) + + for name, p in model.named_parameters(): + if p.requires_grad: + # remove as '.'-character is not allowed in buffers + s_name = name.replace(".", "") + self.m_name2s_name.update({name: s_name}) + self.register_buffer(s_name, p.clone().detach().data) + + self.collected_params = [] + + def reset_num_updates(self): + del self.num_updates + self.register_buffer("num_updates", torch.tensor(0, dtype=torch.int)) + + def forward(self, model): + decay = self.decay + + if self.num_updates >= 0: + self.num_updates += 1 + decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) + + one_minus_decay = 1.0 - decay + + with torch.no_grad(): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + + for key in m_param: + if m_param[key].requires_grad: + sname = self.m_name2s_name[key] + shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) + shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) + else: + assert not key in self.m_name2s_name + + def copy_to(self, model): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + for key in m_param: + if m_param[key].requires_grad: + m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) + else: + assert not key in self.m_name2s_name + + def store(self, parameters): + """ + Save the current parameters for restoring later. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. + """ + self.collected_params = [param.clone() for param in parameters] + + def restore(self, parameters): + """ + Restore the parameters stored with the `store` method. + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before the + `copy_to` method. After validation (or model saving), use this to + restore the former parameters. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. + """ + for c_param, param in zip(self.collected_params, parameters): + param.data.copy_(c_param.data) diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/encoders/__init__.py b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/encoders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/encoders/modules.py b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/encoders/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..e8a16fcc90040a6ea61d68ef9db3f5bc75beb8da --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/encoders/modules.py @@ -0,0 +1,281 @@ +import math +from contextlib import nullcontext +from functools import partial +from typing import Dict, List, Optional, Tuple, Union + +import kornia +import numpy as np +import torch +import torch.nn as nn +from einops import rearrange, repeat +from omegaconf import ListConfig +from torch.utils.checkpoint import checkpoint +from transformers import ( + T5EncoderModel, + T5Tokenizer, +) + +from ...util import ( + append_dims, + autocast, + count_params, + default, + disabled_train, + expand_dims_like, + instantiate_from_config, +) + + +class AbstractEmbModel(nn.Module): + def __init__(self): + super().__init__() + self._is_trainable = None + self._ucg_rate = None + self._input_key = None + + @property + def is_trainable(self) -> bool: + return self._is_trainable + + @property + def ucg_rate(self) -> Union[float, torch.Tensor]: + return self._ucg_rate + + @property + def input_key(self) -> str: + return self._input_key + + @is_trainable.setter + def is_trainable(self, value: bool): + self._is_trainable = value + + @ucg_rate.setter + def ucg_rate(self, value: Union[float, torch.Tensor]): + self._ucg_rate = value + + @input_key.setter + def input_key(self, value: str): + self._input_key = value + + @is_trainable.deleter + def is_trainable(self): + del self._is_trainable + + @ucg_rate.deleter + def ucg_rate(self): + del self._ucg_rate + + @input_key.deleter + def input_key(self): + del self._input_key + + +class GeneralConditioner(nn.Module): + OUTPUT_DIM2KEYS = {2: "vector", 3: "crossattn", 4: "concat", 5: "concat"} + KEY2CATDIM = {"vector": 1, "crossattn": 2, "concat": 1} + + def __init__(self, emb_models: Union[List, ListConfig], cor_embs=[], cor_p=[]): + super().__init__() + embedders = [] + for n, embconfig in enumerate(emb_models): + embedder = instantiate_from_config(embconfig) + assert isinstance( + embedder, AbstractEmbModel + ), f"embedder model {embedder.__class__.__name__} has to inherit from AbstractEmbModel" + embedder.is_trainable = embconfig.get("is_trainable", False) + embedder.ucg_rate = embconfig.get("ucg_rate", 0.0) + if not embedder.is_trainable: + embedder.train = disabled_train + for param in embedder.parameters(): + param.requires_grad = False + embedder.eval() + print( + f"Initialized embedder #{n}: {embedder.__class__.__name__} " + f"with {count_params(embedder, False)} params. Trainable: {embedder.is_trainable}" + ) + + if "input_key" in embconfig: + embedder.input_key = embconfig["input_key"] + elif "input_keys" in embconfig: + embedder.input_keys = embconfig["input_keys"] + else: + raise KeyError(f"need either 'input_key' or 'input_keys' for embedder {embedder.__class__.__name__}") + + embedder.legacy_ucg_val = embconfig.get("legacy_ucg_value", None) + if embedder.legacy_ucg_val is not None: + embedder.ucg_prng = np.random.RandomState() + + embedders.append(embedder) + self.embedders = nn.ModuleList(embedders) + + if len(cor_embs) > 0: + assert len(cor_p) == 2 ** len(cor_embs) + self.cor_embs = cor_embs + self.cor_p = cor_p + + def possibly_get_ucg_val(self, embedder: AbstractEmbModel, batch: Dict) -> Dict: + assert embedder.legacy_ucg_val is not None + p = embedder.ucg_rate + val = embedder.legacy_ucg_val + for i in range(len(batch[embedder.input_key])): + if embedder.ucg_prng.choice(2, p=[1 - p, p]): + batch[embedder.input_key][i] = val + return batch + + def surely_get_ucg_val(self, embedder: AbstractEmbModel, batch: Dict, cond_or_not) -> Dict: + assert embedder.legacy_ucg_val is not None + val = embedder.legacy_ucg_val + for i in range(len(batch[embedder.input_key])): + if cond_or_not[i]: + batch[embedder.input_key][i] = val + return batch + + def get_single_embedding( + self, + embedder, + batch, + output, + cond_or_not: Optional[np.ndarray] = None, + force_zero_embeddings: Optional[List] = None, + ): + embedding_context = nullcontext if embedder.is_trainable else torch.no_grad + with embedding_context(): + if hasattr(embedder, "input_key") and (embedder.input_key is not None): + if embedder.legacy_ucg_val is not None: + if cond_or_not is None: + batch = self.possibly_get_ucg_val(embedder, batch) + else: + batch = self.surely_get_ucg_val(embedder, batch, cond_or_not) + emb_out = embedder(batch[embedder.input_key]) + elif hasattr(embedder, "input_keys"): + emb_out = embedder(*[batch[k] for k in embedder.input_keys]) + assert isinstance( + emb_out, (torch.Tensor, list, tuple) + ), f"encoder outputs must be tensors or a sequence, but got {type(emb_out)}" + if not isinstance(emb_out, (list, tuple)): + emb_out = [emb_out] + for emb in emb_out: + out_key = self.OUTPUT_DIM2KEYS[emb.dim()] + if embedder.ucg_rate > 0.0 and embedder.legacy_ucg_val is None: + if cond_or_not is None: + emb = ( + expand_dims_like( + torch.bernoulli((1.0 - embedder.ucg_rate) * torch.ones(emb.shape[0], device=emb.device)), + emb, + ) + * emb + ) + else: + emb = ( + expand_dims_like( + torch.tensor(1 - cond_or_not, dtype=emb.dtype, device=emb.device), + emb, + ) + * emb + ) + if hasattr(embedder, "input_key") and embedder.input_key in force_zero_embeddings: + emb = torch.zeros_like(emb) + if out_key in output: + output[out_key] = torch.cat((output[out_key], emb), self.KEY2CATDIM[out_key]) + else: + output[out_key] = emb + return output + + def forward(self, batch: Dict, force_zero_embeddings: Optional[List] = None) -> Dict: + output = dict() + if force_zero_embeddings is None: + force_zero_embeddings = [] + + if len(self.cor_embs) > 0: + batch_size = len(batch[list(batch.keys())[0]]) + rand_idx = np.random.choice(len(self.cor_p), size=(batch_size,), p=self.cor_p) + for emb_idx in self.cor_embs: + cond_or_not = rand_idx % 2 + rand_idx //= 2 + output = self.get_single_embedding( + self.embedders[emb_idx], + batch, + output=output, + cond_or_not=cond_or_not, + force_zero_embeddings=force_zero_embeddings, + ) + + for i, embedder in enumerate(self.embedders): + if i in self.cor_embs: + continue + output = self.get_single_embedding( + embedder, batch, output=output, force_zero_embeddings=force_zero_embeddings + ) + return output + + def get_unconditional_conditioning(self, batch_c, batch_uc=None, force_uc_zero_embeddings=None): + if force_uc_zero_embeddings is None: + force_uc_zero_embeddings = [] + ucg_rates = list() + for embedder in self.embedders: + ucg_rates.append(embedder.ucg_rate) + embedder.ucg_rate = 0.0 + cor_embs = self.cor_embs + cor_p = self.cor_p + self.cor_embs = [] + self.cor_p = [] + + c = self(batch_c) + uc = self(batch_c if batch_uc is None else batch_uc, force_uc_zero_embeddings) + + for embedder, rate in zip(self.embedders, ucg_rates): + embedder.ucg_rate = rate + self.cor_embs = cor_embs + self.cor_p = cor_p + + return c, uc + + +class FrozenT5Embedder(AbstractEmbModel): + """Uses the T5 transformer encoder for text""" + + def __init__( + self, + model_dir="google/t5-v1_1-xxl", + device="cuda", + max_length=77, + freeze=True, + cache_dir=None, + ): + super().__init__() + if model_dir is not "google/t5-v1_1-xxl": + self.tokenizer = T5Tokenizer.from_pretrained(model_dir) + self.transformer = T5EncoderModel.from_pretrained(model_dir) + else: + self.tokenizer = T5Tokenizer.from_pretrained(model_dir, cache_dir=cache_dir) + self.transformer = T5EncoderModel.from_pretrained(model_dir, cache_dir=cache_dir) + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + + def freeze(self): + self.transformer = self.transformer.eval() + + for param in self.parameters(): + param.requires_grad = False + + # @autocast + def forward(self, text): + batch_encoding = self.tokenizer( + text, + truncation=True, + max_length=self.max_length, + return_length=True, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt", + ) + tokens = batch_encoding["input_ids"].to(self.device) + with torch.autocast("cuda", enabled=False): + outputs = self.transformer(input_ids=tokens) + z = outputs.last_hidden_state + return z + + def encode(self, text): + return self(text) diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/video_attention.py b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/video_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..9f968d72e7d1ef5f11a68c289e76a8a1c9817312 --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/modules/video_attention.py @@ -0,0 +1,293 @@ +import torch + +from ..modules.attention import * +from ..modules.diffusionmodules.util import AlphaBlender, linear, timestep_embedding + + +class TimeMixSequential(nn.Sequential): + def forward(self, x, context=None, timesteps=None): + for layer in self: + x = layer(x, context, timesteps) + + return x + + +class VideoTransformerBlock(nn.Module): + ATTENTION_MODES = { + "softmax": CrossAttention, + "softmax-xformers": MemoryEfficientCrossAttention, + } + + def __init__( + self, + dim, + n_heads, + d_head, + dropout=0.0, + context_dim=None, + gated_ff=True, + checkpoint=True, + timesteps=None, + ff_in=False, + inner_dim=None, + attn_mode="softmax", + disable_self_attn=False, + disable_temporal_crossattention=False, + switch_temporal_ca_to_sa=False, + ): + super().__init__() + + attn_cls = self.ATTENTION_MODES[attn_mode] + + self.ff_in = ff_in or inner_dim is not None + if inner_dim is None: + inner_dim = dim + + assert int(n_heads * d_head) == inner_dim + + self.is_res = inner_dim == dim + + if self.ff_in: + self.norm_in = nn.LayerNorm(dim) + self.ff_in = FeedForward(dim, dim_out=inner_dim, dropout=dropout, glu=gated_ff) + + self.timesteps = timesteps + self.disable_self_attn = disable_self_attn + if self.disable_self_attn: + self.attn1 = attn_cls( + query_dim=inner_dim, + heads=n_heads, + dim_head=d_head, + context_dim=context_dim, + dropout=dropout, + ) # is a cross-attention + else: + self.attn1 = attn_cls( + query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout + ) # is a self-attention + + self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff) + + if disable_temporal_crossattention: + if switch_temporal_ca_to_sa: + raise ValueError + else: + self.attn2 = None + else: + self.norm2 = nn.LayerNorm(inner_dim) + if switch_temporal_ca_to_sa: + self.attn2 = attn_cls( + query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout + ) # is a self-attention + else: + self.attn2 = attn_cls( + query_dim=inner_dim, + context_dim=context_dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + ) # is self-attn if context is none + + self.norm1 = nn.LayerNorm(inner_dim) + self.norm3 = nn.LayerNorm(inner_dim) + self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa + + self.checkpoint = checkpoint + if self.checkpoint: + print(f"{self.__class__.__name__} is using checkpointing") + + def forward(self, x: torch.Tensor, context: torch.Tensor = None, timesteps: int = None) -> torch.Tensor: + if self.checkpoint: + return checkpoint(self._forward, x, context, timesteps) + else: + return self._forward(x, context, timesteps=timesteps) + + def _forward(self, x, context=None, timesteps=None): + assert self.timesteps or timesteps + assert not (self.timesteps and timesteps) or self.timesteps == timesteps + timesteps = self.timesteps or timesteps + B, S, C = x.shape + x = rearrange(x, "(b t) s c -> (b s) t c", t=timesteps) + + if self.ff_in: + x_skip = x + x = self.ff_in(self.norm_in(x)) + if self.is_res: + x += x_skip + + if self.disable_self_attn: + x = self.attn1(self.norm1(x), context=context) + x + else: + x = self.attn1(self.norm1(x)) + x + + if self.attn2 is not None: + if self.switch_temporal_ca_to_sa: + x = self.attn2(self.norm2(x)) + x + else: + x = self.attn2(self.norm2(x), context=context) + x + x_skip = x + x = self.ff(self.norm3(x)) + if self.is_res: + x += x_skip + + x = rearrange(x, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps) + return x + + def get_last_layer(self): + return self.ff.net[-1].weight + + +str_to_dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16} + + +class SpatialVideoTransformer(SpatialTransformer): + def __init__( + self, + in_channels, + n_heads, + d_head, + depth=1, + dropout=0.0, + use_linear=False, + context_dim=None, + use_spatial_context=False, + timesteps=None, + merge_strategy: str = "fixed", + merge_factor: float = 0.5, + time_context_dim=None, + ff_in=False, + checkpoint=False, + time_depth=1, + attn_mode="softmax", + disable_self_attn=False, + disable_temporal_crossattention=False, + max_time_embed_period: int = 10000, + dtype="fp32", + ): + super().__init__( + in_channels, + n_heads, + d_head, + depth=depth, + dropout=dropout, + attn_type=attn_mode, + use_checkpoint=checkpoint, + context_dim=context_dim, + use_linear=use_linear, + disable_self_attn=disable_self_attn, + ) + self.time_depth = time_depth + self.depth = depth + self.max_time_embed_period = max_time_embed_period + + time_mix_d_head = d_head + n_time_mix_heads = n_heads + + time_mix_inner_dim = int(time_mix_d_head * n_time_mix_heads) + + inner_dim = n_heads * d_head + if use_spatial_context: + time_context_dim = context_dim + + self.time_stack = nn.ModuleList( + [ + VideoTransformerBlock( + inner_dim, + n_time_mix_heads, + time_mix_d_head, + dropout=dropout, + context_dim=time_context_dim, + timesteps=timesteps, + checkpoint=checkpoint, + ff_in=ff_in, + inner_dim=time_mix_inner_dim, + attn_mode=attn_mode, + disable_self_attn=disable_self_attn, + disable_temporal_crossattention=disable_temporal_crossattention, + ) + for _ in range(self.depth) + ] + ) + + assert len(self.time_stack) == len(self.transformer_blocks) + + self.use_spatial_context = use_spatial_context + self.in_channels = in_channels + + time_embed_dim = self.in_channels * 4 + self.time_pos_embed = nn.Sequential( + linear(self.in_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, self.in_channels), + ) + + self.time_mixer = AlphaBlender(alpha=merge_factor, merge_strategy=merge_strategy) + self.dtype = str_to_dtype[dtype] + + def forward( + self, + x: torch.Tensor, + context: Optional[torch.Tensor] = None, + time_context: Optional[torch.Tensor] = None, + timesteps: Optional[int] = None, + image_only_indicator: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + _, _, h, w = x.shape + x_in = x + spatial_context = None + if exists(context): + spatial_context = context + + if self.use_spatial_context: + assert context.ndim == 3, f"n dims of spatial context should be 3 but are {context.ndim}" + + time_context = context + time_context_first_timestep = time_context[::timesteps] + time_context = repeat(time_context_first_timestep, "b ... -> (b n) ...", n=h * w) + elif time_context is not None and not self.use_spatial_context: + time_context = repeat(time_context, "b ... -> (b n) ...", n=h * w) + if time_context.ndim == 2: + time_context = rearrange(time_context, "b c -> b 1 c") + + x = self.norm(x) + if not self.use_linear: + x = self.proj_in(x) + x = rearrange(x, "b c h w -> b (h w) c") + if self.use_linear: + x = self.proj_in(x) + + num_frames = torch.arange(timesteps, device=x.device) + num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps) + num_frames = rearrange(num_frames, "b t -> (b t)") + t_emb = timestep_embedding( + num_frames, + self.in_channels, + repeat_only=False, + max_period=self.max_time_embed_period, + dtype=self.dtype, + ) + emb = self.time_pos_embed(t_emb) + emb = emb[:, None, :] + + for it_, (block, mix_block) in enumerate(zip(self.transformer_blocks, self.time_stack)): + x = block( + x, + context=spatial_context, + ) + + x_mix = x + x_mix = x_mix + emb + + x_mix = mix_block(x_mix, context=time_context, timesteps=timesteps) + x = self.time_mixer( + x_spatial=x, + x_temporal=x_mix, + image_only_indicator=image_only_indicator, + ) + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) + if not self.use_linear: + x = self.proj_out(x) + out = x + x_in + return out diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/util.py b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/util.py new file mode 100644 index 0000000000000000000000000000000000000000..b93a04930b62d5cf5d9361b3153883cdebfdc28a --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/util.py @@ -0,0 +1,383 @@ +import functools +import importlib +import os +from functools import partial +from inspect import isfunction + +import fsspec +import numpy as np +import torch +from PIL import Image, ImageDraw, ImageFont +from safetensors.torch import load_file as load_safetensors +import torch.distributed + +_CONTEXT_PARALLEL_GROUP = None +_CONTEXT_PARALLEL_SIZE = None + + +def is_context_parallel_initialized(): + if _CONTEXT_PARALLEL_GROUP is None: + return False + else: + return True + + +def set_context_parallel_group(size, group): + global _CONTEXT_PARALLEL_GROUP + global _CONTEXT_PARALLEL_SIZE + _CONTEXT_PARALLEL_GROUP = group + _CONTEXT_PARALLEL_SIZE = size + + +def initialize_context_parallel(context_parallel_size): + global _CONTEXT_PARALLEL_GROUP + global _CONTEXT_PARALLEL_SIZE + + assert _CONTEXT_PARALLEL_GROUP is None, "context parallel group is already initialized" + _CONTEXT_PARALLEL_SIZE = context_parallel_size + + rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + + for i in range(0, world_size, context_parallel_size): + ranks = range(i, i + context_parallel_size) + group = torch.distributed.new_group(ranks) + if rank in ranks: + _CONTEXT_PARALLEL_GROUP = group + break + + +def get_context_parallel_group(): + assert _CONTEXT_PARALLEL_GROUP is not None, "context parallel group is not initialized" + + return _CONTEXT_PARALLEL_GROUP + + +def get_context_parallel_world_size(): + assert _CONTEXT_PARALLEL_SIZE is not None, "context parallel size is not initialized" + + return _CONTEXT_PARALLEL_SIZE + + +def get_context_parallel_rank(): + assert _CONTEXT_PARALLEL_SIZE is not None, "context parallel size is not initialized" + + rank = torch.distributed.get_rank() + cp_rank = rank % _CONTEXT_PARALLEL_SIZE + return cp_rank + + +def get_context_parallel_group_rank(): + assert _CONTEXT_PARALLEL_SIZE is not None, "context parallel size is not initialized" + + rank = torch.distributed.get_rank() + cp_group_rank = rank // _CONTEXT_PARALLEL_SIZE + + return cp_group_rank + + +class SafeConv3d(torch.nn.Conv3d): + def forward(self, input): + memory_count = torch.prod(torch.tensor(input.shape)).item() * 2 / 1024**3 + if memory_count > 2: + # print(f"WARNING: Conv3d with {memory_count:.2f}GB") + kernel_size = self.kernel_size[0] + part_num = int(memory_count / 2) + 1 + input_chunks = torch.chunk(input, part_num, dim=2) # NCTHW + if kernel_size > 1: + input_chunks = [input_chunks[0]] + [ + torch.cat((input_chunks[i - 1][:, :, -kernel_size + 1 :], input_chunks[i]), dim=2) + for i in range(1, len(input_chunks)) + ] + + output_chunks = [] + for input_chunk in input_chunks: + output_chunks.append(super(SafeConv3d, self).forward(input_chunk)) + output = torch.cat(output_chunks, dim=2) + return output + else: + return super(SafeConv3d, self).forward(input) + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def get_string_from_tuple(s): + try: + # Check if the string starts and ends with parentheses + if s[0] == "(" and s[-1] == ")": + # Convert the string to a tuple + t = eval(s) + # Check if the type of t is tuple + if type(t) == tuple: + return t[0] + else: + pass + except: + pass + return s + + +def is_power_of_two(n): + """ + chat.openai.com/chat + Return True if n is a power of 2, otherwise return False. + + The function is_power_of_two takes an integer n as input and returns True if n is a power of 2, otherwise it returns False. + The function works by first checking if n is less than or equal to 0. If n is less than or equal to 0, it can't be a power of 2, so the function returns False. + If n is greater than 0, the function checks whether n is a power of 2 by using a bitwise AND operation between n and n-1. If n is a power of 2, then it will have only one bit set to 1 in its binary representation. When we subtract 1 from a power of 2, all the bits to the right of that bit become 1, and the bit itself becomes 0. So, when we perform a bitwise AND between n and n-1, we get 0 if n is a power of 2, and a non-zero value otherwise. + Thus, if the result of the bitwise AND operation is 0, then n is a power of 2 and the function returns True. Otherwise, the function returns False. + + """ + if n <= 0: + return False + return (n & (n - 1)) == 0 + + +def autocast(f, enabled=True): + def do_autocast(*args, **kwargs): + with torch.cuda.amp.autocast( + enabled=enabled, + dtype=torch.get_autocast_gpu_dtype(), + cache_enabled=torch.is_autocast_cache_enabled(), + ): + return f(*args, **kwargs) + + return do_autocast + + +def load_partial_from_config(config): + return partial(get_obj_from_str(config["target"]), **config.get("params", dict())) + + +def log_txt_as_img(wh, xc, size=10): + # wh a tuple of (width, height) + # xc a list of captions to plot + b = len(xc) + txts = list() + for bi in range(b): + txt = Image.new("RGB", wh, color="white") + draw = ImageDraw.Draw(txt) + font = ImageFont.truetype("data/DejaVuSans.ttf", size=size) + nc = int(40 * (wh[0] / 256)) + if isinstance(xc[bi], list): + text_seq = xc[bi][0] + else: + text_seq = xc[bi] + lines = "\n".join(text_seq[start : start + nc] for start in range(0, len(text_seq), nc)) + + try: + draw.text((0, 0), lines, fill="black", font=font) + except UnicodeEncodeError: + print("Cant encode string for logging. Skipping.") + + txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 + txts.append(txt) + txts = np.stack(txts) + txts = torch.tensor(txts) + return txts + + +def partialclass(cls, *args, **kwargs): + class NewCls(cls): + __init__ = functools.partialmethod(cls.__init__, *args, **kwargs) + + return NewCls + + +def make_path_absolute(path): + fs, p = fsspec.core.url_to_fs(path) + if fs.protocol == "file": + return os.path.abspath(p) + return path + + +def ismap(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] > 3) + + +def isimage(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) + + +def isheatmap(x): + if not isinstance(x, torch.Tensor): + return False + + return x.ndim == 2 + + +def isneighbors(x): + if not isinstance(x, torch.Tensor): + return False + return x.ndim == 5 and (x.shape[2] == 3 or x.shape[2] == 1) + + +def exists(x): + return x is not None + + +def expand_dims_like(x, y): + while x.dim() != y.dim(): + x = x.unsqueeze(-1) + return x + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def mean_flat(tensor): + """ + https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def count_params(model, verbose=False): + total_params = sum(p.numel() for p in model.parameters()) + if verbose: + print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") + return total_params + + +def instantiate_from_config(config, **extra_kwargs): + if not "target" in config: + if config == "__is_first_stage__": + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict()), **extra_kwargs) + + +def get_obj_from_str(string, reload=False, invalidate_cache=True): + module, cls = string.rsplit(".", 1) + if invalidate_cache: + importlib.invalidate_caches() + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +def append_zero(x): + return torch.cat([x, x.new_zeros([1])]) + + +def append_dims(x, target_dims): + """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less") + return x[(...,) + (None,) * dims_to_append] + + +def load_model_from_config(config, ckpt, verbose=True, freeze=True): + print(f"Loading model from {ckpt}") + if ckpt.endswith("ckpt"): + pl_sd = torch.load(ckpt, map_location="cpu") + if "global_step" in pl_sd: + print(f"Global Step: {pl_sd['global_step']}") + sd = pl_sd["state_dict"] + elif ckpt.endswith("safetensors"): + sd = load_safetensors(ckpt) + else: + raise NotImplementedError + + model = instantiate_from_config(config.model) + + m, u = model.load_state_dict(sd, strict=False) + + if len(m) > 0 and verbose: + print("missing keys:") + print(m) + if len(u) > 0 and verbose: + print("unexpected keys:") + print(u) + + if freeze: + for param in model.parameters(): + param.requires_grad = False + + model.eval() + return model + + +def get_configs_path() -> str: + """ + Get the `configs` directory. + For a working copy, this is the one in the root of the repository, + but for an installed copy, it's in the `sgm` package (see pyproject.toml). + """ + this_dir = os.path.dirname(__file__) + candidates = ( + os.path.join(this_dir, "configs"), + os.path.join(this_dir, "..", "configs"), + ) + for candidate in candidates: + candidate = os.path.abspath(candidate) + if os.path.isdir(candidate): + return candidate + raise FileNotFoundError(f"Could not find SGM configs in {candidates}") + + +def get_nested_attribute(obj, attribute_path, depth=None, return_key=False): + """ + Will return the result of a recursive get attribute call. + E.g.: + a.b.c + = getattr(getattr(a, "b"), "c") + = get_nested_attribute(a, "b.c") + If any part of the attribute call is an integer x with current obj a, will + try to call a[x] instead of a.x first. + """ + attributes = attribute_path.split(".") + if depth is not None and depth > 0: + attributes = attributes[:depth] + assert len(attributes) > 0, "At least one attribute should be selected" + current_attribute = obj + current_key = None + for level, attribute in enumerate(attributes): + current_key = ".".join(attributes[: level + 1]) + try: + id_ = int(attribute) + current_attribute = current_attribute[id_] + except ValueError: + current_attribute = getattr(current_attribute, attribute) + + return (current_attribute, current_key) if return_key else current_attribute + + +from math import sqrt + + +class SeededNoise: + def __init__(self, seeds, weights): + self.seeds = seeds + self.weights = weights + weight_square_sum = 0 + for weight in weights: + weight_square_sum += weight**2 + self.weight_square_sum_sqrt = sqrt(weight_square_sum) + self.cnt = 0 + + def __call__(self, x): + self.cnt += 1 + randn_combined = torch.zeros_like(x) + for seed, weight in zip(self.seeds, self.weights): + randn = np.random.RandomState(seed + self.cnt).randn(*x.shape) + randn = torch.from_numpy(randn, dtype=x.dtype, device=x.device) + randn_combined += randn * weight + randn_combined /= self.weight_square_sum_sqrt + return randn_combined diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/webds.py b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/webds.py new file mode 100644 index 0000000000000000000000000000000000000000..b99f9f337e2f19532cda3237473ee069dbbc4f9b --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/sgm/webds.py @@ -0,0 +1,389 @@ +import sys +import io +import os +import re +import json +import tarfile +from functools import partial + +import webdataset as wds +from webdataset import ResampledShards, DataPipeline, tarfile_to_samples +from webdataset.filters import pipelinefilter +from webdataset.tariterators import url_opener, group_by_keys +from webdataset.handlers import reraise_exception +from webdataset.gopen import gopen_schemes, gopen + + +def pytorch_worker_info(group=None): # sourcery skip: use-contextlib-suppress + """Return node and worker info for PyTorch and some distributed environments.""" + rank = 0 + world_size = 1 + worker = 0 + num_workers = 1 + try: + import torch.distributed + + if torch.distributed.is_available() and torch.distributed.is_initialized(): + group = group or torch.distributed.group.WORLD + rank = torch.distributed.get_rank(group=group) + world_size = torch.distributed.get_world_size(group=group) + except ModuleNotFoundError: + pass + try: + import torch.utils.data + + worker_info = torch.utils.data.get_worker_info() + if worker_info is not None: + worker = worker_info.id + num_workers = worker_info.num_workers + except ModuleNotFoundError: + pass + + return rank, world_size, worker, num_workers + + +def pytorch_worker_seed(group=None): + """Compute a distinct, deterministic RNG seed for each worker and node.""" + rank, world_size, worker, num_workers = pytorch_worker_info(group=group) + return rank * 1000 + worker + + +def worker_seed_sat(group=None, seed=0): + return pytorch_worker_seed(group=group) + seed * 23 + + +class ConfiguredResampledShards(ResampledShards): + def __init__(self, urls, seed, nshards=sys.maxsize, deterministic=True): + from sat.helpers import print_rank0 + + try: + from megatron.core.parallel_state import get_data_parallel_group + + group = get_data_parallel_group() + print_rank0("Using megatron data parallel group.") + except: + from sat.mpu import get_data_parallel_group + + try: + group = get_data_parallel_group() + print_rank0("Using sat data parallel group.") + except AssertionError: + group = None + print_rank0("No data parallel group is specified!") + worker_seed_sat_this = partial(worker_seed_sat, group=group, seed=seed) + super().__init__(urls, nshards, worker_seed_sat_this, deterministic) + + +class SimpleDistributedWebDataset(DataPipeline): + def __init__(self, path, process_fn, seed, *, shuffle_buffer=1000): + # set shuffle_buffer = 1 to disable it, model-parallel will be different due to shuffle + try: + from sat.mpu import get_model_parallel_world_size + + if get_model_parallel_world_size() > 1: + shuffle_buffer = 1 + except Exception: + pass + super().__init__( + ConfiguredResampledShards(path, seed), # Lots of shards are recommended, or not evenly + tarfile_to_samples(), + wds.shuffle(shuffle_buffer), + process_fn, + ) + + +def tar_file_iterator_with_meta( + fileobj, meta_names, skip_meta=r"__[^/]*__($|/)", suffix=None, handler=reraise_exception, meta_stream=None +): + """Iterate over tar file, yielding filename, content pairs for the given tar stream. + + :param fileobj: byte stream suitable for tarfile + :param meta_names: key of different items in meta file + :param skip_meta: regexp for keys that are skipped entirely (Default value = r"__[^/]*__($|/)") + + """ + stream = tarfile.open(fileobj=fileobj, mode="r|*") + data_dir, filename = fileobj.name.rsplit("/", 1) + meta_data = {} # {id: {meta_name: meta_value, meta_name2: meta_value2, ...}} + + if meta_stream is None: + meta_file_name = filename.split(".")[0] + ".meta.jsonl" + meta_path = os.path.join(data_dir, meta_file_name) + if os.path.exists(meta_path): + meta_stream = open(meta_path, "r") + else: + meta_file_name = meta_stream.name + + if meta_stream is not None: + for lineno, line in enumerate(meta_stream): + meta_list = [] + try: + meta_list.append(json.loads(line)) + except Exception as exn: + from sat.helpers import print_rank0 + + print_rank0(f"Error in loading jsonl {meta_file_name}, lineno {lineno}: {line}", level="DEBUG") + continue + for item in meta_list: + if not item["key"] in meta_data: + meta_data[item["key"]] = {} + for meta_name in meta_names: + if meta_name in item: + meta_data[item["key"]][meta_name] = item[meta_name] + meta_stream.close() + + try: + for tarinfo in stream: + fname = tarinfo.name + try: + if not tarinfo.isreg(): + continue + if fname is None: + continue + if "/" not in fname and fname.startswith("__") and fname.endswith("__"): + # skipping metadata for now + continue + if skip_meta is not None and re.match(skip_meta, fname): + continue + if fname.endswith(".txt") and suffix is not None: + data = (stream.extractfile(tarinfo).read().decode() + suffix).encode() + else: + data = stream.extractfile(tarinfo).read() + result = dict(fname=fname, data=data) + yield result + + if fname.endswith(".id"): + fid = fname.split(".")[0] + if "-$#%@&" in fid: + sfid = fid.split("-$#%@&")[0] + else: + sfid = fid + meta_data_fid = meta_data.get(sfid, {}) + for meta_name in meta_names: + meta_fname = fid + "." + meta_name + meta = meta_data_fid.get(meta_name, None) + yield dict(fname=meta_fname, data=meta) + stream.members = [] + except Exception as exn: + if hasattr(exn, "args") and len(exn.args) > 0: + exn.args = (exn.args[0] + " @ " + str(fileobj),) + exn.args[1:] + if handler(exn): + continue + else: + break + except Exception as exn: + print(exn) + del stream + + +def tar_file_expander_with_meta(data, meta_names, handler=reraise_exception): + """Expand a stream of open tar files into a stream of tar file contents. + + This returns an iterator over (filename, file_contents). + """ + for source in data: + url = source["url"] + try: + assert isinstance(source, dict) + assert "stream" in source + for sample in tar_file_iterator_with_meta(source["stream"], meta_names, meta_stream=source["meta_stream"]): + assert isinstance(sample, dict) and "data" in sample and "fname" in sample + sample["__url__"] = url + yield sample + except Exception as exn: + exn.args = exn.args + (source.get("stream"), source.get("url")) + if handler(exn): + continue + else: + break + + +def url_opener( + data, + handler, + **kw, +): + """Open URLs and yield a stream of url+stream pairs. + + Args: + data: iterator over dict(url=...) + handler: exception handler. + kw: keyword arguments for gopen.gopen. + + Yields: + a stream of url+stream pairs. + """ + for sample in data: + assert isinstance(sample, dict), sample + assert "url" in sample + url = sample["url"] + try: + stream = gopen(url, **kw) + if hasattr(stream, "meta_stream"): + meta_stream = stream.meta_stream + del stream.meta_stream + else: + meta_stream = None + sample.update(stream=stream, meta_stream=meta_stream) + yield sample + except Exception as exn: + exn.args = exn.args + (url,) + if handler(exn): + continue + else: + break + + +def tarfile_samples_with_meta(src, meta_names, handler=reraise_exception): + streams = url_opener(src, handler=handler) + files = tar_file_expander_with_meta(streams, meta_names, handler) + samples = group_by_keys(files, handler=handler) + return samples + + +class MetaDistributedWebDataset(DataPipeline): + """WebDataset with meta information files + Extra Format: + in webdataset (tar), for each sample there is a '.id'; + for each tar file, there is a '.meta.jsonl' file with the same name; + The '.meta.jsonl' file contains lines of json objects, each with a 'key' field to match '.id'. + """ + + def __init__( + self, path, process_fn, seed, *, meta_names=[], nshards=sys.maxsize, shuffle_buffer=1000, include_dirs=None + ): + # os.environ['WDS_SHOW_SEED'] = '1' + import torch + + if torch.distributed.get_rank() == 0: + if include_dirs is not None: # /webdatasets/A,/webdatasets/C + other_paths = [] + include_dirs = include_dirs.split(",") + for include_dir in include_dirs: + if "*" in include_dir: + include_dir, n = include_dir.split("*") + n = int(n) + else: + n = 1 + for cur_dir, dirs, files in os.walk(include_dir): + for f in files: + if f.endswith("tar") and os.path.getsize(os.path.join(cur_dir, f)) > 0: + # other_paths.append(os.path.join(cur_dir,f)) + other_paths.extend([os.path.join(cur_dir, f)] * n) + # print(f'Adding dataset paths {",".join(other_paths)}') + from braceexpand import braceexpand + + if len(path) > 0: # not "" + path = list(braceexpand(path)) + other_paths + else: + path = other_paths + path = [path] + else: + path = [ + None, + ] + torch.distributed.broadcast_object_list(path, src=0) + path = path[0] + + tarfile_samples = partial(tarfile_samples_with_meta, meta_names=meta_names) + tarfile_to_samples = pipelinefilter(tarfile_samples) + + # if model parallel, shuffle_buffer should be 1 to disable shuffling + try: + from sat.mpu import get_model_parallel_world_size + + if get_model_parallel_world_size() > 1: + shuffle_buffer = 1 + except Exception: + pass + + super().__init__( + ConfiguredResampledShards(path, seed, nshards=nshards), + tarfile_to_samples(), + wds.shuffle(shuffle_buffer), + process_fn, + ) + + +# rclone support +from webdataset.gopen import Pipe + + +def gopen_rclone(url, mode="rb", bufsize=1024 * 1024 * 32): + """Open a URL with `curl`. + + :param url: rclone url, e.g. data:bucket1/foo.tar. data should be configured. + :param mode: file mode + :param bufsize: buffer size + """ + url = url.replace("rclone://", "") + if mode[0] == "r": + cmd = f"rclone cat '{url}'" + return Pipe( + cmd, + mode=mode, + shell=True, + bufsize=bufsize, + ignore_status=[141, 23], + ) # skipcq: BAN-B604 + elif mode[0] == "w": + cmd = f"rclone cp - '{url}'" + return Pipe( + cmd, + mode=mode, + shell=True, + bufsize=bufsize, + ignore_status=[141, 26], + ) # skipcq: BAN-B604 + else: + raise ValueError(f"{mode}: unknown mode") + + +def gopen_boto3(url, mode="rb", bufsize=8192 * 2): + """Open a URL with boto3 API. + + :param url: boto3 url, e.g. boto3://bucket1/foo.tar. data should be configured. + :param mode: file mode + :param bufsize: buffer size + """ + import boto3 + + # boto3.set_stream_logger('botocore', level='DEBUG') + if url.startswith("boto3://"): + url = url.replace("boto3://", "") + need_meta = False + else: + url = url.replace("metaboto3://", "") + need_meta = True + endpoint_url = os.environ.get("S3_ENDPOINT_URL", None) + access_key = os.environ.get("S3_ACCESS_KEY_ID", None) + secret_key = os.environ.get("S3_SECRET_ACCESS_KEY", None) + + if mode[0] == "r": + s3_client = boto3.client( + "s3", endpoint_url=endpoint_url, aws_access_key_id=access_key, aws_secret_access_key=secret_key + ) + bucket, key = url.split("/", 1) + + if need_meta: + # download a meta json + meta_file_key = key.split(".")[0] + ".meta.jsonl" + meta_stream = io.BytesIO() + s3_client.download_fileobj(bucket, meta_file_key, meta_stream) + meta_stream.seek(0) + meta_stream.name = meta_file_key + else: + meta_stream = None + + # data tar stream + response = s3_client.get_object(Bucket=bucket, Key=key) # Range optional + response["Body"].name = key # actually not used + response["Body"].meta_stream = meta_stream + return response["Body"] + else: + raise ValueError(f"{mode}: unknown mode") + + +gopen_schemes["rclone"] = gopen_rclone +gopen_schemes["boto3"] = gopen_boto3 +gopen_schemes["metaboto3"] = gopen_boto3 diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/train_video.py b/PyTorch/contrib/cv/video/CogVideoX/sat/train_video.py new file mode 100644 index 0000000000000000000000000000000000000000..8b20c0a305f283a1f0c8db16a26c56c37137513d --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/train_video.py @@ -0,0 +1,235 @@ +import os +import argparse +from functools import partial +import numpy as np +import torch.distributed +from omegaconf import OmegaConf +import imageio + +import torch +import torch_npu +from torch_npu.contrib import transfer_to_npu + +from sat import mpu +from sat.training.deepspeed_training import training_main + +from sgm.util import get_obj_from_str, isheatmap + +from diffusion_video import SATVideoDiffusionEngine +from arguments import get_args + +from einops import rearrange + +try: + import wandb +except ImportError: + print("warning: wandb not installed") + + +def print_debug(args, s): + if args.debug: + s = f"RANK:[{torch.distributed.get_rank()}]:" + s + print(s) + + +def save_texts(texts, save_dir, iterations): + output_path = os.path.join(save_dir, f"{str(iterations).zfill(8)}") + with open(output_path, "w", encoding="utf-8") as f: + for text in texts: + f.write(text + "\n") + + +def save_video_as_grid_and_mp4(video_batch: torch.Tensor, save_path: str, T: int, fps: int = 5, args=None, key=None): + os.makedirs(save_path, exist_ok=True) + + for i, vid in enumerate(video_batch): + gif_frames = [] + for frame in vid: + frame = rearrange(frame, "c h w -> h w c") + frame = (255.0 * frame).cpu().numpy().astype(np.uint8) + gif_frames.append(frame) + now_save_path = os.path.join(save_path, f"{i:06d}.mp4") + with imageio.get_writer(now_save_path, fps=fps) as writer: + for frame in gif_frames: + writer.append_data(frame) + if args is not None and args.wandb: + wandb.log( + {key + f"_video_{i}": wandb.Video(now_save_path, fps=fps, format="mp4")}, step=args.iteration + 1 + ) + + +def log_video(batch, model, args, only_log_video_latents=False): + texts = batch["txt"] + text_save_dir = os.path.join(args.save, "video_texts") + os.makedirs(text_save_dir, exist_ok=True) + save_texts(texts, text_save_dir, args.iteration) + + gpu_autocast_kwargs = { + "enabled": torch.is_autocast_enabled(), + "dtype": torch.get_autocast_gpu_dtype(), + "cache_enabled": torch.is_autocast_cache_enabled(), + } + with torch.no_grad(), torch.cuda.amp.autocast(**gpu_autocast_kwargs): + videos = model.log_video(batch, only_log_video_latents=only_log_video_latents) + + if torch.distributed.get_rank() == 0: + root = os.path.join(args.save, "video") + + if only_log_video_latents: + root = os.path.join(root, "latents") + filename = "{}_gs-{:06}".format("latents", args.iteration) + path = os.path.join(root, filename) + os.makedirs(os.path.split(path)[0], exist_ok=True) + os.makedirs(path, exist_ok=True) + torch.save(videos["latents"], os.path.join(path, "latent.pt")) + else: + for k in videos: + N = videos[k].shape[0] + if not isheatmap(videos[k]): + videos[k] = videos[k][:N] + if isinstance(videos[k], torch.Tensor): + videos[k] = videos[k].detach().float().cpu() + if not isheatmap(videos[k]): + videos[k] = torch.clamp(videos[k], -1.0, 1.0) + + num_frames = batch["num_frames"][0] + fps = batch["fps"][0].cpu().item() + if only_log_video_latents: + root = os.path.join(root, "latents") + filename = "{}_gs-{:06}".format("latents", args.iteration) + path = os.path.join(root, filename) + os.makedirs(os.path.split(path)[0], exist_ok=True) + os.makedirs(path, exist_ok=True) + torch.save(videos["latents"], os.path.join(path, "latents.pt")) + else: + for k in videos: + samples = (videos[k] + 1.0) / 2.0 + filename = "{}_gs-{:06}".format(k, args.iteration) + + path = os.path.join(root, filename) + os.makedirs(os.path.split(path)[0], exist_ok=True) + save_video_as_grid_and_mp4(samples, path, num_frames // fps, fps, args, k) + + +def broad_cast_batch(batch): + mp_size = mpu.get_model_parallel_world_size() + global_rank = torch.distributed.get_rank() // mp_size + src = global_rank * mp_size + + if batch["mp4"] is not None: + broadcast_shape = [batch["mp4"].shape, batch["fps"].shape, batch["num_frames"].shape] + else: + broadcast_shape = None + + txt = [batch["txt"], broadcast_shape] + torch.distributed.broadcast_object_list(txt, src=src, group=mpu.get_model_parallel_group()) + batch["txt"] = txt[0] + + mp4_shape = txt[1][0] + fps_shape = txt[1][1] + num_frames_shape = txt[1][2] + + if mpu.get_model_parallel_rank() != 0: + batch["mp4"] = torch.zeros(mp4_shape, device="cuda") + batch["fps"] = torch.zeros(fps_shape, device="cuda", dtype=torch.long) + batch["num_frames"] = torch.zeros(num_frames_shape, device="cuda", dtype=torch.long) + + torch.distributed.broadcast(batch["mp4"], src=src, group=mpu.get_model_parallel_group()) + torch.distributed.broadcast(batch["fps"], src=src, group=mpu.get_model_parallel_group()) + torch.distributed.broadcast(batch["num_frames"], src=src, group=mpu.get_model_parallel_group()) + return batch + + +def forward_step_eval(data_iterator, model, args, timers, only_log_video_latents=False, data_class=None): + if mpu.get_model_parallel_rank() == 0: + timers("data loader").start() + batch_video = next(data_iterator) + timers("data loader").stop() + + if len(batch_video["mp4"].shape) == 6: + b, v = batch_video["mp4"].shape[:2] + batch_video["mp4"] = batch_video["mp4"].view(-1, *batch_video["mp4"].shape[2:]) + txt = [] + for i in range(b): + for j in range(v): + txt.append(batch_video["txt"][j][i]) + batch_video["txt"] = txt + + for key in batch_video: + if isinstance(batch_video[key], torch.Tensor): + batch_video[key] = batch_video[key].cuda() + else: + batch_video = {"mp4": None, "fps": None, "num_frames": None, "txt": None} + broad_cast_batch(batch_video) + if mpu.get_data_parallel_rank() == 0: + log_video(batch_video, model, args, only_log_video_latents=only_log_video_latents) + + batch_video["global_step"] = args.iteration + loss, loss_dict = model.shared_step(batch_video) + for k in loss_dict: + if loss_dict[k].dtype == torch.bfloat16: + loss_dict[k] = loss_dict[k].to(torch.float32) + return loss, loss_dict + + +def forward_step(data_iterator, model, args, timers, data_class=None): + if mpu.get_model_parallel_rank() == 0: + timers("data loader").start() + batch = next(data_iterator) + timers("data loader").stop() + for key in batch: + if isinstance(batch[key], torch.Tensor): + batch[key] = batch[key].cuda() + + if torch.distributed.get_rank() == 0: + if not os.path.exists(os.path.join(args.save, "training_config.yaml")): + configs = [OmegaConf.load(cfg) for cfg in args.base] + config = OmegaConf.merge(*configs) + os.makedirs(args.save, exist_ok=True) + OmegaConf.save(config=config, f=os.path.join(args.save, "training_config.yaml")) + else: + batch = {"mp4": None, "fps": None, "num_frames": None, "txt": None} + + batch["global_step"] = args.iteration + + broad_cast_batch(batch) + + loss, loss_dict = model.shared_step(batch) + + return loss, loss_dict + + +if __name__ == "__main__": + torch.npu.set_compile_mode(jit_compile=False) + torch.npu.config.allow_internal_format=False + if "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ: + os.environ["LOCAL_RANK"] = os.environ["OMPI_COMM_WORLD_LOCAL_RANK"] + os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"] + os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"] + + py_parser = argparse.ArgumentParser(add_help=False) + known, args_list = py_parser.parse_known_args() + args = get_args(args_list) + args = argparse.Namespace(**vars(args), **vars(known)) + + data_class = get_obj_from_str(args.data_config["target"]) + create_dataset_function = partial(data_class.create_dataset_function, **args.data_config["params"]) + + import yaml + + configs = [] + for config in args.base: + with open(config, "r") as f: + base_config = yaml.safe_load(f) + configs.append(base_config) + args.log_config = configs + + training_main( + args, + model_cls=SATVideoDiffusionEngine, + forward_step_function=partial(forward_step, data_class=data_class), + forward_step_eval=partial( + forward_step_eval, data_class=data_class, only_log_video_latents=args.only_log_video_latents + ), + create_dataset_function=create_dataset_function, + ) diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/vae_modules/attention.py b/PyTorch/contrib/cv/video/CogVideoX/sat/vae_modules/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..041df7700a4abe166089c7fbe52ca8e261badc55 --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/vae_modules/attention.py @@ -0,0 +1,572 @@ +import math +from inspect import isfunction +from typing import Any, Optional + +import torch +import torch.nn.functional as F +from einops import rearrange, repeat +from packaging import version +from torch import nn + +if version.parse(torch.__version__) >= version.parse("2.0.0"): + SDP_IS_AVAILABLE = True + from torch.backends.cuda import SDPBackend, sdp_kernel + + BACKEND_MAP = { + SDPBackend.MATH: { + "enable_math": True, + "enable_flash": False, + "enable_mem_efficient": False, + }, + SDPBackend.FLASH_ATTENTION: { + "enable_math": False, + "enable_flash": True, + "enable_mem_efficient": False, + }, + SDPBackend.EFFICIENT_ATTENTION: { + "enable_math": False, + "enable_flash": False, + "enable_mem_efficient": True, + }, + None: {"enable_math": True, "enable_flash": True, "enable_mem_efficient": True}, + } +else: + from contextlib import nullcontext + + SDP_IS_AVAILABLE = False + sdp_kernel = nullcontext + BACKEND_MAP = {} + print( + f"No SDP backend available, likely because you are running in pytorch versions < 2.0. In fact, " + f"you are using PyTorch {torch.__version__}. You might want to consider upgrading." + ) + +try: + import xformers + import xformers.ops + + XFORMERS_IS_AVAILABLE = True +except: + XFORMERS_IS_AVAILABLE = False + print("no module 'xformers'. Processing without...") + +from modules.utils import checkpoint + + +def exists(val): + return val is not None + + +def uniq(arr): + return {el: True for el in arr}.keys() + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def max_neg_value(t): + return -torch.finfo(t.dtype).max + + +def init_(tensor): + dim = tensor.shape[-1] + std = 1 / math.sqrt(dim) + tensor.uniform_(-std, std) + return tensor + + +# feedforward +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)) + + def forward(self, x): + return self.net(x) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def Normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +class LinearAttention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3) + k = k.softmax(dim=-1) + context = torch.einsum("bhdn,bhen->bhde", k, v) + out = torch.einsum("bhde,bhdn->bhen", context, q) + out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w) + return self.to_out(out) + + +class SpatialSelfAttention(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = rearrange(q, "b c h w -> b (h w) c") + k = rearrange(k, "b c h w -> b c (h w)") + w_ = torch.einsum("bij,bjk->bik", q, k) + + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = rearrange(v, "b c h w -> b c (h w)") + w_ = rearrange(w_, "b i j -> b j i") + h_ = torch.einsum("bij,bjk->bik", v, w_) + h_ = rearrange(h_, "b c (h w) -> b c h w", h=h) + h_ = self.proj_out(h_) + + return x + h_ + + +class CrossAttention(nn.Module): + def __init__( + self, + query_dim, + context_dim=None, + heads=8, + dim_head=64, + dropout=0.0, + backend=None, + ): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head**-0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + self.backend = backend + + def forward( + self, + x, + context=None, + mask=None, + additional_tokens=None, + n_times_crossframe_attn_in_self=0, + ): + h = self.heads + + if additional_tokens is not None: + # get the number of masked tokens at the beginning of the output sequence + n_tokens_to_mask = additional_tokens.shape[1] + # add additional token + x = torch.cat([additional_tokens, x], dim=1) + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + if n_times_crossframe_attn_in_self: + # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439 + assert x.shape[0] % n_times_crossframe_attn_in_self == 0 + n_cp = x.shape[0] // n_times_crossframe_attn_in_self + k = repeat(k[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp) + v = repeat(v[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp) + + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) + + # old + """ + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + del q, k + + if exists(mask): + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + sim = sim.softmax(dim=-1) + + out = einsum('b i j, b j d -> b i d', sim, v) + """ + # new + with sdp_kernel(**BACKEND_MAP[self.backend]): + # print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape) + out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask) # scale is dim_head ** -0.5 per default + + del q, k, v + out = rearrange(out, "b h n d -> b n (h d)", h=h) + + if additional_tokens is not None: + # remove additional token + out = out[:, n_tokens_to_mask:] + return self.to_out(out) + + +class MemoryEfficientCrossAttention(nn.Module): + # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs): + super().__init__() + print( + f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using " + f"{heads} heads with a dimension of {dim_head}." + ) + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.heads = heads + self.dim_head = dim_head + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + self.attention_op: Optional[Any] = None + + def forward( + self, + x, + context=None, + mask=None, + additional_tokens=None, + n_times_crossframe_attn_in_self=0, + ): + if additional_tokens is not None: + # get the number of masked tokens at the beginning of the output sequence + n_tokens_to_mask = additional_tokens.shape[1] + # add additional token + x = torch.cat([additional_tokens, x], dim=1) + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + if n_times_crossframe_attn_in_self: + # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439 + assert x.shape[0] % n_times_crossframe_attn_in_self == 0 + # n_cp = x.shape[0]//n_times_crossframe_attn_in_self + k = repeat( + k[::n_times_crossframe_attn_in_self], + "b ... -> (b n) ...", + n=n_times_crossframe_attn_in_self, + ) + v = repeat( + v[::n_times_crossframe_attn_in_self], + "b ... -> (b n) ...", + n=n_times_crossframe_attn_in_self, + ) + + b, _, _ = q.shape + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(b, t.shape[1], self.heads, self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b * self.heads, t.shape[1], self.dim_head) + .contiguous(), + (q, k, v), + ) + + # actually compute the attention, what we cannot get enough of + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) + + # TODO: Use this directly in the attention operation, as a bias + if exists(mask): + raise NotImplementedError + out = ( + out.unsqueeze(0) + .reshape(b, self.heads, out.shape[1], self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b, out.shape[1], self.heads * self.dim_head) + ) + if additional_tokens is not None: + # remove additional token + out = out[:, n_tokens_to_mask:] + return self.to_out(out) + + +class BasicTransformerBlock(nn.Module): + ATTENTION_MODES = { + "softmax": CrossAttention, # vanilla attention + "softmax-xformers": MemoryEfficientCrossAttention, # ampere + } + + def __init__( + self, + dim, + n_heads, + d_head, + dropout=0.0, + context_dim=None, + gated_ff=True, + checkpoint=True, + disable_self_attn=False, + attn_mode="softmax", + sdp_backend=None, + ): + super().__init__() + assert attn_mode in self.ATTENTION_MODES + if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE: + print( + f"Attention mode '{attn_mode}' is not available. Falling back to native attention. " + f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}" + ) + attn_mode = "softmax" + elif attn_mode == "softmax" and not SDP_IS_AVAILABLE: + print("We do not support vanilla attention anymore, as it is too expensive. Sorry.") + if not XFORMERS_IS_AVAILABLE: + assert False, "Please install xformers via e.g. 'pip install xformers==0.0.16'" + else: + print("Falling back to xformers efficient attention.") + attn_mode = "softmax-xformers" + attn_cls = self.ATTENTION_MODES[attn_mode] + if version.parse(torch.__version__) >= version.parse("2.0.0"): + assert sdp_backend is None or isinstance(sdp_backend, SDPBackend) + else: + assert sdp_backend is None + self.disable_self_attn = disable_self_attn + self.attn1 = attn_cls( + query_dim=dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + context_dim=context_dim if self.disable_self_attn else None, + backend=sdp_backend, + ) # is a self-attention if not self.disable_self_attn + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = attn_cls( + query_dim=dim, + context_dim=context_dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + backend=sdp_backend, + ) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + if self.checkpoint: + print(f"{self.__class__.__name__} is using checkpointing") + + def forward(self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): + kwargs = {"x": x} + + if context is not None: + kwargs.update({"context": context}) + + if additional_tokens is not None: + kwargs.update({"additional_tokens": additional_tokens}) + + if n_times_crossframe_attn_in_self: + kwargs.update({"n_times_crossframe_attn_in_self": n_times_crossframe_attn_in_self}) + + # return mixed_checkpoint(self._forward, kwargs, self.parameters(), self.checkpoint) + return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) + + def _forward(self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): + x = ( + self.attn1( + self.norm1(x), + context=context if self.disable_self_attn else None, + additional_tokens=additional_tokens, + n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self if not self.disable_self_attn else 0, + ) + + x + ) + x = self.attn2(self.norm2(x), context=context, additional_tokens=additional_tokens) + x + x = self.ff(self.norm3(x)) + x + return x + + +class BasicTransformerSingleLayerBlock(nn.Module): + ATTENTION_MODES = { + "softmax": CrossAttention, # vanilla attention + "softmax-xformers": MemoryEfficientCrossAttention, # on the A100s not quite as fast as the above version + # (todo might depend on head_dim, check, falls back to semi-optimized kernels for dim!=[16,32,64,128]) + } + + def __init__( + self, + dim, + n_heads, + d_head, + dropout=0.0, + context_dim=None, + gated_ff=True, + checkpoint=True, + attn_mode="softmax", + ): + super().__init__() + assert attn_mode in self.ATTENTION_MODES + attn_cls = self.ATTENTION_MODES[attn_mode] + self.attn1 = attn_cls( + query_dim=dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + context_dim=context_dim, + ) + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward(self, x, context=None): + return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) + + def _forward(self, x, context=None): + x = self.attn1(self.norm1(x), context=context) + x + x = self.ff(self.norm2(x)) + x + return x + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + NEW: use_linear for more efficiency instead of the 1x1 convs + """ + + def __init__( + self, + in_channels, + n_heads, + d_head, + depth=1, + dropout=0.0, + context_dim=None, + disable_self_attn=False, + use_linear=False, + attn_type="softmax", + use_checkpoint=True, + # sdp_backend=SDPBackend.FLASH_ATTENTION + sdp_backend=None, + ): + super().__init__() + print(f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads") + from omegaconf import ListConfig + + if exists(context_dim) and not isinstance(context_dim, (list, ListConfig)): + context_dim = [context_dim] + if exists(context_dim) and isinstance(context_dim, list): + if depth != len(context_dim): + print( + f"WARNING: {self.__class__.__name__}: Found context dims {context_dim} of depth {len(context_dim)}, " + f"which does not match the specified 'depth' of {depth}. Setting context_dim to {depth * [context_dim[0]]} now." + ) + # depth does not match context dims. + assert all( + map(lambda x: x == context_dim[0], context_dim) + ), "need homogenous context_dim to match depth automatically" + context_dim = depth * [context_dim[0]] + elif context_dim is None: + context_dim = [None] * depth + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = Normalize(in_channels) + if not use_linear: + self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + else: + self.proj_in = nn.Linear(in_channels, inner_dim) + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + n_heads, + d_head, + dropout=dropout, + context_dim=context_dim[d], + disable_self_attn=disable_self_attn, + attn_mode=attn_type, + checkpoint=use_checkpoint, + sdp_backend=sdp_backend, + ) + for d in range(depth) + ] + ) + if not use_linear: + self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)) + else: + # self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) + self.proj_out = zero_module(nn.Linear(inner_dim, in_channels)) + self.use_linear = use_linear + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + if not isinstance(context, list): + context = [context] + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + if not self.use_linear: + x = self.proj_in(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + if self.use_linear: + x = self.proj_in(x) + for i, block in enumerate(self.transformer_blocks): + if i > 0 and len(context) == 1: + i = 0 # use same context for each block + x = block(x, context=context[i]) + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous() + if not self.use_linear: + x = self.proj_out(x) + return x + x_in diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/vae_modules/autoencoder.py b/PyTorch/contrib/cv/video/CogVideoX/sat/vae_modules/autoencoder.py new file mode 100644 index 0000000000000000000000000000000000000000..a1305ace27aa2610b8690da5d0d56879ad7f1eb6 --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/vae_modules/autoencoder.py @@ -0,0 +1,651 @@ +import logging +import math +import re +import random +from abc import abstractmethod +from contextlib import contextmanager +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import pytorch_lightning as pl +import torch +import torch.distributed +import torch.nn as nn +from einops import rearrange +from packaging import version + +from vae_modules.ema import LitEma +from sgm.util import ( + instantiate_from_config, + get_obj_from_str, + default, + is_context_parallel_initialized, + initialize_context_parallel, + get_context_parallel_group, + get_context_parallel_group_rank, +) +from vae_modules.cp_enc_dec import _conv_split, _conv_gather + +logpy = logging.getLogger(__name__) + + +class AbstractAutoencoder(pl.LightningModule): + """ + This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators, + unCLIP models, etc. Hence, it is fairly general, and specific features + (e.g. discriminator training, encoding, decoding) must be implemented in subclasses. + """ + + def __init__( + self, + ema_decay: Union[None, float] = None, + monitor: Union[None, str] = None, + input_key: str = "jpg", + ): + super().__init__() + + self.input_key = input_key + self.use_ema = ema_decay is not None + if monitor is not None: + self.monitor = monitor + + if self.use_ema: + self.model_ema = LitEma(self, decay=ema_decay) + logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + if version.parse(torch.__version__) >= version.parse("2.0.0"): + self.automatic_optimization = False + + # def apply_ckpt(self, ckpt: Union[None, str, dict]): + # if ckpt is None: + # return + # if isinstance(ckpt, str): + # ckpt = { + # "target": "sgm.modules.checkpoint.CheckpointEngine", + # "params": {"ckpt_path": ckpt}, + # } + # engine = instantiate_from_config(ckpt) + # engine(self) + + def apply_ckpt(self, ckpt: Union[None, str, dict]): + if ckpt is None: + return + self.init_from_ckpt(ckpt) + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing_keys, unexpected_keys = self.load_state_dict(sd, strict=False) + print("Missing keys: ", missing_keys) + print("Unexpected keys: ", unexpected_keys) + print(f"Restored from {path}") + + @abstractmethod + def get_input(self, batch) -> Any: + raise NotImplementedError() + + def on_train_batch_end(self, *args, **kwargs): + # for EMA computation + if self.use_ema: + self.model_ema(self) + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.parameters()) + self.model_ema.copy_to(self) + if context is not None: + logpy.info(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.parameters()) + if context is not None: + logpy.info(f"{context}: Restored training weights") + + @abstractmethod + def encode(self, *args, **kwargs) -> torch.Tensor: + raise NotImplementedError("encode()-method of abstract base class called") + + @abstractmethod + def decode(self, *args, **kwargs) -> torch.Tensor: + raise NotImplementedError("decode()-method of abstract base class called") + + def instantiate_optimizer_from_config(self, params, lr, cfg): + logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config") + return get_obj_from_str(cfg["target"])(params, lr=lr, **cfg.get("params", dict())) + + def configure_optimizers(self) -> Any: + raise NotImplementedError() + + +class AutoencodingEngine(AbstractAutoencoder): + """ + Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL + (we also restore them explicitly as special cases for legacy reasons). + Regularizations such as KL or VQ are moved to the regularizer class. + """ + + def __init__( + self, + *args, + encoder_config: Dict, + decoder_config: Dict, + loss_config: Dict, + regularizer_config: Dict, + optimizer_config: Union[Dict, None] = None, + lr_g_factor: float = 1.0, + trainable_ae_params: Optional[List[List[str]]] = None, + ae_optimizer_args: Optional[List[dict]] = None, + trainable_disc_params: Optional[List[List[str]]] = None, + disc_optimizer_args: Optional[List[dict]] = None, + disc_start_iter: int = 0, + diff_boost_factor: float = 3.0, + ckpt_engine: Union[None, str, dict] = None, + ckpt_path: Optional[str] = None, + additional_decode_keys: Optional[List[str]] = None, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.automatic_optimization = False # pytorch lightning + + self.encoder = instantiate_from_config(encoder_config) + self.decoder = instantiate_from_config(decoder_config) + self.loss = instantiate_from_config(loss_config) + self.regularization = instantiate_from_config(regularizer_config) + self.optimizer_config = default(optimizer_config, {"target": "torch.optim.Adam"}) + self.diff_boost_factor = diff_boost_factor + self.disc_start_iter = disc_start_iter + self.lr_g_factor = lr_g_factor + self.trainable_ae_params = trainable_ae_params + if self.trainable_ae_params is not None: + self.ae_optimizer_args = default( + ae_optimizer_args, + [{} for _ in range(len(self.trainable_ae_params))], + ) + assert len(self.ae_optimizer_args) == len(self.trainable_ae_params) + else: + self.ae_optimizer_args = [{}] # makes type consitent + + self.trainable_disc_params = trainable_disc_params + if self.trainable_disc_params is not None: + self.disc_optimizer_args = default( + disc_optimizer_args, + [{} for _ in range(len(self.trainable_disc_params))], + ) + assert len(self.disc_optimizer_args) == len(self.trainable_disc_params) + else: + self.disc_optimizer_args = [{}] # makes type consitent + + if ckpt_path is not None: + assert ckpt_engine is None, "Can't set ckpt_engine and ckpt_path" + logpy.warn("Checkpoint path is deprecated, use `checkpoint_egnine` instead") + self.apply_ckpt(default(ckpt_path, ckpt_engine)) + self.additional_decode_keys = set(default(additional_decode_keys, [])) + + def get_input(self, batch: Dict) -> torch.Tensor: + # assuming unified data format, dataloader returns a dict. + # image tensors should be scaled to -1 ... 1 and in channels-first + # format (e.g., bchw instead if bhwc) + return batch[self.input_key] + + def get_autoencoder_params(self) -> list: + params = [] + if hasattr(self.loss, "get_trainable_autoencoder_parameters"): + params += list(self.loss.get_trainable_autoencoder_parameters()) + if hasattr(self.regularization, "get_trainable_parameters"): + params += list(self.regularization.get_trainable_parameters()) + params = params + list(self.encoder.parameters()) + params = params + list(self.decoder.parameters()) + return params + + def get_discriminator_params(self) -> list: + if hasattr(self.loss, "get_trainable_parameters"): + params = list(self.loss.get_trainable_parameters()) # e.g., discriminator + else: + params = [] + return params + + def get_last_layer(self): + return self.decoder.get_last_layer() + + def encode( + self, + x: torch.Tensor, + return_reg_log: bool = False, + unregularized: bool = False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]: + z = self.encoder(x) + if unregularized: + return z, dict() + z, reg_log = self.regularization(z) + if return_reg_log: + return z, reg_log + return z + + def decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor: + x = self.decoder(z, **kwargs) + return x + + def forward(self, x: torch.Tensor, **additional_decode_kwargs) -> Tuple[torch.Tensor, torch.Tensor, dict]: + z, reg_log = self.encode(x, return_reg_log=True) + dec = self.decode(z, **additional_decode_kwargs) + return z, dec, reg_log + + def inner_training_step(self, batch: dict, batch_idx: int, optimizer_idx: int = 0) -> torch.Tensor: + x = self.get_input(batch) + additional_decode_kwargs = {key: batch[key] for key in self.additional_decode_keys.intersection(batch)} + z, xrec, regularization_log = self(x, **additional_decode_kwargs) + if hasattr(self.loss, "forward_keys"): + extra_info = { + "z": z, + "optimizer_idx": optimizer_idx, + "global_step": self.global_step, + "last_layer": self.get_last_layer(), + "split": "train", + "regularization_log": regularization_log, + "autoencoder": self, + } + extra_info = {k: extra_info[k] for k in self.loss.forward_keys} + else: + extra_info = dict() + + if optimizer_idx == 0: + # autoencode + out_loss = self.loss(x, xrec, **extra_info) + if isinstance(out_loss, tuple): + aeloss, log_dict_ae = out_loss + else: + # simple loss function + aeloss = out_loss + log_dict_ae = {"train/loss/rec": aeloss.detach()} + + self.log_dict( + log_dict_ae, + prog_bar=False, + logger=True, + on_step=True, + on_epoch=True, + sync_dist=False, + ) + self.log( + "loss", + aeloss.mean().detach(), + prog_bar=True, + logger=False, + on_epoch=False, + on_step=True, + ) + return aeloss + elif optimizer_idx == 1: + # discriminator + discloss, log_dict_disc = self.loss(x, xrec, **extra_info) + # -> discriminator always needs to return a tuple + self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return discloss + else: + raise NotImplementedError(f"Unknown optimizer {optimizer_idx}") + + def training_step(self, batch: dict, batch_idx: int): + opts = self.optimizers() + if not isinstance(opts, list): + # Non-adversarial case + opts = [opts] + optimizer_idx = batch_idx % len(opts) + if self.global_step < self.disc_start_iter: + optimizer_idx = 0 + opt = opts[optimizer_idx] + opt.zero_grad() + with opt.toggle_model(): + loss = self.inner_training_step(batch, batch_idx, optimizer_idx=optimizer_idx) + self.manual_backward(loss) + opt.step() + + def validation_step(self, batch: dict, batch_idx: int) -> Dict: + log_dict = self._validation_step(batch, batch_idx) + with self.ema_scope(): + log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema") + log_dict.update(log_dict_ema) + return log_dict + + def _validation_step(self, batch: dict, batch_idx: int, postfix: str = "") -> Dict: + x = self.get_input(batch) + + z, xrec, regularization_log = self(x) + if hasattr(self.loss, "forward_keys"): + extra_info = { + "z": z, + "optimizer_idx": 0, + "global_step": self.global_step, + "last_layer": self.get_last_layer(), + "split": "val" + postfix, + "regularization_log": regularization_log, + "autoencoder": self, + } + extra_info = {k: extra_info[k] for k in self.loss.forward_keys} + else: + extra_info = dict() + out_loss = self.loss(x, xrec, **extra_info) + if isinstance(out_loss, tuple): + aeloss, log_dict_ae = out_loss + else: + # simple loss function + aeloss = out_loss + log_dict_ae = {f"val{postfix}/loss/rec": aeloss.detach()} + full_log_dict = log_dict_ae + + if "optimizer_idx" in extra_info: + extra_info["optimizer_idx"] = 1 + discloss, log_dict_disc = self.loss(x, xrec, **extra_info) + full_log_dict.update(log_dict_disc) + self.log( + f"val{postfix}/loss/rec", + log_dict_ae[f"val{postfix}/loss/rec"], + sync_dist=True, + ) + self.log_dict(full_log_dict, sync_dist=True) + return full_log_dict + + def get_param_groups( + self, parameter_names: List[List[str]], optimizer_args: List[dict] + ) -> Tuple[List[Dict[str, Any]], int]: + groups = [] + num_params = 0 + for names, args in zip(parameter_names, optimizer_args): + params = [] + for pattern_ in names: + pattern_params = [] + pattern = re.compile(pattern_) + for p_name, param in self.named_parameters(): + if re.match(pattern, p_name): + pattern_params.append(param) + num_params += param.numel() + if len(pattern_params) == 0: + logpy.warn(f"Did not find parameters for pattern {pattern_}") + params.extend(pattern_params) + groups.append({"params": params, **args}) + return groups, num_params + + def configure_optimizers(self) -> List[torch.optim.Optimizer]: + if self.trainable_ae_params is None: + ae_params = self.get_autoencoder_params() + else: + ae_params, num_ae_params = self.get_param_groups(self.trainable_ae_params, self.ae_optimizer_args) + logpy.info(f"Number of trainable autoencoder parameters: {num_ae_params:,}") + if self.trainable_disc_params is None: + disc_params = self.get_discriminator_params() + else: + disc_params, num_disc_params = self.get_param_groups(self.trainable_disc_params, self.disc_optimizer_args) + logpy.info(f"Number of trainable discriminator parameters: {num_disc_params:,}") + opt_ae = self.instantiate_optimizer_from_config( + ae_params, + default(self.lr_g_factor, 1.0) * self.learning_rate, + self.optimizer_config, + ) + opts = [opt_ae] + if len(disc_params) > 0: + opt_disc = self.instantiate_optimizer_from_config(disc_params, self.learning_rate, self.optimizer_config) + opts.append(opt_disc) + + return opts + + @torch.no_grad() + def log_images(self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs) -> dict: + log = dict() + additional_decode_kwargs = {} + x = self.get_input(batch) + additional_decode_kwargs.update({key: batch[key] for key in self.additional_decode_keys.intersection(batch)}) + + _, xrec, _ = self(x, **additional_decode_kwargs) + log["inputs"] = x + log["reconstructions"] = xrec + diff = 0.5 * torch.abs(torch.clamp(xrec, -1.0, 1.0) - x) + diff.clamp_(0, 1.0) + log["diff"] = 2.0 * diff - 1.0 + # diff_boost shows location of small errors, by boosting their + # brightness. + log["diff_boost"] = 2.0 * torch.clamp(self.diff_boost_factor * diff, 0.0, 1.0) - 1 + if hasattr(self.loss, "log_images"): + log.update(self.loss.log_images(x, xrec)) + with self.ema_scope(): + _, xrec_ema, _ = self(x, **additional_decode_kwargs) + log["reconstructions_ema"] = xrec_ema + diff_ema = 0.5 * torch.abs(torch.clamp(xrec_ema, -1.0, 1.0) - x) + diff_ema.clamp_(0, 1.0) + log["diff_ema"] = 2.0 * diff_ema - 1.0 + log["diff_boost_ema"] = 2.0 * torch.clamp(self.diff_boost_factor * diff_ema, 0.0, 1.0) - 1 + if additional_log_kwargs: + additional_decode_kwargs.update(additional_log_kwargs) + _, xrec_add, _ = self(x, **additional_decode_kwargs) + log_str = "reconstructions-" + "-".join( + [f"{key}={additional_log_kwargs[key]}" for key in additional_log_kwargs] + ) + log[log_str] = xrec_add + return log + + +class AutoencodingEngineLegacy(AutoencodingEngine): + def __init__(self, embed_dim: int, **kwargs): + self.max_batch_size = kwargs.pop("max_batch_size", None) + ddconfig = kwargs.pop("ddconfig") + ckpt_path = kwargs.pop("ckpt_path", None) + ckpt_engine = kwargs.pop("ckpt_engine", None) + super().__init__( + encoder_config={ + "target": "sgm.modules.diffusionmodules.model.Encoder", + "params": ddconfig, + }, + decoder_config={ + "target": "sgm.modules.diffusionmodules.model.Decoder", + "params": ddconfig, + }, + **kwargs, + ) + self.quant_conv = torch.nn.Conv2d( + (1 + ddconfig["double_z"]) * ddconfig["z_channels"], + (1 + ddconfig["double_z"]) * embed_dim, + 1, + ) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + self.embed_dim = embed_dim + + self.apply_ckpt(default(ckpt_path, ckpt_engine)) + + def get_autoencoder_params(self) -> list: + params = super().get_autoencoder_params() + return params + + def encode(self, x: torch.Tensor, return_reg_log: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]: + if self.max_batch_size is None: + z = self.encoder(x) + z = self.quant_conv(z) + else: + N = x.shape[0] + bs = self.max_batch_size + n_batches = int(math.ceil(N / bs)) + z = list() + for i_batch in range(n_batches): + z_batch = self.encoder(x[i_batch * bs : (i_batch + 1) * bs]) + z_batch = self.quant_conv(z_batch) + z.append(z_batch) + z = torch.cat(z, 0) + + z, reg_log = self.regularization(z) + if return_reg_log: + return z, reg_log + return z + + def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor: + if self.max_batch_size is None: + dec = self.post_quant_conv(z) + dec = self.decoder(dec, **decoder_kwargs) + else: + N = z.shape[0] + bs = self.max_batch_size + n_batches = int(math.ceil(N / bs)) + dec = list() + for i_batch in range(n_batches): + dec_batch = self.post_quant_conv(z[i_batch * bs : (i_batch + 1) * bs]) + dec_batch = self.decoder(dec_batch, **decoder_kwargs) + dec.append(dec_batch) + dec = torch.cat(dec, 0) + + return dec + + +class AutoencoderKL(AutoencodingEngineLegacy): + def __init__(self, **kwargs): + if "lossconfig" in kwargs: + kwargs["loss_config"] = kwargs.pop("lossconfig") + super().__init__( + regularizer_config={"target": ("sgm.modules.autoencoding.regularizers" ".DiagonalGaussianRegularizer")}, + **kwargs, + ) + + +class IdentityFirstStage(AbstractAutoencoder): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def get_input(self, x: Any) -> Any: + return x + + def encode(self, x: Any, *args, **kwargs) -> Any: + return x + + def decode(self, x: Any, *args, **kwargs) -> Any: + return x + + +class VideoAutoencodingEngine(AutoencodingEngine): + def __init__( + self, + ckpt_path: Union[None, str] = None, + ignore_keys: Union[Tuple, list] = (), + image_video_weights=[1, 1], + only_train_decoder=False, + context_parallel_size=0, + **kwargs, + ): + super().__init__(**kwargs) + self.context_parallel_size = context_parallel_size + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + + def log_videos(self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs) -> dict: + return self.log_images(batch, additional_log_kwargs, **kwargs) + + def get_input(self, batch: dict) -> torch.Tensor: + if self.context_parallel_size > 0: + if not is_context_parallel_initialized(): + initialize_context_parallel(self.context_parallel_size) + + batch = batch[self.input_key] + + global_src_rank = get_context_parallel_group_rank() * self.context_parallel_size + torch.distributed.broadcast(batch, src=global_src_rank, group=get_context_parallel_group()) + + batch = _conv_split(batch, dim=2, kernel_size=1) + return batch + + return batch[self.input_key] + + def apply_ckpt(self, ckpt: Union[None, str, dict]): + if ckpt is None: + return + self.init_from_ckpt(ckpt) + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing_keys, unexpected_keys = self.load_state_dict(sd, strict=False) + print("Missing keys: ", missing_keys) + print("Unexpected keys: ", unexpected_keys) + print(f"Restored from {path}") + + +class VideoAutoencoderInferenceWrapper(VideoAutoencodingEngine): + def __init__( + self, + cp_size=0, + *args, + **kwargs, + ): + self.cp_size = cp_size + super().__init__(*args, **kwargs) + + def encode( + self, + x: torch.Tensor, + return_reg_log: bool = False, + unregularized: bool = False, + input_cp: bool = False, + output_cp: bool = False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]: + if self.cp_size > 0 and not input_cp: + if not is_context_parallel_initialized: + initialize_context_parallel(self.cp_size) + + global_src_rank = get_context_parallel_group_rank() * self.cp_size + torch.distributed.broadcast(x, src=global_src_rank, group=get_context_parallel_group()) + + x = _conv_split(x, dim=2, kernel_size=1) + + if return_reg_log: + z, reg_log = super().encode(x, return_reg_log, unregularized) + else: + z = super().encode(x, return_reg_log, unregularized) + + if self.cp_size > 0 and not output_cp: + z = _conv_gather(z, dim=2, kernel_size=1) + + if return_reg_log: + return z, reg_log + return z + + def decode( + self, + z: torch.Tensor, + input_cp: bool = False, + output_cp: bool = False, + split_kernel_size: int = 1, + **kwargs, + ): + if self.cp_size > 0 and not input_cp: + if not is_context_parallel_initialized: + initialize_context_parallel(self.cp_size) + + global_src_rank = get_context_parallel_group_rank() * self.cp_size + torch.distributed.broadcast(z, src=global_src_rank, group=get_context_parallel_group()) + + z = _conv_split(z, dim=2, kernel_size=split_kernel_size) + + x = super().decode(z, **kwargs) + + if self.cp_size > 0 and not output_cp: + x = _conv_gather(x, dim=2, kernel_size=split_kernel_size) + + return x + + def forward( + self, + x: torch.Tensor, + input_cp: bool = False, + latent_cp: bool = False, + output_cp: bool = False, + **additional_decode_kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor, dict]: + z, reg_log = self.encode(x, return_reg_log=True, input_cp=input_cp, output_cp=latent_cp) + dec = self.decode(z, input_cp=latent_cp, output_cp=output_cp, **additional_decode_kwargs) + return z, dec, reg_log diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/vae_modules/cp_enc_dec.py b/PyTorch/contrib/cv/video/CogVideoX/sat/vae_modules/cp_enc_dec.py new file mode 100644 index 0000000000000000000000000000000000000000..d50720df052e5095ddc872da84017fc29f661db7 --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/vae_modules/cp_enc_dec.py @@ -0,0 +1,987 @@ +import math +import torch +import torch.distributed +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +from beartype import beartype +from beartype.typing import Union, Tuple, Optional, List +from einops import rearrange + +from sgm.util import ( + get_context_parallel_group, + get_context_parallel_rank, + get_context_parallel_world_size, + get_context_parallel_group_rank, +) + +# try: +from vae_modules.utils import SafeConv3d as Conv3d +# except: +# # Degrade to normal Conv3d if SafeConv3d is not available +# from torch.nn import Conv3d + + +def cast_tuple(t, length=1): + return t if isinstance(t, tuple) else ((t,) * length) + + +def divisible_by(num, den): + return (num % den) == 0 + + +def is_odd(n): + return not divisible_by(n, 2) + + +def exists(v): + return v is not None + + +def pair(t): + return t if isinstance(t, tuple) else (t, t) + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +def nonlinearity(x): + # swish + return x * torch.sigmoid(x) + + +def leaky_relu(p=0.1): + return nn.LeakyReLU(p) + + +def _split(input_, dim): + cp_world_size = get_context_parallel_world_size() + + if cp_world_size == 1: + return input_ + + cp_rank = get_context_parallel_rank() + + # print('in _split, cp_rank:', cp_rank, 'input_size:', input_.shape) + + inpu_first_frame_ = input_.transpose(0, dim)[:1].transpose(0, dim).contiguous() + input_ = input_.transpose(0, dim)[1:].transpose(0, dim).contiguous() + dim_size = input_.size()[dim] // cp_world_size + + input_list = torch.split(input_, dim_size, dim=dim) + output = input_list[cp_rank] + + if cp_rank == 0: + output = torch.cat([inpu_first_frame_, output], dim=dim) + output = output.contiguous() + + # print('out _split, cp_rank:', cp_rank, 'output_size:', output.shape) + + return output + + +def _gather(input_, dim): + cp_world_size = get_context_parallel_world_size() + + # Bypass the function if context parallel is 1 + if cp_world_size == 1: + return input_ + + group = get_context_parallel_group() + cp_rank = get_context_parallel_rank() + + # print('in _gather, cp_rank:', cp_rank, 'input_size:', input_.shape) + + input_first_frame_ = input_.transpose(0, dim)[:1].transpose(0, dim).contiguous() + if cp_rank == 0: + input_ = input_.transpose(0, dim)[1:].transpose(0, dim).contiguous() + + tensor_list = [torch.empty_like(torch.cat([input_first_frame_, input_], dim=dim))] + [ + torch.empty_like(input_) for _ in range(cp_world_size - 1) + ] + + if cp_rank == 0: + input_ = torch.cat([input_first_frame_, input_], dim=dim) + + tensor_list[cp_rank] = input_ + torch.distributed.all_gather(tensor_list, input_, group=group) + + output = torch.cat(tensor_list, dim=dim).contiguous() + + # print('out _gather, cp_rank:', cp_rank, 'output_size:', output.shape) + + return output + + +def _conv_split(input_, dim, kernel_size): + cp_world_size = get_context_parallel_world_size() + + # Bypass the function if context parallel is 1 + if cp_world_size == 1: + return input_ + + # print('in _conv_split, cp_rank:', cp_rank, 'input_size:', input_.shape) + + cp_rank = get_context_parallel_rank() + + dim_size = (input_.size()[dim] - kernel_size) // cp_world_size + + if cp_rank == 0: + output = input_.transpose(dim, 0)[: dim_size + kernel_size].transpose(dim, 0) + else: + # output = input_.transpose(dim, 0)[cp_rank * dim_size + 1:(cp_rank + 1) * dim_size + kernel_size].transpose(dim, 0) + output = input_.transpose(dim, 0)[ + cp_rank * dim_size + kernel_size : (cp_rank + 1) * dim_size + kernel_size + ].transpose(dim, 0) + output = output.contiguous() + + # print('out _conv_split, cp_rank:', cp_rank, 'input_size:', output.shape) + + return output + + +def _conv_gather(input_, dim, kernel_size): + cp_world_size = get_context_parallel_world_size() + + # Bypass the function if context parallel is 1 + if cp_world_size == 1: + return input_ + + group = get_context_parallel_group() + cp_rank = get_context_parallel_rank() + + # print('in _conv_gather, cp_rank:', cp_rank, 'input_size:', input_.shape) + + input_first_kernel_ = input_.transpose(0, dim)[:kernel_size].transpose(0, dim).contiguous() + if cp_rank == 0: + input_ = input_.transpose(0, dim)[kernel_size:].transpose(0, dim).contiguous() + else: + input_ = input_.transpose(0, dim)[max(kernel_size - 1, 0) :].transpose(0, dim).contiguous() + + tensor_list = [torch.empty_like(torch.cat([input_first_kernel_, input_], dim=dim))] + [ + torch.empty_like(input_) for _ in range(cp_world_size - 1) + ] + if cp_rank == 0: + input_ = torch.cat([input_first_kernel_, input_], dim=dim) + + tensor_list[cp_rank] = input_ + torch.distributed.all_gather(tensor_list, input_, group=group) + + # Note: torch.cat already creates a contiguous tensor. + output = torch.cat(tensor_list, dim=dim).contiguous() + + # print('out _conv_gather, cp_rank:', cp_rank, 'input_size:', output.shape) + + return output + + +def _pass_from_previous_rank(input_, dim, kernel_size): + # Bypass the function if kernel size is 1 + if kernel_size == 1: + return input_ + + group = get_context_parallel_group() + cp_rank = get_context_parallel_rank() + cp_group_rank = get_context_parallel_group_rank() + cp_world_size = get_context_parallel_world_size() + + # print('in _pass_from_previous_rank, cp_rank:', cp_rank, 'input_size:', input_.shape) + + global_rank = torch.distributed.get_rank() + global_world_size = torch.distributed.get_world_size() + + input_ = input_.transpose(0, dim) + + # pass from last rank + send_rank = global_rank + 1 + recv_rank = global_rank - 1 + if send_rank % cp_world_size == 0: + send_rank -= cp_world_size + if recv_rank % cp_world_size == cp_world_size - 1: + recv_rank += cp_world_size + + if cp_rank < cp_world_size - 1: + req_send = torch.distributed.isend(input_[-kernel_size + 1 :].contiguous(), send_rank, group=group) + if cp_rank > 0: + recv_buffer = torch.empty_like(input_[-kernel_size + 1 :]).contiguous() + req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group) + + if cp_rank == 0: + input_ = torch.cat([input_[:1]] * (kernel_size - 1) + [input_], dim=0) + else: + req_recv.wait() + input_ = torch.cat([recv_buffer, input_], dim=0) + + input_ = input_.transpose(0, dim).contiguous() + + # print('out _pass_from_previous_rank, cp_rank:', cp_rank, 'input_size:', input_.shape) + + return input_ + + +def _fake_cp_pass_from_previous_rank(input_, dim, kernel_size, cache_padding=None): + # Bypass the function if kernel size is 1 + if kernel_size == 1: + return input_ + + group = get_context_parallel_group() + cp_rank = get_context_parallel_rank() + cp_group_rank = get_context_parallel_group_rank() + cp_world_size = get_context_parallel_world_size() + + # print('in _pass_from_previous_rank, cp_rank:', cp_rank, 'input_size:', input_.shape) + + global_rank = torch.distributed.get_rank() + global_world_size = torch.distributed.get_world_size() + + input_ = input_.transpose(0, dim) + + # pass from last rank + send_rank = global_rank + 1 + recv_rank = global_rank - 1 + if send_rank % cp_world_size == 0: + send_rank -= cp_world_size + if recv_rank % cp_world_size == cp_world_size - 1: + recv_rank += cp_world_size + + # req_send = torch.distributed.isend(input_[-kernel_size + 1:].contiguous(), send_rank, group=group) + # recv_buffer = torch.empty_like(input_[-kernel_size + 1:]).contiguous() + # req_recv = torch.distributed.recv(recv_buffer, recv_rank, group=group) + # req_recv.wait() + recv_buffer = torch.empty_like(input_[-kernel_size + 1 :]).contiguous() + if cp_rank < cp_world_size - 1: + req_send = torch.distributed.isend(input_[-kernel_size + 1 :].contiguous(), send_rank, group=group) + if cp_rank > 0: + req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group) + # req_send = torch.distributed.isend(input_[-kernel_size + 1:].contiguous(), send_rank, group=group) + # req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group) + + if cp_rank == 0: + if cache_padding is not None: + input_ = torch.cat([cache_padding.transpose(0, dim).to(input_.device), input_], dim=0) + else: + input_ = torch.cat([input_[:1]] * (kernel_size - 1) + [input_], dim=0) + else: + req_recv.wait() + input_ = torch.cat([recv_buffer, input_], dim=0) + + input_ = input_.transpose(0, dim).contiguous() + return input_ + + +def _drop_from_previous_rank(input_, dim, kernel_size): + input_ = input_.transpose(0, dim)[kernel_size - 1 :].transpose(0, dim) + return input_ + + +class _ConvolutionScatterToContextParallelRegion(torch.autograd.Function): + @staticmethod + def forward(ctx, input_, dim, kernel_size): + ctx.dim = dim + ctx.kernel_size = kernel_size + return _conv_split(input_, dim, kernel_size) + + @staticmethod + def backward(ctx, grad_output): + return _conv_gather(grad_output, ctx.dim, ctx.kernel_size), None, None + + +class _ConvolutionGatherFromContextParallelRegion(torch.autograd.Function): + @staticmethod + def forward(ctx, input_, dim, kernel_size): + ctx.dim = dim + ctx.kernel_size = kernel_size + return _conv_gather(input_, dim, kernel_size) + + @staticmethod + def backward(ctx, grad_output): + return _conv_split(grad_output, ctx.dim, ctx.kernel_size), None, None + + +class _ConvolutionPassFromPreviousRank(torch.autograd.Function): + @staticmethod + def forward(ctx, input_, dim, kernel_size): + ctx.dim = dim + ctx.kernel_size = kernel_size + return _pass_from_previous_rank(input_, dim, kernel_size) + + @staticmethod + def backward(ctx, grad_output): + return _drop_from_previous_rank(grad_output, ctx.dim, ctx.kernel_size), None, None + + +class _FakeCPConvolutionPassFromPreviousRank(torch.autograd.Function): + @staticmethod + def forward(ctx, input_, dim, kernel_size, cache_padding): + ctx.dim = dim + ctx.kernel_size = kernel_size + return _fake_cp_pass_from_previous_rank(input_, dim, kernel_size, cache_padding) + + @staticmethod + def backward(ctx, grad_output): + return _drop_from_previous_rank(grad_output, ctx.dim, ctx.kernel_size), None, None, None + + +def conv_scatter_to_context_parallel_region(input_, dim, kernel_size): + return _ConvolutionScatterToContextParallelRegion.apply(input_, dim, kernel_size) + + +def conv_gather_from_context_parallel_region(input_, dim, kernel_size): + return _ConvolutionGatherFromContextParallelRegion.apply(input_, dim, kernel_size) + + +def conv_pass_from_last_rank(input_, dim, kernel_size): + return _ConvolutionPassFromPreviousRank.apply(input_, dim, kernel_size) + + +def fake_cp_pass_from_previous_rank(input_, dim, kernel_size, cache_padding): + return _FakeCPConvolutionPassFromPreviousRank.apply(input_, dim, kernel_size, cache_padding) + + +class ContextParallelCausalConv3d(nn.Module): + def __init__(self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], stride=1, **kwargs): + super().__init__() + kernel_size = cast_tuple(kernel_size, 3) + + time_kernel_size, height_kernel_size, width_kernel_size = kernel_size + + assert is_odd(height_kernel_size) and is_odd(width_kernel_size) + + time_pad = time_kernel_size - 1 + height_pad = height_kernel_size // 2 + width_pad = width_kernel_size // 2 + + self.height_pad = height_pad + self.width_pad = width_pad + self.time_pad = time_pad + self.time_kernel_size = time_kernel_size + self.temporal_dim = 2 + + stride = (stride, stride, stride) + dilation = (1, 1, 1) + self.conv = Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs) + self.cache_padding = None + + def forward(self, input_, clear_cache=True): + # if input_.shape[2] == 1: # handle image + # # first frame padding + # input_parallel = torch.cat([input_] * self.time_kernel_size, dim=2) + # else: + # input_parallel = conv_pass_from_last_rank(input_, self.temporal_dim, self.time_kernel_size) + + # padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad) + # input_parallel = F.pad(input_parallel, padding_2d, mode = 'constant', value = 0) + + # output_parallel = self.conv(input_parallel) + # output = output_parallel + # return output + + input_parallel = fake_cp_pass_from_previous_rank( + input_, self.temporal_dim, self.time_kernel_size, self.cache_padding + ) + + del self.cache_padding + self.cache_padding = None + if not clear_cache: + cp_rank, cp_world_size = get_context_parallel_rank(), get_context_parallel_world_size() + global_rank = torch.distributed.get_rank() + if cp_world_size == 1: + self.cache_padding = ( + input_parallel[:, :, -self.time_kernel_size + 1 :].contiguous().detach().clone().cpu() + ) + else: + if cp_rank == cp_world_size - 1: + torch.distributed.isend( + input_parallel[:, :, -self.time_kernel_size + 1 :].contiguous(), + global_rank + 1 - cp_world_size, + group=get_context_parallel_group(), + ) + if cp_rank == 0: + recv_buffer = torch.empty_like(input_parallel[:, :, -self.time_kernel_size + 1 :]).contiguous() + torch.distributed.recv( + recv_buffer, global_rank - 1 + cp_world_size, group=get_context_parallel_group() + ) + self.cache_padding = recv_buffer.contiguous().detach().clone().cpu() + + padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad) + input_parallel = F.pad(input_parallel, padding_2d, mode="constant", value=0) + + output_parallel = self.conv(input_parallel) + output = output_parallel + return output + + +class ContextParallelGroupNorm(torch.nn.GroupNorm): + def forward(self, input_): + gather_flag = input_.shape[2] > 1 + if gather_flag: + input_ = conv_gather_from_context_parallel_region(input_, dim=2, kernel_size=1) + output = super().forward(input_) + if gather_flag: + output = conv_scatter_to_context_parallel_region(output, dim=2, kernel_size=1) + return output + + +def Normalize(in_channels, gather=False, **kwargs): # same for 3D and 2D + if gather: + return ContextParallelGroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + else: + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +class SpatialNorm3D(nn.Module): + def __init__( + self, + f_channels, + zq_channels, + freeze_norm_layer=False, + add_conv=False, + pad_mode="constant", + gather=False, + **norm_layer_params, + ): + super().__init__() + if gather: + self.norm_layer = ContextParallelGroupNorm(num_channels=f_channels, **norm_layer_params) + else: + self.norm_layer = torch.nn.GroupNorm(num_channels=f_channels, **norm_layer_params) + # self.norm_layer = norm_layer(num_channels=f_channels, **norm_layer_params) + if freeze_norm_layer: + for p in self.norm_layer.parameters: + p.requires_grad = False + + self.add_conv = add_conv + if add_conv: + self.conv = ContextParallelCausalConv3d( + chan_in=zq_channels, + chan_out=zq_channels, + kernel_size=3, + ) + + self.conv_y = ContextParallelCausalConv3d( + chan_in=zq_channels, + chan_out=f_channels, + kernel_size=1, + ) + self.conv_b = ContextParallelCausalConv3d( + chan_in=zq_channels, + chan_out=f_channels, + kernel_size=1, + ) + + def forward(self, f, zq, clear_fake_cp_cache=True): + if f.shape[2] > 1 and f.shape[2] % 2 == 1: + f_first, f_rest = f[:, :, :1], f[:, :, 1:] + f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:] + zq_first, zq_rest = zq[:, :, :1], zq[:, :, 1:] + zq_first = torch.nn.functional.interpolate(zq_first, size=f_first_size, mode="nearest") + zq_rest = torch.nn.functional.interpolate(zq_rest, size=f_rest_size, mode="nearest") + zq = torch.cat([zq_first, zq_rest], dim=2) + else: + zq = torch.nn.functional.interpolate(zq, size=f.shape[-3:], mode="nearest") + + if self.add_conv: + zq = self.conv(zq, clear_cache=clear_fake_cp_cache) + + # f = conv_gather_from_context_parallel_region(f, dim=2, kernel_size=1) + norm_f = self.norm_layer(f) + # norm_f = conv_scatter_to_context_parallel_region(norm_f, dim=2, kernel_size=1) + + new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) + return new_f + + +def Normalize3D( + in_channels, + zq_ch, + add_conv, + gather=False, +): + return SpatialNorm3D( + in_channels, + zq_ch, + gather=gather, + freeze_norm_layer=False, + add_conv=add_conv, + num_groups=32, + eps=1e-6, + affine=True, + ) + + +class Upsample3D(nn.Module): + def __init__( + self, + in_channels, + with_conv, + compress_time=False, + ): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + self.compress_time = compress_time + + def forward(self, x): + if self.compress_time and x.shape[2] > 1: + if x.shape[2] % 2 == 1: + # split first frame + x_first, x_rest = x[:, :, 0], x[:, :, 1:] + + x_first = torch.nn.functional.interpolate(x_first, scale_factor=2.0, mode="nearest") + x_rest = torch.nn.functional.interpolate(x_rest, scale_factor=2.0, mode="nearest") + x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2) + else: + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + + else: + # only interpolate 2D + t = x.shape[2] + x = rearrange(x, "b c t h w -> (b t) c h w") + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + x = rearrange(x, "(b t) c h w -> b c t h w", t=t) + + if self.with_conv: + t = x.shape[2] + x = rearrange(x, "b c t h w -> (b t) c h w") + x = self.conv(x) + x = rearrange(x, "(b t) c h w -> b c t h w", t=t) + return x + + +class DownSample3D(nn.Module): + def __init__(self, in_channels, with_conv, compress_time=False, out_channels=None): + super().__init__() + self.with_conv = with_conv + if out_channels is None: + out_channels = in_channels + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0) + self.compress_time = compress_time + + def forward(self, x): + if self.compress_time and x.shape[2] > 1: + h, w = x.shape[-2:] + x = rearrange(x, "b c t h w -> (b h w) c t") + + if x.shape[-1] % 2 == 1: + # split first frame + x_first, x_rest = x[..., 0], x[..., 1:] + + if x_rest.shape[-1] > 0: + x_rest = torch.nn.functional.avg_pool1d(x_rest, kernel_size=2, stride=2) + x = torch.cat([x_first[..., None], x_rest], dim=-1) + x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w) + else: + x = torch.nn.functional.avg_pool1d(x, kernel_size=2, stride=2) + x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w) + + if self.with_conv: + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + t = x.shape[2] + x = rearrange(x, "b c t h w -> (b t) c h w") + x = self.conv(x) + x = rearrange(x, "(b t) c h w -> b c t h w", t=t) + else: + t = x.shape[2] + x = rearrange(x, "b c t h w -> (b t) c h w") + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + x = rearrange(x, "(b t) c h w -> b c t h w", t=t) + return x + + +class ContextParallelResnetBlock3D(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout, + temb_channels=512, + zq_ch=None, + add_conv=False, + gather_norm=False, + normalization=Normalize, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = normalization( + in_channels, + zq_ch=zq_ch, + add_conv=add_conv, + gather=gather_norm, + ) + + self.conv1 = ContextParallelCausalConv3d( + chan_in=in_channels, + chan_out=out_channels, + kernel_size=3, + ) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, out_channels) + self.norm2 = normalization( + out_channels, + zq_ch=zq_ch, + add_conv=add_conv, + gather=gather_norm, + ) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = ContextParallelCausalConv3d( + chan_in=out_channels, + chan_out=out_channels, + kernel_size=3, + ) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = ContextParallelCausalConv3d( + chan_in=in_channels, + chan_out=out_channels, + kernel_size=3, + ) + else: + self.nin_shortcut = Conv3d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + ) + + def forward(self, x, temb, zq=None, clear_fake_cp_cache=True): + h = x + + # if isinstance(self.norm1, torch.nn.GroupNorm): + # h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1) + if zq is not None: + h = self.norm1(h, zq, clear_fake_cp_cache=clear_fake_cp_cache) + else: + h = self.norm1(h) + # if isinstance(self.norm1, torch.nn.GroupNorm): + # h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1) + + h = nonlinearity(h) + h = self.conv1(h, clear_cache=clear_fake_cp_cache) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None, None] + + # if isinstance(self.norm2, torch.nn.GroupNorm): + # h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1) + if zq is not None: + h = self.norm2(h, zq, clear_fake_cp_cache=clear_fake_cp_cache) + else: + h = self.norm2(h) + # if isinstance(self.norm2, torch.nn.GroupNorm): + # h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1) + + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h, clear_cache=clear_fake_cp_cache) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x, clear_cache=clear_fake_cp_cache) + else: + x = self.nin_shortcut(x) + + return x + h + + +class ContextParallelEncoder3D(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + double_z=True, + pad_mode="first", + temporal_compress_times=4, + gather_norm=False, + **ignore_kwargs, + ): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # log2 of temporal_compress_times + self.temporal_compress_level = int(np.log2(temporal_compress_times)) + + self.conv_in = ContextParallelCausalConv3d( + chan_in=in_channels, + chan_out=self.ch, + kernel_size=3, + ) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ContextParallelResnetBlock3D( + in_channels=block_in, + out_channels=block_out, + dropout=dropout, + temb_channels=self.temb_ch, + gather_norm=gather_norm, + ) + ) + block_in = block_out + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + if i_level < self.temporal_compress_level: + down.downsample = DownSample3D(block_in, resamp_with_conv, compress_time=True) + else: + down.downsample = DownSample3D(block_in, resamp_with_conv, compress_time=False) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ContextParallelResnetBlock3D( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + gather_norm=gather_norm, + ) + + self.mid.block_2 = ContextParallelResnetBlock3D( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + gather_norm=gather_norm, + ) + + # end + self.norm_out = Normalize(block_in, gather=gather_norm) + + self.conv_out = ContextParallelCausalConv3d( + chan_in=block_in, + chan_out=2 * z_channels if double_z else z_channels, + kernel_size=3, + ) + + def forward(self, x, **kwargs): + # timestep embedding + temb = None + + # downsampling + h = self.conv_in(x) + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](h, temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + if i_level != self.num_resolutions - 1: + h = self.down[i_level].downsample(h) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.block_2(h, temb) + + # end + # h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1) + h = self.norm_out(h) + # h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1) + + h = nonlinearity(h) + h = self.conv_out(h) + + return h + + +class ContextParallelDecoder3D(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + give_pre_end=False, + zq_ch=None, + add_conv=False, + pad_mode="first", + temporal_compress_times=4, + gather_norm=False, + **ignorekwargs, + ): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + + # log2 of temporal_compress_times + self.temporal_compress_level = int(np.log2(temporal_compress_times)) + + if zq_ch is None: + zq_ch = z_channels + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,) + tuple(ch_mult) + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) + + self.conv_in = ContextParallelCausalConv3d( + chan_in=z_channels, + chan_out=block_in, + kernel_size=3, + ) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ContextParallelResnetBlock3D( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + zq_ch=zq_ch, + add_conv=add_conv, + normalization=Normalize3D, + gather_norm=gather_norm, + ) + + self.mid.block_2 = ContextParallelResnetBlock3D( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + zq_ch=zq_ch, + add_conv=add_conv, + normalization=Normalize3D, + gather_norm=gather_norm, + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + ContextParallelResnetBlock3D( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + zq_ch=zq_ch, + add_conv=add_conv, + normalization=Normalize3D, + gather_norm=gather_norm, + ) + ) + block_in = block_out + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + if i_level < self.num_resolutions - self.temporal_compress_level: + up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=False) + else: + up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=True) + self.up.insert(0, up) + + self.norm_out = Normalize3D(block_in, zq_ch, add_conv=add_conv, gather=gather_norm) + + self.conv_out = ContextParallelCausalConv3d( + chan_in=block_in, + chan_out=out_ch, + kernel_size=3, + ) + + def forward(self, z, clear_fake_cp_cache=True, **kwargs): + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + t = z.shape[2] + # z to block_in + + zq = z + h = self.conv_in(z, clear_cache=clear_fake_cp_cache) + + # middle + h = self.mid.block_1(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache) + h = self.mid.block_2(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h, zq) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h, zq, clear_fake_cp_cache=clear_fake_cp_cache) + h = nonlinearity(h) + h = self.conv_out(h, clear_cache=clear_fake_cp_cache) + + return h + + def get_last_layer(self): + return self.conv_out.conv.weight diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/vae_modules/ema.py b/PyTorch/contrib/cv/video/CogVideoX/sat/vae_modules/ema.py new file mode 100644 index 0000000000000000000000000000000000000000..9f1f7606c2c9b68ebd2302215a9e08f9f31ed8ab --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/vae_modules/ema.py @@ -0,0 +1,82 @@ +import torch +from torch import nn + + +class LitEma(nn.Module): + def __init__(self, model, decay=0.9999, use_num_upates=True): + super().__init__() + if decay < 0.0 or decay > 1.0: + raise ValueError("Decay must be between 0 and 1") + + self.m_name2s_name = {} + self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32)) + self.register_buffer( + "num_updates", + torch.tensor(0, dtype=torch.int) if use_num_upates else torch.tensor(-1, dtype=torch.int), + ) + + for name, p in model.named_parameters(): + if p.requires_grad: + # remove as '.'-character is not allowed in buffers + s_name = name.replace(".", "") + self.m_name2s_name.update({name: s_name}) + self.register_buffer(s_name, p.clone().detach().data) + + self.collected_params = [] + + def reset_num_updates(self): + del self.num_updates + self.register_buffer("num_updates", torch.tensor(0, dtype=torch.int)) + + def forward(self, model): + decay = self.decay + + if self.num_updates >= 0: + self.num_updates += 1 + decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) + + one_minus_decay = 1.0 - decay + + with torch.no_grad(): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + + for key in m_param: + if m_param[key].requires_grad: + sname = self.m_name2s_name[key] + shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) + shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) + else: + assert not key in self.m_name2s_name + + def copy_to(self, model): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + for key in m_param: + if m_param[key].requires_grad: + m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) + else: + assert not key in self.m_name2s_name + + def store(self, parameters): + """ + Save the current parameters for restoring later. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. + """ + self.collected_params = [param.clone() for param in parameters] + + def restore(self, parameters): + """ + Restore the parameters stored with the `store` method. + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before the + `copy_to` method. After validation (or model saving), use this to + restore the former parameters. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. + """ + for c_param, param in zip(self.collected_params, parameters): + param.data.copy_(c_param.data) diff --git a/PyTorch/contrib/cv/video/CogVideoX/sat/vae_modules/utils.py b/PyTorch/contrib/cv/video/CogVideoX/sat/vae_modules/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8c8dba626f250a69c8c28e67f0c7c1c822bc6bc2 --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/sat/vae_modules/utils.py @@ -0,0 +1,404 @@ +import functools +import importlib +import os +from functools import partial +from inspect import isfunction + +import fsspec +import numpy as np +import torch +from PIL import Image, ImageDraw, ImageFont +from safetensors.torch import load_file as load_safetensors +import torch.distributed + +_CONTEXT_PARALLEL_GROUP = None +_CONTEXT_PARALLEL_SIZE = None + + +def is_context_parallel_initialized(): + if _CONTEXT_PARALLEL_GROUP is None: + return False + else: + return True + + +def initialize_context_parallel(context_parallel_size): + global _CONTEXT_PARALLEL_GROUP + global _CONTEXT_PARALLEL_SIZE + + assert _CONTEXT_PARALLEL_GROUP is None, "context parallel group is already initialized" + _CONTEXT_PARALLEL_SIZE = context_parallel_size + + rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + + for i in range(0, world_size, context_parallel_size): + ranks = range(i, i + context_parallel_size) + group = torch.distributed.new_group(ranks) + if rank in ranks: + _CONTEXT_PARALLEL_GROUP = group + break + + +def get_context_parallel_group(): + assert _CONTEXT_PARALLEL_GROUP is not None, "context parallel group is not initialized" + + return _CONTEXT_PARALLEL_GROUP + + +def get_context_parallel_world_size(): + assert _CONTEXT_PARALLEL_SIZE is not None, "context parallel size is not initialized" + + return _CONTEXT_PARALLEL_SIZE + + +def get_context_parallel_rank(): + assert _CONTEXT_PARALLEL_SIZE is not None, "context parallel size is not initialized" + + rank = torch.distributed.get_rank() + cp_rank = rank % _CONTEXT_PARALLEL_SIZE + return cp_rank + + +def get_context_parallel_group_rank(): + assert _CONTEXT_PARALLEL_SIZE is not None, "context parallel size is not initialized" + + rank = torch.distributed.get_rank() + cp_group_rank = rank // _CONTEXT_PARALLEL_SIZE + + return cp_group_rank + + +class SafeConv3d(torch.nn.Conv3d): + def forward(self, input): + memory_count = torch.prod(torch.tensor(input.shape)).item() * 2 / 1024**3 + if memory_count > 2: + kernel_size = self.kernel_size[0] + part_num = int(memory_count / 2) + 1 + input_chunks = torch.chunk(input, part_num, dim=2) # NCTHW + if kernel_size > 1: + input_chunks = [input_chunks[0]] + [ + torch.cat((input_chunks[i - 1][:, :, -kernel_size + 1 :], input_chunks[i]), dim=2) + for i in range(1, len(input_chunks)) + ] + + output_chunks = [] + for input_chunk in input_chunks: + output_chunks.append(super(SafeConv3d, self).forward(input_chunk)) + output = torch.cat(output_chunks, dim=2) + return output + else: + return super(SafeConv3d, self).forward(input) + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def get_string_from_tuple(s): + try: + # Check if the string starts and ends with parentheses + if s[0] == "(" and s[-1] == ")": + # Convert the string to a tuple + t = eval(s) + # Check if the type of t is tuple + if type(t) == tuple: + return t[0] + else: + pass + except: + pass + return s + + +def is_power_of_two(n): + """ + chat.openai.com/chat + Return True if n is a power of 2, otherwise return False. + + The function is_power_of_two takes an integer n as input and returns True if n is a power of 2, otherwise it returns False. + The function works by first checking if n is less than or equal to 0. If n is less than or equal to 0, it can't be a power of 2, so the function returns False. + If n is greater than 0, the function checks whether n is a power of 2 by using a bitwise AND operation between n and n-1. If n is a power of 2, then it will have only one bit set to 1 in its binary representation. When we subtract 1 from a power of 2, all the bits to the right of that bit become 1, and the bit itself becomes 0. So, when we perform a bitwise AND between n and n-1, we get 0 if n is a power of 2, and a non-zero value otherwise. + Thus, if the result of the bitwise AND operation is 0, then n is a power of 2 and the function returns True. Otherwise, the function returns False. + + """ + if n <= 0: + return False + return (n & (n - 1)) == 0 + + +def autocast(f, enabled=True): + def do_autocast(*args, **kwargs): + with torch.cuda.amp.autocast( + enabled=enabled, + dtype=torch.get_autocast_gpu_dtype(), + cache_enabled=torch.is_autocast_cache_enabled(), + ): + return f(*args, **kwargs) + + return do_autocast + + +def load_partial_from_config(config): + return partial(get_obj_from_str(config["target"]), **config.get("params", dict())) + + +def log_txt_as_img(wh, xc, size=10): + # wh a tuple of (width, height) + # xc a list of captions to plot + b = len(xc) + txts = list() + for bi in range(b): + txt = Image.new("RGB", wh, color="white") + draw = ImageDraw.Draw(txt) + font = ImageFont.truetype("data/DejaVuSans.ttf", size=size) + nc = int(40 * (wh[0] / 256)) + if isinstance(xc[bi], list): + text_seq = xc[bi][0] + else: + text_seq = xc[bi] + lines = "\n".join(text_seq[start : start + nc] for start in range(0, len(text_seq), nc)) + + try: + draw.text((0, 0), lines, fill="black", font=font) + except UnicodeEncodeError: + print("Cant encode string for logging. Skipping.") + + txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 + txts.append(txt) + txts = np.stack(txts) + txts = torch.tensor(txts) + return txts + + +def partialclass(cls, *args, **kwargs): + class NewCls(cls): + __init__ = functools.partialmethod(cls.__init__, *args, **kwargs) + + return NewCls + + +def make_path_absolute(path): + fs, p = fsspec.core.url_to_fs(path) + if fs.protocol == "file": + return os.path.abspath(p) + return path + + +def ismap(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] > 3) + + +def isimage(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) + + +def isheatmap(x): + if not isinstance(x, torch.Tensor): + return False + + return x.ndim == 2 + + +def isneighbors(x): + if not isinstance(x, torch.Tensor): + return False + return x.ndim == 5 and (x.shape[2] == 3 or x.shape[2] == 1) + + +def exists(x): + return x is not None + + +def expand_dims_like(x, y): + while x.dim() != y.dim(): + x = x.unsqueeze(-1) + return x + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def mean_flat(tensor): + """ + https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def count_params(model, verbose=False): + total_params = sum(p.numel() for p in model.parameters()) + if verbose: + print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") + return total_params + + +def instantiate_from_config(config): + if not "target" in config: + if config == "__is_first_stage__": + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def get_obj_from_str(string, reload=False, invalidate_cache=True): + module, cls = string.rsplit(".", 1) + if invalidate_cache: + importlib.invalidate_caches() + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +def append_zero(x): + return torch.cat([x, x.new_zeros([1])]) + + +def append_dims(x, target_dims): + """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less") + return x[(...,) + (None,) * dims_to_append] + + +def load_model_from_config(config, ckpt, verbose=True, freeze=True): + print(f"Loading model from {ckpt}") + if ckpt.endswith("ckpt"): + pl_sd = torch.load(ckpt, map_location="cpu") + if "global_step" in pl_sd: + print(f"Global Step: {pl_sd['global_step']}") + sd = pl_sd["state_dict"] + elif ckpt.endswith("safetensors"): + sd = load_safetensors(ckpt) + else: + raise NotImplementedError + + model = instantiate_from_config(config.model) + + m, u = model.load_state_dict(sd, strict=False) + + if len(m) > 0 and verbose: + print("missing keys:") + print(m) + if len(u) > 0 and verbose: + print("unexpected keys:") + print(u) + + if freeze: + for param in model.parameters(): + param.requires_grad = False + + model.eval() + return model + + +def get_configs_path() -> str: + """ + Get the `configs` directory. + For a working copy, this is the one in the root of the repository, + but for an installed copy, it's in the `sgm` package (see pyproject.toml). + """ + this_dir = os.path.dirname(__file__) + candidates = ( + os.path.join(this_dir, "configs"), + os.path.join(this_dir, "..", "configs"), + ) + for candidate in candidates: + candidate = os.path.abspath(candidate) + if os.path.isdir(candidate): + return candidate + raise FileNotFoundError(f"Could not find SGM configs in {candidates}") + + +def get_nested_attribute(obj, attribute_path, depth=None, return_key=False): + """ + Will return the result of a recursive get attribute call. + E.g.: + a.b.c + = getattr(getattr(a, "b"), "c") + = get_nested_attribute(a, "b.c") + If any part of the attribute call is an integer x with current obj a, will + try to call a[x] instead of a.x first. + """ + attributes = attribute_path.split(".") + if depth is not None and depth > 0: + attributes = attributes[:depth] + assert len(attributes) > 0, "At least one attribute should be selected" + current_attribute = obj + current_key = None + for level, attribute in enumerate(attributes): + current_key = ".".join(attributes[: level + 1]) + try: + id_ = int(attribute) + current_attribute = current_attribute[id_] + except ValueError: + current_attribute = getattr(current_attribute, attribute) + + return (current_attribute, current_key) if return_key else current_attribute + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + ctx.gpu_autocast_kwargs = { + "enabled": torch.is_autocast_enabled(), + "dtype": torch.get_autocast_gpu_dtype(), + "cache_enabled": torch.is_autocast_cache_enabled(), + } + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads diff --git a/PyTorch/contrib/cv/video/CogVideoX/tools/convert_weight_sat2hf.py b/PyTorch/contrib/cv/video/CogVideoX/tools/convert_weight_sat2hf.py new file mode 100644 index 0000000000000000000000000000000000000000..183be6257b5d8a5a821c96d9653dce44e1628db2 --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/tools/convert_weight_sat2hf.py @@ -0,0 +1,304 @@ +""" +This script demonstrates how to convert and generate video from a text prompt +using CogVideoX with 🤗Huggingface Diffusers Pipeline. +This script requires the `diffusers>=0.30.2` library to be installed. + +Functions: + - reassign_query_key_value_inplace: Reassigns the query, key, and value weights in-place. + - reassign_query_key_layernorm_inplace: Reassigns layer normalization for query and key in-place. + - reassign_adaln_norm_inplace: Reassigns adaptive layer normalization in-place. + - remove_keys_inplace: Removes specified keys from the state_dict in-place. + - replace_up_keys_inplace: Replaces keys in the "up" block in-place. + - get_state_dict: Extracts the state_dict from a saved checkpoint. + - update_state_dict_inplace: Updates the state_dict with new key assignments in-place. + - convert_transformer: Converts a transformer checkpoint to the CogVideoX format. + - convert_vae: Converts a VAE checkpoint to the CogVideoX format. + - get_args: Parses command-line arguments for the script. + - generate_video: Generates a video from a text prompt using the CogVideoX pipeline. +""" + +import argparse +from typing import Any, Dict + +import torch +from transformers import T5EncoderModel, T5Tokenizer + +from diffusers import ( + AutoencoderKLCogVideoX, + CogVideoXDDIMScheduler, + CogVideoXImageToVideoPipeline, + CogVideoXPipeline, + CogVideoXTransformer3DModel, +) + + +def reassign_query_key_value_inplace(key: str, state_dict: Dict[str, Any]): + to_q_key = key.replace("query_key_value", "to_q") + to_k_key = key.replace("query_key_value", "to_k") + to_v_key = key.replace("query_key_value", "to_v") + to_q, to_k, to_v = torch.chunk(state_dict[key], chunks=3, dim=0) + state_dict[to_q_key] = to_q + state_dict[to_k_key] = to_k + state_dict[to_v_key] = to_v + state_dict.pop(key) + + +def reassign_query_key_layernorm_inplace(key: str, state_dict: Dict[str, Any]): + layer_id, weight_or_bias = key.split(".")[-2:] + + if "query" in key: + new_key = f"transformer_blocks.{layer_id}.attn1.norm_q.{weight_or_bias}" + elif "key" in key: + new_key = f"transformer_blocks.{layer_id}.attn1.norm_k.{weight_or_bias}" + + state_dict[new_key] = state_dict.pop(key) + + +def reassign_adaln_norm_inplace(key: str, state_dict: Dict[str, Any]): + layer_id, _, weight_or_bias = key.split(".")[-3:] + + weights_or_biases = state_dict[key].chunk(12, dim=0) + norm1_weights_or_biases = torch.cat(weights_or_biases[0:3] + weights_or_biases[6:9]) + norm2_weights_or_biases = torch.cat(weights_or_biases[3:6] + weights_or_biases[9:12]) + + norm1_key = f"transformer_blocks.{layer_id}.norm1.linear.{weight_or_bias}" + state_dict[norm1_key] = norm1_weights_or_biases + + norm2_key = f"transformer_blocks.{layer_id}.norm2.linear.{weight_or_bias}" + state_dict[norm2_key] = norm2_weights_or_biases + + state_dict.pop(key) + + +def remove_keys_inplace(key: str, state_dict: Dict[str, Any]): + state_dict.pop(key) + + +def replace_up_keys_inplace(key: str, state_dict: Dict[str, Any]): + key_split = key.split(".") + layer_index = int(key_split[2]) + replace_layer_index = 4 - 1 - layer_index + + key_split[1] = "up_blocks" + key_split[2] = str(replace_layer_index) + new_key = ".".join(key_split) + + state_dict[new_key] = state_dict.pop(key) + + +TRANSFORMER_KEYS_RENAME_DICT = { + "transformer.final_layernorm": "norm_final", + "transformer": "transformer_blocks", + "attention": "attn1", + "mlp": "ff.net", + "dense_h_to_4h": "0.proj", + "dense_4h_to_h": "2", + ".layers": "", + "dense": "to_out.0", + "input_layernorm": "norm1.norm", + "post_attn1_layernorm": "norm2.norm", + "time_embed.0": "time_embedding.linear_1", + "time_embed.2": "time_embedding.linear_2", + "mixins.patch_embed": "patch_embed", + "mixins.final_layer.norm_final": "norm_out.norm", + "mixins.final_layer.linear": "proj_out", + "mixins.final_layer.adaLN_modulation.1": "norm_out.linear", + "mixins.pos_embed.pos_embedding": "patch_embed.pos_embedding", # Specific to CogVideoX-5b-I2V +} + +TRANSFORMER_SPECIAL_KEYS_REMAP = { + "query_key_value": reassign_query_key_value_inplace, + "query_layernorm_list": reassign_query_key_layernorm_inplace, + "key_layernorm_list": reassign_query_key_layernorm_inplace, + "adaln_layer.adaLN_modulations": reassign_adaln_norm_inplace, + "embed_tokens": remove_keys_inplace, + "freqs_sin": remove_keys_inplace, + "freqs_cos": remove_keys_inplace, + "position_embedding": remove_keys_inplace, +} + +VAE_KEYS_RENAME_DICT = { + "block.": "resnets.", + "down.": "down_blocks.", + "downsample": "downsamplers.0", + "upsample": "upsamplers.0", + "nin_shortcut": "conv_shortcut", + "encoder.mid.block_1": "encoder.mid_block.resnets.0", + "encoder.mid.block_2": "encoder.mid_block.resnets.1", + "decoder.mid.block_1": "decoder.mid_block.resnets.0", + "decoder.mid.block_2": "decoder.mid_block.resnets.1", +} + +VAE_SPECIAL_KEYS_REMAP = { + "loss": remove_keys_inplace, + "up.": replace_up_keys_inplace, +} + +TOKENIZER_MAX_LENGTH = 226 + + +def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]: + state_dict = saved_dict + if "model" in saved_dict.keys(): + state_dict = state_dict["model"] + if "module" in saved_dict.keys(): + state_dict = state_dict["module"] + if "state_dict" in saved_dict.keys(): + state_dict = state_dict["state_dict"] + return state_dict + + +def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]: + state_dict[new_key] = state_dict.pop(old_key) + + +def convert_transformer( + ckpt_path: str, + num_layers: int, + num_attention_heads: int, + use_rotary_positional_embeddings: bool, + i2v: bool, + dtype: torch.dtype, +): + PREFIX_KEY = "model.diffusion_model." + + original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True)) + transformer = CogVideoXTransformer3DModel( + in_channels=32 if i2v else 16, + num_layers=num_layers, + num_attention_heads=num_attention_heads, + use_rotary_positional_embeddings=use_rotary_positional_embeddings, + use_learned_positional_embeddings=i2v, + ).to(dtype=dtype) + + for key in list(original_state_dict.keys()): + new_key = key[len(PREFIX_KEY) :] + for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + update_state_dict_inplace(original_state_dict, key, new_key) + + for key in list(original_state_dict.keys()): + for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items(): + if special_key not in key: + continue + handler_fn_inplace(key, original_state_dict) + transformer.load_state_dict(original_state_dict, strict=True) + return transformer + + +def convert_vae(ckpt_path: str, scaling_factor: float, dtype: torch.dtype): + original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True)) + vae = AutoencoderKLCogVideoX(scaling_factor=scaling_factor).to(dtype=dtype) + + for key in list(original_state_dict.keys()): + new_key = key[:] + for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + update_state_dict_inplace(original_state_dict, key, new_key) + + for key in list(original_state_dict.keys()): + for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items(): + if special_key not in key: + continue + handler_fn_inplace(key, original_state_dict) + + vae.load_state_dict(original_state_dict, strict=True) + return vae + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint") + parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint") + parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") + parser.add_argument("--fp16", action="store_true", default=False, help="Whether to save the model weights in fp16") + parser.add_argument("--bf16", action="store_true", default=False, help="Whether to save the model weights in bf16") + parser.add_argument( + "--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving" + ) + parser.add_argument( + "--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory" + ) + # For CogVideoX-2B, num_layers is 30. For 5B, it is 42 + parser.add_argument("--num_layers", type=int, default=30, help="Number of transformer blocks") + # For CogVideoX-2B, num_attention_heads is 30. For 5B, it is 48 + parser.add_argument("--num_attention_heads", type=int, default=30, help="Number of attention heads") + # For CogVideoX-2B, use_rotary_positional_embeddings is False. For 5B, it is True + parser.add_argument( + "--use_rotary_positional_embeddings", action="store_true", default=False, help="Whether to use RoPE or not" + ) + # For CogVideoX-2B, scaling_factor is 1.15258426. For 5B, it is 0.7 + parser.add_argument("--scaling_factor", type=float, default=1.15258426, help="Scaling factor in the VAE") + # For CogVideoX-2B, snr_shift_scale is 3.0. For 5B, it is 1.0 + parser.add_argument("--snr_shift_scale", type=float, default=3.0, help="Scaling factor in the VAE") + parser.add_argument("--i2v", action="store_true", default=False, help="Whether to save the model weights in fp16") + return parser.parse_args() + + +if __name__ == "__main__": + args = get_args() + + transformer = None + vae = None + + if args.fp16 and args.bf16: + raise ValueError("You cannot pass both --fp16 and --bf16 at the same time.") + + dtype = torch.float16 if args.fp16 else torch.bfloat16 if args.bf16 else torch.float32 + + if args.transformer_ckpt_path is not None: + transformer = convert_transformer( + args.transformer_ckpt_path, + args.num_layers, + args.num_attention_heads, + args.use_rotary_positional_embeddings, + args.i2v, + dtype, + ) + if args.vae_ckpt_path is not None: + vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, dtype) + + text_encoder_id = "google/t5-v1_1-xxl" + tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH) + text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir) + # Apparently, the conversion does not work anymore without this :shrug: + for param in text_encoder.parameters(): + param.data = param.data.contiguous() + + scheduler = CogVideoXDDIMScheduler.from_config( + { + "snr_shift_scale": args.snr_shift_scale, + "beta_end": 0.012, + "beta_schedule": "scaled_linear", + "beta_start": 0.00085, + "clip_sample": False, + "num_train_timesteps": 1000, + "prediction_type": "v_prediction", + "rescale_betas_zero_snr": True, + "set_alpha_to_one": True, + "timestep_spacing": "trailing", + } + ) + if args.i2v: + pipeline_cls = CogVideoXImageToVideoPipeline + else: + pipeline_cls = CogVideoXPipeline + + pipe = pipeline_cls( + tokenizer=tokenizer, + text_encoder=text_encoder, + vae=vae, + transformer=transformer, + scheduler=scheduler, + ) + + if args.fp16: + pipe = pipe.to(dtype=torch.float16) + if args.bf16: + pipe = pipe.to(dtype=torch.bfloat16) + + # We don't use variant here because the model must be run in fp16 (2B) or bf16 (5B). It would be weird + # for users to specify variant when the default is not fp32 and they want to run with the correct default (which + # is either fp16/bf16 here). + pipe.save_pretrained(args.output_path, safe_serialization=True, push_to_hub=args.push_to_hub) diff --git a/PyTorch/contrib/cv/video/CogVideoX/tools/export_sat_lora_weight.py b/PyTorch/contrib/cv/video/CogVideoX/tools/export_sat_lora_weight.py new file mode 100644 index 0000000000000000000000000000000000000000..9340d8f2996d55f67a72ddc8c92541465d612c9b --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/tools/export_sat_lora_weight.py @@ -0,0 +1,83 @@ +from typing import Any, Dict +import torch +import argparse +from diffusers.loaders.lora_base import LoraBaseMixin +from diffusers.models.modeling_utils import load_state_dict + + +def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]: + state_dict = saved_dict + if "model" in saved_dict.keys(): + state_dict = state_dict["model"] + if "module" in saved_dict.keys(): + state_dict = state_dict["module"] + if "state_dict" in saved_dict.keys(): + state_dict = state_dict["state_dict"] + return state_dict + +LORA_KEYS_RENAME = { + + 'attention.query_key_value.matrix_A.0': 'attn1.to_q.lora_A.weight', + 'attention.query_key_value.matrix_A.1': 'attn1.to_k.lora_A.weight', + 'attention.query_key_value.matrix_A.2': 'attn1.to_v.lora_A.weight', + 'attention.query_key_value.matrix_B.0': 'attn1.to_q.lora_B.weight', + 'attention.query_key_value.matrix_B.1': 'attn1.to_k.lora_B.weight', + 'attention.query_key_value.matrix_B.2': 'attn1.to_v.lora_B.weight', + 'attention.dense.matrix_A.0': 'attn1.to_out.0.lora_A.weight', + 'attention.dense.matrix_B.0': 'attn1.to_out.0.lora_B.weight' +} + + + +PREFIX_KEY = "model.diffusion_model." +SAT_UNIT_KEY = "layers" +LORA_PREFIX_KEY = "transformer_blocks" + + + +def export_lora_weight(ckpt_path,lora_save_directory): + + merge_original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True)) + + + lora_state_dict = {} + for key in list(merge_original_state_dict.keys()): + new_key = key[len(PREFIX_KEY) :] + for special_key, lora_keys in LORA_KEYS_RENAME.items(): + if new_key.endswith(special_key): + new_key = new_key.replace(special_key, lora_keys) + new_key = new_key.replace(SAT_UNIT_KEY, LORA_PREFIX_KEY) + + lora_state_dict[new_key] = merge_original_state_dict[key] + + + + # final length should be 240 + if len(lora_state_dict) != 240: + raise ValueError("lora_state_dict length is not 240") + + lora_state_dict.keys() + + LoraBaseMixin.write_lora_layers( + state_dict=lora_state_dict, + save_directory=lora_save_directory, + is_main_process=True, + weight_name=None, + save_function=None, + safe_serialization=True + ) + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--sat_pt_path", type=str, required=True, help="Path to original sat transformer checkpoint" + ) + parser.add_argument("--lora_save_directory", type=str, required=True, help="Path where converted lora should be saved") + return parser.parse_args() + + +if __name__ == "__main__": + args = get_args() + + export_lora_weight(args.sat_pt_path, args.lora_save_directory) diff --git a/PyTorch/contrib/cv/video/CogVideoX/tools/llm_flux_cogvideox/generate.sh b/PyTorch/contrib/cv/video/CogVideoX/tools/llm_flux_cogvideox/generate.sh new file mode 100644 index 0000000000000000000000000000000000000000..c455273d90d4468e2d1bf86053855345f3ee6411 --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/tools/llm_flux_cogvideox/generate.sh @@ -0,0 +1,35 @@ +#!/bin/bash + +NUM_VIDEOS=10 +INFERENCE_STEPS=50 +GUIDANCE_SCALE=7.0 +OUTPUT_DIR_PREFIX="outputs/gpu_" +LOG_DIR_PREFIX="logs/gpu_" + +VIDEO_MODEL_PATH="/share/official_pretrains/hf_home/CogVideoX-5b-I2V" +LLM_MODEL_PATH="/share/home/zyx/Models/Meta-Llama-3.1-8B-Instruct" +IMAGE_MODEL_PATH = "share/home/zyx/Models/FLUX.1-dev" + +#VIDEO_MODEL_PATH="THUDM/CogVideoX-5B-I2V" +#LLM_MODEL_PATH="THUDM/glm-4-9b-chat" +#IMAGE_MODEL_PATH = "black-forest-labs/FLUX.1-dev" + +CUDA_DEVICES=${CUDA_VISIBLE_DEVICES:-"0"} + +IFS=',' read -r -a GPU_ARRAY <<< "$CUDA_DEVICES" + +for i in "${!GPU_ARRAY[@]}" +do + GPU=${GPU_ARRAY[$i]} + echo "Starting task on GPU $GPU..." + CUDA_VISIBLE_DEVICES=$GPU nohup python3 llm_flux_cogvideox.py \ + --caption_generator_model_id $LLM_MODEL_PATH \ + --image_generator_model_id $IMAGE_MODEL_PATH \ + --model_path $VIDEO_MODEL_PATH \ + --num_videos $NUM_VIDEOS \ + --image_generator_num_inference_steps $INFERENCE_STEPS \ + --guidance_scale $GUIDANCE_SCALE \ + --use_dynamic_cfg \ + --output_dir ${OUTPUT_DIR_PREFIX}${GPU} \ + > ${LOG_DIR_PREFIX}${GPU}.log 2>&1 & +done \ No newline at end of file diff --git a/PyTorch/contrib/cv/video/CogVideoX/tools/llm_flux_cogvideox/gradio_page.py b/PyTorch/contrib/cv/video/CogVideoX/tools/llm_flux_cogvideox/gradio_page.py new file mode 100644 index 0000000000000000000000000000000000000000..0116c19fea979c2392a7606272b6e96a47082ed0 --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/tools/llm_flux_cogvideox/gradio_page.py @@ -0,0 +1,194 @@ +import os +import gradio as gr +import gc +import random +import torch +import numpy as np +from PIL import Image +import transformers +from diffusers import CogVideoXImageToVideoPipeline, CogVideoXDPMScheduler, DiffusionPipeline +from diffusers.utils import export_to_video +from transformers import AutoTokenizer +from datetime import datetime, timedelta +import threading +import time +import moviepy.editor as mp + +torch.set_float32_matmul_precision("high") + +# Set default values +caption_generator_model_id = "/share/home/zyx/Models/Meta-Llama-3.1-8B-Instruct" +image_generator_model_id = "/share/home/zyx/Models/FLUX.1-dev" +video_generator_model_id = "/share/official_pretrains/hf_home/CogVideoX-5b-I2V" +seed = 1337 + +os.makedirs("./output", exist_ok=True) +os.makedirs("./gradio_tmp", exist_ok=True) + +tokenizer = AutoTokenizer.from_pretrained(caption_generator_model_id, trust_remote_code=True) +caption_generator = transformers.pipeline( + "text-generation", + model=caption_generator_model_id, + device_map="balanced", + model_kwargs={ + "local_files_only": True, + "torch_dtype": torch.bfloat16, + }, + trust_remote_code=True, + tokenizer=tokenizer +) + +image_generator = DiffusionPipeline.from_pretrained( + image_generator_model_id, + torch_dtype=torch.bfloat16, + device_map="balanced" +) +# image_generator.to("cuda") + +video_generator = CogVideoXImageToVideoPipeline.from_pretrained( + video_generator_model_id, + torch_dtype=torch.bfloat16, + device_map="balanced" +) + +video_generator.vae.enable_slicing() +video_generator.vae.enable_tiling() + +video_generator.scheduler = CogVideoXDPMScheduler.from_config( + video_generator.scheduler.config, timestep_spacing="trailing" +) + +# Define prompts +SYSTEM_PROMPT = """ +You are part of a team of people that create videos using generative models. You use a video-generation model that can generate a video about anything you describe. + +For example, if you respond with "A beautiful morning in the woods with the sun peaking through the trees", the video generation model will create a video of exactly as described. Your task is to summarize the descriptions of videos provided by users and create detailed prompts to feed into the generative model. + +There are a few rules to follow: +- You will only ever output a single video description per request. +- If the user mentions to summarize the prompt in [X] words, make sure not to exceed the limit. + +Your responses should just be the video generation prompt. Here are examples: +- "A detailed wooden toy ship with intricately carved masts and sails is seen gliding smoothly over a plush, blue carpet that mimics the waves of the sea. The ship's hull is painted a rich brown, with tiny windows. The carpet, soft and textured, provides a perfect backdrop, resembling an oceanic expanse. Surrounding the ship are various other toys and children's items, hinting at a playful environment. The scene captures the innocence and imagination of childhood, with the toy ship's journey symbolizing endless adventures in a whimsical, indoor setting." +- "A street artist, clad in a worn-out denim jacket and a colorful bandana, stands before a vast concrete wall in the heart of the city, holding a can of spray paint, spray-painting a colorful bird on a mottled wall." +""".strip() + +USER_PROMPT = """ +Could you generate a prompt for a video generation model? Please limit the prompt to [{0}] words. +""".strip() + + +def generate_caption(prompt): + num_words = random.choice([25, 50, 75, 100]) + user_prompt = USER_PROMPT.format(num_words) + + messages = [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": prompt + "\n" + user_prompt}, + ] + + response = caption_generator( + messages, + max_new_tokens=226, + return_full_text=False + ) + caption = response[0]["generated_text"] + if caption.startswith("\"") and caption.endswith("\""): + caption = caption[1:-1] + return caption + + +def generate_image(caption, progress=gr.Progress(track_tqdm=True)): + image = image_generator( + prompt=caption, + height=480, + width=720, + num_inference_steps=30, + guidance_scale=3.5, + ).images[0] + return image, image # One for output One for State + + +def generate_video( + caption, + image, + progress=gr.Progress(track_tqdm=True) +): + generator = torch.Generator().manual_seed(seed) + video_frames = video_generator( + image=image, + prompt=caption, + height=480, + width=720, + num_frames=49, + num_inference_steps=50, + guidance_scale=6, + use_dynamic_cfg=True, + generator=generator, + ).frames[0] + video_path = save_video(video_frames) + gif_path = convert_to_gif(video_path) + return video_path, gif_path + + +def save_video(tensor): + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + video_path = f"./output/{timestamp}.mp4" + os.makedirs(os.path.dirname(video_path), exist_ok=True) + export_to_video(tensor, video_path, fps=8) + return video_path + + +def convert_to_gif(video_path): + clip = mp.VideoFileClip(video_path) + clip = clip.set_fps(8) + clip = clip.resize(height=240) + gif_path = video_path.replace(".mp4", ".gif") + clip.write_gif(gif_path, fps=8) + return gif_path + + +def delete_old_files(): + while True: + now = datetime.now() + cutoff = now - timedelta(minutes=10) + directories = ["./output", "./gradio_tmp"] + + for directory in directories: + for filename in os.listdir(directory): + file_path = os.path.join(directory, filename) + if os.path.isfile(file_path): + file_mtime = datetime.fromtimestamp(os.path.getmtime(file_path)) + if file_mtime < cutoff: + os.remove(file_path) + time.sleep(600) + + +threading.Thread(target=delete_old_files, daemon=True).start() + +with gr.Blocks() as demo: + gr.Markdown(""" +
+ LLM + FLUX + CogVideoX-I2V Space 🤗 +
+ """) + with gr.Row(): + with gr.Column(): + prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here", lines=5) + generate_caption_button = gr.Button("Generate Caption") + caption = gr.Textbox(label="Caption", placeholder="Caption will appear here", lines=5) + generate_image_button = gr.Button("Generate Image") + image_output = gr.Image(label="Generated Image") + state_image = gr.State() + generate_caption_button.click(fn=generate_caption, inputs=prompt, outputs=caption) + generate_image_button.click(fn=generate_image, inputs=caption, outputs=[image_output, state_image]) + with gr.Column(): + video_output = gr.Video(label="Generated Video", width=720, height=480) + download_video_button = gr.File(label="📥 Download Video", visible=False) + download_gif_button = gr.File(label="📥 Download GIF", visible=False) + generate_video_button = gr.Button("Generate Video from Image") + generate_video_button.click(fn=generate_video, inputs=[caption, state_image], + outputs=[video_output, download_gif_button]) + +if __name__ == "__main__": + demo.launch() diff --git a/PyTorch/contrib/cv/video/CogVideoX/tools/llm_flux_cogvideox/llm_flux_cogvideox.py b/PyTorch/contrib/cv/video/CogVideoX/tools/llm_flux_cogvideox/llm_flux_cogvideox.py new file mode 100644 index 0000000000000000000000000000000000000000..8e97888fb4bb6a45acdbf3873b450ea2ddbb33ed --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/tools/llm_flux_cogvideox/llm_flux_cogvideox.py @@ -0,0 +1,257 @@ +""" +The original experimental code for this project can be found at: + +https://gist.github.com/a-r-r-o-w/d070cce059ab4ceab3a9f289ff83c69c + +By using this code, description prompts will be generated through a local large language model, and images will be +generated using the black-forest-labs/FLUX.1-dev model, followed by video generation via CogVideoX. +The entire process utilizes open-source solutions, without the need for any API keys. + +You can use the generate.sh file in the same folder to automate running this code +for batch generation of videos and images. + +bash generate.sh + +""" + +import argparse +import gc +import json +import os +import pathlib +import random +from typing import Any, Dict + +from transformers import AutoTokenizer + +os.environ["TORCH_LOGS"] = "+dynamo,recompiles,graph_breaks" +os.environ["TORCHDYNAMO_VERBOSE"] = "1" + +import numpy as np +import torch +import transformers +from diffusers import CogVideoXImageToVideoPipeline, CogVideoXDPMScheduler, DiffusionPipeline +from diffusers.utils.logging import get_logger +from diffusers.utils import export_to_video + +torch.set_float32_matmul_precision("high") + +logger = get_logger(__name__) + +SYSTEM_PROMPT = """ +You are part of a team of people that create videos using generative models. You use a video-generation model that can generate a video about anything you describe. + +For example, if you respond with "A beautiful morning in the woods with the sun peaking through the trees", the video generation model will create a video of exactly as described. You task is to summarize the descriptions of videos provided to by users, and create details prompts to feed into the generative model. + +There are a few rules to follow: +- You will only ever output a single video description per request. +- If the user mentions to summarize the prompt in [X] words, make sure to not exceed the limit. + +You responses should just be the video generation prompt. Here are examples: +- “A lone figure stands on a city rooftop at night, gazing up at the full moon. The moon glows brightly, casting a gentle light over the quiet cityscape. Below, the windows of countless homes shine with warm lights, creating a contrast between the bustling life below and the peaceful solitude above. The scene captures the essence of the Mid-Autumn Festival, where despite the distance, the figure feels connected to loved ones through the shared beauty of the moonlit sky.” +- "A detailed wooden toy ship with intricately carved masts and sails is seen gliding smoothly over a plush, blue carpet that mimics the waves of the sea. The ship's hull is painted a rich brown, with tiny windows. The carpet, soft and textured, provides a perfect backdrop, resembling an oceanic expanse. Surrounding the ship are various other toys and children's items, hinting at a playful environment. The scene captures the innocence and imagination of childhood, with the toy ship's journey symbolizing endless adventures in a whimsical, indoor setting." +- "A street artist, clad in a worn-out denim jacket and a colorful banana, stands before a vast concrete wall in the heart, holding a can of spray paint, spray-painting a colorful bird on a mottled wall" +""".strip() + +USER_PROMPT = """ +Could you generate a prompt for a video generation model? +Please limit the prompt to [{0}] words. +""".strip() + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--num_videos", + type=int, + default=5, + help="Number of unique videos you would like to generate." + ) + parser.add_argument( + "--model_path", + type=str, + default="THUDM/CogVideoX-5B", + help="The path of Image2Video CogVideoX-5B", + ) + parser.add_argument( + "--caption_generator_model_id", + type=str, + default="THUDM/glm-4-9b-chat", + help="Caption generation model. default GLM-4-9B", + ) + parser.add_argument( + "--caption_generator_cache_dir", + type=str, + default=None, + help="Cache directory for caption generation model." + ) + parser.add_argument( + "--image_generator_model_id", + type=str, + default="black-forest-labs/FLUX.1-dev", + help="Image generation model." + ) + parser.add_argument( + "--image_generator_cache_dir", + type=str, + default=None, + help="Cache directory for image generation model." + ) + parser.add_argument( + "--image_generator_num_inference_steps", + type=int, + default=50, + help="Caption generation model." + ) + parser.add_argument( + "--guidance_scale", + type=float, + default=7, + help="Guidance scale to be use for generation." + ) + parser.add_argument( + "--use_dynamic_cfg", + action="store_true", + help="Whether or not to use cosine dynamic guidance for generation [Recommended].", + ) + parser.add_argument( + "--output_dir", + type=str, + default="outputs/", + help="Location where generated images and videos should be stored.", + ) + parser.add_argument( + "--compile", + action="store_true", + help="Whether or not to compile the transformer of image and video generators." + ) + parser.add_argument( + "--enable_vae_tiling", + action="store_true", + help="Whether or not to use VAE tiling when encoding/decoding." + ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="Seed for reproducibility." + ) + return parser.parse_args() + + +def reset_memory(): + gc.collect() + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + torch.cuda.reset_accumulated_memory_stats() + + +@torch.no_grad() +def main(args: Dict[str, Any]) -> None: + output_dir = pathlib.Path(args.output_dir) + os.makedirs(output_dir.as_posix(), exist_ok=True) + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + reset_memory() + tokenizer = AutoTokenizer.from_pretrained(args.caption_generator_model_id, trust_remote_code=True) + caption_generator = transformers.pipeline( + "text-generation", + model=args.caption_generator_model_id, + device_map="auto", + model_kwargs={ + "local_files_only": True, + "cache_dir": args.caption_generator_cache_dir, + "torch_dtype": torch.bfloat16, + }, + trust_remote_code=True, + tokenizer=tokenizer + ) + + captions = [] + for i in range(args.num_videos): + num_words = random.choice([50, 75, 100]) + user_prompt = USER_PROMPT.format(num_words) + + messages = [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": user_prompt}, + ] + + outputs = caption_generator(messages, max_new_tokens=226) + caption = outputs[0]["generated_text"][-1]["content"] + if caption.startswith("\"") and caption.endswith("\""): + caption = caption[1:-1] + captions.append(caption) + logger.info(f"Generated caption: {caption}") + + with open(output_dir / "captions.json", "w") as file: + json.dump(captions, file) + + del caption_generator + reset_memory() + + image_generator = DiffusionPipeline.from_pretrained( + args.image_generator_model_id, + cache_dir=args.image_generator_cache_dir, + torch_dtype=torch.bfloat16 + ) + image_generator.to("cuda") + + if args.compile: + image_generator.transformer = torch.compile(image_generator.transformer, mode="max-autotune", fullgraph=True) + + if args.enable_vae_tiling: + image_generator.vae.enable_tiling() + + images = [] + for index, caption in enumerate(captions): + image = image_generator( + prompt=caption, + height=480, + width=720, + num_inference_steps=args.image_generator_num_inference_steps, + guidance_scale=3.5, + ).images[0] + filename = caption[:25].replace(".", "_").replace("'", "_").replace('"', "_").replace(",", "_") + image.save(output_dir / f"{index}_{filename}.png") + images.append(image) + + del image_generator + reset_memory() + + video_generator = CogVideoXImageToVideoPipeline.from_pretrained( + args.model_path, torch_dtype=torch.bfloat16).to("cuda") + video_generator.scheduler = CogVideoXDPMScheduler.from_config( + video_generator.scheduler.config, + timestep_spacing="trailing") + + if args.compile: + video_generator.transformer = torch.compile(video_generator.transformer, mode="max-autotune", fullgraph=True) + + if args.enable_vae_tiling: + video_generator.vae.enable_tiling() + + generator = torch.Generator().manual_seed(args.seed) + for index, (caption, image) in enumerate(zip(captions, images)): + video = video_generator( + image=image, + prompt=caption, + height=480, + width=720, + num_frames=49, + num_inference_steps=50, + guidance_scale=args.guidance_scale, + use_dynamic_cfg=args.use_dynamic_cfg, + generator=generator, + ).frames[0] + filename = caption[:25].replace(".", "_").replace("'", "_").replace('"', "_").replace(",", "_") + export_to_video(video, output_dir / f"{index}_{filename}.mp4", fps=8) + + +if __name__ == "__main__": + args = get_args() + main(args) diff --git a/PyTorch/contrib/cv/video/CogVideoX/tools/load_cogvideox_lora.py b/PyTorch/contrib/cv/video/CogVideoX/tools/load_cogvideox_lora.py new file mode 100644 index 0000000000000000000000000000000000000000..1b129755f1a261e55ff6e9a118179cecbbbd4f52 --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/tools/load_cogvideox_lora.py @@ -0,0 +1,125 @@ +# Copyright 2024 The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import math +import random +import time +from diffusers.utils import export_to_video +from diffusers.image_processor import VaeImageProcessor +from datetime import datetime, timedelta +from diffusers import CogVideoXPipeline, CogVideoXDDIMScheduler, CogVideoXDPMScheduler +import os +import torch +import argparse + + +device = "cuda" if torch.cuda.is_available() else "cpu" + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--lora_weights_path", + type=str, + default=None, + required=True, + help="Path to lora weights.", + ) + parser.add_argument( + "--lora_r", + type=int, + default=128, + help="""LoRA weights have a rank parameter, with the default for 2B trans set at 128 and 5B trans set at 256. + This part is used to calculate the value for lora_scale, which is by default divided by the alpha value, + used for stable learning and to prevent underflow. In the SAT training framework, + alpha is set to 1 by default. The higher the rank, the better the expressive capability, + but it requires more memory and training time. Increasing this number blindly isn't always better. + The formula for lora_scale is: lora_r / alpha. + """, + ) + parser.add_argument( + "--lora_alpha", + type=int, + default=1, + help="""LoRA weights have a rank parameter, with the default for 2B trans set at 128 and 5B trans set at 256. + This part is used to calculate the value for lora_scale, which is by default divided by the alpha value, + used for stable learning and to prevent underflow. In the SAT training framework, + alpha is set to 1 by default. The higher the rank, the better the expressive capability, + but it requires more memory and training time. Increasing this number blindly isn't always better. + The formula for lora_scale is: lora_r / alpha. + """, + ) + parser.add_argument( + "--prompt", + type=str, + help="prompt", + ) + parser.add_argument( + "--output_dir", + type=str, + default="output", + help="The output directory where the model predictions and checkpoints will be written.", + ) + return parser.parse_args() + + +if __name__ == "__main__": + args = get_args() + pipe = CogVideoXPipeline.from_pretrained(args.pretrained_model_name_or_path, torch_dtype=torch.bfloat16).to(device) + pipe.load_lora_weights(args.lora_weights_path, weight_name="pytorch_lora_weights.safetensors", adapter_name="cogvideox-lora") + # pipe.fuse_lora(lora_scale=args.lora_alpha/args.lora_r, ['transformer']) + lora_scaling=args.lora_alpha/args.lora_r + pipe.set_adapters(["cogvideox-lora"], [lora_scaling]) + + + pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing") + + os.makedirs(args.output_dir, exist_ok=True) + + latents = pipe( + prompt=args.prompt, + num_videos_per_prompt=1, + num_inference_steps=50, + num_frames=49, + use_dynamic_cfg=True, + output_type="pt", + guidance_scale=3.0, + generator=torch.Generator(device="cpu").manual_seed(42), + ).frames + batch_size = latents.shape[0] + batch_video_frames = [] + for batch_idx in range(batch_size): + pt_image = latents[batch_idx] + pt_image = torch.stack([pt_image[i] for i in range(pt_image.shape[0])]) + + image_np = VaeImageProcessor.pt_to_numpy(pt_image) + image_pil = VaeImageProcessor.numpy_to_pil(image_np) + batch_video_frames.append(image_pil) + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + video_path = f"{args.output_dir}/{timestamp}.mp4" + os.makedirs(os.path.dirname(video_path), exist_ok=True) + tensor = batch_video_frames[0] + fps=math.ceil((len(batch_video_frames[0]) - 1) / 6) + + export_to_video(tensor, video_path, fps=fps) \ No newline at end of file diff --git a/PyTorch/contrib/cv/video/CogVideoX/tools/parallel_inference/parallel_inference_xdit.py b/PyTorch/contrib/cv/video/CogVideoX/tools/parallel_inference/parallel_inference_xdit.py new file mode 100644 index 0000000000000000000000000000000000000000..e4caf3316d656fe0e9751b0f330963cdb3b6aa91 --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/tools/parallel_inference/parallel_inference_xdit.py @@ -0,0 +1,105 @@ +""" +This is a parallel inference script for CogVideo. The original script +can be found from the xDiT project at + +https://github.com/xdit-project/xDiT/blob/main/examples/cogvideox_example.py + +By using this code, the inference process is parallelized on multiple GPUs, +and thus speeded up. + +Usage: +1. pip install xfuser +2. mkdir results +3. run the following command to generate video +torchrun --nproc_per_node=4 parallel_inference_xdit.py \ + --model --ulysses_degree 1 --ring_degree 2 \ + --use_cfg_parallel --height 480 --width 720 --num_frames 9 \ + --prompt 'A small dog.' + +You can also use the run.sh file in the same folder to automate running this +code for batch generation of videos, by running: + +sh ./run.sh + +""" + +import time +import torch +import torch.distributed +from diffusers import AutoencoderKLTemporalDecoder +from xfuser import xFuserCogVideoXPipeline, xFuserArgs +from xfuser.config import FlexibleArgumentParser +from xfuser.core.distributed import ( + get_world_group, + get_data_parallel_rank, + get_data_parallel_world_size, + get_runtime_state, + is_dp_last_group, +) +from diffusers.utils import export_to_video + + +def main(): + parser = FlexibleArgumentParser(description="xFuser Arguments") + args = xFuserArgs.add_cli_args(parser).parse_args() + engine_args = xFuserArgs.from_cli_args(args) + + # Check if ulysses_degree is valid + num_heads = 30 + if engine_args.ulysses_degree > 0 and num_heads % engine_args.ulysses_degree != 0: + raise ValueError( + f"ulysses_degree ({engine_args.ulysses_degree}) must be a divisor of the number of heads ({num_heads})" + ) + + engine_config, input_config = engine_args.create_config() + local_rank = get_world_group().local_rank + + pipe = xFuserCogVideoXPipeline.from_pretrained( + pretrained_model_name_or_path=engine_config.model_config.model, + engine_config=engine_config, + torch_dtype=torch.bfloat16, + ) + if args.enable_sequential_cpu_offload: + pipe.enable_model_cpu_offload(gpu_id=local_rank) + pipe.vae.enable_tiling() + else: + device = torch.device(f"cuda:{local_rank}") + pipe = pipe.to(device) + + torch.cuda.reset_peak_memory_stats() + start_time = time.time() + + output = pipe( + height=input_config.height, + width=input_config.width, + num_frames=input_config.num_frames, + prompt=input_config.prompt, + num_inference_steps=input_config.num_inference_steps, + generator=torch.Generator(device="cuda").manual_seed(input_config.seed), + guidance_scale=6, + ).frames[0] + + end_time = time.time() + elapsed_time = end_time - start_time + peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}") + + parallel_info = ( + f"dp{engine_args.data_parallel_degree}_cfg{engine_config.parallel_config.cfg_degree}_" + f"ulysses{engine_args.ulysses_degree}_ring{engine_args.ring_degree}_" + f"tp{engine_args.tensor_parallel_degree}_" + f"pp{engine_args.pipefusion_parallel_degree}_patch{engine_args.num_pipeline_patch}" + ) + if is_dp_last_group(): + world_size = get_data_parallel_world_size() + resolution = f"{input_config.width}x{input_config.height}" + output_filename = f"results/cogvideox_{parallel_info}_{resolution}.mp4" + export_to_video(output, output_filename, fps=8) + print(f"output saved to {output_filename}") + + if get_world_group().rank == get_world_group().world_size - 1: + print(f"epoch time: {elapsed_time:.2f} sec, memory: {peak_memory/1e9} GB") + get_runtime_state().destory_distributed_env() + + +if __name__ == "__main__": + main() diff --git a/PyTorch/contrib/cv/video/CogVideoX/tools/parallel_inference/run.sh b/PyTorch/contrib/cv/video/CogVideoX/tools/parallel_inference/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..7f9d5a8136a3873fa4118fc15954cf379a521b3f --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/tools/parallel_inference/run.sh @@ -0,0 +1,51 @@ +set -x + +export PYTHONPATH=$PWD:$PYTHONPATH + +# Select the model type +# The model is downloaded to a specified location on disk, +# or you can simply use the model's ID on Hugging Face, +# which will then be downloaded to the default cache path on Hugging Face. + +export MODEL_TYPE="CogVideoX" +# Configuration for different model types +# script, model_id, inference_step +declare -A MODEL_CONFIGS=( + ["CogVideoX"]="parallel_inference_xdit.py /cfs/dit/CogVideoX-2b 20" +) + +if [[ -v MODEL_CONFIGS[$MODEL_TYPE] ]]; then + IFS=' ' read -r SCRIPT MODEL_ID INFERENCE_STEP <<< "${MODEL_CONFIGS[$MODEL_TYPE]}" + export SCRIPT MODEL_ID INFERENCE_STEP +else + echo "Invalid MODEL_TYPE: $MODEL_TYPE" + exit 1 +fi + +mkdir -p ./results + +# task args +if [ "$MODEL_TYPE" = "CogVideoX" ]; then + TASK_ARGS="--height 480 --width 720 --num_frames 9" +fi + +# CogVideoX asserts sp_degree == ulysses_degree*ring_degree <= 2. Also, do not set the pipefusion degree. +if [ "$MODEL_TYPE" = "CogVideoX" ]; then +N_GPUS=4 +PARALLEL_ARGS="--ulysses_degree 2 --ring_degree 1" +CFG_ARGS="--use_cfg_parallel" +fi + + +torchrun --nproc_per_node=$N_GPUS ./$SCRIPT \ +--model $MODEL_ID \ +$PARALLEL_ARGS \ +$TASK_ARGS \ +$PIPEFUSION_ARGS \ +$OUTPUT_ARGS \ +--num_inference_steps $INFERENCE_STEP \ +--warmup_steps 0 \ +--prompt "A small dog." \ +$CFG_ARGS \ +$PARALLLEL_VAE \ +$COMPILE_FLAG diff --git a/PyTorch/contrib/cv/video/CogVideoX/tools/replicate/cog.yaml b/PyTorch/contrib/cv/video/CogVideoX/tools/replicate/cog.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2de2ddbce6bf0c8a6df528ec06cc7817e236e4e1 --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/tools/replicate/cog.yaml @@ -0,0 +1,37 @@ +# Configuration for Cog ⚙️ +# Reference: https://cog.run/yaml + +build: + # set to true if your model requires a GPU + gpu: true + + # a list of ubuntu apt packages to install + system_packages: + - "libgl1-mesa-glx" + - "libglib2.0-0" + + # python version in the form '3.11' or '3.11.4' + python_version: "3.11" + + # a list of packages in the format == + python_packages: + - diffusers>=0.30.3 + - accelerate>=0.34.2 + - transformers>=4.44.2 + - numpy==1.26.0 + - torch>=2.4.0 + - torchvision>=0.19.0 + - sentencepiece>=0.2.0 + - SwissArmyTransformer>=0.4.12 + - imageio>=2.35.1 + - imageio-ffmpeg>=0.5.1 + - openai>=1.45.0 + - moviepy>=1.0.3 + - pillow==9.5.0 + - pydantic==1.10.7 + run: + - curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.8.2/pget_linux_x86_64" && chmod +x /usr/local/bin/pget + +# predict.py defines how predictions are run on your model +predict: "predict_t2v.py:Predictor" +# predict: "predict_i2v.py:Predictor" diff --git a/PyTorch/contrib/cv/video/CogVideoX/tools/replicate/predict_i2v.py b/PyTorch/contrib/cv/video/CogVideoX/tools/replicate/predict_i2v.py new file mode 100644 index 0000000000000000000000000000000000000000..5e457961d3b1be28dc53e331bd4b41021098a4ce --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/tools/replicate/predict_i2v.py @@ -0,0 +1,89 @@ +# Prediction interface for Cog ⚙️ +# https://cog.run/python + +import os +import subprocess +import time +import torch +from diffusers import CogVideoXImageToVideoPipeline +from diffusers.utils import export_to_video, load_image +from cog import BasePredictor, Input, Path + + +MODEL_CACHE = "model_cache_i2v" +MODEL_URL = ( + f"https://weights.replicate.delivery/default/THUDM/CogVideo/{MODEL_CACHE}.tar" +) +os.environ["HF_DATASETS_OFFLINE"] = "1" +os.environ["TRANSFORMERS_OFFLINE"] = "1" +os.environ["HF_HOME"] = MODEL_CACHE +os.environ["TORCH_HOME"] = MODEL_CACHE +os.environ["HF_DATASETS_CACHE"] = MODEL_CACHE +os.environ["TRANSFORMERS_CACHE"] = MODEL_CACHE +os.environ["HUGGINGFACE_HUB_CACHE"] = MODEL_CACHE + + +def download_weights(url, dest): + start = time.time() + print("downloading url: ", url) + print("downloading to: ", dest) + subprocess.check_call(["pget", "-x", url, dest], close_fds=False) + print("downloading took: ", time.time() - start) + + +class Predictor(BasePredictor): + def setup(self) -> None: + """Load the model into memory to make running multiple predictions efficient""" + + if not os.path.exists(MODEL_CACHE): + download_weights(MODEL_URL, MODEL_CACHE) + + # model_id: THUDM/CogVideoX-5b-I2V + self.pipe = CogVideoXImageToVideoPipeline.from_pretrained( + MODEL_CACHE, torch_dtype=torch.bfloat16 + ).to("cuda") + + self.pipe.enable_model_cpu_offload() + self.pipe.vae.enable_tiling() + + def predict( + self, + prompt: str = Input( + description="Input prompt", default="Starry sky slowly rotating." + ), + image: Path = Input(description="Input image"), + num_inference_steps: int = Input( + description="Number of denoising steps", ge=1, le=500, default=50 + ), + guidance_scale: float = Input( + description="Scale for classifier-free guidance", ge=1, le=20, default=6 + ), + num_frames: int = Input( + description="Number of frames for the output video", default=49 + ), + seed: int = Input( + description="Random seed. Leave blank to randomize the seed", default=None + ), + ) -> Path: + """Run a single prediction on the model""" + + if seed is None: + seed = int.from_bytes(os.urandom(2), "big") + print(f"Using seed: {seed}") + + img = load_image(image=str(image)) + + video = self.pipe( + prompt=prompt, + image=img, + num_videos_per_prompt=1, + num_inference_steps=num_inference_steps, + num_frames=num_frames, + guidance_scale=guidance_scale, + generator=torch.Generator(device="cuda").manual_seed(seed), + ).frames[0] + + out_path = "/tmp/out.mp4" + + export_to_video(video, out_path, fps=8) + return Path(out_path) diff --git a/PyTorch/contrib/cv/video/CogVideoX/tools/replicate/predict_t2v.py b/PyTorch/contrib/cv/video/CogVideoX/tools/replicate/predict_t2v.py new file mode 100644 index 0000000000000000000000000000000000000000..cadeee25641e922da5c9be32c2006caed0e2845d --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/tools/replicate/predict_t2v.py @@ -0,0 +1,87 @@ +# Prediction interface for Cog ⚙️ +# https://cog.run/python + +import os +import subprocess +import time +import torch +from diffusers import CogVideoXPipeline +from diffusers.utils import export_to_video +from cog import BasePredictor, Input, Path + + +MODEL_CACHE = "model_cache" +MODEL_URL = ( + f"https://weights.replicate.delivery/default/THUDM/CogVideo/{MODEL_CACHE}.tar" +) +os.environ["HF_DATASETS_OFFLINE"] = "1" +os.environ["TRANSFORMERS_OFFLINE"] = "1" +os.environ["HF_HOME"] = MODEL_CACHE +os.environ["TORCH_HOME"] = MODEL_CACHE +os.environ["HF_DATASETS_CACHE"] = MODEL_CACHE +os.environ["TRANSFORMERS_CACHE"] = MODEL_CACHE +os.environ["HUGGINGFACE_HUB_CACHE"] = MODEL_CACHE + + +def download_weights(url, dest): + start = time.time() + print("downloading url: ", url) + print("downloading to: ", dest) + subprocess.check_call(["pget", "-x", url, dest], close_fds=False) + print("downloading took: ", time.time() - start) + + +class Predictor(BasePredictor): + def setup(self) -> None: + """Load the model into memory to make running multiple predictions efficient""" + + if not os.path.exists(MODEL_CACHE): + download_weights(MODEL_URL, MODEL_CACHE) + + # model_id: THUDM/CogVideoX-5b + self.pipe = CogVideoXPipeline.from_pretrained( + MODEL_CACHE, + torch_dtype=torch.bfloat16, + ).to("cuda") + + self.pipe.enable_model_cpu_offload() + self.pipe.vae.enable_tiling() + + def predict( + self, + prompt: str = Input( + description="Input prompt", + default="A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance.", + ), + num_inference_steps: int = Input( + description="Number of denoising steps", ge=1, le=500, default=50 + ), + guidance_scale: float = Input( + description="Scale for classifier-free guidance", ge=1, le=20, default=6 + ), + num_frames: int = Input( + description="Number of frames for the output video", default=49 + ), + seed: int = Input( + description="Random seed. Leave blank to randomize the seed", default=None + ), + ) -> Path: + """Run a single prediction on the model""" + + if seed is None: + seed = int.from_bytes(os.urandom(2), "big") + print(f"Using seed: {seed}") + + video = self.pipe( + prompt=prompt, + num_videos_per_prompt=1, + num_inference_steps=num_inference_steps, + num_frames=num_frames, + guidance_scale=guidance_scale, + generator=torch.Generator(device="cuda").manual_seed(seed), + ).frames[0] + + out_path = "/tmp/out.mp4" + + export_to_video(video, out_path, fps=8) + return Path(out_path) diff --git a/PyTorch/contrib/cv/video/CogVideoX/tools/venhancer/README.md b/PyTorch/contrib/cv/video/CogVideoX/tools/venhancer/README.md new file mode 100644 index 0000000000000000000000000000000000000000..cc6f45c632b1310908e305e398e7fcb9c36ac29e --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/tools/venhancer/README.md @@ -0,0 +1,98 @@ +# Enhance CogVideoX Generated Videos with VEnhancer + +This tutorial will guide you through using the VEnhancer tool to enhance videos generated by CogVideoX, including +achieving higher frame rates and higher resolutions. + +## Model Introduction + +VEnhancer implements spatial super-resolution, temporal super-resolution (frame interpolation), and video refinement in +a unified framework. It can flexibly adapt to different upsampling factors (e.g., 1x~8x) for spatial or temporal +super-resolution. Additionally, it provides flexible control to modify the refinement strength, enabling it to handle +diverse video artifacts. + +VEnhancer follows the design of ControlNet, copying the architecture and weights of the multi-frame encoder and middle +block from a pre-trained video diffusion model to build a trainable conditional network. This video ControlNet accepts +low-resolution keyframes and noisy full-frame latents as inputs. In addition to the time step t and prompt, our proposed +video-aware conditioning also includes noise augmentation level σ and downscaling factor s as additional network +conditioning inputs. + +## Hardware Requirements + ++ Operating System: Linux (requires xformers dependency) ++ Hardware: NVIDIA GPU with at least 60GB of VRAM per card. Machines such as H100, A100 are recommended. + +## Quick Start + +1. Clone the repository and install dependencies as per the official instructions: + +```shell +git clone https://github.com/Vchitect/VEnhancer.git +cd VEnhancer +## Torch and other dependencies can use those from CogVideoX. If you need to create a new environment, use the following commands: +pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 + +## Install required dependencies +pip install -r requirements.txt +``` + +Where: + +- `input_path` is the path to the input video +- `prompt` is the description of the video content. The prompt used by this tool should be shorter, not exceeding 77 + words. You may need to simplify the prompt used for generating the CogVideoX video. +- `target_fps` is the target frame rate for the video. Typically, 16 fps is already smooth, with 24 fps as the default + value. +- `up_scale` is recommend to be set to 2,3,4. The target resolution is limited to be around 2k and below. +- `noise_aug` value depends on the input video quality. Lower quality needs higher noise levels, which corresponds to + stronger refinement. 250~300 is for very low-quality videos. good videos: <= 200. +- `steps` if you want fewer steps, please change solver_mode to "normal" first, then decline the number of steps. " + fast" solver_mode has fixed steps (15). + The code will automatically download the required models from Hugging Face during execution. + +Typical runtime logs are as follows: + +```shell +/share/home/zyx/.conda/envs/cogvideox/lib/python3.10/site-packages/xformers/ops/fmha/flash.py:211: FutureWarning: `torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch. + @torch.library.impl_abstract("xformers_flash::flash_fwd") +/share/home/zyx/.conda/envs/cogvideox/lib/python3.10/site-packages/xformers/ops/fmha/flash.py:344: FutureWarning: `torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch. + @torch.library.impl_abstract("xformers_flash::flash_bwd") +2024-08-20 13:25:17,553 - video_to_video - INFO - checkpoint_path: ./ckpts/venhancer_paper.pt +/share/home/zyx/.conda/envs/cogvideox/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + checkpoint = torch.load(checkpoint_path, map_location=map_location) +2024-08-20 13:25:37,486 - video_to_video - INFO - Build encoder with FrozenOpenCLIPEmbedder +/share/home/zyx/Code/VEnhancer/video_to_video/video_to_video_model.py:35: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + load_dict = torch.load(cfg.model_path, map_location='cpu') +2024-08-20 13:25:55,391 - video_to_video - INFO - Load model path ./ckpts/venhancer_paper.pt, with local status +2024-08-20 13:25:55,392 - video_to_video - INFO - Build diffusion with GaussianDiffusion +2024-08-20 13:26:16,092 - video_to_video - INFO - input video path: inputs/000000.mp4 +2024-08-20 13:26:16,093 - video_to_video - INFO - text: Wide-angle aerial shot at dawn,soft morning light casting long shadows,an elderly man walking his dog through a quiet,foggy park,trees and benches in the background,peaceful and serene atmosphere +2024-08-20 13:26:16,156 - video_to_video - INFO - input frames length: 49 +2024-08-20 13:26:16,156 - video_to_video - INFO - input fps: 8.0 +2024-08-20 13:26:16,156 - video_to_video - INFO - target_fps: 24.0 +2024-08-20 13:26:16,311 - video_to_video - INFO - input resolution: (480, 720) +2024-08-20 13:26:16,312 - video_to_video - INFO - target resolution: (1320, 1982) +2024-08-20 13:26:16,312 - video_to_video - INFO - noise augmentation: 250 +2024-08-20 13:26:16,312 - video_to_video - INFO - scale s is set to: 8 +2024-08-20 13:26:16,399 - video_to_video - INFO - video_data shape: torch.Size([145, 3, 1320, 1982]) +/share/home/zyx/Code/VEnhancer/video_to_video/video_to_video_model.py:108: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead. + with amp.autocast(enabled=True): +2024-08-20 13:27:19,605 - video_to_video - INFO - step: 0 +2024-08-20 13:30:12,020 - video_to_video - INFO - step: 1 +2024-08-20 13:33:04,956 - video_to_video - INFO - step: 2 +2024-08-20 13:35:58,691 - video_to_video - INFO - step: 3 +2024-08-20 13:38:51,254 - video_to_video - INFO - step: 4 +2024-08-20 13:41:44,150 - video_to_video - INFO - step: 5 +2024-08-20 13:44:37,017 - video_to_video - INFO - step: 6 +2024-08-20 13:47:30,037 - video_to_video - INFO - step: 7 +2024-08-20 13:50:22,838 - video_to_video - INFO - step: 8 +2024-08-20 13:53:15,844 - video_to_video - INFO - step: 9 +2024-08-20 13:56:08,657 - video_to_video - INFO - step: 10 +2024-08-20 13:59:01,648 - video_to_video - INFO - step: 11 +2024-08-20 14:01:54,541 - video_to_video - INFO - step: 12 +2024-08-20 14:04:47,488 - video_to_video - INFO - step: 13 +2024-08-20 14:10:13,637 - video_to_video - INFO - sampling, finished. + +``` + +Running on a single A100 GPU, enhancing each 6-second CogVideoX generated video with default settings will consume 60GB +of VRAM and take 40-50 minutes. \ No newline at end of file diff --git a/PyTorch/contrib/cv/video/CogVideoX/tools/venhancer/README_ja.md b/PyTorch/contrib/cv/video/CogVideoX/tools/venhancer/README_ja.md new file mode 100644 index 0000000000000000000000000000000000000000..70f2d74d04e25139bc9972abab51b82d783c4225 --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/tools/venhancer/README_ja.md @@ -0,0 +1,92 @@ + +# VEnhancer で CogVideoX によって生成されたビデオを強化する + +このチュートリアルでは、VEnhancer ツールを使用して、CogVideoX で生成されたビデオを強化し、より高いフレームレートと高い解像度を実現する方法を説明します。 + +## モデルの紹介 + +VEnhancer は、空間超解像、時間超解像(フレーム補間)、およびビデオのリファインメントを統一されたフレームワークで実現します。空間または時間の超解像のために、さまざまなアップサンプリング係数(例:1x〜8x)に柔軟に対応できます。さらに、多様なビデオアーティファクトを処理するために、リファインメント強度を変更する柔軟な制御を提供します。 + +VEnhancer は ControlNet の設計に従い、事前訓練されたビデオ拡散モデルのマルチフレームエンコーダーとミドルブロックのアーキテクチャとウェイトをコピーして、トレーニング可能な条件ネットワークを構築します。このビデオ ControlNet は、低解像度のキーフレームとノイズを含む完全なフレームを入力として受け取ります。さらに、タイムステップ t とプロンプトに加えて、提案されたビデオ対応条件により、ノイズ増幅レベル σ およびダウンスケーリングファクター s が追加のネットワーク条件として使用されます。 + +## ハードウェア要件 + ++ オペレーティングシステム: Linux (xformers 依存関係が必要) ++ ハードウェア: 単一カードあたり少なくとも 60GB の VRAM を持つ NVIDIA GPU。H100、A100 などのマシンを推奨します。 + +## クイックスタート + +1. 公式の指示に従ってリポジトリをクローンし、依存関係をインストールします。 + +```shell +git clone https://github.com/Vchitect/VEnhancer.git +cd VEnhancer +## Torch などの依存関係は CogVideoX の依存関係を使用できます。新しい環境を作成する必要がある場合は、以下のコマンドを使用してください。 +pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 + +## 必須の依存関係をインストールします。 +pip install -r requirements.txt +``` + +2. コードを実行します。 + +```shell +python enhance_a_video.py --up_scale 4 --target_fps 24 --noise_aug 250 --solver_mode 'fast' --steps 15 --input_path inputs/000000.mp4 --prompt 'Wide-angle aerial shot at dawn, soft morning light casting long shadows, an elderly man walking his dog through a quiet, foggy park, trees and benches in the background, peaceful and serene atmosphere' --save_dir 'results/' +``` + +次の設定を行います: + +- `input_path` 是输入视频的路径 +- `prompt` 是视频内容的描述。此工具使用的提示词应更短,不超过77个字。您可能需要简化用于生成CogVideoX视频的提示词。 +- `target_fps` 是视频的目标帧率。通常,16 fps已经很流畅,默认值为24 fps。 +- `up_scale` 推荐设置为2、3或4。目标分辨率限制在2k左右及以下。 +- `noise_aug` 的值取决于输入视频的质量。质量较低的视频需要更高的噪声级别,这对应于更强的优化。250~300适用于非常低质量的视频。对于高质量视频,设置为≤200。 +- `steps` 如果想减少步数,请先将solver_mode改为“normal”,然后减少步数。“fast”模式的步数是固定的(15步)。 + 代码在执行过程中会自动从Hugging Face下载所需的模型。 + +コードの実行中に、必要なモデルは Hugging Face から自動的にダウンロードされます。 + +```shell +/share/home/zyx/.conda/envs/cogvideox/lib/python3.10/site-packages/xformers/ops/fmha/flash.py:211: FutureWarning: `torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch. + @torch.library.impl_abstract("xformers_flash::flash_fwd") +/share/home/zyx/.conda/envs/cogvideox/lib/python3.10/site-packages/xformers/ops/fmha/flash.py:344: FutureWarning: `torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch. + @torch.library.impl_abstract("xformers_flash::flash_bwd") +2024-08-20 13:25:17,553 - video_to_video - INFO - checkpoint_path: ./ckpts/venhancer_paper.pt +/share/home/zyx/.conda/envs/cogvideox/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + checkpoint = torch.load(checkpoint_path, map_location=map_location) +2024-08-20 13:25:37,486 - video_to_video - INFO - Build encoder with FrozenOpenCLIPEmbedder +/share/home/zyx/Code/VEnhancer/video_to_video/video_to_video_model.py:35: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + load_dict = torch.load(cfg.model_path, map_location='cpu') +2024-08-20 13:25:55,391 - video_to_video - INFO - Load model path ./ckpts/venhancer_paper.pt, with local status +2024-08-20 13:25:55,392 - video_to_video - INFO - Build diffusion with GaussianDiffusion +2024-08-20 13:26:16,092 - video_to_video - INFO - input video path: inputs/000000.mp4 +2024-08-20 13:26:16,093 - video_to_video - INFO - text: Wide-angle aerial shot at dawn,soft morning light casting long shadows,an elderly man walking his dog through a quiet,foggy park,trees and benches in the background,peaceful and serene atmosphere +2024-08-20 13:26:16,156 - video_to_video - INFO - input frames length: 49 +2024-08-20 13:26:16,156 - video_to_video - INFO - input fps: 8.0 +2024-08-20 13:26:16,156 - video_to_video - INFO - target_fps: 24.0 +2024-08-20 13:26:16,311 - video_to_video - INFO - input resolution: (480, 720) +2024-08-20 13:26:16,312 - video_to_video - INFO - target resolution: (1320, 1982) +2024-08-20 13:26:16,312 - video_to_video - INFO - noise augmentation: 250 +2024-08-20 13:26:16,312 - video_to_video - INFO - scale s is set to: 8 +2024-08-20 13:26:16,399 - video_to_video - INFO - video_data shape: torch.Size([145, 3, 1320, 1982]) +/share/home/zyx/Code/VEnhancer/video_to_video/video_to_video_model.py:108: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead. + with amp.autocast(enabled=True): +2024-08-20 13:27:19,605 - video_to_video - INFO - step: 0 +2024-08-20 13:30:12,020 - video_to_video - INFO - step: 1 +2024-08-20 13:33:04,956 - video_to_video - INFO - step: 2 +2024-08-20 13:35:58,691 - video_to_video - INFO - step: 3 +2024-08-20 13:38:51,254 - video_to_video - INFO - step: 4 +2024-08-20 13:41:44,150 - video_to_video - INFO - step: 5 +2024-08-20 13:44:37,017 - video_to_video - INFO - step: 6 +2024-08-20 13:47:30,037 - video_to_video - INFO - step: 7 +2024-08-20 13:50:22,838 - video_to_video - INFO - step: 8 +2024-08-20 13:53:15,844 - video_to_video - INFO - step: 9 +2024-08-20 13:56:08,657 - video_to_video - INFO - step: 10 +2024-08-20 13:59:01,648 - video_to_video - INFO - step: 11 +2024-08-20 14:01:54,541 - video_to_video - INFO - step: 12 +2024-08-20 14:04:47,488 - video_to_video - INFO - step: 13 +2024-08-20 14:10:13,637 - video_to_video - INFO - sampling, finished. + +``` + +A100 GPU を単一で使用している場合、CogVideoX によって生成された 6 秒間のビデオを強化するには、デフォルト設定で 60GB の VRAM を消費し、40〜50 分かかります。 diff --git a/PyTorch/contrib/cv/video/CogVideoX/tools/venhancer/README_zh.md b/PyTorch/contrib/cv/video/CogVideoX/tools/venhancer/README_zh.md new file mode 100644 index 0000000000000000000000000000000000000000..a481cd179347c3ca446f8f0f572e6eaa10ecbf09 --- /dev/null +++ b/PyTorch/contrib/cv/video/CogVideoX/tools/venhancer/README_zh.md @@ -0,0 +1,101 @@ +# 使用 VEnhancer 对 CogVdieoX 生成视频进行增强 + +本教程将要使用 VEnhancer 工具 对 CogVdieoX 生成视频进行增强, 包括更高的帧率和更高的分辨率 + +## 模型介绍 + +VEnhancer 在一个统一的框架中实现了空间超分辨率、时间超分辨率(帧插值)和视频优化。它可以灵活地适应不同的上采样因子(例如,1x~ +8x)用于空间或时间超分辨率。此外,它提供了灵活的控制,以修改优化强度,从而处理多样化的视频伪影。 + +VEnhancer 遵循 ControlNet 的设计,复制了预训练的视频扩散模型的多帧编码器和中间块的架构和权重,构建了一个可训练的条件网络。这个视频 +ControlNet 接受低分辨率关键帧和包含噪声的完整帧作为输入。此外,除了时间步 t 和提示词外,我们提出的视频感知条件还将噪声增强的噪声级别 +σ 和降尺度因子 s 作为附加的网络条件输入。 + +## 硬件需求 + ++ 操作系统: Linux (需要依赖xformers) ++ 硬件: NVIDIA GPU 并至少保证单卡显存超过60G,推荐使用 H100,A100等机器。 + +## 快速上手 + +1. 按照官方指引克隆仓库并安装依赖 + +```shell +git clone https://github.com/Vchitect/VEnhancer.git +cd VEnhancer +## torch等依赖可以使用CogVideoX的依赖,如果你需要创建一个新的环境,可以使用以下命令 +pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 + +## 安装必须的依赖 +pip install -r requirements.txt +``` + +2. 运行代码 + +```shell +python enhance_a_video.py \ +--up_scale 4 --target_fps 24 --noise_aug 250 \ +--solver_mode 'fast' --steps 15 \ +--input_path inputs/000000.mp4 \ +--prompt 'Wide-angle aerial shot at dawn,soft morning light casting long shadows,an elderly man walking his dog through a quiet,foggy park,trees and benches in the background,peaceful and serene atmosphere' \ +--save_dir 'results/' +``` + +其中: + +- `input_path` 是输入视频的路径 +- `prompt` 是视频内容的描述。此工具使用的提示词应更短,不超过77个字。您可能需要简化用于生成CogVideoX视频的提示词。 +- `target_fps` 是视频的目标帧率。通常,16 fps已经很流畅,默认值为24 fps。 +- `up_scale` 推荐设置为2、3或4。目标分辨率限制在2k左右及以下。 +- `noise_aug` 的值取决于输入视频的质量。质量较低的视频需要更高的噪声级别,这对应于更强的优化。250~300适用于非常低质量的视频。对于高质量视频,设置为≤200。 +- `steps` 如果想减少步数,请先将solver_mode改为“normal”,然后减少步数。“fast”模式的步数是固定的(15步)。 + 代码在执行过程中会自动从Hugging Face下载所需的模型。 + +代码运行过程中,会自动从Huggingface拉取需要的模型 + +运行日志通常如下: + +```shell +/share/home/zyx/.conda/envs/cogvideox/lib/python3.10/site-packages/xformers/ops/fmha/flash.py:211: FutureWarning: `torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch. + @torch.library.impl_abstract("xformers_flash::flash_fwd") +/share/home/zyx/.conda/envs/cogvideox/lib/python3.10/site-packages/xformers/ops/fmha/flash.py:344: FutureWarning: `torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch. + @torch.library.impl_abstract("xformers_flash::flash_bwd") +2024-08-20 13:25:17,553 - video_to_video - INFO - checkpoint_path: ./ckpts/venhancer_paper.pt +/share/home/zyx/.conda/envs/cogvideox/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + checkpoint = torch.load(checkpoint_path, map_location=map_location) +2024-08-20 13:25:37,486 - video_to_video - INFO - Build encoder with FrozenOpenCLIPEmbedder +/share/home/zyx/Code/VEnhancer/video_to_video/video_to_video_model.py:35: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + load_dict = torch.load(cfg.model_path, map_location='cpu') +2024-08-20 13:25:55,391 - video_to_video - INFO - Load model path ./ckpts/venhancer_paper.pt, with local status +2024-08-20 13:25:55,392 - video_to_video - INFO - Build diffusion with GaussianDiffusion +2024-08-20 13:26:16,092 - video_to_video - INFO - input video path: inputs/000000.mp4 +2024-08-20 13:26:16,093 - video_to_video - INFO - text: Wide-angle aerial shot at dawn,soft morning light casting long shadows,an elderly man walking his dog through a quiet,foggy park,trees and benches in the background,peaceful and serene atmosphere +2024-08-20 13:26:16,156 - video_to_video - INFO - input frames length: 49 +2024-08-20 13:26:16,156 - video_to_video - INFO - input fps: 8.0 +2024-08-20 13:26:16,156 - video_to_video - INFO - target_fps: 24.0 +2024-08-20 13:26:16,311 - video_to_video - INFO - input resolution: (480, 720) +2024-08-20 13:26:16,312 - video_to_video - INFO - target resolution: (1320, 1982) +2024-08-20 13:26:16,312 - video_to_video - INFO - noise augmentation: 250 +2024-08-20 13:26:16,312 - video_to_video - INFO - scale s is set to: 8 +2024-08-20 13:26:16,399 - video_to_video - INFO - video_data shape: torch.Size([145, 3, 1320, 1982]) +/share/home/zyx/Code/VEnhancer/video_to_video/video_to_video_model.py:108: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead. + with amp.autocast(enabled=True): +2024-08-20 13:27:19,605 - video_to_video - INFO - step: 0 +2024-08-20 13:30:12,020 - video_to_video - INFO - step: 1 +2024-08-20 13:33:04,956 - video_to_video - INFO - step: 2 +2024-08-20 13:35:58,691 - video_to_video - INFO - step: 3 +2024-08-20 13:38:51,254 - video_to_video - INFO - step: 4 +2024-08-20 13:41:44,150 - video_to_video - INFO - step: 5 +2024-08-20 13:44:37,017 - video_to_video - INFO - step: 6 +2024-08-20 13:47:30,037 - video_to_video - INFO - step: 7 +2024-08-20 13:50:22,838 - video_to_video - INFO - step: 8 +2024-08-20 13:53:15,844 - video_to_video - INFO - step: 9 +2024-08-20 13:56:08,657 - video_to_video - INFO - step: 10 +2024-08-20 13:59:01,648 - video_to_video - INFO - step: 11 +2024-08-20 14:01:54,541 - video_to_video - INFO - step: 12 +2024-08-20 14:04:47,488 - video_to_video - INFO - step: 13 +2024-08-20 14:10:13,637 - video_to_video - INFO - sampling, finished. + +``` + +使用A100单卡运行,对于每个CogVideoX生产的6秒视频,按照默认配置,会消耗60G显存,并用时40-50分钟。 \ No newline at end of file