From aa1971bc770ff56a4359b9ea5e9bf26d191e72b8 Mon Sep 17 00:00:00 2001 From: hekuikui Date: Mon, 11 Dec 2023 18:02:04 +0800 Subject: [PATCH 1/5] add faster_rcnn_res50_fpn for torch_aie --- .../cv/detection/FasterRcnn/coco_eval.py | 78 +++++++++ .../cv/detection/FasterRcnn/get_info.py | 46 ++++++ .../mmdetection_coco_postprocess.py | 149 ++++++++++++++++++ .../FasterRcnn/mmdetection_coco_preprocess.py | 70 ++++++++ .../cv/detection/FasterRcnn/sample.py | 122 ++++++++++++++ .../cv/detection/FasterRcnn/txt_to_json.py | 114 ++++++++++++++ 6 files changed, 579 insertions(+) create mode 100644 AscendIE/TorchAIE/built-in/cv/detection/FasterRcnn/coco_eval.py create mode 100644 AscendIE/TorchAIE/built-in/cv/detection/FasterRcnn/get_info.py create mode 100644 AscendIE/TorchAIE/built-in/cv/detection/FasterRcnn/mmdetection_coco_postprocess.py create mode 100644 AscendIE/TorchAIE/built-in/cv/detection/FasterRcnn/mmdetection_coco_preprocess.py create mode 100644 AscendIE/TorchAIE/built-in/cv/detection/FasterRcnn/sample.py create mode 100644 AscendIE/TorchAIE/built-in/cv/detection/FasterRcnn/txt_to_json.py diff --git a/AscendIE/TorchAIE/built-in/cv/detection/FasterRcnn/coco_eval.py b/AscendIE/TorchAIE/built-in/cv/detection/FasterRcnn/coco_eval.py new file mode 100644 index 0000000000..69d0adcf83 --- /dev/null +++ b/AscendIE/TorchAIE/built-in/cv/detection/FasterRcnn/coco_eval.py @@ -0,0 +1,78 @@ +import argparse +import numpy as np +from pycocotools.coco import COCO +from pycocotools.cocoeval import COCOeval + +CLASSES = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', + 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', + 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', + 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', + 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', + 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', + 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', + 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', + 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', + 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', + 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', + 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', + 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', + 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush') + +def coco_evaluation(annotation_json, result_json): + cocoGt = COCO(annotation_json) + cocoDt = cocoGt.loadRes(result_json) + iou_thrs = np.linspace(.5, 0.95, int(np.round((0.95 - .5) / .05)) + 1, endpoint=True) + iou_type = 'bbox' + + cocoEval = COCOeval(cocoGt, cocoDt, iou_type) + cocoEval.params.catIds = cocoGt.getCatIds(catNms=CLASSES) + cocoEval.params.imgIds = cocoGt.getImgIds() + cocoEval.params.maxDets = [100, 300, 1000] # proposal number for evaluating recalls/mAPs. + cocoEval.params.iouThrs = iou_thrs + + cocoEval.evaluate() + cocoEval.accumulate() + cocoEval.summarize() + + # mapping of cocoEval.stats + coco_metric_names = { + 'mAP': 0, + 'mAP_50': 1, + 'mAP_75': 2, + 'mAP_s': 3, + 'mAP_m': 4, + 'mAP_l': 5, + 'AR@100': 6, + 'AR@300': 7, + 'AR@1000': 8, + 'AR_s@1000': 9, + 'AR_m@1000': 10, + 'AR_l@1000': 11 + } + + metric_items = ['mAP', 'mAP_50', 'mAP_75', 'mAP_s', 'mAP_m', 'mAP_l'] + eval_results = {} + + for metric_item in metric_items: + key = f'bbox_{metric_item}' + val = float( + f'{cocoEval.stats[coco_metric_names[metric_item]]:.3f}' + ) + eval_results[key] = val + ap = cocoEval.stats[:6] + eval_results['bbox_mAP_copypaste'] = ( + f'{ap[0]:.3f} {ap[1]:.3f} {ap[2]:.3f} {ap[3]:.3f} ' + f'{ap[4]:.3f} {ap[5]:.3f}') + + return eval_results + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--ground_truth", default="instances_val2017.json") + parser.add_argument("--detection_result", default="coco_detection_result.json") + args = parser.parse_args() + result = coco_evaluation(args.ground_truth, args.detection_result) + print(result) + with open('./coco_detection_result.txt', 'w') as f: + for key, value in result.items(): + f.write(key + ': ' + str(value) + '\n') \ No newline at end of file diff --git a/AscendIE/TorchAIE/built-in/cv/detection/FasterRcnn/get_info.py b/AscendIE/TorchAIE/built-in/cv/detection/FasterRcnn/get_info.py new file mode 100644 index 0000000000..7b14c54b90 --- /dev/null +++ b/AscendIE/TorchAIE/built-in/cv/detection/FasterRcnn/get_info.py @@ -0,0 +1,46 @@ +import os +import sys +import cv2 +from glob import glob + + +def get_bin_info(file_path, info_name, width, height): + bin_images = glob(os.path.join(file_path, '*.bin')) + with open(info_name, 'w') as file: + for index, img in enumerate(bin_images): + content = ' '.join([str(index), img, width, height]) + file.write(content) + file.write('\n') + + +def get_jpg_info(file_path, info_name): + extensions = ['jpg', 'jpeg', 'JPG', 'JPEG'] + image_names = [] + for extension in extensions: + image_names.append(glob(os.path.join(file_path, '*.' + extension))) + with open(info_name, 'w') as file: + for image_name in image_names: + if len(image_name) == 0: + continue + else: + for index, img in enumerate(image_name): + img_cv = cv2.imread(img) + shape = img_cv.shape + width, height = shape[1], shape[0] + content = ' '.join([str(index), img, str(width), str(height)]) + file.write(content) + file.write('\n') + + +if __name__ == '__main__': + file_type = sys.argv[1] + file_path = sys.argv[2] + info_name = sys.argv[3] + if file_type == 'bin': + width = sys.argv[4] + height = sys.argv[5] + assert len(sys.argv) == 6, 'The number of input parameters must be equal to 5' + get_bin_info(file_path, info_name, width, height) + elif file_type == 'jpg': + assert len(sys.argv) == 4, 'The number of input parameters must be equal to 3' + get_jpg_info(file_path, info_name) \ No newline at end of file diff --git a/AscendIE/TorchAIE/built-in/cv/detection/FasterRcnn/mmdetection_coco_postprocess.py b/AscendIE/TorchAIE/built-in/cv/detection/FasterRcnn/mmdetection_coco_postprocess.py new file mode 100644 index 0000000000..054aad4bfd --- /dev/null +++ b/AscendIE/TorchAIE/built-in/cv/detection/FasterRcnn/mmdetection_coco_postprocess.py @@ -0,0 +1,149 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import numpy as np +import argparse +import cv2 + +CLASSES = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', + 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', + 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', + 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', + 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', + 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', + 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', + 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', + 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', + 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', + 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', + 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', + 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', + 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'] + +def coco_postprocess(bbox: np.ndarray, image_size, + net_input_width, net_input_height): + """ + This function is postprocessing for FasterRCNN output. + + Before calling this function, reshape the raw output of FasterRCNN to + following form + numpy.ndarray: + [x, y, width, height, confidence, probability of 80 classes] + shape: (100,) + The postprocessing restore the bounding rectangles of FasterRCNN output + to origin scale and filter with non-maximum suppression. + + :param bbox: a numpy array of the FasterRCNN output + :param image_path: a string of image path + :return: three list for best bound, class and score + """ + w = image_size[0] + h = image_size[1] + scale = min(net_input_width / w, net_input_height / h) + + pad_w = net_input_width - w * scale + pad_h = net_input_height - h * scale + pad_left = pad_w // 2 + pad_top = pad_h // 2 + + # cal predict box on the image src + pbox = bbox + pbox[:, 0] = (bbox[:, 0] - pad_left) / scale + pbox[:, 1] = (bbox[:, 1] - pad_top) / scale + pbox[:, 2] = (bbox[:, 2] - pad_left) / scale + pbox[:, 3] = (bbox[:, 3] - pad_top) / scale + return pbox + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--bin_data_path", default="./result/dumpOutput_device0") + parser.add_argument("--test_annotation", default="./coco2017_jpg.info") + parser.add_argument("--det_results_path", default="./detection-results/") + parser.add_argument("--net_out_num", default=2) + parser.add_argument("--net_input_width", default=1216) + parser.add_argument("--net_input_height", default=1216) + parser.add_argument("--prob_thres", default=0.05) + parser.add_argument("--ifShowDetObj", action="store_true", help="if input the para means True, neither False.") + flags = parser.parse_args() + print(flags.ifShowDetObj, type(flags.ifShowDetObj)) + # generate dict according to annotation file for query resolution + # load width and height of input images + img_size_dict = dict() + with open(flags.test_annotation)as f: + for line in f.readlines(): + temp = line.split(" ") + img_file_path = temp[1] + img_name = temp[1].split("/")[-1].split(".")[0] + img_width = int(temp[2]) + img_height = int(temp[3]) + img_size_dict[img_name] = (img_width, img_height, img_file_path) + + # read bin file for generate predict result + bin_path = flags.bin_data_path + det_results_path = flags.det_results_path + os.makedirs(det_results_path, exist_ok=True) + total_img = set([name[:name.rfind('_')] + for name in os.listdir(bin_path) if "bin" in name]) + + for bin_file in sorted(total_img): + path_base = os.path.join(bin_path, bin_file) + # load all detected output tensor + res_buff = [] + for num in range(0, flags.net_out_num): + if os.path.exists(path_base + "_" + str(num) + ".bin"): + if num == 0: + buf = np.fromfile(path_base + "_" + str(num) + ".bin", dtype="float32") + buf = np.reshape(buf, [100, 5]) + elif num == 1: + buf = np.fromfile(path_base + "_" + str(num) + ".bin", dtype="int64") + buf = np.reshape(buf, [100, 1]) + res_buff.append(buf) + else: + print("[ERROR] file not exist", path_base + "_" + str(num) + ".bin") + res_tensor = np.concatenate(res_buff, axis=1) + current_img_size = img_size_dict[bin_file] + print("[TEST]---------------------------concat{} imgsize{}".format(len(res_tensor), current_img_size)) + predbox = coco_postprocess(res_tensor, current_img_size, flags.net_input_width, flags.net_input_height) + + if flags.ifShowDetObj == True: + imgCur = cv2.imread(current_img_size[2]) + + det_results_str = '' + for idx, class_ind in enumerate(predbox[:,5]): + if float(predbox[idx][4]) < float(flags.prob_thres): + continue + # skip negative class index + if class_ind < 0 or class_ind > 80: + continue + + class_name = CLASSES[int(class_ind)] + det_results_str += "{} {} {} {} {} {}\n".format(class_name, str(predbox[idx][4]), predbox[idx][0], + predbox[idx][1], predbox[idx][2], predbox[idx][3]) + if flags.ifShowDetObj == True: + imgCur=cv2.rectangle(imgCur, (int(predbox[idx][0]), int(predbox[idx][1])), + (int(predbox[idx][2]), int(predbox[idx][3])), (0,255,0), 1) + imgCur = cv2.putText(imgCur, class_name+'|'+str(predbox[idx][4]), + (int(predbox[idx][0]), int(predbox[idx][1])), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1) + # 图像,文字内容, 坐标 ,字体,大小,颜色,字体厚度 + + if flags.ifShowDetObj == True: + print(os.path.join(det_results_path, bin_file +'.jpg')) + cv2.imwrite(os.path.join(det_results_path, bin_file +'.jpg'), imgCur, [int(cv2.IMWRITE_JPEG_QUALITY),70]) + + det_results_file = os.path.join(det_results_path, bin_file + ".txt") + with open(det_results_file, "w") as detf: + detf.write(det_results_str) + print(det_results_str) diff --git a/AscendIE/TorchAIE/built-in/cv/detection/FasterRcnn/mmdetection_coco_preprocess.py b/AscendIE/TorchAIE/built-in/cv/detection/FasterRcnn/mmdetection_coco_preprocess.py new file mode 100644 index 0000000000..acf0029ccb --- /dev/null +++ b/AscendIE/TorchAIE/built-in/cv/detection/FasterRcnn/mmdetection_coco_preprocess.py @@ -0,0 +1,70 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import os +import cv2 +import argparse +import mmcv +import torch +from tqdm import tqdm + +dataset_config = { + 'resize': (1216, 1216), + 'mean': [123.675, 116.28, 103.53], + 'std': [58.395, 57.12, 57.375], +} + +tensor_height = 1216 +tensor_width = 1216 + +def coco_preprocess(input_image, output_bin_path): + #define the output file name + img_name = input_image.split('/')[-1] + bin_name = img_name.split('.')[0] + ".bin" + bin_fl = os.path.join(output_bin_path, bin_name) + + one_img = mmcv.imread(os.path.join(input_image), backend='cv2') + one_img = mmcv.imrescale(one_img, (tensor_height, tensor_width)) + # calculate padding + h = one_img.shape[0] + w = one_img.shape[1] + pad_left = (tensor_width - w) // 2 + pad_top = (tensor_height - h) // 2 + pad_right = tensor_width - pad_left - w + pad_bottom = tensor_height - pad_top - h + + mean = np.array(dataset_config['mean'], dtype=np.float32) + std = np.array(dataset_config['std'], dtype=np.float32) + one_img = mmcv.imnormalize(one_img, mean, std) + one_img = mmcv.impad(one_img, padding=(pad_left, pad_top, pad_right, pad_bottom), pad_val=0) + one_img = one_img.transpose(2, 0, 1) + one_img.tofile(bin_fl) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='preprocess of FasterRCNN pytorch model') + parser.add_argument("--image_folder_path", default="./coco2014/", help='image of dataset') + parser.add_argument("--bin_folder_path", default="./coco2014_bin/", help='Preprocessed image buffer') + flags = parser.parse_args() + + if not os.path.exists(flags.bin_folder_path): + os.makedirs(flags.bin_folder_path) + print("Start to process images...") + images = tqdm(os.listdir(flags.image_folder_path)) + for image_name in images: + if not (image_name.endswith(".jpeg") or image_name.endswith(".JPEG") or image_name.endswith(".jpg")): + continue + path_image = os.path.join(flags.image_folder_path, image_name) + coco_preprocess(path_image, flags.bin_folder_path) + print("Done.") diff --git a/AscendIE/TorchAIE/built-in/cv/detection/FasterRcnn/sample.py b/AscendIE/TorchAIE/built-in/cv/detection/FasterRcnn/sample.py new file mode 100644 index 0000000000..0f490cc893 --- /dev/null +++ b/AscendIE/TorchAIE/built-in/cv/detection/FasterRcnn/sample.py @@ -0,0 +1,122 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time +import argparse +import numpy as np +from tqdm import tqdm +import torch +import torchvision + +import torch_aie + +def generate_options(): + parser = argparse.ArgumentParser() + parser.add_argument("--traced-model-path", type=str, help="path to traced-model") + parser.add_argument("--bin-path", type=str, help="path to bin files") + parser.add_argument("--batch-size", type=int, default=1, help="set batch size, default is 1") + parser.add_argument("--image-size", type=int, default=1216, help="set image size, default is 1216") + parser.add_argument("--ts-save-path", type=str, default="'./faster_rcnn_r50_fpn_aie.pt'", help="path to save results") + parser.add_argument("--results-save-path", type=str, default="./results", help="path to save results") + parser.add_argument("--device-id", type=int, default=0, help="set device, default is 0") + return parser.parse_args + +def compile(model_path, save_path, batch_size, image_size): + model = torch.jit.load(model_path) + model.eval() + print('Model loaded.') + + input_info = [torch_aie.Input((batch_size, 3, image_size, image_size))] + + compiled_model = torch_aie.compile( + model, + inputs=input_info, + precision_policy=torch_aie._enums.PrecisionPolicy.FP16, + soc_version="Ascend310P3") + print('Model compiled successfully.') + compiled_model.save(save_path) + return compiled_model + +def inference(val_bin_path, save_path, model, device_id, image_size): + torch_aie.set_device(device_id) + device = f'npu:{device_id}' + stream = torch_aie.npu.Stream(device) + + file_list = sorted(os.listdir(val_bin_path)) + pbar = tqdm(file_list) + for file_name in pbar: + # generate input + bin_file_path = os.path.join(val_bin_path, file_name) + data = np.fromfile(bin_file_path, dtype=np.float32).reshape(1, 3, image_size, image_size) + image = torch.from_numpy(data).to(device) + + # infer + with torch_aie.npu.stream(stream): + aie_results = model(image) + stream.synchronize() + + boxes = aie_results[0][0].to("cpu").numpy() + scores = aie_results[1][0].to("cpu").numpy() + base_name = file_name.split(".")[0] + result_0_save_path = os.path.join(save_path, base_name + "_0.bin") + result_1_save_path = os.path.join(save_path, base_name + "_1.bin") + boxes.tofile(result_0_save_path) + scores.tofile(result_1_save_path) + pbar.set_description_str("Process " + file_name + " done.") + +def performance_test(model, device_id, batch_size, image_size): + torch_aie.set_device(device_id) + + model = torch.jit.load(model) + model.eval() + print('Model loaded successfully.') + + random_input = torch.rand(batch_size, 3, image_size, image_size) + device = f'npu:{device_id}' + random_input = random_input.to(device) + stream = torch_aie.npu.Stream(device) + + # warm up + num_warmup = 50 + for _ in range(num_warmup): + with torch_aie.npu.stream(stream): + model(random_input) + stream.synchronize() + print('warmup done.') + + # performance test + print('Start performance test.') + num_infer = 500 + start = time.time() + for _ in range(num_infer): + with torch_aie.npu.stream(stream): + model(random_input) + stream.synchronize() + avg_time = (time.time() - start) / num_infer + fps = batch_size / avg_time + print(f'FPS: {fps:.4f}') + + +if __name__ == "__main__": + opts = generate_options() + + # compile + compiled_model = compile(opts.traced_model_path, opts.ts_save_path, opts.batch_size, opts.image_size) + + # infer + inference(opts.bin_path, opts.results_save_path, compiled_model, opts.device_id, opts.image_size) + + # performance test + performance_test(compiled_model, opts.device_id, opts.batch_size, opts.image_size) \ No newline at end of file diff --git a/AscendIE/TorchAIE/built-in/cv/detection/FasterRcnn/txt_to_json.py b/AscendIE/TorchAIE/built-in/cv/detection/FasterRcnn/txt_to_json.py new file mode 100644 index 0000000000..9f479da6a0 --- /dev/null +++ b/AscendIE/TorchAIE/built-in/cv/detection/FasterRcnn/txt_to_json.py @@ -0,0 +1,114 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import glob +import os +import sys +import argparse +import mmcv + +CLASSES = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', + 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', + 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', + 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', + 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', + 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', + 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', + 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', + 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', + 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', + 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', + 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', + 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', + 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'] + +cat_ids = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, +24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, +48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, +72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90] + +''' + 0,0 ------> x (width) + | + | (Left,Top) + | *_________ + | | | + | | + y |_________| + (height) * + (Right,Bottom) +''' + +def file_lines_to_list(path): + # open txt file lines to a list + with open(path) as f: + content = f.readlines() + # remove whitespace characters like `\n` at the end of each line + content = [x.strip() for x in content] + return content + + +def error(msg): + print(msg) + sys.exit(0) + + +def get_predict_list(file_path, gt_classes): + dr_files_list = glob.glob(file_path + '/*.txt') + dr_files_list.sort() + + bounding_boxes = [] + for txt_file in dr_files_list: + file_id = txt_file.split(".txt", 1)[0] + file_id = os.path.basename(os.path.normpath(file_id)) + lines = file_lines_to_list(txt_file) + for line in lines: + try: + sl = line.split() + if len(sl) > 6: + class_name = sl[0] + ' ' + sl[1] + scores, left, top, right, bottom = sl[2:] + else: + class_name, scores, left, top, right, bottom = sl + if float(scores) < 0.05: + continue + except ValueError: + error_msg = "Error: File " + txt_file + " wrong format.\n" + error_msg += " Expected: \n" + error_msg += " Received: " + line + error(error_msg) + + # bbox = left + " " + top + " " + right + " " + bottom + left = float(left) + right = float(right) + top = float(top) + bottom = float(bottom) + bbox = [left, top, right-left, bottom-top] + bounding_boxes.append({"image_id": int(file_id), "bbox": bbox, + "score": float(scores), "category_id": cat_ids[CLASSES.index(class_name)]}) + # sort detection-results by decreasing scores + # bounding_boxes.sort(key=lambda x: float(x['score']), reverse=True) + return bounding_boxes + + + +if __name__ == '__main__': + parser = argparse.ArgumentParser('mAp calculate') + parser.add_argument('--npu_txt_path', default="detection-results", + help='the path of the predict result') + parser.add_argument("--json_output_file", default="coco_detection_result") + args = parser.parse_args() + + res_bbox = get_predict_list(args.npu_txt_path, CLASSES) + mmcv.dump(res_bbox, args.json_output_file + '.json') \ No newline at end of file -- Gitee From c8c17d1964d80ff6da86aabe6ad616f6e935481e Mon Sep 17 00:00:00 2001 From: hekuikui Date: Mon, 11 Dec 2023 20:43:01 +0800 Subject: [PATCH 2/5] fix bug --- .../cv/detection/FasterRcnn/sample.py | 24 +++++++++---------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/AscendIE/TorchAIE/built-in/cv/detection/FasterRcnn/sample.py b/AscendIE/TorchAIE/built-in/cv/detection/FasterRcnn/sample.py index 0f490cc893..fe98e3a403 100644 --- a/AscendIE/TorchAIE/built-in/cv/detection/FasterRcnn/sample.py +++ b/AscendIE/TorchAIE/built-in/cv/detection/FasterRcnn/sample.py @@ -24,14 +24,14 @@ import torch_aie def generate_options(): parser = argparse.ArgumentParser() - parser.add_argument("--traced-model-path", type=str, help="path to traced-model") - parser.add_argument("--bin-path", type=str, help="path to bin files") - parser.add_argument("--batch-size", type=int, default=1, help="set batch size, default is 1") - parser.add_argument("--image-size", type=int, default=1216, help="set image size, default is 1216") - parser.add_argument("--ts-save-path", type=str, default="'./faster_rcnn_r50_fpn_aie.pt'", help="path to save results") - parser.add_argument("--results-save-path", type=str, default="./results", help="path to save results") - parser.add_argument("--device-id", type=int, default=0, help="set device, default is 0") - return parser.parse_args + parser.add_argument("--traced_model_path", type=str, default="./faster_rcnn_r50_fpn_trace_20231130.pt", help="path to traced-model") + parser.add_argument("--bin_path", type=str, default="./coco_val_bin", help="path to bin files") + parser.add_argument("--batch_size", type=int, default=1, help="set batch size, default is 1") + parser.add_argument("--image_size", type=int, default=1216, help="set image size, default is 1216") + parser.add_argument("--ts_save_path", type=str, default="./faster_rcnn_r50_fpn_aie.pt", help="path to save results") + parser.add_argument("--results_save_path", type=str, default="./results", help="path to save results") + parser.add_argument("--device_id", type=int, default=0, help="set device, default is 0") + return parser.parse_args() def compile(model_path, save_path, batch_size, image_size): model = torch.jit.load(model_path) @@ -53,7 +53,7 @@ def inference(val_bin_path, save_path, model, device_id, image_size): torch_aie.set_device(device_id) device = f'npu:{device_id}' stream = torch_aie.npu.Stream(device) - + model.eval() file_list = sorted(os.listdir(val_bin_path)) pbar = tqdm(file_list) for file_name in pbar: @@ -79,15 +79,13 @@ def inference(val_bin_path, save_path, model, device_id, image_size): def performance_test(model, device_id, batch_size, image_size): torch_aie.set_device(device_id) - model = torch.jit.load(model) - model.eval() - print('Model loaded successfully.') - random_input = torch.rand(batch_size, 3, image_size, image_size) device = f'npu:{device_id}' random_input = random_input.to(device) stream = torch_aie.npu.Stream(device) + model.eval() + # warm up num_warmup = 50 for _ in range(num_warmup): -- Gitee From 87f458ef215e8ec596434615e081578d6c28c954 Mon Sep 17 00:00:00 2001 From: hekuikui Date: Thu, 14 Dec 2023 11:52:16 +0800 Subject: [PATCH 3/5] add fils --- .../cv/detection/FasterRcnn/README.md | 263 ++++++++++++++++++ .../mmdetection_coco_postprocess.py | 2 +- .../cv/detection/FasterRcnn/requirements.txt | 15 + .../built-in/cv/detection/FasterRcnn/trace.py | 94 +++++++ 4 files changed, 373 insertions(+), 1 deletion(-) create mode 100644 AscendIE/TorchAIE/built-in/cv/detection/FasterRcnn/README.md create mode 100644 AscendIE/TorchAIE/built-in/cv/detection/FasterRcnn/requirements.txt create mode 100644 AscendIE/TorchAIE/built-in/cv/detection/FasterRcnn/trace.py diff --git a/AscendIE/TorchAIE/built-in/cv/detection/FasterRcnn/README.md b/AscendIE/TorchAIE/built-in/cv/detection/FasterRcnn/README.md new file mode 100644 index 0000000000..812e8b4972 --- /dev/null +++ b/AscendIE/TorchAIE/built-in/cv/detection/FasterRcnn/README.md @@ -0,0 +1,263 @@ +# Faster R-CNN_ResNet50模型-推理指导 + + +- [概述](#ZH-CN_TOPIC_0000001172161501) + +- [推理环境准备](#ZH-CN_TOPIC_0000001126281702) + +- [快速上手](#ZH-CN_TOPIC_0000001126281700) + + - [获取源码](#section4622531142816) + - [准备数据集](#section183221994411) + - [模型推理](#section741711594517) + +- [模型推理性能](#ZH-CN_TOPIC_0000001172201573) + +- [配套环境](#ZH-CN_TOPIC_0000001126121892) + + ****** + + + +# 概述 + +2016年,新的Faster RCNN模型被提出,在结构上,Faster RCNN已经将特征抽取(feature extraction),proposal提取,bounding box regression(rect refine),classification都整合在了一个网络中,使得综合性能有较大提高,在检测速度方面尤为明显。 + +- 参考实现: + + ``` + url=https://github.com/open-mmlab/mmdetection/tree/master/configs/faster_rcnn + branch=master + commit_id=a21eb25535f31634cef332b09fc27d28956fb24b + model_name=faster_rcnn_r50_fpn + ``` + + + + +## 输入输出数据 + +- 输入数据 + + | 输入数据 | 数据类型 | 大小 | 数据排布格式 | + | :------: | :------: | :-------------------------: | :----------: | + | input | RGB_FP32 | batchsize x 3 x 1216 x 1216 | NCHW | + + +- 输出数据 + + | 输出数据 | 大小 | 数据类型 | 数据排布格式 | + | :------: | :---: | :------: | :---------: | + | boxes | 100x5 | FLOAT32 | ND | + | labels | 100 | INT32 | ND | + + +# 推理环境准备 + +- 该模型需要以下插件与驱动 + + **表 1** 版本配套表 + +| 配套 | 版本 | +|-----------------------|-----------------| +| CANN | 7.0.RC1.3 | - | +| Python | 3.9 | +| PyTorch | 2.0.1 | +| torchVison | 0.15.2 |- +| Ascend-cann-torch-aie | >= 7.0.0 +| Ascend-cann-aie | >= 7.0.0 +| 芯片类型 | Ascend310P3 | - | | + +# 快速上手 + +## 获取源码 + +1. 获取本仓代码 + ```bash + git clone https://gitee.com/ascend/ModelZoo-PyTorch.git + cd ./ModelZoo-PyTorch-fasterRcnn/AscendIE/TorchAIE/built-in/cv/detection/FasterRcnn + ``` + + 文件说明 + ``` + Faster_R-CNN_ResNet50 + ├── README.md # 此文档 + ├── coco_eval.py # 验证推理精度的脚本 + ├── get_info.py # 用于获取图像数据集的info文件 + ├── sample.py # 模型推理脚本文件 + ├── mmdetection_coco_postprocess.py # 推理结果后处理脚本 + ├── mmdetection_coco_preprocess.py # 数据集预处理脚本 + ├── trace.py # 模型trace脚本 + └── txt_to_json.py # 将推理结果txt文件转换为coco数据集评测精度的标准json格式 + ``` + +2. 安装依赖 + ```bash + pip3 install -r requirements.txt + + # 安装mmpycocotools + pip3 install mmpycocotools==12.0.3 + + # 从源码安装mmcv-full + git clone https://github.com/open-mmlab/mmcv.git + cd mmcv + git reset --hard 643009e4458109cb88ba5e669eec61a5e54c83be + pip3 install -r requirements.txt + MMCV_WITH_OPS=1 pip3 install -v -e . + cd .. + ``` + + + +3. 获取模型源码,并安装相应的依赖库 + + ```bash + git clone https://github.com/open-mmlab/mmdetection.git + cd mmdetection + git reset --hard a21eb25535f31634cef332b09fc27d28956fb24b + pip3 install -v -e . + cd .. + ``` + + +## 准备数据集 + +1. 获取原始数据集和验证集 + + 该模型使用[COCO官网](https://cocodataset.org/#download)的coco2017的5千张验证集进行测试,图片与标签分别存放在```val2017/```与```annotations/instances_val2017.json```。 + + ```bash + wget http://images.cocodataset.org/zips/val2017.zip --no-check-certificate + unzip -qo val2017.zip + + wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip --no-check-certificate + unzip -qo annotations_trainval2017.zip + ``` + +2. 数据预处理 + 将原始数据集转换为模型输入的数据。 + + 将原始数据(.jpeg)转化为二进制文件(.bin)。转化方法参考mmdetection预处理方法,以获得最佳精度。以coco_2017数据集为例,通过缩放、均值方差手段归一化,输出为二进制文件。 + + 执行mmdetection_coco_preprocess.py脚本,完成预处理。 + + ```bash + python3 mmdetection_coco_preprocess.py --image_folder_path val2017/ --bin_folder_path val2017_bin + ``` + + 参数说明: + - --image_folder_path: 图像数据集目录。 + - --bin_folder_path: 二进制文件输出目录。 + + +3. JPG图片info文件生成 + + + 后处理时需要输入数据集.jpg图片的info文件。使用get_info.py脚本,输入已经获得的图片文件,输出生成图片数据集的info文件。 + + 运行get_info.py脚本。 + + ```bash + python3 get_info.py jpg ./val2017/ coco2017_jpg.info + ``` + 参数说明: + + - 第一个参数为生成的数据集文件格式。 + - 第二个参数为coco图片数据文件的**相对路径**。 + - 第三个参数为生成的数据集信息文件保存的路径。 + + + 运行成功后,在当前目录中生成```coco2017_jpg.info```。 + +## 模型推理 + +1. 模型加载。 + + + a. 获取权重文件。 + + ``` + wget http://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth --no-check-certificate + ``` + b. load并trace模型 + 使用mmdetection/tools目录中的pytorch2onnx导出onnx文件。运行pytorch2onnx脚本。 + + ``` + python3 trace.py --config ./mmdetection/configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py --checkpoint ./faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth --input-img ./coco/val2017/000000515350.jpg --shape 1216 --output-file ./faster_rcnn_torch_aie.pt --mmdet-ops-path ./mmdet_ops/build/libmmdet_ops.so + ``` + + 参数说明: + + - --config: 模型的配置文件。 + - --checkpoint: 模型的权重文件。 + - --input-img: coco数据集jpg文件。 + - --shape: 模型的输入shape。 + - --output-file: trace的pt模型文件保存路径。 + + 获得```faster_rcnn_torch_aie.pt```文件。 + +2. 开始推理验证。 + + a. 执行推理。 + + ``` + python3 sample.py --traced_model_path ./faster_rcnn_torch_aie.pt --bin_path path/to/coco/val_2017_bin --ts_save_path compiled/model/save/path --results_save_path detection/results/save/path + ``` + - 参数说明: + - --traced_model_path:trace的pt文件路径。 + - --bin_path:coco测试数据集的bin文件路径。 + - --ts_save_path:编译之后的ts模型所存目录。 + - --results_save_path:检测结果bin文件所存目录。 + + 推理后的输出默认在当前目录result下。 + + b. 精度验证。 + + 本模型提供后处理脚本,将二进制数据转化为txt文件,执行脚本。 + + ``` + python3 mmdetection_coco_postprocess.py --bin_data_path=result/${infer_result_dir} --prob_thres=0.05 --det_results_path=detection-results --test_annotation=coco2017_jpg.info + ``` + + - 参数说明: + + - bin_data_path:推理输出目录 (注意替换成实际目录,如```2022_12_16-18_01_01/```)。 + + - prob_thres:框的置信度阈值。 + + - det_results:后处理输出目录。 + + 评测结果的mAP值需要使用官方的pycocotools工具,首先将后处理输出的txt文件转化为coco数据集评测精度的标准json格式。 + + 执行转换脚本。 + + ``` + python3 txt_to_json.py --npu_txt_path detection-results --json_output_file coco_detection_result + ``` + - 参数说明: + + - --npu_txt_path: 输入的txt文件目录。 + + - --json_output_file: 输出的json文件路径。 + + + 运行成功后,生成```coco_detection_result.json```文件。 + 调用coco_eval.py脚本,输出推理结果的详细评测报告。 + + ``` + python3 coco_eval.py --detection_result coco_detection_result.json --ground_truth=annotations/instances_val2017.json + ``` + - 参数说明: + - --detection_result:推理结果json文件。 + + - --ground_truth:```instances_val2017.json```的存放路径。 + + + +# 模型推理性能&精度 + +调用Torch_AIE推理计算,性能参考下列数据。 + +| 芯片型号 | Batch Size | 数据集 | 精度(mAP) | 性能(FPS) | +| :------: | :---------: | :-----: | :---: | :--: | +| Ascend310P | 1 | coco2017 | 37.2 | 13.11 | \ No newline at end of file diff --git a/AscendIE/TorchAIE/built-in/cv/detection/FasterRcnn/mmdetection_coco_postprocess.py b/AscendIE/TorchAIE/built-in/cv/detection/FasterRcnn/mmdetection_coco_postprocess.py index 054aad4bfd..5a6d3f9c47 100644 --- a/AscendIE/TorchAIE/built-in/cv/detection/FasterRcnn/mmdetection_coco_postprocess.py +++ b/AscendIE/TorchAIE/built-in/cv/detection/FasterRcnn/mmdetection_coco_postprocess.py @@ -108,7 +108,7 @@ if __name__ == '__main__': buf = np.fromfile(path_base + "_" + str(num) + ".bin", dtype="float32") buf = np.reshape(buf, [100, 5]) elif num == 1: - buf = np.fromfile(path_base + "_" + str(num) + ".bin", dtype="int64") + buf = np.fromfile(path_base + "_" + str(num) + ".bin", dtype="int32") buf = np.reshape(buf, [100, 1]) res_buff.append(buf) else: diff --git a/AscendIE/TorchAIE/built-in/cv/detection/FasterRcnn/requirements.txt b/AscendIE/TorchAIE/built-in/cv/detection/FasterRcnn/requirements.txt new file mode 100644 index 0000000000..4ad80b6ccd --- /dev/null +++ b/AscendIE/TorchAIE/built-in/cv/detection/FasterRcnn/requirements.txt @@ -0,0 +1,15 @@ +numpy +decorator +attrs +psutil +tqdm +onnx==1.7.0 +Pillow==9.2.0 +opencv-python==4.6.0.66 +torch==1.8.1 +torchvision==0.9.1 +protobuf==3.20.0 +onnxruntime==1.12.1 +onnxoptimizer==0.2.7 +terminaltables==3.1.10 + diff --git a/AscendIE/TorchAIE/built-in/cv/detection/FasterRcnn/trace.py b/AscendIE/TorchAIE/built-in/cv/detection/FasterRcnn/trace.py new file mode 100644 index 0000000000..8061599b3c --- /dev/null +++ b/AscendIE/TorchAIE/built-in/cv/detection/FasterRcnn/trace.py @@ -0,0 +1,94 @@ +import argparse +import os.path as osp + +import numpy as np +import onnx +import onnxruntime as rt +import torch +import sys,os + + +from mmdet.core import (build_model_from_cfg, generate_inputs_and_wrap_model, + preprocess_example_input) + +def trace( + config_path, + checkpoint_path, + input_img, + input_shape, + output_file='tmp.onnx', + normalize_cfg=None): + + input_config = { + 'input_shape': input_shape, + 'input_path': input_img, + 'normalize_cfg': normalize_cfg + } + + # prepare original model and meta for verifying the onnx model + orig_model = build_model_from_cfg(config_path, checkpoint_path) + one_img, one_meta = preprocess_example_input(input_config) + model, tensor_data = generate_inputs_and_wrap_model( + config_path, checkpoint_path, input_config) + + model.eval() + torch.jit.trace(model, tensor_data).save('faster_rcnn_r50_fpn_trace_20231130.pt') + +def parse_args(): + parser = argparse.ArgumentParser( + description='Trace MMDetection models') + parser.add_argument('--config', help='test config file path') + parser.add_argument('--checkpoint', help='checkpoint file') + parser.add_argument('--input-img', type=str, help='Images for input') + parser.add_argument('--output-file', type=str, default='tmp.onnx') + parser.add_argument('--mmdet-ops-path', type=str, default="") + parser.add_argument( + '--shape', + type=int, + nargs='+', + default=[800, 1216], + help='input image size') + parser.add_argument( + '--mean', + type=float, + nargs='+', + default=[123.675, 116.28, 103.53], + help='mean value used for preprocess input data') + parser.add_argument( + '--std', + type=float, + nargs='+', + default=[58.395, 57.12, 57.375], + help='variance value used for preprocess input data') + args = parser.parse_args() + return args + + +if __name__ == '__main__': + args = parse_args() + + torch.ops.load_library(args.mmdet_ops_path) + + if len(args.shape) == 1: + input_shape = (1, 3, args.shape[0], args.shape[0]) + elif len(args.shape) == 2: + input_shape = (1, 3) + tuple(args.shape) + else: + raise ValueError('invalid input shape') + + assert len(args.mean) == 3 + assert len(args.std) == 3 + + normalize_cfg = {'mean': args.mean, 'std': args.std} + + # convert model to onnx file + trace( + args.config, + args.checkpoint, + args.input_img, + input_shape, + args.output_file, + normalize_cfg=normalize_cfg,) + + +# python3 trace.py --config ./mmdetection/configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py --checkpoint ./faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth --input-img /data/datasets/coco/val2017/000000515350.jpg --shape 1216 --output-file "faster_rcnn_torch_aie_1214.pt" --mmdet-ops-path ./mmdet_ops/build/libmmdet_ops.so -- Gitee From d13aff35e8044df5bc823e05dfe5a2ace7a17e61 Mon Sep 17 00:00:00 2001 From: hekuikui Date: Fri, 15 Dec 2023 11:44:41 +0800 Subject: [PATCH 4/5] add trace.py and update readme --- .../README.md | 52 +- .../coco_eval.py | 0 .../get_info.py | 0 .../mmdet_ops/CMakeLists.txt | 13 + .../Faster_rcnn_r50_fpn_1x/mmdet_ops/build.sh | 6 + .../mmdet_ops/mmdet_ops.cpp | 66 +++ .../Faster_rcnn_r50_fpn_1x/mmdetection.patch | 519 ++++++++++++++++++ .../mmdetection_coco_postprocess.py | 0 .../mmdetection_coco_preprocess.py | 0 .../requirements.txt | 5 +- .../sample.py | 0 .../trace.py | 12 +- .../txt_to_json.py | 0 13 files changed, 641 insertions(+), 32 deletions(-) rename AscendIE/TorchAIE/built-in/cv/detection/{FasterRcnn => Faster_rcnn_r50_fpn_1x}/README.md (93%) rename AscendIE/TorchAIE/built-in/cv/detection/{FasterRcnn => Faster_rcnn_r50_fpn_1x}/coco_eval.py (100%) rename AscendIE/TorchAIE/built-in/cv/detection/{FasterRcnn => Faster_rcnn_r50_fpn_1x}/get_info.py (100%) create mode 100644 AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/mmdet_ops/CMakeLists.txt create mode 100644 AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/mmdet_ops/build.sh create mode 100644 AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/mmdet_ops/mmdet_ops.cpp create mode 100644 AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/mmdetection.patch rename AscendIE/TorchAIE/built-in/cv/detection/{FasterRcnn => Faster_rcnn_r50_fpn_1x}/mmdetection_coco_postprocess.py (100%) rename AscendIE/TorchAIE/built-in/cv/detection/{FasterRcnn => Faster_rcnn_r50_fpn_1x}/mmdetection_coco_preprocess.py (100%) rename AscendIE/TorchAIE/built-in/cv/detection/{FasterRcnn => Faster_rcnn_r50_fpn_1x}/requirements.txt (76%) rename AscendIE/TorchAIE/built-in/cv/detection/{FasterRcnn => Faster_rcnn_r50_fpn_1x}/sample.py (100%) rename AscendIE/TorchAIE/built-in/cv/detection/{FasterRcnn => Faster_rcnn_r50_fpn_1x}/trace.py (76%) rename AscendIE/TorchAIE/built-in/cv/detection/{FasterRcnn => Faster_rcnn_r50_fpn_1x}/txt_to_json.py (100%) diff --git a/AscendIE/TorchAIE/built-in/cv/detection/FasterRcnn/README.md b/AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/README.md similarity index 93% rename from AscendIE/TorchAIE/built-in/cv/detection/FasterRcnn/README.md rename to AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/README.md index 812e8b4972..7d08f52e8e 100644 --- a/AscendIE/TorchAIE/built-in/cv/detection/FasterRcnn/README.md +++ b/AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/README.md @@ -75,12 +75,14 @@ 1. 获取本仓代码 ```bash git clone https://gitee.com/ascend/ModelZoo-PyTorch.git - cd ./ModelZoo-PyTorch-fasterRcnn/AscendIE/TorchAIE/built-in/cv/detection/FasterRcnn + cd ./ModelZoo-PyTorch-fasterRcnn/AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x ``` + + 文件说明 ``` - Faster_R-CNN_ResNet50 + FasterRcnn ├── README.md # 此文档 ├── coco_eval.py # 验证推理精度的脚本 ├── get_info.py # 用于获取图像数据集的info文件 @@ -91,23 +93,17 @@ └── txt_to_json.py # 将推理结果txt文件转换为coco数据集评测精度的标准json格式 ``` + + 2. 安装依赖 ```bash pip3 install -r requirements.txt # 安装mmpycocotools pip3 install mmpycocotools==12.0.3 - - # 从源码安装mmcv-full - git clone https://github.com/open-mmlab/mmcv.git - cd mmcv - git reset --hard 643009e4458109cb88ba5e669eec61a5e54c83be - pip3 install -r requirements.txt - MMCV_WITH_OPS=1 pip3 install -v -e . - cd .. ``` - + 3. 获取模型源码,并安装相应的依赖库 @@ -116,7 +112,16 @@ cd mmdetection git reset --hard a21eb25535f31634cef332b09fc27d28956fb24b pip3 install -v -e . - cd .. + ``` + + + +4. 修改mmdetection源码 + + 使用mmdetection(v2.8.0)trace前, 需要对源码做一定的改动,以适配Ascend NPU。 + + ```bash + patch -p1 < mmdetection.patch ``` @@ -175,26 +180,33 @@ a. 获取权重文件。 + ``` wget http://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth --no-check-certificate ``` + + b. load并trace模型 使用mmdetection/tools目录中的pytorch2onnx导出onnx文件。运行pytorch2onnx脚本。 + ``` python3 trace.py --config ./mmdetection/configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py --checkpoint ./faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth --input-img ./coco/val2017/000000515350.jpg --shape 1216 --output-file ./faster_rcnn_torch_aie.pt --mmdet-ops-path ./mmdet_ops/build/libmmdet_ops.so ``` - 参数说明: - - - --config: 模型的配置文件。 - - --checkpoint: 模型的权重文件。 - - --input-img: coco数据集jpg文件。 - - --shape: 模型的输入shape。 - - --output-file: trace的pt模型文件保存路径。 - 获得```faster_rcnn_torch_aie.pt```文件。 + 参数说明: + + - --config: 模型的配置文件。 + - --checkpoint: 模型的权重文件。 + - --input-img: coco数据集jpg文件。 + - --shape: 模型的输入shape。 + - --output-file: trace的pt模型文件保存路径。 + + 获得```faster_rcnn_torch_aie.pt```文件。 + + 2. 开始推理验证。 diff --git a/AscendIE/TorchAIE/built-in/cv/detection/FasterRcnn/coco_eval.py b/AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/coco_eval.py similarity index 100% rename from AscendIE/TorchAIE/built-in/cv/detection/FasterRcnn/coco_eval.py rename to AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/coco_eval.py diff --git a/AscendIE/TorchAIE/built-in/cv/detection/FasterRcnn/get_info.py b/AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/get_info.py similarity index 100% rename from AscendIE/TorchAIE/built-in/cv/detection/FasterRcnn/get_info.py rename to AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/get_info.py diff --git a/AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/mmdet_ops/CMakeLists.txt b/AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/mmdet_ops/CMakeLists.txt new file mode 100644 index 0000000000..44192fbdab --- /dev/null +++ b/AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/mmdet_ops/CMakeLists.txt @@ -0,0 +1,13 @@ +cmake_minimum_required(VERSION 3.1 FATAL_ERROR) +project(mmdet_ops) +set(CMAKE_PREFIX_PATH "/home/hekuikui/env/conda/envs/t181_v091/lib/python3.9/site-packages/torch") +find_package(Torch REQUIRED) + +add_library(mmdet_ops SHARED mmdet_ops.cpp) + +target_compile_features(mmdet_ops PRIVATE cxx_std_14) + +target_link_libraries(mmdet_ops PUBLIC + c10 + torch + torch_cpu) diff --git a/AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/mmdet_ops/build.sh b/AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/mmdet_ops/build.sh new file mode 100644 index 0000000000..4a5dbcf607 --- /dev/null +++ b/AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/mmdet_ops/build.sh @@ -0,0 +1,6 @@ +rm -r build +mkdir build +cd build + +cmake .. +make -j 32 \ No newline at end of file diff --git a/AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/mmdet_ops/mmdet_ops.cpp b/AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/mmdet_ops/mmdet_ops.cpp new file mode 100644 index 0000000000..81ba9fae86 --- /dev/null +++ b/AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/mmdet_ops/mmdet_ops.cpp @@ -0,0 +1,66 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include + +std::tuple batch_nms( + at::Tensor bbox, + at::Tensor scores, + double score_threshold, + double iou_threshold, + int64_t max_size_per_class, + int64_t max_total_size +) +{ + auto boxBatch = bbox.sizes()[0]; + auto boxFeat = bbox.sizes()[3]; + auto scoreBatch = scores.sizes()[0]; + + auto outBox = torch::ones({boxBatch, max_total_size, boxFeat}).to(torch::kFloat16); + auto outScore = torch::ones({scoreBatch, max_total_size}).to(torch::kFloat16); + auto outClass = torch::ones({max_total_size, }).to(torch::kInt64); + auto outNum = torch::ones({1, }).to(torch::kFloat32); + + return std::make_tuple(outBox, outScore, outClass, outNum); +} + +at::Tensor roi_extractor( + std::vector feats, + at::Tensor rois, + bool aligned, + int64_t finest_scale, + int64_t pooled_height, + int64_t pooled_width, + c10::string_view pool_mode, + int64_t roi_scale_factor, + int64_t sample_num, + std::vector spatial_scale +) +{ + auto k = rois.sizes()[0]; + auto c = feats[0].sizes()[1]; + auto roi_feats = torch::ones({k, c, pooled_height, pooled_width}).to(torch::kFloat32); + + return roi_feats; +} + +TORCH_LIBRARY(aie, m) { + m.def("batch_nms", batch_nms); + m.def("roi_extractor", roi_extractor); +} \ No newline at end of file diff --git a/AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/mmdetection.patch b/AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/mmdetection.patch new file mode 100644 index 0000000000..04bb741368 --- /dev/null +++ b/AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/mmdetection.patch @@ -0,0 +1,519 @@ +From a098c07a11670cac3352493c5146540d13d10d19 Mon Sep 17 00:00:00 2001 +From: hekuikui +Date: Fri, 15 Dec 2023 11:41:37 +0800 +Subject: [PATCH 1215/1215] mmdetection.patch + +--- + .../core/bbox/coder/delta_xywh_bbox_coder.py | 32 +++++-- + mmdet/core/post_processing/bbox_nms.py | 92 ++++++++++++++++++- + mmdet/models/dense_heads/rpn_head.py | 81 +++++++++++++++- + mmdet/models/detectors/base.py | 5 + + mmdet/models/detectors/single_stage.py | 5 +- + mmdet/models/roi_heads/cascade_roi_head.py | 5 +- + .../roi_heads/mask_heads/fcn_mask_head.py | 14 ++- + .../single_level_roi_extractor.py | 48 ++++++++-- + mmdet/models/roi_heads/standard_roi_head.py | 14 +-- + mmdet/models/roi_heads/test_mixins.py | 25 +++-- + 10 files changed, 270 insertions(+), 51 deletions(-) + +diff --git a/mmdet/core/bbox/coder/delta_xywh_bbox_coder.py b/mmdet/core/bbox/coder/delta_xywh_bbox_coder.py +index e9eb3579..62562e75 100644 +--- a/mmdet/core/bbox/coder/delta_xywh_bbox_coder.py ++++ b/mmdet/core/bbox/coder/delta_xywh_bbox_coder.py +@@ -168,8 +168,13 @@ def delta2bbox(rois, + [0.0000, 0.3161, 4.1945, 0.6839], + [5.0000, 5.0000, 5.0000, 5.0000]]) + """ +- means = deltas.new_tensor(means).view(1, -1).repeat(1, deltas.size(1) // 4) +- stds = deltas.new_tensor(stds).view(1, -1).repeat(1, deltas.size(1) // 4) ++ # fix shape for means and stds for onnx ++ if torch.onnx.is_in_onnx_export(): ++ means = deltas.new_tensor(means).view(1, -1).repeat(1, deltas.size(1).numpy() // 4) ++ stds = deltas.new_tensor(stds).view(1, -1).repeat(1, deltas.size(1).numpy() // 4) ++ else: ++ means = deltas.new_tensor(means).view(1, -1).repeat(1, deltas.size(1) // 4) ++ stds = deltas.new_tensor(stds).view(1, -1).repeat(1, deltas.size(1) // 4) + denorm_deltas = deltas * stds + means + dx = denorm_deltas[:, 0::4] + dy = denorm_deltas[:, 1::4] +@@ -178,12 +183,23 @@ def delta2bbox(rois, + max_ratio = np.abs(np.log(wh_ratio_clip)) + dw = dw.clamp(min=-max_ratio, max=max_ratio) + dh = dh.clamp(min=-max_ratio, max=max_ratio) +- # Compute center of each roi +- px = ((rois[:, 0] + rois[:, 2]) * 0.5).unsqueeze(1).expand_as(dx) +- py = ((rois[:, 1] + rois[:, 3]) * 0.5).unsqueeze(1).expand_as(dy) +- # Compute width/height of each roi +- pw = (rois[:, 2] - rois[:, 0]).unsqueeze(1).expand_as(dw) +- ph = (rois[:, 3] - rois[:, 1]).unsqueeze(1).expand_as(dh) ++ # improve gather performance on NPU ++ if torch.onnx.is_in_onnx_export(): ++ rois_perf = rois.permute(1, 0) ++ # Compute center of each roi ++ px = ((rois_perf[0, :] + rois_perf[2, :]) * 0.5).unsqueeze(1).expand_as(dx) ++ py = ((rois_perf[1, :] + rois_perf[3, :]) * 0.5).unsqueeze(1).expand_as(dy) ++ # Compute width/height of each roi ++ pw = (rois_perf[2, :] - rois_perf[0, :]).unsqueeze(1).expand_as(dw) ++ ph = (rois_perf[3, :] - rois_perf[1, :]).unsqueeze(1).expand_as(dh) ++ else: ++ rois_perf = rois.permute(1, 0) ++ # Compute center of each roi ++ px = ((rois_perf[0, :] + rois_perf[2, :]) * 0.5).unsqueeze(1).expand_as(dx) ++ py = ((rois_perf[1, :] + rois_perf[3, :]) * 0.5).unsqueeze(1).expand_as(dy) ++ # Compute width/height of each roi ++ pw = (rois_perf[2, :] - rois_perf[0, :]).unsqueeze(1).expand_as(dw) ++ ph = (rois_perf[3, :] - rois_perf[1, :]).unsqueeze(1).expand_as(dh) + # Use exp(network energy) to enlarge/shrink each roi + gw = pw * dw.exp() + gh = ph * dh.exp() +diff --git a/mmdet/core/post_processing/bbox_nms.py b/mmdet/core/post_processing/bbox_nms.py +index 463fe2e4..1363e9c9 100644 +--- a/mmdet/core/post_processing/bbox_nms.py ++++ b/mmdet/core/post_processing/bbox_nms.py +@@ -4,6 +4,68 @@ from mmcv.ops.nms import batched_nms + from mmdet.core.bbox.iou_calculators import bbox_overlaps + + ++class BatchNMSOp(torch.autograd.Function): ++ @staticmethod ++ def forward(ctx, bboxes, scores, score_threshold, iou_threshold, max_size_per_class, max_total_size): ++ """ ++ boxes (torch.Tensor): boxes in shape (batch, N, C, 4). ++ scores (torch.Tensor): scores in shape (batch, N, C). ++ return: ++ nmsed_boxes: (1, N, 4) ++ nmsed_scores: (1, N) ++ nmsed_classes: (1, N) ++ nmsed_num: (1,) ++ """ ++ ++ # Phony implementation for onnx export ++ nmsed_boxes = bboxes[:, :max_total_size, 0, :] ++ nmsed_scores = scores[:, :max_total_size, 0] ++ nmsed_classes = torch.arange(max_total_size, dtype=torch.long) ++ nmsed_num = torch.Tensor([max_total_size]) ++ ++ return nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num ++ ++ @staticmethod ++ def symbolic(g, bboxes, scores, score_thr, iou_thr, max_size_p_class, max_t_size): ++ nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num = g.op('BatchMultiClassNMS', ++ bboxes, scores, score_threshold_f=score_thr, iou_threshold_f=iou_thr, ++ max_size_per_class_i=max_size_p_class, max_total_size_i=max_t_size, outputs=4) ++ return nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num ++ ++def batch_nms_op(bboxes, scores, score_threshold, iou_threshold, max_size_per_class, max_total_size): ++ """ ++ boxes (torch.Tensor): boxes in shape (N, 4). ++ scores (torch.Tensor): scores in shape (N, ). ++ """ ++ if torch.onnx.is_in_onnx_export(): ++ if bboxes.dtype == torch.float32: ++ bboxes = bboxes.reshape(1, bboxes.shape[0].numpy(), -1, 4).half() ++ scores = scores.reshape(1, scores.shape[0].numpy(), -1).half() ++ else: ++ bboxes = bboxes.reshape(1, bboxes.shape[0].numpy(), -1, 4) ++ scores = scores.reshape(1, scores.shape[0].numpy(), -1) ++ else: ++ if bboxes.dtype == torch.float32: ++ bboxes = bboxes.reshape(1, bboxes.shape[0], -1, 4).half() ++ scores = scores.reshape(1, scores.shape[0], -1).half() ++ else: ++ bboxes = bboxes.reshape(1, bboxes.shape[0], -1, 4) ++ scores = scores.reshape(1, scores.shape[0], -1) ++ ++ batch_nms = torch.ops.aie.batch_nms ++ nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num = batch_nms(bboxes, scores, ++ score_threshold, iou_threshold, ++ max_size_per_class, max_total_size) ++ # nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num = BatchNMSOp.apply(bboxes, scores, ++ # score_threshold, iou_threshold, max_size_per_class, max_total_size) ++ nmsed_boxes = nmsed_boxes.float() ++ nmsed_scores = nmsed_scores.float() ++ nmsed_classes = nmsed_classes.long() ++ dets = torch.cat((nmsed_boxes.reshape((max_total_size, 4)), nmsed_scores.reshape((max_total_size, 1))), -1) ++ labels = nmsed_classes.reshape((max_total_size, )) ++ return dets, labels ++ ++ + def multiclass_nms(multi_bboxes, + multi_scores, + score_thr, +@@ -36,13 +98,30 @@ def multiclass_nms(multi_bboxes, + if multi_bboxes.shape[1] > 4: + bboxes = multi_bboxes.view(multi_scores.size(0), -1, 4) + else: +- bboxes = multi_bboxes[:, None].expand( +- multi_scores.size(0), num_classes, 4) ++ # export expand operator to onnx more nicely ++ if torch.onnx.is_in_onnx_export: ++ bbox_shape_tensor = torch.ones(multi_scores.size(0), num_classes, 4) ++ bboxes = multi_bboxes[:, None].expand_as(bbox_shape_tensor) ++ else: ++ # bboxes = multi_bboxes[:, None].expand( ++ # multi_scores.size(0), num_classes, 4) ++ bbox_shape_tensor = torch.ones(multi_scores.size(0), num_classes, 4) ++ bboxes = multi_bboxes[:, None].expand_as(bbox_shape_tensor) ++ + + scores = multi_scores[:, :-1] + if score_factors is not None: + scores = scores * score_factors[:, None] + ++ # npu ++ if torch.onnx.is_in_onnx_export(): ++ dets, labels = batch_nms_op(bboxes, scores, score_thr, nms_cfg.get("iou_threshold"), max_num, max_num) ++ return dets, labels ++ else: ++ dets, labels = batch_nms_op(bboxes, scores, score_thr, nms_cfg.get("iou_threshold"), max_num, max_num) ++ return dets, labels ++ ++ # cpu and gpu + labels = torch.arange(num_classes, dtype=torch.long) + labels = labels.view(1, -1).expand_as(scores) + +@@ -53,11 +132,13 @@ def multiclass_nms(multi_bboxes, + # remove low scoring boxes + valid_mask = scores > score_thr + inds = valid_mask.nonzero(as_tuple=False).squeeze(1) ++ # vals, inds = torch.topk(scores, 1000) ++ + bboxes, scores, labels = bboxes[inds], scores[inds], labels[inds] + if inds.numel() == 0: +- if torch.onnx.is_in_onnx_export(): +- raise RuntimeError('[ONNX Error] Can not record NMS ' +- 'as it has not been executed this time') ++ # if torch.onnx.is_in_onnx_export(): ++ raise RuntimeError('[ONNX Error] Can not record NMS ' ++ 'as it has not been executed this time') + if return_inds: + return bboxes, labels, inds + else: +@@ -76,6 +157,7 @@ def multiclass_nms(multi_bboxes, + return dets, labels[keep] + + ++ + def fast_nms(multi_bboxes, + multi_scores, + multi_coeffs, +diff --git a/mmdet/models/dense_heads/rpn_head.py b/mmdet/models/dense_heads/rpn_head.py +index f565d1a4..d95765c4 100644 +--- a/mmdet/models/dense_heads/rpn_head.py ++++ b/mmdet/models/dense_heads/rpn_head.py +@@ -9,6 +9,67 @@ from .anchor_head import AnchorHead + from .rpn_test_mixin import RPNTestMixin + + ++class BatchNMSOp(torch.autograd.Function): ++ @staticmethod ++ def forward(ctx, bboxes, scores, score_threshold, iou_threshold, max_size_per_class, max_total_size): ++ """ ++ boxes (torch.Tensor): boxes in shape (batch, N, C, 4). ++ scores (torch.Tensor): scores in shape (batch, N, C). ++ return: ++ nmsed_boxes: (1, N, 4) ++ nmsed_scores: (1, N) ++ nmsed_classes: (1, N) ++ nmsed_num: (1,) ++ """ ++ ++ # Phony implementation for onnx export ++ nmsed_boxes = bboxes[:, :max_total_size, 0, :] ++ nmsed_scores = scores[:, :max_total_size, 0] ++ nmsed_classes = torch.arange(max_total_size, dtype=torch.long) ++ nmsed_num = torch.Tensor([max_total_size]) ++ ++ return nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num ++ ++ @staticmethod ++ def symbolic(g, bboxes, scores, score_thr, iou_thr, max_size_p_class, max_t_size): ++ nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num = g.op('BatchMultiClassNMS', ++ bboxes, scores, score_threshold_f=score_thr, iou_threshold_f=iou_thr, ++ max_size_per_class_i=max_size_p_class, max_total_size_i=max_t_size, outputs=4) ++ return nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num ++ ++def batch_nms_op(bboxes, scores, score_threshold, iou_threshold, max_size_per_class, max_total_size): ++ """ ++ boxes (torch.Tensor): boxes in shape (N, 4). ++ scores (torch.Tensor): scores in shape (N, ). ++ """ ++ if torch.onnx.is_in_onnx_export(): ++ if bboxes.dtype == torch.float32: ++ bboxes = bboxes.reshape(1, bboxes.shape[0].numpy(), -1, 4).half() ++ scores = scores.reshape(1, scores.shape[0].numpy(), -1).half() ++ else: ++ bboxes = bboxes.reshape(1, bboxes.shape[0].numpy(), -1, 4) ++ scores = scores.reshape(1, scores.shape[0].numpy(), -1) ++ else: ++ if bboxes.dtype == torch.float32: ++ bboxes = bboxes.reshape(1, bboxes.shape[0], -1, 4).half() ++ scores = scores.reshape(1, scores.shape[0], -1).half() ++ else: ++ bboxes = bboxes.reshape(1, bboxes.shape[0], -1, 4) ++ scores = scores.reshape(1, scores.shape[0], -1) ++ ++ batch_nms = torch.ops.aie.batch_nms ++ nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num = batch_nms(bboxes, scores, ++ score_threshold, iou_threshold, ++ max_size_per_class, max_total_size) ++ # nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num = BatchNMSOp.apply(bboxes, scores, ++ # score_threshold, iou_threshold, max_size_per_class, max_total_size) # max_total_size num_bbox ++ nmsed_boxes = nmsed_boxes.float() ++ nmsed_scores = nmsed_scores.float() ++ nmsed_classes = nmsed_classes.long() ++ dets = torch.cat((nmsed_boxes.reshape((max_total_size, 4)), nmsed_scores.reshape((max_total_size, 1))), -1) ++ labels = nmsed_classes.reshape((max_total_size, )) ++ return dets, labels ++ + @HEADS.register_module() + class RPNHead(RPNTestMixin, AnchorHead): + """RPN head. +@@ -132,9 +193,12 @@ class RPNHead(RPNTestMixin, AnchorHead): + if cfg.nms_pre > 0 and scores.shape[0] > cfg.nms_pre: + # sort is faster than topk + # _, topk_inds = scores.topk(cfg.nms_pre) +- ranked_scores, rank_inds = scores.sort(descending=True) +- topk_inds = rank_inds[:cfg.nms_pre] +- scores = ranked_scores[:cfg.nms_pre] ++ # onnx uses topk to sort, this is simpler for onnx export ++ if torch.onnx.is_in_onnx_export(): ++ scores, topk_inds = torch.topk(scores, cfg.nms_pre) ++ else: ++ scores, topk_inds = torch.topk(scores, cfg.nms_pre) ++ + rpn_bbox_pred = rpn_bbox_pred[topk_inds, :] + anchors = anchors[topk_inds, :] + mlvl_scores.append(scores) +@@ -164,5 +228,12 @@ class RPNHead(RPNTestMixin, AnchorHead): + + # TODO: remove the hard coded nms type + nms_cfg = dict(type='nms', iou_threshold=cfg.nms_thr) +- dets, keep = batched_nms(proposals, scores, ids, nms_cfg) +- return dets[:cfg.nms_post] ++ # npu return ++ if torch.onnx.is_in_onnx_export(): ++ dets, labels = batch_nms_op(proposals, scores, 0.0, nms_cfg.get("iou_threshold"), cfg.nms_post, cfg.nms_post) ++ return dets ++ # cpu and gpu return ++ else: ++ dets, labels = batch_nms_op(proposals, scores, 0.0, nms_cfg.get("iou_threshold"), cfg.nms_post, ++ cfg.nms_post) ++ return dets +diff --git a/mmdet/models/detectors/base.py b/mmdet/models/detectors/base.py +index 7c6d5e96..7e053ae7 100644 +--- a/mmdet/models/detectors/base.py ++++ b/mmdet/models/detectors/base.py +@@ -131,6 +131,11 @@ class BaseDetector(nn.Module, metaclass=ABCMeta): + augs (multiscale, flip, etc.) and the inner list indicates + images in a batch. + """ ++ if not isinstance(imgs, list): ++ imgs = [imgs] ++ if not isinstance(img_metas, list): ++ img_metas = [img_metas] ++ + for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]: + if not isinstance(var, list): + raise TypeError(f'{name} must be a list, but got {type(var)}') +diff --git a/mmdet/models/detectors/single_stage.py b/mmdet/models/detectors/single_stage.py +index 96c4acac..929bf2a0 100644 +--- a/mmdet/models/detectors/single_stage.py ++++ b/mmdet/models/detectors/single_stage.py +@@ -114,8 +114,9 @@ class SingleStageDetector(BaseDetector): + bbox_list = self.bbox_head.get_bboxes( + *outs, img_metas, rescale=rescale) + # skip post-processing when exporting to ONNX +- if torch.onnx.is_in_onnx_export(): +- return bbox_list ++ # if torch.onnx.is_in_onnx_export(): ++ # return bbox_list ++ return bbox_list + + bbox_results = [ + bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes) +diff --git a/mmdet/models/roi_heads/cascade_roi_head.py b/mmdet/models/roi_heads/cascade_roi_head.py +index 45b6f36a..6f8b9245 100644 +--- a/mmdet/models/roi_heads/cascade_roi_head.py ++++ b/mmdet/models/roi_heads/cascade_roi_head.py +@@ -349,8 +349,9 @@ class CascadeRoIHead(BaseRoIHead, BBoxTestMixin, MaskTestMixin): + det_bboxes.append(det_bbox) + det_labels.append(det_label) + +- if torch.onnx.is_in_onnx_export(): +- return det_bboxes, det_labels ++ # if torch.onnx.is_in_onnx_export(): ++ # return det_bboxes, det_labels ++ return det_bboxes, det_labels + bbox_results = [ + bbox2result(det_bboxes[i], det_labels[i], + self.bbox_head[-1].num_classes) +diff --git a/mmdet/models/roi_heads/mask_heads/fcn_mask_head.py b/mmdet/models/roi_heads/mask_heads/fcn_mask_head.py +index 0cba3cda..38726b0c 100644 +--- a/mmdet/models/roi_heads/mask_heads/fcn_mask_head.py ++++ b/mmdet/models/roi_heads/mask_heads/fcn_mask_head.py +@@ -204,6 +204,14 @@ class FCNMaskHead(nn.Module): + if thr > 0: + masks = masks >= thr + return masks ++ else: ++ from torchvision.models.detection.roi_heads \ ++ import paste_masks_in_image ++ masks = paste_masks_in_image(mask_pred, bboxes, ori_shape[:2]) ++ thr = rcnn_test_cfg.get('mask_thr_binary', 0) ++ if thr > 0: ++ masks = masks >= thr ++ return masks + + N = len(mask_pred) + # The actual implementation split the input into chunks, +@@ -316,9 +324,9 @@ def _do_paste_mask(masks, boxes, img_h, img_w, skip_empty=True): + gy = img_y[:, :, None].expand(N, img_y.size(1), img_x.size(1)) + grid = torch.stack([gx, gy], dim=3) + +- if torch.onnx.is_in_onnx_export(): +- raise RuntimeError( +- 'Exporting F.grid_sample from Pytorch to ONNX is not supported.') ++ # if torch.onnx.is_in_onnx_export(): ++ raise RuntimeError( ++ 'Exporting F.grid_sample from Pytorch to ONNX is not supported.') + img_masks = F.grid_sample( + masks.to(dtype=torch.float32), grid, align_corners=False) + +diff --git a/mmdet/models/roi_heads/roi_extractors/single_level_roi_extractor.py b/mmdet/models/roi_heads/roi_extractors/single_level_roi_extractor.py +index c0eebc4a..d3beb385 100644 +--- a/mmdet/models/roi_heads/roi_extractors/single_level_roi_extractor.py ++++ b/mmdet/models/roi_heads/roi_extractors/single_level_roi_extractor.py +@@ -4,6 +4,30 @@ from mmcv.runner import force_fp32 + from mmdet.models.builder import ROI_EXTRACTORS + from .base_roi_extractor import BaseRoIExtractor + ++import torch.onnx.symbolic_helper as sym_help ++ ++class RoiExtractor(torch.autograd.Function): ++ @staticmethod ++ def forward(self, f0, f1, f2, f3, rois, aligned=1, finest_scale=56, pooled_height=7, pooled_width=7, ++ pool_mode='avg', roi_scale_factor=0, sample_num=0, spatial_scale=[0.25, 0.125, 0.0625, 0.03125]): ++ """ ++ feats (torch.Tensor): feats in shape (batch, 256, H, W). ++ rois (torch.Tensor): rois in shape (k, 5). ++ return: ++ roi_feats (torch.Tensor): (k, 256, pooled_width, pooled_width) ++ """ ++ ++ # phony implementation for shape inference ++ k = rois.size()[0] ++ roi_feats = torch.ones(k, 256, pooled_height, pooled_width) ++ return roi_feats ++ ++ @staticmethod ++ def symbolic(g, f0, f1, f2, f3, rois): ++ # TODO: support tensor list type for feats ++ roi_feats = g.op('RoiExtractor', f0, f1, f2, f3, rois, aligned_i=1, finest_scale_i=56, pooled_height_i=7, pooled_width_i=7, ++ pool_mode_s='avg', roi_scale_factor_i=0, sample_num_i=0, spatial_scale_f=[0.25, 0.125, 0.0625, 0.03125], outputs=1) ++ return roi_feats + + @ROI_EXTRACTORS.register_module() + class SingleRoIExtractor(BaseRoIExtractor): +@@ -52,6 +76,18 @@ class SingleRoIExtractor(BaseRoIExtractor): + + @force_fp32(apply_to=('feats', ), out_fp16=True) + def forward(self, feats, rois, roi_scale_factor=None): ++ # Work around to export onnx for npu ++ if torch.onnx.is_in_onnx_export(): ++ roi_feats = RoiExtractor.apply(feats[0], feats[1], feats[2], feats[3], rois) ++ # roi_feats = RoiExtractor.apply(list(feats), rois) ++ return roi_feats ++ else: ++ # roi_feats = RoiExtractor.apply(feats[0], feats[1], feats[2], feats[3], rois) ++ # roi_feats = RoiExtractor.apply(list(feats), rois) ++ roi_extractor = torch.ops.aie.roi_extractor ++ roi_feats = roi_extractor([feats[0], feats[1], feats[2], feats[3]], rois, 1, 56, 7, 7, 'avg', 0, 0, [0.25, 0.125, 0.0625, 0.03125]) ++ return roi_feats ++ + """Forward function.""" + out_size = self.roi_layers[0].output_size + num_levels = len(feats) +@@ -82,12 +118,12 @@ class SingleRoIExtractor(BaseRoIExtractor): + mask = target_lvls == i + inds = mask.nonzero(as_tuple=False).squeeze(1) + # TODO: make it nicer when exporting to onnx +- if torch.onnx.is_in_onnx_export(): +- # To keep all roi_align nodes exported to onnx +- rois_ = rois[inds] +- roi_feats_t = self.roi_layers[i](feats[i], rois_) +- roi_feats[inds] = roi_feats_t +- continue ++ # if torch.onnx.is_in_onnx_export(): ++ # To keep all roi_align nodes exported to onnx ++ rois_ = rois[inds] ++ roi_feats_t = self.roi_layers[i](feats[i], rois_) ++ roi_feats[inds] = roi_feats_t ++ continue + if inds.numel() > 0: + rois_ = rois[inds] + roi_feats_t = self.roi_layers[i](feats[i], rois_) +diff --git a/mmdet/models/roi_heads/standard_roi_head.py b/mmdet/models/roi_heads/standard_roi_head.py +index c530f2a5..bacba384 100644 +--- a/mmdet/models/roi_heads/standard_roi_head.py ++++ b/mmdet/models/roi_heads/standard_roi_head.py +@@ -246,13 +246,13 @@ class StandardRoIHead(BaseRoIHead, BBoxTestMixin, MaskTestMixin): + + det_bboxes, det_labels = self.simple_test_bboxes( + x, img_metas, proposal_list, self.test_cfg, rescale=rescale) +- if torch.onnx.is_in_onnx_export(): +- if self.with_mask: +- segm_results = self.simple_test_mask( +- x, img_metas, det_bboxes, det_labels, rescale=rescale) +- return det_bboxes, det_labels, segm_results +- else: +- return det_bboxes, det_labels ++ # if torch.onnx.is_in_onnx_export(): ++ if self.with_mask: ++ segm_results = self.simple_test_mask( ++ x, img_metas, det_bboxes, det_labels, rescale=rescale) ++ return det_bboxes, det_labels, segm_results ++ else: ++ return det_bboxes, det_labels + + bbox_results = [ + bbox2result(det_bboxes[i], det_labels[i], +diff --git a/mmdet/models/roi_heads/test_mixins.py b/mmdet/models/roi_heads/test_mixins.py +index 0e675d6e..171a21c3 100644 +--- a/mmdet/models/roi_heads/test_mixins.py ++++ b/mmdet/models/roi_heads/test_mixins.py +@@ -211,19 +211,18 @@ class MaskTestMixin(object): + mask_result = self._mask_forward(x, mask_rois) + mask_preds.append(mask_result['mask_pred']) + else: +- _bboxes = [ +- det_bboxes[i][:, :4] * +- scale_factors[i] if rescale else det_bboxes[i][:, :4] +- for i in range(len(det_bboxes)) +- ] +- mask_rois = bbox2roi(_bboxes) +- mask_results = self._mask_forward(x, mask_rois) +- mask_pred = mask_results['mask_pred'] +- # split batch mask prediction back to each image +- num_mask_roi_per_img = [ +- det_bbox.shape[0] for det_bbox in det_bboxes +- ] +- mask_preds = mask_pred.split(num_mask_roi_per_img, 0) ++ # avoid mask_pred.split with static number of prediction ++ mask_preds = [] ++ _bboxes = [] ++ for i, boxes in enumerate(det_bboxes): ++ boxes = boxes[:, :4] ++ if rescale: ++ boxes *= scale_factors[i] ++ _bboxes.append(boxes) ++ img_inds = boxes[:, :1].clone() * 0 + i ++ mask_rois = torch.cat([img_inds, boxes], dim=-1) ++ mask_result = self._mask_forward(x, mask_rois) ++ mask_preds.append(mask_result['mask_pred']) + + # apply mask post-processing to each image individually + segm_results = [] +-- +2.25.1 + diff --git a/AscendIE/TorchAIE/built-in/cv/detection/FasterRcnn/mmdetection_coco_postprocess.py b/AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/mmdetection_coco_postprocess.py similarity index 100% rename from AscendIE/TorchAIE/built-in/cv/detection/FasterRcnn/mmdetection_coco_postprocess.py rename to AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/mmdetection_coco_postprocess.py diff --git a/AscendIE/TorchAIE/built-in/cv/detection/FasterRcnn/mmdetection_coco_preprocess.py b/AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/mmdetection_coco_preprocess.py similarity index 100% rename from AscendIE/TorchAIE/built-in/cv/detection/FasterRcnn/mmdetection_coco_preprocess.py rename to AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/mmdetection_coco_preprocess.py diff --git a/AscendIE/TorchAIE/built-in/cv/detection/FasterRcnn/requirements.txt b/AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/requirements.txt similarity index 76% rename from AscendIE/TorchAIE/built-in/cv/detection/FasterRcnn/requirements.txt rename to AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/requirements.txt index 4ad80b6ccd..fa06f33281 100644 --- a/AscendIE/TorchAIE/built-in/cv/detection/FasterRcnn/requirements.txt +++ b/AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/requirements.txt @@ -6,9 +6,10 @@ tqdm onnx==1.7.0 Pillow==9.2.0 opencv-python==4.6.0.66 -torch==1.8.1 -torchvision==0.9.1 +torch==1.10.0 +torchvision==0.11.1 protobuf==3.20.0 +mmcv-full==1.2.4 onnxruntime==1.12.1 onnxoptimizer==0.2.7 terminaltables==3.1.10 diff --git a/AscendIE/TorchAIE/built-in/cv/detection/FasterRcnn/sample.py b/AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/sample.py similarity index 100% rename from AscendIE/TorchAIE/built-in/cv/detection/FasterRcnn/sample.py rename to AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/sample.py diff --git a/AscendIE/TorchAIE/built-in/cv/detection/FasterRcnn/trace.py b/AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/trace.py similarity index 76% rename from AscendIE/TorchAIE/built-in/cv/detection/FasterRcnn/trace.py rename to AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/trace.py index 8061599b3c..bb07fcc8ba 100644 --- a/AscendIE/TorchAIE/built-in/cv/detection/FasterRcnn/trace.py +++ b/AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/trace.py @@ -7,7 +7,6 @@ import onnxruntime as rt import torch import sys,os - from mmdet.core import (build_model_from_cfg, generate_inputs_and_wrap_model, preprocess_example_input) @@ -16,7 +15,7 @@ def trace( checkpoint_path, input_img, input_shape, - output_file='tmp.onnx', + output_file='tmp.pt', normalize_cfg=None): input_config = { @@ -25,14 +24,11 @@ def trace( 'normalize_cfg': normalize_cfg } - # prepare original model and meta for verifying the onnx model - orig_model = build_model_from_cfg(config_path, checkpoint_path) - one_img, one_meta = preprocess_example_input(input_config) model, tensor_data = generate_inputs_and_wrap_model( config_path, checkpoint_path, input_config) model.eval() - torch.jit.trace(model, tensor_data).save('faster_rcnn_r50_fpn_trace_20231130.pt') + torch.jit.trace(model, tensor_data).save(output_file) def parse_args(): parser = argparse.ArgumentParser( @@ -81,7 +77,6 @@ if __name__ == '__main__': normalize_cfg = {'mean': args.mean, 'std': args.std} - # convert model to onnx file trace( args.config, args.checkpoint, @@ -89,6 +84,3 @@ if __name__ == '__main__': input_shape, args.output_file, normalize_cfg=normalize_cfg,) - - -# python3 trace.py --config ./mmdetection/configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py --checkpoint ./faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth --input-img /data/datasets/coco/val2017/000000515350.jpg --shape 1216 --output-file "faster_rcnn_torch_aie_1214.pt" --mmdet-ops-path ./mmdet_ops/build/libmmdet_ops.so diff --git a/AscendIE/TorchAIE/built-in/cv/detection/FasterRcnn/txt_to_json.py b/AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/txt_to_json.py similarity index 100% rename from AscendIE/TorchAIE/built-in/cv/detection/FasterRcnn/txt_to_json.py rename to AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/txt_to_json.py -- Gitee From b4e7fab877c5b5e082d4e021b70b97519ed07521 Mon Sep 17 00:00:00 2001 From: hekuikui Date: Mon, 18 Dec 2023 16:30:37 +0800 Subject: [PATCH 5/5] rm redundant files --- .../Faster_rcnn_r50_fpn_1x/README.md | 14 +- .../Faster_rcnn_r50_fpn_1x/coco_eval.py | 78 --------- .../Faster_rcnn_r50_fpn_1x/get_info.py | 46 ------ .../mmdetection_coco_postprocess.py | 149 ------------------ .../mmdetection_coco_preprocess.py | 70 -------- .../Faster_rcnn_r50_fpn_1x/txt_to_json.py | 114 -------------- 6 files changed, 8 insertions(+), 463 deletions(-) delete mode 100644 AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/coco_eval.py delete mode 100644 AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/get_info.py delete mode 100644 AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/mmdetection_coco_postprocess.py delete mode 100644 AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/mmdetection_coco_preprocess.py delete mode 100644 AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/txt_to_json.py diff --git a/AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/README.md b/AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/README.md index 7d08f52e8e..46e2bac2a3 100644 --- a/AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/README.md +++ b/AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/README.md @@ -84,13 +84,11 @@ ``` FasterRcnn ├── README.md # 此文档 - ├── coco_eval.py # 验证推理精度的脚本 - ├── get_info.py # 用于获取图像数据集的info文件 ├── sample.py # 模型推理脚本文件 - ├── mmdetection_coco_postprocess.py # 推理结果后处理脚本 - ├── mmdetection_coco_preprocess.py # 数据集预处理脚本 - ├── trace.py # 模型trace脚本 - └── txt_to_json.py # 将推理结果txt文件转换为coco数据集评测精度的标准json格式 + ├── mmdetection.patch # mmdetection仓patch文件 + ├── requirements.txt # 依赖库文件 + ├── trace.py # 模型trace脚本文件 + └── mmdet_ops # 自定义算子注册文件夹 ``` @@ -124,7 +122,11 @@ patch -p1 < mmdetection.patch ``` +5. 获取前后处理脚本源码 + ```bash + cp ModelZoo-PyTorch/ACL_PyTorch/contrib/cv/detection/Faster_R-CNN_ResNet50/* ./ + ``` ## 准备数据集 1. 获取原始数据集和验证集 diff --git a/AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/coco_eval.py b/AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/coco_eval.py deleted file mode 100644 index 69d0adcf83..0000000000 --- a/AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/coco_eval.py +++ /dev/null @@ -1,78 +0,0 @@ -import argparse -import numpy as np -from pycocotools.coco import COCO -from pycocotools.cocoeval import COCOeval - -CLASSES = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', - 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', - 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', - 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', - 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', - 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', - 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', - 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', - 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', - 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', - 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', - 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', - 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', - 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush') - -def coco_evaluation(annotation_json, result_json): - cocoGt = COCO(annotation_json) - cocoDt = cocoGt.loadRes(result_json) - iou_thrs = np.linspace(.5, 0.95, int(np.round((0.95 - .5) / .05)) + 1, endpoint=True) - iou_type = 'bbox' - - cocoEval = COCOeval(cocoGt, cocoDt, iou_type) - cocoEval.params.catIds = cocoGt.getCatIds(catNms=CLASSES) - cocoEval.params.imgIds = cocoGt.getImgIds() - cocoEval.params.maxDets = [100, 300, 1000] # proposal number for evaluating recalls/mAPs. - cocoEval.params.iouThrs = iou_thrs - - cocoEval.evaluate() - cocoEval.accumulate() - cocoEval.summarize() - - # mapping of cocoEval.stats - coco_metric_names = { - 'mAP': 0, - 'mAP_50': 1, - 'mAP_75': 2, - 'mAP_s': 3, - 'mAP_m': 4, - 'mAP_l': 5, - 'AR@100': 6, - 'AR@300': 7, - 'AR@1000': 8, - 'AR_s@1000': 9, - 'AR_m@1000': 10, - 'AR_l@1000': 11 - } - - metric_items = ['mAP', 'mAP_50', 'mAP_75', 'mAP_s', 'mAP_m', 'mAP_l'] - eval_results = {} - - for metric_item in metric_items: - key = f'bbox_{metric_item}' - val = float( - f'{cocoEval.stats[coco_metric_names[metric_item]]:.3f}' - ) - eval_results[key] = val - ap = cocoEval.stats[:6] - eval_results['bbox_mAP_copypaste'] = ( - f'{ap[0]:.3f} {ap[1]:.3f} {ap[2]:.3f} {ap[3]:.3f} ' - f'{ap[4]:.3f} {ap[5]:.3f}') - - return eval_results - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument("--ground_truth", default="instances_val2017.json") - parser.add_argument("--detection_result", default="coco_detection_result.json") - args = parser.parse_args() - result = coco_evaluation(args.ground_truth, args.detection_result) - print(result) - with open('./coco_detection_result.txt', 'w') as f: - for key, value in result.items(): - f.write(key + ': ' + str(value) + '\n') \ No newline at end of file diff --git a/AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/get_info.py b/AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/get_info.py deleted file mode 100644 index 7b14c54b90..0000000000 --- a/AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/get_info.py +++ /dev/null @@ -1,46 +0,0 @@ -import os -import sys -import cv2 -from glob import glob - - -def get_bin_info(file_path, info_name, width, height): - bin_images = glob(os.path.join(file_path, '*.bin')) - with open(info_name, 'w') as file: - for index, img in enumerate(bin_images): - content = ' '.join([str(index), img, width, height]) - file.write(content) - file.write('\n') - - -def get_jpg_info(file_path, info_name): - extensions = ['jpg', 'jpeg', 'JPG', 'JPEG'] - image_names = [] - for extension in extensions: - image_names.append(glob(os.path.join(file_path, '*.' + extension))) - with open(info_name, 'w') as file: - for image_name in image_names: - if len(image_name) == 0: - continue - else: - for index, img in enumerate(image_name): - img_cv = cv2.imread(img) - shape = img_cv.shape - width, height = shape[1], shape[0] - content = ' '.join([str(index), img, str(width), str(height)]) - file.write(content) - file.write('\n') - - -if __name__ == '__main__': - file_type = sys.argv[1] - file_path = sys.argv[2] - info_name = sys.argv[3] - if file_type == 'bin': - width = sys.argv[4] - height = sys.argv[5] - assert len(sys.argv) == 6, 'The number of input parameters must be equal to 5' - get_bin_info(file_path, info_name, width, height) - elif file_type == 'jpg': - assert len(sys.argv) == 4, 'The number of input parameters must be equal to 3' - get_jpg_info(file_path, info_name) \ No newline at end of file diff --git a/AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/mmdetection_coco_postprocess.py b/AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/mmdetection_coco_postprocess.py deleted file mode 100644 index 5a6d3f9c47..0000000000 --- a/AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/mmdetection_coco_postprocess.py +++ /dev/null @@ -1,149 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import numpy as np -import argparse -import cv2 - -CLASSES = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', - 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', - 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', - 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', - 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', - 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', - 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', - 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', - 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', - 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', - 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', - 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', - 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', - 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'] - -def coco_postprocess(bbox: np.ndarray, image_size, - net_input_width, net_input_height): - """ - This function is postprocessing for FasterRCNN output. - - Before calling this function, reshape the raw output of FasterRCNN to - following form - numpy.ndarray: - [x, y, width, height, confidence, probability of 80 classes] - shape: (100,) - The postprocessing restore the bounding rectangles of FasterRCNN output - to origin scale and filter with non-maximum suppression. - - :param bbox: a numpy array of the FasterRCNN output - :param image_path: a string of image path - :return: three list for best bound, class and score - """ - w = image_size[0] - h = image_size[1] - scale = min(net_input_width / w, net_input_height / h) - - pad_w = net_input_width - w * scale - pad_h = net_input_height - h * scale - pad_left = pad_w // 2 - pad_top = pad_h // 2 - - # cal predict box on the image src - pbox = bbox - pbox[:, 0] = (bbox[:, 0] - pad_left) / scale - pbox[:, 1] = (bbox[:, 1] - pad_top) / scale - pbox[:, 2] = (bbox[:, 2] - pad_left) / scale - pbox[:, 3] = (bbox[:, 3] - pad_top) / scale - return pbox - - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument("--bin_data_path", default="./result/dumpOutput_device0") - parser.add_argument("--test_annotation", default="./coco2017_jpg.info") - parser.add_argument("--det_results_path", default="./detection-results/") - parser.add_argument("--net_out_num", default=2) - parser.add_argument("--net_input_width", default=1216) - parser.add_argument("--net_input_height", default=1216) - parser.add_argument("--prob_thres", default=0.05) - parser.add_argument("--ifShowDetObj", action="store_true", help="if input the para means True, neither False.") - flags = parser.parse_args() - print(flags.ifShowDetObj, type(flags.ifShowDetObj)) - # generate dict according to annotation file for query resolution - # load width and height of input images - img_size_dict = dict() - with open(flags.test_annotation)as f: - for line in f.readlines(): - temp = line.split(" ") - img_file_path = temp[1] - img_name = temp[1].split("/")[-1].split(".")[0] - img_width = int(temp[2]) - img_height = int(temp[3]) - img_size_dict[img_name] = (img_width, img_height, img_file_path) - - # read bin file for generate predict result - bin_path = flags.bin_data_path - det_results_path = flags.det_results_path - os.makedirs(det_results_path, exist_ok=True) - total_img = set([name[:name.rfind('_')] - for name in os.listdir(bin_path) if "bin" in name]) - - for bin_file in sorted(total_img): - path_base = os.path.join(bin_path, bin_file) - # load all detected output tensor - res_buff = [] - for num in range(0, flags.net_out_num): - if os.path.exists(path_base + "_" + str(num) + ".bin"): - if num == 0: - buf = np.fromfile(path_base + "_" + str(num) + ".bin", dtype="float32") - buf = np.reshape(buf, [100, 5]) - elif num == 1: - buf = np.fromfile(path_base + "_" + str(num) + ".bin", dtype="int32") - buf = np.reshape(buf, [100, 1]) - res_buff.append(buf) - else: - print("[ERROR] file not exist", path_base + "_" + str(num) + ".bin") - res_tensor = np.concatenate(res_buff, axis=1) - current_img_size = img_size_dict[bin_file] - print("[TEST]---------------------------concat{} imgsize{}".format(len(res_tensor), current_img_size)) - predbox = coco_postprocess(res_tensor, current_img_size, flags.net_input_width, flags.net_input_height) - - if flags.ifShowDetObj == True: - imgCur = cv2.imread(current_img_size[2]) - - det_results_str = '' - for idx, class_ind in enumerate(predbox[:,5]): - if float(predbox[idx][4]) < float(flags.prob_thres): - continue - # skip negative class index - if class_ind < 0 or class_ind > 80: - continue - - class_name = CLASSES[int(class_ind)] - det_results_str += "{} {} {} {} {} {}\n".format(class_name, str(predbox[idx][4]), predbox[idx][0], - predbox[idx][1], predbox[idx][2], predbox[idx][3]) - if flags.ifShowDetObj == True: - imgCur=cv2.rectangle(imgCur, (int(predbox[idx][0]), int(predbox[idx][1])), - (int(predbox[idx][2]), int(predbox[idx][3])), (0,255,0), 1) - imgCur = cv2.putText(imgCur, class_name+'|'+str(predbox[idx][4]), - (int(predbox[idx][0]), int(predbox[idx][1])), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1) - # 图像,文字内容, 坐标 ,字体,大小,颜色,字体厚度 - - if flags.ifShowDetObj == True: - print(os.path.join(det_results_path, bin_file +'.jpg')) - cv2.imwrite(os.path.join(det_results_path, bin_file +'.jpg'), imgCur, [int(cv2.IMWRITE_JPEG_QUALITY),70]) - - det_results_file = os.path.join(det_results_path, bin_file + ".txt") - with open(det_results_file, "w") as detf: - detf.write(det_results_str) - print(det_results_str) diff --git a/AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/mmdetection_coco_preprocess.py b/AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/mmdetection_coco_preprocess.py deleted file mode 100644 index acf0029ccb..0000000000 --- a/AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/mmdetection_coco_preprocess.py +++ /dev/null @@ -1,70 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np -import os -import cv2 -import argparse -import mmcv -import torch -from tqdm import tqdm - -dataset_config = { - 'resize': (1216, 1216), - 'mean': [123.675, 116.28, 103.53], - 'std': [58.395, 57.12, 57.375], -} - -tensor_height = 1216 -tensor_width = 1216 - -def coco_preprocess(input_image, output_bin_path): - #define the output file name - img_name = input_image.split('/')[-1] - bin_name = img_name.split('.')[0] + ".bin" - bin_fl = os.path.join(output_bin_path, bin_name) - - one_img = mmcv.imread(os.path.join(input_image), backend='cv2') - one_img = mmcv.imrescale(one_img, (tensor_height, tensor_width)) - # calculate padding - h = one_img.shape[0] - w = one_img.shape[1] - pad_left = (tensor_width - w) // 2 - pad_top = (tensor_height - h) // 2 - pad_right = tensor_width - pad_left - w - pad_bottom = tensor_height - pad_top - h - - mean = np.array(dataset_config['mean'], dtype=np.float32) - std = np.array(dataset_config['std'], dtype=np.float32) - one_img = mmcv.imnormalize(one_img, mean, std) - one_img = mmcv.impad(one_img, padding=(pad_left, pad_top, pad_right, pad_bottom), pad_val=0) - one_img = one_img.transpose(2, 0, 1) - one_img.tofile(bin_fl) - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description='preprocess of FasterRCNN pytorch model') - parser.add_argument("--image_folder_path", default="./coco2014/", help='image of dataset') - parser.add_argument("--bin_folder_path", default="./coco2014_bin/", help='Preprocessed image buffer') - flags = parser.parse_args() - - if not os.path.exists(flags.bin_folder_path): - os.makedirs(flags.bin_folder_path) - print("Start to process images...") - images = tqdm(os.listdir(flags.image_folder_path)) - for image_name in images: - if not (image_name.endswith(".jpeg") or image_name.endswith(".JPEG") or image_name.endswith(".jpg")): - continue - path_image = os.path.join(flags.image_folder_path, image_name) - coco_preprocess(path_image, flags.bin_folder_path) - print("Done.") diff --git a/AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/txt_to_json.py b/AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/txt_to_json.py deleted file mode 100644 index 9f479da6a0..0000000000 --- a/AscendIE/TorchAIE/built-in/cv/detection/Faster_rcnn_r50_fpn_1x/txt_to_json.py +++ /dev/null @@ -1,114 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import glob -import os -import sys -import argparse -import mmcv - -CLASSES = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', - 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', - 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', - 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', - 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', - 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', - 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', - 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', - 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', - 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', - 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', - 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', - 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', - 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'] - -cat_ids = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, -24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, -48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, -72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90] - -''' - 0,0 ------> x (width) - | - | (Left,Top) - | *_________ - | | | - | | - y |_________| - (height) * - (Right,Bottom) -''' - -def file_lines_to_list(path): - # open txt file lines to a list - with open(path) as f: - content = f.readlines() - # remove whitespace characters like `\n` at the end of each line - content = [x.strip() for x in content] - return content - - -def error(msg): - print(msg) - sys.exit(0) - - -def get_predict_list(file_path, gt_classes): - dr_files_list = glob.glob(file_path + '/*.txt') - dr_files_list.sort() - - bounding_boxes = [] - for txt_file in dr_files_list: - file_id = txt_file.split(".txt", 1)[0] - file_id = os.path.basename(os.path.normpath(file_id)) - lines = file_lines_to_list(txt_file) - for line in lines: - try: - sl = line.split() - if len(sl) > 6: - class_name = sl[0] + ' ' + sl[1] - scores, left, top, right, bottom = sl[2:] - else: - class_name, scores, left, top, right, bottom = sl - if float(scores) < 0.05: - continue - except ValueError: - error_msg = "Error: File " + txt_file + " wrong format.\n" - error_msg += " Expected: \n" - error_msg += " Received: " + line - error(error_msg) - - # bbox = left + " " + top + " " + right + " " + bottom - left = float(left) - right = float(right) - top = float(top) - bottom = float(bottom) - bbox = [left, top, right-left, bottom-top] - bounding_boxes.append({"image_id": int(file_id), "bbox": bbox, - "score": float(scores), "category_id": cat_ids[CLASSES.index(class_name)]}) - # sort detection-results by decreasing scores - # bounding_boxes.sort(key=lambda x: float(x['score']), reverse=True) - return bounding_boxes - - - -if __name__ == '__main__': - parser = argparse.ArgumentParser('mAp calculate') - parser.add_argument('--npu_txt_path', default="detection-results", - help='the path of the predict result') - parser.add_argument("--json_output_file", default="coco_detection_result") - args = parser.parse_args() - - res_bbox = get_predict_list(args.npu_txt_path, CLASSES) - mmcv.dump(res_bbox, args.json_output_file + '.json') \ No newline at end of file -- Gitee