diff --git a/ACL_PyTorch/contrib/cv/detection/YOLOX-mmdetection/LICENSE b/ACL_PyTorch/contrib/cv/detection/YOLOX-mmdetection/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..cc87e8683f8accf92fb441738e981d6ab8ce7536 --- /dev/null +++ b/ACL_PyTorch/contrib/cv/detection/YOLOX-mmdetection/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2021 Megvii, Base Detection + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/ACL_PyTorch/contrib/cv/detection/YOLOX-mmdetection/YOLOX.patch b/ACL_PyTorch/contrib/cv/detection/YOLOX-mmdetection/YOLOX.patch new file mode 100644 index 0000000000000000000000000000000000000000..2b817bb72e4dd35ce246b0343ef41b637add0713 --- /dev/null +++ b/ACL_PyTorch/contrib/cv/detection/YOLOX-mmdetection/YOLOX.patch @@ -0,0 +1,152 @@ +diff --git a/configs/yolox/yolox_s_8x8_300e_coco.py b/configs/yolox/yolox_s_8x8_300e_coco.py +index cc73051..db1551f 100644 +--- a/configs/yolox/yolox_s_8x8_300e_coco.py ++++ b/configs/yolox/yolox_s_8x8_300e_coco.py +@@ -19,7 +19,7 @@ model = dict( + train_cfg=dict(assigner=dict(type='SimOTAAssigner', center_radius=2.5)), + # In order to align the source code, the threshold of the val phase is + # 0.01, and the threshold of the test phase is 0.001. +- test_cfg=dict(score_thr=0.01, nms=dict(type='nms', iou_threshold=0.65))) ++ test_cfg=dict(score_thr=0.001, nms=dict(type='nms', iou_threshold=0.65))) + + # dataset settings + data_root = 'data/coco/' +diff --git a/mmdet/models/dense_heads/yolox_head.py b/mmdet/models/dense_heads/yolox_head.py +index a1811c9..5adcf4f 100644 +--- a/mmdet/models/dense_heads/yolox_head.py ++++ b/mmdet/models/dense_heads/yolox_head.py +@@ -17,6 +17,54 @@ from .base_dense_head import BaseDenseHead + from .dense_test_mixins import BBoxTestMixin + + ++class BatchNMSOp(torch.autograd.Function): ++ @staticmethod ++ def forward(ctx, bboxes, scores, score_threshold, iou_threshold, max_size_per_class, max_total_size): ++ """ ++ boxes (torch.Tensor): boxes in shape (batch, N, C, 4). ++ scores (torch.Tensor): scores in shape (batch, N, C). ++ return: ++ nmsed_boxes: (batch, N, 4) ++ nmsed_scores: (batch, N) ++ nmsed_classes: (batch, N) ++ nmsed_num: (batch,) ++ """ ++ ++ # Phony implementation for onnx export ++ nmsed_boxes = bboxes[:, :max_total_size, 0, :] ++ nmsed_scores = scores[:, :max_total_size, 0] ++ nmsed_classes = torch.arange(max_total_size, dtype=torch.long) ++ nmsed_num = torch.Tensor([max_total_size]) ++ ++ return nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num ++ ++ @staticmethod ++ def symbolic(g, bboxes, scores, score_thr, iou_thr, max_size_p_class, max_t_size): ++ nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num = g.op('BatchMultiClassNMS', ++ bboxes, scores, score_threshold_f=score_thr, ++ iou_threshold_f=iou_thr, ++ max_size_per_class_i=max_size_p_class, ++ max_total_size_i=max_t_size, outputs=4) ++ return nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num ++ ++ ++def batch_nms_op(bboxes, scores, score_threshold, iou_threshold, max_size_per_class, max_total_size): ++ if bboxes.dtype == torch.float32: ++ bboxes = bboxes.half() ++ scores = scores.half() ++ bboxes = bboxes.reshape(-1, bboxes.shape[1].numpy(), 1, 4) ++ scores = scores.reshape(-1, scores.shape[1].numpy(), 80) ++ nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num = BatchNMSOp.apply(bboxes, scores, ++ score_threshold, iou_threshold, ++ max_size_per_class, max_total_size) ++ nmsed_boxes = nmsed_boxes.float() ++ nmsed_scores = nmsed_scores.float() ++ nmsed_classes = nmsed_classes.long() ++ dets = torch.cat((nmsed_boxes, nmsed_scores.unsqueeze(-1)), -1) ++ labels = nmsed_classes.reshape((-1, max_total_size)) ++ return dets, labels ++ ++ + @HEADS.register_module() + class YOLOXHead(BaseDenseHead, BBoxTestMixin): + """YOLOXHead head used in `YOLOX `_. +@@ -248,9 +296,8 @@ class YOLOXHead(BaseDenseHead, BBoxTestMixin): + """ + assert len(cls_scores) == len(bbox_preds) == len(objectnesses) + cfg = self.test_cfg if cfg is None else cfg +- scale_factors = [img_meta['scale_factor'] for img_meta in img_metas] + +- num_imgs = len(img_metas) ++ num_imgs = cls_scores[0].shape[0] + featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores] + mlvl_priors = self.prior_generator.grid_priors( + featmap_sizes, +@@ -280,20 +327,11 @@ class YOLOXHead(BaseDenseHead, BBoxTestMixin): + + flatten_bboxes = self._bbox_decode(flatten_priors, flatten_bbox_preds) + +- if rescale: +- flatten_bboxes[..., :4] /= flatten_bboxes.new_tensor( +- scale_factors).unsqueeze(1) +- +- result_list = [] +- for img_id in range(len(img_metas)): +- cls_scores = flatten_cls_scores[img_id] +- score_factor = flatten_objectness[img_id] +- bboxes = flatten_bboxes[img_id] +- +- result_list.append( +- self._bboxes_nms(cls_scores, bboxes, score_factor, cfg)) +- +- return result_list ++ score_factors = flatten_objectness.unsqueeze(2).expand(-1, flatten_cls_scores.shape[1], ++ flatten_cls_scores.shape[2]) ++ scores = torch.mul(score_factors, flatten_cls_scores) ++ max_size = 200 ++ return batch_nms_op(flatten_bboxes, scores, cfg.score_thr, cfg.nms.iou_threshold, max_size, max_size) + + def _bbox_decode(self, priors, bbox_preds): + xys = (bbox_preds[..., :2] * priors[:, 2:]) + priors[:, :2] +diff --git a/mmdet/models/detectors/yolox.py b/mmdet/models/detectors/yolox.py +index 2aba93f..4b35710 100644 +--- a/mmdet/models/detectors/yolox.py ++++ b/mmdet/models/detectors/yolox.py +@@ -132,3 +132,13 @@ class YOLOX(SingleStageDetector): + + input_size = (tensor[0].item(), tensor[1].item()) + return input_size ++ ++ def onnx_export(self, img, img_metas=None, **kwargs): ++ return self.simple_test(img, img_metas, **kwargs) ++ ++ def simple_test(self, img, img_metas, **kwargs): ++ if torch.onnx.is_in_onnx_export(): ++ feat = self.extract_feat(img) ++ return self.bbox_head.simple_test(feat, img_metas, **kwargs) ++ else: ++ return super().simple_test(img, img_metas, **kwargs) +diff --git a/tools/deployment/pytorch2onnx.py b/tools/deployment/pytorch2onnx.py +index 5c786f8..2d0b89d 100644 +--- a/tools/deployment/pytorch2onnx.py ++++ b/tools/deployment/pytorch2onnx.py +@@ -97,7 +97,8 @@ def pytorch2onnx(model, + do_constant_folding=True, + verbose=show, + opset_version=opset_version, +- dynamic_axes=dynamic_axes) ++ dynamic_axes=dynamic_axes, ++ enable_onnx_checker=False) + + model.forward = origin_forward + +@@ -215,8 +216,8 @@ def parse_normalize_cfg(test_pipeline): + break + assert transforms is not None, 'Failed to find `transforms`' + norm_config_li = [_ for _ in transforms if _['type'] == 'Normalize'] +- assert len(norm_config_li) == 1, '`norm_config` should only have one' +- norm_config = norm_config_li[0] ++ assert len(norm_config_li) <= 1, '`norm_config` should less than or equal to one' ++ norm_config = norm_config_li[0] if len(norm_config_li) > 0 else dict(mean=0.0, std=1.0) + return norm_config + + diff --git a/ACL_PyTorch/contrib/cv/detection/YOLOX-mmdetection/YOLOX_postprocess.py b/ACL_PyTorch/contrib/cv/detection/YOLOX-mmdetection/YOLOX_postprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..29003ed851832ef6eb707b4258022e994daef334 --- /dev/null +++ b/ACL_PyTorch/contrib/cv/detection/YOLOX-mmdetection/YOLOX_postprocess.py @@ -0,0 +1,54 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import mmcv +import numpy as np +import argparse +from mmdet.core import bbox2result +from mmdet.datasets import build_dataset + +ann_file = '/annotations/instances_val2017.json' +img_prefix = '/val2017/' + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--dataset_path', default="/opt/npu/coco") + parser.add_argument('--model_config', default="mmdetection/configs/yolox/yolox_s_8x8_300e_coco.py") + parser.add_argument('--bin_data_path', default="result/dumpOutput_device0/") + parser.add_argument('--meta_info_path', default="yolox_meta.info") + parser.add_argument('--num_classes', default=81) + + args = parser.parse_args() + + cfg = mmcv.Config.fromfile(args.model_config) + cfg.data.test.test_mode = True + cfg.data.test.ann_file = args.dataset_path + ann_file + cfg.data.test.img_prefix = args.dataset_path + img_prefix + dataset = build_dataset(cfg.data.test) + + num_classes = int(args.num_classes) + outputs = [] + with open(args.meta_info_path, "r") as fp: + for line in fp: + _, file_path, scalar = line.split() + scalar = float(scalar) + file_name = file_path.split("/")[1].replace(".bin", "") + result_list = [ + np.fromfile("{0}{1}_{2}.bin".format(args.bin_data_path, file_name, 1), dtype=np.float32).reshape(-1, 5), + np.fromfile("{0}{1}_{2}.bin".format(args.bin_data_path, file_name, 2), dtype=np.int64)] + result_list[0][..., :4] /= scalar + bbox_result = bbox2result(result_list[0], result_list[1], num_classes) + outputs.append(bbox_result) + eval_kwargs = {'metric': ['bbox']} + dataset.evaluate(outputs, **eval_kwargs) diff --git a/ACL_PyTorch/contrib/cv/detection/YOLOX-mmdetection/YOLOX_preprocess.py b/ACL_PyTorch/contrib/cv/detection/YOLOX-mmdetection/YOLOX_preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..8b0d3d8715972c6f4d2284ad9703045968aae0b9 --- /dev/null +++ b/ACL_PyTorch/contrib/cv/detection/YOLOX-mmdetection/YOLOX_preprocess.py @@ -0,0 +1,72 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import argparse +import numpy as np +import cv2 +import mmcv +import torch +import pickle as pk +import multiprocessing + +flags = None + + +def gen_input_bin(file_batches, batch): + i = 0 + for file in file_batches[batch]: + i = i + 1 + print("batch", batch, file, "===", i) + + image = mmcv.imread(os.path.join(flags.image_src_path, file)) + # ori_shape = image.shape + image, scalar = mmcv.imrescale(image, (flags.model_input_height, flags.model_input_width), return_scale=True) + # img_shape = image.shape + image = mmcv.impad(image, shape=(flags.model_input_height, flags.model_input_width), + pad_val=(flags.model_pad_val, flags.model_pad_val, flags.model_pad_val)) + + image = image.transpose(2, 0, 1) + image = image.astype(np.float32) + image.tofile(os.path.join(flags.bin_file_path, file.split('.')[0] + ".bin")) + image_meta = {'scalar': scalar} + with open(os.path.join(flags.meta_file_path, file.split('.')[0] + ".pk"), "wb") as fp: + pk.dump(image_meta, fp) + + +def preprocess(): + files = os.listdir(flags.image_src_path) + file_batches = [files[i:i + 100] for i in range(0, 5000, 100) if files[i:i + 100] != []] + thread_pool = multiprocessing.Pool(len(file_batches)) + for batch in range(len(file_batches)): + thread_pool.apply_async(gen_input_bin, args=(file_batches, batch)) + thread_pool.close() + thread_pool.join() + print("in thread, except will not report! please ensure bin files generated.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='preprocess of YOLOX PyTorch model') + parser.add_argument("--image_src_path", default="/opt/npu/coco/val2017", help='image of dataset') + parser.add_argument("--bin_file_path", default="val2017_bin", help='Preprocessed image buffer') + parser.add_argument("--meta_file_path", default="val2017_bin_meta", help='Get image meta') + parser.add_argument("--model_input_height", default=640, type=int, help='input tensor height') + parser.add_argument("--model_input_width", default=640, type=int, help='input tensor width') + parser.add_argument("--model_pad_val", default=114, type=int, help='image pad value') + flags = parser.parse_args() + if not os.path.exists(flags.bin_file_path): + os.makedirs(flags.bin_file_path) + if not os.path.exists(flags.meta_file_path): + os.makedirs(flags.meta_file_path) + preprocess() diff --git a/ACL_PyTorch/contrib/cv/detection/YOLOX-mmdetection/gen_dataset_info.py b/ACL_PyTorch/contrib/cv/detection/YOLOX-mmdetection/gen_dataset_info.py new file mode 100644 index 0000000000000000000000000000000000000000..dee00755558872b2746c53c9e96f24394dbbd417 --- /dev/null +++ b/ACL_PyTorch/contrib/cv/detection/YOLOX-mmdetection/gen_dataset_info.py @@ -0,0 +1,53 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import mmcv +from mmdet.datasets import build_dataset +import pickle as pk + +ann_file = '/annotations/instances_val2017.json' +img_prefix = '/val2017/' + +if __name__ == '__main__': + image_src_path = sys.argv[1] + config_path = sys.argv[2] + bin_path = sys.argv[3] + meta_path = sys.argv[4] + info_name = sys.argv[5] + info_meta_name = sys.argv[6] + width = int(sys.argv[7]) + height = int(sys.argv[8]) + + cfg = mmcv.Config.fromfile(config_path) + cfg.data.test.ann_file = image_src_path + ann_file + cfg.data.test.img_prefix = image_src_path + img_prefix + cfg.data.test.test_mode = True + + dataset = build_dataset(cfg.data.test) + + with open(info_name, "w") as fp1, open(info_meta_name, "w") as fp2: + for idx in range(5000): + img_id = dataset.img_ids[idx] + fp1.write("{} {}/{:0>12d}.bin {} {}\n".format(idx, bin_path, img_id, width, height)) + fp_meta = open("%s/%012d.pk" % (meta_path, img_id), "rb") + meta = pk.load(fp_meta) + fp_meta.close() + fp2.write("{} {}/{:0>12d}.bin {}\n".format( + idx, + meta_path, + img_id, + meta['scalar'] + )) + print("Get info done!") diff --git a/ACL_PyTorch/contrib/cv/detection/YOLOX-mmdetection/modelzoo_level.txt b/ACL_PyTorch/contrib/cv/detection/YOLOX-mmdetection/modelzoo_level.txt new file mode 100644 index 0000000000000000000000000000000000000000..3901da7fbaa158ca7d805621545eabdadc18999f --- /dev/null +++ b/ACL_PyTorch/contrib/cv/detection/YOLOX-mmdetection/modelzoo_level.txt @@ -0,0 +1,4 @@ +FuncStatus:OK +PrecisionStatus:OK +ModelConvert:OK +PerfStatus=OK diff --git a/ACL_PyTorch/contrib/cv/detection/YOLOX-mmdetection/readme.md b/ACL_PyTorch/contrib/cv/detection/YOLOX-mmdetection/readme.md new file mode 100644 index 0000000000000000000000000000000000000000..645cc073d17cc9c6290570fb32fdfdc09ecb5f4f --- /dev/null +++ b/ACL_PyTorch/contrib/cv/detection/YOLOX-mmdetection/readme.md @@ -0,0 +1,51 @@ +### YOLOX模型PyTorch离线推理指导 + +### 1. 环境准备 + +1. 安装依赖 + +```bash +pip install -r requirements.txt +``` + +2. 获取,修改与安装开源模型代码 + +``` +git clone -b master https://github.com/open-mmlab/mmdetection.git +cd mmdetection +git reset 6b87ac22b8d9dea8cc28b9ce84909e6c311e6268 --hard + +pip install -v -e . # or python3 setup.py develop +pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cpu/torch1.7.0/index.html +patch -p1 < ../YOLOX.patch +cd .. +``` + +3. 将权重文件[yolox_x_8x8_300e_coco_20211126_140254-1ef88d67.pth](https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_x_8x8_300e_coco/yolox_x_8x8_300e_coco_20211126_140254-1ef88d67.pth)放到当前工作目录。 + +4. 数据集 + + 获取COCO数据集,并重命名为COCO,放到/root/datasets目录 + +5. [获取benchmark工具](https://gitee.com/ascend/cann-benchmark/tree/master/infer) + + 将benchmark.x86_64或benchmark.aarch64放到当前工作目录 + +### 2. 离线推理 + +710上执行,执行时使npu-smi info查看设备状态,确保device空闲 + +```bash +bash test/pth2om.sh --batch_size=1 +bash test/eval_acc_perf.sh --datasets_path=/root/datasets --batch_size=1 +``` + +**评测结果:** + +| 模型 | pth精度 | 710离线推理精度 | 性能基准 | 710性能 | +| ----------- | --------- | --------------- | --------- | ------- | +| YOLOX bs1 | box AP:50.9 | box AP:51.0 | fps 11.828 | fps 27.697 | +| YOLOX bs16 | box AP:50.9 | box AP:51.0 | fps 14.480 | fps 38.069 | + + + diff --git a/ACL_PyTorch/contrib/cv/detection/YOLOX-mmdetection/requirements.txt b/ACL_PyTorch/contrib/cv/detection/YOLOX-mmdetection/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..417d5e02a3b432ff1a034314f6975c4072f6f479 --- /dev/null +++ b/ACL_PyTorch/contrib/cv/detection/YOLOX-mmdetection/requirements.txt @@ -0,0 +1,7 @@ +torch==1.7.0 +torchvision==0.8.0 +onnx +opencv-python +sympy +cython +numpy \ No newline at end of file diff --git a/ACL_PyTorch/contrib/cv/detection/YOLOX-mmdetection/test/eval_acc_perf.sh b/ACL_PyTorch/contrib/cv/detection/YOLOX-mmdetection/test/eval_acc_perf.sh new file mode 100644 index 0000000000000000000000000000000000000000..97f3b0e46b4de415c4430def9cc38a490a7ac257 --- /dev/null +++ b/ACL_PyTorch/contrib/cv/detection/YOLOX-mmdetection/test/eval_acc_perf.sh @@ -0,0 +1,68 @@ +#!/bin/bash + +set -eu + +datasets_path="/root/datasets" +batch_size=1 + +for para in $* +do + if [[ $para == --datasets_path* ]]; then + datasets_path=`echo ${para#*=}` + fi + if [[ $para == --batch_size* ]]; then + batch_size=`echo ${para#*=}` + fi +done + + +arch=`uname -m` + +rm -rf val2017_bin +rm -rf val2017_bin_meta +python YOLOX_preprocess.py --image_src_path ${datasets_path}/coco/val2017 + +if [ $? != 0 ]; then + echo "fail!" + exit -1 +fi + +python gen_dataset_info.py \ +${datasets_path}/coco \ +mmdetection/configs/yolox/yolox_s_8x8_300e_coco.py \ +val2017_bin val2017_bin_meta \ +yolox.info yolox_meta.info \ +640 640 + +if [ $? != 0 ]; then + echo "fail!" + exit -1 +fi + +source /usr/local/Ascend/ascend-toolkit/set_env.sh +rm -rf result + +./benchmark.${arch} -model_type=vision -om_path=yolox.om -device_id=0 -batch_size=${batch_size} \ +-input_text_path=yolox.info -input_width=640 -input_height=640 -useDvpp=false -output_binary=true + +if [ $? != 0 ]; then + echo "fail!" + exit -1 +fi + +python YOLOX_postprocess.py --dataset_path ${datasets_path}/coco +if [ $? != 0 ]; then + echo "fail!" + exit -1 +fi + + +echo "====performance data====" +python test/parse.py result/perf_vision_batchsize_${batch_size}_device_0.txt + +if [ $? != 0 ]; then + echo "fail!" + exit -1 +fi +echo "success" + diff --git a/ACL_PyTorch/contrib/cv/detection/YOLOX-mmdetection/test/parse.py b/ACL_PyTorch/contrib/cv/detection/YOLOX-mmdetection/test/parse.py new file mode 100644 index 0000000000000000000000000000000000000000..4e51797a87ad12ff1d7dc6c400e6a3d5dee540eb --- /dev/null +++ b/ACL_PyTorch/contrib/cv/detection/YOLOX-mmdetection/test/parse.py @@ -0,0 +1,32 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import json +import re + +if __name__ == '__main__': + if sys.argv[1].endswith('.json'): + result_json = sys.argv[1] + with open(result_json, 'r') as f: + content = f.read() + tops = [i.get('value') for i in json.loads(content).get('value') if 'Top' in i.get('key')] + print('om {} top1:{} top5:{}'.format(result_json.split('_')[1].split('.')[0], tops[0], tops[4])) + elif sys.argv[1].endswith('.txt'): + result_txt = sys.argv[1] + with open(result_txt, 'r') as f: + content = f.read() + txt_data_list = [i.strip() for i in re.findall(r':(.*?),', content.replace('\n', ',') + ',')] + fps = float(txt_data_list[7]) + print('710 bs{} fps:{}'.format(result_txt.split('_')[3], fps)) diff --git a/ACL_PyTorch/contrib/cv/detection/YOLOX-mmdetection/test/pth2om.sh b/ACL_PyTorch/contrib/cv/detection/YOLOX-mmdetection/test/pth2om.sh new file mode 100644 index 0000000000000000000000000000000000000000..e7eece1b2736e18bcb5d1426afadb5bc606596d4 --- /dev/null +++ b/ACL_PyTorch/contrib/cv/detection/YOLOX-mmdetection/test/pth2om.sh @@ -0,0 +1,40 @@ +#!/bin/bash + +set -eu +batch_size=1 + +for para in $* +do + if [[ $para == --batch_size* ]]; then + batch_size=`echo ${para#*=}` + fi +done + + +cd mmdetection +rm -f ../yolox.onnx + +python tools/deployment/pytorch2onnx.py \ +configs/yolox/yolox_x_8x8_300e_coco.py \ +../yolox_x_8x8_300e_coco_20211126_140254-1ef88d67.pth \ +--output-file \ +../yolox.onnx \ +--shape 640 640 --dynamic-export + +if [ $? != 0 ]; then + echo "fail!" + exit -1 +fi + +cd .. +rm -f yolox.om +source /usr/local/Ascend/ascend-toolkit/set_env.sh + +atc --framework=5 --model=yolox.onnx --output=yolox --input_format=NCHW \ +--input_shape="input:$batch_size,3,640,640" --log=error --soc_version=Ascend710 + +if [ -f "yolox.om" ]; then + echo "success" +else + echo "fail!" +fi