From e9858dfa05759af44c08b82373d8914e0ff5c5a8 Mon Sep 17 00:00:00 2001 From: MunchLau Date: Tue, 22 Mar 2022 13:39:14 +0000 Subject: [PATCH 1/8] self_check --- .../contrib/cv/classfication/volo/env.sh | 11 + .../contrib/cv/classfication/volo/modify.py | 29 + .../contrib/cv/classfication/volo/readme.md | 78 ++ .../cv/classfication/volo/requirements.txt | 5 + .../classfication/volo/test/eval_acc_perf.sh | 1 + .../cv/classfication/volo/test/pth2om.sh | 10 + .../contrib/cv/classfication/volo/volo.py | 791 ++++++++++++++++++ .../cv/classfication/volo/volo_postprocess.py | 83 ++ .../cv/classfication/volo/volo_preprocess.py | 59 ++ .../cv/classfication/volo/volo_pth2onnx.py | 44 + 10 files changed, 1111 insertions(+) create mode 100755 ACL_PyTorch/contrib/cv/classfication/volo/env.sh create mode 100644 ACL_PyTorch/contrib/cv/classfication/volo/modify.py create mode 100644 ACL_PyTorch/contrib/cv/classfication/volo/readme.md create mode 100644 ACL_PyTorch/contrib/cv/classfication/volo/requirements.txt create mode 100644 ACL_PyTorch/contrib/cv/classfication/volo/test/eval_acc_perf.sh create mode 100644 ACL_PyTorch/contrib/cv/classfication/volo/test/pth2om.sh create mode 100755 ACL_PyTorch/contrib/cv/classfication/volo/volo.py create mode 100755 ACL_PyTorch/contrib/cv/classfication/volo/volo_postprocess.py create mode 100755 ACL_PyTorch/contrib/cv/classfication/volo/volo_preprocess.py create mode 100755 ACL_PyTorch/contrib/cv/classfication/volo/volo_pth2onnx.py diff --git a/ACL_PyTorch/contrib/cv/classfication/volo/env.sh b/ACL_PyTorch/contrib/cv/classfication/volo/env.sh new file mode 100755 index 0000000000..1694c0387f --- /dev/null +++ b/ACL_PyTorch/contrib/cv/classfication/volo/env.sh @@ -0,0 +1,11 @@ +export LD_LIBRARY_PATH=/usr/local/Ascend/ascend-toolkit/latest/lib64:/usr/local/Ascend/ascend-toolkit/latest/compiler/lib64/plugin/opskernel:/usr/local/Ascend/ascend-toolkit/latest/compiler/lib64/plugin/nnengine:$LD_LIBRARY_PATH +export PYTHONPATH=/usr/local/Ascend/ascend-toolkit/latest/python/site-packages:/usr/local/Ascend/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe:$PYTHONPATH +export PATH=/usr/local/Ascend/ascend-toolkit/latest/bin:/usr/local/Ascend/ascend-toolkit/latest/compiler/ccec_compiler/bin:$PATH +export ASCEND_AICPU_PATH=/usr/local/Ascend/ascend-toolkit/latest +export ASCEND_OPP_PATH=/usr/local/Ascend/ascend-toolkit/latest/opp +export TOOLCHAIN_HOME=/usr/local/Ascend/ascend-toolkit/latest/toolkit +export ASCEND_HOME_PATH=/usr/local/Ascend/ascend-toolkit/latest +export DDK_PATH=/usr/local/Ascend/ascend-toolkit/latest +export NPU_HOST_LIB=/usr/local/Ascend/ascend-toolkit/latest/acllib/lib64/stub:/usr/local/Ascend/ascend-toolkit/latest/x86_64-linux/runtime/lib64/stub +export LD_LIBRARY_PATH=/usr/local/Ascend/driver/lib64/:${LD_LIBRARY_PATH} +export LD_LIBRARY_PATH=/usr/local/Ascend/ascend-toolkit/latest/x86_64-linux/fwkacllib/lib64:${LD_LIBRARY_PATH} \ No newline at end of file diff --git a/ACL_PyTorch/contrib/cv/classfication/volo/modify.py b/ACL_PyTorch/contrib/cv/classfication/volo/modify.py new file mode 100644 index 0000000000..9cbe145294 --- /dev/null +++ b/ACL_PyTorch/contrib/cv/classfication/volo/modify.py @@ -0,0 +1,29 @@ +import numpy as np +from MagicONNX.magiconnx import OnnxGraph +import argparse + +INT32_MAX = 2147483647 +INT32_MIN = -2147483648 + +def modify(path, output): + graph = OnnxGraph(path) + col2ims = graph.get_nodes("Col2im") + for idx, node in enumerate(col2ims): + attr = node['output_size'] + node.attrs.pop("output_size") + new_init = graph.add_initializer(f'output_size_{node.name}', np.array(attr).astype(np.int32)) + node.inputs = [node.inputs[0], f'output_size_{node.name}'] + + graph.save(output) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='modify the onnx node') + parser.add_argument('--src', type=str, default='./d1_224_84.2.pth.tar', + help='weights of pytorch dir') + parser.add_argument('--des', type=str, default='./volo_d1_224_Col2im.onnx', + help='weights of onnx dir') + args = parser.parse_args() + modify(args.src, args.des) + print("modify the onnx successfully!") + + diff --git a/ACL_PyTorch/contrib/cv/classfication/volo/readme.md b/ACL_PyTorch/contrib/cv/classfication/volo/readme.md new file mode 100644 index 0000000000..da4f3d09a2 --- /dev/null +++ b/ACL_PyTorch/contrib/cv/classfication/volo/readme.md @@ -0,0 +1,78 @@ +# VOLO + +This implements training of volo_d1 on the ImageNet-2012 dataset and token labeling, mainly modified from [sail-sg/volo](https://github.com/sail-sg/volo). + +## VOLO Detail + +There is an error of Col2im operator on pth2onnx, define the OP in volo.py. +- The check of onnx should be commented out. +Example: +File "python3.8/site-packages/torch/onnx/utils.py", line 785, in _export +```bash +#if (operator_export_type is OperatorExportTypes.ONNX) and (not val_use_external_data_format): + #try: + #_check_onnx_proto(proto) + #except RuntimeError as e: + #raise CheckerError(e) +``` +## Requirements + +- Prepare the checkpoint of pytorch +- `pip install -r requirements.txt` +- Download the Imagenet-2012 dataset. Refer to the original repository https://github.com/rwightman/pytorch-image-models +- install MagicONNX + ```bash + git clone https://gitee.com/Ronnie_zheng/MagicONNX.git + cd MagicONNX + pip install . + ``` +- compile msame + reference from https://gitee.com/ascend/tools/tree/master/msame +```bash + git clone https://gitee.com/ascend/tools.git + #如下为设置环境变量的示例,请将/home/HwHiAiUser/Ascend/ascend-toolkit/latest替换为Ascend 的ACLlib安装包的实际安装路径。 + export DDK_PATH=/home/HwHiAiUser/Ascend/ascend-toolkit/latest + export NPU_HOST_LIB=/home/HwHiAiUser/Ascend/ascend-toolkit/latest/acllib/lib64/stub + + cd $HOME/AscendProjects/tools/msame/ + ./build.sh g++ $HOME/AscendProjects/tools/msame/out +``` + +## preprocess the dataset + +Because we use msame to inference, so we should preprocess origin dataset to `.bin` file. +And different batchsize should be different binary file. The command is below: + +```bash +python volo_preprocess.py --src /opt/npu/val --des /opt/npu/data_bs1 --batchsize 1 +python volo_preprocess.py --src /opt/npu/val --des /opt/npu/data_bs16 --batchsize 16 +``` +Then we get the binary dataset in `/opt/npu/data_bs1` or `/opt/npu/data_bs16` and also the label txt.The file named `volo_val_bs1.txt` or `volo_val_bs16.txt` + +## Inference +```bash +# pth2om for batchsize 1 +bash test/pth2om.sh d1_224_84.pth.tar volo_bs1.onnx volo_modify_bs1.onnx volo_bs1 1 "input:1,3,224,224" +# pth2om for batchsize 16 +bash test/pth2om.sh d1_224_84.pth.tar volo_bs16.onnx volo_modify_bs16.onnx volo_bs16 16 "input:16,3,224,224" + +# inference with batchsize 1 with performance +./msame --model "volo_bs1.om" --input "/opt/npu/data_bs1" --output "./" --outfmt TXT + +# inference with batchsize 16 with performance +./msame --model "volo_bs16.om" --input "/opt/npu/data_bs16" --output "./" --outfmt TXT + +# compute the val accuracy, modify the batchsize, result dir and label dir +bash eval_acc_perf.sh +``` + +## Volo inference result +| accuracy | top1 | top2 | top3 | top4 | top5 | +| :------: | :---: | :---: | :---: | :---: | :---: | +| bs1 | - | O2 | 1 | 152.37 | +| bs16 | - | O2 | 1 | 23.26 | + +| performance | average time | average time without first | +| :---------: | :-----------: | :-------------------------: | +| bs1 | 396.46ms | 396.46ms | +| bs16 | 3635.25ms | 3635.25ms | diff --git a/ACL_PyTorch/contrib/cv/classfication/volo/requirements.txt b/ACL_PyTorch/contrib/cv/classfication/volo/requirements.txt new file mode 100644 index 0000000000..06ad02f761 --- /dev/null +++ b/ACL_PyTorch/contrib/cv/classfication/volo/requirements.txt @@ -0,0 +1,5 @@ +timm +torch==1.7.0 +torchvision==0.8.0 +numpy +pillow diff --git a/ACL_PyTorch/contrib/cv/classfication/volo/test/eval_acc_perf.sh b/ACL_PyTorch/contrib/cv/classfication/volo/test/eval_acc_perf.sh new file mode 100644 index 0000000000..fac1b5971a --- /dev/null +++ b/ACL_PyTorch/contrib/cv/classfication/volo/test/eval_acc_perf.sh @@ -0,0 +1 @@ +python volo_postprocess.py --batchsize 1 --result 2022321_14_50_42_791955 --label ./volo_val_bs1.txt \ No newline at end of file diff --git a/ACL_PyTorch/contrib/cv/classfication/volo/test/pth2om.sh b/ACL_PyTorch/contrib/cv/classfication/volo/test/pth2om.sh new file mode 100644 index 0000000000..11ebf0ccd8 --- /dev/null +++ b/ACL_PyTorch/contrib/cv/classfication/volo/test/pth2om.sh @@ -0,0 +1,10 @@ +python volo_pth2onnx.py --src $1 --des $2 --batchsize $5 +python modify.py --src $2 --des $3 +atc --model=$3 \ + --framework=5 \ + --output=$4 \ + --input_format=NCHW \ + --input_shape=$6 \ + --log=debug \ + --soc_version=Ascend310 \ + --buffer_optimize=off_optimize diff --git a/ACL_PyTorch/contrib/cv/classfication/volo/volo.py b/ACL_PyTorch/contrib/cv/classfication/volo/volo.py new file mode 100755 index 0000000000..69b042d369 --- /dev/null +++ b/ACL_PyTorch/contrib/cv/classfication/volo/volo.py @@ -0,0 +1,791 @@ +# Copyright 2021 Sea Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Vision OutLOoker (VOLO) implementation +""" +import torch +import torch.nn as nn + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from timm.models.registry import register_model +import math +import numpy as np +from torch.nn.modules.utils import _pair + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .96, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.proj', 'classifier': 'head', + **kwargs + } + + +default_cfgs = { + 'volo': _cfg(crop_pct=0.96), + 'volo_large': _cfg(crop_pct=1.15), +} + +def fold(input, output_size, kernel_size, dilation=1, padding=0, stride=1): + return col2im_op(input, _pair(output_size), _pair(kernel_size), _pair(dilation), _pair(padding), _pair(stride)) + + +def col2im_op(x, output_size, kernel_size, dilation, padding, stride): + out = Col2ImOp.apply(x, output_size, kernel_size, dilation, padding, stride) + return out + +class Col2ImOp(torch.autograd.Function): + @staticmethod + def forward(ctx, x, output_size, kernel_size, dilation, padding, stride): + out = torch.randn(x.shape[0], x.shape[1], output_size[0], output_size[1]).to(x.dtype) + return out + @staticmethod + def symbolic(g, x, output_size, kernel_size, dilation, padding, stride): + out = g.op('Col2im', x, output_size_i = output_size, kernel_size_i = kernel_size, dilation_i = dilation, padding_i = padding, stride_i = stride) + return out + +class OutlookAttention(nn.Module): + """ + Implementation of outlook attention + --dim: hidden dim + --num_heads: number of heads + --kernel_size: kernel size in each window for outlook attention + return: token features after outlook attention + """ + + def __init__(self, dim, num_heads, kernel_size=3, padding=1, stride=1, + qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + head_dim = dim // num_heads + self.num_heads = num_heads + self.kernel_size = kernel_size + self.padding = padding + self.stride = stride + self.scale = qk_scale or head_dim**-0.5 + + self.v = nn.Linear(dim, dim, bias=qkv_bias) + self.attn = nn.Linear(dim, kernel_size**4 * num_heads) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.unfold = nn.Unfold(kernel_size=kernel_size, padding=padding, stride=stride) + self.pool = nn.AvgPool2d(kernel_size=stride, stride=stride, ceil_mode=True) + + def forward(self, x): + B, H, W, C = x.shape + + v = self.v(x).permute(0, 3, 1, 2) # B, C, H, W + + h, w = math.ceil(H / self.stride), math.ceil(W / self.stride) + v = self.unfold(v).reshape(B, self.num_heads, C // self.num_heads, + self.kernel_size * self.kernel_size, + h * w).permute(0, 1, 4, 3, 2) # B,H,N,kxk,C/H + + attn = self.pool(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) + attn = self.attn(attn).reshape( + B, h * w, self.num_heads, self.kernel_size * self.kernel_size, + self.kernel_size * self.kernel_size).permute(0, 2, 1, 3, 4) # B,H,N,kxk,kxk + attn = attn * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + #ori + #x = (attn @ v).permute(0, 1, 4, 3, 2).reshape( + #B, C * self.kernel_size * self.kernel_size, h * w) + x = (attn @ v).permute(0, 1, 4, 3, 2).reshape( + B, C, self.kernel_size * self.kernel_size , h * w) + x = fold(x, output_size=(H, W), kernel_size=self.kernel_size, + padding=self.padding, stride=self.stride) + + x = self.proj(x.permute(0, 2, 3, 1)) + x = self.proj_drop(x) + + return x + + +class Outlooker(nn.Module): + """ + Implementation of outlooker layer: which includes outlook attention + MLP + Outlooker is the first stage in our VOLO + --dim: hidden dim + --num_heads: number of heads + --mlp_ratio: mlp ratio + --kernel_size: kernel size in each window for outlook attention + return: outlooker layer + """ + def __init__(self, dim, kernel_size, padding, stride=1, + num_heads=1,mlp_ratio=3., attn_drop=0., + drop_path=0., act_layer=nn.GELU, + norm_layer=nn.LayerNorm, qkv_bias=False, + qk_scale=None): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = OutlookAttention(dim, num_heads, kernel_size=kernel_size, + padding=padding, stride=stride, + qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop) + + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer) + + def forward(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class Mlp(nn.Module): + "Implementation of MLP" + + def __init__(self, in_features, hidden_features=None, + out_features=None, act_layer=nn.GELU, + drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + "Implementation of self-attention" + + def __init__(self, dim, num_heads=8, qkv_bias=False, + qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, H, W, C = x.shape + + qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, + C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[ + 2] # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, H, W, C) + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +class Transformer(nn.Module): + """ + Implementation of Transformer, + Transformer is the second stage in our VOLO + """ + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, + qk_scale=None, attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, + qk_scale=qk_scale, attn_drop=attn_drop) + + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer) + + def forward(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class ClassAttention(nn.Module): + """ + Class attention layer from CaiT, see details in CaiT + Class attention is the post stage in our VOLO, which is optional. + """ + def __init__(self, dim, num_heads=8, head_dim=None, qkv_bias=False, + qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + if head_dim is not None: + self.head_dim = head_dim + else: + head_dim = dim // num_heads + self.head_dim = head_dim + self.scale = qk_scale or head_dim**-0.5 + + self.kv = nn.Linear(dim, + self.head_dim * self.num_heads * 2, + bias=qkv_bias) + self.q = nn.Linear(dim, self.head_dim * self.num_heads, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(self.head_dim * self.num_heads, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + + kv = self.kv(x).reshape(B, N, 2, self.num_heads, + self.head_dim).permute(2, 0, 3, 1, 4) + k, v = kv[0], kv[ + 1] # make torchscript happy (cannot use tensor as tuple) + q = self.q(x[:, :1, :]).reshape(B, self.num_heads, 1, self.head_dim) + attn = ((q * self.scale) @ k.transpose(-2, -1)) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + cls_embed = (attn @ v).transpose(1, 2).reshape( + B, 1, self.head_dim * self.num_heads) + cls_embed = self.proj(cls_embed) + cls_embed = self.proj_drop(cls_embed) + return cls_embed + + +class ClassBlock(nn.Module): + """ + Class attention block from CaiT, see details in CaiT + We use two-layers class attention in our VOLO, which is optional. + """ + + def __init__(self, dim, num_heads, head_dim=None, mlp_ratio=4., + qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = ClassAttention( + dim, num_heads=num_heads, head_dim=head_dim, qkv_bias=qkv_bias, + qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + # NOTE: drop path for stochastic depth + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop) + + def forward(self, x): + cls_embed = x[:, :1] + cls_embed = cls_embed + self.drop_path(self.attn(self.norm1(x))) + cls_embed = cls_embed + self.drop_path(self.mlp(self.norm2(cls_embed))) + return torch.cat([cls_embed, x[:, 1:]], dim=1) + + +def get_block(block_type, **kargs): + """ + get block by name, specifically for class attention block in here + """ + if block_type == 'ca': + return ClassBlock(**kargs) + + +def rand_bbox(size, lam, scale=1): + """ + get bounding box as token labeling (https://github.com/zihangJiang/TokenLabeling) + return: bounding box + """ + W = size[1] // scale + H = size[2] // scale + cut_rat = np.sqrt(1. - lam) + cut_w = np.int(W * cut_rat) + cut_h = np.int(H * cut_rat) + + # uniform + cx = np.random.randint(W) + cy = np.random.randint(H) + + bbx1 = np.clip(cx - cut_w // 2, 0, W) + bby1 = np.clip(cy - cut_h // 2, 0, H) + bbx2 = np.clip(cx + cut_w // 2, 0, W) + bby2 = np.clip(cy + cut_h // 2, 0, H) + + return bbx1, bby1, bbx2, bby2 + + +class PatchEmbed(nn.Module): + """ + Image to Patch Embedding. + Different with ViT use 1 conv layer, we use 4 conv layers to do patch embedding + """ + + def __init__(self, img_size=224, stem_conv=False, stem_stride=1, + patch_size=8, in_chans=3, hidden_dim=64, embed_dim=384): + super().__init__() + assert patch_size in [4, 8, 16] + + self.stem_conv = stem_conv + if stem_conv: + self.conv = nn.Sequential( + nn.Conv2d(in_chans, hidden_dim, kernel_size=7, stride=stem_stride, + padding=3, bias=False), # 112x112 + nn.BatchNorm2d(hidden_dim), + nn.ReLU(inplace=True), + nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1, + padding=1, bias=False), # 112x112 + nn.BatchNorm2d(hidden_dim), + nn.ReLU(inplace=True), + nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1, + padding=1, bias=False), # 112x112 + nn.BatchNorm2d(hidden_dim), + nn.ReLU(inplace=True), + ) + + self.proj = nn.Conv2d(hidden_dim, + embed_dim, + kernel_size=patch_size // stem_stride, + stride=patch_size // stem_stride) + self.num_patches = (img_size // patch_size) * (img_size // patch_size) + + def forward(self, x): + if self.stem_conv: + x = self.conv(x) + x = self.proj(x) # B, C, H, W + return x + + +class Downsample(nn.Module): + """ + Image to Patch Embedding, downsampling between stage1 and stage2 + """ + def __init__(self, in_embed_dim, out_embed_dim, patch_size): + super().__init__() + self.proj = nn.Conv2d(in_embed_dim, out_embed_dim, + kernel_size=patch_size, stride=patch_size) + + def forward(self, x): + x = x.permute(0, 3, 1, 2) + x = self.proj(x) # B, C, H, W + x = x.permute(0, 2, 3, 1) + return x + + +def outlooker_blocks(block_fn, index, dim, layers, num_heads=1, kernel_size=3, + padding=1,stride=1, mlp_ratio=3., qkv_bias=False, qk_scale=None, + attn_drop=0, drop_path_rate=0., **kwargs): + """ + generate outlooker layer in stage1 + return: outlooker layers + """ + blocks = [] + for block_idx in range(layers[index]): + block_dpr = drop_path_rate * (block_idx + + sum(layers[:index])) / (sum(layers) - 1) + blocks.append(block_fn(dim, kernel_size=kernel_size, padding=padding, + stride=stride, num_heads=num_heads, mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, + drop_path=block_dpr)) + + blocks = nn.Sequential(*blocks) + + return blocks + + +def transformer_blocks(block_fn, index, dim, layers, num_heads, mlp_ratio=3., + qkv_bias=False, qk_scale=None, attn_drop=0, + drop_path_rate=0., **kwargs): + """ + generate transformer layers in stage2 + return: transformer layers + """ + blocks = [] + for block_idx in range(layers[index]): + block_dpr = drop_path_rate * (block_idx + + sum(layers[:index])) / (sum(layers) - 1) + blocks.append( + block_fn(dim, num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + drop_path=block_dpr)) + + blocks = nn.Sequential(*blocks) + + return blocks + + +class VOLO(nn.Module): + """ + Vision Outlooker, the main class of our model + --layers: [x,x,x,x], four blocks in two stages, the first block is outlooker, the + other three are transformer, we set four blocks, which are easily + applied to downstream tasks + --img_size, --in_chans, --num_classes: these three are very easy to understand + --patch_size: patch_size in outlook attention + --stem_hidden_dim: hidden dim of patch embedding, d1-d4 is 64, d5 is 128 + --embed_dims, --num_heads: embedding dim, number of heads in each block + --downsamples: flags to apply downsampling or not + --outlook_attention: flags to apply outlook attention or not + --mlp_ratios, --qkv_bias, --qk_scale, --drop_rate: easy to undertand + --attn_drop_rate, --drop_path_rate, --norm_layer: easy to undertand + --post_layers: post layers like two class attention layers using [ca, ca], + if yes, return_mean=False + --return_mean: use mean of all feature tokens for classification, if yes, no class token + --return_dense: use token labeling, details are here: + https://github.com/zihangJiang/TokenLabeling + --mix_token: mixing tokens as token labeling, details are here: + https://github.com/zihangJiang/TokenLabeling + --pooling_scale: pooling_scale=2 means we downsample 2x + --out_kernel, --out_stride, --out_padding: kerner size, + stride, and padding for outlook attention + """ + def __init__(self, layers, img_size=224, in_chans=3, num_classes=1000, patch_size=8, + stem_hidden_dim=64, embed_dims=None, num_heads=None, downsamples=None, + outlook_attention=None, mlp_ratios=None, qkv_bias=False, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, + post_layers=None, return_mean=False, return_dense=True, mix_token=True, + pooling_scale=2, out_kernel=3, out_stride=2, out_padding=1): + + super().__init__() + self.num_classes = num_classes + self.patch_embed = PatchEmbed(stem_conv=True, stem_stride=2, patch_size=patch_size, + in_chans=in_chans, hidden_dim=stem_hidden_dim, + embed_dim=embed_dims[0]) + + # inital positional encoding, we add positional encoding after outlooker blocks + self.pos_embed = nn.Parameter( + torch.zeros(1, img_size // patch_size // pooling_scale, + img_size // patch_size // pooling_scale, + embed_dims[-1])) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # set the main block in network + network = [] + for i in range(len(layers)): + if outlook_attention[i]: + # stage 1 + stage = outlooker_blocks(Outlooker, i, embed_dims[i], layers, + downsample=downsamples[i], num_heads=num_heads[i], + kernel_size=out_kernel, stride=out_stride, + padding=out_padding, mlp_ratio=mlp_ratios[i], + qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop_rate, norm_layer=norm_layer) + network.append(stage) + else: + # stage 2 + stage = transformer_blocks(Transformer, i, embed_dims[i], layers, + num_heads[i], mlp_ratio=mlp_ratios[i], + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop_path_rate=drop_path_rate, + attn_drop=attn_drop_rate, + norm_layer=norm_layer) + network.append(stage) + + if downsamples[i]: + # downsampling between two stages + network.append(Downsample(embed_dims[i], embed_dims[i + 1], 2)) + + self.network = nn.ModuleList(network) + + # set post block, for example, class attention layers + self.post_network = None + if post_layers is not None: + self.post_network = nn.ModuleList([ + get_block(post_layers[i], + dim=embed_dims[-1], + num_heads=num_heads[-1], + mlp_ratio=mlp_ratios[-1], + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop_rate, + drop_path=0., + norm_layer=norm_layer) + for i in range(len(post_layers)) + ]) + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims[-1])) + trunc_normal_(self.cls_token, std=.02) + + # set output type + self.return_mean = return_mean # if yes, return mean, not use class token + self.return_dense = return_dense # if yes, return class token and all feature tokens + if return_dense: + assert not return_mean, "cannot return both mean and dense" + self.mix_token = mix_token + self.pooling_scale = pooling_scale + if mix_token: # enable token mixing, see token labeling for details. + self.beta = 1.0 + assert return_dense, "return all tokens if mix_token is enabled" + if return_dense: + self.aux_head = nn.Linear( + embed_dims[-1], + num_classes) if num_classes > 0 else nn.Identity() + self.norm = norm_layer(embed_dims[-1]) + + # Classifier head + self.head = nn.Linear( + embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity() + + trunc_normal_(self.pos_embed, std=.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes): + self.num_classes = num_classes + self.head = nn.Linear( + self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_embeddings(self, x): + # patch embedding + x = self.patch_embed(x) + # B,C,H,W-> B,H,W,C + x = x.permute(0, 2, 3, 1) + return x + + def forward_tokens(self, x): + for idx, block in enumerate(self.network): + if idx == 2: # add positional encoding after outlooker blocks + x = x + self.pos_embed + x = self.pos_drop(x) + x = block(x) + + B, H, W, C = x.shape + x = x.reshape(B, -1, C) + return x + + def forward_cls(self, x): + B, N, C = x.shape + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + for block in self.post_network: + x = block(x) + return x + + def forward(self, x): + # step1: patch embedding + x = self.forward_embeddings(x) + + # mix token, see token labeling for details. + if self.mix_token and self.training: + lam = np.random.beta(self.beta, self.beta) + patch_h, patch_w = x.shape[1] // self.pooling_scale, x.shape[ + 2] // self.pooling_scale + bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(), lam, scale=self.pooling_scale) + temp_x = x.clone() + sbbx1,sbby1,sbbx2,sbby2=self.pooling_scale*bbx1,self.pooling_scale*bby1,\ + self.pooling_scale*bbx2,self.pooling_scale*bby2 + temp_x[:, sbbx1:sbbx2, sbby1:sbby2, :] = x.flip(0)[:, sbbx1:sbbx2, sbby1:sbby2, :] + x = temp_x + else: + bbx1, bby1, bbx2, bby2 = 0, 0, 0, 0 + + # step2: tokens learning in the two stages + x = self.forward_tokens(x) + + # step3: post network, apply class attention or not + if self.post_network is not None: + x = self.forward_cls(x) + x = self.norm(x) + + if self.return_mean: # if no class token, return mean + return self.head(x.mean(1)) + + x_cls = self.head(x[:, 0]) + if not self.return_dense: + return x_cls + + x_aux = self.aux_head( + x[:, 1:] + ) # generate classes in all feature tokens, see token labeling + + if not self.training: + return x_cls + 0.5 * x_aux.max(1)[0] + + if self.mix_token and self.training: # reverse "mix token", see token labeling for details. + x_aux = x_aux.reshape(x_aux.shape[0], patch_h, patch_w, x_aux.shape[-1]) + + temp_x = x_aux.clone() + temp_x[:, bbx1:bbx2, bby1:bby2, :] = x_aux.flip(0)[:, bbx1:bbx2, bby1:bby2, :] + x_aux = temp_x + + x_aux = x_aux.reshape(x_aux.shape[0], patch_h * patch_w, x_aux.shape[-1]) + + # return these: 1. class token, 2. classes from all feature tokens, 3. bounding box + return x_cls, x_aux, (bbx1, bby1, bbx2, bby2) + + +@register_model +def volo_d1(pretrained=False, **kwargs): + """ + VOLO-D1 model, Params: 27M + --layers: [x,x,x,x], four blocks in two stages, the first stage(block) is outlooker, + the other three blocks are transformer, we set four blocks, which are easily + applied to downstream tasks + --embed_dims, --num_heads,: embedding dim, number of heads in each block + --downsamples: flags to apply downsampling or not in four blocks + --outlook_attention: flags to apply outlook attention or not + --mlp_ratios: mlp ratio in four blocks + --post_layers: post layers like two class attention layers using [ca, ca] + See detail for all args in the class VOLO() + """ + layers = [4, 4, 8, 2] # num of layers in the four blocks + embed_dims = [192, 384, 384, 384] + num_heads = [6, 12, 12, 12] + mlp_ratios = [3, 3, 3, 3] + downsamples = [True, False, False, False] # do downsampling after first block + outlook_attention = [True, False, False, False ] + # first block is outlooker (stage1), the other three are transformer (stage2) + model = VOLO(layers, + embed_dims=embed_dims, + num_heads=num_heads, + mlp_ratios=mlp_ratios, + downsamples=downsamples, + outlook_attention=outlook_attention, + post_layers=['ca', 'ca'], + **kwargs) + model.default_cfg = default_cfgs['volo'] + return model + + +@register_model +def volo_d2(pretrained=False, **kwargs): + """ + VOLO-D2 model, Params: 59M + """ + layers = [6, 4, 10, 4] + embed_dims = [256, 512, 512, 512] + num_heads = [8, 16, 16, 16] + mlp_ratios = [3, 3, 3, 3] + downsamples = [True, False, False, False] + outlook_attention = [True, False, False, False] + model = VOLO(layers, + embed_dims=embed_dims, + num_heads=num_heads, + mlp_ratios=mlp_ratios, + downsamples=downsamples, + outlook_attention=outlook_attention, + post_layers=['ca', 'ca'], + **kwargs) + model.default_cfg = default_cfgs['volo'] + return model + + +@register_model +def volo_d3(pretrained=False, **kwargs): + """ + VOLO-D3 model, Params: 86M + """ + layers = [8, 8, 16, 4] + embed_dims = [256, 512, 512, 512] + num_heads = [8, 16, 16, 16] + mlp_ratios = [3, 3, 3, 3] + downsamples = [True, False, False, False] + outlook_attention = [True, False, False, False] + model = VOLO(layers, + embed_dims=embed_dims, + num_heads=num_heads, + mlp_ratios=mlp_ratios, + downsamples=downsamples, + outlook_attention=outlook_attention, + post_layers=['ca', 'ca'], + **kwargs) + model.default_cfg = default_cfgs['volo'] + return model + + +@register_model +def volo_d4(pretrained=False, **kwargs): + """ + VOLO-D4 model, Params: 193M + """ + layers = [8, 8, 16, 4] + embed_dims = [384, 768, 768, 768] + num_heads = [12, 16, 16, 16] + mlp_ratios = [3, 3, 3, 3] + downsamples = [True, False, False, False] + outlook_attention = [True, False, False, False] + model = VOLO(layers, + embed_dims=embed_dims, + num_heads=num_heads, + mlp_ratios=mlp_ratios, + downsamples=downsamples, + outlook_attention=outlook_attention, + post_layers=['ca', 'ca'], + **kwargs) + model.default_cfg = default_cfgs['volo_large'] + return model + + +@register_model +def volo_d5(pretrained=False, **kwargs): + """ + VOLO-D5 model, Params: 296M + stem_hidden_dim=128, the dim in patch embedding is 128 for VOLO-D5 + """ + layers = [12, 12, 20, 4] + embed_dims = [384, 768, 768, 768] + num_heads = [12, 16, 16, 16] + mlp_ratios = [4, 4, 4, 4] + downsamples = [True, False, False, False] + outlook_attention = [True, False, False, False] + model = VOLO(layers, + embed_dims=embed_dims, + num_heads=num_heads, + mlp_ratios=mlp_ratios, + downsamples=downsamples, + outlook_attention=outlook_attention, + post_layers=['ca', 'ca'], + stem_hidden_dim=128, + **kwargs) + model.default_cfg = default_cfgs['volo_large'] + return model diff --git a/ACL_PyTorch/contrib/cv/classfication/volo/volo_postprocess.py b/ACL_PyTorch/contrib/cv/classfication/volo/volo_postprocess.py new file mode 100755 index 0000000000..2d8d8ce00e --- /dev/null +++ b/ACL_PyTorch/contrib/cv/classfication/volo/volo_postprocess.py @@ -0,0 +1,83 @@ +import os +import numpy as np +import argparse + +def read_txt_data(path): + line = "" + with open(path, 'r') as f: + line = f.read() + if line != "": + return np.array([float(s) for s in line.split(" ") if s != "" and s != "\n"]) + return None + +def read_label(path, bs): + with open(path, 'r') as f: + content = f.read() + lines = [line for line in content.split('\n')] + if lines[-1] == "": + lines = lines[:-1] + if bs == 16: + total_label = np.zeros((len(files) * bs)) + base = 0 + for line in lines: + labels = line.split(' ')[1:-1] + labels = [int(label) for label in labels] + for i in range(len(labels)): + total_label[base * bs + i] = labels[i] + base = base + 1 + total_label = np.expand_dims(total_label, 1) + return total_label + if bs == 1: + labels = [int(line.split(' ')[-2]) for line in lines] + labels = np.array(labels) + labels = np.expand_dims(labels, 1) + return labels + +def get_topK(files, topk, bs): + if bs == 1: + matrix = np.zeros((len(files), topk)) + if bs ==16: + matrix = np.zeros((len(files) * bs, topk)) + for file in files: + data = read_txt_data(root + file) + if bs == 1: + line = np.argsort(data)[-topk:][::-1] + index = int(file.split('_')[1]) + matrix[index-1, :] = line[:topk] + if bs == 16: + base_index = int(file.split('_')[1]) + newdata = data.reshape(bs, 1000) + for i in range(bs): + line = np.argsort(newdata[i,:])[-topk:][::-1] + matrix[base_index * bs + i, :] = line[:topk] + return matrix.astype(np.int64) + +def get_topK_acc(matrix, labels, k): + matrix_tmp = matrix[:, :k] + match_array = np.logical_or.reduce(matrix_tmp==labels, axis=1) + topk_acc = match_array.sum() / match_array.shape[0] + return topk_acc + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='VOLO validation') + parser.add_argument('--batchsize', type=int, default='1', + help='batchsize.') + parser.add_argument('--result', type=str, default='./', + help='output dir of msame') + parser.add_argument('--label', type=str, default='./volo_val_bs1.txt', + help='label txt dir') + args = parser.parse_args() + root = args.result + bs = args.batchsize + label_dir = args.label + files = None + if os.path.exists(root): + files=os.listdir(root) + else: + print('this path not exist') + exit(0) + matrix = get_topK(files, 6, bs) + labels = read_label(label_dir, bs) + for i in range(1, 6): + acc = get_topK_acc(matrix, labels, i) + print("acc@top{}: {:.3f}%".format(i, 100*acc)) \ No newline at end of file diff --git a/ACL_PyTorch/contrib/cv/classfication/volo/volo_preprocess.py b/ACL_PyTorch/contrib/cv/classfication/volo/volo_preprocess.py new file mode 100755 index 0000000000..95ddd03c7f --- /dev/null +++ b/ACL_PyTorch/contrib/cv/classfication/volo/volo_preprocess.py @@ -0,0 +1,59 @@ +import sys +from timm.data import create_loader, ImageDataset +import os +import numpy as np +import argparse + +os.environ['device'] = 'cpu' + +def preprocess_volo(data_dir, save_path, batch_size): + f = open("volo_val_bs"+str(batch_size)+".txt", "w") + + loader = create_loader( + ImageDataset(data_dir), + input_size=(3, 224, 224), + batch_size=batch_size, + use_prefetcher=True, + interpolation="bicubic", + mean=(0.485, 0.456, 0.406), + std=(0.229, 0.224, 0.225), + num_workers=4, + crop_pct=0.96, + pin_memory=False, + tf_preprocessing=False) + + for batch_idx, (input, target) in enumerate(loader): + img = np.array(input).astype(np.float32) + save_name = os.path.join(save_path, "test_" + str(batch_idx) + ".bin") + img.tofile(save_name) + if batch_size == 1: + info = "%s %d \n" % ("test_" + str(batch_idx) + ".bin", target) + if batch_size == 16: + info = "%s %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d \n" % ("test_" + str(batch_idx) + ".bin", \ + target[0], target[1], target[2], target[3], target[4], target[5], target[6], target[7], target[8], \ + target[9], target[10], target[11], target[12], target[13], target[14], target[15]) + f.write(info) + + f.close() + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Imagenet val_dataset preprocess') + parser.add_argument('--src', type=str, default='./', + help='imagenet val dir.') + parser.add_argument('--des', type=str, default='./', + help='preprocess dataset dir.') + parser.add_argument('--batchsize', type=int, default='1', + help='batchsize.') + args = parser.parse_args() + src = args.src + des = args.des + bs = args.batchsize + files = None + if not os.path.exists(src): + print('this path not exist') + exit(0) + os.makedirs(des, exist_ok=True) + preprocess_volo(src, des, bs) + + # python volo_224_preprocess.py --src /opt/npu/val --des /opt/npu/data_bs1 --batchsize 1 \ No newline at end of file diff --git a/ACL_PyTorch/contrib/cv/classfication/volo/volo_pth2onnx.py b/ACL_PyTorch/contrib/cv/classfication/volo/volo_pth2onnx.py new file mode 100755 index 0000000000..0f3f83d90a --- /dev/null +++ b/ACL_PyTorch/contrib/cv/classfication/volo/volo_pth2onnx.py @@ -0,0 +1,44 @@ +import torch +import torch.onnx +from timm.models import create_model, load_checkpoint +import os +from volo import * +import argparse + +def pth_to_onnx(input, checkpoint, onnx_path, input_names=['input'], output_names=['output'], device='cpu'): + if not onnx_path.endswith('.onnx'): + print('Warning! The onnx model name is not correct,\ + please give a name that ends with \'.onnx\'!') + return 0 + + model = create_model( + 'volo_d1', + pretrained=False, + num_classes=None, + in_chans=3, + global_pool=None, + scriptable=False, + img_size=224) + load_checkpoint(model, checkpoint, False, strict=False) + model.eval() + + torch.onnx.export(model, input, onnx_path, verbose=True, input_names=input_names, output_names=output_names, opset_version=12, export_modules_as_functions=False) + print("Exporting .pth model to onnx model has been successful!") + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='PyTorch to onnx') + parser.add_argument('--src', type=str, default='./d1_224_84.2.pth.tar', + help='weights of pytorch dir') + parser.add_argument('--des', type=str, default='./volo_d1_224_Col2im.onnx', + help='weights of onnx dir') + parser.add_argument('--batchsize', type=int, default='1', + help='batchsize.') + args = parser.parse_args() + checkpoint = args.src + onnx_path = args.des + bs = args.batchsize + input = torch.randn(bs, 3, 224, 224) + pth_to_onnx(input, checkpoint, onnx_path) + + + -- Gitee From cb6987b7a2ca1e17a1ba2c17080c1bf760547d93 Mon Sep 17 00:00:00 2001 From: MunchLau Date: Tue, 22 Mar 2022 13:46:41 +0000 Subject: [PATCH 2/8] readme_top1 --- ACL_PyTorch/contrib/cv/classfication/volo/readme.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ACL_PyTorch/contrib/cv/classfication/volo/readme.md b/ACL_PyTorch/contrib/cv/classfication/volo/readme.md index da4f3d09a2..9d031545ad 100644 --- a/ACL_PyTorch/contrib/cv/classfication/volo/readme.md +++ b/ACL_PyTorch/contrib/cv/classfication/volo/readme.md @@ -67,10 +67,10 @@ bash eval_acc_perf.sh ``` ## Volo inference result -| accuracy | top1 | top2 | top3 | top4 | top5 | -| :------: | :---: | :---: | :---: | :---: | :---: | -| bs1 | - | O2 | 1 | 152.37 | -| bs16 | - | O2 | 1 | 23.26 | +| accuracy | top1 | +| :------: | :--------: | +| bs1 | 80.619 | +| bs16 | 82.275 | | performance | average time | average time without first | | :---------: | :-----------: | :-------------------------: | -- Gitee From ce6de362619e3f3cfb73b44883b5385148f9b618 Mon Sep 17 00:00:00 2001 From: MunchLau Date: Wed, 23 Mar 2022 02:08:30 +0000 Subject: [PATCH 3/8] eval.sh --- .../contrib/cv/classfication/volo/readme.md | 2 +- .../classfication/volo/test/eval_acc_perf.sh | 2 +- .../cv/classfication/volo/test/perf_g.sh | 4 + .../contrib/cv/classfication/volo/validate.py | 344 ++++++++++++++++++ 4 files changed, 350 insertions(+), 2 deletions(-) create mode 100644 ACL_PyTorch/contrib/cv/classfication/volo/test/perf_g.sh create mode 100644 ACL_PyTorch/contrib/cv/classfication/volo/validate.py diff --git a/ACL_PyTorch/contrib/cv/classfication/volo/readme.md b/ACL_PyTorch/contrib/cv/classfication/volo/readme.md index 9d031545ad..1e2ffac6cf 100644 --- a/ACL_PyTorch/contrib/cv/classfication/volo/readme.md +++ b/ACL_PyTorch/contrib/cv/classfication/volo/readme.md @@ -63,7 +63,7 @@ bash test/pth2om.sh d1_224_84.pth.tar volo_bs16.onnx volo_modify_bs16.onnx volo_ ./msame --model "volo_bs16.om" --input "/opt/npu/data_bs16" --output "./" --outfmt TXT # compute the val accuracy, modify the batchsize, result dir and label dir -bash eval_acc_perf.sh +bash eval_acc_perf.sh 1 /path/to/result /path/to/label.txt ``` ## Volo inference result diff --git a/ACL_PyTorch/contrib/cv/classfication/volo/test/eval_acc_perf.sh b/ACL_PyTorch/contrib/cv/classfication/volo/test/eval_acc_perf.sh index fac1b5971a..3ed475338e 100644 --- a/ACL_PyTorch/contrib/cv/classfication/volo/test/eval_acc_perf.sh +++ b/ACL_PyTorch/contrib/cv/classfication/volo/test/eval_acc_perf.sh @@ -1 +1 @@ -python volo_postprocess.py --batchsize 1 --result 2022321_14_50_42_791955 --label ./volo_val_bs1.txt \ No newline at end of file +python volo_postprocess.py --batchsize $1 --result $2 --label $3 \ No newline at end of file diff --git a/ACL_PyTorch/contrib/cv/classfication/volo/test/perf_g.sh b/ACL_PyTorch/contrib/cv/classfication/volo/test/perf_g.sh new file mode 100644 index 0000000000..f4aaf1ab68 --- /dev/null +++ b/ACL_PyTorch/contrib/cv/classfication/volo/test/perf_g.sh @@ -0,0 +1,4 @@ +python3 validate.py "$@" \ + "/mnt/data/imagenet" \ + --model volo_d1 --img-size 224 \ + --apex-amp diff --git a/ACL_PyTorch/contrib/cv/classfication/volo/validate.py b/ACL_PyTorch/contrib/cv/classfication/volo/validate.py new file mode 100644 index 0000000000..85638ea792 --- /dev/null +++ b/ACL_PyTorch/contrib/cv/classfication/volo/validate.py @@ -0,0 +1,344 @@ +""" +ImageNet Validation Script +Adapted from https://github.com/rwightman/pytorch-image-models +The script is further extend to evaluate VOLO +""" +import argparse +import os +import csv +import glob +import time +import logging +import torch +import torch.nn as nn +import torch.nn.parallel +from collections import OrderedDict +from contextlib import suppress + +from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models +from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet +from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_legacy +import models + +has_apex = False +try: + from apex import amp + has_apex = True +except ImportError: + pass + +has_native_amp = False +try: + if getattr(torch.cuda.amp, 'autocast') is not None: + has_native_amp = True +except AttributeError: + pass + +torch.backends.cudnn.benchmark = True +_logger = logging.getLogger('validate') + + +parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation') +parser.add_argument('data', metavar='DIR', + help='path to dataset') +parser.add_argument('--dataset', '-d', metavar='NAME', default='', + help='dataset type (default: ImageFolder/ImageTar if empty)') +parser.add_argument('--split', metavar='NAME', default='validation', + help='dataset split (default: validation)') +parser.add_argument('--model', '-m', metavar='NAME', default='dpn92', + help='model architecture (default: dpn92)') +parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', + help='number of data loading workers (default: 2)') +parser.add_argument('-b', '--batch-size', default=256, type=int, + metavar='N', help='mini-batch size (default: 256)') +parser.add_argument('--img-size', default=None, type=int, + metavar='N', help='Input image dimension, uses model default if empty') +parser.add_argument('--input-size', default=None, nargs=3, type=int, + metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224),' + ' uses model default if empty') +parser.add_argument('--crop-pct', default=None, type=float, + metavar='N', help='Input image center crop pct') +parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', + help='Override mean pixel value of dataset') +parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', + help='Override std deviation of of dataset') +parser.add_argument('--interpolation', default='', type=str, metavar='NAME', + help='Image resize interpolation type (overrides model)') +parser.add_argument('--num-classes', type=int, default=None, + help='Number classes in dataset') +parser.add_argument('--class-map', default='', type=str, metavar='FILENAME', + help='path to class to idx mapping file (default: "")') +parser.add_argument('--gp', default=None, type=str, metavar='POOL', + help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.') +parser.add_argument('--log-freq', default=50, type=int, + metavar='N', help='batch logging frequency (default: 10)') +parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', + help='path to latest checkpoint (default: none)') +parser.add_argument('--pretrained', dest='pretrained', action='store_true', + help='use pre-trained model') +parser.add_argument('--num-gpu', type=int, default=1, + help='Number of GPUS to use') +parser.add_argument('--no-test-pool', dest='no_test_pool', action='store_true', + help='disable test time pool') +parser.add_argument('--no-prefetcher', action='store_true', default=False, + help='disable fast prefetcher') +parser.add_argument('--pin-mem', action='store_true', default=False, + help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') +parser.add_argument('--channels-last', action='store_true', default=False, + help='Use channels_last memory layout') +parser.add_argument('--amp', action='store_true', default=False, + help='Use AMP mixed precision. Defaults to Apex, fallback to native Torch AMP.') +parser.add_argument('--apex-amp', action='store_true', default=False, + help='Use NVIDIA Apex AMP mixed precision') +parser.add_argument('--native-amp', action='store_true', default=False, + help='Use Native Torch AMP mixed precision') +parser.add_argument('--tf-preprocessing', action='store_true', default=False, + help='Use Tensorflow preprocessing pipeline (require CPU TF installed') +parser.add_argument('--use-ema', dest='use_ema', action='store_true', + help='use ema version of weights if present') +parser.add_argument('--torchscript', dest='torchscript', action='store_true', + help='convert model torchscript for inference') +parser.add_argument('--legacy-jit', dest='legacy_jit', action='store_true', + help='use legacy jit mode for pytorch 1.5/1.5.1/1.6 to get back fusion performance') +parser.add_argument('--results-file', default='', type=str, metavar='FILENAME', + help='Output csv file for validation results (summary)') +parser.add_argument('--real-labels', default='', type=str, metavar='FILENAME', + help='Real labels JSON file for imagenet evaluation') +parser.add_argument('--valid-labels', default='', type=str, metavar='FILENAME', + help='Valid label indices txt file for validation of partial label space') + + +def validate(args): + # might as well try to validate something + args.pretrained = args.pretrained or not args.checkpoint + args.prefetcher = not args.no_prefetcher + amp_autocast = suppress # do nothing + if args.amp: + if has_native_amp: + args.native_amp = True + elif has_apex: + args.apex_amp = True + else: + _logger.warning("Neither APEX or Native Torch AMP is available.") + assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set." + if args.native_amp: + amp_autocast = torch.cuda.amp.autocast + _logger.info('Validating in mixed precision with native PyTorch AMP.') + elif args.apex_amp: + _logger.info('Validating in mixed precision with NVIDIA APEX AMP.') + else: + _logger.info('Validating in float32. AMP not enabled.') + + if args.legacy_jit: + set_jit_legacy() + + # create model + model = create_model( + args.model, + pretrained=args.pretrained, + num_classes=args.num_classes, + in_chans=3, + global_pool=args.gp, + scriptable=args.torchscript, + img_size=args.img_size) + if args.num_classes is None: + assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.' + args.num_classes = model.num_classes + + if args.checkpoint: + load_checkpoint(model, args.checkpoint, args.use_ema, strict=False) + + param_count = sum([m.numel() for m in model.parameters()]) + _logger.info('Model %s created, param count: %d' % (args.model, param_count)) + + data_config = resolve_data_config(vars(args), model=model, use_test_size=True) + test_time_pool = False + if not args.no_test_pool: + model, test_time_pool = apply_test_time_pool(model, data_config, use_test_size=True) + + if args.torchscript: + torch.jit.optimized_execution(True) + model = torch.jit.script(model) + + model = model.cuda() + if args.apex_amp: + model = amp.initialize(model, opt_level='O1') + + if args.channels_last: + model = model.to(memory_format=torch.channels_last) + + if args.num_gpu > 1: + model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))) + + criterion = nn.CrossEntropyLoss().cuda() + + dataset = create_dataset( + root=args.data, name=args.dataset, split=args.split, + load_bytes=args.tf_preprocessing, class_map=args.class_map) + + if args.valid_labels: + with open(args.valid_labels, 'r') as f: + valid_labels = {int(line.rstrip()) for line in f} + valid_labels = [i in valid_labels for i in range(args.num_classes)] + else: + valid_labels = None + + if args.real_labels: + real_labels = RealLabelsImagenet(dataset.filenames(basename=True), real_json=args.real_labels) + else: + real_labels = None + + crop_pct = 1.0 if test_time_pool else data_config['crop_pct'] + loader = create_loader( + dataset, + input_size=data_config['input_size'], + batch_size=args.batch_size, + use_prefetcher=args.prefetcher, + interpolation=data_config['interpolation'], + mean=data_config['mean'], + std=data_config['std'], + num_workers=args.workers, + crop_pct=crop_pct, + pin_memory=args.pin_mem, + tf_preprocessing=args.tf_preprocessing) + + batch_time = AverageMeter() + losses = AverageMeter() + top1 = AverageMeter() + top5 = AverageMeter() + + model.eval() + with torch.no_grad(): + # warmup, reduce variability of first batch time, especially for comparing torchscript vs non + input = torch.randn((args.batch_size,) + data_config['input_size']).cuda() + if args.channels_last: + input = input.contiguous(memory_format=torch.channels_last) + model(input) + end = time.time() + for batch_idx, (input, target) in enumerate(loader): + if args.no_prefetcher: + target = target.cuda() + input = input.cuda() + if args.channels_last: + input = input.contiguous(memory_format=torch.channels_last) + + # compute output + with amp_autocast(): + output = model(input) + if isinstance(output, (tuple, list)): + output = output[0] + if valid_labels is not None: + output = output[:, valid_labels] + loss = criterion(output, target) + + if real_labels is not None: + real_labels.add_result(output) + + # measure accuracy and record loss + acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5)) + losses.update(loss.item(), input.size(0)) + top1.update(acc1.item(), input.size(0)) + top5.update(acc5.item(), input.size(0)) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if batch_idx % args.log_freq == 0: + _logger.info( + 'Test: [{0:>4d}/{1}] ' + 'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' + 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' + 'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) ' + 'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format( + batch_idx, len(loader), batch_time=batch_time, + rate_avg=input.size(0) / batch_time.avg, + loss=losses, top1=top1, top5=top5)) + + if real_labels is not None: + # real labels mode replaces topk values at the end + top1a, top5a = real_labels.get_accuracy(k=1), real_labels.get_accuracy(k=5) + else: + top1a, top5a = top1.avg, top5.avg + results = OrderedDict( + top1=round(top1a, 4), top1_err=round(100 - top1a, 4), + top5=round(top5a, 4), top5_err=round(100 - top5a, 4), + param_count=round(param_count / 1e6, 2), + img_size=data_config['input_size'][-1], + cropt_pct=crop_pct, + interpolation=data_config['interpolation']) + + _logger.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format( + results['top1'], results['top1_err'], results['top5'], results['top5_err'])) + + return results + + +def main(): + setup_default_logging() + args = parser.parse_args() + model_cfgs = [] + model_names = [] + if os.path.isdir(args.checkpoint): + # validate all checkpoints in a path with same model + checkpoints = glob.glob(args.checkpoint + '/*.pth.tar') + checkpoints += glob.glob(args.checkpoint + '/*.pth') + model_names = list_models(args.model) + model_cfgs = [(args.model, c) for c in sorted(checkpoints, key=natural_key)] + else: + if args.model == 'all': + # validate all models in a list of names with pretrained checkpoints + args.pretrained = True + model_names = list_models(pretrained=True, exclude_filters=['*in21k']) + model_cfgs = [(n, '') for n in model_names] + elif not is_model(args.model): + # model name doesn't exist, try as wildcard filter + model_names = list_models(args.model) + model_cfgs = [(n, '') for n in model_names] + + if len(model_cfgs): + results_file = args.results_file or './results-all.csv' + _logger.info('Running bulk validation on these pretrained models: {}'.format(', '.join(model_names))) + results = [] + try: + start_batch_size = args.batch_size + for m, c in model_cfgs: + batch_size = start_batch_size + args.model = m + args.checkpoint = c + result = OrderedDict(model=args.model) + r = {} + while not r and batch_size >= args.num_gpu: + torch.cuda.empty_cache() + try: + args.batch_size = batch_size + print('Validating with batch size: %d' % args.batch_size) + r = validate(args) + except RuntimeError as e: + if batch_size <= args.num_gpu: + print("Validation failed with no ability to reduce batch size. Exiting.") + raise e + batch_size = max(batch_size // 2, args.num_gpu) + print("Validation failed, reducing batch size by 50%") + result.update(r) + if args.checkpoint: + result['checkpoint'] = args.checkpoint + results.append(result) + except KeyboardInterrupt as e: + pass + results = sorted(results, key=lambda x: x['top1'], reverse=True) + if len(results): + write_results(results_file, results) + else: + validate(args) + +def write_results(results_file, results): + with open(results_file, mode='w') as cf: + dw = csv.DictWriter(cf, fieldnames=results[0].keys()) + dw.writeheader() + for r in results: + dw.writerow(r) + cf.flush() + +if __name__ == '__main__': + main() -- Gitee From 2d478fa7a656428a8968c77885975f931fe01def Mon Sep 17 00:00:00 2001 From: MunchLau Date: Wed, 6 Apr 2022 07:08:17 +0000 Subject: [PATCH 4/8] check --- ACL_PyTorch/contrib/cv/classfication/volo/readme.md | 10 ++++++---- .../contrib/cv/classfication/volo/volo_preprocess.py | 4 ++-- .../contrib/cv/classfication/volo/volo_pth2onnx.py | 2 +- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/ACL_PyTorch/contrib/cv/classfication/volo/readme.md b/ACL_PyTorch/contrib/cv/classfication/volo/readme.md index 1e2ffac6cf..f553e3bee1 100644 --- a/ACL_PyTorch/contrib/cv/classfication/volo/readme.md +++ b/ACL_PyTorch/contrib/cv/classfication/volo/readme.md @@ -72,7 +72,9 @@ bash eval_acc_perf.sh 1 /path/to/result /path/to/label.txt | bs1 | 80.619 | | bs16 | 82.275 | -| performance | average time | average time without first | -| :---------: | :-----------: | :-------------------------: | -| bs1 | 396.46ms | 396.46ms | -| bs16 | 3635.25ms | 3635.25ms | +|batchsize| performance | average time | average time without first | +| :-----: | :---------: | :-----------: | :-------------------------: | +| bs1 | 10.08fps | 396.46ms | 396.46ms | +| bs16 | 17.6fps | 3635.25ms | 3635.25ms | + + diff --git a/ACL_PyTorch/contrib/cv/classfication/volo/volo_preprocess.py b/ACL_PyTorch/contrib/cv/classfication/volo/volo_preprocess.py index 95ddd03c7f..1e57a48fb6 100755 --- a/ACL_PyTorch/contrib/cv/classfication/volo/volo_preprocess.py +++ b/ACL_PyTorch/contrib/cv/classfication/volo/volo_preprocess.py @@ -13,7 +13,7 @@ def preprocess_volo(data_dir, save_path, batch_size): ImageDataset(data_dir), input_size=(3, 224, 224), batch_size=batch_size, - use_prefetcher=True, + use_prefetcher=False, interpolation="bicubic", mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), @@ -56,4 +56,4 @@ if __name__ == '__main__': os.makedirs(des, exist_ok=True) preprocess_volo(src, des, bs) - # python volo_224_preprocess.py --src /opt/npu/val --des /opt/npu/data_bs1 --batchsize 1 \ No newline at end of file + # python volo_224_preprocess.py --src /opt/npu/val --des /opt/npu/data_bs1 --batchsize 1 diff --git a/ACL_PyTorch/contrib/cv/classfication/volo/volo_pth2onnx.py b/ACL_PyTorch/contrib/cv/classfication/volo/volo_pth2onnx.py index 0f3f83d90a..1289d3440c 100755 --- a/ACL_PyTorch/contrib/cv/classfication/volo/volo_pth2onnx.py +++ b/ACL_PyTorch/contrib/cv/classfication/volo/volo_pth2onnx.py @@ -22,7 +22,7 @@ def pth_to_onnx(input, checkpoint, onnx_path, input_names=['input'], output_name load_checkpoint(model, checkpoint, False, strict=False) model.eval() - torch.onnx.export(model, input, onnx_path, verbose=True, input_names=input_names, output_names=output_names, opset_version=12, export_modules_as_functions=False) + torch.onnx.export(model, input, onnx_path, verbose=True, input_names=input_names, output_names=output_names, opset_version=12) print("Exporting .pth model to onnx model has been successful!") if __name__ == '__main__': -- Gitee From bfde0c9dca4156f6289ba44ef49b03838e0481d1 Mon Sep 17 00:00:00 2001 From: MunchLau Date: Wed, 6 Apr 2022 07:49:42 +0000 Subject: [PATCH 5/8] check_precess --- ACL_PyTorch/contrib/cv/classfication/volo/volo_preprocess.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ACL_PyTorch/contrib/cv/classfication/volo/volo_preprocess.py b/ACL_PyTorch/contrib/cv/classfication/volo/volo_preprocess.py index 1e57a48fb6..7ca9c57e5d 100755 --- a/ACL_PyTorch/contrib/cv/classfication/volo/volo_preprocess.py +++ b/ACL_PyTorch/contrib/cv/classfication/volo/volo_preprocess.py @@ -25,6 +25,7 @@ def preprocess_volo(data_dir, save_path, batch_size): for batch_idx, (input, target) in enumerate(loader): img = np.array(input).astype(np.float32) save_name = os.path.join(save_path, "test_" + str(batch_idx) + ".bin") + print(save_name) img.tofile(save_name) if batch_size == 1: info = "%s %d \n" % ("test_" + str(batch_idx) + ".bin", target) -- Gitee From 00bc19d5e1e00c61e49a6bcf270fc812586d387d Mon Sep 17 00:00:00 2001 From: MunchLau Date: Wed, 6 Apr 2022 08:39:21 +0000 Subject: [PATCH 6/8] check_precess_batch --- .../cv/classfication/volo/volo_postprocess.py | 18 +++++++++--------- .../cv/classfication/volo/volo_preprocess.py | 12 ++++++++++++ 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/ACL_PyTorch/contrib/cv/classfication/volo/volo_postprocess.py b/ACL_PyTorch/contrib/cv/classfication/volo/volo_postprocess.py index 2d8d8ce00e..1eae6503ed 100755 --- a/ACL_PyTorch/contrib/cv/classfication/volo/volo_postprocess.py +++ b/ACL_PyTorch/contrib/cv/classfication/volo/volo_postprocess.py @@ -16,7 +16,12 @@ def read_label(path, bs): lines = [line for line in content.split('\n')] if lines[-1] == "": lines = lines[:-1] - if bs == 16: + if bs == 1: + labels = [int(line.split(' ')[-2]) for line in lines] + labels = np.array(labels) + labels = np.expand_dims(labels, 1) + return labels + else: total_label = np.zeros((len(files) * bs)) base = 0 for line in lines: @@ -27,16 +32,11 @@ def read_label(path, bs): base = base + 1 total_label = np.expand_dims(total_label, 1) return total_label - if bs == 1: - labels = [int(line.split(' ')[-2]) for line in lines] - labels = np.array(labels) - labels = np.expand_dims(labels, 1) - return labels def get_topK(files, topk, bs): if bs == 1: matrix = np.zeros((len(files), topk)) - if bs ==16: + else: matrix = np.zeros((len(files) * bs, topk)) for file in files: data = read_txt_data(root + file) @@ -44,7 +44,7 @@ def get_topK(files, topk, bs): line = np.argsort(data)[-topk:][::-1] index = int(file.split('_')[1]) matrix[index-1, :] = line[:topk] - if bs == 16: + else: base_index = int(file.split('_')[1]) newdata = data.reshape(bs, 1000) for i in range(bs): @@ -80,4 +80,4 @@ if __name__ == "__main__": labels = read_label(label_dir, bs) for i in range(1, 6): acc = get_topK_acc(matrix, labels, i) - print("acc@top{}: {:.3f}%".format(i, 100*acc)) \ No newline at end of file + print("acc@top{}: {:.3f}%".format(i, 100*acc)) diff --git a/ACL_PyTorch/contrib/cv/classfication/volo/volo_preprocess.py b/ACL_PyTorch/contrib/cv/classfication/volo/volo_preprocess.py index 7ca9c57e5d..cc898f92ed 100755 --- a/ACL_PyTorch/contrib/cv/classfication/volo/volo_preprocess.py +++ b/ACL_PyTorch/contrib/cv/classfication/volo/volo_preprocess.py @@ -29,10 +29,22 @@ def preprocess_volo(data_dir, save_path, batch_size): img.tofile(save_name) if batch_size == 1: info = "%s %d \n" % ("test_" + str(batch_idx) + ".bin", target) + if batch_size == 2: + info = "%s %d %d \n" % ("test_" + str(batch_idx) + ".bin", target[0], target[1]) + if batch_size == 4: + info = "%s %d %d %d %d \n" % ("test_" + str(batch_idx) + ".bin", target[0], target[1], target[2], target[3]) + if batch_size == 8: + info = "%s %d %d %d %d %d %d %d %d \n" % ("test_" + str(batch_idx) + ".bin", target[0], target[1], target[2], target[3], \ + target[4], target[5], target[6], target[7]) if batch_size == 16: info = "%s %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d \n" % ("test_" + str(batch_idx) + ".bin", \ target[0], target[1], target[2], target[3], target[4], target[5], target[6], target[7], target[8], \ target[9], target[10], target[11], target[12], target[13], target[14], target[15]) + if batch_size == 32: + info = "%s %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d \n" % ("test_" + str(batch_idx) + ".bin", \ + target[0], target[1], target[2], target[3], target[4], target[5], target[6], target[7], target[8], target[9], target[10], target[11], target[12], \ + target[13], target[14], target[15], target[16], target[17], target[18], target[19], target[20], target[21], target[22], target[23], target[24], \ + target[25], target[26], target[27], target[28], target[29], target[30], target[31]) f.write(info) f.close() -- Gitee From 019be4bc41f9af2400a135cb3c5d4f19800bacd8 Mon Sep 17 00:00:00 2001 From: MunchLau Date: Wed, 6 Apr 2022 09:13:31 +0000 Subject: [PATCH 7/8] process_any_batchsize --- .../cv/classfication/volo/volo_preprocess.py | 22 ++++--------------- 1 file changed, 4 insertions(+), 18 deletions(-) diff --git a/ACL_PyTorch/contrib/cv/classfication/volo/volo_preprocess.py b/ACL_PyTorch/contrib/cv/classfication/volo/volo_preprocess.py index cc898f92ed..6699f61a73 100755 --- a/ACL_PyTorch/contrib/cv/classfication/volo/volo_preprocess.py +++ b/ACL_PyTorch/contrib/cv/classfication/volo/volo_preprocess.py @@ -27,24 +27,10 @@ def preprocess_volo(data_dir, save_path, batch_size): save_name = os.path.join(save_path, "test_" + str(batch_idx) + ".bin") print(save_name) img.tofile(save_name) - if batch_size == 1: - info = "%s %d \n" % ("test_" + str(batch_idx) + ".bin", target) - if batch_size == 2: - info = "%s %d %d \n" % ("test_" + str(batch_idx) + ".bin", target[0], target[1]) - if batch_size == 4: - info = "%s %d %d %d %d \n" % ("test_" + str(batch_idx) + ".bin", target[0], target[1], target[2], target[3]) - if batch_size == 8: - info = "%s %d %d %d %d %d %d %d %d \n" % ("test_" + str(batch_idx) + ".bin", target[0], target[1], target[2], target[3], \ - target[4], target[5], target[6], target[7]) - if batch_size == 16: - info = "%s %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d \n" % ("test_" + str(batch_idx) + ".bin", \ - target[0], target[1], target[2], target[3], target[4], target[5], target[6], target[7], target[8], \ - target[9], target[10], target[11], target[12], target[13], target[14], target[15]) - if batch_size == 32: - info = "%s %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d %d \n" % ("test_" + str(batch_idx) + ".bin", \ - target[0], target[1], target[2], target[3], target[4], target[5], target[6], target[7], target[8], target[9], target[10], target[11], target[12], \ - target[13], target[14], target[15], target[16], target[17], target[18], target[19], target[20], target[21], target[22], target[23], target[24], \ - target[25], target[26], target[27], target[28], target[29], target[30], target[31]) + info = "%s " % ("test_" + str(batch_idx) + ".bin") + for i in range(batch_size): + info = info + str(int(target[i])) + " " + info = info + "\n" f.write(info) f.close() -- Gitee From a03e702cdf46da5b81eeb4538eb1a29d5586fd94 Mon Sep 17 00:00:00 2001 From: MunchLau Date: Wed, 6 Apr 2022 09:28:55 +0000 Subject: [PATCH 8/8] process_droplast --- ACL_PyTorch/contrib/cv/classfication/volo/volo_preprocess.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ACL_PyTorch/contrib/cv/classfication/volo/volo_preprocess.py b/ACL_PyTorch/contrib/cv/classfication/volo/volo_preprocess.py index 6699f61a73..3ef58ebd3f 100755 --- a/ACL_PyTorch/contrib/cv/classfication/volo/volo_preprocess.py +++ b/ACL_PyTorch/contrib/cv/classfication/volo/volo_preprocess.py @@ -24,6 +24,8 @@ def preprocess_volo(data_dir, save_path, batch_size): for batch_idx, (input, target) in enumerate(loader): img = np.array(input).astype(np.float32) + if img.shape[0] < batch_size: + continue save_name = os.path.join(save_path, "test_" + str(batch_idx) + ".bin") print(save_name) img.tofile(save_name) -- Gitee