diff --git a/PyTorch/contrib/cv/video/MDNet/LICENSE b/PyTorch/contrib/cv/video/MDNet/LICENSE new file mode 100755 index 0000000000000000000000000000000000000000..4fa3cd95570a91b6ad5d77437b96e38a9e56907c --- /dev/null +++ b/PyTorch/contrib/cv/video/MDNet/LICENSE @@ -0,0 +1,57 @@ +Copyright Pohang University of Science and Technology. All rights reserved. + +Contact person: +Hyeonseob Nam (namhs09 postech.ac.kr) + +This software is being made available for individual research use only. +Any commercial use or redistribution of this software requires a license from +the Pohang University of Science and Technology. + +You may use this work subject to the following conditions: + +1. This work is provided "as is" by the copyright holder, with +absolutely no warranties of correctness, fitness, intellectual property +ownership, or anything else whatsoever. You use the work +entirely at your own risk. The copyright holder will not be liable for +any legal damages whatsoever connected with the use of this work. + +2. The copyright holder retain all copyright to the work. All copies of +the work and all works derived from it must contain (1) this copyright +notice, and (2) additional notices describing the content, dates and +copyright holder of modifications or additions made to the work, if +any, including distribution and use conditions and intellectual property +claims. Derived works must be clearly distinguished from the original +work, both by name and by the prominent inclusion of explicit +descriptions of overlaps and differences. + +3. The names and trademarks of the copyright holder may not be used in +advertising or publicity related to this work without specific prior +written permission. + +4. In return for the free use of this work, you are requested, but not +legally required, to do the following: + +* If you become aware of factors that may significantly affect other + users of the work, for example major bugs or + deficiencies or possible intellectual property issues, you are + requested to report them to the copyright holder, if possible + including redistributable fixes or workarounds. + +* If you use the work in scientific research or as part of a larger + software system, you are requested to cite the use in any related + publications or technical documentation. The work is based upon: + + Hyeonseob Nam, Bohyung Han. + Learning Multi-Domain Convolutional Neural Networks for Visual Tracking + CVPR, 2016. + + @InProceedings{nam2016mdnet, + author = {Nam, Hyeonseob and Han, Bohyung}, + title = {Learning Multi-Domain Convolutional Neural Networks for Visual Tracking}, + booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, + month = {June}, + year = {2016} + } + +This copyright notice must be retained with all copies of the software, +including any modified or derived versions. diff --git a/PyTorch/contrib/cv/video/MDNet/README.md b/PyTorch/contrib/cv/video/MDNet/README.md new file mode 100644 index 0000000000000000000000000000000000000000..86c000dbfcb5fec89f960312bb1d58f44c85f742 --- /dev/null +++ b/PyTorch/contrib/cv/video/MDNet/README.md @@ -0,0 +1,158 @@ +# MDNet for PyTorch + +- [概述](#概述) +- [准备训练环境](#准备训练环境) +- [开始训练](#开始训练) +- [训练结果展示](#训练结果展示) +- [版本说明](#版本说明) + + + +# 概述 + +## 简述 + +MDNet是一种视觉跟踪算法,该算法基于经过判别训练的卷积神经网络(CNN)的表示。MDNet使用大量具有跟踪真实情况的视频对CNN进行预训练,以获得通用的目标表示。MDNet的网络由共享层和域特定层的多个分支组成,其中域对应于单独的训练序列,每个分支负责二分类以识别每个域中的目标。针对每个域迭代地训练网络,以获得共享层中的通用目标表示。在新序列中跟踪目标时,通过将预训练CNN中的共享层与在线更新的新二分类层相结合来构建一个新网络。通过评估围绕先前目标状态随机采样的候选窗口来执行在线跟踪。 + +- 参考实现: + + ``` + url=https://github.com/hyeonseobnam/py-MDNet + commit_id=680fa4d + ``` + +- 适配昇腾 AI 处理器的实现: + + ``` + url=https://gitee.com/ascend/ModelZoo-PyTorch.git + code_path=PyTorch/contrib/cv/video + ``` + +- 通过Git获取代码方法如下: + + ``` + git clone https://gitee.com/ascend/ModelZoo-PyTorch.git # 克隆仓库的代码 + cd PyTorch/contrib/cv/video/MDNet # 切换到模型代码所在路径 + ``` + +- 通过单击“立即下载”,下载源码包。 + +# 准备训练环境 + +## 准备环境 + +- 当前模型支持的固件与驱动、 CANN 以及 PyTorch 如下表所示。 + + **表 1** 版本配套表 + + | 配套 | 版本 | + | ---------- | ------------------------------------------------------------ | + | 固件与驱动 | [1.0.15](https://www.hiascend.com/hardware/firmware-drivers?tag=commercial) | + | CANN | [5.1.RC1](https://www.hiascend.com/software/cann/commercial?version=5.1.RC1) | + | PyTorch | [1.8.1](https://gitee.com/ascend/pytorch/tree/master/)或[1.5.0](https://gitee.com/ascend/pytorch/tree/v1.5.0/) | + +- 环境准备指导。 + + 请参考《[Pytorch框架训练环境准备](https://www.hiascend.com/document/detail/zh/ModelZoo/pytorchframework/ptes)》。 + +- 安装依赖。 + + ``` + pip install -r requirements.txt + ``` + + +## 准备数据集 + +1. 获取数据集。 + + 用户自行获取原始数据集,可选用的开源数据集包括VOT、ImagNet-VID、OTB等,将数据集上传到服务器任意路径下并解压。 + + 训练以VOT数据集为例,放至源码包根目录的datasets/VOT目录下,数据集目录结构参考如下所示。VOT数据集使用VOT2013、VOT2014和VOT2015。 + + ``` + ├── datasets + ├──VOT + ├──vot2013 + │──视频1 + │──视频2 + │ ... + ├──vot2014 + │──视频1 + │──视频2 + │ ... + ├──vot2015 + │──视频1 + │──视频2 + │ ... + ``` + + 推理以OTB数据集为例,放至源码包根目录的datasets/OTB50目录下,数据集目录结构参考如下所示。OTB数据集使用OTB50。 + + ``` + ├── datasets + ├──OTB50 + │──视频1 + │──视频2 + │ ... + ``` + + > **说明:** + >该数据集的训练过程脚本只作为一种参考示例。 + +2. 数据预处理。 + 执行 + + ``` + python3 pertrain/prepro_vot.py + ``` + + 对vot数据集进行预处理。 + +## 获取预训练模型 + +请参考原始仓库上的README_raw.md进行预训练模型获取。将获取的imagenet-vgg-m.mat预训练模型放至源码包根目录的models目录下。 + +# 开始训练 + +## 训练模型 + +1. 进入解压后的源码包根目录。 + + ``` + cd /${模型文件夹名称} + ``` + +2. 运行训练脚本。 + + 该模型支持单机单卡训练。 + + - 单机单卡训练 + + 启动单卡训练。 + + ``` + bash ./test/train_full_1p.sh + ``` + + 训练完成后,权重文件保存在源码包根目录的models目录下,并输出模型训练精度和性能信息。 + 模型训练和验证的配置分别在源码包根目录的pretrain/options_*.yaml和tracking/options.yaml中修改。 + +# 训练结果展示 + +**表 2** 训练结果展示表 + +| NAME | Acc | FPS | Epochs | AMP_Type | +| ------- | ----- | ----- | ------ | -------- | +| 1p-GPU | 0.904 | 61.68 | 50 | O1 | +| 1p-NPU | 0.911 | 13.64 | 50 | O1 | + +# 版本说明 + +## 变更 + +2020.11.30:首次发布。 + +## 已知问题 + +模型的LocalResponseNorm当前在NPU上存在精度问题,因此放在CPU上计算,对模型性能有影响。 diff --git a/PyTorch/contrib/cv/video/MDNet/README_raw.md b/PyTorch/contrib/cv/video/MDNet/README_raw.md new file mode 100755 index 0000000000000000000000000000000000000000..7f172f9a9240612683312cad26b101a4d38cd35c --- /dev/null +++ b/PyTorch/contrib/cv/video/MDNet/README_raw.md @@ -0,0 +1,62 @@ +# py-MDNet + +by [Hyeonseob Nam](https://hyeonseobnam.github.io/) and [Bohyung Han](http://cvlab.postech.ac.kr/~bhhan/) at POSTECH + +**Update (April, 2019)** +- Migration to python 3.6 & pyTorch 1.0 +- Efficiency improvement (~5fps) +- ImagNet-VID pretraining +- Code refactoring + +## Introduction +PyTorch implementation of MDNet, which runs at ~5fps with a single CPU core and a single GPU (GTX 1080 Ti). +#### [[Project]](http://cvlab.postech.ac.kr/research/mdnet/) [[Paper]](https://arxiv.org/abs/1510.07945) [[Matlab code]](https://github.com/HyeonseobNam/MDNet) + +If you're using this code for your research, please cite: + + @InProceedings{nam2016mdnet, + author = {Nam, Hyeonseob and Han, Bohyung}, + title = {Learning Multi-Domain Convolutional Neural Networks for Visual Tracking}, + booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, + month = {June}, + year = {2016} + } + +## Results on OTB +- Raw results of MDNet pretrained on **VOT-OTB** (VOT13,14,15 excluding OTB): [Google drive link](https://drive.google.com/open?id=1ZSCj1UEn4QhoRypgH28hVxSgWbI8q8Hl) +- Raw results of MDNet pretrained on **Imagenet-VID**: [Google drive link](https://drive.google.com/open?id=14lJGcumtBRmtpZhmgY1BsrbEQixfhIpP) + + + + + +## Prerequisites +- python 3.6+ +- opencv 3.0+ +- [PyTorch 1.0+](http://pytorch.org/) and its dependencies +- for GPU support: a GPU with ~3G memory + +## Usage + +### Tracking +```bash + python tracking/run_tracker.py -s DragonBaby [-d (display fig)] [-f (save fig)] +``` + - You can provide a sequence configuration in two ways (see tracking/gen_config.py): + - ```python tracking/run_tracker.py -s [seq name]``` + - ```python tracking/run_tracker.py -j [json path]``` + +### Pretraining + - Download [VGG-M](http://www.vlfeat.org/matconvnet/models/imagenet-vgg-m.mat) (matconvnet model) and save as "models/imagenet-vgg-m.mat" + - Pretraining on VOT-OTB + - Download [VOT](http://www.votchallenge.net/) datasets into "datasets/VOT/vot201x" + ``` bash + python pretrain/prepro_vot.py + python pretrain/train_mdnet.py -d vot + ``` + - Pretraining on ImageNet-VID + - Download [ImageNet-VID](http://bvisionweb1.cs.unc.edu/ilsvrc2015/download-videos-3j16.php#vid) dataset into "datasets/ILSVRC" + ``` bash + python pretrain/prepro_imagenet.py + python pretrain/train_mdnet.py -d imagenet + ``` diff --git a/PyTorch/contrib/cv/video/MDNet/datasets/list/vot-otb.txt b/PyTorch/contrib/cv/video/MDNet/datasets/list/vot-otb.txt new file mode 100755 index 0000000000000000000000000000000000000000..768656babfaf78bfbd33a37f1a6220849295c44e --- /dev/null +++ b/PyTorch/contrib/cv/video/MDNet/datasets/list/vot-otb.txt @@ -0,0 +1,58 @@ +vot2013/cup +vot2013/iceskater +vot2013/juice +vot2014/ball +vot2014/bicycle +vot2014/drunk +vot2014/fish1 +vot2014/hand1 +vot2014/polarbear +vot2014/sphere +vot2014/sunshade +vot2014/surfing +vot2014/torus +vot2014/tunnel +vot2015/bag +vot2015/ball1 +vot2015/ball2 +vot2015/birds1 +vot2015/birds2 +vot2015/blanket +vot2015/bmx +vot2015/book +vot2015/butterfly +vot2015/crossing +vot2015/dinosaur +vot2015/fernando +vot2015/fish1 +vot2015/fish2 +vot2015/fish3 +vot2015/fish4 +vot2015/glove +vot2015/godfather +vot2015/graduate +vot2015/gymnastics1 +vot2015/gymnastics2 +vot2015/gymnastics3 +vot2015/gymnastics4 +vot2015/hand +vot2015/handball1 +vot2015/handball2 +vot2015/helicopter +vot2015/iceskater1 +vot2015/leaves +vot2015/marching +vot2015/motocross2 +vot2015/nature +vot2015/octopus +vot2015/rabbit +vot2015/racing +vot2015/road +vot2015/sheep +vot2015/singer3 +vot2015/soccer2 +vot2015/soldier +vot2015/sphere +vot2015/traffic +vot2015/tunnel +vot2015/wiper diff --git a/PyTorch/contrib/cv/video/MDNet/models/.keep b/PyTorch/contrib/cv/video/MDNet/models/.keep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/PyTorch/contrib/cv/video/MDNet/modelzoo_level.txt b/PyTorch/contrib/cv/video/MDNet/modelzoo_level.txt new file mode 100644 index 0000000000000000000000000000000000000000..a829ab59b97a1022dd6fc33b59b7ae0d55009432 --- /dev/null +++ b/PyTorch/contrib/cv/video/MDNet/modelzoo_level.txt @@ -0,0 +1,3 @@ +FuncStatus:OK +PerfStatus:NOK +PrecisionStatus:OK \ No newline at end of file diff --git a/PyTorch/contrib/cv/video/MDNet/modules/__init__.py b/PyTorch/contrib/cv/video/MDNet/modules/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/PyTorch/contrib/cv/video/MDNet/modules/model.py b/PyTorch/contrib/cv/video/MDNet/modules/model.py new file mode 100755 index 0000000000000000000000000000000000000000..558381a256bbe33cae9413ddc88215ade7ab0fb0 --- /dev/null +++ b/PyTorch/contrib/cv/video/MDNet/modules/model.py @@ -0,0 +1,198 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import scipy.io +import numpy as np +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim + + +def append_params(params, module, prefix): + for child in module.children(): + for k,p in child._parameters.items(): + if p is None: continue + + if isinstance(child, nn.BatchNorm2d): + name = prefix + '_bn_' + k + else: + name = prefix + '_' + k + + if name not in params: + params[name] = p + else: + raise RuntimeError('Duplicated param name: {:s}'.format(name)) + + +def set_optimizer(model, lr_base, lr_mult, train_all=False, momentum=0.9, w_decay=0.0005): + if train_all: + params = model.get_all_params() + else: + params = model.get_learnable_params() + param_list = [] + for k, p in params.items(): + lr = lr_base + for l, m in lr_mult.items(): + if k.startswith(l): + lr = lr_base * m + param_list.append({'params': [p], 'lr':lr}) + optimizer = optim.SGD(param_list, lr = lr, momentum=momentum, weight_decay=w_decay) + return optimizer + + +class LRNCPU(nn.Module): + # LocalResponseNorm has a precision issue on npu, move to cpu. + # See https://e.gitee.com/HUAWEI-ASCEND/dashboard?issue=I5ST2D. + def forward(self, input): + input = input.cpu() + input = nn.LocalResponseNorm(2)(input) + input = input.npu() + return input + + +class MDNet(nn.Module): + def __init__(self, model_path=None, K=1): + super(MDNet, self).__init__() + self.K = K + self.layers = nn.Sequential(OrderedDict([ + ('conv1', nn.Sequential(nn.Conv2d(3, 96, kernel_size=7, stride=2), + nn.ReLU(inplace=True), + LRNCPU(), + # nn.LocalResponseNorm(2), + nn.MaxPool2d(kernel_size=3, stride=2))), + ('conv2', nn.Sequential(nn.Conv2d(96, 256, kernel_size=5, stride=2), + nn.ReLU(inplace=True), + LRNCPU(), + # nn.LocalResponseNorm(2), + nn.MaxPool2d(kernel_size=3, stride=2))), + ('conv3', nn.Sequential(nn.Conv2d(256, 512, kernel_size=3, stride=1), + nn.ReLU(inplace=True))), + ('fc4', nn.Sequential(nn.Linear(512 * 3 * 3, 512), + nn.ReLU(inplace=True))), + ('fc5', nn.Sequential(nn.Dropout(0.5), + nn.Linear(512, 512), + nn.ReLU(inplace=True)))])) + + self.branches = nn.ModuleList([nn.Sequential(nn.Dropout(0.5), + nn.Linear(512, 2)) for _ in range(K)]) + + for m in self.layers.modules(): + if isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.constant_(m.bias, 0.1) + for m in self.branches.modules(): + if isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.constant_(m.bias, 0) + + if model_path is not None: + if os.path.splitext(model_path)[1] == '.pth': + self.load_model(model_path) + elif os.path.splitext(model_path)[1] == '.mat': + self.load_mat_model(model_path) + else: + raise RuntimeError('Unkown model format: {:s}'.format(model_path)) + self.build_param_dict() + + def build_param_dict(self): + self.params = OrderedDict() + for name, module in self.layers.named_children(): + append_params(self.params, module, name) + for k, module in enumerate(self.branches): + append_params(self.params, module, 'fc6_{:d}'.format(k)) + + def set_learnable_params(self, layers): + for k, p in self.params.items(): + if any([k.startswith(l) for l in layers]): + p.requires_grad = True + else: + p.requires_grad = False + + def get_learnable_params(self): + params = OrderedDict() + for k, p in self.params.items(): + if p.requires_grad: + params[k] = p + return params + + def get_all_params(self): + params = OrderedDict() + for k, p in self.params.items(): + params[k] = p + return params + + def forward(self, x, k=0, in_layer='conv1', out_layer='fc6'): + # forward model from in_layer to out_layer + run = False + for name, module in self.layers.named_children(): + if name == in_layer: + run = True + if run: + x = module(x) + if name == 'conv3': + x = x.reshape(x.size(0), -1) + if name == out_layer: + return x + + x = self.branches[k](x) + if out_layer=='fc6': + return x + elif out_layer=='fc6_softmax': + return F.softmax(x, dim=1) + + def load_model(self, model_path): + states = torch.load(model_path) + shared_layers = states['shared_layers'] + self.layers.load_state_dict(shared_layers) + + def load_mat_model(self, matfile): + mat = scipy.io.loadmat(matfile) + mat_layers = list(mat['layers'])[0] + + # copy conv weights + for i in range(3): + weight, bias = mat_layers[i * 4]['weights'].item()[0] + self.layers[i][0].weight.data = torch.from_numpy(np.transpose(weight, (3, 2, 0, 1))) + self.layers[i][0].bias.data = torch.from_numpy(bias[:, 0]) + + +class BCELoss(nn.Module): + def forward(self, pos_score, neg_score, average=True): + pos_loss = -F.log_softmax(pos_score, dim=1)[:, 1] + neg_loss = -F.log_softmax(neg_score, dim=1)[:, 0] + + loss = pos_loss.sum() + neg_loss.sum() + if average: + loss /= (pos_loss.size(0) + neg_loss.size(0)) + return loss + + +class Accuracy(): + def __call__(self, pos_score, neg_score): + pos_correct = (pos_score[:, 1] > pos_score[:, 0]).sum().float() + neg_correct = (neg_score[:, 1] < neg_score[:, 0]).sum().float() + acc = (pos_correct + neg_correct) / (pos_score.size(0) + neg_score.size(0) + 1e-8) + return acc.item() + + +class Precision(): + def __call__(self, pos_score, neg_score): + scores = torch.cat((pos_score[:, 1], neg_score[:, 1]), 0) + topk = torch.topk(scores, pos_score.size(0))[1] + prec = (topk < pos_score.size(0)).float().sum() / (pos_score.size(0) + 1e-8) + return prec.item() diff --git a/PyTorch/contrib/cv/video/MDNet/modules/sample_generator.py b/PyTorch/contrib/cv/video/MDNet/modules/sample_generator.py new file mode 100755 index 0000000000000000000000000000000000000000..763269a3f3af04e0bb1125225d8f2f8b2bd893de --- /dev/null +++ b/PyTorch/contrib/cv/video/MDNet/modules/sample_generator.py @@ -0,0 +1,110 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from PIL import Image + +from .utils import overlap_ratio + + +class SampleGenerator(): + def __init__(self, type_, img_size, trans=1, scale=1, aspect=None, valid=False): + self.type = type_ + self.img_size = np.array(img_size) # (w, h) + self.trans = trans + self.scale = scale + self.aspect = aspect + self.valid = valid + + def _gen_samples(self, bb, n): + # + # bb: target bbox (min_x,min_y,w,h) + bb = np.array(bb, dtype='float32') + + # (center_x, center_y, w, h) + sample = np.array([bb[0] + bb[2] / 2, bb[1] + bb[3] / 2, bb[2], bb[3]], dtype='float32') + samples = np.tile(sample[None, :], (n ,1)) + + # vary aspect ratio + if self.aspect is not None: + ratio = np.random.rand(n, 2) * 2 - 1 + samples[:, 2:] *= self.aspect ** ratio + + # sample generation + if self.type == 'gaussian': + samples[:, :2] += self.trans * np.mean(bb[2:]) * np.clip(0.5 * np.random.randn(n, 2), -1, 1) + samples[:, 2:] *= self.scale ** np.clip(0.5 * np.random.randn(n, 1), -1, 1) + + elif self.type == 'uniform': + samples[:, :2] += self.trans * np.mean(bb[2:]) * (np.random.rand(n, 2) * 2 - 1) + samples[:, 2:] *= self.scale ** (np.random.rand(n, 1) * 2 - 1) + + elif self.type == 'whole': + m = int(2 * np.sqrt(n)) + xy = np.dstack(np.meshgrid(np.linspace(0, 1, m), np.linspace(0, 1, m))).reshape(-1, 2) + xy = np.random.permutation(xy)[:n] + samples[:, :2] = bb[2:] / 2 + xy * (self.img_size - bb[2:] / 2 - 1) + samples[:, 2:] *= self.scale ** (np.random.rand(n, 1) * 2 - 1) + + # adjust bbox range + samples[:, 2:] = np.clip(samples[:, 2:], 10, self.img_size - 10) + if self.valid: + samples[:, :2] = np.clip(samples[:, :2], samples[:, 2:] / 2, self.img_size - samples[:, 2:] / 2 - 1) + else: + samples[:, :2] = np.clip(samples[:, :2], 0, self.img_size) + + # (min_x, min_y, w, h) + samples[:, :2] -= samples[:, 2:] / 2 + + return samples + + def __call__(self, bbox, n, overlap_range=None, scale_range=None): + + if overlap_range is None and scale_range is None: + return self._gen_samples(bbox, n) + + else: + samples = None + remain = n + factor = 2 + while remain > 0 and factor < 16: + samples_ = self._gen_samples(bbox, remain * factor) + + idx = np.ones(len(samples_), dtype=bool) + if overlap_range is not None: + r = overlap_ratio(samples_, bbox) + idx *= (r >= overlap_range[0]) * (r <= overlap_range[1]) + if scale_range is not None: + s = np.prod(samples_[:, 2:], axis=1) / np.prod(bbox[2:]) + idx *= (s >= scale_range[0]) * (s <= scale_range[1]) + + samples_ = samples_[idx, :] + samples_ = samples_[:min(remain, len(samples_))] + if samples is None: + samples = samples_ + else: + samples = np.concatenate([samples, samples_]) + remain = n - len(samples) + factor = factor * 2 + + return samples + + def set_type(self, type_): + self.type = type_ + + def set_trans(self, trans): + self.trans = trans + + def expand_trans(self, trans_limit): + self.trans = min(self.trans * 1.1, trans_limit) diff --git a/PyTorch/contrib/cv/video/MDNet/modules/utils.py b/PyTorch/contrib/cv/video/MDNet/modules/utils.py new file mode 100755 index 0000000000000000000000000000000000000000..ed64a413a8cfa3ea8117759dc2273262b0ab96fd --- /dev/null +++ b/PyTorch/contrib/cv/video/MDNet/modules/utils.py @@ -0,0 +1,159 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from PIL import Image +import numpy as np +import cv2 + + +def overlap_ratio(rect1, rect2): + ''' + Compute overlap ratio between two rects + - rect: 1d array of [x,y,w,h] or + 2d array of N x [x,y,w,h] + ''' + + if rect1.ndim == 1: + rect1 = rect1[None, :] + if rect2.ndim == 1: + rect2 = rect2[None, :] + + left = np.maximum(rect1[:, 0], rect2[:, 0]) + right = np.minimum(rect1[:, 0] + rect1[:, 2], rect2[:, 0] + rect2[:, 2]) + top = np.maximum(rect1[:, 1], rect2[:, 1]) + bottom = np.minimum(rect1[:, 1] + rect1[:, 3], rect2[:, 1] + rect2[:, 3]) + + intersect = np.maximum(0, right - left) * np.maximum(0, bottom - top) + union = rect1[:, 2] * rect1[:, 3] + rect2[:, 2] * rect2[:, 3] - intersect + iou = np.clip(intersect / union, 0, 1) + return iou + +def center_dist(rect1, rect2): + if rect1.ndim == 1: + rect1 = rect1[None, :] + if rect2.ndim == 1: + rect2 = rect2[None, :] + + x1 = rect1[:, 0] + rect1[:, 2] / 2 + y1 = rect1[:, 1] + rect1[:, 3] / 2 + x2 = rect2[:, 0] + rect2[:, 2] / 2 + y2 = rect2[:, 1] + rect2[:, 3] / 2 + dist = np.sqrt(np.power(x1 - x2, 2) + np.power(y1 - y2, 2)) + return dist + +def crop_image2(img, bbox, img_size=107, padding=16, flip=False, rotate_limit=0, blur_limit=0): + x, y, w, h = np.array(bbox, dtype='float32') + + cx, cy = x + w/2, y + h/2 + + if padding > 0: + w += 2 * padding * w/img_size + h += 2 * padding * h/img_size + + # List of transformation matrices + matrices = [] + + # Translation matrix to move patch center to origin + translation_matrix = np.asarray([[1, 0, -cx], + [0, 1, -cy], + [0, 0, 1]], dtype=np.float32) + matrices.append(translation_matrix) + + # Scaling matrix according to image size + scaling_matrix = np.asarray([[img_size / w, 0, 0], + [0, img_size / h, 0], + [0, 0, 1]], dtype=np.float32) + matrices.append(scaling_matrix) + + # Define flip matrix + if flip and np.random.binomial(1, 0.5): + flip_matrix = np.eye(3, dtype=np.float32) + flip_matrix[0, 0] = -1 + matrices.append(flip_matrix) + + # Define rotation matrix + if rotate_limit and np.random.binomial(1, 0.5): + angle = np.random.uniform(-rotate_limit, rotate_limit) + alpha = np.cos(np.deg2rad(angle)) + beta = np.sin(np.deg2rad(angle)) + rotation_matrix = np.asarray([[alpha, -beta, 0], + [beta, alpha, 0], + [0, 0, 1]], dtype=np.float32) + matrices.append(rotation_matrix) + + # Translation matrix to move patch center from origin + revert_t_matrix = np.asarray([[1, 0, img_size / 2], + [0, 1, img_size / 2], + [0, 0, 1]], dtype=np.float32) + matrices.append(revert_t_matrix) + + # Aggregate all transformation matrices + matrix = np.eye(3) + for m_ in matrices: + matrix = np.matmul(m_, matrix) + + # Warp image, padded value is set to 128 + patch = cv2.warpPerspective(img, + matrix, + (img_size, img_size), + borderValue=128) + + if blur_limit and np.random.binomial(1, 0.5): + blur_size = np.random.choice(np.arange(1, blur_limit + 1, 2)) + patch = cv2.GaussianBlur(patch, (blur_size, blur_size), 0) + + return patch + + +def crop_image(img, bbox, img_size=107, padding=16, valid=False): + # This function is deprecated in favor of crop_image2 + + x,y,w,h = np.array(bbox, dtype='float32') + + half_w, half_h = w / 2, h / 2 + center_x, center_y = x + half_w, y + half_h + + if padding > 0: + pad_w = padding * w / img_size + pad_h = padding * h / img_size + half_w += pad_w + half_h += pad_h + + img_h, img_w, _ = img.shape + min_x = int(center_x - half_w + 0.5) + min_y = int(center_y - half_h + 0.5) + max_x = int(center_x + half_w + 0.5) + max_y = int(center_y + half_h + 0.5) + + if valid: + min_x = max(0, min_x) + min_y = max(0, min_y) + max_x = min(img_w, max_x) + max_y = min(img_h, max_y) + + if min_x >=0 and min_y >= 0 and max_x <= img_w and max_y <= img_h: + cropped = img[min_y:max_y, min_x:max_x, :] + + else: + min_x_val = max(0, min_x) + min_y_val = max(0, min_y) + max_x_val = min(img_w, max_x) + max_y_val = min(img_h, max_y) + + cropped = 128 * np.ones((max_y - min_y, max_x - min_x, 3), dtype='uint8') + cropped[min_y_val - min_y:max_y_val - min_y, min_x_val - min_x:max_x_val - min_x, :] \ + = img[min_y_val:max_y_val, min_x_val:max_x_val, :] + + scaled = np.array(Image.fromarray(cropped).resize((img_size, img_size))) + return scaled diff --git a/PyTorch/contrib/cv/video/MDNet/pretrain/data/.keep b/PyTorch/contrib/cv/video/MDNet/pretrain/data/.keep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/PyTorch/contrib/cv/video/MDNet/pretrain/data_prov.py b/PyTorch/contrib/cv/video/MDNet/pretrain/data_prov.py new file mode 100755 index 0000000000000000000000000000000000000000..1ccabc40831362b1a3ab19001ef0b1c2bae6dc86 --- /dev/null +++ b/PyTorch/contrib/cv/video/MDNet/pretrain/data_prov.py @@ -0,0 +1,92 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from PIL import Image + +import torch +import torch.utils.data as data + +from modules.sample_generator import SampleGenerator +from modules.utils import crop_image2 + + +class RegionDataset(data.Dataset): + def __init__(self, img_list, gt, opts): + self.img_list = np.asarray(img_list) + self.gt = gt + + self.batch_frames = opts['batch_frames'] + self.batch_pos = opts['batch_pos'] + self.batch_neg = opts['batch_neg'] + + self.overlap_pos = opts['overlap_pos'] + self.overlap_neg = opts['overlap_neg'] + + self.crop_size = opts['img_size'] + self.padding = opts['padding'] + + self.flip = opts.get('flip', False) + self.rotate = opts.get('rotate', 0) + self.blur = opts.get('blur', 0) + + self.index = np.random.permutation(len(self.img_list)) + self.pointer = 0 + + image = Image.open(self.img_list[0]).convert('RGB') + self.pos_generator = SampleGenerator('uniform', image.size, + opts['trans_pos'], opts['scale_pos']) + self.neg_generator = SampleGenerator('uniform', image.size, + opts['trans_neg'], opts['scale_neg']) + + def __iter__(self): + return self + + def __next__(self): + next_pointer = min(self.pointer + self.batch_frames, len(self.img_list)) + idx = self.index[self.pointer:next_pointer] + if len(idx) < self.batch_frames: + self.index = np.random.permutation(len(self.img_list)) + next_pointer = self.batch_frames - len(idx) + idx = np.concatenate((idx, self.index[:next_pointer])) + self.pointer = next_pointer + + pos_regions = np.empty((0, 3, self.crop_size, self.crop_size), dtype='float32') + neg_regions = np.empty((0, 3, self.crop_size, self.crop_size), dtype='float32') + for i, (img_path, bbox) in enumerate(zip(self.img_list[idx], self.gt[idx])): + image = Image.open(img_path).convert('RGB') + image = np.asarray(image) + + n_pos = (self.batch_pos - len(pos_regions)) // (self.batch_frames - i) + n_neg = (self.batch_neg - len(neg_regions)) // (self.batch_frames - i) + pos_examples = self.pos_generator(bbox, n_pos, overlap_range=self.overlap_pos) + neg_examples = self.neg_generator(bbox, n_neg, overlap_range=self.overlap_neg) + + pos_regions = np.concatenate((pos_regions, self.extract_regions(image, pos_examples)), axis=0) + neg_regions = np.concatenate((neg_regions, self.extract_regions(image, neg_examples)), axis=0) + + pos_regions = torch.from_numpy(pos_regions) + neg_regions = torch.from_numpy(neg_regions) + return pos_regions, neg_regions + + next = __next__ + + def extract_regions(self, image, samples): + regions = np.zeros((len(samples), self.crop_size, self.crop_size, 3), dtype='uint8') + for i, sample in enumerate(samples): + regions[i] = crop_image2(image, sample, self.crop_size, self.padding, + self.flip, self.rotate, self.blur) + regions = regions.transpose(0, 3, 1, 2) + regions = regions.astype('float32') - 128. + return regions diff --git a/PyTorch/contrib/cv/video/MDNet/pretrain/options_imagenet.yaml b/PyTorch/contrib/cv/video/MDNet/pretrain/options_imagenet.yaml new file mode 100755 index 0000000000000000000000000000000000000000..b0f550d139caa1592a0e8b173d77fda25d91e64e --- /dev/null +++ b/PyTorch/contrib/cv/video/MDNet/pretrain/options_imagenet.yaml @@ -0,0 +1,38 @@ +use_gpu: true + +# data path +data_path: "pretrain/data/imagenet_vid.pkl" + +# model path +init_model_path: "models/imagenet-vgg-m.mat" +model_path: "models/mdnet_imagenet_vid.pth" + +# input size +img_size: 107 +padding: 16 + +# batch size +batch_frames: 8 +batch_pos: 32 +batch_neg: 96 +batch_accum: 50 + +# training examples sampling +trans_pos: 0.1 +scale_pos: 1.3 +trans_neg: 2 +scale_neg: 1.6 +overlap_pos: [0.7, 1] +overlap_neg: [0, 0.5] + +# augmentation +flip: True +rotate: 30 +blur: 7 + +# training +lr: 0.0001 +grad_clip: 10 +lr_mult: {"fc": 10} +ft_layers: ["conv", "fc"] +n_cycles: 100 diff --git a/PyTorch/contrib/cv/video/MDNet/pretrain/options_vot.yaml b/PyTorch/contrib/cv/video/MDNet/pretrain/options_vot.yaml new file mode 100755 index 0000000000000000000000000000000000000000..7375ac57fd999dfbec518a460d4ca149a8f005d4 --- /dev/null +++ b/PyTorch/contrib/cv/video/MDNet/pretrain/options_vot.yaml @@ -0,0 +1,32 @@ +use_gpu: true + +# data path +data_path: "pretrain/data/vot-otb.pkl" + +# model path +init_model_path: "models/imagenet-vgg-m.mat" +model_path: "models/mdnet_vot-otb.pth" + +# input size +img_size: 107 +padding: 16 + +# batch size +batch_frames: 8 +batch_pos: 32 +batch_neg: 96 + +# training examples sampling +trans_pos: 0.1 +scale_pos: 1.3 +trans_neg: 2 +scale_neg: 1.6 +overlap_pos: [0.7, 1] +overlap_neg: [0, 0.5] + +# training +lr: 0.0001 +grad_clip: 10 +lr_mult: {"fc": 10} +ft_layers: ["conv", "fc"] +n_cycles: 50 diff --git a/PyTorch/contrib/cv/video/MDNet/pretrain/prepro_imagenet.py b/PyTorch/contrib/cv/video/MDNet/pretrain/prepro_imagenet.py new file mode 100755 index 0000000000000000000000000000000000000000..69a010008fa66f1faef820b127309bdfea5ae138 --- /dev/null +++ b/PyTorch/contrib/cv/video/MDNet/pretrain/prepro_imagenet.py @@ -0,0 +1,94 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import numpy as np +import pickle +from collections import OrderedDict + +import xml.etree.ElementTree +import xmltodict + +seq_home = 'datasets/ILSVRC/' +output_path = 'pretrain/data/imagenet_vid.pkl' + +train_list = [p for p in os.listdir(seq_home + 'Data/VID/train')] +seq_list = [] +for num, cur_dir in enumerate(train_list): + seq_list += [os.path.join(cur_dir, p) for p in os.listdir(seq_home + 'Data/VID/train/' + cur_dir)] + +data = {} +completeNum = 0 +for i, seqname in enumerate(seq_list): + print('{}/{}: {}'.format(i, len(seq_list), seqname)) + seq_path = seq_home + 'Data/VID/train/' + seqname + gt_path = seq_home +'Annotations/VID/train/' + seqname + img_list = sorted([p for p in os.listdir(seq_path) if os.path.splitext(p)[1] == '.JPEG']) + + enable_gt = [] + enable_img_list = [] + save_enable = True + gt_list = sorted([os.path.join(gt_path, p) for p in os.listdir(gt_path) if os.path.splitext(p)[1] == '.xml']) + + for gidx in range(0, len(img_list)): + with open(gt_list[gidx]) as fd: + doc = xmltodict.parse(fd.read()) + try: + try: + object_ = doc['annotation']['object'][0] + except: + object_ = doc['annotation']['object'] + except: + ## no object, occlusion and hidden etc. + continue + + if int(object_['trackid']) != 0: + continue + + xmin = float(object_['bndbox']['xmin']) + xmax = float(object_['bndbox']['xmax']) + ymin = float(object_['bndbox']['ymin']) + ymax = float(object_['bndbox']['ymax']) + + ## discard too big object + if (float(doc['annotation']['size']['width']) / 2. < xmax - xmin ) and \ + (float(doc['annotation']['size']['height']) / 2. < ymax - ymin ): + continue + + cur_gt = np.zeros((4)) + cur_gt[0] = xmin + cur_gt[1] = ymin + cur_gt[2] = xmax - xmin + cur_gt[3] = ymax - ymin + + enable_gt.append(cur_gt) + enable_img_list.append(img_list[gidx]) + + if len(enable_img_list) == 0: + save_enable = False + + if save_enable: + assert len(enable_img_list) == len(enable_gt), "Lengths do not match!!" + enable_img_list = [os.path.join(seq_path, p) for p in enable_img_list] + data[seqname] = {'images':enable_img_list, 'gt':np.asarray(enable_gt)} + completeNum += 1 + print('Complete!') + +# Save db +output_dir = os.path.dirname(output_path) +os.makedirs(output_dir, exist_ok=True) +with open(output_path, 'wb') as fp: + pickle.dump(data, fp, -1) + +print('complete {} videos'.format(completeNum)) diff --git a/PyTorch/contrib/cv/video/MDNet/pretrain/prepro_vot.py b/PyTorch/contrib/cv/video/MDNet/pretrain/prepro_vot.py new file mode 100755 index 0000000000000000000000000000000000000000..5be8609f92dbc2ae2f90ddf763ce715303b9ebc8 --- /dev/null +++ b/PyTorch/contrib/cv/video/MDNet/pretrain/prepro_vot.py @@ -0,0 +1,52 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import numpy as np +import pickle +from collections import OrderedDict + +seq_home = 'datasets/VOT' +seqlist_path = 'datasets/list/vot-otb.txt' +output_path = 'pretrain/data/vot-otb.pkl' + +with open(seqlist_path,'r') as fp: + seq_list = fp.read().splitlines() + +# Construct db +data = OrderedDict() +for i, seq in enumerate(seq_list): + img_list = sorted([p for p in os.listdir(os.path.join(seq_home, seq, 'color')) if os.path.splitext(p)[1] == '.jpg']) + gt = np.loadtxt(os.path.join(seq_home, seq, 'groundtruth.txt'), delimiter=',') + + # if seq == 'vot2014/ball': + # img_list = img_list[1:] + + assert len(img_list) == len(gt), "Lengths do not match!!" + + if gt.shape[1] == 8: + x_min = np.min(gt[:, [0, 2, 4, 6]], axis=1)[:, None] + y_min = np.min(gt[:, [1, 3, 5, 7]], axis=1)[:, None] + x_max = np.max(gt[:, [0, 2, 4, 6]], axis=1)[:, None] + y_max = np.max(gt[:, [1, 3, 5, 7]], axis=1)[:, None] + gt = np.concatenate((x_min, y_min, x_max - x_min, y_max - y_min), axis=1) + + img_list = [os.path.join(seq_home, seq, 'color', img) for img in img_list] + data[seq] = {'images': img_list, 'gt': gt} + +# Save db +output_dir = os.path.dirname(output_path) +os.makedirs(output_dir, exist_ok=True) +with open(output_path, 'wb') as fp: + pickle.dump(data, fp) diff --git a/PyTorch/contrib/cv/video/MDNet/pretrain/train_mdnet.py b/PyTorch/contrib/cv/video/MDNet/pretrain/train_mdnet.py new file mode 100755 index 0000000000000000000000000000000000000000..9cf837d01fad7909fa1cb76c1ec80345245355a9 --- /dev/null +++ b/PyTorch/contrib/cv/video/MDNet/pretrain/train_mdnet.py @@ -0,0 +1,118 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os, sys +import pickle +import yaml +import time +import argparse +import numpy as np + +import torch +import torch_npu +from apex import amp + +sys.path.insert(0,'.') +from data_prov import RegionDataset +from modules.model import MDNet, set_optimizer, BCELoss, Precision + + +def train_mdnet(opts): + + # Init dataset + with open(opts['data_path'], 'rb') as fp: + data = pickle.load(fp) + K = len(data) + dataset = [None] * K + for k, seq in enumerate(data.values()): + dataset[k] = RegionDataset(seq['images'], seq['gt'], opts) + + # Init model + model = MDNet(opts['init_model_path'], K) + if opts['use_gpu']: + model = model.npu() + model.set_learnable_params(opts['ft_layers']) + + # Init criterion and optimizer + criterion = BCELoss() + evaluator = Precision() + optimizer = set_optimizer(model, opts['lr'], opts['lr_mult']) + + model, optimizer = amp.initialize(model, optimizer, opt_level="O1", loss_scale='dynamic') + + # Main trainig loop + for i in range(opts['n_cycles']): + print('==== Start Cycle {:d}/{:d} ===='.format(i + 1, opts['n_cycles'])) + + if i in opts.get('lr_decay', []): + print('decay learning rate') + for param_group in optimizer.param_groups: + param_group['lr'] *= opts.get('gamma', 0.1) + + # Training + model.train() + prec = np.zeros(K) + k_list = np.random.permutation(K) + for j, k in enumerate(k_list): + tic = time.time() + # training + pos_regions, neg_regions = dataset[k].next() + if opts['use_gpu']: + pos_regions = pos_regions.npu() + neg_regions = neg_regions.npu() + pos_score = model(pos_regions, k) + neg_score = model(neg_regions, k) + + loss = criterion(pos_score, neg_score) + + batch_accum = opts.get('batch_accum', 1) + if j % batch_accum == 0: + model.zero_grad() + # loss.backward() + with amp.scale_loss(loss, optimizer) as scaled_loss: + scaled_loss.backward() + if j % batch_accum == batch_accum - 1 or j == len(k_list) - 1: + if 'grad_clip' in opts: + torch.nn.utils.clip_grad_norm_(model.parameters(), opts['grad_clip']) + optimizer.step() + + prec[k] = evaluator(pos_score, neg_score) + + toc = time.time()-tic + print('Cycle {:2d}/{:2d}, Iter {:2d}/{:2d} (Domain {:2d}), Loss {:.3f}, Precision {:.3f}, Time {:.3f}' + .format(i, opts['n_cycles'], j, len(k_list), k, loss.item(), prec[k], toc)) + + print('Mean Precision: {:.3f}'.format(prec.mean())) + print('Save model to {:s}'.format(opts['model_path'])) + if opts['use_gpu']: + model = model.cpu() + states = {'shared_layers': model.layers.state_dict()} + torch.save(states, opts['model_path']) + if opts['use_gpu']: + model = model.npu() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('-d', '--dataset', default='imagenet', help='training dataset {vot, imagenet}') + parser.add_argument('--device_id', default=0) + args = parser.parse_args() + + opts = yaml.safe_load(open('pretrain/options_{}.yaml'.format(args.dataset), 'r')) + + np.random.seed(0) + torch.manual_seed(0) + torch.npu.set_device(int(args.device_id)) + + train_mdnet(opts) diff --git a/PyTorch/contrib/cv/video/MDNet/requirements.txt b/PyTorch/contrib/cv/video/MDNet/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..e4adb6677bd6b79bb673e30727ef62ca45ecc17d --- /dev/null +++ b/PyTorch/contrib/cv/video/MDNet/requirements.txt @@ -0,0 +1,7 @@ +matplotlib +numpy +opencv-python +Pillow +PyYAML +scikit-learn +scipy diff --git a/PyTorch/contrib/cv/video/MDNet/test/eval_1p.sh b/PyTorch/contrib/cv/video/MDNet/test/eval_1p.sh new file mode 100755 index 0000000000000000000000000000000000000000..4cadd1695812afea3e4982ecceec42a8b9a3b6b7 --- /dev/null +++ b/PyTorch/contrib/cv/video/MDNet/test/eval_1p.sh @@ -0,0 +1,63 @@ +#!/bin/bash + +#集合通信参数,不需要修改 +export RANK_SIZE=1 +#规避环境变量冲突 +if [ -f /usr/local/Ascend/bin/setenv.bash ];then + unset PYTHONPATH + source /usr/local/Ascend/bin/setenv.bash +fi +device_id=0 +#网络名称,同目录名称,需要模型审视修改 +Network="MDNet" + +#参数校验,不需要修改 +for para in $* +do + if [[ $para == --device_id* ]];then + device_id=`echo ${para#*=}` + fi +done + +# 校验是否指定了device_id,分动态分配device_id与手动指定device_id,此处不需要修改 +if [ $ASCEND_DEVICE_ID ];then + echo "device id is ${ASCEND_DEVICE_ID}" +elif [ ${device_id} ];then + export ASCEND_DEVICE_ID=${device_id} + echo "device id is ${ASCEND_DEVICE_ID}" +else + "[Error] device id must be config" + exit 1 +fi + +###############指定脚本执行路径############### +# cd到与test文件夹同层级目录下执行脚本,提高兼容性;test_path_dir为包含test文件夹的路径 +cur_path=`pwd` +cur_path_last_dirname=${cur_path##*/} +if [ x"${cur_path_last_dirname}" == x"test" ];then + test_path_dir=${cur_path} + cd .. + cur_path=`pwd` +else + test_path_dir=${cur_path}/test +fi + +#################创建日志输出目录,不需要修改################# +if [ -d ${test_path_dir}/output/${ASCEND_DEVICE_ID} ];then + rm -rf ${test_path_dir}/output/${ASCEND_DEVICE_ID} + mkdir -p ${test_path_dir}/output/$ASCEND_DEVICE_ID +else + mkdir -p ${test_path_dir}/output/$ASCEND_DEVICE_ID +fi + +source /usr/local/Ascend/ascend-toolkit/set_env.sh + +# eval +nohup python3 tracking/eval_otb.py \ + --device_id=$ASCEND_DEVICE_ID > ${test_path_dir}/output/${ASCEND_DEVICE_ID}/eval_${ASCEND_DEVICE_ID}.log 2>&1 & +wait + +#结果打印 +echo "------------------ Final result ------------------" +train_accuracy=`grep 'Mean Precision:' ${test_path_dir}/output/${ASCEND_DEVICE_ID}/eval_${ASCEND_DEVICE_ID}.log | tail -1 | awk -F 'Precision: ' '{print $2}'` +echo "Final Train Accuracy : ${train_accuracy}" diff --git a/PyTorch/contrib/cv/video/MDNet/test/train_full_1p.sh b/PyTorch/contrib/cv/video/MDNet/test/train_full_1p.sh new file mode 100755 index 0000000000000000000000000000000000000000..d5666fc7b7a059dc3750eac84907d939a7ab6eb6 --- /dev/null +++ b/PyTorch/contrib/cv/video/MDNet/test/train_full_1p.sh @@ -0,0 +1,109 @@ +#!/bin/bash + +#集合通信参数,不需要修改 +export RANK_SIZE=1 +#规避环境变量冲突 +if [ -f /usr/local/Ascend/bin/setenv.bash ];then + unset PYTHONPATH + source /usr/local/Ascend/bin/setenv.bash +fi +device_id=0 +#网络名称,同目录名称,需要模型审视修改 +Network="MDNet" + +#训练batch_size,需要模型审视修改 +batch_size=8 + +#参数校验,不需要修改 +for para in $* +do + if [[ $para == --device_id* ]];then + device_id=`echo ${para#*=}` + fi +done + +# 校验是否指定了device_id,分动态分配device_id与手动指定device_id,此处不需要修改 +if [ $ASCEND_DEVICE_ID ];then + echo "device id is ${ASCEND_DEVICE_ID}" +elif [ ${device_id} ];then + export ASCEND_DEVICE_ID=${device_id} + echo "device id is ${ASCEND_DEVICE_ID}" +else + "[Error] device id must be config" + exit 1 +fi + +###############指定训练脚本执行路径############### +# cd到与test文件夹同层级目录下执行脚本,提高兼容性;test_path_dir为包含test文件夹的路径 +cur_path=`pwd` +cur_path_last_dirname=${cur_path##*/} +if [ x"${cur_path_last_dirname}" == x"test" ];then + test_path_dir=${cur_path} + cd .. + cur_path=`pwd` +else + test_path_dir=${cur_path}/test +fi + +#训练开始时间,不需要修改 +start_time=$(date +%s) + +#################创建日志输出目录,不需要修改################# +if [ -d ${test_path_dir}/output/${ASCEND_DEVICE_ID} ];then + rm -rf ${test_path_dir}/output/${ASCEND_DEVICE_ID} + mkdir -p ${test_path_dir}/output/$ASCEND_DEVICE_ID +else + mkdir -p ${test_path_dir}/output/$ASCEND_DEVICE_ID +fi + +source /usr/local/Ascend/ascend-toolkit/set_env.sh + +#执行训练脚本,以下传参不需要修改,其他需要模型审视修改 +nohup python3 pretrain/train_mdnet.py \ + --dataset vot \ + --device_id=$ASCEND_DEVICE_ID > ${test_path_dir}/output/${ASCEND_DEVICE_ID}/train_${ASCEND_DEVICE_ID}.log 2>&1 & +wait + +nohup python3 tracking/eval_otb.py \ + --device_id=$ASCEND_DEVICE_ID > ${test_path_dir}/output/${ASCEND_DEVICE_ID}/eval_${ASCEND_DEVICE_ID}.log 2>&1 & +wait + +#训练结束时间,不需要修改 +end_time=$(date +%s) +e2e_time=$(( $end_time - $start_time )) + +#结果打印,不需要修改 +echo "------------------ Final result ------------------" +#单迭代训练时长 +TrainingTime=`grep 'Iter' ${test_path_dir}/output/${ASCEND_DEVICE_ID}/train_${ASCEND_DEVICE_ID}.log | tail -n +150 | awk -F 'Time ' '{print $2}' | awk '{sum+=$1} END {print sum/NR}'` +#输出性能FPS,需要模型审视修改 +FPS=`awk 'BEGIN{printf "%.2f\n", '${batch_size}/${TrainingTime}'}'` +#打印,不需要修改 +echo "Final Performance images/sec : $FPS" + +#输出训练精度,需要模型审视修改 +train_accuracy=`grep 'Mean Precision:' ${test_path_dir}/output/${ASCEND_DEVICE_ID}/eval_${ASCEND_DEVICE_ID}.log | tail -1 | awk -F 'Precision: ' '{print $2}'` +#打印,不需要修改 +echo "Final Train Accuracy : ${train_accuracy}" +echo "E2E Training Duration sec : $e2e_time" + +#稳定性精度看护结果汇总 +#训练用例信息,不需要修改 +BatchSize=${batch_size} +DeviceType=`uname -m` +CaseName=${Network}_bs${BatchSize}_${RANK_SIZE}'p'_'acc' + +##获取性能数据,不需要修改 +#吞吐量 +ActualFPS=${FPS} + +#关键信息打印到${CaseName}.log中,不需要修改 +echo "Network = ${Network}" > ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "RankSize = ${RANK_SIZE}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "BatchSize = ${BatchSize}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "DeviceType = ${DeviceType}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "CaseName = ${CaseName}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "ActualFPS = ${ActualFPS}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "TrainingTime = ${TrainingTime}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "TrainAccuracy = ${train_accuracy}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "E2ETrainingTime = ${e2e_time}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log diff --git a/PyTorch/contrib/cv/video/MDNet/test/train_performance_1p.sh b/PyTorch/contrib/cv/video/MDNet/test/train_performance_1p.sh new file mode 100755 index 0000000000000000000000000000000000000000..829a17e6ab1a68b38dd10bc53c7f3dc9779a8305 --- /dev/null +++ b/PyTorch/contrib/cv/video/MDNet/test/train_performance_1p.sh @@ -0,0 +1,105 @@ +#!/bin/bash + +#集合通信参数,不需要修改 +export RANK_SIZE=1 +#规避环境变量冲突 +if [ -f /usr/local/Ascend/bin/setenv.bash ];then + unset PYTHONPATH + source /usr/local/Ascend/bin/setenv.bash +fi +device_id=0 +#网络名称,同目录名称,需要模型审视修改 +Network="MDNet" + +#训练batch_size,需要模型审视修改 +batch_size=8 + +#参数校验,不需要修改 +for para in $* +do + if [[ $para == --device_id* ]];then + device_id=`echo ${para#*=}` + fi +done + +# 校验是否指定了device_id,分动态分配device_id与手动指定device_id,此处不需要修改 +if [ $ASCEND_DEVICE_ID ];then + echo "device id is ${ASCEND_DEVICE_ID}" +elif [ ${device_id} ];then + export ASCEND_DEVICE_ID=${device_id} + echo "device id is ${ASCEND_DEVICE_ID}" +else + "[Error] device id must be config" + exit 1 +fi + +###############指定训练脚本执行路径############### +# cd到与test文件夹同层级目录下执行脚本,提高兼容性;test_path_dir为包含test文件夹的路径 +cur_path=`pwd` +cur_path_last_dirname=${cur_path##*/} +if [ x"${cur_path_last_dirname}" == x"test" ];then + test_path_dir=${cur_path} + cd .. + cur_path=`pwd` +else + test_path_dir=${cur_path}/test +fi + +#训练开始时间,不需要修改 +start_time=$(date +%s) + +#################创建日志输出目录,不需要修改################# +if [ -d ${test_path_dir}/output/${ASCEND_DEVICE_ID} ];then + rm -rf ${test_path_dir}/output/${ASCEND_DEVICE_ID} + mkdir -p ${test_path_dir}/output/$ASCEND_DEVICE_ID +else + mkdir -p ${test_path_dir}/output/$ASCEND_DEVICE_ID +fi + +source /usr/local/Ascend/ascend-toolkit/set_env.sh + +sed -i "s|n_cycles: 50|n_cycles: 5|g" ${cur_path}/pretrain/options_vot.yaml + +#执行训练脚本,以下传参不需要修改,其他需要模型审视修改 +nohup python3 pretrain/train_mdnet.py \ + --dataset vot \ + --device_id=$ASCEND_DEVICE_ID > ${test_path_dir}/output/${ASCEND_DEVICE_ID}/train_${ASCEND_DEVICE_ID}.log 2>&1 & +wait + +sed -i "s|n_cycles: 5|n_cycles: 50|g" ${cur_path}/pretrain/options_vot.yaml + +#训练结束时间,不需要修改 +end_time=$(date +%s) +e2e_time=$(( $end_time - $start_time )) + +#结果打印,不需要修改 +echo "------------------ Final result ------------------" +#单迭代训练时长 +TrainingTime=`grep 'Iter' ${test_path_dir}/output/${ASCEND_DEVICE_ID}/train_${ASCEND_DEVICE_ID}.log | tail -n +150 | awk -F 'Time ' '{print $2}' | awk '{sum+=$1} END {print sum/NR}'` +#输出性能FPS,需要模型审视修改 +FPS=`awk 'BEGIN{printf "%.2f\n", '${batch_size}/${TrainingTime}'}'` +#打印,不需要修改 +echo "Final Performance images/sec : $FPS" + +#打印,不需要修改 +echo "E2E Training Duration sec : $e2e_time" + +#稳定性精度看护结果汇总 +#训练用例信息,不需要修改 +BatchSize=${batch_size} +DeviceType=`uname -m` +CaseName=${Network}_bs${BatchSize}_${RANK_SIZE}'p'_'perf' + +##获取性能数据,不需要修改 +#吞吐量 +ActualFPS=${FPS} + +#关键信息打印到${CaseName}.log中,不需要修改 +echo "Network = ${Network}" > ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "RankSize = ${RANK_SIZE}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "BatchSize = ${BatchSize}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "DeviceType = ${DeviceType}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "CaseName = ${CaseName}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "ActualFPS = ${ActualFPS}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "TrainingTime = ${TrainingTime}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "E2ETrainingTime = ${e2e_time}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log diff --git a/PyTorch/contrib/cv/video/MDNet/tracking/__init__.py b/PyTorch/contrib/cv/video/MDNet/tracking/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/PyTorch/contrib/cv/video/MDNet/tracking/bbreg.py b/PyTorch/contrib/cv/video/MDNet/tracking/bbreg.py new file mode 100755 index 0000000000000000000000000000000000000000..a8fcd7cbd8d82b5c1c85b27104ece699fa7769bd --- /dev/null +++ b/PyTorch/contrib/cv/video/MDNet/tracking/bbreg.py @@ -0,0 +1,73 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from sklearn.linear_model import Ridge +import numpy as np + +from modules.utils import overlap_ratio + + +class BBRegressor(): + def __init__(self, img_size, alpha=1000, overlap=[0.6, 1], scale=[1, 2]): + self.img_size = img_size + self.alpha = alpha + self.overlap_range = overlap + self.scale_range = scale + self.model = Ridge(alpha=self.alpha) + + def train(self, X, bbox, gt): + X = X.cpu().numpy() + bbox = np.copy(bbox) + gt = np.copy(gt) + + if gt.ndim==1: + gt = gt[None,:] + + r = overlap_ratio(bbox, gt) + s = np.prod(bbox[:,2:], axis=1) / np.prod(gt[0,2:]) + idx = (r >= self.overlap_range[0]) * (r <= self.overlap_range[1]) * \ + (s >= self.scale_range[0]) * (s <= self.scale_range[1]) + + X = X[idx] + bbox = bbox[idx] + + Y = self.get_examples(bbox, gt) + self.model.fit(X, Y) + + def predict(self, X, bbox): + X = X.cpu().numpy() + bbox_ = np.copy(bbox) + + Y = self.model.predict(X) + + bbox_[:,:2] = bbox_[:,:2] + bbox_[:,2:]/2 + bbox_[:,:2] = Y[:,:2] * bbox_[:,2:] + bbox_[:,:2] + bbox_[:,2:] = np.exp(Y[:,2:]) * bbox_[:,2:] + bbox_[:,:2] = bbox_[:,:2] - bbox_[:,2:]/2 + + bbox_[:,:2] = np.maximum(bbox_[:,:2], 0) + bbox_[:,2:] = np.minimum(bbox_[:,2:], self.img_size - bbox[:,:2]) + return bbox_ + + def get_examples(self, bbox, gt): + bbox[:,:2] = bbox[:,:2] + bbox[:,2:]/2 + gt[:,:2] = gt[:,:2] + gt[:,2:]/2 + + dst_xy = (gt[:,:2] - bbox[:,:2]) / bbox[:,2:] + dst_wh = np.log(gt[:,2:] / bbox[:,2:]) + + Y = np.concatenate((dst_xy, dst_wh), axis=1) + return Y + diff --git a/PyTorch/contrib/cv/video/MDNet/tracking/data_prov.py b/PyTorch/contrib/cv/video/MDNet/tracking/data_prov.py new file mode 100755 index 0000000000000000000000000000000000000000..88ad3a17ebbeaa2fab9cf9c6c45b92ea4d63ec9b --- /dev/null +++ b/PyTorch/contrib/cv/video/MDNet/tracking/data_prov.py @@ -0,0 +1,59 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import numpy as np +from PIL import Image + +import torch +import torch.utils.data as data + +from modules.utils import crop_image2 + + +class RegionExtractor(): + def __init__(self, image, samples, opts): + self.image = np.asarray(image) + self.samples = samples + + self.crop_size = opts['img_size'] + self.padding = opts['padding'] + self.batch_size = opts['batch_test'] + + self.index = np.arange(len(samples)) + self.pointer = 0 + + def __iter__(self): + return self + + def __next__(self): + if self.pointer == len(self.samples): + self.pointer = 0 + raise StopIteration + else: + next_pointer = min(self.pointer + self.batch_size, len(self.samples)) + index = self.index[self.pointer:next_pointer] + self.pointer = next_pointer + regions = self.extract_regions(index) + regions = torch.from_numpy(regions) + return regions + next = __next__ + + def extract_regions(self, index): + regions = np.zeros((len(index), self.crop_size, self.crop_size, 3), dtype='uint8') + for i, sample in enumerate(self.samples[index]): + regions[i] = crop_image2(self.image, sample, self.crop_size, self.padding) + regions = regions.transpose(0, 3, 1, 2) + regions = regions.astype('float32') - 128. + return regions diff --git a/PyTorch/contrib/cv/video/MDNet/tracking/eval_otb.py b/PyTorch/contrib/cv/video/MDNet/tracking/eval_otb.py new file mode 100644 index 0000000000000000000000000000000000000000..8dc06433479ab7637132cdeef938728bb0354968 --- /dev/null +++ b/PyTorch/contrib/cv/video/MDNet/tracking/eval_otb.py @@ -0,0 +1,60 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import json +import argparse + +import numpy as np +import torch +import torch_npu + +from gen_config import gen_config +from run_tracker import run_mdnet + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('-s', '--seq', default='', help='input seq') + parser.add_argument('--seq_home', default='./datasets/OTB50', help='input seq home') + parser.add_argument('-j', '--json', default='', help='input json') + parser.add_argument('-f', '--savefig', default=True, action='store_true') + parser.add_argument('-d', '--display', default=False, action='store_true') + parser.add_argument('--device_id', default=0) + args = parser.parse_args() + + np.random.seed(0) + torch.manual_seed(0) + torch.npu.set_device(int(args.device_id)) + + dist = np.array([]) + + for seq_name in os.listdir(args.seq_home): + args.seq = seq_name + + # Generate sequence config + img_list, init_bbox, gt, savefig_dir, display, result_path = gen_config(args) + + # Run tracker + result, result_bb, fps, result_dist = run_mdnet(img_list, init_bbox, gt=gt, savefig_dir=savefig_dir, display=display) + + # Save result + res = {} + res['res'] = result_bb.round().tolist() + res['type'] = 'rect' + res['fps'] = fps + json.dump(res, open(result_path, 'w'), indent=2) + + dist = np.append(dist, result_dist) + + print(f'Mean Precision: {sum(dist <= 20)/len(dist):.3f}') diff --git a/PyTorch/contrib/cv/video/MDNet/tracking/gen_config.py b/PyTorch/contrib/cv/video/MDNet/tracking/gen_config.py new file mode 100755 index 0000000000000000000000000000000000000000..ce8c55ceed85226ffbe892035a71661e4bf911b7 --- /dev/null +++ b/PyTorch/contrib/cv/video/MDNet/tracking/gen_config.py @@ -0,0 +1,64 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import json +import numpy as np + + +def gen_config(args): + + if args.seq != '': + # generate config from a sequence name + + seq_home = args.seq_home if args.seq_home else 'datasets/OTB' + result_home = 'results' + + seq_name = args.seq + img_dir = os.path.join(seq_home, seq_name, 'img') + gt_path = os.path.join(seq_home, seq_name, 'groundtruth_rect.txt') + + img_list = os.listdir(img_dir) + img_list.sort() + img_list = [os.path.join(img_dir, x) for x in img_list] + + with open(gt_path) as f: + gt = np.loadtxt((x.replace('\t',',') for x in f), delimiter=',') + init_bbox = gt[0] + img_list = img_list[:len(gt)] + + result_dir = os.path.join(result_home, seq_name) + if not os.path.exists(result_dir): + os.makedirs(result_dir) + savefig_dir = os.path.join(result_dir, 'figs') + result_path = os.path.join(result_dir, 'result.json') + + elif args.json != '': + # load config from a json file + + param = json.load(open(args.json, 'r')) + seq_name = param['seq_name'] + img_list = param['img_list'] + init_bbox = param['init_bbox'] + savefig_dir = param['savefig_dir'] + result_path = param['result_path'] + gt = None + + if args.savefig: + if not os.path.exists(savefig_dir): + os.makedirs(savefig_dir) + else: + savefig_dir = '' + + return img_list, init_bbox, gt, savefig_dir, args.display, result_path diff --git a/PyTorch/contrib/cv/video/MDNet/tracking/options.yaml b/PyTorch/contrib/cv/video/MDNet/tracking/options.yaml new file mode 100755 index 0000000000000000000000000000000000000000..f9767780d58b2f5f1da96cf3a240332c2f5425a6 --- /dev/null +++ b/PyTorch/contrib/cv/video/MDNet/tracking/options.yaml @@ -0,0 +1,61 @@ +use_gpu: true + +# model path +model_path: "models/mdnet_vot-otb.pth" + +# input size +img_size: 107 +padding: 16 + +# batch size +batch_pos: 32 +batch_neg: 96 +batch_neg_cand: 1024 +batch_test: 256 + +# candidates sampling +n_samples: 256 +trans: 0.6 +scale: 1.05 +trans_limit: 1.5 + +# training examples sampling +trans_pos: 0.1 +scale_pos: 1.3 +trans_neg_init: 1 +scale_neg_init: 1.6 +trans_neg: 2 +scale_neg: 1.3 + +# bounding box regression +n_bbreg: 1000 +overlap_bbreg: [0.6, 1] +trans_bbreg: 0.3 +scale_bbreg: 1.6 +aspect_bbreg: 1.1 + +# initial training +lr_init: 0.0005 +maxiter_init: 50 +n_pos_init: 500 +n_neg_init: 5000 +overlap_pos_init: [0.7, 1] +overlap_neg_init: [0, 0.5] + +# online training +lr_update: 0.001 +maxiter_update: 15 +n_pos_update: 50 +n_neg_update: 200 +overlap_pos_update: [0.7, 1] +overlap_neg_update: [0, 0.3] + +# update criteria +long_interval: 10 +n_frames_long: 100 +n_frames_short: 30 + +# training +grad_clip: 10 +lr_mult: {'fc6': 10} +ft_layers: ['fc'] diff --git a/PyTorch/contrib/cv/video/MDNet/tracking/run_tracker.py b/PyTorch/contrib/cv/video/MDNet/tracking/run_tracker.py new file mode 100755 index 0000000000000000000000000000000000000000..c2197667c15df5d2263035c3c0f069a54ecabb64 --- /dev/null +++ b/PyTorch/contrib/cv/video/MDNet/tracking/run_tracker.py @@ -0,0 +1,348 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import os +import sys +import time +import argparse +import yaml, json +from PIL import Image + +import matplotlib.pyplot as plt + +import torch +import torch_npu +import torch.utils.data as data +import torch.optim as optim + +sys.path.insert(0, '.') +from modules.model import MDNet, BCELoss, set_optimizer +from modules.sample_generator import SampleGenerator +from modules.utils import overlap_ratio, center_dist +from data_prov import RegionExtractor +from bbreg import BBRegressor +from gen_config import gen_config + +opts = yaml.safe_load(open('tracking/options.yaml','r')) + + +def forward_samples(model, image, samples, out_layer='conv3'): + model.eval() + extractor = RegionExtractor(image, samples, opts) + for i, regions in enumerate(extractor): + if opts['use_gpu']: + regions = regions.npu() + with torch.no_grad(): + feat = model(regions, out_layer=out_layer) + if i==0: + feats = feat.detach().clone() + else: + feats = torch.cat((feats, feat.detach().clone()), 0) + return feats + + +def train(model, criterion, optimizer, pos_feats, neg_feats, maxiter, in_layer='fc4'): + model.train() + + batch_pos = opts['batch_pos'] + batch_neg = opts['batch_neg'] + batch_test = opts['batch_test'] + batch_neg_cand = max(opts['batch_neg_cand'], batch_neg) + + pos_idx = np.random.permutation(pos_feats.size(0)) + neg_idx = np.random.permutation(neg_feats.size(0)) + while(len(pos_idx) < batch_pos * maxiter): + pos_idx = np.concatenate([pos_idx, np.random.permutation(pos_feats.size(0))]) + while(len(neg_idx) < batch_neg_cand * maxiter): + neg_idx = np.concatenate([neg_idx, np.random.permutation(neg_feats.size(0))]) + pos_pointer = 0 + neg_pointer = 0 + + for i in range(maxiter): + + # select pos idx + pos_next = pos_pointer + batch_pos + pos_cur_idx = pos_idx[pos_pointer:pos_next] + pos_cur_idx = pos_feats.new(pos_cur_idx).long() + pos_pointer = pos_next + + # select neg idx + neg_next = neg_pointer + batch_neg_cand + neg_cur_idx = neg_idx[neg_pointer:neg_next] + neg_cur_idx = neg_feats.new(neg_cur_idx).long() + neg_pointer = neg_next + + # create batch + batch_pos_feats = pos_feats[pos_cur_idx] + batch_neg_feats = neg_feats[neg_cur_idx] + + # hard negative mining + if batch_neg_cand > batch_neg: + model.eval() + for start in range(0, batch_neg_cand, batch_test): + end = min(start + batch_test, batch_neg_cand) + with torch.no_grad(): + score = model(batch_neg_feats[start:end], in_layer=in_layer) + if start==0: + neg_cand_score = score.detach()[:, 1].clone() + else: + neg_cand_score = torch.cat((neg_cand_score, score.detach()[:, 1].clone()), 0) + + _, top_idx = neg_cand_score.topk(batch_neg) + batch_neg_feats = batch_neg_feats[top_idx] + model.train() + + # forward + pos_score = model(batch_pos_feats, in_layer=in_layer) + neg_score = model(batch_neg_feats, in_layer=in_layer) + + # optimize + loss = criterion(pos_score, neg_score) + model.zero_grad() + loss.backward() + if 'grad_clip' in opts: + torch.nn.utils.clip_grad_norm_(model.parameters(), opts['grad_clip']) + optimizer.step() + + +def run_mdnet(img_list, init_bbox, gt=None, savefig_dir='', display=False): + + # Init bbox + target_bbox = np.array(init_bbox) + result = np.zeros((len(img_list), 4)) + result_bb = np.zeros((len(img_list), 4)) + result[0] = target_bbox + result_bb[0] = target_bbox + + if gt is not None: + overlap = np.zeros(len(img_list)) + overlap[0] = 1 + dist = np.zeros(len(img_list)) + dist[0] = 0 + + # Init model + model = MDNet(opts['model_path']) + if opts['use_gpu']: + model = model.npu() + + # Init criterion and optimizer + criterion = BCELoss() + model.set_learnable_params(opts['ft_layers']) + optimizer = set_optimizer(model, opts['lr_init'], opts['lr_mult']) + + tic = time.time() + # Load first image + image = Image.open(img_list[0]).convert('RGB') + + # Draw pos/neg samples + pos_examples = SampleGenerator('gaussian', image.size, opts['trans_pos'], opts['scale_pos'])( + target_bbox, opts['n_pos_init'], opts['overlap_pos_init']) + + neg_examples = np.concatenate([ + SampleGenerator('uniform', image.size, opts['trans_neg_init'], opts['scale_neg_init'])( + target_bbox, int(opts['n_neg_init'] * 0.5), opts['overlap_neg_init']), + SampleGenerator('whole', image.size)( + target_bbox, int(opts['n_neg_init'] * 0.5), opts['overlap_neg_init'])]) + neg_examples = np.random.permutation(neg_examples) + + # Extract pos/neg features + pos_feats = forward_samples(model, image, pos_examples) + neg_feats = forward_samples(model, image, neg_examples) + + # Initial training + train(model, criterion, optimizer, pos_feats, neg_feats, opts['maxiter_init']) + lr_ratio = opts['lr_update'] / opts['lr_init'] + for param_group in optimizer.param_groups: + param_group['lr'] *= lr_ratio + + # Train bbox regressor + bbreg_examples = SampleGenerator('uniform', image.size, opts['trans_bbreg'], opts['scale_bbreg'], opts['aspect_bbreg'])( + target_bbox, opts['n_bbreg'], opts['overlap_bbreg']) + bbreg_feats = forward_samples(model, image, bbreg_examples) + bbreg = BBRegressor(image.size) + bbreg.train(bbreg_feats, bbreg_examples, target_bbox) + del bbreg_feats + + # Init sample generators for update + sample_generator = SampleGenerator('gaussian', image.size, opts['trans'], opts['scale']) + pos_generator = SampleGenerator('gaussian', image.size, opts['trans_pos'], opts['scale_pos']) + neg_generator = SampleGenerator('uniform', image.size, opts['trans_neg'], opts['scale_neg']) + + # Init pos/neg features for update + neg_examples = neg_generator(target_bbox, opts['n_neg_update'], opts['overlap_neg_init']) + neg_feats = forward_samples(model, image, neg_examples) + pos_feats_all = [pos_feats] + neg_feats_all = [neg_feats] + + spf_total = time.time() - tic + + # Display + savefig = savefig_dir != '' + if display or savefig: + dpi = 80.0 + figsize = (image.size[0] / dpi, image.size[1] / dpi) + + fig = plt.figure(frameon=False, figsize=figsize, dpi=dpi) + ax = plt.Axes(fig, [0., 0., 1., 1.]) + ax.set_axis_off() + fig.add_axes(ax) + im = ax.imshow(image, aspect='auto') + + if gt is not None: + gt_rect = plt.Rectangle(tuple(gt[0, :2]), gt[0, 2], gt[0, 3], + linewidth=3, edgecolor="#00ff00", zorder=1, fill=False) + ax.add_patch(gt_rect) + + rect = plt.Rectangle(tuple(result_bb[0, :2]), result_bb[0, 2], result_bb[0, 3], + linewidth=3, edgecolor="#ff0000", zorder=1, fill=False) + ax.add_patch(rect) + + if display: + plt.pause(.01) + plt.draw() + if savefig: + fig.savefig(os.path.join(savefig_dir, '0000.jpg'), dpi=dpi) + + # Main loop + for i in range(1, len(img_list)): + + tic = time.time() + # Load image + image = Image.open(img_list[i]).convert('RGB') + + # Estimate target bbox + samples = sample_generator(target_bbox, opts['n_samples']) + sample_scores = forward_samples(model, image, samples, out_layer='fc6') + + top_scores, top_idx = sample_scores[:, 1].topk(5) + top_idx = top_idx.cpu() + target_score = top_scores.mean() + target_bbox = samples[top_idx] + if top_idx.shape[0] > 1: + target_bbox = target_bbox.mean(axis=0) + success = target_score > 0 + + # Expand search area at failure + if success: + sample_generator.set_trans(opts['trans']) + else: + sample_generator.expand_trans(opts['trans_limit']) + + # Bbox regression + if success: + bbreg_samples = samples[top_idx] + if top_idx.shape[0] == 1: + bbreg_samples = bbreg_samples[None,:] + bbreg_feats = forward_samples(model, image, bbreg_samples) + bbreg_samples = bbreg.predict(bbreg_feats, bbreg_samples) + bbreg_bbox = bbreg_samples.mean(axis=0) + else: + bbreg_bbox = target_bbox + + # Save result + result[i] = target_bbox + result_bb[i] = bbreg_bbox + + # Data collect + if success: + pos_examples = pos_generator(target_bbox, opts['n_pos_update'], opts['overlap_pos_update']) + pos_feats = forward_samples(model, image, pos_examples) + pos_feats_all.append(pos_feats) + if len(pos_feats_all) > opts['n_frames_long']: + del pos_feats_all[0] + + neg_examples = neg_generator(target_bbox, opts['n_neg_update'], opts['overlap_neg_update']) + neg_feats = forward_samples(model, image, neg_examples) + neg_feats_all.append(neg_feats) + if len(neg_feats_all) > opts['n_frames_short']: + del neg_feats_all[0] + + # Short term update + if not success: + nframes = min(opts['n_frames_short'], len(pos_feats_all)) + pos_data = torch.cat(pos_feats_all[-nframes:], 0) + neg_data = torch.cat(neg_feats_all, 0) + train(model, criterion, optimizer, pos_data, neg_data, opts['maxiter_update']) + + # Long term update + elif i % opts['long_interval'] == 0: + pos_data = torch.cat(pos_feats_all, 0) + neg_data = torch.cat(neg_feats_all, 0) + train(model, criterion, optimizer, pos_data, neg_data, opts['maxiter_update']) + + spf = time.time() - tic + spf_total += spf + + # Display + if display or savefig: + im.set_data(image) + + if gt is not None: + gt_rect.set_xy(gt[i, :2]) + gt_rect.set_width(gt[i, 2]) + gt_rect.set_height(gt[i, 3]) + + rect.set_xy(result_bb[i, :2]) + rect.set_width(result_bb[i, 2]) + rect.set_height(result_bb[i, 3]) + + if display: + plt.pause(.01) + plt.draw() + if savefig: + fig.savefig(os.path.join(savefig_dir, '{:04d}.jpg'.format(i)), dpi=dpi) + + if gt is None: + print('Frame {:d}/{:d}, Score {:.3f}, Time {:.3f}' + .format(i, len(img_list), target_score, spf)) + else: + overlap[i] = overlap_ratio(gt[i], result_bb[i])[0] + dist[i] = center_dist(gt[i], result_bb[i])[0] + print('Frame {:d}/{:d}, Overlap {:.3f}, Dist {:.3f}, Score {:.3f}, Time {:.3f}' + .format(i, len(img_list), overlap[i], dist[i], target_score, spf)) + + if gt is not None: + print('meanIOU: {:.3f}'.format(overlap.mean())) + print('Precision: {:.3f}'.format(sum(dist <= 20) / len(img_list))) + fps = len(img_list) / spf_total + return result, result_bb, fps, dist + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument('-s', '--seq', default='', help='input seq') + parser.add_argument('-j', '--json', default='', help='input json') + parser.add_argument('-f', '--savefig', action='store_true') + parser.add_argument('-d', '--display', action='store_true') + + args = parser.parse_args() + assert args.seq != '' or args.json != '' + + np.random.seed(0) + torch.manual_seed(0) + + # Generate sequence config + img_list, init_bbox, gt, savefig_dir, display, result_path = gen_config(args) + + # Run tracker + result, result_bb, fps = run_mdnet(img_list, init_bbox, gt=gt, savefig_dir=savefig_dir, display=display) + + # Save result + res = {} + res['res'] = result_bb.round().tolist() + res['type'] = 'rect' + res['fps'] = fps + json.dump(res, open(result_path, 'w'), indent=2)