diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/README.md b/MindIE/MindIE-Torch/built-in/foundation/cogview3/README.md index e161ab5c1018a5ab7ac79eaacac77eefaba2e798..a16bf5a0b90b5245fd247c040ad960e4838a1e06 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/cogview3/README.md +++ b/MindIE/MindIE-Torch/built-in/foundation/cogview3/README.md @@ -5,7 +5,7 @@ | 配套 | 版本 | 环境准备指导 | | ----- | ----- |-----| | Python | 3.10.12 | - | - | torch | 2.4.0 | - | + | torch | 2.1.0 | - | ### 1.1 获取CANN&MindIE安装包&环境准备 - [800I A2](https://www.hiascend.com/developer/download/community/result?module=pt+ie+cann&product=4&model=32) @@ -70,15 +70,15 @@ pip install torch_npu-{pytorchversion}.xxxx.{arch}.whl ## 三、CogView3使用 ### 3.1 权重及配置文件说明 -1. CogView3权重路径: +1. CogView3权重主路径: ```shell https://huggingface.co/THUDM/CogView3-Plus-3B/tree/main ``` -- 修改该权重的model_index.json +- 修改主路径下的model_index.json文件 ```shell { "_class_name": "CogView3PlusPipeline", - "_diffusers_version": "0.31.0", + "_diffusers_version": "0.31.0.dev0", "scheduler": [ "cogview3plus", "CogVideoXDDIMScheduler" @@ -117,6 +117,30 @@ https://huggingface.co/THUDM/CogView3-Plus-3B/tree/main/tokenizer ```shell https://huggingface.co/THUDM/CogView3-Plus-3B/tree/main/transformer ``` +- 修改该路径下的config.json文件 +```shell +{ + "_class_name": "CogView3PlusTransformer2DModel", + "_diffusers_version": "0.31.0.dev0", + "attention_head_dim": 40, + "condition_dim": 256, + "in_channels": 16, + "num_attention_heads": 64, + "num_layers": 30, + "out_channels": 16, + "patch_size": 2, + "pooled_projection_dim": 1536, + "pos_embed_max_size": 128, + "sample_size": 128, + "text_embed_dim": 4096, + "time_embed_dim": 512, + "use_cache": False, + "cache_interval": 2, + "cache_start": 1, + "num_cache_layer" 11, + "cache_start_steps" 10 +} +``` 6. vae权重链接: ```shell https://huggingface.co/THUDM/CogView3-Plus-3B/tree/main/vae @@ -142,25 +166,169 @@ https://huggingface.co/THUDM/CogView3-Plus-3B/tree/main/vae | | |---- 模型权重 ``` -### 3.2 单卡单prompt功能测试 -设置权重路径 +### 3.2 权重下载 +提前下载权重,放到数据集目录下(/data)。 +```shell +# 需要使用 git-lfs (https://git-lfs.com) +git lfs install + +# 下载CogView3权重 +git clone https://huggingface.co/THUDM/CogView3-Plus-3B +``` + +### 3.3 性能测试 +1. 进入主路径 ```shell -model_path='/data/CogView3B' +cd cogview3 ``` -执行命令: +2. 推理: ```shell python inference_cogview3plus.py \ - --model_path ${model_path} \ - --device_id 0 \ + --model_path /data/CogView3B \ + --prompt_file ./prompts/example_prompts.txt \ --width 1024 \ --height 1024 \ --num_inference_steps 50 \ - --dtype bf16 + --dtype bf16 \ + --device_id 0 ``` 参数说明: - model_path:权重路径,包含scheduler、text_encoder、tokenizer、transformer、vae,5个模型的配置文件及权重。 -- device_id:推理设备ID。 +- prompt_file:提示词文件。 - width:需要生成的图像的宽。 - height: 需要生成的图像的高。 - num_inference_steps:推理迭代步数。 - dtype: 数据类型。目前只支持bf16。 +- device_id:推理设备ID。 + +3. 可以通过修改权重文件中`/data/CongView3B/transforer/config.json`中的`use_cache`参数来控制dit cache算法的开关,`true`表示使用dit cache,`false`表示关闭dit cache。 + +### 3.4 精度测试 + +1. 由于生成的图片存在随机性,提供两种精度验证方法: + 1. CLIP-score(文图匹配度量):评估图片和输入文本的相关性,分数的取值范围为[-1, 1],越高越好。使用Parti数据集进行验证。 + 2. HPSv2(图片美学度量):评估生成图片的人类偏好评分,分数的取值范围为[0, 1],越高越好。使用HPSv2数据集进行验证 + + 注意,由于要生成的图片数量较多,进行完整的精度验证需要耗费很长的时间。 + +2. 下载Parti数据集和hpsv2数据集 + 所有数据集放到congview3/prompts目录下 + ```bash + # 下载Parti数据集 + wget https://raw.githubusercontent.com/google-research/parti/main/PartiPrompts.tsv --no-check-certificate + ``` + hpsv2数据集下载链接:https://gitee.com/ascend/ModelZoo-PyTorch/blob/master/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_xl/hpsv2_benchmark_prompts.json + +3. 下载模型权重 + + ```bash + # Clip Score和HPSv2均需要使用的权重 + GIT_LFS_SKIP_SMUDGE=1 + git clone https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K + + # HPSv2权重 + wget https://huggingface.co/spaces/xswu/HPSv2/resolve/main/HPS_v2_compressed.pt --no-check-certificate + ``` + 也可手动下载[权重](https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K/blob/main/open_clip_pytorch_model.bin) + 将权重放到`CLIP-ViT-H-14-laion2B-s32B-b79K`目录下,手动下载[HPSv2权重](https://huggingface.co/spaces/xswu/HPSv2/resolve/main/HPS_v2_compressed.pt)放到当前路径 + +4. 使用推理脚本读取Parti数据集,生成图片 + ```bash + mkdir ./results_PartiPrompts + python3 inference_cogview3plus.py \ + --model_path /data/CogView3B \ + --prompt_file ./prompts/PartiPrompts.tsv \ + --prompt_file_type parti \ + --info_file_save_path ./image_info_PartiPrompts.json \ + --save_dir ./results_PartiPrompts \ + --num_images_per_prompt 4 \ + --height 1024 \ + --width 1024 \ + --batch_size 1 \ + --seed 42 \ + --device_id 0 + ``` + 参数说明: + - model_path:权重路径,包含scheduler、text_encoder、tokenizer、transformer、vae,5个模型的配置文件及权重。 + - prompt_file:提示词文件。 + - prompt_file_type: prompt文件类型,用于指定读取方式,可选plain,parti,hpsv2。 + - info_file_save_path:生成图片信息的json文件路径。 + - save_dir:生成图片的存放目录。 + - num_images_per_prompt: 每个prompt生成的图片数量。注意使用hpsv2时,设置num_images_per_prompt=1即可。 + - height: 需要生成的图像的高。 + - width:需要生成的图像的宽。 + - batch_size:模型batch size。 + - seed:随机种子。 + - device_id:推理设备ID。 + + 执行完成后在`./results_PartiPrompts`目录下生成推理图片,在当前目录生成一个`image_info_PartiPrompts.json`文件,记录着图片和prompt的对应关系,并在终端显示推理时间。 + +5. 使用推理脚本读取hpsv2数据集,生成图片 + ```bash + mkdir ./results_hpsv2 + python3 inference_cogview3plus.py \ + --model_path /data/CogView3B \ + --prompt_file ./prompts/hpsv2_benchmark_prompts.json \ + --prompt_file_type hpsv2 \ + --info_file_save_path ./image_info_hpsv2.json \ + --save_dir ./results_hpsv2 \ + --num_images_per_prompt 1 \ + --height 1024 \ + --width 1024 \ + --batch_size 1 \ + --seed 42 \ + --device_id 0 + ``` + 参数说明: + - model_path:权重路径,包含scheduler、text_encoder、tokenizer、transformer、vae,5个模型的配置文件及权重。 + - prompt_file:提示词文件。 + - prompt_file_type: prompt文件类型,用于指定读取方式,可选plain,parti,hpsv2。 + - info_file_save_path:生成图片信息的json文件路径。 + - save_dir:生成图片的存放目录。 + - num_images_per_prompt: 每个prompt生成的图片数量。注意使用hpsv2时,设置num_images_per_prompt=1即可。 + - height: 需要生成的图像的高。 + - width:需要生成的图像的宽。 + - batch_size:模型batch size。 + - seed:随机种子。 + - device_id:推理设备ID。 + + 执行完成后在`./results_hpsv2`目录下生成推理图片,在当前目录生成一个`image_info_hpsv2.json`文件,记录着图片和prompt的对应关系,并在终端显示推理时间。 + +6. 计算精度指标 + 1. CLIP-score + ```bash + python3 clip_score.py \ + --device=gpu \ + --image_info="./image_info_PartiPrompts.json" \ + --model_name="ViT-H-14" \ + --model_weights_path="./CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin" + ``` + 参数说明: + - --device: 推理设备(CPU或者GPU)。 + - --image_info: 上一步生成的`image_info.json`文件。 + - --model_name: Clip模型名称。 + - --model_weights_path: Clip模型权重文件路径。 + + 执行完成后会在屏幕打印出精度计算结果。 + + 2. HPSv2 + ```bash + python3 hpsv2_score.py \ + --image_info="image_info_hpsv2.json" \ + --HPSv2_checkpoint="./HPS_v2_compressed.pt" \ + --clip_checkpoint="./CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin" + ``` + + 参数说明: + - --image_info: 上一步生成的`image_info.json`文件。 + - --HPSv2_checkpoint: HPSv2模型权重文件路径。 + - --clip_checkpointh: Clip模型权重文件路径。 + + 执行完成后会在屏幕打印出精度计算结果。 + +### CogView3plus + +| 硬件形态 | 迭代次数 | dit cache | 平均耗时 | CLIP_score | HPSV2_score | +| :------: |:----:|:----:|:----:|:----:|:----:| +| Atlas 800T A2 (64G) 单卡 | 50 | False | 27.588s | 0.367 | 0.2879729 | +| Atlas 800T A2 (64G) 单卡 | 50 | True | 23.639s | 0.367 | 0.2878573 | \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/clip_score.py b/MindIE/MindIE-Torch/built-in/foundation/cogview3/clip_score.py new file mode 100644 index 0000000000000000000000000000000000000000..6b758f7e41aa4d36151bcee4423a0b0019035561 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogview3/clip_score.py @@ -0,0 +1,141 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import json +import time +import argparse + +import open_clip +import numpy as np +from PIL import Image +import torch +import torch.nn.functional as F + + +def clip_score(models, prompt, image_files, device): + model_clip = models[0] + tokenizer = models[1] + preprocess = models[2] + imgs = [] + texts = [] + for image_file in image_files: + img = preprocess(Image.open(image_file)).unsqueeze(0).to(device) + imgs.append(img) + text = tokenizer([prompt]).to(device) + texts.append(text) + + img = torch.cat(imgs) # [bs, 3, 224, 224] + text = torch.cat(texts) # [bs, 77] + + with torch.no_grad(): + text_ft = model_clip.encode_text(text).float() + img_ft = model_clip.encode_image(img).float() + score = F.cosine_similarity(img_ft, text_ft).squeeze() + + return score.cpu() + + +def main(): + args = parse_arguments() + + if args.device is None: + device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu') + else: + device = torch.device(args.device) + + t_b = time.time() + print(f"Load clip model...") + model_clip, _, preprocess = open_clip.create_model_and_transforms( + args.model_name, pretrained=args.model_weights_path, device=device) + model_clip.eval() + print(f">done. elapsed time: {(time.time() - t_b):.3f} s") + + tokenizer = open_clip.get_tokenizer(args.model_name) + + with os.fdopen(os.open(args.image_info, os.O_RDONLY), "r") as f: + image_info = json.load(f) + + t_b = time.time() + print(f"Calc clip score...") + all_scores = [] + cat_scores = {} + + for i, info in enumerate(image_info): + image_files = info['images'] + category = info['category'] + prompt = info['prompt'] + + print(f"[{i + 1}/{len(image_info)}] {prompt}") + + image_scores = clip_score((model_clip, tokenizer, preprocess), + prompt, + image_files, + device) + if len(image_files) > 1: + best_score = max(image_scores) + else: + best_score = image_scores + + print(f"image scores: {image_scores}") + print(f"best score: {best_score}") + + all_scores.append(best_score) + if category not in cat_scores: + cat_scores[category] = [] + cat_scores[category].append(best_score) + print(f">done. elapsed time: {(time.time() - t_b):.3f} s") + + average_score = np.average(all_scores) + print(f"====================================") + print(f"average score: {average_score:.3f}") + print(f"category average scores:") + cat_average_scores = {} + for category, scores in cat_scores.items(): + cat_average_scores[category] = np.average(scores) + print(f"[{category}], average score: {cat_average_scores[category]:.3f}") + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--device", + type=str, + default="cpu", + choices=["cpu", "cuda"], + help="device for torch.", + ) + parser.add_argument( + "--image_info", + type=str, + default="./image_info.json", + help="Image_info.json file.", + ) + parser.add_argument( + "--model_name", + type=str, + default="ViT-H-14", + help="open clip model name", + ) + parser.add_argument( + "--model_weights_path", + type=str, + default="./CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin", + help="open clip model weights", + ) + return parser.parse_args() + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/__init__.py index 1139593a36dcd44b33934bf240adf6ab4a477613..e5bd9d5fa93e9131eea8c8fdeaf2d03e7cb330de 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/__init__.py +++ b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/__init__.py @@ -1,3 +1,4 @@ from .pipeline import CogView3PlusPipeline, DiffusionPipeline from .schedulers import CogVideoXDDIMScheduler, SchedulerMixin -from .models import CogView3PlusTransformer2DModel, ModelMixin \ No newline at end of file +from .models import CogView3PlusTransformer2DModel, ModelMixin +from .utils import set_random_seed \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/transformer_cogview3plus.py b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/transformer_cogview3plus.py index f704e22589437bb8533e428f8cf1bbab7375f2ae..9515d865be0cebf651110fbbb159f1028f1409b2 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/transformer_cogview3plus.py +++ b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/models/transformer_cogview3plus.py @@ -130,9 +130,9 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin): pos_embed_max_size: int = 128, use_cache: bool = True, cache_interval: int = 2, - cache_start: int = 3, - num_cache_layer: int = 13, - cache_start_steps: int = 5, + cache_start: int = 1, + num_cache_layer: int = 11, + cache_start_steps: int = 10, ): super().__init__() self.out_channels = out_channels diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/pipeline/pipeline_cogview3plus.py b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/pipeline/pipeline_cogview3plus.py index fe2bd5cfcd33a2a7e99cdf4f79b1130dc86e5cf5..2f14fdd7c334b8517d54e32730b68db6dbe9914e 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/pipeline/pipeline_cogview3plus.py +++ b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/pipeline/pipeline_cogview3plus.py @@ -27,7 +27,6 @@ from diffusers import AutoencoderKL from ..models import CogView3PlusTransformer2DModel from ..schedulers import CogVideoXDDIMScheduler -from .pipeline_output import CogView3PipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -224,7 +223,7 @@ class CogView3PlusPipeline(DiffusionPipeline): num_inference_steps: int = 50, guidance_scale: float = 5.0, num_images_per_prompt: int = 1, - ) -> Union[CogView3PipelineOutput, Tuple]: + ) -> Tuple: if image_size is None: height = self.transformer.config.sample_size * self.vae_scale_factor width = self.transformer.config.sample_size * self.vae_scale_factor @@ -336,4 +335,4 @@ class CogView3PlusPipeline(DiffusionPipeline): # Offload all models self.maybe_free_model_hooks() - return CogView3PipelineOutput(images=image) \ No newline at end of file + return (image,) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/utils/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f35da6dceac87094e7ca1c0afb182a8f119b8c36 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/utils/__init__.py @@ -0,0 +1 @@ +from .utils import set_random_seed \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/utils/utils.py b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..de985c14534241e01913c891139c712ce1d990dc --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogview3/cogview3plus/utils/utils.py @@ -0,0 +1,35 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import importlib +import random +import torch +import numpy as np + + +def set_random_seed(seed): + """Set random seed. + + Args: + seed (int, optional): Seed to be used. + + """ + + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + return seed \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/hpsv2_score.py b/MindIE/MindIE-Torch/built-in/foundation/cogview3/hpsv2_score.py new file mode 100644 index 0000000000000000000000000000000000000000..e535a2ffcdfd7c731bc249befa750ee34eefdebc --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogview3/hpsv2_score.py @@ -0,0 +1,123 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +from typing import Union +import json + +from clint.textui import progress +import hpsv2 +from hpsv2.utils import root_path, hps_version_map +from hpsv2.src.open_clip import create_model_and_transforms, get_tokenizer +import huggingface_hub +from PIL import Image +import requests +import torch + + +def initialize_model(pretrained_path, device): + model, _, preprocess_val = create_model_and_transforms( + "ViT-H-14", pretrained=pretrained_path, precision='amp', + device=device, + jit=False, + force_quick_gelu=False, + force_custom_text=False, + force_patch_dropout=False, + force_image_size=None, + pretrained_image=False, + image_mean=None, + image_std=None, + light_augmentation=True, + aug_cfg={}, + output_dict=True, + with_score_predictor=False, + with_region_predictor=False + ) + return model, preprocess_val + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--image_info", + type=str, + default="./image_info.json", + help="Image_info.json file.", + ) + parser.add_argument( + "--HPSv2_checkpoint", + type=str, + default="./HPS_v2_compressed.pt", + help="HPS_v2 model weights", + ) + parser.add_argument( + "--clip_checkpoint", + type=str, + default="./CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin", + help="open clip model weights", + ) + return parser.parse_args() + + +def main(): + args = parse_arguments() + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + model, preprocess_val = initialize_model(args.clip_checkpoint, device) + + checkpoint = torch.load(args.HPSv2_checkpoint, map_location=device) + model.load_state_dict(checkpoint['state_dict']) + tokenizer = get_tokenizer('ViT-H-14') + model = model.to(device) + model.eval() + + with os.fdopen(os.open(args.image_info, os.O_RDONLY), "r") as f: + image_info = json.load(f) + + result = [] + for i, info in enumerate(image_info): + image_file = info['images'][0] + prompt = info['prompt'] + + # Load your image and prompt + with torch.no_grad(): + # Process the image + if isinstance(image_file, str): + image = preprocess_val(Image.open(image_file)) + elif isinstance(image_file, Image.Image): + image = preprocess_val(image_file) + else: + raise TypeError('The type of parameter img_path is illegal.') + image = image.unsqueeze(0).to(device=device, non_blocking=True) + # Process the prompt + text = tokenizer([prompt]).to(device=device, non_blocking=True) + # Calculate the HPS + with torch.cuda.amp.autocast(): + outputs = model(image, text) + image_features = outputs["image_features"] + text_features = outputs["text_features"] + logits_per_image = image_features @ text_features.T + + hps_score = torch.diagonal(logits_per_image).cpu().numpy() + print(f"image {i} hps_score: ", hps_score[0]) + + result.append(hps_score[0]) + + print('avg HPSv2 score:', sum(result) / len(result)) + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/inference_cogview3plus.py b/MindIE/MindIE-Torch/built-in/foundation/cogview3/inference_cogview3plus.py index c3bb1f2ebbf59fe976dab8533c293a6d2c76afe9..0f8796713ad16c3cab6a528b189b6294d164aee4 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/cogview3/inference_cogview3plus.py +++ b/MindIE/MindIE-Torch/built-in/foundation/cogview3/inference_cogview3plus.py @@ -17,34 +17,150 @@ import argparse import logging import time +import os +import csv +import json import torch -from cogview3plus import CogView3PlusPipeline +from cogview3plus import CogView3PlusPipeline, set_random_seed logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) +class PromptLoader: + def __init__( + self, + prompt_file: str, + prompt_file_type: str, + batch_size: int, + num_images_per_prompt: int = 1, + max_num_prompts: int = 0 + ): + self.prompts = [] + self.categories = ['Not_specified'] + self.batch_size = batch_size + self.num_images_per_prompt = num_images_per_prompt + + if prompt_file_type == 'plain': + self.load_prompts_plain(prompt_file, max_num_prompts) + elif prompt_file_type == 'parti': + self.load_prompts_parti(prompt_file, max_num_prompts) + elif prompt_file_type == 'hpsv2': + self.load_prompts_hpsv2(prompt_file, max_num_prompts) + else: + print("This operation is not supported!") + + self.current_id = 0 + self.inner_id = 0 + + def __len__(self): + return len(self.prompts) * self.num_images_per_prompt + + def __iter__(self): + return self + + def __next__(self): + if self.current_id == len(self.prompts): + raise StopIteration + + ret = { + 'prompts': [], + 'categories': [], + 'save_names': [], + 'n_prompts': self.batch_size, + } + for _ in range(self.batch_size): + if self.current_id == len(self.prompts): + ret['prompts'].append('') + ret['save_names'].append('') + ret['categories'].append('') + ret['n_prompts'] -= 1 + + else: + prompt, category_id = self.prompts[self.current_id] + ret['prompts'].append(prompt) + ret['categories'].append(self.categories[category_id]) + ret['save_names'].append(f'{self.current_id}_{self.inner_id}') + + self.inner_id += 1 + if self.inner_id == self.num_images_per_prompt: + self.inner_id = 0 + self.current_id += 1 + + return ret + + def load_prompts_plain(self, file_path: str, max_num_prompts: int): + with os.fdopen(os.open(file_path, os.O_RDONLY), "r") as f: + for i, line in enumerate(f): + if max_num_prompts and i == max_num_prompts: + break + + prompt = line.strip() + self.prompts.append((prompt, 0)) + + def load_prompts_parti(self, file_path: str, max_num_prompts: int): + with os.fdopen(os.open(file_path, os.O_RDONLY), "r") as f: + # Skip the first line + next(f) + tsv_file = csv.reader(f, delimiter="\t") + for i, line in enumerate(tsv_file): + if max_num_prompts and i == max_num_prompts: + break + + prompt = line[0] + category = line[1] + if category not in self.categories: + self.categories.append(category) + + category_id = self.categories.index(category) + self.prompts.append((prompt, category_id)) + + def load_prompts_hpsv2(self, file_path: str, max_num_prompts: int): + with open(file_path, 'r') as file: + all_prompts = json.load(file) + count = 0 + for style, prompts in all_prompts.items(): + for prompt in prompts: + count += 1 + if max_num_prompts and count >= max_num_prompts: + break + + if style not in self.categories: + self.categories.append(style) + + category_id = self.categories.index(style) + self.prompts.append((prompt, category_id)) + + def parse_arguments(): parser = argparse.ArgumentParser(description="Generate an image using the CogView3-Plus-3B model.") # Define arguments for prompt, model path, etc. parser.add_argument( - "--prompt", - type=list, - default=[ - "A vibrant cherry red sports car sits proudly under the gleaming sun, \ - its polished exterior smooth and flawless, casting a mirror-like reflection. \ - The car features a low, aerodynamic body, angular headlights that gaze forward like predatory eyes, \ - and a set of black, high-gloss racing rims that contrast starkly with the red. \ - A subtle hint of chrome embellishes the grille and exhaust, \ - while the tinted windows suggest a luxurious and private interior. \ - he scene conveys a sense of speed and elegance, \ - the car appearing as if it's about to burst into a sprint along a coastal road, \ - with the ocean's azure waves crashing in the background." - ], - help="The text description for generating the image." + "--prompt_file", + type=str, + default="./prompts/example_prompts.txt", + help="A text file of prompts for generating images.", + ) + parser.add_argument( + "--prompt_file_type", + choices=["plain", "parti", "hpsv2"], + default="plain", + help="Type of prompt file.", + ) + parser.add_argument( + "--save_dir", + type=str, + default="./results", + help="Path to save result images.", + ) + parser.add_argument( + "--info_file_save_path", + type=str, + default="./image_info.json", + help="Path to save image information file.", ) parser.add_argument( "--model_path", type=str, default="/data/CogView3B", help="Path to the pre-trained model." @@ -55,12 +171,24 @@ def parse_arguments(): parser.add_argument( "--num_images_per_prompt", type=int, default=1, help="Number of images to generate per prompt." ) + parser.add_argument( + "--max_num_prompts", + default=0, + type=int, + help="Limit the number of prompts (0: no limit).", + ) + parser.add_argument( + "--batch_size", + type=int, + default=1, + help="Batch size." + ) parser.add_argument("--num_inference_steps", type=int, default=50, help="Number of denoising steps for inference.") parser.add_argument("--width", type=int, default=1024, help="Width of the generated image.") parser.add_argument("--height", type=int, default=1024, help="Height of the generated image.") - parser.add_argument("--output_path", type=str, default="cogview3.png", help="Path to save the generated image.") parser.add_argument("--dtype", type=str, default="bf16", help="bf16 or fp16") - parser.add_argument("--device_id", type=int, default=7, help="NPU device id") + parser.add_argument("--seed", type=int, default=None, help="Random seed") + parser.add_argument("--device_id", type=int, default=0, help="NPU device id") return parser.parse_args() @@ -69,37 +197,68 @@ def infer(args): torch.npu.set_device(args.device_id) dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float16 + if not os.path.exists(args.save_dir): + os.makedirs(args.save_dir) + + if args.seed is not None: + set_random_seed(args.seed) + # Load the pre-trained model with the specified precision pipe = CogView3PlusPipeline.from_pretrained(args.model_path, torch_dtype=dtype).to("npu") use_time = 0 - loops = 5 - for i in range(loops): + prompt_loader = PromptLoader(args.prompt_file, + args.prompt_file_type, + args.batch_size, + args.num_images_per_prompt, + args.max_num_prompts) + + infer_num = 0 + image_info = [] + current_prompt = None + for i, input_info in enumerate(prompt_loader): + prompts = input_info['prompts'] + categories = input_info['categories'] + save_names = input_info['save_names'] + n_prompts = input_info['n_prompts'] + + print(f"[{infer_num + n_prompts}/{len(prompt_loader)}]: {prompts}") + infer_num += args.batch_size + start_time = time.time() - # Generate the image based on the prompt - image = pipe( - prompt=args.prompt[0], + images = pipe( + prompt=prompts, guidance_scale=args.guidance_scale, - num_images_per_prompt=args.num_images_per_prompt, num_inference_steps=args.num_inference_steps, image_size=(args.height, args.width), - ).images[0] - - if i >= 2: + ) + + if i > 1: # do not count the time spent inferring the first 0 to 2 images use_time += time.time() - start_time - logger.info("current_time is %.3f )", time.time() - start_time) - torch.npu.empty_cache() - - logger.info("use_time is %.3f)", use_time / 3) + for j in range(n_prompts): + image_save_path = os.path.join(args.save_dir, f"{save_names[j]}.png") + image = images[0][j] + image.save(image_save_path) + + if current_prompt != prompts[j]: + current_prompt = prompts[j] + image_info.append({'images': [], 'prompt': current_prompt, 'category': categories[j]}) + + image_info[-1]['images'].append(image_save_path) - # Save the generated image to the local file system - image.save(args.output_path) + infer_num = infer_num - 2 # do not count the time spent inferring the first 5 images + print(f"[info] infer number: {infer_num}; use time: {use_time:.3f}s\n" + f"average time: {use_time / infer_num:.3f}s\n") - print(f"Image saved to {args.output_path}") + # Save image information to a json file + if os.path.exists(args.info_file_save_path): + os.remove(args.info_file_save_path) + + with os.fdopen(os.open(args.info_file_save_path, os.O_RDWR | os.O_CREAT, 0o640), "w") as f: + json.dump(image_info, f) if __name__ == "__main__": inference_args = parse_arguments() infer(inference_args) - diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/prompts/example_prompts.txt b/MindIE/MindIE-Torch/built-in/foundation/cogview3/prompts/example_prompts.txt new file mode 100644 index 0000000000000000000000000000000000000000..7291dde080b65546a06310cb35b8d7631f66d155 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogview3/prompts/example_prompts.txt @@ -0,0 +1,5 @@ +A vibrant cherry red sports car sits proudly under the gleaming sun, its polished exterior smooth and flawless, casting a mirror-like reflection. The car features a low, aerodynamic body, angular headlights that gaze forward like predatory eyes, and a set of black, high-gloss racing rims that contrast starkly with the red. A subtle hint of chrome embellishes the grille and exhaust, while the tinted windows suggest a luxurious and private interior. The scene conveys a sense of speed and elegance, the car appearing as if it's about to burst into a sprint along a coastal road, with the ocean's azure waves crashing in the background. +A vibrant cherry red sports car sits proudly under the gleaming sun, its polished exterior smooth and flawless, casting a mirror-like reflection. The car features a low, aerodynamic body, angular headlights that gaze forward like predatory eyes, and a set of black, high-gloss racing rims that contrast starkly with the red. A subtle hint of chrome embellishes the grille and exhaust, while the tinted windows suggest a luxurious and private interior. The scene conveys a sense of speed and elegance, the car appearing as if it's about to burst into a sprint along a coastal road, with the ocean's azure waves crashing in the background. +A vibrant cherry red sports car sits proudly under the gleaming sun, its polished exterior smooth and flawless, casting a mirror-like reflection. The car features a low, aerodynamic body, angular headlights that gaze forward like predatory eyes, and a set of black, high-gloss racing rims that contrast starkly with the red. A subtle hint of chrome embellishes the grille and exhaust, while the tinted windows suggest a luxurious and private interior. The scene conveys a sense of speed and elegance, the car appearing as if it's about to burst into a sprint along a coastal road, with the ocean's azure waves crashing in the background. +A vibrant cherry red sports car sits proudly under the gleaming sun, its polished exterior smooth and flawless, casting a mirror-like reflection. The car features a low, aerodynamic body, angular headlights that gaze forward like predatory eyes, and a set of black, high-gloss racing rims that contrast starkly with the red. A subtle hint of chrome embellishes the grille and exhaust, while the tinted windows suggest a luxurious and private interior. The scene conveys a sense of speed and elegance, the car appearing as if it's about to burst into a sprint along a coastal road, with the ocean's azure waves crashing in the background. +A vibrant cherry red sports car sits proudly under the gleaming sun, its polished exterior smooth and flawless, casting a mirror-like reflection. The car features a low, aerodynamic body, angular headlights that gaze forward like predatory eyes, and a set of black, high-gloss racing rims that contrast starkly with the red. A subtle hint of chrome embellishes the grille and exhaust, while the tinted windows suggest a luxurious and private interior. The scene conveys a sense of speed and elegance, the car appearing as if it's about to burst into a sprint along a coastal road, with the ocean's azure waves crashing in the background. \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogview3/requirents.txt b/MindIE/MindIE-Torch/built-in/foundation/cogview3/requirents.txt index 1600434700000cd99216b7d6179326b0d54380e0..b3b2501d42af863ecadb74ea3a24d879adc4d56e 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/cogview3/requirents.txt +++ b/MindIE/MindIE-Torch/built-in/foundation/cogview3/requirents.txt @@ -4,5 +4,5 @@ gradio==5.9.1 accelerate==1.0.1 diffusers==0.31.0 sentencepiece==0.2.0 -torch==2.4.0 +torch==2.1.0 openai==1.58.1 \ No newline at end of file