diff --git a/AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/README.md b/AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..46e2bac2a3f49c526ae327a71ac65896cbed1283
--- /dev/null
+++ b/AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/README.md
@@ -0,0 +1,277 @@
+# Faster R-CNN_ResNet50模型-推理指导
+
+
+- [概述](#ZH-CN_TOPIC_0000001172161501)
+
+- [推理环境准备](#ZH-CN_TOPIC_0000001126281702)
+
+- [快速上手](#ZH-CN_TOPIC_0000001126281700)
+
+ - [获取源码](#section4622531142816)
+ - [准备数据集](#section183221994411)
+ - [模型推理](#section741711594517)
+
+- [模型推理性能](#ZH-CN_TOPIC_0000001172201573)
+
+- [配套环境](#ZH-CN_TOPIC_0000001126121892)
+
+ ******
+
+
+
+# 概述
+
+2016年,新的Faster RCNN模型被提出,在结构上,Faster RCNN已经将特征抽取(feature extraction),proposal提取,bounding box regression(rect refine),classification都整合在了一个网络中,使得综合性能有较大提高,在检测速度方面尤为明显。
+
+- 参考实现:
+
+ ```
+ url=https://github.com/open-mmlab/mmdetection/tree/master/configs/faster_rcnn
+ branch=master
+ commit_id=a21eb25535f31634cef332b09fc27d28956fb24b
+ model_name=faster_rcnn_r50_fpn
+ ```
+
+
+
+
+## 输入输出数据
+
+- 输入数据
+
+ | 输入数据 | 数据类型 | 大小 | 数据排布格式 |
+ | :------: | :------: | :-------------------------: | :----------: |
+ | input | RGB_FP32 | batchsize x 3 x 1216 x 1216 | NCHW |
+
+
+- 输出数据
+
+ | 输出数据 | 大小 | 数据类型 | 数据排布格式 |
+ | :------: | :---: | :------: | :---------: |
+ | boxes | 100x5 | FLOAT32 | ND |
+ | labels | 100 | INT32 | ND |
+
+
+# 推理环境准备
+
+- 该模型需要以下插件与驱动
+
+ **表 1** 版本配套表
+
+| 配套 | 版本 |
+|-----------------------|-----------------|
+| CANN | 7.0.RC1.3 | - |
+| Python | 3.9 |
+| PyTorch | 2.0.1 |
+| torchVison | 0.15.2 |-
+| Ascend-cann-torch-aie | >= 7.0.0
+| Ascend-cann-aie | >= 7.0.0
+| 芯片类型 | Ascend310P3 | - | |
+
+# 快速上手
+
+## 获取源码
+
+1. 获取本仓代码
+ ```bash
+ git clone https://gitee.com/ascend/ModelZoo-PyTorch.git
+ cd ./ModelZoo-PyTorch-fasterRcnn/AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x
+ ```
+
+
+
+ 文件说明
+ ```
+ FasterRcnn
+ ├── README.md # 此文档
+ ├── sample.py # 模型推理脚本文件
+ ├── mmdetection.patch # mmdetection仓patch文件
+ ├── requirements.txt # 依赖库文件
+ ├── trace.py # 模型trace脚本文件
+ └── mmdet_ops # 自定义算子注册文件夹
+ ```
+
+
+
+2. 安装依赖
+ ```bash
+ pip3 install -r requirements.txt
+
+ # 安装mmpycocotools
+ pip3 install mmpycocotools==12.0.3
+ ```
+
+
+
+3. 获取模型源码,并安装相应的依赖库
+
+ ```bash
+ git clone https://github.com/open-mmlab/mmdetection.git
+ cd mmdetection
+ git reset --hard a21eb25535f31634cef332b09fc27d28956fb24b
+ pip3 install -v -e .
+ ```
+
+
+
+4. 修改mmdetection源码
+
+ 使用mmdetection(v2.8.0)trace前, 需要对源码做一定的改动,以适配Ascend NPU。
+
+ ```bash
+ patch -p1 < mmdetection.patch
+ ```
+
+5. 获取前后处理脚本源码
+
+ ```bash
+ cp ModelZoo-PyTorch/ACL_PyTorch/contrib/cv/detection/Faster_R-CNN_ResNet50/* ./
+ ```
+## 准备数据集
+
+1. 获取原始数据集和验证集
+
+ 该模型使用[COCO官网](https://cocodataset.org/#download)的coco2017的5千张验证集进行测试,图片与标签分别存放在```val2017/```与```annotations/instances_val2017.json```。
+
+ ```bash
+ wget http://images.cocodataset.org/zips/val2017.zip --no-check-certificate
+ unzip -qo val2017.zip
+
+ wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip --no-check-certificate
+ unzip -qo annotations_trainval2017.zip
+ ```
+
+2. 数据预处理
+ 将原始数据集转换为模型输入的数据。
+
+ 将原始数据(.jpeg)转化为二进制文件(.bin)。转化方法参考mmdetection预处理方法,以获得最佳精度。以coco_2017数据集为例,通过缩放、均值方差手段归一化,输出为二进制文件。
+
+ 执行mmdetection_coco_preprocess.py脚本,完成预处理。
+
+ ```bash
+ python3 mmdetection_coco_preprocess.py --image_folder_path val2017/ --bin_folder_path val2017_bin
+ ```
+
+ 参数说明:
+ - --image_folder_path: 图像数据集目录。
+ - --bin_folder_path: 二进制文件输出目录。
+
+
+3. JPG图片info文件生成
+
+
+ 后处理时需要输入数据集.jpg图片的info文件。使用get_info.py脚本,输入已经获得的图片文件,输出生成图片数据集的info文件。
+
+ 运行get_info.py脚本。
+
+ ```bash
+ python3 get_info.py jpg ./val2017/ coco2017_jpg.info
+ ```
+ 参数说明:
+
+ - 第一个参数为生成的数据集文件格式。
+ - 第二个参数为coco图片数据文件的**相对路径**。
+ - 第三个参数为生成的数据集信息文件保存的路径。
+
+
+ 运行成功后,在当前目录中生成```coco2017_jpg.info```。
+
+## 模型推理
+
+1. 模型加载。
+
+
+ a. 获取权重文件。
+
+
+ ```
+ wget http://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth --no-check-certificate
+ ```
+
+
+ b. load并trace模型
+ 使用mmdetection/tools目录中的pytorch2onnx导出onnx文件。运行pytorch2onnx脚本。
+
+
+ ```
+ python3 trace.py --config ./mmdetection/configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py --checkpoint ./faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth --input-img ./coco/val2017/000000515350.jpg --shape 1216 --output-file ./faster_rcnn_torch_aie.pt --mmdet-ops-path ./mmdet_ops/build/libmmdet_ops.so
+ ```
+
+
+ 参数说明:
+
+ - --config: 模型的配置文件。
+ - --checkpoint: 模型的权重文件。
+ - --input-img: coco数据集jpg文件。
+ - --shape: 模型的输入shape。
+ - --output-file: trace的pt模型文件保存路径。
+
+ 获得```faster_rcnn_torch_aie.pt```文件。
+
+
+
+2. 开始推理验证。
+
+ a. 执行推理。
+
+ ```
+ python3 sample.py --traced_model_path ./faster_rcnn_torch_aie.pt --bin_path path/to/coco/val_2017_bin --ts_save_path compiled/model/save/path --results_save_path detection/results/save/path
+ ```
+ - 参数说明:
+ - --traced_model_path:trace的pt文件路径。
+ - --bin_path:coco测试数据集的bin文件路径。
+ - --ts_save_path:编译之后的ts模型所存目录。
+ - --results_save_path:检测结果bin文件所存目录。
+
+ 推理后的输出默认在当前目录result下。
+
+ b. 精度验证。
+
+ 本模型提供后处理脚本,将二进制数据转化为txt文件,执行脚本。
+
+ ```
+ python3 mmdetection_coco_postprocess.py --bin_data_path=result/${infer_result_dir} --prob_thres=0.05 --det_results_path=detection-results --test_annotation=coco2017_jpg.info
+ ```
+
+ - 参数说明:
+
+ - bin_data_path:推理输出目录 (注意替换成实际目录,如```2022_12_16-18_01_01/```)。
+
+ - prob_thres:框的置信度阈值。
+
+ - det_results:后处理输出目录。
+
+ 评测结果的mAP值需要使用官方的pycocotools工具,首先将后处理输出的txt文件转化为coco数据集评测精度的标准json格式。
+
+ 执行转换脚本。
+
+ ```
+ python3 txt_to_json.py --npu_txt_path detection-results --json_output_file coco_detection_result
+ ```
+ - 参数说明:
+
+ - --npu_txt_path: 输入的txt文件目录。
+
+ - --json_output_file: 输出的json文件路径。
+
+
+ 运行成功后,生成```coco_detection_result.json```文件。
+ 调用coco_eval.py脚本,输出推理结果的详细评测报告。
+
+ ```
+ python3 coco_eval.py --detection_result coco_detection_result.json --ground_truth=annotations/instances_val2017.json
+ ```
+ - 参数说明:
+ - --detection_result:推理结果json文件。
+
+ - --ground_truth:```instances_val2017.json```的存放路径。
+
+
+
+# 模型推理性能&精度
+
+调用Torch_AIE推理计算,性能参考下列数据。
+
+| 芯片型号 | Batch Size | 数据集 | 精度(mAP) | 性能(FPS) |
+| :------: | :---------: | :-----: | :---: | :--: |
+| Ascend310P | 1 | coco2017 | 37.2 | 13.11 |
\ No newline at end of file
diff --git a/AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/mmdet_ops/CMakeLists.txt b/AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/mmdet_ops/CMakeLists.txt
new file mode 100644
index 0000000000000000000000000000000000000000..44192fbdab3560b769062e7c17591536a6aed722
--- /dev/null
+++ b/AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/mmdet_ops/CMakeLists.txt
@@ -0,0 +1,13 @@
+cmake_minimum_required(VERSION 3.1 FATAL_ERROR)
+project(mmdet_ops)
+set(CMAKE_PREFIX_PATH "/home/hekuikui/env/conda/envs/t181_v091/lib/python3.9/site-packages/torch")
+find_package(Torch REQUIRED)
+
+add_library(mmdet_ops SHARED mmdet_ops.cpp)
+
+target_compile_features(mmdet_ops PRIVATE cxx_std_14)
+
+target_link_libraries(mmdet_ops PUBLIC
+ c10
+ torch
+ torch_cpu)
diff --git a/AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/mmdet_ops/build.sh b/AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/mmdet_ops/build.sh
new file mode 100644
index 0000000000000000000000000000000000000000..4a5dbcf6079fe9ca2e6d64dd06c198b5003ecf70
--- /dev/null
+++ b/AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/mmdet_ops/build.sh
@@ -0,0 +1,6 @@
+rm -r build
+mkdir build
+cd build
+
+cmake ..
+make -j 32
\ No newline at end of file
diff --git a/AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/mmdet_ops/mmdet_ops.cpp b/AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/mmdet_ops/mmdet_ops.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..81ba9fae8694294c4fe49826d68dfed3d48e8a7c
--- /dev/null
+++ b/AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/mmdet_ops/mmdet_ops.cpp
@@ -0,0 +1,66 @@
+/**
+ * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include
+#include
+
+#include
+
+std::tuple batch_nms(
+ at::Tensor bbox,
+ at::Tensor scores,
+ double score_threshold,
+ double iou_threshold,
+ int64_t max_size_per_class,
+ int64_t max_total_size
+)
+{
+ auto boxBatch = bbox.sizes()[0];
+ auto boxFeat = bbox.sizes()[3];
+ auto scoreBatch = scores.sizes()[0];
+
+ auto outBox = torch::ones({boxBatch, max_total_size, boxFeat}).to(torch::kFloat16);
+ auto outScore = torch::ones({scoreBatch, max_total_size}).to(torch::kFloat16);
+ auto outClass = torch::ones({max_total_size, }).to(torch::kInt64);
+ auto outNum = torch::ones({1, }).to(torch::kFloat32);
+
+ return std::make_tuple(outBox, outScore, outClass, outNum);
+}
+
+at::Tensor roi_extractor(
+ std::vector feats,
+ at::Tensor rois,
+ bool aligned,
+ int64_t finest_scale,
+ int64_t pooled_height,
+ int64_t pooled_width,
+ c10::string_view pool_mode,
+ int64_t roi_scale_factor,
+ int64_t sample_num,
+ std::vector spatial_scale
+)
+{
+ auto k = rois.sizes()[0];
+ auto c = feats[0].sizes()[1];
+ auto roi_feats = torch::ones({k, c, pooled_height, pooled_width}).to(torch::kFloat32);
+
+ return roi_feats;
+}
+
+TORCH_LIBRARY(aie, m) {
+ m.def("batch_nms", batch_nms);
+ m.def("roi_extractor", roi_extractor);
+}
\ No newline at end of file
diff --git a/AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/mmdetection.patch b/AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/mmdetection.patch
new file mode 100644
index 0000000000000000000000000000000000000000..04bb741368dc8163c2309e6c138378e23e50f5a2
--- /dev/null
+++ b/AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/mmdetection.patch
@@ -0,0 +1,519 @@
+From a098c07a11670cac3352493c5146540d13d10d19 Mon Sep 17 00:00:00 2001
+From: hekuikui
+Date: Fri, 15 Dec 2023 11:41:37 +0800
+Subject: [PATCH 1215/1215] mmdetection.patch
+
+---
+ .../core/bbox/coder/delta_xywh_bbox_coder.py | 32 +++++--
+ mmdet/core/post_processing/bbox_nms.py | 92 ++++++++++++++++++-
+ mmdet/models/dense_heads/rpn_head.py | 81 +++++++++++++++-
+ mmdet/models/detectors/base.py | 5 +
+ mmdet/models/detectors/single_stage.py | 5 +-
+ mmdet/models/roi_heads/cascade_roi_head.py | 5 +-
+ .../roi_heads/mask_heads/fcn_mask_head.py | 14 ++-
+ .../single_level_roi_extractor.py | 48 ++++++++--
+ mmdet/models/roi_heads/standard_roi_head.py | 14 +--
+ mmdet/models/roi_heads/test_mixins.py | 25 +++--
+ 10 files changed, 270 insertions(+), 51 deletions(-)
+
+diff --git a/mmdet/core/bbox/coder/delta_xywh_bbox_coder.py b/mmdet/core/bbox/coder/delta_xywh_bbox_coder.py
+index e9eb3579..62562e75 100644
+--- a/mmdet/core/bbox/coder/delta_xywh_bbox_coder.py
++++ b/mmdet/core/bbox/coder/delta_xywh_bbox_coder.py
+@@ -168,8 +168,13 @@ def delta2bbox(rois,
+ [0.0000, 0.3161, 4.1945, 0.6839],
+ [5.0000, 5.0000, 5.0000, 5.0000]])
+ """
+- means = deltas.new_tensor(means).view(1, -1).repeat(1, deltas.size(1) // 4)
+- stds = deltas.new_tensor(stds).view(1, -1).repeat(1, deltas.size(1) // 4)
++ # fix shape for means and stds for onnx
++ if torch.onnx.is_in_onnx_export():
++ means = deltas.new_tensor(means).view(1, -1).repeat(1, deltas.size(1).numpy() // 4)
++ stds = deltas.new_tensor(stds).view(1, -1).repeat(1, deltas.size(1).numpy() // 4)
++ else:
++ means = deltas.new_tensor(means).view(1, -1).repeat(1, deltas.size(1) // 4)
++ stds = deltas.new_tensor(stds).view(1, -1).repeat(1, deltas.size(1) // 4)
+ denorm_deltas = deltas * stds + means
+ dx = denorm_deltas[:, 0::4]
+ dy = denorm_deltas[:, 1::4]
+@@ -178,12 +183,23 @@ def delta2bbox(rois,
+ max_ratio = np.abs(np.log(wh_ratio_clip))
+ dw = dw.clamp(min=-max_ratio, max=max_ratio)
+ dh = dh.clamp(min=-max_ratio, max=max_ratio)
+- # Compute center of each roi
+- px = ((rois[:, 0] + rois[:, 2]) * 0.5).unsqueeze(1).expand_as(dx)
+- py = ((rois[:, 1] + rois[:, 3]) * 0.5).unsqueeze(1).expand_as(dy)
+- # Compute width/height of each roi
+- pw = (rois[:, 2] - rois[:, 0]).unsqueeze(1).expand_as(dw)
+- ph = (rois[:, 3] - rois[:, 1]).unsqueeze(1).expand_as(dh)
++ # improve gather performance on NPU
++ if torch.onnx.is_in_onnx_export():
++ rois_perf = rois.permute(1, 0)
++ # Compute center of each roi
++ px = ((rois_perf[0, :] + rois_perf[2, :]) * 0.5).unsqueeze(1).expand_as(dx)
++ py = ((rois_perf[1, :] + rois_perf[3, :]) * 0.5).unsqueeze(1).expand_as(dy)
++ # Compute width/height of each roi
++ pw = (rois_perf[2, :] - rois_perf[0, :]).unsqueeze(1).expand_as(dw)
++ ph = (rois_perf[3, :] - rois_perf[1, :]).unsqueeze(1).expand_as(dh)
++ else:
++ rois_perf = rois.permute(1, 0)
++ # Compute center of each roi
++ px = ((rois_perf[0, :] + rois_perf[2, :]) * 0.5).unsqueeze(1).expand_as(dx)
++ py = ((rois_perf[1, :] + rois_perf[3, :]) * 0.5).unsqueeze(1).expand_as(dy)
++ # Compute width/height of each roi
++ pw = (rois_perf[2, :] - rois_perf[0, :]).unsqueeze(1).expand_as(dw)
++ ph = (rois_perf[3, :] - rois_perf[1, :]).unsqueeze(1).expand_as(dh)
+ # Use exp(network energy) to enlarge/shrink each roi
+ gw = pw * dw.exp()
+ gh = ph * dh.exp()
+diff --git a/mmdet/core/post_processing/bbox_nms.py b/mmdet/core/post_processing/bbox_nms.py
+index 463fe2e4..1363e9c9 100644
+--- a/mmdet/core/post_processing/bbox_nms.py
++++ b/mmdet/core/post_processing/bbox_nms.py
+@@ -4,6 +4,68 @@ from mmcv.ops.nms import batched_nms
+ from mmdet.core.bbox.iou_calculators import bbox_overlaps
+
+
++class BatchNMSOp(torch.autograd.Function):
++ @staticmethod
++ def forward(ctx, bboxes, scores, score_threshold, iou_threshold, max_size_per_class, max_total_size):
++ """
++ boxes (torch.Tensor): boxes in shape (batch, N, C, 4).
++ scores (torch.Tensor): scores in shape (batch, N, C).
++ return:
++ nmsed_boxes: (1, N, 4)
++ nmsed_scores: (1, N)
++ nmsed_classes: (1, N)
++ nmsed_num: (1,)
++ """
++
++ # Phony implementation for onnx export
++ nmsed_boxes = bboxes[:, :max_total_size, 0, :]
++ nmsed_scores = scores[:, :max_total_size, 0]
++ nmsed_classes = torch.arange(max_total_size, dtype=torch.long)
++ nmsed_num = torch.Tensor([max_total_size])
++
++ return nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num
++
++ @staticmethod
++ def symbolic(g, bboxes, scores, score_thr, iou_thr, max_size_p_class, max_t_size):
++ nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num = g.op('BatchMultiClassNMS',
++ bboxes, scores, score_threshold_f=score_thr, iou_threshold_f=iou_thr,
++ max_size_per_class_i=max_size_p_class, max_total_size_i=max_t_size, outputs=4)
++ return nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num
++
++def batch_nms_op(bboxes, scores, score_threshold, iou_threshold, max_size_per_class, max_total_size):
++ """
++ boxes (torch.Tensor): boxes in shape (N, 4).
++ scores (torch.Tensor): scores in shape (N, ).
++ """
++ if torch.onnx.is_in_onnx_export():
++ if bboxes.dtype == torch.float32:
++ bboxes = bboxes.reshape(1, bboxes.shape[0].numpy(), -1, 4).half()
++ scores = scores.reshape(1, scores.shape[0].numpy(), -1).half()
++ else:
++ bboxes = bboxes.reshape(1, bboxes.shape[0].numpy(), -1, 4)
++ scores = scores.reshape(1, scores.shape[0].numpy(), -1)
++ else:
++ if bboxes.dtype == torch.float32:
++ bboxes = bboxes.reshape(1, bboxes.shape[0], -1, 4).half()
++ scores = scores.reshape(1, scores.shape[0], -1).half()
++ else:
++ bboxes = bboxes.reshape(1, bboxes.shape[0], -1, 4)
++ scores = scores.reshape(1, scores.shape[0], -1)
++
++ batch_nms = torch.ops.aie.batch_nms
++ nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num = batch_nms(bboxes, scores,
++ score_threshold, iou_threshold,
++ max_size_per_class, max_total_size)
++ # nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num = BatchNMSOp.apply(bboxes, scores,
++ # score_threshold, iou_threshold, max_size_per_class, max_total_size)
++ nmsed_boxes = nmsed_boxes.float()
++ nmsed_scores = nmsed_scores.float()
++ nmsed_classes = nmsed_classes.long()
++ dets = torch.cat((nmsed_boxes.reshape((max_total_size, 4)), nmsed_scores.reshape((max_total_size, 1))), -1)
++ labels = nmsed_classes.reshape((max_total_size, ))
++ return dets, labels
++
++
+ def multiclass_nms(multi_bboxes,
+ multi_scores,
+ score_thr,
+@@ -36,13 +98,30 @@ def multiclass_nms(multi_bboxes,
+ if multi_bboxes.shape[1] > 4:
+ bboxes = multi_bboxes.view(multi_scores.size(0), -1, 4)
+ else:
+- bboxes = multi_bboxes[:, None].expand(
+- multi_scores.size(0), num_classes, 4)
++ # export expand operator to onnx more nicely
++ if torch.onnx.is_in_onnx_export:
++ bbox_shape_tensor = torch.ones(multi_scores.size(0), num_classes, 4)
++ bboxes = multi_bboxes[:, None].expand_as(bbox_shape_tensor)
++ else:
++ # bboxes = multi_bboxes[:, None].expand(
++ # multi_scores.size(0), num_classes, 4)
++ bbox_shape_tensor = torch.ones(multi_scores.size(0), num_classes, 4)
++ bboxes = multi_bboxes[:, None].expand_as(bbox_shape_tensor)
++
+
+ scores = multi_scores[:, :-1]
+ if score_factors is not None:
+ scores = scores * score_factors[:, None]
+
++ # npu
++ if torch.onnx.is_in_onnx_export():
++ dets, labels = batch_nms_op(bboxes, scores, score_thr, nms_cfg.get("iou_threshold"), max_num, max_num)
++ return dets, labels
++ else:
++ dets, labels = batch_nms_op(bboxes, scores, score_thr, nms_cfg.get("iou_threshold"), max_num, max_num)
++ return dets, labels
++
++ # cpu and gpu
+ labels = torch.arange(num_classes, dtype=torch.long)
+ labels = labels.view(1, -1).expand_as(scores)
+
+@@ -53,11 +132,13 @@ def multiclass_nms(multi_bboxes,
+ # remove low scoring boxes
+ valid_mask = scores > score_thr
+ inds = valid_mask.nonzero(as_tuple=False).squeeze(1)
++ # vals, inds = torch.topk(scores, 1000)
++
+ bboxes, scores, labels = bboxes[inds], scores[inds], labels[inds]
+ if inds.numel() == 0:
+- if torch.onnx.is_in_onnx_export():
+- raise RuntimeError('[ONNX Error] Can not record NMS '
+- 'as it has not been executed this time')
++ # if torch.onnx.is_in_onnx_export():
++ raise RuntimeError('[ONNX Error] Can not record NMS '
++ 'as it has not been executed this time')
+ if return_inds:
+ return bboxes, labels, inds
+ else:
+@@ -76,6 +157,7 @@ def multiclass_nms(multi_bboxes,
+ return dets, labels[keep]
+
+
++
+ def fast_nms(multi_bboxes,
+ multi_scores,
+ multi_coeffs,
+diff --git a/mmdet/models/dense_heads/rpn_head.py b/mmdet/models/dense_heads/rpn_head.py
+index f565d1a4..d95765c4 100644
+--- a/mmdet/models/dense_heads/rpn_head.py
++++ b/mmdet/models/dense_heads/rpn_head.py
+@@ -9,6 +9,67 @@ from .anchor_head import AnchorHead
+ from .rpn_test_mixin import RPNTestMixin
+
+
++class BatchNMSOp(torch.autograd.Function):
++ @staticmethod
++ def forward(ctx, bboxes, scores, score_threshold, iou_threshold, max_size_per_class, max_total_size):
++ """
++ boxes (torch.Tensor): boxes in shape (batch, N, C, 4).
++ scores (torch.Tensor): scores in shape (batch, N, C).
++ return:
++ nmsed_boxes: (1, N, 4)
++ nmsed_scores: (1, N)
++ nmsed_classes: (1, N)
++ nmsed_num: (1,)
++ """
++
++ # Phony implementation for onnx export
++ nmsed_boxes = bboxes[:, :max_total_size, 0, :]
++ nmsed_scores = scores[:, :max_total_size, 0]
++ nmsed_classes = torch.arange(max_total_size, dtype=torch.long)
++ nmsed_num = torch.Tensor([max_total_size])
++
++ return nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num
++
++ @staticmethod
++ def symbolic(g, bboxes, scores, score_thr, iou_thr, max_size_p_class, max_t_size):
++ nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num = g.op('BatchMultiClassNMS',
++ bboxes, scores, score_threshold_f=score_thr, iou_threshold_f=iou_thr,
++ max_size_per_class_i=max_size_p_class, max_total_size_i=max_t_size, outputs=4)
++ return nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num
++
++def batch_nms_op(bboxes, scores, score_threshold, iou_threshold, max_size_per_class, max_total_size):
++ """
++ boxes (torch.Tensor): boxes in shape (N, 4).
++ scores (torch.Tensor): scores in shape (N, ).
++ """
++ if torch.onnx.is_in_onnx_export():
++ if bboxes.dtype == torch.float32:
++ bboxes = bboxes.reshape(1, bboxes.shape[0].numpy(), -1, 4).half()
++ scores = scores.reshape(1, scores.shape[0].numpy(), -1).half()
++ else:
++ bboxes = bboxes.reshape(1, bboxes.shape[0].numpy(), -1, 4)
++ scores = scores.reshape(1, scores.shape[0].numpy(), -1)
++ else:
++ if bboxes.dtype == torch.float32:
++ bboxes = bboxes.reshape(1, bboxes.shape[0], -1, 4).half()
++ scores = scores.reshape(1, scores.shape[0], -1).half()
++ else:
++ bboxes = bboxes.reshape(1, bboxes.shape[0], -1, 4)
++ scores = scores.reshape(1, scores.shape[0], -1)
++
++ batch_nms = torch.ops.aie.batch_nms
++ nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num = batch_nms(bboxes, scores,
++ score_threshold, iou_threshold,
++ max_size_per_class, max_total_size)
++ # nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num = BatchNMSOp.apply(bboxes, scores,
++ # score_threshold, iou_threshold, max_size_per_class, max_total_size) # max_total_size num_bbox
++ nmsed_boxes = nmsed_boxes.float()
++ nmsed_scores = nmsed_scores.float()
++ nmsed_classes = nmsed_classes.long()
++ dets = torch.cat((nmsed_boxes.reshape((max_total_size, 4)), nmsed_scores.reshape((max_total_size, 1))), -1)
++ labels = nmsed_classes.reshape((max_total_size, ))
++ return dets, labels
++
+ @HEADS.register_module()
+ class RPNHead(RPNTestMixin, AnchorHead):
+ """RPN head.
+@@ -132,9 +193,12 @@ class RPNHead(RPNTestMixin, AnchorHead):
+ if cfg.nms_pre > 0 and scores.shape[0] > cfg.nms_pre:
+ # sort is faster than topk
+ # _, topk_inds = scores.topk(cfg.nms_pre)
+- ranked_scores, rank_inds = scores.sort(descending=True)
+- topk_inds = rank_inds[:cfg.nms_pre]
+- scores = ranked_scores[:cfg.nms_pre]
++ # onnx uses topk to sort, this is simpler for onnx export
++ if torch.onnx.is_in_onnx_export():
++ scores, topk_inds = torch.topk(scores, cfg.nms_pre)
++ else:
++ scores, topk_inds = torch.topk(scores, cfg.nms_pre)
++
+ rpn_bbox_pred = rpn_bbox_pred[topk_inds, :]
+ anchors = anchors[topk_inds, :]
+ mlvl_scores.append(scores)
+@@ -164,5 +228,12 @@ class RPNHead(RPNTestMixin, AnchorHead):
+
+ # TODO: remove the hard coded nms type
+ nms_cfg = dict(type='nms', iou_threshold=cfg.nms_thr)
+- dets, keep = batched_nms(proposals, scores, ids, nms_cfg)
+- return dets[:cfg.nms_post]
++ # npu return
++ if torch.onnx.is_in_onnx_export():
++ dets, labels = batch_nms_op(proposals, scores, 0.0, nms_cfg.get("iou_threshold"), cfg.nms_post, cfg.nms_post)
++ return dets
++ # cpu and gpu return
++ else:
++ dets, labels = batch_nms_op(proposals, scores, 0.0, nms_cfg.get("iou_threshold"), cfg.nms_post,
++ cfg.nms_post)
++ return dets
+diff --git a/mmdet/models/detectors/base.py b/mmdet/models/detectors/base.py
+index 7c6d5e96..7e053ae7 100644
+--- a/mmdet/models/detectors/base.py
++++ b/mmdet/models/detectors/base.py
+@@ -131,6 +131,11 @@ class BaseDetector(nn.Module, metaclass=ABCMeta):
+ augs (multiscale, flip, etc.) and the inner list indicates
+ images in a batch.
+ """
++ if not isinstance(imgs, list):
++ imgs = [imgs]
++ if not isinstance(img_metas, list):
++ img_metas = [img_metas]
++
+ for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]:
+ if not isinstance(var, list):
+ raise TypeError(f'{name} must be a list, but got {type(var)}')
+diff --git a/mmdet/models/detectors/single_stage.py b/mmdet/models/detectors/single_stage.py
+index 96c4acac..929bf2a0 100644
+--- a/mmdet/models/detectors/single_stage.py
++++ b/mmdet/models/detectors/single_stage.py
+@@ -114,8 +114,9 @@ class SingleStageDetector(BaseDetector):
+ bbox_list = self.bbox_head.get_bboxes(
+ *outs, img_metas, rescale=rescale)
+ # skip post-processing when exporting to ONNX
+- if torch.onnx.is_in_onnx_export():
+- return bbox_list
++ # if torch.onnx.is_in_onnx_export():
++ # return bbox_list
++ return bbox_list
+
+ bbox_results = [
+ bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes)
+diff --git a/mmdet/models/roi_heads/cascade_roi_head.py b/mmdet/models/roi_heads/cascade_roi_head.py
+index 45b6f36a..6f8b9245 100644
+--- a/mmdet/models/roi_heads/cascade_roi_head.py
++++ b/mmdet/models/roi_heads/cascade_roi_head.py
+@@ -349,8 +349,9 @@ class CascadeRoIHead(BaseRoIHead, BBoxTestMixin, MaskTestMixin):
+ det_bboxes.append(det_bbox)
+ det_labels.append(det_label)
+
+- if torch.onnx.is_in_onnx_export():
+- return det_bboxes, det_labels
++ # if torch.onnx.is_in_onnx_export():
++ # return det_bboxes, det_labels
++ return det_bboxes, det_labels
+ bbox_results = [
+ bbox2result(det_bboxes[i], det_labels[i],
+ self.bbox_head[-1].num_classes)
+diff --git a/mmdet/models/roi_heads/mask_heads/fcn_mask_head.py b/mmdet/models/roi_heads/mask_heads/fcn_mask_head.py
+index 0cba3cda..38726b0c 100644
+--- a/mmdet/models/roi_heads/mask_heads/fcn_mask_head.py
++++ b/mmdet/models/roi_heads/mask_heads/fcn_mask_head.py
+@@ -204,6 +204,14 @@ class FCNMaskHead(nn.Module):
+ if thr > 0:
+ masks = masks >= thr
+ return masks
++ else:
++ from torchvision.models.detection.roi_heads \
++ import paste_masks_in_image
++ masks = paste_masks_in_image(mask_pred, bboxes, ori_shape[:2])
++ thr = rcnn_test_cfg.get('mask_thr_binary', 0)
++ if thr > 0:
++ masks = masks >= thr
++ return masks
+
+ N = len(mask_pred)
+ # The actual implementation split the input into chunks,
+@@ -316,9 +324,9 @@ def _do_paste_mask(masks, boxes, img_h, img_w, skip_empty=True):
+ gy = img_y[:, :, None].expand(N, img_y.size(1), img_x.size(1))
+ grid = torch.stack([gx, gy], dim=3)
+
+- if torch.onnx.is_in_onnx_export():
+- raise RuntimeError(
+- 'Exporting F.grid_sample from Pytorch to ONNX is not supported.')
++ # if torch.onnx.is_in_onnx_export():
++ raise RuntimeError(
++ 'Exporting F.grid_sample from Pytorch to ONNX is not supported.')
+ img_masks = F.grid_sample(
+ masks.to(dtype=torch.float32), grid, align_corners=False)
+
+diff --git a/mmdet/models/roi_heads/roi_extractors/single_level_roi_extractor.py b/mmdet/models/roi_heads/roi_extractors/single_level_roi_extractor.py
+index c0eebc4a..d3beb385 100644
+--- a/mmdet/models/roi_heads/roi_extractors/single_level_roi_extractor.py
++++ b/mmdet/models/roi_heads/roi_extractors/single_level_roi_extractor.py
+@@ -4,6 +4,30 @@ from mmcv.runner import force_fp32
+ from mmdet.models.builder import ROI_EXTRACTORS
+ from .base_roi_extractor import BaseRoIExtractor
+
++import torch.onnx.symbolic_helper as sym_help
++
++class RoiExtractor(torch.autograd.Function):
++ @staticmethod
++ def forward(self, f0, f1, f2, f3, rois, aligned=1, finest_scale=56, pooled_height=7, pooled_width=7,
++ pool_mode='avg', roi_scale_factor=0, sample_num=0, spatial_scale=[0.25, 0.125, 0.0625, 0.03125]):
++ """
++ feats (torch.Tensor): feats in shape (batch, 256, H, W).
++ rois (torch.Tensor): rois in shape (k, 5).
++ return:
++ roi_feats (torch.Tensor): (k, 256, pooled_width, pooled_width)
++ """
++
++ # phony implementation for shape inference
++ k = rois.size()[0]
++ roi_feats = torch.ones(k, 256, pooled_height, pooled_width)
++ return roi_feats
++
++ @staticmethod
++ def symbolic(g, f0, f1, f2, f3, rois):
++ # TODO: support tensor list type for feats
++ roi_feats = g.op('RoiExtractor', f0, f1, f2, f3, rois, aligned_i=1, finest_scale_i=56, pooled_height_i=7, pooled_width_i=7,
++ pool_mode_s='avg', roi_scale_factor_i=0, sample_num_i=0, spatial_scale_f=[0.25, 0.125, 0.0625, 0.03125], outputs=1)
++ return roi_feats
+
+ @ROI_EXTRACTORS.register_module()
+ class SingleRoIExtractor(BaseRoIExtractor):
+@@ -52,6 +76,18 @@ class SingleRoIExtractor(BaseRoIExtractor):
+
+ @force_fp32(apply_to=('feats', ), out_fp16=True)
+ def forward(self, feats, rois, roi_scale_factor=None):
++ # Work around to export onnx for npu
++ if torch.onnx.is_in_onnx_export():
++ roi_feats = RoiExtractor.apply(feats[0], feats[1], feats[2], feats[3], rois)
++ # roi_feats = RoiExtractor.apply(list(feats), rois)
++ return roi_feats
++ else:
++ # roi_feats = RoiExtractor.apply(feats[0], feats[1], feats[2], feats[3], rois)
++ # roi_feats = RoiExtractor.apply(list(feats), rois)
++ roi_extractor = torch.ops.aie.roi_extractor
++ roi_feats = roi_extractor([feats[0], feats[1], feats[2], feats[3]], rois, 1, 56, 7, 7, 'avg', 0, 0, [0.25, 0.125, 0.0625, 0.03125])
++ return roi_feats
++
+ """Forward function."""
+ out_size = self.roi_layers[0].output_size
+ num_levels = len(feats)
+@@ -82,12 +118,12 @@ class SingleRoIExtractor(BaseRoIExtractor):
+ mask = target_lvls == i
+ inds = mask.nonzero(as_tuple=False).squeeze(1)
+ # TODO: make it nicer when exporting to onnx
+- if torch.onnx.is_in_onnx_export():
+- # To keep all roi_align nodes exported to onnx
+- rois_ = rois[inds]
+- roi_feats_t = self.roi_layers[i](feats[i], rois_)
+- roi_feats[inds] = roi_feats_t
+- continue
++ # if torch.onnx.is_in_onnx_export():
++ # To keep all roi_align nodes exported to onnx
++ rois_ = rois[inds]
++ roi_feats_t = self.roi_layers[i](feats[i], rois_)
++ roi_feats[inds] = roi_feats_t
++ continue
+ if inds.numel() > 0:
+ rois_ = rois[inds]
+ roi_feats_t = self.roi_layers[i](feats[i], rois_)
+diff --git a/mmdet/models/roi_heads/standard_roi_head.py b/mmdet/models/roi_heads/standard_roi_head.py
+index c530f2a5..bacba384 100644
+--- a/mmdet/models/roi_heads/standard_roi_head.py
++++ b/mmdet/models/roi_heads/standard_roi_head.py
+@@ -246,13 +246,13 @@ class StandardRoIHead(BaseRoIHead, BBoxTestMixin, MaskTestMixin):
+
+ det_bboxes, det_labels = self.simple_test_bboxes(
+ x, img_metas, proposal_list, self.test_cfg, rescale=rescale)
+- if torch.onnx.is_in_onnx_export():
+- if self.with_mask:
+- segm_results = self.simple_test_mask(
+- x, img_metas, det_bboxes, det_labels, rescale=rescale)
+- return det_bboxes, det_labels, segm_results
+- else:
+- return det_bboxes, det_labels
++ # if torch.onnx.is_in_onnx_export():
++ if self.with_mask:
++ segm_results = self.simple_test_mask(
++ x, img_metas, det_bboxes, det_labels, rescale=rescale)
++ return det_bboxes, det_labels, segm_results
++ else:
++ return det_bboxes, det_labels
+
+ bbox_results = [
+ bbox2result(det_bboxes[i], det_labels[i],
+diff --git a/mmdet/models/roi_heads/test_mixins.py b/mmdet/models/roi_heads/test_mixins.py
+index 0e675d6e..171a21c3 100644
+--- a/mmdet/models/roi_heads/test_mixins.py
++++ b/mmdet/models/roi_heads/test_mixins.py
+@@ -211,19 +211,18 @@ class MaskTestMixin(object):
+ mask_result = self._mask_forward(x, mask_rois)
+ mask_preds.append(mask_result['mask_pred'])
+ else:
+- _bboxes = [
+- det_bboxes[i][:, :4] *
+- scale_factors[i] if rescale else det_bboxes[i][:, :4]
+- for i in range(len(det_bboxes))
+- ]
+- mask_rois = bbox2roi(_bboxes)
+- mask_results = self._mask_forward(x, mask_rois)
+- mask_pred = mask_results['mask_pred']
+- # split batch mask prediction back to each image
+- num_mask_roi_per_img = [
+- det_bbox.shape[0] for det_bbox in det_bboxes
+- ]
+- mask_preds = mask_pred.split(num_mask_roi_per_img, 0)
++ # avoid mask_pred.split with static number of prediction
++ mask_preds = []
++ _bboxes = []
++ for i, boxes in enumerate(det_bboxes):
++ boxes = boxes[:, :4]
++ if rescale:
++ boxes *= scale_factors[i]
++ _bboxes.append(boxes)
++ img_inds = boxes[:, :1].clone() * 0 + i
++ mask_rois = torch.cat([img_inds, boxes], dim=-1)
++ mask_result = self._mask_forward(x, mask_rois)
++ mask_preds.append(mask_result['mask_pred'])
+
+ # apply mask post-processing to each image individually
+ segm_results = []
+--
+2.25.1
+
diff --git a/AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/requirements.txt b/AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..fa06f3328159e6d837eec3845f627ae21c9f01bc
--- /dev/null
+++ b/AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/requirements.txt
@@ -0,0 +1,16 @@
+numpy
+decorator
+attrs
+psutil
+tqdm
+onnx==1.7.0
+Pillow==9.2.0
+opencv-python==4.6.0.66
+torch==1.10.0
+torchvision==0.11.1
+protobuf==3.20.0
+mmcv-full==1.2.4
+onnxruntime==1.12.1
+onnxoptimizer==0.2.7
+terminaltables==3.1.10
+
diff --git a/AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/sample.py b/AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/sample.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe98e3a40300cc4e6e67d348b29f937970e3568f
--- /dev/null
+++ b/AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/sample.py
@@ -0,0 +1,120 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import time
+import argparse
+import numpy as np
+from tqdm import tqdm
+import torch
+import torchvision
+
+import torch_aie
+
+def generate_options():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--traced_model_path", type=str, default="./faster_rcnn_r50_fpn_trace_20231130.pt", help="path to traced-model")
+ parser.add_argument("--bin_path", type=str, default="./coco_val_bin", help="path to bin files")
+ parser.add_argument("--batch_size", type=int, default=1, help="set batch size, default is 1")
+ parser.add_argument("--image_size", type=int, default=1216, help="set image size, default is 1216")
+ parser.add_argument("--ts_save_path", type=str, default="./faster_rcnn_r50_fpn_aie.pt", help="path to save results")
+ parser.add_argument("--results_save_path", type=str, default="./results", help="path to save results")
+ parser.add_argument("--device_id", type=int, default=0, help="set device, default is 0")
+ return parser.parse_args()
+
+def compile(model_path, save_path, batch_size, image_size):
+ model = torch.jit.load(model_path)
+ model.eval()
+ print('Model loaded.')
+
+ input_info = [torch_aie.Input((batch_size, 3, image_size, image_size))]
+
+ compiled_model = torch_aie.compile(
+ model,
+ inputs=input_info,
+ precision_policy=torch_aie._enums.PrecisionPolicy.FP16,
+ soc_version="Ascend310P3")
+ print('Model compiled successfully.')
+ compiled_model.save(save_path)
+ return compiled_model
+
+def inference(val_bin_path, save_path, model, device_id, image_size):
+ torch_aie.set_device(device_id)
+ device = f'npu:{device_id}'
+ stream = torch_aie.npu.Stream(device)
+ model.eval()
+ file_list = sorted(os.listdir(val_bin_path))
+ pbar = tqdm(file_list)
+ for file_name in pbar:
+ # generate input
+ bin_file_path = os.path.join(val_bin_path, file_name)
+ data = np.fromfile(bin_file_path, dtype=np.float32).reshape(1, 3, image_size, image_size)
+ image = torch.from_numpy(data).to(device)
+
+ # infer
+ with torch_aie.npu.stream(stream):
+ aie_results = model(image)
+ stream.synchronize()
+
+ boxes = aie_results[0][0].to("cpu").numpy()
+ scores = aie_results[1][0].to("cpu").numpy()
+ base_name = file_name.split(".")[0]
+ result_0_save_path = os.path.join(save_path, base_name + "_0.bin")
+ result_1_save_path = os.path.join(save_path, base_name + "_1.bin")
+ boxes.tofile(result_0_save_path)
+ scores.tofile(result_1_save_path)
+ pbar.set_description_str("Process " + file_name + " done.")
+
+def performance_test(model, device_id, batch_size, image_size):
+ torch_aie.set_device(device_id)
+
+ random_input = torch.rand(batch_size, 3, image_size, image_size)
+ device = f'npu:{device_id}'
+ random_input = random_input.to(device)
+ stream = torch_aie.npu.Stream(device)
+
+ model.eval()
+
+ # warm up
+ num_warmup = 50
+ for _ in range(num_warmup):
+ with torch_aie.npu.stream(stream):
+ model(random_input)
+ stream.synchronize()
+ print('warmup done.')
+
+ # performance test
+ print('Start performance test.')
+ num_infer = 500
+ start = time.time()
+ for _ in range(num_infer):
+ with torch_aie.npu.stream(stream):
+ model(random_input)
+ stream.synchronize()
+ avg_time = (time.time() - start) / num_infer
+ fps = batch_size / avg_time
+ print(f'FPS: {fps:.4f}')
+
+
+if __name__ == "__main__":
+ opts = generate_options()
+
+ # compile
+ compiled_model = compile(opts.traced_model_path, opts.ts_save_path, opts.batch_size, opts.image_size)
+
+ # infer
+ inference(opts.bin_path, opts.results_save_path, compiled_model, opts.device_id, opts.image_size)
+
+ # performance test
+ performance_test(compiled_model, opts.device_id, opts.batch_size, opts.image_size)
\ No newline at end of file
diff --git a/AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/trace.py b/AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/trace.py
new file mode 100644
index 0000000000000000000000000000000000000000..bb07fcc8ba3d0734c73764bba84075a5b72b81af
--- /dev/null
+++ b/AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/trace.py
@@ -0,0 +1,86 @@
+import argparse
+import os.path as osp
+
+import numpy as np
+import onnx
+import onnxruntime as rt
+import torch
+import sys,os
+
+from mmdet.core import (build_model_from_cfg, generate_inputs_and_wrap_model,
+ preprocess_example_input)
+
+def trace(
+ config_path,
+ checkpoint_path,
+ input_img,
+ input_shape,
+ output_file='tmp.pt',
+ normalize_cfg=None):
+
+ input_config = {
+ 'input_shape': input_shape,
+ 'input_path': input_img,
+ 'normalize_cfg': normalize_cfg
+ }
+
+ model, tensor_data = generate_inputs_and_wrap_model(
+ config_path, checkpoint_path, input_config)
+
+ model.eval()
+ torch.jit.trace(model, tensor_data).save(output_file)
+
+def parse_args():
+ parser = argparse.ArgumentParser(
+ description='Trace MMDetection models')
+ parser.add_argument('--config', help='test config file path')
+ parser.add_argument('--checkpoint', help='checkpoint file')
+ parser.add_argument('--input-img', type=str, help='Images for input')
+ parser.add_argument('--output-file', type=str, default='tmp.onnx')
+ parser.add_argument('--mmdet-ops-path', type=str, default="")
+ parser.add_argument(
+ '--shape',
+ type=int,
+ nargs='+',
+ default=[800, 1216],
+ help='input image size')
+ parser.add_argument(
+ '--mean',
+ type=float,
+ nargs='+',
+ default=[123.675, 116.28, 103.53],
+ help='mean value used for preprocess input data')
+ parser.add_argument(
+ '--std',
+ type=float,
+ nargs='+',
+ default=[58.395, 57.12, 57.375],
+ help='variance value used for preprocess input data')
+ args = parser.parse_args()
+ return args
+
+
+if __name__ == '__main__':
+ args = parse_args()
+
+ torch.ops.load_library(args.mmdet_ops_path)
+
+ if len(args.shape) == 1:
+ input_shape = (1, 3, args.shape[0], args.shape[0])
+ elif len(args.shape) == 2:
+ input_shape = (1, 3) + tuple(args.shape)
+ else:
+ raise ValueError('invalid input shape')
+
+ assert len(args.mean) == 3
+ assert len(args.std) == 3
+
+ normalize_cfg = {'mean': args.mean, 'std': args.std}
+
+ trace(
+ args.config,
+ args.checkpoint,
+ args.input_img,
+ input_shape,
+ args.output_file,
+ normalize_cfg=normalize_cfg,)