From b612762bbb34f9c1b89da0db7955139322a80ad1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A5=B6=E6=B2=B9=E8=8A=B1=E7=94=9F=E6=9E=9C=E9=85=B1?= <1511077945@qq.com> Date: Mon, 21 Mar 2022 07:17:33 +0000 Subject: [PATCH 1/3] =?UTF-8?q?=E6=96=B0=E5=BB=BA=20CTPN?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- PyTorch/contrib/cv/detection/CTPN/.keep | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 PyTorch/contrib/cv/detection/CTPN/.keep diff --git a/PyTorch/contrib/cv/detection/CTPN/.keep b/PyTorch/contrib/cv/detection/CTPN/.keep new file mode 100644 index 0000000000..e69de29bb2 -- Gitee From 2bbedeaae8b3c391cd62fc097c6017e6e3d04f91 Mon Sep 17 00:00:00 2001 From: lgq1997 Date: Fri, 1 Jul 2022 10:26:47 +0800 Subject: [PATCH 2/3] merge commits --- PyTorch/contrib/cv/detection/CTPN/LICENSE | 30 + PyTorch/contrib/cv/detection/CTPN/README.md | 48 ++ .../contrib/cv/detection/CTPN/ctpn/config.py | 41 ++ .../contrib/cv/detection/CTPN/ctpn/ctpn.py | 201 +++++++ .../contrib/cv/detection/CTPN/ctpn/dataset.py | 329 +++++++++++ .../contrib/cv/detection/CTPN/ctpn/utils.py | 522 +++++++++++++++++ .../cv/detection/CTPN/modelzoo_level.txt | 5 + .../cv/detection/CTPN/predict_2_txt.py | 152 +++++ .../cv/detection/CTPN/requirements.txt | 4 + .../CTPN/scripts/rrc_evaluation_funcs_1_1.py | 505 ++++++++++++++++ .../cv/detection/CTPN/scripts/script.py | 412 +++++++++++++ PyTorch/contrib/cv/detection/CTPN/test/env.sh | 79 +++ .../cv/detection/CTPN/{ => test/output}/.keep | 0 .../cv/detection/CTPN/test/train_eval.sh | 101 ++++ .../cv/detection/CTPN/test/train_full_1p.sh | 140 +++++ .../cv/detection/CTPN/test/train_full_8p.sh | 142 +++++ .../CTPN/test/train_performance_1p.sh | 140 +++++ .../CTPN/test/train_performance_8p.sh | 141 +++++ PyTorch/contrib/cv/detection/CTPN/train.py | 545 ++++++++++++++++++ 19 files changed, 3537 insertions(+) create mode 100644 PyTorch/contrib/cv/detection/CTPN/LICENSE create mode 100644 PyTorch/contrib/cv/detection/CTPN/README.md create mode 100644 PyTorch/contrib/cv/detection/CTPN/ctpn/config.py create mode 100644 PyTorch/contrib/cv/detection/CTPN/ctpn/ctpn.py create mode 100644 PyTorch/contrib/cv/detection/CTPN/ctpn/dataset.py create mode 100644 PyTorch/contrib/cv/detection/CTPN/ctpn/utils.py create mode 100644 PyTorch/contrib/cv/detection/CTPN/modelzoo_level.txt create mode 100644 PyTorch/contrib/cv/detection/CTPN/predict_2_txt.py create mode 100644 PyTorch/contrib/cv/detection/CTPN/requirements.txt create mode 100644 PyTorch/contrib/cv/detection/CTPN/scripts/rrc_evaluation_funcs_1_1.py create mode 100644 PyTorch/contrib/cv/detection/CTPN/scripts/script.py create mode 100644 PyTorch/contrib/cv/detection/CTPN/test/env.sh rename PyTorch/contrib/cv/detection/CTPN/{ => test/output}/.keep (100%) create mode 100644 PyTorch/contrib/cv/detection/CTPN/test/train_eval.sh create mode 100644 PyTorch/contrib/cv/detection/CTPN/test/train_full_1p.sh create mode 100644 PyTorch/contrib/cv/detection/CTPN/test/train_full_8p.sh create mode 100644 PyTorch/contrib/cv/detection/CTPN/test/train_performance_1p.sh create mode 100644 PyTorch/contrib/cv/detection/CTPN/test/train_performance_8p.sh create mode 100644 PyTorch/contrib/cv/detection/CTPN/train.py diff --git a/PyTorch/contrib/cv/detection/CTPN/LICENSE b/PyTorch/contrib/cv/detection/CTPN/LICENSE new file mode 100644 index 0000000000..6a1fa5ef09 --- /dev/null +++ b/PyTorch/contrib/cv/detection/CTPN/LICENSE @@ -0,0 +1,30 @@ +BSD 3-Clause License + +Copyright (c) 2019 the authors +All rights reserved. +Copyright 2022 Huawei Technologies Co., Ltd + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/PyTorch/contrib/cv/detection/CTPN/README.md b/PyTorch/contrib/cv/detection/CTPN/README.md new file mode 100644 index 0000000000..8c91c1284c --- /dev/null +++ b/PyTorch/contrib/cv/detection/CTPN/README.md @@ -0,0 +1,48 @@ +# ctpn.pytorch +Pytorch implementation of CTPN (Detecting Text in Natural Image with Connectionist Text Proposal Network) + +# Paper +https://arxiv.org/pdf/1609.03605.pdf + +# 环境准备 + +请使用apt-get install zip或yum install zip安装压缩工具zip + +# 数据准备 + +请下载 icdar13 dataset并解压到`${ROOT}/data/icdar13/`,并将其中的`gt.zip`标签文件移动到`${ROOT}`/下 +# Training + +```bash +# training 1p accuracy +bash ./test/train_full_1p.sh --data_path=real_data_path + +# training 1p performance +bash ./test/train_performance_1p.sh --data_path=real_data_path + +# training 8p accuracy +bash ./test/train_full_8p.sh --data_path=real_data_path + +# training 8p performance +bash ./test/train_performance_8p.sh --data_path=real_data_path + +``` + +# Test +测试的权重采用最后一个epoch的权重文件,即当epoch=200时,权重路径为output_models/checkpoint-200.pth.tar +``` +# test 8p accuracy +bash test/train_eval.sh --data_path=real_data_path --pth_path=output_models/checkpoint-200.pth.tar +``` +测试精度包含三个部分,hmean为用于比对的精度 +``` +Calculated!{"precision": 0.7331386861313869, "recall": 0.7094063926940639, "hmean": 0.7210773213359865} +``` +## CTPN training result + +| 名称 | 精度 | 性能 | +| :----: | :--: | :------: | +| NPU-8p | 72.1 | 17.66fps | +| GPU-8p | 72.4 | 13.25fps | +| NPU-1p | | 1.695fps | +| GPU-1p | | 6.295fps| \ No newline at end of file diff --git a/PyTorch/contrib/cv/detection/CTPN/ctpn/config.py b/PyTorch/contrib/cv/detection/CTPN/ctpn/config.py new file mode 100644 index 0000000000..3f92075430 --- /dev/null +++ b/PyTorch/contrib/cv/detection/CTPN/ctpn/config.py @@ -0,0 +1,41 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#-*- coding:utf-8 -*- +import os + +#img_dir = '../imagedata/image/' +#label_dir = '../imagedata/xml/' + +img_dir = '/home/dockerHome/ctpn/ctpn_8p/imagedata/Challenge2_Training_Task12_Images/' +label_dir = '/home/dockerHome/ctpn/ctpn_8p/imagedata/Challenge2_Training_Task1_GT/' + +num_workers = 0 +pretrained_weights = '' +#pretrained_weights = './checkpoints/gpu_ctpn_ep98_0.2615_0.0333_0.2948.pth' + + +anchor_scale = 16 +IOU_NEGATIVE = 0.3 +IOU_POSITIVE = 0.7 +IOU_SELECT = 0.7 + +RPN_POSITIVE_NUM = 150 +RPN_TOTAL_NUM = 300 + +IMAGE_MEAN = [123.68, 116.779, 103.939] + +# online hard example mining +OHEM = True +checkpoints_dir = './checkpoints' diff --git a/PyTorch/contrib/cv/detection/CTPN/ctpn/ctpn.py b/PyTorch/contrib/cv/detection/CTPN/ctpn/ctpn.py new file mode 100644 index 0000000000..d67052bca7 --- /dev/null +++ b/PyTorch/contrib/cv/detection/CTPN/ctpn/ctpn.py @@ -0,0 +1,201 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# -*- coding:utf-8 -*- +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.models as models +from torch.contrib.npu.optimized_lib import module as nnn +from ctpn import config + +""" +回归损失: smooth L1 Loss +只针对正样本求取回归损失 +L = 0.5*x**2 |x|<1 +L = |x| - 0.5 +sigma: 平滑系数 +1、从预测框p和真值框g中筛选出正样本 +2、|x| = |g - p| +3、求取loss,这里设置了一个平滑系数 1/sigma + (1) |x|>1/sigma: loss = |x| - 0.5/sigma + (2) |x|<1/sigma: loss = 0.5*sigma*|x|**2 +""" + + +class RPN_REGR_Loss(nn.Module): + def __init__(self, device, sigma=9.0): + super(RPN_REGR_Loss, self).__init__() + self.sigma = sigma + self.device = device + + def forward(self, input_data, target): + input_data = input_data.to(self.device).float() + target = target.to(self.device).float() + cls = target[0, :, 0] + regression = target[0, :, 1:3] + regr_keep = (cls == 1).nonzero()[:, 0] + regr_true = regression[regr_keep] + if regr_true.numel() > 0: + regr_pred = input_data[0][regr_keep] + diff = torch.abs(regr_true - regr_pred) + less_one = (diff < 1.0 / self.sigma).float() + loss = less_one * 0.5 * diff ** 2 * self.sigma + torch.abs(1 - less_one) * (diff - 0.5 / self.sigma) + loss = torch.sum(loss, 1) + loss = torch.mean(loss) + else: + loss = input_data.sum() * 0 + return loss.to(self.device) + + +""" +分类损失: softmax loss +1、OHEM模式 + (1) 筛选出正样本,求取softmaxloss + (2) 求取负样本数量N_neg, 指定样本数量N, 求取负样本的topK loss, 其中K = min(N_neg, N - len(pos_num)) + (3) loss = loss1 + loss2 +2、求取NLLLoss,截断在(0, 10)区间 +""" + + +class RPN_CLS_Loss(nn.Module): + def __init__(self, device): + super(RPN_CLS_Loss, self).__init__() + self.device = device + self.L_cls = nn.CrossEntropyLoss(reduction='none').to(self.device) + + def forward(self, input_data, target): + input_data = input_data.to(self.device).float() + target = target.to(self.device).float() + if config.OHEM: + cls_gt = target[0][0] + num_pos = 0 + loss_pos_sum = 0 + + if len((cls_gt == 1).nonzero()) != 0: + cls_pos = (cls_gt == 1).nonzero()[:, 0] + gt_pos = cls_gt[cls_pos].long() + cls_pred_pos = input_data[0][cls_pos] + loss_pos = self.L_cls(cls_pred_pos.view(-1, 2), gt_pos.view(-1)) + loss_pos_sum = loss_pos.sum() + num_pos = len(loss_pos) + + cls_neg = (cls_gt == 0).nonzero()[:, 0] + gt_neg = cls_gt[cls_neg].long() + cls_pred_neg = input_data[0][cls_neg] + + loss_neg = self.L_cls(cls_pred_neg.view(-1, 2), gt_neg.view(-1)) + loss_neg_topK, _ = torch.topk(loss_neg, min(len(loss_neg), config.RPN_TOTAL_NUM - num_pos)) + loss_cls = loss_pos_sum + loss_neg_topK.sum() + loss_cls = loss_cls / config.RPN_TOTAL_NUM + + return loss_cls.to(self.device) + else: + y_true = target[0][0] + cls_keep = (y_true != -1).nonzero()[:, 0] + cls_true = y_true[cls_keep].long() + cls_pred = input_data[0][cls_keep] + loss = F.nll_loss(F.log_softmax(cls_pred, dim=-1), cls_true) + loss = torch.clamp(torch.mean(loss), 0, 10) if loss.numel() > 0 else torch.tensor(0.0) + + return loss.to(self.device) + + +class basic_conv(nn.Module): + def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, + bn=True, bias=True): + super(basic_conv, self).__init__() + self.out_channels = out_planes + self.bn = bn + self.relu = relu + self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, + dilation=dilation, groups=groups, bias=bias) + if self.bn: + self.batchnorm = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) + if self.relu: + self.act = nn.ReLU(inplace=True) + + def forward(self, x): + x = self.conv(x) + if self.bn: + x = self.batchnorm(x) + if self.relu: + x = self.act(x) + return x + + +""" +image -> feature map -> rpn -> blstm -> fc -> classifier + -> regression +""" + + +class CTPN_Model(nn.Module): + def __init__(self): + super().__init__() + base_model = models.resnet34(pretrained=True) + self.conv1 = base_model.conv1 + self.bn1 = base_model.bn1 + self.relu = base_model.relu + self.layer1 = base_model.layer1 + self.layer2 = base_model.layer2 + self.layer3 = base_model.layer3 + self.layer4 = base_model.layer4 + del base_model.maxpool + del base_model.avgpool + del base_model.fc + self.rpn = basic_conv(512, 512, 3, 1, 1, bn=False) + self.brnn = nnn.BiLSTM(512, 128) + self.lstm_fc = basic_conv(256, 512, 1, 1, relu=True, bn=False) + self.rpn_class = basic_conv(512, 10 * 2, 1, 1, relu=False, bn=False) + self.rpn_regress = basic_conv(512, 10 * 2, 1, 1, relu=False, bn=False) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + # rpn + x = self.rpn(x) # [b, c, h, w] + + x1 = x.permute(0, 2, 3, 1).contiguous() # channels last [b, h, w, c] + b = x1.size() # b, h, w, c + x1 = x1.view(b[0] * b[1], b[2], b[3]) + + x2 = self.brnn(x1) + + xsz = x.size() + x3 = x2.view(xsz[0], xsz[2], xsz[3], 256) # torch.Size([4, 20, 20, 256]) + + x3 = x3.permute(0, 3, 1, 2).contiguous() # channels first [b, c, h, w] + x3 = self.lstm_fc(x3) + x = x3 + + cls = self.rpn_class(x) + regression = self.rpn_regress(x) + + cls = cls.permute(0, 2, 3, 1).contiguous() + regression = regression.permute(0, 2, 3, 1).contiguous() + + cls = cls.view(cls.size(0), cls.size(1) * cls.size(2) * 10, 2) + regression = regression.view(regression.size(0), regression.size(1) * regression.size(2) * 10, 2) + + return cls, regression + + +if __name__ == '__main__': + CTPN_Model() diff --git a/PyTorch/contrib/cv/detection/CTPN/ctpn/dataset.py b/PyTorch/contrib/cv/detection/CTPN/ctpn/dataset.py new file mode 100644 index 0000000000..6f97e6a17f --- /dev/null +++ b/PyTorch/contrib/cv/detection/CTPN/ctpn/dataset.py @@ -0,0 +1,329 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# -*- coding:utf-8 -*- +import os +import random + +import numpy as np +import cv2 +import torch +from torch.utils.data import Dataset +import xml.etree.ElementTree as ET +from ctpn.utils import cal_rpn + +IMAGE_MEAN = [123.68, 116.779, 103.939] + +''' +从xml文件中读取图像中的真值框 +''' + + +def readxml(path): + gtboxes = [] + xml = ET.parse(path) + for elem in xml.iter(): + if 'object' in elem.tag: + for attr in list(elem): + if 'bndbox' in attr.tag: + xmin = int(round(float(attr.find('xmin').text))) + ymin = int(round(float(attr.find('ymin').text))) + xmax = int(round(float(attr.find('xmax').text))) + ymax = int(round(float(attr.find('ymax').text))) + gtboxes.append((xmin, ymin, xmax, ymax)) + + return np.array(gtboxes) + + +''' +读取VOC格式数据,返回用于训练的图像、anchor目标框、标签 +''' + + +class VOCDataset(Dataset): + def __init__(self, datadir, labelsdir): + if not os.path.isdir(datadir): + raise Exception('[ERROR] {} is not a directory'.format(datadir)) + if not os.path.isdir(labelsdir): + raise Exception('[ERROR] {} is not a directory'.format(labelsdir)) + + self.datadir = datadir + self.img_names = os.listdir(self.datadir) + self.labelsdir = labelsdir + + def __len__(self): + return len(self.img_names) + + def generate_gtboxes(self, xml_path, rescale_fac=1.0): + base_gtboxes = readxml(xml_path) + gtboxes = [] + for base_gtbox in base_gtboxes: + xmin, ymin, xmax, ymax = base_gtbox + if rescale_fac > 1.0: + xmin = int(xmin / rescale_fac) + xmax = int(xmax / rescale_fac) + ymin = int(ymin / rescale_fac) + ymax = int(ymax / rescale_fac) + prev = xmin + for i in range(xmin // 16 + 1, xmax // 16 + 1): + _next = 16 * i - 0.5 + gtboxes.append((prev, ymin, _next, ymax)) + prev = _next + gtboxes.append((prev, ymin, xmax, ymax)) + return np.array(gtboxes) + + def __getitem__(self, idx): + img_name = self.img_names[idx] + img_path = os.path.join(self.datadir, img_name) + img = cv2.imread(img_path) + h, w, c = img.shape + rescale_fac = max(h, w) / 1000 + if rescale_fac > 1.0: + h = int(h / rescale_fac) + w = int(w / rescale_fac) + img = cv2.resize(img, (w, h)) + + xml_path = os.path.join(self.labelsdir, img_name.split('.')[0] + '.xml') + gtbox = self.generate_gtboxes(xml_path, rescale_fac) + + if np.random.randint(2) == 1: + img = img[:, ::-1, :] + newx1 = w - gtbox[:, 2] - 1 + newx2 = w - gtbox[:, 0] - 1 + gtbox[:, 0] = newx1 + gtbox[:, 2] = newx2 + + [cls, regr] = cal_rpn((h, w), (int(h / 16), int(w / 16)), 16, gtbox) + regr = np.hstack([cls.reshape(cls.shape[0], 1), regr]) + cls = np.expand_dims(cls, axis=0) + + m_img = img - IMAGE_MEAN + m_img = torch.from_numpy(m_img.transpose([2, 0, 1])).float() + cls = torch.from_numpy(cls).float() + regr = torch.from_numpy(regr).float() + + return m_img, cls, regr + + +################################################################################ + + +class ICDARDataset(Dataset): + def __init__(self, datadir, labelsdir): + if not os.path.isdir(datadir): + raise Exception('[ERROR] {} is not a directory'.format(datadir)) + if not os.path.isdir(labelsdir): + raise Exception('[ERROR] {} is not a directory'.format(labelsdir)) + + self.datadir = datadir + self.img_names = os.listdir(self.datadir) + self.labelsdir = labelsdir + + def __len__(self): + return len(self.img_names) + + def box_transfer(self, coor_lists, rescale_fac=1.0): + gtboxes = [] + for coor_list in coor_lists: + coors_x = [int(coor_list[2 * i]) for i in range(4)] + coors_y = [int(coor_list[2 * i + 1]) for i in range(4)] + xmin = min(coors_x) + xmax = max(coors_x) + ymin = min(coors_y) + ymax = max(coors_y) + if rescale_fac > 1.0: + xmin = int(xmin / rescale_fac) + xmax = int(xmax / rescale_fac) + ymin = int(ymin / rescale_fac) + ymax = int(ymax / rescale_fac) + gtboxes.append((xmin, ymin, xmax, ymax)) + return np.array(gtboxes) + + def box_transfer_v2(self, coor_lists, rescale_fac=1.0): + gtboxes = [] + for coor_list in coor_lists: + coors_x = [int(coor_list[2 * i]) for i in range(4)] + coors_y = [int(coor_list[2 * i + 1]) for i in range(4)] + xmin = min(coors_x) + xmax = max(coors_x) + ymin = min(coors_y) + ymax = max(coors_y) + if rescale_fac > 1.0: + xmin = int(xmin / rescale_fac) + xmax = int(xmax / rescale_fac) + ymin = int(ymin / rescale_fac) + ymax = int(ymax / rescale_fac) + prev = xmin + for i in range(xmin // 16 + 1, xmax // 16 + 1): + _next = 16 * i - 0.5 + gtboxes.append((prev, ymin, _next, ymax)) + prev = _next + gtboxes.append((prev, ymin, xmax, ymax)) + return np.array(gtboxes) + + def parse_gtfile(self, gt_path, rescale_fac=1.0): + coor_lists = list() + with open(gt_path, 'r', encoding="utf-8-sig") as f: + content = f.readlines() + for line in content: + coor_list = line.split(',')[:8] + if len(coor_list) == 8: + coor_lists.append(coor_list) + return self.box_transfer_v2(coor_lists, rescale_fac) + + def draw_boxes(self, img, cls, base_anchors, gt_box): + for i in range(len(cls)): + if cls[i] == 1: + pt1 = (int(base_anchors[i][0]), int(base_anchors[i][1])) + pt2 = (int(base_anchors[i][2]), int(base_anchors[i][3])) + img = cv2.rectangle(img, pt1, pt2, (200, 100, 100)) + for i in range(gt_box.shape[0]): + pt1 = (int(gt_box[i][0]), int(gt_box[i][1])) + pt2 = (int(gt_box[i][2]), int(gt_box[i][3])) + img = cv2.rectangle(img, pt1, pt2, (100, 200, 100)) + return img + + def __getitem__(self, idx): + img_name = self.img_names[idx] + img_path = os.path.join(self.datadir, img_name) + img = cv2.imread(img_path)[:, :, ::-1] + + h, w, c = img.shape + rescale_fac = max(h, w) / 1000 + if rescale_fac > 1.0: + h = int(h / rescale_fac) + w = int(w / rescale_fac) + img = cv2.resize(img, (w, h)) + + # gt_path = os.path.join(self.labelsdir, img_name.split('.')[0]+'.txt') + gt_path = os.path.join(self.labelsdir, "gt_" + img_name.split('.')[0] + '.txt') + gtbox = self.parse_gtfile(gt_path, rescale_fac) + + # random flip image + if np.random.randint(2) == 1: + img = img[:, ::-1, :] + newx1 = w - gtbox[:, 2] - 1 + newx2 = w - gtbox[:, 0] - 1 + gtbox[:, 0] = newx1 + gtbox[:, 2] = newx2 + + print(gtbox) + [cls, regr] = cal_rpn((h, w), (int(h / 16), int(w / 16)), 16, gtbox) + regr = np.hstack([cls.reshape(cls.shape[0], 1), regr]) + cls = np.expand_dims(cls, axis=0) + + m_img = img - IMAGE_MEAN + m_img = torch.from_numpy(m_img.transpose([2, 0, 1])).float() + cls = torch.from_numpy(cls).float() + regr = torch.from_numpy(regr).float() + + return m_img, cls, regr + + +''' +读取ICDAR格式数据,返回用于训练的图像、anchor目标框、标签 +''' + + +class icdarDataset(Dataset): + def __init__(self, datadir, labelsdir): + if not os.path.isdir(datadir): + raise Exception('[ERROR] {} is not a directory'.format(datadir)) + if not os.path.isdir(labelsdir): + raise Exception('[ERROR] {} is not a directory'.format(labelsdir)) + + self.datadir = datadir + self.img_names = os.listdir(self.datadir) + self.labelsdir = labelsdir + self.short_len = 1008 + self.long_len = 1008 + + def __len__(self): + return len(self.img_names) + + def generate_gtboxes(self, txt_path, w_pad, h_pad, rescale_fac): + + coor_lists = list() + with open(txt_path, 'r', encoding="utf-8-sig") as f: + content = f.readlines() + for line in content: + coor_list = line.split(' ')[:4] + for i in range(len(coor_list)): + coor_list[i] = int(coor_list[i]) + coor_lists.append(coor_list) + base_gtboxes = np.array(coor_lists) + + gtboxes = [] + for base_gtbox in base_gtboxes: + xmin, ymin, xmax, ymax = base_gtbox + xmin = int(xmin / rescale_fac) + w_pad + xmax = int(xmax / rescale_fac) + w_pad + ymin = int(ymin / rescale_fac) + h_pad + ymax = int(ymax / rescale_fac) + h_pad + prev = xmin + for i in range(xmin // 16 + 1, xmax // 16 + 1): + _next = 16 * i - 0.5 + gtboxes.append((prev, ymin, _next, ymax)) + prev = _next + gtboxes.append((prev, ymin, xmax, ymax)) + return np.array(gtboxes) + + def __getitem__(self, idx): + img_name = self.img_names[idx] + img_path = os.path.join(self.datadir, img_name) + img = cv2.imread(img_path)[:, :, ::-1] + h, w, c = img.shape + scale_range = random.uniform(0.7, 1.5) + h = h * scale_range + w = w * scale_range + rescale_fac = 1.0 + if min(h, w) > self.short_len or max(h, w) > self.long_len: + rescale_fac = max(min(h, w) / self.short_len, max(h, w) / self.long_len) + h = int(round(h / rescale_fac)) + w = int(round(w / rescale_fac)) + rescale_fac = rescale_fac / scale_range + img = cv2.resize(img, (w, h)) + if h > w: + target_w = self.short_len + target_h = self.long_len + else: + target_w = self.long_len + target_h = self.short_len + + w_pad = random.randint(0, target_w - w) + h_pad = random.randint(0, target_h - h) + + w_pad2 = target_w - w - w_pad + h_pad2 = target_h - h - h_pad + img = np.pad(img, ((h_pad, h_pad2), (w_pad, w_pad2), (0, 0))) + + txt_path = os.path.join(self.labelsdir, 'gt_' + img_name.split('.')[0] + '.txt') + gtbox = self.generate_gtboxes(txt_path, w_pad, h_pad, rescale_fac) + h, w = target_h, target_w + if np.random.randint(2) == 1: + img = img[:, ::-1, :] + newx1 = w - gtbox[:, 2] - 1 + newx2 = w - gtbox[:, 0] - 1 + gtbox[:, 0] = newx1 + gtbox[:, 2] = newx2 + [cls, regr] = cal_rpn((h, w), (int(h / 16), int(w / 16)), 16, gtbox) + regr = np.hstack([cls.reshape(cls.shape[0], 1), regr]) + cls = np.expand_dims(cls, axis=0) + + m_img = img - IMAGE_MEAN + m_img = torch.from_numpy(m_img.transpose([2, 0, 1])).float() + cls = torch.from_numpy(cls).float() + regr = torch.from_numpy(regr).float() + + return m_img, cls, regr diff --git a/PyTorch/contrib/cv/detection/CTPN/ctpn/utils.py b/PyTorch/contrib/cv/detection/CTPN/ctpn/utils.py new file mode 100644 index 0000000000..8e503e0c13 --- /dev/null +++ b/PyTorch/contrib/cv/detection/CTPN/ctpn/utils.py @@ -0,0 +1,522 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# -*- coding:utf-8 -*- +import numpy as np +import cv2 +from ctpn.config import * + +""" +anchor生成 +遇到的问题:首先,base_anchor 为初始位置点生成的anchor,按步长在feature map 的各个点生成anchor之后,anchors的 shape 为[10, h*w, 4]。 +这里,我一开始是直接将anchors reshape 成 [10*h*w, 4],这在训练时不收敛。 +原因浅析:按我代码的实现方式,直接[10, h*w, 4] -> [10*h*w, 4],anchor 的排列顺序将按照不同的anchor形状(共10种)进行排列,而不是根据feature map 的点按序排列, +而按 ctpn 的实现方式,小的anchor需要连成大的文本框才是最终的结果,不按点的顺序生成anchor可能给训练带来较大的干扰。 +解决方案:将 anchor 根据feature_map 的各个点,按序生成10个anchor重新排列,也即:[10, h*w, 4] -> [h*w, 10, 4] -> [10*h*w, 4],问题解决。 +""" + + +def gen_anchor(featuresize, scale, + heights=[11, 16, 23, 33, 48, 68, 97, 139, 198, 283], + widths=[16, 16, 16, 16, 16, 16, 16, 16, 16, 16]): + h, w = featuresize + shift_x = np.arange(0, w) * scale + shift_y = np.arange(0, h) * scale + shift_x, shift_y = np.meshgrid(shift_x, shift_y) + shift = np.stack((shift_x.ravel(), shift_y.ravel(), shift_x.ravel(), shift_y.ravel()), axis=1) + + # base center(x,,y) -> (x1, y1, x2, y2) + base_anchor = np.array([0, 0, 15, 15]) + xt = (base_anchor[0] + base_anchor[2]) * 0.5 + yt = (base_anchor[1] + base_anchor[3]) * 0.5 + # heights = np.array(heights).view(len(heights), 1) + # widths = np.array(widths).view(len(widths), 1) + + heights = np.array(heights).reshape(len(heights), 1) + widths = np.array(widths).reshape(len(widths), 1) + x1 = xt - widths * 0.5 + y1 = yt - heights * 0.5 + x2 = xt + widths * 0.5 + y2 = yt + heights * 0.5 + base_anchor = np.hstack((x1, y1, x2, y2)) + + anchors = list() + for i in range(base_anchor.shape[0]): + anchor_x1 = shift[:, 0] + base_anchor[i][0] + anchor_y1 = shift[:, 1] + base_anchor[i][1] + anchor_x2 = shift[:, 2] + base_anchor[i][2] + anchor_y2 = shift[:, 3] + base_anchor[i][3] + anchor_.append(np.dstack((anchor_x1, anchor_y1, anchor_x2, anchor_y2))) + + # return np.squeeze(np.array(anchor)).transpose((1,0,2)).view((-1, 4)) + return np.squeeze(np.array(anchors)).transpose((1, 0, 2)).reshape((-1, 4)) + + +""" +anchor 与 bbox的 iou计算 +iou = inter_area/(bb_area + anchor_area - inter_area) +""" + + +def compute_iou(anchors, bbox): + ious = np.zeros((len(anchors), len(bbox)), dtype=np.float32) + anchor_area = (anchors[:, 2] - anchors[:, 0]) * (anchors[:, 3] - anchors[:, 1]) + for num, _bbox in enumerate(bbox): + bb = np.tile(_bbox, (len(anchors), 1)) + bb_area = (bb[:, 2] - bb[:, 0]) * (bb[:, 3] - bb[:, 1]) + inter_h = np.maximum(np.minimum(bb[:, 3], anchors[:, 3]) - np.maximum(bb[:, 1], anchors[:, 1]), 0) + inter_w = np.maximum(np.minimum(bb[:, 2], anchors[:, 2]) - np.maximum(bb[:, 0], anchors[:, 0]), 0) + inter_area = inter_h * inter_w + area = bb_area + anchor_area - inter_area + for i in range(len(bb_area)): + area[i] = abs(bb_area[i]) + abs(anchor_area[i]) - inter_area[i] + ious[:, num] = inter_area / area + + return ious + + +""" +计算 anchor与 gtboxes在垂直方向的差异参数 regression_factor(Vc, Vh) +1、(x1, y1, x2, y2) -> (ctr_x, ctr_y, w, h) +2、 Vc = (gt_y - anchor_y) / anchor_h + Vh = np.log(gt_h / anchor_h) +""" + + +def bbox_transfrom(anchors, gtboxes): + gt_y = (gtboxes[:, 1] + gtboxes[:, 3]) * 0.5 + gt_h = gtboxes[:, 3] - gtboxes[:, 1] + 1.0 + + anchor_y = (anchors[:, 1] + anchors[:, 3]) * 0.5 + anchor_h = anchors[:, 3] - anchors[:, 1] + 1.0 + + Vc = (gt_y - anchor_y) / anchor_h + Vh = np.log(gt_h / anchor_h) + + return np.vstack((Vc, Vh)).transpose() + + +""" +已知 anchor和差异参数 regression_factor(Vc, Vh),计算目标框 bbox +""" + + +def transform_bbox(anchors, regression_factor): + anchor_y = (anchors[:, 1] + anchors[:, 3]) * 0.5 + anchor_x = (anchors[:, 0] + anchors[:, 2]) * 0.5 + anchor_h = anchors[:, 3] - anchors[:, 1] + 1 + + Vc = regression_factor[0, :, 0] + Vh = regression_factor[0, :, 1] + + bbox_y = Vc * anchor_h + anchor_y + bbox_h = np.exp(Vh) * anchor_h + + x1 = anchor_x - 16 * 0.5 + y1 = bbox_y - bbox_h * 0.5 + x2 = anchor_x + 16 * 0.5 + y2 = bbox_y + bbox_h * 0.5 + bbox = np.vstack((x1, y1, x2, y2)).transpose() + + return bbox + + +""" +bbox 边界裁剪 + x1 >= 0 + y1 >= 0 + x2 < im_shape[1] + y2 < im_shape[0] +""" + + +def clip_bbox(bbox, im_shape): + bbox[:, 0] = np.maximum(np.minimum(bbox[:, 0], im_shape[1] - 1), 0) + bbox[:, 1] = np.maximum(np.minimum(bbox[:, 1], im_shape[0] - 1), 0) + bbox[:, 2] = np.maximum(np.minimum(bbox[:, 2], im_shape[1] - 1), 0) + bbox[:, 3] = np.maximum(np.minimum(bbox[:, 3], im_shape[0] - 1), 0) + + return bbox + + +""" +bbox尺寸过滤,舍弃小于设定最小尺寸的bbox +""" + + +def filter_bbox(bbox, minsize): + ws = bbox[:, 2] - bbox[:, 0] + 1 + hs = bbox[:, 3] - bbox[:, 1] + 1 + keep = np.where((ws >= minsize) & (hs >= minsize))[0] + return keep + + +""" +RPN module +1、生成anchor +2、计算anchor 和真值框 gtboxes的 iou +3、根据 iou,给每个anchor分配标签,0为负样本,1为正样本,-1为舍弃项 + (1) 对每个真值框 bbox,找出与其 iou最大的 anchor,设为正样本 + (2) 对每个anchor,记录其与每个bbox求取的 iou中最大的值 max_overlap + (3) 对max_overlap 大于设定阈值的anchor,将其设为正样本,小于设定阈值,则设定为负样本 +4、过滤超出边界的anchor框,将其标签设定为 -1 +5、选取不超过设定数量的正负样本 +6、求取anchor 取得max_overlap 时的gtbbox之间的真值差异量(Vc, Vh) +""" + + +def cal_rpn(imgsize, featuresize, scale, gtboxes): + base_anchor = gen_anchor(featuresize, scale) + overlaps = compute_iou(base_anchor, gtboxes) + + gt_argmax_overlaps = overlaps.argmax(axis=0) + anchor_argmax_overlaps = overlaps.argmax(axis=1) + anchor_max_overlaps = overlaps[range(overlaps.shape[0]), anchor_argmax_overlaps] + + labels = np.empty(base_anchor.shape[0]) + labels.fill(-1) + labels[gt_argmax_overlaps] = 1 + labels[anchor_max_overlaps > IOU_POSITIVE] = 1 + labels[anchor_max_overlaps < IOU_NEGATIVE] = 0 + + outside_anchor = np.where( + (base_anchor[:, 0] < 0) | + (base_anchor[:, 1] < 0) | + (base_anchor[:, 2] >= imgsize[1]) | + (base_anchor[:, 3] >= imgsize[0]) + )[0] + labels[outside_anchor] = -1 + + fg_index = np.where(labels == 1)[0] + if (len(fg_index) > RPN_POSITIVE_NUM): + labels[np.random.choice(fg_index, len(fg_index) - RPN_POSITIVE_NUM, replace=False)] = -1 + if not OHEM: + bg_index = np.where(labels == 0)[0] + num_bg = RPN_TOTAL_NUM - np.sum(labels == 1) + if (len(bg_index) > num_bg): + labels[np.random.choice(bg_index, len(bg_index) - num_bg, replace=False)] = -1 + + bbox_targets = bbox_transfrom(base_anchor, gtboxes[anchor_argmax_overlaps, :]) + + return [labels, bbox_targets] + + +""" +非极大值抑制,去除重叠框 +""" + + +def nms(dets, thresh): + x1 = dets[:, 0] + y1 = dets[:, 1] + x2 = dets[:, 2] + y2 = dets[:, 3] + scores = dets[:, 4] + + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + order = scores.argsort()[::-1] + + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + xx1 = np.maximum(x1[i], x1[order[1:]]) + yy1 = np.maximum(y1[i], y1[order[1:]]) + xx2 = np.minimum(x2[i], x2[order[1:]]) + yy2 = np.minimum(y2[i], y2[order[1:]]) + + w = np.maximum(0.0, xx2 - xx1 + 1) + h = np.maximum(0.0, yy2 - yy1 + 1) + inter = w * h + ovr = inter / (areas[i] + areas[order[1:]] - inter) + + inds = np.where(ovr <= thresh)[0] + order = order[inds + 1] + return keep + + +""" +基于图的文本行构造算法 +子图连接规则,根据图中配对的文本框生成文本行 +1、遍历 graph 的行和列,寻找列全为false、行不全为false的行和列,索引号为index +2、找到 graph 的第 index 行中为true的那项的索引号,加入子图中,并将索引号迭代给index +3、重复步骤2,直到 graph 的第 index 行全部为false +4、重复步骤1、2、3,遍历完graph +返回文本行list[文本框索引] +""" + + +class Graph: + def __init__(self, graph): + self.graph = graph + + def sub_graphs_connected(self): + sub_graphs = [] + for index in range(self.graph.shape[0]): + if not self.graph[:, index].any() and self.graph[index, :].any(): + v = index + sub_graphs.append([v]) + while self.graph[v, :].any(): + v = np.where(self.graph[v, :])[0][0] + sub_graphs[-1].append(v) + + return sub_graphs + + +""" +配置参数 +MAX_HORIZONTAL_GAP: 文本行内,文本框最大水平距离 +MIN_V_OVERLAPS: 文本框最小垂直iou +MIN_SIZE_SIM: 文本框尺寸最小相似度 +""" + + +class TextLineCfg: + SCALE = 600 + MAX_SCALE = 1200 + TEXT_PROPOSALS_WIDTH = 16 + MIN_NUM_PROPOSALS = 2 + MIN_RATIO = 0.5 + LINE_MIN_SCORE = 0.9 + TEXT_PROPOSALS_MIN_SCORE = 0.7 + TEXT_PROPOSALS_NMS_THRESH = 0.3 + MAX_HORIZONTAL_GAP = 60 + MIN_V_OVERLAPS = 0.6 + MIN_SIZE_SIM = 0.6 + + +class TextProposalGraphBuilder: + """ + 构建配对的文本框 + """ + + def __init__(self): + self.im_size = im_size + self.scores = scores + self.text_proposals = text_proposals + self.heights = text_proposals[:, 3] - text_proposals[:, 1] + 1 + self.boxes_table = boxes_table + + def get_successions(self, index): + """ + 遍历[x0, x0+MAX_HORIZONTAL_GAP] + 获取指定索引号的后继文本框 + """ + box = self.text_proposals[index] + results = [] + for left in range(int(box[0]) + 1, min(int(box[0]) + TextLineCfg.MAX_HORIZONTAL_GAP + 1, self.im_size[1])): + adj_box_indices = self.boxes_table[left] + for adj_box_index in adj_box_indices: + if self.meet_v_iou(adj_box_index, index): + results.append(adj_box_index) + if len(results) != 0: + return results + + return results + + def get_precursors(self, index): + """ + 遍历[x0-MAX_HORIZONTAL_GAP, x0] + 获取指定索引号的前驱文本框 + """ + box = self.text_proposals[index] + results = [] + for left in range(int(box[0]) - 1, max(int(box[0] - TextLineCfg.MAX_HORIZONTAL_GAP), 0) - 1, -1): + adj_box_indices = self.boxes_table[left] + for adj_box_index in adj_box_indices: + if self.meet_v_iou(adj_box_index, index): + results.append(adj_box_index) + if len(results) != 0: + return results + + return results + + def is_succession_node(self, index, succession_index): + """ + 判断是否是配对的文本框 + """ + precursors = self.get_precursors(succession_index) + if self.scores[index] >= np.max(self.scores[precursors]): + return True + + return False + + def meet_v_iou(self, index1, index2): + """ + 判断两个文本框是否满足垂直方向的iou条件 + overlaps_v: 文本框垂直方向的iou计算。 iou_v = inv_y/min(h1, h2) + size_similarity: 文本框在垂直方向的高度尺寸相似度。 sim = min(h1, h2)/max(h1, h2) + """ + + def overlaps_v(index1, index2): + h1 = self.heights[index1] + h2 = self.heights[index2] + y0 = max(self.text_proposals[index2][1], self.text_proposals[index1][1]) + y1 = min(self.text_proposals[index2][3], self.text_proposals[index1][3]) + return max(0, y1 - y0 + 1) / min(h1, h2) + + def size_similarity(index1, index2): + h1 = self.heights[index1] + h2 = self.heights[index2] + return min(h1, h2) / max(h1, h2) + + return overlaps_v(index1, index2) >= TextLineCfg.MIN_V_OVERLAPS and \ + size_similarity(index1, index2) >= TextLineCfg.MIN_SIZE_SIM + + def build_graph(self, text_proposals, scores, im_size): + """ + 根据文本框构建文本框对 + self.heights: 所有文本框的高度 + self.boxes_table: 将文本框根据左上点的x1坐标进行分组 + graph: bool类型的[n, n]数组,表示两个文本框是否配对,n为文本框的个数 + (1) 获取当前文本框Bi的后继文本框 + (2) 选取后继文本框中得分最高的,记为Bj + (3) 获取Bj的前驱文本框 + (4) 如果Bj的前驱文本框中得分最高的恰好是 Bi,则构成文本框对 + """ + + boxes_table = [[] for _ in range(self.im_size[1])] + for index, box in enumerate(text_proposals): + boxes_table[int(box[0])].append(index) + + graph = np.zeros((text_proposals.shape[0], text_proposals.shape[0]), np.bool) + + for index, box in enumerate(text_proposals): + successions = self.get_successions(index) + if len(successions) == 0: + continue + succession_index = successions[np.argmax(scores[successions])] + if self.is_succession_node(index, succession_index): + graph[index, succession_index] = True + + return Graph(graph) + + +class TextProposalConnectorOriented: + """ + 连接文本框,构建文本行bbox + """ + + def __init__(self): + self.graph_builder = TextProposalGraphBuilder() + + def group_text_proposals(self, text_proposals, scores, im_size): + """ + 将文本框连接起来,按文本行分组 + """ + graph = self.graph_builder.build_graph(text_proposals, scores, im_size) + + return graph.sub_graphs_connected() + + def fit_y(self, X, Y, x1, x2): + """ + 一元线性函数拟合X,Y,返回y1, y2的坐标值 + """ + if np.sum(X == X[0]) == len(X): + return Y[0], Y[0] + p = np.poly1d(np.polyfit(X, Y, 1)) + return p(x1), p(x2) + + def get_text_lines(self, text_proposals, scores, im_size): + """ + 根据文本框,构建文本行 + 1、将文本框划分成文本行组,每个文本行组内包含符合规则的文本框 + 2、处理每个文本行组,将其串成一个大的文本行 + (1) 获取文本行组内的所有文本框 text_line_boxes + (2) 求取每个组内每个文本框的中心坐标 (X, Y),最小、最大宽度坐标值 (x0 ,x1) + (3) 拟合所有中心点直线 z1 + (4) 设置offset为文本框宽度的一半 + (5) 拟合组内所有文本框的左上角点直线,并返回当x取 (x0+offset, x1-offset)时的极作极右y坐标 (lt_y, rt_y) + (6) 拟合组内所有文本框的左下角点直线,并返回当x取 (x0+offset, x1-offset)时的极作极右y坐标 (lb_y, rb_y) + (7) 取文本行组内所有框的评分的均值,作为该文本行的分数 + (8) 生成文本行基本数据 + 3、生成大文本框 + """ + tp_groups = self.group_text_proposals(text_proposals, scores, im_size) + + text_lines = np.zeros((len(tp_groups), 8), np.float32) + for index, tp_indices in enumerate(tp_groups): + text_line_boxes = text_proposals[list(tp_indices)] + + X = (text_line_boxes[:, 0] + text_line_boxes[:, 2]) / 2 + Y = (text_line_boxes[:, 1] + text_line_boxes[:, 3]) / 2 + x0 = np.min(text_line_boxes[:, 0]) + x1 = np.max(text_line_boxes[:, 2]) + + z1 = np.polyfit(X, Y, 1) + + offset = (text_line_boxes[0, 2] - text_line_boxes[0, 0]) * 0.5 + + lt_y, rt_y = self.fit_y(text_line_boxes[:, 0], text_line_boxes[:, 1], x0 + offset, x1 - offset) + lb_y, rb_y = self.fit_y(text_line_boxes[:, 0], text_line_boxes[:, 3], x0 + offset, x1 - offset) + + score = scores[list(tp_indices)].sum() / float(len(tp_indices)) + + text_lines[index, 0] = x0 + text_lines[index, 1] = min(lt_y, rt_y) # 文本行上端 线段 的y坐标的小值 + text_lines[index, 2] = x1 + text_lines[index, 3] = max(lb_y, rb_y) # 文本行下端 线段 的y坐标的大值 + text_lines[index, 4] = score # 文本行得分 + text_lines[index, 5] = z1[0] # 根据中心点拟合的直线的k,b + text_lines[index, 6] = z1[1] + height = np.mean((text_line_boxes[:, 3] - text_line_boxes[:, 1])) # 小框平均高度 + text_lines[index, 7] = height + 2.5 + + text_recs = np.zeros((len(text_lines), 9), np.float) + index = 0 + for line in text_lines: + b1 = line[6] - line[7] / 2 # 根据高度和文本行中心线,求取文本行上下两条线的b值 + b2 = line[6] + line[7] / 2 + x1 = line[0] + y1 = line[5] * line[0] + b1 # 左上 + x2 = line[2] + y2 = line[5] * line[2] + b1 # 右上 + x3 = line[0] + y3 = line[5] * line[0] + b2 # 左下 + x4 = line[2] + y4 = line[5] * line[2] + b2 # 右下 + disX = x2 - x1 + disY = y2 - y1 + width = np.sqrt(disX * disX + disY * disY) # 文本行宽度 + + fTmp0 = y3 - y1 # 文本行高度 + fTmp1 = fTmp0 * disY / width + x = np.fabs(fTmp1 * disX / width) # 做补偿 + y = np.fabs(fTmp1 * disY / width) + if line[5] < 0: + x1 -= x + y1 += y + x4 += x + y4 -= y + else: + x2 += x + y2 += y + x3 -= x + y3 -= y + text_recs[index, 0] = x1 + text_recs[index, 1] = y1 + text_recs[index, 2] = x2 + text_recs[index, 3] = y2 + text_recs[index, 4] = x3 + text_recs[index, 5] = y3 + text_recs[index, 6] = x4 + text_recs[index, 7] = y4 + text_recs[index, 8] = line[4] + index = index + 1 + + return text_recs + + +if __name__ == '__main__': + anchor = gen_anchor((10, 15), 16) diff --git a/PyTorch/contrib/cv/detection/CTPN/modelzoo_level.txt b/PyTorch/contrib/cv/detection/CTPN/modelzoo_level.txt new file mode 100644 index 0000000000..d3c415da47 --- /dev/null +++ b/PyTorch/contrib/cv/detection/CTPN/modelzoo_level.txt @@ -0,0 +1,5 @@ +GPUStatus:OK +NPUMigrationStatus:POK +FuncStatus:OK +PrecisionStatus:POK +PerfStatus:NOK diff --git a/PyTorch/contrib/cv/detection/CTPN/predict_2_txt.py b/PyTorch/contrib/cv/detection/CTPN/predict_2_txt.py new file mode 100644 index 0000000000..d8fe721ae4 --- /dev/null +++ b/PyTorch/contrib/cv/detection/CTPN/predict_2_txt.py @@ -0,0 +1,152 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# -*- coding:utf-8 -*- +import argparse +import copy +import os +import cv2 +import shutil +import math +import numpy as np +import torch +import torch.nn.functional as F +from ctpn import config +from ctpn.ctpn import CTPN_Model +from ctpn.utils import gen_anchor, transform_bbox, clip_bbox, filter_bbox, nms, TextProposalConnectorOriented +from tqdm import tqdm + +parser = argparse.ArgumentParser(description='PyTorch CTPN Testing') +parser.add_argument('--data-path', default='/home/dockerHome/ctpn/ctpn_8p/imagedata/', type=str, + help='number of data loading workers (default: 4)') +parser.add_argument('--model-path', default='./output_models/checkpoint-200.pth.tar', type=str, + help='number of total epochs to run') +args = parser.parse_args() + + +def load_state_dict(state_dicts): + new_state_dicts = {} + for k, v in state_dicts.items(): + if (k[:7] == 'module.'): + name = k[7:] + else: + name = k + new_state_dicts[name] = v + return new_state_dicts + + +long_len = 1008 + +device = 'npu:0' +weights = args.model_path +model = CTPN_Model().to(device) +state_dict = torch.load(weights, map_location='cpu')['state_dict'] +new_state_dict = load_state_dict(state_dict) +model.load_state_dict(new_state_dict) +model.eval() + + +def get_text_boxes(image, img_name=None, display=True, prob_thresh=0.5): + h, w = image.shape[:2] + image_c = copy.deepcopy(image) + image = image[:, :, ::-1] + rescale_fac = max(h, w) / long_len + if rescale_fac > 1.0: + h = int(h / rescale_fac) + w = int(w / rescale_fac) + image = cv2.resize(image, (w, h)) + h, w = image.shape[:2] + target_w = int(math.ceil(w / 16)) * 16 + target_h = int(math.ceil(h / 16)) * 16 + + w_pad = (target_w - w) // 2 + h_pad = (target_h - h) // 2 + + w_pad2 = target_w - w - w_pad + h_pad2 = target_h - h - h_pad + image = np.pad(image, ((h_pad, h_pad2), (w_pad, w_pad2), (0, 0))) + image = image.astype(np.float32) - config.IMAGE_MEAN + image = torch.from_numpy(image.transpose(2, 0, 1)).unsqueeze(0).float().to(device) + h, w = target_h, target_w + with torch.no_grad(): + cls, regr = model(image) + cls_prob = F.softmax(cls, dim=-1).cpu().numpy() + regr = regr.cpu().numpy() + anchor = gen_anchor((int(h / 16), int(w / 16)), 16) + bbox = transform_bbox(anchor, regr) + bbox = clip_bbox(bbox, [h, w]) + + fg = np.where(cls_prob[0, :, 1] > prob_thresh)[0] + select_anchor = bbox[fg, :] + select_score = cls_prob[0, fg, 1] + select_anchor = select_anchor.astype(np.int32) + keep_index = filter_bbox(select_anchor, 16) + + select_anchor = select_anchor[keep_index] + select_score = select_score[keep_index] + select_score = np.reshape(select_score, (select_score.shape[0], 1)) + nmsbox = np.hstack((select_anchor, select_score)) + keep = nms(nmsbox, 0.3) + select_anchor = select_anchor[keep] + select_score = select_score[keep] + + textConn = TextProposalConnectorOriented() + text = textConn.get_text_lines(select_anchor, select_score, [h, w]) + text_n = copy.deepcopy(text) + text[:, [0, 2, 4, 6]] = text[:, [0, 2, 4, 6]] - w_pad + text[:, [1, 3, 5, 7]] = text[:, [1, 3, 5, 7]] - h_pad + if rescale_fac > 1.0: + text_ori = text * rescale_fac + else: + text_ori = text + if display: + for i in text_ori: + s = str(round(i[-1] * 100, 2)) + '%' + i = [int(j) for j in i] + cv2.line(image_c, (i[0], i[1]), (i[2], i[3]), (0, 255, 0), 2) + cv2.line(image_c, (i[0], i[1]), (i[4], i[5]), (0, 255, 0), 2) + cv2.line(image_c, (i[6], i[7]), (i[2], i[3]), (0, 255, 0), 2) + cv2.line(image_c, (i[4], i[5]), (i[6], i[7]), (0, 255, 0), 2) + cv2.putText(image_c, s, (i[0] + 13, i[1] + 13), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2, + cv2.LINE_AA) + cv2.imwrite(f'./display/{img_name}', image_c) + + return text_n, text_ori, image_c # 返回文字坐标,原始文字坐标,原图像 + + +def make_dir(path): + if os.path.exists(path): + shutil.rmtree(path) + os.mkdir(path) + + +def predict_2_txt(): + img_name_list = os.listdir(os.path.join(args.data_path, 'Challenge2_Test_Task12_Images')) + make_dir('./display/') + make_dir('./results/') + make_dir('./results/predict_txt_2/') + for k in tqdm(range(len(img_name_list))): + img_path = os.path.join(args.data_path, 'Challenge2_Test_Task12_Images/{}'.format(img_name_list[k])) + input_img = cv2.imread(img_path) + text, text_ori, out_img = get_text_boxes(input_img, img_name=img_name_list[k], display=True) + with open('./results/predict_txt_2/res_{}.txt'.format(img_name_list[k][:-4]), 'w') as f: + for i in range(text_ori.shape[0]): + x_min, y_min = min(text_ori[i][0:8:2]), min(text_ori[i][1:8:2]) + x_max, y_max = max(text_ori[i][0:8:2]), max(text_ori[i][1:8:2]) + f.write('{},{},{},{}'.format(int(x_min), int(y_min), int(x_max), int(y_max))) + f.write('\n') + + +if __name__ == '__main__': + predict_2_txt() diff --git a/PyTorch/contrib/cv/detection/CTPN/requirements.txt b/PyTorch/contrib/cv/detection/CTPN/requirements.txt new file mode 100644 index 0000000000..52ee8f5523 --- /dev/null +++ b/PyTorch/contrib/cv/detection/CTPN/requirements.txt @@ -0,0 +1,4 @@ +-r requirements/runtime.txt +-r requirements/optional.txt +-r requirements/tests.txt +-r requirements/build.txt diff --git a/PyTorch/contrib/cv/detection/CTPN/scripts/rrc_evaluation_funcs_1_1.py b/PyTorch/contrib/cv/detection/CTPN/scripts/rrc_evaluation_funcs_1_1.py new file mode 100644 index 0000000000..a8b0a82f25 --- /dev/null +++ b/PyTorch/contrib/cv/detection/CTPN/scripts/rrc_evaluation_funcs_1_1.py @@ -0,0 +1,505 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 +# encoding: UTF-8 + +# File: rrc_evaluation_funcs_1_1.py +# Version: 1.1 +# Version info: changes for Python 3 +# Date: 2019-12-29 +# Description: File with useful functions to use by the evaluation scripts in the RRC website. + +import json +import sys; + +sys.path.append('./') +import zipfile +import re +import os +import importlib + + +def print_help(): + sys.stdout.write('Usage: python %s.py -g= -s= [-o= -p=]' % sys.argv[0]) + sys.exit(2) + + +def load_zip_file_keys(file, fileNameRegExp=''): + """ + Returns an array with the entries of the ZIP file that match with the regular expression. + The key's are the names or the file or the capturing group definied in the fileNameRegExp + """ + try: + archive = zipfile.ZipFile(file, mode='r', allowZip64=True) + except: + raise Exception('Error loading the ZIP archive.') + + pairs = [] + + for name in archive.namelist(): + addFile = True + keyName = name + if fileNameRegExp != "": + m = re.match(fileNameRegExp, name) + if m == None: + addFile = False + else: + if len(m.groups()) > 0: + keyName = m.group(1) + + if addFile: + pairs.append(keyName) + + return pairs + + +def load_zip_file(file, fileNameRegExp='', allEntries=False): + """ + Returns an array with the contents (filtered by fileNameRegExp) of a ZIP file. + The key's are the names or the file or the capturing group definied in the fileNameRegExp + allEntries validates that all entries in the ZIP file pass the fileNameRegExp + """ + try: + archive = zipfile.ZipFile(file, mode='r', allowZip64=True) + except: + raise Exception('Error loading the ZIP archive') + + pairs = [] + for name in archive.namelist(): + addFile = True + keyName = name + if fileNameRegExp != "": + m = re.match(fileNameRegExp, name) + if m == None: + addFile = False + else: + if len(m.groups()) > 0: + keyName = m.group(1) + + if addFile: + pairs.append([keyName, archive.read(name)]) + else: + if allEntries: + raise Exception('ZIP entry not valid: %s' % name) + + return dict(pairs) + + +def decode_utf8(raw): + """ + Returns a Unicode object on success, or None on failure + """ + try: + return raw.decode('utf-8-sig', errors='replace') + except Exception as e: + return None + + +def validate_lines_in_file(fileName, file_contents, CRLF=True, LTRB=True, withTranscription=False, withConfidence=False, + imWidth=0, imHeight=0): + """ + This function validates that all lines of the file calling the Line validation function for each line + """ + utf8File = decode_utf8(file_contents) + if (utf8File is None): + raise Exception("The file %s is not UTF-8" % fileName) + + lines = utf8File.split("\r\n" if CRLF else "\n") + for line in lines: + line = line.replace("\r", "").replace("\n", "") + if (line != ""): + try: + validate_tl_line(line, LTRB, withTranscription, withConfidence, imWidth, imHeight) + except Exception as e: + raise Exception( + ("Line in sample not valid. Sample: %s Line: %s Error: %s" % (fileName, line, str(e))).encode( + 'utf-8', 'replace')) + + +def validate_tl_line(line, LTRB=True, withTranscription=True, withConfidence=True, imWidth=0, imHeight=0): + """ + Validate the format of the line. If the line is not valid an exception will be raised. + If maxWidth and maxHeight are specified, all points must be inside the imgage bounds. + Posible values are: + LTRB=True: xmin,ymin,xmax,ymax[,confidence][,transcription] + LTRB=False: x1,y1,x2,y2,x3,y3,x4,y4[,confidence][,transcription] + """ + get_tl_line_values(line, LTRB, withTranscription, withConfidence, imWidth, imHeight) + + +def get_tl_line_values(line, LTRB=True, withTranscription=False, withConfidence=False, imWidth=0, imHeight=0): + """ + Validate the format of the line. If the line is not valid an exception will be raised. + If maxWidth and maxHeight are specified, all points must be inside the imgage bounds. + Posible values are: + LTRB=True: xmin,ymin,xmax,ymax[,confidence][,transcription] + LTRB=False: x1,y1,x2,y2,x3,y3,x4,y4[,confidence][,transcription] + Returns values from a textline. Points , [Confidences], [Transcriptions] + """ + confidence = 0.0 + transcription = ""; + points = [] + + numPoints = 4; + + if LTRB: + + numPoints = 4; + + if withTranscription and withConfidence: + m = re.match( + r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$', line) + if m == None: + m = re.match( + r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$', + line) + raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,confidence,transcription") + elif withConfidence: + m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*$', + line) + if m == None: + raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,confidence") + elif withTranscription: + m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,(.*)$', line) + if m == None: + raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,transcription") + else: + m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,?\s*$', line) + if m == None: + raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax") + + xmin = int(m.group(1)) + ymin = int(m.group(2)) + xmax = int(m.group(3)) + ymax = int(m.group(4)) + if (xmax < xmin): + raise Exception("Xmax value (%s) not valid (Xmax < Xmin)." % (xmax)) + if (ymax < ymin): + raise Exception("Ymax value (%s) not valid (Ymax < Ymin)." % (ymax)) + + points = [float(m.group(i)) for i in range(1, (numPoints + 1))] + + if (imWidth > 0 and imHeight > 0): + validate_point_inside_bounds(xmin, ymin, imWidth, imHeight); + validate_point_inside_bounds(xmax, ymax, imWidth, imHeight); + + else: + + numPoints = 8; + + if withTranscription and withConfidence: + m = re.match( + r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$', + line) + if m == None: + raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,confidence,transcription") + elif withConfidence: + m = re.match( + r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-1].?[0-9]*)\s*$', + line) + if m == None: + raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,confidence") + elif withTranscription: + m = re.match( + r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,(.*)$', + line) + if m == None: + raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,transcription") + else: + m = re.match( + r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*$', + line) + if m == None: + raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4") + + points = [float(m.group(i)) for i in range(1, (numPoints + 1))] + + validate_clockwise_points(points) + + if (imWidth > 0 and imHeight > 0): + validate_point_inside_bounds(points[0], points[1], imWidth, imHeight); + validate_point_inside_bounds(points[2], points[3], imWidth, imHeight); + validate_point_inside_bounds(points[4], points[5], imWidth, imHeight); + validate_point_inside_bounds(points[6], points[7], imWidth, imHeight); + + if withConfidence: + try: + confidence = float(m.group(numPoints + 1)) + except ValueError: + raise Exception("Confidence value must be a float") + + if withTranscription: + posTranscription = numPoints + (2 if withConfidence else 1) + transcription = m.group(posTranscription) + m2 = re.match(r'^\s*\"(.*)\"\s*$', transcription) + if m2 != None: # Transcription with double quotes, we extract the value and replace escaped characters + transcription = m2.group(1).replace("\\\\", "\\").replace("\\\"", "\"") + + return points, confidence, transcription + + +def get_tl_dict_values(detection, withTranscription=False, withConfidence=False, imWidth=0, imHeight=0, + validNumPoints=[], validate_cw=True): + """ + Validate the format of the dictionary. If the dictionary is not valid an exception will be raised. + If maxWidth and maxHeight are specified, all points must be inside the imgage bounds. + Posible values: + {"points":[[x1,y1],[x2,y2],[x3,x3],..,[xn,yn]]} + {"points":[[x1,y1],[x2,y2],[x3,x3],..,[xn,yn]],"transcription":"###","confidence":0.4,"illegibility":false} + {"points":[[x1,y1],[x2,y2],[x3,x3],..,[xn,yn]],"transcription":"###","confidence":0.4,"dontCare":false} + Returns values from the dictionary. Points , [Confidences], [Transcriptions] + """ + confidence = 0.0 + transcription = ""; + points = [] + + if isinstance(detection, dict) == False: + raise Exception("Incorrect format. Object has to be a dictionary") + + if not 'points' in detection: + raise Exception("Incorrect format. Object has no points key)") + + if isinstance(detection['points'], list) == False: + raise Exception("Incorrect format. Object points key have to be an array)") + + num_points = len(detection['points']) + + if num_points < 3: + raise Exception( + "Incorrect format. Incorrect number of points. At least 3 points are necessary. Found: " + str(num_points)) + + if (len(validNumPoints) > 0 and num_points in validNumPoints == False): + raise Exception("Incorrect format. Incorrect number of points. Only allowed 4,8 or 12 points)") + + for i in range(num_points): + if isinstance(detection['points'][i], list) == False: + raise Exception("Incorrect format. Point #" + str(i + 1) + " has to be an array)") + + if len(detection['points'][i]) != 2: + raise Exception("Incorrect format. Point #" + str(i + 1) + " has to be an array with 2 objects(x,y) )") + + if isinstance(detection['points'][i][0], (int, float)) == False or isinstance(detection['points'][i][1], + (int, float)) == False: + raise Exception("Incorrect format. Point #" + str(i + 1) + " childs have to be Integers)") + + if (imWidth > 0 and imHeight > 0): + validate_point_inside_bounds(detection['points'][i][0], detection['points'][i][1], imWidth, imHeight); + + points.append(float(detection['points'][i][0])) + points.append(float(detection['points'][i][1])) + + if validate_cw: + validate_clockwise_points(points) + + if withConfidence: + if not 'confidence' in detection: + raise Exception("Incorrect format. No confidence key)") + + if isinstance(detection['confidence'], (int, float)) == False: + raise Exception("Incorrect format. Confidence key has to be a float)") + + if detection['confidence'] < 0 or detection['confidence'] > 1: + raise Exception("Incorrect format. Confidence key has to be a float between 0.0 and 1.0") + + confidence = detection['confidence'] + + if withTranscription: + if not 'transcription' in detection: + raise Exception("Incorrect format. No transcription key)") + + if isinstance(detection['transcription'], str) == False: + raise Exception("Incorrect format. Transcription has to be a string. Detected: " + type( + detection['transcription']).__name__) + + transcription = detection['transcription'] + + if 'illegibility' in detection: # Ensures that if illegibility atribute is present and is True the transcription is set to ### (don't care) + if detection['illegibility'] == True: + transcription = "###" + + if 'dontCare' in detection: # Ensures that if dontCare atribute is present and is True the transcription is set to ### (don't care) + if detection['dontCare'] == True: + transcription = "###" + + return points, confidence, transcription + + +def validate_point_inside_bounds(x, y, imWidth, imHeight): + if (x < 0 or x > imWidth): + raise Exception("X value (%s) not valid. Image dimensions: (%s,%s)" % (x, imWidth, imHeight)) + if (y < 0 or y > imHeight): + raise Exception( + "Y value (%s) not valid. Image dimensions: (%s,%s)" % (y, imWidth, imHeight)) + + +def validate_clockwise_points(points): + """ + Validates that the points are in clockwise order. + """ + edge = [] + for i in range(len(points) // 2): + edge.append((int(points[(i + 1) * 2 % len(points)]) - int(points[i * 2])) * ( + int(points[((i + 1) * 2 + 1) % len(points)]) + int(points[i * 2 + 1]))) + if sum(edge) > 0: + raise Exception( + "Points are not clockwise. The coordinates of bounding points have to be given in clockwise order. Regarding the correct interpretation of 'clockwise' remember that the image coordinate system used is the standard one, with the image origin at the upper left, the X axis extending to the right and Y axis extending downwards.") + + +def get_tl_line_values_from_file_contents(content, CRLF=True, LTRB=True, withTranscription=False, withConfidence=False, + imWidth=0, imHeight=0, sort_by_confidences=True): + """ + Returns all points, confindences and transcriptions of a file in lists. Valid line formats: + xmin,ymin,xmax,ymax,[confidence],[transcription] + x1,y1,x2,y2,x3,y3,x4,y4,[confidence],[transcription] + """ + pointsList = [] + transcriptionsList = [] + confidencesList = [] + + lines = content.split("\r\n" if CRLF else "\n") + for line in lines: + line = line.replace("\r", "").replace("\n", "") + if (line != ""): + points, confidence, transcription = get_tl_line_values(line, LTRB, withTranscription, withConfidence, + imWidth, imHeight); + pointsList.append(points) + transcriptionsList.append(transcription) + confidencesList.append(confidence) + + if withConfidence and len(confidencesList) > 0 and sort_by_confidences: + import numpy as np + sorted_ind = np.argsort(-np.array(confidencesList)) + confidencesList = [confidencesList[i] for i in sorted_ind] + pointsList = [pointsList[i] for i in sorted_ind] + transcriptionsList = [transcriptionsList[i] for i in sorted_ind] + + return pointsList, confidencesList, transcriptionsList + + +def get_tl_dict_values_from_array(array, withTranscription=False, withConfidence=False, imWidth=0, imHeight=0, + sort_by_confidences=True, validNumPoints=[], validate_cw=True): + """ + Returns all points, confindences and transcriptions of a file in lists. Valid dict formats: + {"points":[[x1,y1],[x2,y2],[x3,x3],..,[xn,yn]],"transcription":"###","confidence":0.4} + """ + pointsList = [] + transcriptionsList = [] + confidencesList = [] + + for n in range(len(array)): + objectDict = array[n] + points, confidence, transcription = get_tl_dict_values(objectDict, withTranscription, withConfidence, imWidth, + imHeight, validNumPoints, validate_cw); + pointsList.append(points) + transcriptionsList.append(transcription) + confidencesList.append(confidence) + + if withConfidence and len(confidencesList) > 0 and sort_by_confidences: + import numpy as np + sorted_ind = np.argsort(-np.array(confidencesList)) + confidencesList = [confidencesList[i] for i in sorted_ind] + pointsList = [pointsList[i] for i in sorted_ind] + transcriptionsList = [transcriptionsList[i] for i in sorted_ind] + + return pointsList, confidencesList, transcriptionsList + + +def main_evaluation(p, default_evaluation_params_fn, validate_data_fn, evaluate_method_fn, show_result=True, + per_sample=True): + """ + This process validates a method, evaluates it and if it succed generates a ZIP file with a JSON entry for each sample. + Params: + p: Dictionary of parmeters with the GT/submission locations. If None is passed, the parameters send by the system are used. + default_evaluation_params_fn: points to a function that returns a dictionary with the default parameters used for the evaluation + validate_data_fn: points to a method that validates the corrct format of the submission + evaluate_method_fn: points to a function that evaluated the submission and return a Dictionary with the results + """ + + if (p == None): + p = dict([s[1:].split('=') for s in sys.argv[1:]]) + if (len(sys.argv) < 3): + print_help() + + evalParams = default_evaluation_params_fn() + if 'p' in p.keys(): + evalParams.update(p['p'] if isinstance(p['p'], dict) else json.loads(p['p'])) + + resDict = {'calculated': True, 'Message': '', 'method': '{}', 'per_sample': '{}'} + try: + validate_data_fn(p['g'], p['s'], evalParams) + evalData = evaluate_method_fn(p['g'], p['s'], evalParams) + resDict.update(evalData) + + except Exception as e: + resDict['Message'] = str(e) + resDict['calculated'] = False + + if 'o' in p: + if not os.path.exists(p['o']): + os.makedirs(p['o']) + + resultsOutputname = p['o'] + '/results.zip' + outZip = zipfile.ZipFile(resultsOutputname, mode='w', allowZip64=True) + + del resDict['per_sample'] + if 'output_items' in resDict.keys(): + del resDict['output_items'] + + outZip.writestr('method.json', json.dumps(resDict)) + + if not resDict['calculated']: + if show_result: + sys.stderr.write('Error!\n' + resDict['Message'] + '\n\n') + if 'o' in p: + outZip.close() + return resDict + + if 'o' in p: + if per_sample == True: + for k, v in evalData['per_sample'].items(): + outZip.writestr(k + '.json', json.dumps(v)) + + if 'output_items' in evalData.keys(): + for k, v in evalData['output_items'].items(): + outZip.writestr(k, v) + + outZip.close() + + if show_result: + sys.stdout.write("Calculated!") + sys.stdout.write(json.dumps(resDict['method'])) + + return resDict + + +def main_validation(default_evaluation_params_fn, validate_data_fn): + """ + This process validates a method + Params: + default_evaluation_params_fn: points to a function that returns a dictionary with the default parameters used for the evaluation + validate_data_fn: points to a method that validates the corrct format of the submission + """ + try: + p = dict([s[1:].split('=') for s in sys.argv[1:]]) + evalParams = default_evaluation_params_fn() + if 'p' in p.keys(): + evalParams.update(p['p'] if isinstance(p['p'], dict) else json.loads(p['p'])) + + validate_data_fn(p['g'], p['s'], evalParams) + print('SUCCESS') + sys.exit(0) + except Exception as e: + print(str(e)) + sys.exit(101) diff --git a/PyTorch/contrib/cv/detection/CTPN/scripts/script.py b/PyTorch/contrib/cv/detection/CTPN/scripts/script.py new file mode 100644 index 0000000000..d10e8e1b26 --- /dev/null +++ b/PyTorch/contrib/cv/detection/CTPN/scripts/script.py @@ -0,0 +1,412 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +# File: TL2p_deteval_1_1.py +# Version: 1.1 +# Version info: changes for Python 3 +# Date: 2019-12-29 +# Description: Evaluation script that computes Text Localization following the Deteval implementation + +from collections import namedtuple +import scripts.rrc_evaluation_funcs_1_1 as rrc_evaluation_funcs +import importlib +import numpy as np +import math + + +def evaluation_imports(): + """ + evaluation_imports: Dictionary ( key = module name , value = alias ) with python modules used in the evaluation. + """ + return { + 'math': 'math', + 'numpy': 'np' + } + + +def default_evaluation_params(): + """ + default_evaluation_params: Default parameters to use for the validation and evaluation. + """ + return { + 'AREA_RECALL_CONSTRAINT': 0.8, + 'AREA_PRECISION_CONSTRAINT': 0.4, + 'EV_PARAM_IND_CENTER_DIFF_THR': 1, + 'MTYPE_OO_O': 1., + 'MTYPE_OM_O': 0.8, + 'MTYPE_OM_M': 1., + 'GT_SAMPLE_NAME_2_ID': 'gt_img_([0-9]+).txt', + 'DET_SAMPLE_NAME_2_ID': 'res_img_([0-9]+).txt', + 'CRLF': False # Lines are delimited by Windows CRLF format + } + + +def validate_data(gtFilePath, submFilePath, evaluationParams): + """ + Method validate_data: validates that all files in the results folder are correct (have the correct name contents). + Validates also that there are no missing files in the folder. + If some error detected, the method raises the error + """ + gt = rrc_evaluation_funcs.load_zip_file(gtFilePath, evaluationParams['GT_SAMPLE_NAME_2_ID']) + + subm = rrc_evaluation_funcs.load_zip_file(submFilePath, evaluationParams['DET_SAMPLE_NAME_2_ID'], True) + + # Validate format of GroundTruth + for k in gt: + rrc_evaluation_funcs.validate_lines_in_file(k, gt[k], evaluationParams['CRLF'], True, True) + + # Validate format of results + for k in subm: + if (k in gt) == False: + raise Exception("The sample %s not present in GT" % k) + + rrc_evaluation_funcs.validate_lines_in_file(k, subm[k], evaluationParams['CRLF'], True, False) + + +def evaluate_method(gtFilePath, submFilePath, evaluationParams): + """ + Method evaluate_method: evaluate method and returns the results + Results. Dictionary with the following values: + - method (required) Global method metrics. Ex: { 'Precision':0.8,'Recall':0.9 } + - samples (optional) Per sample metrics. Ex: {'sample1' : { 'Precision':0.8,'Recall':0.9 } , 'sample2' : { 'Precision':0.8,'Recall':0.9 } + """ + + for module, alias in evaluation_imports().items(): + globals()[alias] = importlib.import_module(module) + + def one_to_one_match(row, col): + cont = 0 + for j in range(len(recallMat[0])): + if recallMat[row, j] >= evaluationParams['AREA_RECALL_CONSTRAINT'] and precisionMat[row, j] >= \ + evaluationParams['AREA_PRECISION_CONSTRAINT']: + cont = cont + 1 + if (cont != 1): + return False + cont = 0 + for i in range(len(recallMat)): + if recallMat[i, col] >= evaluationParams['AREA_RECALL_CONSTRAINT'] and precisionMat[i, col] >= \ + evaluationParams['AREA_PRECISION_CONSTRAINT']: + cont = cont + 1 + if (cont != 1): + return False + + if recallMat[row, col] >= evaluationParams['AREA_RECALL_CONSTRAINT'] and precisionMat[row, col] >= \ + evaluationParams['AREA_PRECISION_CONSTRAINT']: + return True + return False + + def num_overlaps_gt(gtNum): + cont = 0 + for detNum in range(len(detRects)): + if detNum not in detDontCareRectsNum: + if recallMat[gtNum, detNum] > 0: + cont = cont + 1 + return cont + + def num_overlaps_det(detNum): + cont = 0 + for gtNum in range(len(recallMat)): + if gtNum not in gtDontCareRectsNum: + if recallMat[gtNum, detNum] > 0: + cont = cont + 1 + return cont + + def is_single_overlap(row, col): + if num_overlaps_gt(row) == 1 and num_overlaps_det(col) == 1: + return True + else: + return False + + def one_to_many_match(gtNum): + many_sum = 0 + detRects = [] + for detNum in range(len(recallMat[0])): + if gtRectMat[gtNum] == 0 and detRectMat[detNum] == 0 and detNum not in detDontCareRectsNum: + if precisionMat[gtNum, detNum] >= evaluationParams['AREA_PRECISION_CONSTRAINT']: + many_sum += recallMat[gtNum, detNum] + detRects.append(detNum) + if round(many_sum, 4) >= evaluationParams['AREA_RECALL_CONSTRAINT']: + return True, detRects + else: + return False, [] + + def many_to_one_match(detNum): + many_sum = 0 + gtRects = [] + for gtNum in range(len(recallMat)): + if gtRectMat[gtNum] == 0 and detRectMat[detNum] == 0 and gtNum not in gtDontCareRectsNum: + if recallMat[gtNum, detNum] >= evaluationParams['AREA_RECALL_CONSTRAINT']: + many_sum += precisionMat[gtNum, detNum] + gtRects.append(gtNum) + if round(many_sum, 4) >= evaluationParams['AREA_PRECISION_CONSTRAINT']: + return True, gtRects + else: + return False, [] + + def area(a, b): + dx = min(a.xmax, b.xmax) - max(a.xmin, b.xmin) + 1 + dy = min(a.ymax, b.ymax) - max(a.ymin, b.ymin) + 1 + if (dx >= 0) and (dy >= 0): + return dx * dy + else: + return 0. + + def center(r): + x = float(r.xmin) + float(r.xmax - r.xmin + 1) / 2.; + y = float(r.ymin) + float(r.ymax - r.ymin + 1) / 2.; + return Point(x, y) + + def point_distance(r1, r2): + distx = math.fabs(r1.x - r2.x) + disty = math.fabs(r1.y - r2.y) + return math.sqrt(distx * distx + disty * disty) + + def center_distance(r1, r2): + return point_distance(center(r1), center(r2)) + + def diag(r): + w = (r.xmax - r.xmin + 1) + h = (r.ymax - r.ymin + 1) + return math.sqrt(h * h + w * w) + + def rectangle_to_points(rect): + points = [int(rect.xmin), int(rect.ymax), int(rect.xmax), int(rect.ymax), int(rect.xmax), int(rect.ymin), + int(rect.xmin), int(rect.ymin)] + return points + + perSampleMetrics = {} + + methodRecallSum = 0 + methodPrecisionSum = 0 + + Rectangle = namedtuple('Rectangle', 'xmin ymin xmax ymax') + Point = namedtuple('Point', 'x y') + + gt = rrc_evaluation_funcs.load_zip_file(gtFilePath, evaluationParams['GT_SAMPLE_NAME_2_ID']) + subm = rrc_evaluation_funcs.load_zip_file(submFilePath, evaluationParams['DET_SAMPLE_NAME_2_ID'], True) + + numGt = 0; + numDet = 0; + + for resFile in gt: + + gtFile = rrc_evaluation_funcs.decode_utf8(gt[resFile]) + recall = 0 + precision = 0 + hmean = 0 + recallAccum = 0. + precisionAccum = 0. + gtRects = [] + detRects = [] + gtPolPoints = [] + detPolPoints = [] + gtDontCareRectsNum = [] # Array of Ground Truth Rectangles' keys marked as don't Care + detDontCareRectsNum = [] # Array of Detected Rectangles' matched with a don't Care GT + pairs = [] + evaluationLog = "" + + recallMat = np.empty([1, 1]) + precisionMat = np.empty([1, 1]) + + pointsList, _, transcriptionsList = rrc_evaluation_funcs.get_tl_line_values_from_file_contents(gtFile, + evaluationParams[ + 'CRLF'], + True, True, + False) + for n in range(len(pointsList)): + points = pointsList[n] + transcription = transcriptionsList[n] + dontCare = transcription == "###" + gtRect = Rectangle(*points) + gtRects.append(gtRect) + gtPolPoints.append(points) + if dontCare: + gtDontCareRectsNum.append(len(gtRects) - 1) + + evaluationLog += "GT rectangles: " + str(len(gtRects)) + ( + " (" + str(len(gtDontCareRectsNum)) + " don't care)\n" if len(gtDontCareRectsNum) > 0 else "\n") + + if resFile in subm: + detFile = rrc_evaluation_funcs.decode_utf8(subm[resFile]) + pointsList, _, _ = rrc_evaluation_funcs.get_tl_line_values_from_file_contents(detFile, + evaluationParams['CRLF'], + True, False, False) + for n in range(len(pointsList)): + points = pointsList[n] + detRect = Rectangle(*points) + detRects.append(detRect) + detPolPoints.append(points) + if len(gtDontCareRectsNum) > 0: + for dontCareRectNum in gtDontCareRectsNum: + dontCareRect = gtRects[dontCareRectNum] + intersected_area = area(dontCareRect, detRect) + rdDimensions = ((detRect.xmax - detRect.xmin + 1) * (detRect.ymax - detRect.ymin + 1)); + if (rdDimensions == 0): + precision = 0 + else: + precision = intersected_area / rdDimensions + if (precision > evaluationParams['AREA_PRECISION_CONSTRAINT']): + detDontCareRectsNum.append(len(detRects) - 1) + break + + evaluationLog += "DET rectangles: " + str(len(detRects)) + ( + " (" + str(len(detDontCareRectsNum)) + " don't care)\n" if len(detDontCareRectsNum) > 0 else "\n") + + if len(gtRects) == 0: + recall = 1 + precision = 0 if len(detRects) > 0 else 1 + + if len(detRects) > 0: + # Calculate recall and precision matrixs + outputShape = [len(gtRects), len(detRects)] + recallMat = np.empty(outputShape) + precisionMat = np.empty(outputShape) + gtRectMat = np.zeros(len(gtRects), np.int8) + detRectMat = np.zeros(len(detRects), np.int8) + for gtNum in range(len(gtRects)): + for detNum in range(len(detRects)): + rG = gtRects[gtNum] + rD = detRects[detNum] + intersected_area = area(rG, rD) + rgDimensions = ((rG.xmax - rG.xmin + 1) * (rG.ymax - rG.ymin + 1)); + rdDimensions = ((rD.xmax - rD.xmin + 1) * (rD.ymax - rD.ymin + 1)); + recallMat[gtNum, detNum] = 0 if rgDimensions == 0 else intersected_area / rgDimensions + precisionMat[gtNum, detNum] = 0 if rdDimensions == 0 else intersected_area / rdDimensions + + # Find one-to-one matches + evaluationLog += "Find one-to-one matches\n" + for gtNum in range(len(gtRects)): + for detNum in range(len(detRects)): + if gtRectMat[gtNum] == 0 and detRectMat[ + detNum] == 0 and gtNum not in gtDontCareRectsNum and detNum not in detDontCareRectsNum: + match = one_to_one_match(gtNum, detNum) + if match is True: + # in deteval we have to make other validation before mark as one-to-one + if is_single_overlap(gtNum, detNum) is True: + rG = gtRects[gtNum] + rD = detRects[detNum] + normDist = center_distance(rG, rD); + normDist /= diag(rG) + diag(rD); + normDist *= 2.0; + if normDist < evaluationParams['EV_PARAM_IND_CENTER_DIFF_THR']: + gtRectMat[gtNum] = 1 + detRectMat[detNum] = 1 + recallAccum += evaluationParams['MTYPE_OO_O'] + precisionAccum += evaluationParams['MTYPE_OO_O'] + pairs.append({'gt': gtNum, 'det': detNum, 'type': 'OO'}) + evaluationLog += "Match GT #" + str(gtNum) + " with Det #" + str(detNum) + "\n" + else: + evaluationLog += "Match Discarded GT #" + str(gtNum) + " with Det #" + str( + detNum) + " normDist: " + str(normDist) + " \n" + else: + evaluationLog += "Match Discarded GT #" + str(gtNum) + " with Det #" + str( + detNum) + " not single overlap\n" + # Find one-to-many matches + evaluationLog += "Find one-to-many matches\n" + for gtNum in range(len(gtRects)): + if gtNum not in gtDontCareRectsNum: + match, matchesDet = one_to_many_match(gtNum) + if match is True: + evaluationLog += "num_overlaps_gt=" + str(num_overlaps_gt(gtNum)) + # in deteval we have to make other validation before mark as one-to-one + if num_overlaps_gt(gtNum) >= 2: + gtRectMat[gtNum] = 1 + recallAccum += ( + evaluationParams['MTYPE_OO_O'] if len(matchesDet) == 1 else evaluationParams[ + 'MTYPE_OM_O']) + precisionAccum += ( + evaluationParams['MTYPE_OO_O'] if len(matchesDet) == 1 else evaluationParams[ + 'MTYPE_OM_O'] * len( + matchesDet)) + pairs.append( + {'gt': gtNum, 'det': matchesDet, 'type': 'OO' if len(matchesDet) == 1 else 'OM'}) + for detNum in matchesDet: + detRectMat[detNum] = 1 + evaluationLog += "Match GT #" + str(gtNum) + " with Det #" + str(matchesDet) + "\n" + else: + evaluationLog += "Match Discarded GT #" + str(gtNum) + " with Det #" + str( + matchesDet) + " not single overlap\n" + + # Find many-to-one matches + evaluationLog += "Find many-to-one matches\n" + for detNum in range(len(detRects)): + if detNum not in detDontCareRectsNum: + match, matchesGt = many_to_one_match(detNum) + if match is True: + # in deteval we have to make other validation before mark as one-to-one + if num_overlaps_det(detNum) >= 2: + detRectMat[detNum] = 1 + recallAccum += ( + evaluationParams['MTYPE_OO_O'] if len(matchesGt) == 1 else evaluationParams[ + 'MTYPE_OM_M'] * len( + matchesGt)) + precisionAccum += ( + evaluationParams['MTYPE_OO_O'] if len(matchesGt) == 1 else evaluationParams[ + 'MTYPE_OM_M']) + pairs.append( + {'gt': matchesGt, 'det': detNum, 'type': 'OO' if len(matchesGt) == 1 else 'MO'}) + for gtNum in matchesGt: + gtRectMat[gtNum] = 1 + evaluationLog += "Match GT #" + str(matchesGt) + " with Det #" + str(detNum) + "\n" + else: + evaluationLog += "Match Discarded GT #" + str(matchesGt) + " with Det #" + str( + detNum) + " not single overlap\n" + + numGtCare = (len(gtRects) - len(gtDontCareRectsNum)) + if numGtCare == 0: + recall = float(1) + precision = float(0) if len(detRects) > 0 else float(1) + else: + recall = float(recallAccum) / numGtCare + precision = float(0) if (len(detRects) - len(detDontCareRectsNum)) == 0 else float( + precisionAccum) / (len(detRects) - len(detDontCareRectsNum)) + hmean = 0 if (precision + recall) == 0 else 2.0 * precision * recall / (precision + recall) + + methodRecallSum += recallAccum + methodPrecisionSum += precisionAccum + numGt += len(gtRects) - len(gtDontCareRectsNum) + numDet += len(detRects) - len(detDontCareRectsNum) + + perSampleMetrics[resFile] = { + 'precision': precision, + 'recall': recall, + 'hmean': hmean, + 'pairs': pairs, + 'recallMat': [] if len(detRects) > 100 else recallMat.tolist(), + 'precisionMat': [] if len(detRects) > 100 else precisionMat.tolist(), + 'gtPolPoints': gtPolPoints, + 'detPolPoints': detPolPoints, + 'gtDontCare': gtDontCareRectsNum, + 'detDontCare': detDontCareRectsNum, + 'evaluationParams': evaluationParams, + 'evaluationLog': evaluationLog + } + + methodRecall = 0 if numGt == 0 else methodRecallSum / numGt + methodPrecision = 0 if numDet == 0 else methodPrecisionSum / numDet + methodHmean = 0 if methodRecall + methodPrecision == 0 else 2 * methodRecall * methodPrecision / ( + methodRecall + methodPrecision) + + methodMetrics = {'precision': methodPrecision, 'recall': methodRecall, 'hmean': methodHmean} + + resDict = {'calculated': True, 'Message': '', 'method': methodMetrics, 'per_sample': perSampleMetrics} + + return resDict; + + +if __name__ == '__main__': + rrc_evaluation_funcs.main_evaluation(None, default_evaluation_params, validate_data, evaluate_method) diff --git a/PyTorch/contrib/cv/detection/CTPN/test/env.sh b/PyTorch/contrib/cv/detection/CTPN/test/env.sh new file mode 100644 index 0000000000..a975f7978c --- /dev/null +++ b/PyTorch/contrib/cv/detection/CTPN/test/env.sh @@ -0,0 +1,79 @@ +#!/bin/bash +export install_path=/usr/local/Ascend + +if [ -d ${install_path}/toolkit ]; then + export LD_LIBRARY_PATH=${install_path}/fwkacllib/lib64/:/usr/include/hdf5/lib/:/usr/local/:/usr/local/lib/:/usr/lib/:${install_path}/driver/lib64/common/:${install_path}/driver/lib64/driver/:${install_path}/add-ons:${path_lib}:${LD_LIBRARY_PATH} + export PATH=${install_path}/fwkacllib/ccec_compiler/bin:${install_path}/fwkacllib/bin:$PATH + export PYTHONPATH=${install_path}/fwkacllib/python/site-packages:${install_path}/tfplugin/python/site-packages:${install_path}/toolkit/python/site-packages:$PYTHONPATH + export PYTHONPATH=/usr/local/python3.7.5/lib/python3.7/site-packages:$PYTHONPATH + export ASCEND_OPP_PATH=${install_path}/opp +else + if [ -d ${install_path}/nnae/latest ];then + export LD_LIBRARY_PATH=${install_path}/nnae/latest/fwkacllib/lib64/:/usr/local/:/usr/local/python3.7.5/lib/:/usr/local/openblas/lib:/usr/local/lib/:/usr/lib64/:/usr/lib/:${install_path}/driver/lib64/common/:${install_path}/driver/lib64/driver/:${install_path}/add-ons/:/usr/lib/aarch64_64-linux-gnu:$LD_LIBRARY_PATH + export PATH=$PATH:${install_path}/nnae/latest/fwkacllib/ccec_compiler/bin/:${install_path}/nnae/latest/toolkit/tools/ide_daemon/bin/ + export ASCEND_OPP_PATH=${install_path}/nnae/latest/opp/ + export OPTION_EXEC_EXTERN_PLUGIN_PATH=${install_path}/nnae/latest/fwkacllib/lib64/plugin/opskernel/libfe.so:${install_path}/nnae/latest/fwkacllib/lib64/plugin/opskernel/libaicpu_engine.so:${install_path}/nnae/latest/fwkacllib/lib64/plugin/opskernel/libge_local_engine.so + export PYTHONPATH=${install_path}/nnae/latest/fwkacllib/python/site-packages/:${install_path}/nnae/latest/fwkacllib/python/site-packages/auto_tune.egg/auto_tune:${install_path}/nnae/latest/fwkacllib/python/site-packages/schedule_search.egg:$PYTHONPATH + export ASCEND_AICPU_PATH=${install_path}/nnae/latest + else + export LD_LIBRARY_PATH=${install_path}/ascend-toolkit/latest/fwkacllib/lib64/:/usr/local/:/usr/local/lib/:/usr/lib64/:/usr/lib/:/usr/local/python3.7.5/lib/:/usr/local/openblas/lib:${install_path}/driver/lib64/common/:${install_path}/driver/lib64/driver/:${install_path}/add-ons/:/usr/lib/aarch64-linux-gnu:$LD_LIBRARY_PATH + export PATH=$PATH:${install_path}/ascend-toolkit/latest/fwkacllib/ccec_compiler/bin/:${install_path}/ascend-toolkit/latest/toolkit/tools/ide_daemon/bin/ + export ASCEND_OPP_PATH=${install_path}/ascend-toolkit/latest/opp/ + export OPTION_EXEC_EXTERN_PLUGIN_PATH=${install_path}/ascend-toolkit/latest/fwkacllib/lib64/plugin/opskernel/libfe.so:${install_path}/ascend-toolkit/latest/fwkacllib/lib64/plugin/opskernel/libaicpu_engine.so:${install_path}/ascend-toolkit/latest/fwkacllib/lib64/plugin/opskernel/libge_local_engine.so + export PYTHONPATH=${install_path}/ascend-toolkit/latest/fwkacllib/python/site-packages/:${install_path}/ascend-toolkit/latest/fwkacllib/python/site-packages/auto_tune.egg/auto_tune:${install_path}/ascend-toolkit/latest/fwkacllib/python/site-packages/schedule_search.egg:$PYTHONPATH + export ASCEND_AICPU_PATH=${install_path}/ascend-toolkit/latest + fi +fi + +${install_path}/driver/tools/msnpureport -g error -d 0 +${install_path}/driver/tools/msnpureport -g error -d 1 +${install_path}/driver/tools/msnpureport -g error -d 2 +${install_path}/driver/tools/msnpureport -g error -d 3 +${install_path}/driver/tools/msnpureport -g error -d 4 +${install_path}/driver/tools/msnpureport -g error -d 5 +${install_path}/driver/tools/msnpureport -g error -d 6 +${install_path}/driver/tools/msnpureport -g error -d 7 + +#将Host日志输出到串口,0-关闭/1-开启 +export ASCEND_SLOG_PRINT_TO_STDOUT=0 +#设置默认日志级别,0-debug/1-info/2-warning/3-error +export ASCEND_GLOBAL_LOG_LEVEL=3 +#设置Event日志开启标志,0-关闭/1-开启 +export ASCEND_GLOBAL_EVENT_ENABLE=0 +#设置是否开启taskque,0-关闭/1-开启 +export TASK_QUEUE_ENABLE=1 +#设置是否开启PTCopy,0-关闭/1-开启 +export PTCOPY_ENABLE=1 +#设置是否开启2个非连续combined标志,0-关闭/1-开启 +export COMBINED_ENABLE=1 +#设置是否开启3个非连续combined标志,0-关闭/1-开启 +export TRI_COMBINED_ENABLE=1 +#设置特殊场景是否需要重新编译,不需要修改 +export DYNAMIC_OP="ADD#MUL" +# HCCL白名单开关,1-关闭/0-开启 +export HCCL_WHITELIST_DISABLE=1 +# HCCL默认超时时间120s较少,修改为1800s对齐PyTorch默认设置 +export HCCL_CONNECT_TIMEOUT=1800 + +ulimit -SHn 512000 + +path_lib=$(python3.7 -c """ +import sys +import re +result='' +for index in range(len(sys.path)): + match_sit = re.search('-packages', sys.path[index]) + if match_sit is not None: + match_lib = re.search('lib', sys.path[index]) + + if match_lib is not None: + end=match_lib.span()[1] + result += sys.path[index][0:end] + ':' + + result+=sys.path[index] + '/torch/lib:' +print(result)""" +) + +echo ${path_lib} + +export LD_LIBRARY_PATH=/usr/local/python3.7.5/lib/:${path_lib}:$LD_LIBRARY_PATH \ No newline at end of file diff --git a/PyTorch/contrib/cv/detection/CTPN/.keep b/PyTorch/contrib/cv/detection/CTPN/test/output/.keep similarity index 100% rename from PyTorch/contrib/cv/detection/CTPN/.keep rename to PyTorch/contrib/cv/detection/CTPN/test/output/.keep diff --git a/PyTorch/contrib/cv/detection/CTPN/test/train_eval.sh b/PyTorch/contrib/cv/detection/CTPN/test/train_eval.sh new file mode 100644 index 0000000000..7cc181afed --- /dev/null +++ b/PyTorch/contrib/cv/detection/CTPN/test/train_eval.sh @@ -0,0 +1,101 @@ +#!/bin/bash + +################基础配置参数,需要模型审视修改################## +# 必选字段(必须在此处定义的参数): Network batch_size RANK_SIZE +# 网络名称,同目录名称 +Network="CTPN" +# 训练batch_size +batch_size=1 +# 训练使用的npu卡数 +export RANK_SIZE=1 +# 数据集路径,保持为空,不需要修改 +data_path="" +checkpoint="" + +# 训练epoch 210 +train_epochs=200 + +# 参数校验,data_path为必传参数,其他参数的增删由模型自身决定;此处新增参数需在上面有定义并赋值 +for para in $* +do + if [[ $para == --data_path* ]];then + data_path=`echo ${para#*=}` + elif [[ $para == --pth_path* ]];then + checkpoint=`echo ${para#*=}` + fi +done + +# 校验是否传入data_path,不需要修改 +if [[ $data_path == "" ]];then + echo "[Error] para \"data_path\" must be confing" + exit 1 +fi + +###############指定训练脚本执行路径############### +# cd到与test文件夹同层级目录下执行脚本,提高兼容性;test_path_dir为包含test文件夹的路径 +cur_path=`pwd` +cur_path_last_dirname=${cur_path##*/} +if [ x"${cur_path_last_dirname}" == x"test" ];then + test_path_dir=${cur_path} + cd .. + cur_path=`pwd` +else + test_path_dir=${cur_path}/test +fi + + +#################创建日志输出目录,不需要修改################# +ASCEND_DEVICE_ID=0 +if [ -d ${test_path_dir}/output/${ASCEND_DEVICE_ID} ];then + rm -rf ${test_path_dir}/output/${ASCEND_DEVICE_ID} + mkdir -p ${test_path_dir}/output/$ASCEND_DEVICE_ID +else + mkdir -p ${test_path_dir}/output/$ASCEND_DEVICE_ID +fi + +#################启动训练脚本################# +# 训练开始时间,不需要修改 +start_time=$(date +%s) +# 非平台场景时source 环境变量 +check_etp_flag=`env | grep etp_running_flag` +etp_flag=`echo ${check_etp_flag#*=}` +if [ x"${etp_flag}" != x"true" ];then + source ${test_path_dir}/env.sh +fi + +python3.7 predict_2_txt.py --data-path=$data_path\ + --model-path=$checkpoint\ + > ${test_path_dir}/output/${ASCEND_DEVICE_ID}/train_${ASCEND_DEVICE_ID}.log 2>&1 & +wait + +zip -pj ./predict.zip ./results/predict_txt_2/* > zip.log 2>&1 & +wait + +# 输出训练精度,需要模型审视修改 +echo "Final Train Accuracy : " +python3.7 scripts/script.py -g=gt.zip -s=predict.zip +wait +echo \ +##################获取训练数据################ +# 训练结束时间,不需要修改 +end_time=$(date +%s) +e2e_time=$(( $end_time - $start_time )) + +# 训练用例信息,不需要修改 +BatchSize=${batch_size} +DeviceType=`uname -m` +CaseName=${Network}_bs${BatchSize}_${RANK_SIZE}'p'_'eval' + +# 打印,不需要修改 + +echo "E2E Training Duration sec : $e2e_time" + + +# 关键信息打印到${CaseName}.log中,不需要修改 +echo "Network = ${Network}" > ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "RankSize = ${RANK_SIZE}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "BatchSize = ${BatchSize}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "DeviceType = ${DeviceType}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "CaseName = ${CaseName}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "TrainAccuracy = ${train_accuracy}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "E2ETrainingTime = ${e2e_time}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log diff --git a/PyTorch/contrib/cv/detection/CTPN/test/train_full_1p.sh b/PyTorch/contrib/cv/detection/CTPN/test/train_full_1p.sh new file mode 100644 index 0000000000..ee5181af99 --- /dev/null +++ b/PyTorch/contrib/cv/detection/CTPN/test/train_full_1p.sh @@ -0,0 +1,140 @@ +#!/bin/bash + +##################基础配置参数,需要模型审视修改################## +# 必选字段(必须在此处定义的参数): Network batch_size RANK_SIZE +# 网络名称,同目录名称 +Network="CTPN" +# 训练batch_size +batch_size=1 +# 训练使用的npu卡数 +export RANK_SIZE=1 +# 数据集路径,保持为空,不需要修改 +data_path="" + +# 训练epoch +train_epochs=200 +# 指定训练所使用的npu device卡id +device_id=0 +# 学习率 +learning_rate=0.0003 +# 加载数据进程数 +workers=128 + + +# 参数校验,data_path为必传参数, 其他参数的增删由模型自身决定;此处若新增参数需在上面有定义并赋值 +for para in $* +do + if [[ $para == --world_size* ]];then + world_size=`echo ${para#*=}` + elif [[ $para == --data_path* ]];then + data_path=`echo ${para#*=}` + fi +done + +# 校验是否传入data_path,不需要修改 +if [[ $data_path == "" ]];then + echo "[Error] para \"data_path\" must be confing" + exit 1 +fi + +# 校验单卡训练是否指定了device id,分动态分配device id 与手动指定device id,此处不需要修改 +if [ $ASCEND_DEVICE_ID ];then + echo "device id is ${ASCEND_DEVICE_ID}" + ln -s source dest +elif [ ${device_id} ]; then + export ASCEND_DEVICE_ID=${device_id} + echo "device id is ${ASCEND_DEVICE_ID}" +else + echo "[Error] device id must be confing" + exit 1 +fi + +#################指定训练脚本执行路径################## +# cd到与test文件同层级目录下执行脚本,提高兼容性;test_path_dir为包含test文件夹的路径 +cur_path=`pwd` +cur_path_last_dirname=${cur_path##*/} +if [ x"${cur_path_last_dirname}" == x"test" ]; then + test_path_dir=${cur_path} + cd .. + cur_path=`pwd` +else + test_path_dir=${cur_path}/test +fi + +##################创建日志输出目录,不需要修改################## +ASCEND_DEVICE_ID=${device_id} +if [ -d ${test_path_dir}/output/$ASCEND_DEVICE_ID ];then + rm -rf ${test_path_dir}/output/$ASCEND_DEVICE_ID + mkdir -p ${test_path_dir}/output/$ASCEND_DEVICE_ID +else + mkdir -p ${test_path_dir}/output/$ASCEND_DEVICE_ID +fi + +##################启动训练脚本################## +# 训练开始时间,不需要修改 +start_time=$(date +%s) +# source 环境变量 +source ${test_path_dir}/env.sh +python3.7 train.py \ + --addr=$(hostname -I |awk '{print $1}') \ + --seed=49 \ + --workers=128 \ + --learning-rate=3e-4 \ + --mom=0.9 \ + --weight-decay=1.0e-04 \ + --print-freq=1 \ + --device_list='0' \ + --gpu=0 \ + --device='npu' \ + --epochs=200\ + --amp \ + --world-size 1 \ + --data-path=${data_path} \ + --batch-size=1 >${test_path_dir}/output/$ASCEND_DEVICE_ID/train_$ASCEND_DEVICE_ID.log 2>&1 & +wait + + +##################获取训练数据################## +# 训练结束时间,不需要修改 +end_time=$(date +%s) +e2e_time=$(( $end_time - $start_time )) + +# 终端结果打印,不需要修改 +echo "------------------ Final result ------------------" +# 单迭代训练时长 +OverallFPS=`grep -a 'overallFPS after training:' ${test_path_dir}/output/${ASCEND_DEVICE_ID}/train_${ASCEND_DEVICE_ID}.log|awk 'END {print}' | awk -F "overallFPS after training:" '{print $NF}'| awk -F " " '{print $1}'` +# 输出性能FPS,需要模型审视修改 +FPS=`awk 'BEGIN{printf "%.2f\n", '${OverallFPS}'}'` +# 打印,不需要修改 +echo "Final Performance FPS : ${OverallFPS}" +echo "E2E Training Duration sec : $e2e_time" + +# 性能看护结果汇总 +# 训练用例信息,不需要修改 +BatchSize=${batch_size} +DeviceType=`uname -m` +CaseName=${Network}_bs${BatchSize}_${RANK_SIZE}'p'_'perf' + +# 获取性能数据,不需要修改 +# 吞吐量 +ActualFPS=${FPS} +TrainingTime=`awk 'BEGIN{printf "%.2f\n", '${batch_size}'*1000/'${FPS}'}'` + +# 从train_$ASCEND_DEVICE_ID.log提取Loss到train_${CaseName}_loss.txt中,需要根据模型审视 +grep -a 'Stage0-heatmaps:' ${test_path_dir}/output/$ASCEND_DEVICE_ID/train_$ASCEND_DEVICE_ID.log|awk -F "Stage0-heatmaps:" '{print $NF}'|awk -F " " '{print $1}'>> ${test_path_dir}/output/$ASCEND_DEVICE_ID/train_${CaseName}_loss.txt + +# 最后一个迭代loss值,不需要修改 +ActualLoss=`awk 'END {print}' ${test_path_dir}/output/$ASCEND_DEVICE_ID/train_${CaseName}_loss.txt` + + +##################将训练数据存入文件################## +# 关键信息打印到${CaseName}.log中,不需要修改 +echo "Network = ${Network}" > ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "RankSize = ${RANK_SIZE}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "BatchSize = ${BatchSize}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "DeviceType = ${DeviceType}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "CaseName = ${CaseName}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "ActualFPS = ${ActualFPS}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "TrainingTime = ${TrainingTime}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "ActualLoss = ${ActualLoss}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "E2ETrainingTime = ${e2e_time}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log \ No newline at end of file diff --git a/PyTorch/contrib/cv/detection/CTPN/test/train_full_8p.sh b/PyTorch/contrib/cv/detection/CTPN/test/train_full_8p.sh new file mode 100644 index 0000000000..ac08fb4708 --- /dev/null +++ b/PyTorch/contrib/cv/detection/CTPN/test/train_full_8p.sh @@ -0,0 +1,142 @@ +#!/bin/bash + +##################基础配置参数,需要模型审视修改################## +# 必选字段(必须在此处定义的参数): Network batch_size RANK_SIZE +# 网络名称,同目录名称 +Network="CTPN" +# 训练batch_size +batch_size=8 +# 训练使用的npu卡数 +export RANK_SIZE=8 +# 数据集路径,保持为空,不需要修改 +data_path="" + +# 训练epoch +train_epochs=200 +# 指定训练所使用的npu device卡id +device_id=0 +# 学习率 +learning_rate=0.0003 +# 加载数据进程数 +workers=128 + + +# 参数校验,data_path为必传参数, 其他参数的增删由模型自身决定;此处若新增参数需在上面有定义并赋值 +for para in $* +do + if [[ $para == --world_size* ]];then + world_size=`echo ${para#*=}` + elif [[ $para == --data_path* ]];then + data_path=`echo ${para#*=}` + fi +done + +# 校验是否传入data_path,不需要修改 +if [[ $data_path == "" ]];then + echo "[Error] para \"data_path\" must be confing" + exit 1 +fi + +# 校验单卡训练是否指定了device id,分动态分配device id 与手动指定device id,此处不需要修改 +if [ $ASCEND_DEVICE_ID ];then + echo "device id is ${ASCEND_DEVICE_ID}" + ln -s source dest +elif [ ${device_id} ]; then + export ASCEND_DEVICE_ID=${device_id} + echo "device id is ${ASCEND_DEVICE_ID}" +else + echo "[Error] device id must be confing" + exit 1 +fi + +#################指定训练脚本执行路径################## +# cd到与test文件同层级目录下执行脚本,提高兼容性;test_path_dir为包含test文件夹的路径 +cur_path=`pwd` +cur_path_last_dirname=${cur_path##*/} +if [ x"${cur_path_last_dirname}" == x"test" ]; then + test_path_dir=${cur_path} + cd .. + cur_path=`pwd` +else + test_path_dir=${cur_path}/test +fi + +##################创建日志输出目录,不需要修改################## +ASCEND_DEVICE_ID=${device_id} +if [ -d ${test_path_dir}/output/$ASCEND_DEVICE_ID ];then + rm -rf ${test_path_dir}/output/$ASCEND_DEVICE_ID + mkdir -p ${test_path_dir}/output/$ASCEND_DEVICE_ID +else + mkdir -p ${test_path_dir}/output/$ASCEND_DEVICE_ID +fi + +##################启动训练脚本################## +# 训练开始时间,不需要修改 +start_time=$(date +%s) +# source 环境变量 +source ${test_path_dir}/env.sh +python3.7 train.py \ + --addr=$(hostname -I |awk '{print $1}') \ + --seed=49 \ + --workers=128 \ + --learning-rate=3e-4 \ + --mom=0.9 \ + --weight-decay=1.0e-04 \ + --print-freq=1 \ + --dist-url='tcp://127.0.0.1:50001' \ + --dist-backend 'hccl' \ + --multiprocessing-distributed \ + --world-size=1 \ + --rank=0 \ + --device='npu' \ + --epochs=200\ + --amp \ + --data-path=${data_path} \ + --batch-size=8 >${test_path_dir}/output/$ASCEND_DEVICE_ID/train_$ASCEND_DEVICE_ID.log 2>&1 & +wait + + +##################获取训练数据################## +# 训练结束时间,不需要修改 +end_time=$(date +%s) +e2e_time=$(( $end_time - $start_time )) + +# 终端结果打印,不需要修改 +echo "------------------ Final result ------------------" +# 单迭代训练时长 +OverallFPS=`grep -a 'overallFPS after training:' ${test_path_dir}/output/${ASCEND_DEVICE_ID}/train_${ASCEND_DEVICE_ID}.log|awk 'END {print}' | awk -F "overallFPS after training:" '{print $NF}'| awk -F " " '{print $1}'` +# 输出性能FPS,需要模型审视修改 +FPS=`awk 'BEGIN{printf "%.2f\n", '${OverallFPS}'}'` +# 打印,不需要修改 +echo "Final Performance FPS : ${OverallFPS}" +echo "E2E Training Duration sec : $e2e_time" + +# 性能看护结果汇总 +# 训练用例信息,不需要修改 +BatchSize=${batch_size} +DeviceType=`uname -m` +CaseName=${Network}_bs${BatchSize}_${RANK_SIZE}'p'_'perf' + +# 获取性能数据,不需要修改 +# 吞吐量 +ActualFPS=${FPS} +TrainingTime=`awk 'BEGIN{printf "%.2f\n", '${batch_size}'*1000/'${FPS}'}'` + +# 从train_$ASCEND_DEVICE_ID.log提取Loss到train_${CaseName}_loss.txt中,需要根据模型审视 +grep -a 'Stage0-heatmaps:' ${test_path_dir}/output/$ASCEND_DEVICE_ID/train_$ASCEND_DEVICE_ID.log|awk -F "Stage0-heatmaps:" '{print $NF}'|awk -F " " '{print $1}'>> ${test_path_dir}/output/$ASCEND_DEVICE_ID/train_${CaseName}_loss.txt + +# 最后一个迭代loss值,不需要修改 +ActualLoss=`awk 'END {print}' ${test_path_dir}/output/$ASCEND_DEVICE_ID/train_${CaseName}_loss.txt` + + +##################将训练数据存入文件################## +# 关键信息打印到${CaseName}.log中,不需要修改 +echo "Network = ${Network}" > ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "RankSize = ${RANK_SIZE}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "BatchSize = ${BatchSize}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "DeviceType = ${DeviceType}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "CaseName = ${CaseName}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "ActualFPS = ${ActualFPS}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "TrainingTime = ${TrainingTime}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "ActualLoss = ${ActualLoss}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "E2ETrainingTime = ${e2e_time}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log \ No newline at end of file diff --git a/PyTorch/contrib/cv/detection/CTPN/test/train_performance_1p.sh b/PyTorch/contrib/cv/detection/CTPN/test/train_performance_1p.sh new file mode 100644 index 0000000000..00ece72256 --- /dev/null +++ b/PyTorch/contrib/cv/detection/CTPN/test/train_performance_1p.sh @@ -0,0 +1,140 @@ +#!/bin/bash + +##################基础配置参数,需要模型审视修改################## +# 必选字段(必须在此处定义的参数): Network batch_size RANK_SIZE +# 网络名称,同目录名称 +Network="CTPN" +# 训练batch_size +batch_size=1 +# 训练使用的npu卡数 +export RANK_SIZE=1 +# 数据集路径,保持为空,不需要修改 +data_path="" + +# 训练epoch +train_epochs=20 +# 指定训练所使用的npu device卡id +device_id=0 +# 学习率 +learning_rate=0.0003 +# 加载数据进程数 +workers=128 + + +# 参数校验,data_path为必传参数, 其他参数的增删由模型自身决定;此处若新增参数需在上面有定义并赋值 +for para in $* +do + if [[ $para == --world_size* ]];then + world_size=`echo ${para#*=}` + elif [[ $para == --data_path* ]];then + data_path=`echo ${para#*=}` + fi +done + +# 校验是否传入data_path,不需要修改 +if [[ $data_path == "" ]];then + echo "[Error] para \"data_path\" must be confing" + exit 1 +fi + +# 校验单卡训练是否指定了device id,分动态分配device id 与手动指定device id,此处不需要修改 +if [ $ASCEND_DEVICE_ID ];then + echo "device id is ${ASCEND_DEVICE_ID}" + ln -s source dest +elif [ ${device_id} ]; then + export ASCEND_DEVICE_ID=${device_id} + echo "device id is ${ASCEND_DEVICE_ID}" +else + echo "[Error] device id must be confing" + exit 1 +fi + +#################指定训练脚本执行路径################## +# cd到与test文件同层级目录下执行脚本,提高兼容性;test_path_dir为包含test文件夹的路径 +cur_path=`pwd` +cur_path_last_dirname=${cur_path##*/} +if [ x"${cur_path_last_dirname}" == x"test" ]; then + test_path_dir=${cur_path} + cd .. + cur_path=`pwd` +else + test_path_dir=${cur_path}/test +fi + +##################创建日志输出目录,不需要修改################## +ASCEND_DEVICE_ID=${device_id} +if [ -d ${test_path_dir}/output/$ASCEND_DEVICE_ID ];then + rm -rf ${test_path_dir}/output/$ASCEND_DEVICE_ID + mkdir -p ${test_path_dir}/output/$ASCEND_DEVICE_ID +else + mkdir -p ${test_path_dir}/output/$ASCEND_DEVICE_ID +fi + +##################启动训练脚本################## +# 训练开始时间,不需要修改 +start_time=$(date +%s) +# source 环境变量 +source ${test_path_dir}/env.sh +python3.7 train.py \ + --addr=$(hostname -I |awk '{print $1}') \ + --seed=49 \ + --workers=128 \ + --learning-rate=3e-4 \ + --mom=0.9 \ + --weight-decay=1.0e-04 \ + --print-freq=1 \ + --device_list='0' \ + --gpu=0 \ + --device='npu' \ + --epochs=20\ + --amp \ + --world-size 1 \ + --data-path=${data_path} \ + --batch-size=1 >${test_path_dir}/output/$ASCEND_DEVICE_ID/train_$ASCEND_DEVICE_ID.log 2>&1 & +wait + + +##################获取训练数据################## +# 训练结束时间,不需要修改 +end_time=$(date +%s) +e2e_time=$(( $end_time - $start_time )) + +# 终端结果打印,不需要修改 +echo "------------------ Final result ------------------" +# 单迭代训练时长 +OverallFPS=`grep -a 'overallFPS after training:' ${test_path_dir}/output/${ASCEND_DEVICE_ID}/train_${ASCEND_DEVICE_ID}.log|awk 'END {print}' | awk -F "overallFPS after training:" '{print $NF}'| awk -F " " '{print $1}'` +# 输出性能FPS,需要模型审视修改 +FPS=`awk 'BEGIN{printf "%.2f\n", '${OverallFPS}'}'` +# 打印,不需要修改 +echo "Final Performance FPS : ${OverallFPS}" +echo "E2E Training Duration sec : $e2e_time" + +# 性能看护结果汇总 +# 训练用例信息,不需要修改 +BatchSize=${batch_size} +DeviceType=`uname -m` +CaseName=${Network}_bs${BatchSize}_${RANK_SIZE}'p'_'perf' + +# 获取性能数据,不需要修改 +# 吞吐量 +ActualFPS=${FPS} +TrainingTime=`awk 'BEGIN{printf "%.2f\n", '${batch_size}'*1000/'${FPS}'}'` + +# 从train_$ASCEND_DEVICE_ID.log提取Loss到train_${CaseName}_loss.txt中,需要根据模型审视 +grep -a 'Stage0-heatmaps:' ${test_path_dir}/output/$ASCEND_DEVICE_ID/train_$ASCEND_DEVICE_ID.log|awk -F "Stage0-heatmaps:" '{print $NF}'|awk -F " " '{print $1}'>> ${test_path_dir}/output/$ASCEND_DEVICE_ID/train_${CaseName}_loss.txt + +# 最后一个迭代loss值,不需要修改 +ActualLoss=`awk 'END {print}' ${test_path_dir}/output/$ASCEND_DEVICE_ID/train_${CaseName}_loss.txt` + + +##################将训练数据存入文件################## +# 关键信息打印到${CaseName}.log中,不需要修改 +echo "Network = ${Network}" > ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "RankSize = ${RANK_SIZE}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "BatchSize = ${BatchSize}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "DeviceType = ${DeviceType}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "CaseName = ${CaseName}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "ActualFPS = ${ActualFPS}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "TrainingTime = ${TrainingTime}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "ActualLoss = ${ActualLoss}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "E2ETrainingTime = ${e2e_time}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log \ No newline at end of file diff --git a/PyTorch/contrib/cv/detection/CTPN/test/train_performance_8p.sh b/PyTorch/contrib/cv/detection/CTPN/test/train_performance_8p.sh new file mode 100644 index 0000000000..610c4f2a52 --- /dev/null +++ b/PyTorch/contrib/cv/detection/CTPN/test/train_performance_8p.sh @@ -0,0 +1,141 @@ +#!/bin/bash + +##################基础配置参数,需要模型审视修改################## +# 必选字段(必须在此处定义的参数): Network batch_size RANK_SIZE +# 网络名称,同目录名称 +Network="CTPN" +# 训练batch_size +batch_size=8 +# 训练使用的npu卡数 +export RANK_SIZE=8 +# 数据集路径,保持为空,不需要修改 +data_path="" + +# 训练epoch +train_epochs=200 +# 指定训练所使用的npu device卡id +device_id=0 +# 学习率 +learning_rate=0.0003 +# 加载数据进程数 +workers=128 + + +# 参数校验,data_path为必传参数, 其他参数的增删由模型自身决定;此处若新增参数需在上面有定义并赋值 +for para in $* +do + if [[ $para == --world_size* ]];then + world_size=`echo ${para#*=}` + elif [[ $para == --data_path* ]];then + data_path=`echo ${para#*=}` + fi +done + +# 校验是否传入data_path,不需要修改 +if [[ $data_path == "" ]];then + echo "[Error] para \"data_path\" must be confing" + exit 1 +fi + +# 校验单卡训练是否指定了device id,分动态分配device id 与手动指定device id,此处不需要修改 +if [ $ASCEND_DEVICE_ID ];then + echo "device id is ${ASCEND_DEVICE_ID}" + ln -s source dest +elif [ ${device_id} ]; then + export ASCEND_DEVICE_ID=${device_id} + echo "device id is ${ASCEND_DEVICE_ID}" +else + echo "[Error] device id must be confing" + exit 1 +fi + +#################指定训练脚本执行路径################## +# cd到与test文件同层级目录下执行脚本,提高兼容性;test_path_dir为包含test文件夹的路径 +cur_path=`pwd` +cur_path_last_dirname=${cur_path##*/} +if [ x"${cur_path_last_dirname}" == x"test" ]; then + test_path_dir=${cur_path} + cd .. + cur_path=`pwd` +else + test_path_dir=${cur_path}/test +fi + +##################创建日志输出目录,不需要修改################## +ASCEND_DEVICE_ID=${device_id} +if [ -d ${test_path_dir}/output/$ASCEND_DEVICE_ID ];then + rm -rf ${test_path_dir}/output/$ASCEND_DEVICE_ID + mkdir -p ${test_path_dir}/output/$ASCEND_DEVICE_ID +else + mkdir -p ${test_path_dir}/output/$ASCEND_DEVICE_ID +fi + +##################启动训练脚本################## +# 训练开始时间,不需要修改 +start_time=$(date +%s) +# source 环境变量 +source ${test_path_dir}/env.sh +python3.7 train.py \ + --addr=$(hostname -I |awk '{print $1}') \ + --seed=49 \ + --workers=128 \ + --learning-rate=3e-4 \ + --mom=0.9 \ + --weight-decay=1.0e-04 \ + --print-freq=1 \ + --dist-url='tcp://127.0.0.1:50001' \ + --dist-backend 'hccl' \ + --multiprocessing-distributed \ + --world-size=1 \ + --rank=0 \ + --device='npu' \ + --epochs=20\ + --amp \ + --data-path=${data_path} \ + --batch-size=8 >${test_path_dir}/output/$ASCEND_DEVICE_ID/train_$ASCEND_DEVICE_ID.log 2>&1 & +wait + +##################获取训练数据################## +# 训练结束时间,不需要修改 +end_time=$(date +%s) +e2e_time=$(( $end_time - $start_time )) + +# 终端结果打印,不需要修改 +echo "------------------ Final result ------------------" +# 单迭代训练时长 +OverallFPS=`grep -a 'overallFPS after training:' ${test_path_dir}/output/${ASCEND_DEVICE_ID}/train_${ASCEND_DEVICE_ID}.log|awk 'END {print}' | awk -F "overallFPS after training:" '{print $NF}'| awk -F " " '{print $1}'` +# 输出性能FPS,需要模型审视修改 +FPS=`awk 'BEGIN{printf "%.2f\n", '${OverallFPS}'}'` +# 打印,不需要修改 +echo "Final Performance FPS : ${OverallFPS}" +echo "E2E Training Duration sec : $e2e_time" + +# 性能看护结果汇总 +# 训练用例信息,不需要修改 +BatchSize=${batch_size} +DeviceType=`uname -m` +CaseName=${Network}_bs${BatchSize}_${RANK_SIZE}'p'_'perf' + +# 获取性能数据,不需要修改 +# 吞吐量 +ActualFPS=${FPS} +TrainingTime=`awk 'BEGIN{printf "%.2f\n", '${batch_size}'*1000/'${FPS}'}'` + +# 从train_$ASCEND_DEVICE_ID.log提取Loss到train_${CaseName}_loss.txt中,需要根据模型审视 +grep -a 'Stage0-heatmaps:' ${test_path_dir}/output/$ASCEND_DEVICE_ID/train_$ASCEND_DEVICE_ID.log|awk -F "Stage0-heatmaps:" '{print $NF}'|awk -F " " '{print $1}'>> ${test_path_dir}/output/$ASCEND_DEVICE_ID/train_${CaseName}_loss.txt + +# 最后一个迭代loss值,不需要修改 +ActualLoss=`awk 'END {print}' ${test_path_dir}/output/$ASCEND_DEVICE_ID/train_${CaseName}_loss.txt` + + +##################将训练数据存入文件################## +# 关键信息打印到${CaseName}.log中,不需要修改 +echo "Network = ${Network}" > ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "RankSize = ${RANK_SIZE}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "BatchSize = ${BatchSize}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "DeviceType = ${DeviceType}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "CaseName = ${CaseName}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "ActualFPS = ${ActualFPS}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "TrainingTime = ${TrainingTime}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "ActualLoss = ${ActualLoss}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "E2ETrainingTime = ${e2e_time}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log \ No newline at end of file diff --git a/PyTorch/contrib/cv/detection/CTPN/train.py b/PyTorch/contrib/cv/detection/CTPN/train.py new file mode 100644 index 0000000000..89f37584ad --- /dev/null +++ b/PyTorch/contrib/cv/detection/CTPN/train.py @@ -0,0 +1,545 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import warnings + +warnings.filterwarnings('ignore') +import argparse +import os +import random +import shutil +import time +import torch +import numpy as np +import apex +from apex import amp +import torch.nn as nn +import torch.nn.parallel +import torch.npu +import torch.backends.cudnn as cudnn +import torch.distributed as dist +import torch.optim +import torch.multiprocessing as mp +import torch.utils.data +import torch.utils.data.distributed +import torchvision.transforms as transforms +import torchvision.datasets as datasets + +from ctpn.ctpn import CTPN_Model, RPN_CLS_Loss, RPN_REGR_Loss +from ctpn.dataset import VOCDataset +from ctpn.dataset import ICDARDataset +from ctpn.dataset import icdarDataset +from ctpn import config + +# torch.multiprocessing.set_sharing_strategy('file_system') +parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') +parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', + help='number of data loading workers (default: 4)') +parser.add_argument('--epochs', default=90, type=int, metavar='N', + help='number of total epochs to run') +parser.add_argument('--start-epoch', default=0, type=int, metavar='N', + help='manual epoch number (useful on restarts)') +parser.add_argument('-b', '--batch-size', default=256, type=int, + metavar='N', + help='mini-batch size (default: 256), this is the total ' + 'batch size of all GPUs on the current node when ' + 'using Data Parallel or Distributed Data Parallel') +parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, + metavar='LR', help='initial learning rate', dest='lr') +parser.add_argument('--momentum', default=0.9, type=float, metavar='M', + help='momentum') +parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, + metavar='W', help='weight decay (default: 1e-4)', + dest='weight_decay') +parser.add_argument('-p', '--print-freq', default=10, type=int, + metavar='N', help='print frequency (default: 10)') +parser.add_argument('--resume', default='', type=str, metavar='PATH', + help='path to latest checkpoint (default: none)') +parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', + help='evaluate model on validation set') +parser.add_argument('--pretrained', dest='pretrained', action='store_true', + help='use pre-trained model') +parser.add_argument('--world-size', default=-1, type=int, + help='number of nodes for distributed training') +parser.add_argument('--data-path', default='./dataset/icdar13', type=str, + help='number of nodes for distributed training') +parser.add_argument('--rank', default=-1, type=int, + help='node rank for distributed training') +parser.add_argument('--dist-url', default='tcp://127.0.0.1:50001', type=str, + help='url used to set up distributed training') +parser.add_argument('--dist-backend', default='nccl', type=str, + help='distributed backend') +parser.add_argument('--seed', default=None, type=int, + help='seed for initializing training. ') +parser.add_argument('--gpu', default=None, type=int, + help='GPU id to use.') +parser.add_argument('--multiprocessing-distributed', action='store_true', + help='Use multi-processing distributed training to launch ' + 'N processes per node, which has N GPUs. This is the ' + 'fastest way to use PyTorch for either single node or ' + 'multi node data parallel training') +## for ascend 910 +parser.add_argument('--device', default='npu', type=str, help='npu or gpu') +parser.add_argument('--addr', default='10.136.181.115', + type=str, help='master addr') +parser.add_argument('--device_list', default='0,1,2,3,4,5,6,7', + type=str, help='device id list') +parser.add_argument('--amp', default=False, action='store_true', + help='use amp to train the model') +parser.add_argument('--loss-scale', default=64., type=float, + help='loss scale using in amp, default -1 means dynamic') +parser.add_argument('--opt-level', default='O2', type=str, + help='loss scale using in amp, default -1 means dynamic') +parser.add_argument('--prof', default=False, action='store_true', + help='use profiling to evaluate the performance of model') +parser.add_argument('--warm_up_epochs', default=5, type=int, + help='warm up') + + +def device_id_to_process_device_map(device_list): + devices = device_list.split(",") + devices = [int(x) for x in devices] + devices.sort() + + process_device_map = dict() + for process_id, device_id in enumerate(devices): + process_device_map[process_id] = device_id + + return process_device_map + + +def main(): + args = parser.parse_args() + print(args.device_list) + + os.environ['MASTER_ADDR'] = args.addr + os.environ['MASTER_PORT'] = '29688' + make_dir('./output_models/') + + if args.seed is not None: + random.seed(args.seed) + torch.manual_seed(args.seed) + cudnn.deterministic = True + warnings.warn('You have chosen to seed training. ' + 'This will turn on the CUDNN deterministic setting, ' + 'which can slow down your training considerably! ' + 'You may see unexpected behavior when restarting ' + 'from checkpoints.') + + if args.gpu is not None: + warnings.warn('You have chosen a specific GPU. This will completely ' + 'disable data parallelism.') + + if args.dist_url == "env://" and args.world_size == -1: + args.world_size = int(os.environ["WORLD_SIZE"]) + + args.distributed = args.world_size > 1 or args.multiprocessing_distributed + args.process_device_map = device_id_to_process_device_map(args.device_list) + + if args.device == 'npu': + ngpus_per_node = len(args.process_device_map) + else: + if args.distributed: + ngpus_per_node = torch.cuda.device_count() + else: + ngpus_per_node = 1 + print('ngpus_per_node:', ngpus_per_node) + if args.multiprocessing_distributed: + # Since we have ngpus_per_node processes per node, the total world_size + # needs to be adjusted accordingly + args.world_size = ngpus_per_node * args.world_size + # Use torch.multiprocessing.spawn to launch distributed processes: the + # main_worker process function + mp.spawn(main_worker, nprocs=ngpus_per_node, + args=(ngpus_per_node, args)) + else: + # Simply call main_worker function + main_worker(args.gpu, ngpus_per_node, args) + + +def main_worker(gpu, ngpus_per_node, args): + args.gpu = args.process_device_map[gpu] + + if args.gpu is not None: + print("Use GPU: {} for training".format(args.gpu)) + + if args.distributed: + if args.dist_url == "env://" and args.rank == -1: + args.rank = int(os.environ["RANK"]) + if args.multiprocessing_distributed: + # For multiprocessing distributed training, rank needs to be the + # global rank among all the processes + args.rank = args.rank * ngpus_per_node + gpu + + if args.device == 'npu': + dist.init_process_group(backend=args.dist_backend, # init_method=args.dist_url, + world_size=args.world_size, rank=args.rank) + else: + dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, + world_size=args.world_size, rank=args.rank) + + print("=> creating model") + model = CTPN_Model() + critetion_cls = RPN_CLS_Loss('cpu') + critetion_regr = RPN_REGR_Loss('cpu') + if args.distributed: + # For multiprocessing distributed, DistributedDataParallel constructor + # should always set the single deviceF scope, otherwise, + # DistributedDataParallel will use all available devices. + if args.gpu is not None: + if args.device == 'npu': + loc = 'npu:{}'.format(args.gpu) + torch.npu.set_device(loc) + model = model.to(loc) + else: + torch.cuda.set_device(args.gpu) + model.cuda(args.gpu) + + # When using a single GPU per process and per + # DistributedDataParallel, we need to divide the batch size + # ourselves based on the total number of GPUs we have + args.batch_size = int(args.batch_size / args.world_size) + args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) + else: + if args.device == 'npu': + loc = 'npu:{}'.format(args.gpu) + model = model.to(loc) + else: + model.cuda() + # DistributedDataParallel will divide and allocate batch_size to all + # available GPUs if device_ids are not set + elif args.gpu is not None: + if args.device == 'npu': + loc = 'npu:{}'.format(args.gpu) + torch.npu.set_device(args.gpu) + model = model.to(loc) + else: + torch.cuda.set_device(args.gpu) + model = model.cuda(args.gpu) + + else: + # DataParallel will divide and allocate batch_size to all available GPUs + if args.device == 'npu': + loc = 'npu:{}'.format(args.gpu) + else: + print("before : model = torch.nn.DataParallel(model).cuda()") + + # define loss function (criterion) and optimizer + optimizer = apex.optimizers.NpuFusedAdamW(model.parameters(), args.lr, weight_decay=args.weight_decay) + if args.amp: + model, optimizer = amp.initialize( + model, optimizer, opt_level=args.opt_level, loss_scale=args.loss_scale) + + if args.distributed: + # For multiprocessing distributed, DistributedDataParallel constructor + # should always set the single device scope, otherwise, + # DistributedDataParallel will use all available devices. + if args.gpu is not None: + # When using a single GPU per process and per + # DistributedDataParallel, we need to divide the batch size + # ourselves based on the total number of GPUs we have + if args.pretrained: + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], broadcast_buffers=False, + find_unused_parameters=True) + else: + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], broadcast_buffers=False) + else: + model = torch.nn.parallel.DistributedDataParallel(model) + else: + # DataParallel will divide and allocate batch_size to all available GPUs + if args.device == 'npu': + if args.gpu is not None: + loc = 'npu:{}'.format(args.gpu) + model = torch.nn.DataParallel(model).to(loc) + else: + model = torch.nn.DataParallel(model).cuda() + + # optionally resume from a checkpoint + if args.resume: + if os.path.isfile(args.resume): + print("=> loading checkpoint '{}'".format(args.resume)) + if args.gpu is None: + checkpoint = torch.load(args.resume) + else: + # Map model to be loaded to specified single gpu. + if args.device == 'npu': + loc = 'npu:{}'.format(args.gpu) + else: + loc = 'cuda:{}'.format(args.gpu) + checkpoint = torch.load(args.resume, map_location=loc) + args.start_epoch = checkpoint['epoch'] + model.load_state_dict(checkpoint['state_dict']) + optimizer.load_state_dict(checkpoint['optimizer']) + if args.amp: + amp.load_state_dict(checkpoint['amp']) + print("=> loaded checkpoint '{}' (epoch {})" + .format(args.resume, checkpoint['epoch'])) + else: + print("=> no checkpoint found at '{}'".format(args.resume)) + + cudnn.benchmark = True + + img_dir = os.path.join(args.data_path, 'Challenge2_Training_Task12_Images/') + label_dir = os.path.join(args.data_path, 'Challenge2_Training_Task1_GT/') + # Data loading code + train_dataset = icdarDataset(config.img_dir, config.label_dir) + + if args.distributed: + train_sampler = torch.utils.data.distributed.DistributedSampler( + train_dataset) + else: + train_sampler = None + + train_loader = torch.utils.data.DataLoader( + train_dataset, batch_size=args.batch_size, shuffle=( + train_sampler is None), + num_workers=args.workers, pin_memory=False, sampler=train_sampler, drop_last=True) + + if args.prof: + profiling(train_loader, model, critetion_regr, critetion_cls, optimizer, args) + return + + start_time = time.time() + all_fps = [] + for epoch in range(args.start_epoch, args.epochs): + if args.distributed: + train_sampler.set_epoch(epoch) + + adjust_learning_rate(optimizer, epoch, args) + + # train for one epoch + cur_fps = train(train_loader, model, critetion_regr, critetion_cls, optimizer, epoch, args, ngpus_per_node) + if cur_fps is not None: + all_fps.append(cur_fps) + + if args.device == 'npu' and args.gpu == 0 and epoch == 199: + print("Complete 200 epoch training, take time:{}h".format(round((time.time() - start_time) / 3600.0, 2))) + + if not args.multiprocessing_distributed or (args.multiprocessing_distributed + and args.rank % ngpus_per_node == 0): + + ############## npu modify begin ############# + if args.amp: + if (epoch + 1) % 5 == 0: + save_checkpoint({ + 'epoch': epoch + 1, + 'arch': 'ctpn', + 'state_dict': model.state_dict(), + 'optimizer': optimizer.state_dict(), + 'amp': amp.state_dict(), + }, filename=f'output_models/checkpoint-{epoch + 1}.pth.tar') + else: + if (epoch + 1) % 5 == 0: + save_checkpoint({ + 'epoch': epoch + 1, + 'arch': 'ctpn', + 'state_dict': model.state_dict(), + 'optimizer': optimizer.state_dict(), + }, filename=f'output_models/checkpoint-{epoch + 1}.pth.tar') + if not args.multiprocessing_distributed or (args.multiprocessing_distributed + and args.rank % ngpus_per_node == 0): + if all_fps: + print('overallFPS after training:', np.mean(all_fps)) + ############## npu modify end ############# + + +def profiling(data_loader, model, critetion_regr, critetion_cls, optimizer, args): + # switch to train mode + model.train() + + def update(model, images, clss, regrs, optimizer): + out_cls, out_regr = model(images) + loss_regr = critetion_regr(out_regr, regrs) + loss_cls = critetion_cls(out_cls, clss) + loss = loss_cls.to(loc, non_blocking=True) + loss_regr.to(loc, non_blocking=True) + if args.amp: + with amp.scale_loss(loss, optimizer) as scaled_loss: + scaled_loss.backward() + else: + loss.backward() + optimizer.zero_grad() + optimizer.step() + + for step, (images, clss, regrs) in enumerate(data_loader): + if args.device == 'npu': + loc = 'npu:{}'.format(args.gpu) + images = images.to(loc, non_blocking=True) + clss = clss.to(loc, non_blocking=True) + regrs = regrs.to(loc, non_blocking=True) + else: + images = images.cuda(args.gpu, non_blocking=True) + clss = clss.cuda(args.gpu, non_blocking=True) + regrs = regrs.cuda(args.gpu, non_blocking=True) + + if step < 5: + update(model, images, clss, regrs, optimizer) + else: + if args.device == 'npu': + with torch.autograd.profiler.profile(use_npu=True) as prof: + update(model, images, clss, regrs, optimizer) + else: + with torch.autograd.profiler.profile(use_cuda=True) as prof: + update(model, images, clss, regrs, optimizer) + break + if not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank == 0): + prof.export_chrome_trace("output.prof") + + +def train(train_loader, model, critetion_regr, critetion_cls, optimizer, epoch, args, ngpus_per_node): + batch_time = AverageMeter('Time', ':6.3f') + data_time = AverageMeter('Data', ':6.3f') + losses_cls = AverageMeter('LossCls', ':.4e') + losses_regr = AverageMeter('LossRegr', ':.4e') + progress = ProgressMeter( + len(train_loader), + [batch_time, data_time, losses_cls, losses_regr], + prefix="Epoch: [{}]".format(epoch)) + + # switch to train mode + model.train() + + end = time.time() + for i, (images, clss, regrs) in enumerate(train_loader): + # measure data loading time + data_time.update(time.time() - end) + + if args.device == 'npu': + # torch.npu.global_step_inc() + loc = 'npu:{}'.format(args.gpu) + images = images.to(loc, non_blocking=True) + clss = clss.to(loc, non_blocking=True) + regrs = regrs.to(loc, non_blocking=True) + + # compute output + out_cls, out_regr = model(images) + loss_regr = critetion_regr(out_regr, regrs) + loss_cls = critetion_cls(out_cls, clss) + loss = loss_cls.to(loc, non_blocking=True) + loss_regr.to(loc, non_blocking=True) + + # measure accuracy and record loss + losses_regr.update(loss_regr.item(), images.size(0)) + losses_cls.update(loss_cls.item(), images.size(0)) + + # compute gradient and do SGD step + optimizer.zero_grad() + if args.amp: + with amp.scale_loss(loss, optimizer) as scaled_loss: + scaled_loss.backward() + else: + loss.backward() + + optimizer.step() + if args.device == 'npu': + torch.npu.synchronize() + + # measure elapsed time + cost_time = time.time() - end + batch_time.update(cost_time) + end = time.time() + + if i % args.print_freq == 0: + if not args.multiprocessing_distributed or (args.multiprocessing_distributed + and args.rank % ngpus_per_node == 0): + progress.display(i) + + if not args.multiprocessing_distributed or (args.multiprocessing_distributed + and args.rank % ngpus_per_node == 0): + print("[npu id:", args.gpu, "]", "batch_size:", args.world_size * args.batch_size, + 'Time: {:.3f}'.format(batch_time.avg), '* FPS@all {:.3f}'.format( + args.batch_size * args.world_size / batch_time.avg)) + if i >= 10: + cur_fps = args.batch_size * args.world_size / batch_time.avg + return cur_fps + else: + return None + else: + return None + + +def save_checkpoint(state, filename='checkpoint.pth.tar'): + torch.save(state, filename) + + +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self, name, fmt=':f', start_count_index=2): + self.name = name + self.fmt = fmt + self.reset() + self.start_count_index = start_count_index + + def reset(self): + self.val = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + if self.count == 0: + pass + + self.count += n + if self.count > (self.start_count_index * self.N): + self.sum += val * n + + def __str__(self): + fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' + return fmtstr.format(**self.__dict__) + + +class ProgressMeter(object): + + def __init__(self, num_batches, meters, prefix=""): + self.batch_fmtstr = self._get_batch_fmtstr(num_batches) + self.meters = meters + self.prefix = prefix + + def display(self, batch): + entries = [self.prefix + self.batch_fmtstr.format(batch)] + entries += [str(meter) for meter in self.meters] + print('\t'.join(entries)) + + def _get_batch_fmtstr(self, num_batches): + num_digits = len(str(num_batches // 1)) + fmt = '{:' + str(num_digits) + 'd}' + return '[' + fmt + '/' + fmt.format(num_batches) + ']' + + +def adjust_learning_rate(optimizer, epoch, args): + """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" + # lr = args.lr * (0.1 ** (epoch // (args.epochs//3 - 3))) + + if args.warm_up_epochs > 0 and epoch < args.warm_up_epochs: + lr = args.lr * ((epoch + 1) / (args.warm_up_epochs + 1)) + else: + alpha = 0 + cosine_decay = 0.5 * ( + 1 + np.cos(np.pi * (epoch - args.warm_up_epochs) / (args.epochs - args.warm_up_epochs))) + decayed = (1 - alpha) * cosine_decay + alpha + lr = args.lr * decayed + + print("=> Epoch[%d] Setting lr: %.8f" % (epoch, lr)) + for param_group in optimizer.param_groups: + param_group['lr'] = lr + + +def make_dir(path): + if os.path.exists(path): + shutil.rmtree(path) + os.mkdir(path) + + +if __name__ == '__main__': + main() -- Gitee From e6fa505e65c22296e3bafb324e8074f4dc393634 Mon Sep 17 00:00:00 2001 From: lgq1997 Date: Fri, 1 Jul 2022 10:27:40 +0800 Subject: [PATCH 3/3] merge commits --- PyTorch/contrib/cv/detection/CTPN/test/output/.keep | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 PyTorch/contrib/cv/detection/CTPN/test/output/.keep diff --git a/PyTorch/contrib/cv/detection/CTPN/test/output/.keep b/PyTorch/contrib/cv/detection/CTPN/test/output/.keep deleted file mode 100644 index e69de29bb2..0000000000 -- Gitee