diff --git a/AscendIE/TorchAIE/built-in/cv/detection/ssd/acc_dataset.py b/AscendIE/TorchAIE/built-in/cv/detection/ssd/acc_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..d43094efb15915912a7e6287732779861d5b5dcb --- /dev/null +++ b/AscendIE/TorchAIE/built-in/cv/detection/ssd/acc_dataset.py @@ -0,0 +1,97 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +from tqdm import tqdm +import os + +import torch +import numpy as np + + +USE_NPU = True +INPUT_WIDTH = 300 +INPUT_HEIGHT = 300 + + +def parse_args(): + args = argparse.ArgumentParser(description="A program that operates in 'om' or 'ts' mode.") + args.add_argument('--ts_path',help='MobilenetV1 ts file path', type=str, + default='./ssd300_coco.ts' + ) + args.add_argument("--batch_size", type=int, default=1, help="batch size.") + args.add_argument('--img_bin_path',help='image bin path', type=str, + default='./coco2017_bin' + ) + args.add_argument('--save_dir',help='result save dir', type=str, + default='./pyinfer_res_npu' + ) + return args.parse_args() + + +if __name__ == '__main__': + infer_times = 100 + om_cost = 0 + pt_cost = 0 + opts = parse_args() + batch_size = opts.batch_size + directory_path = opts.img_bin_path + save_dir = opts.save_dir + + model = torch.jit.load(opts.ts_path) + + if USE_NPU: + + import torch_aie + from torch_aie import _enums + + input_info = [torch_aie.Input((batch_size, 3, INPUT_WIDTH, INPUT_HEIGHT))] + torch_aie.set_device(0) + print("start compile") + model = torch_aie.compile( + model, + inputs=input_info, + precision_policy=_enums.PrecisionPolicy.FP32, + soc_version='Ascend310P3', + optimization_level=0 + ) + print("end compile") + + if not os.path.exists(save_dir): + os.makedirs(save_dir) + print(f"Directory '{save_dir}' created.") + + # Iterate through each file in the directory + for filename in tqdm(os.listdir(directory_path)): + filepath = os.path.join(directory_path, filename) + if os.path.isfile(filepath): + try: + with open(filepath, 'rb') as file: + array_data = np.fromfile(file, dtype=np.float32).reshape((batch_size, 3, INPUT_WIDTH, INPUT_HEIGHT)) + torch_tensor = torch.tensor(array_data) + + if USE_NPU: + input_tensor_npu = torch_tensor.to("npu:0") + aieout_npu = model(input_tensor_npu) + first_out = aieout_npu[0].to("cpu").detach().numpy() + second_out = aieout_npu[1].to("cpu").detach().numpy() + else: + tsout = model(torch_tensor) + first_out = tsout[0].detach().numpy() + second_out = tsout[1].detach().numpy() + + first_out.tofile(os.path.join(save_dir, filename.split(".")[0] + "_0.bin")) + second_out.tofile(os.path.join(save_dir, filename.split(".")[0] + "_1.bin")) + except Exception as e: + print(f'Error reading {filename}: {e}') diff --git a/AscendIE/TorchAIE/built-in/cv/detection/ssd/coco_eval.py b/AscendIE/TorchAIE/built-in/cv/detection/ssd/coco_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..2a3ab83f15d3d29b2e94bd6badcca4960ec1b615 --- /dev/null +++ b/AscendIE/TorchAIE/built-in/cv/detection/ssd/coco_eval.py @@ -0,0 +1,98 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +from tqdm import tqdm + + +import numpy as np +from pycocotools.coco import COCO +from pycocotools.cocoeval import COCOeval + + +CLASSES = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', + 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', + 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', + 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', + 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', + 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', + 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', + 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', + 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', + 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', + 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', + 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', + 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', + 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush') + + +def coco_evaluation(annotation_json, result_json): + cocoGt = COCO(annotation_json) + cocoDt = cocoGt.loadRes(result_json) + iou_thrs = np.linspace(.5, 0.95, int(np.round((0.95 - .5) / .05)) + 1, endpoint=True) + iou_type = 'bbox' + + cocoEval = COCOeval(cocoGt, cocoDt, iou_type) + cocoEval.params.catIds = cocoGt.get_cat_ids(cat_names=CLASSES) + cocoEval.params.imgIds = cocoGt.get_img_ids() + cocoEval.params.maxDets = [100, 300, 1000] # proposal number for evaluating recalls/mAPs. + cocoEval.params.iouThrs = iou_thrs + + cocoEval.evaluate() + cocoEval.accumulate() + cocoEval.summarize() + + # mapping of cocoEval.stats + coco_metric_names = { + 'mAP': 0, + 'mAP_50': 1, + 'mAP_75': 2, + 'mAP_s': 3, + 'mAP_m': 4, + 'mAP_l': 5, + 'AR@100': 6, + 'AR@300': 7, + 'AR@1000': 8, + 'AR_s@1000': 9, + 'AR_m@1000': 10, + 'AR_l@1000': 11 + } + + metric_items = ['mAP', 'mAP_50', 'mAP_75', 'mAP_s', 'mAP_m', 'mAP_l'] + eval_results = {} + + for metric_item in tqdm(metric_items): + key = f'bbox_{metric_item}' + val = float( + f'{cocoEval.stats[coco_metric_names[metric_item]]:.3f}' + ) + eval_results[key] = val + ap = cocoEval.stats[:6] + eval_results['bbox_mAP_copypaste'] = ( + f'{ap[0]:.3f} {ap[1]:.3f} {ap[2]:.3f} {ap[3]:.3f} ' + f'{ap[4]:.3f} {ap[5]:.3f}') + + return eval_results + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--ground_truth", default="instances_val2017.json") + parser.add_argument("--detection_result", default="coco_detection_result.json") + args = parser.parse_args() + result = coco_evaluation(args.ground_truth, args.detection_result) + print(result) + with open('./coco_detection_result.txt', 'w') as f: + for key, value in result.items(): + f.write(key + ': ' + str(value) + '\n') diff --git a/AscendIE/TorchAIE/built-in/cv/detection/ssd/export.py b/AscendIE/TorchAIE/built-in/cv/detection/ssd/export.py new file mode 100644 index 0000000000000000000000000000000000000000..50610187dd4166227e57f37ecff9eae4f47e269a --- /dev/null +++ b/AscendIE/TorchAIE/built-in/cv/detection/ssd/export.py @@ -0,0 +1,97 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import argparse + +import torch + +from mmdet.core import (build_model_from_cfg, generate_inputs_and_wrap_model) + + +def pytorch2onnx(config_path, + checkpoint_path, + input_img, + input_shape, + normalize_cfg=None): + + input_config = { + 'input_shape': input_shape, + 'input_path': input_img, + 'normalize_cfg': normalize_cfg + } + + # prepare original model and meta for verifying the onnx model + orig_model = build_model_from_cfg(config_path, checkpoint_path) + print("type of orig_model:", type(orig_model)) + model, tensor_data = generate_inputs_and_wrap_model( + config_path, checkpoint_path, input_config) + + ts_model = torch.jit.trace(model, tensor_data) + ts_model.save("./ssd300_coco.ts") + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Convert MMDetection models to ONNX') + parser.add_argument('--checkpoint', help='checkpoint file', type=str, default='./ssd300_coco_20200307-a92d2092.pth') + parser.add_argument('--mmdet_path',help='mmdetection repo folder path', type=str, + default='./mmdetection' + ) + parser.add_argument( + '--shape', + type=int, + nargs='+', + default=[300], + help='input image size') + parser.add_argument( + '--mean', + type=float, + nargs='+', + default=[123.675, 116.28, 103.53], + help='mean value used for preprocess input data') + parser.add_argument( + '--std', + type=float, + nargs='+', + default=[1, 1, 1], + help='variance value used for preprocess input data') + args = parser.parse_args() + return args + + +if __name__ == '__main__': + args = parse_args() + cfg = os.path.join(args.mmdet_path, "configs/ssd/ssd300_coco.py") + input_img = os.path.join(args.mmdet_path, "tests/data/color.jpg") + + if len(args.shape) == 1: + input_shape = (1, 3, args.shape[0], args.shape[0]) + elif len(args.shape) == 2: + input_shape = (1, 3) + tuple(args.shape) + else: + raise ValueError('invalid input shape') + + assert len(args.mean) == 3 + assert len(args.std) == 3 + + normalize_cfg = {'mean': args.mean, 'std': args.std} + + # convert model to onnx file + pytorch2onnx( + cfg, + args.checkpoint, + input_img, + input_shape, + normalize_cfg=normalize_cfg) diff --git a/AscendIE/TorchAIE/built-in/cv/detection/ssd/get_info.py b/AscendIE/TorchAIE/built-in/cv/detection/ssd/get_info.py new file mode 100644 index 0000000000000000000000000000000000000000..7bfd9c751596be3ac77c8ea199413e8e4d3d4e72 --- /dev/null +++ b/AscendIE/TorchAIE/built-in/cv/detection/ssd/get_info.py @@ -0,0 +1,63 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +from glob import glob + +import cv2 + + +def get_bin_info(bin_file_path, bin_info_name, bin_width, bin_height): + """get_bin_info""" + bin_images = glob(os.path.join(bin_file_path, '*.bin')) + with open(bin_info_name, 'w') as info_file: + for index, img in enumerate(bin_images): + content = ' '.join([str(index), img, bin_width, bin_height]) + info_file.write(content) + info_file.write('\n') + + +def get_jpg_info(jpg_file_path, jpg_info_name): + """get_jpg_info""" + extensions = ['jpg', 'jpeg', 'JPG', 'JPEG'] + image_names = [] + for extension in extensions: + image_names.append(glob(os.path.join(jpg_file_path, '*.' + extension))) + with open(jpg_info_name, 'w') as jpg_file: + for image_name in image_names: + if len(image_name) == 0: + continue + else: + for index, img in enumerate(image_name): + img_cv = cv2.imread(img) + shape = img_cv.shape + jpg_width, jpg_height = shape[1], shape[0] + content = ' '.join([str(index), img, str(jpg_width), str(jpg_height)]) + jpg_file.write(content) + jpg_file.write('\n') + + +if __name__ == '__main__': + file_type = sys.argv[1] + file_path = sys.argv[2] + info_name = sys.argv[3] + if file_type == 'bin': + width = sys.argv[4] + height = sys.argv[5] + assert len(sys.argv) == 6, 'The number of input parameters must be equal to 5' + get_bin_info(file_path, info_name, width, height) + elif file_type == 'jpg': + assert len(sys.argv) == 4, 'The number of input parameters must be equal to 3' + get_jpg_info(file_path, info_name) \ No newline at end of file diff --git a/AscendIE/TorchAIE/built-in/cv/detection/ssd/perf.py b/AscendIE/TorchAIE/built-in/cv/detection/ssd/perf.py new file mode 100644 index 0000000000000000000000000000000000000000..d815e1ec7bb044ab2df618dd793ab310e355fedf --- /dev/null +++ b/AscendIE/TorchAIE/built-in/cv/detection/ssd/perf.py @@ -0,0 +1,116 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import time +from tqdm import tqdm + +import torch +import numpy as np + +import torch_aie +from torch_aie import _enums + + +INPUT_WIDTH = 300 +INPUT_HEIGHT = 300 + +def parse_args(): + args = argparse.ArgumentParser(description="A program that operates in 'om' or 'ts' mode.") + args.add_argument("--mode", choices=["om", "ts"], required=True, help="Specify the mode ('om' or 'ts').") + args.add_argument('--om_path',help='MobilenetV1 om file path', type=str, + default='/onnx/mobilenetv1/mobilenet-v1_bs1.om' + ) + args.add_argument('--ts_path',help='MobilenetV1 ts file path', type=str, + default='/onnx/ssd/ssd300_coco.ts' + ) + args.add_argument("--batch-size", type=int, default=4, help="batch size.") + return args.parse_args() + +if __name__ == '__main__': + infer_times = 100 + om_cost = 0 + pt_cost = 0 + opts = parse_args() + TS_PATH = opts.ts_path + OM_PATH = opts.om_path + BATCH_SIZE = opts.batch_size + + if opts.mode == "om": + om_model = InferSession(0, OM_PATH) + for _ in tqdm(range(0, infer_times)): + dummy_input = np.random.randn(BATCH_SIZE, 3, INPUT_WIDTH, INPUT_HEIGHT).astype(np.uint8) + start = time.time() + output = om_model.infer([dummy_input], 'static', custom_sizes=90000000) # revise static + # output = om_model.infer([dummy_input], 'dymshape', custom_sizes=4000) # revise dynm fp32为4个字节,输出为1x1000 + cost = time.time() - start + om_cost += cost + + if opts.mode == "ts": + ts_model = torch.jit.load(TS_PATH) + + input_info = [torch_aie.Input((BATCH_SIZE, 3, INPUT_WIDTH, INPUT_HEIGHT))] + + torch_aie.set_device(0) + print("start compile") + torchaie_model = torch_aie.compile( + ts_model, + inputs=input_info, + precision_policy=_enums.PrecisionPolicy.FP16, + soc_version='Ascend310P3', + ) + print("end compile") + torchaie_model.eval() + + print("start export") + torch_aie.export_engine(ts_model, + "forward", + "ssd.om", + inputs=input_info, + precision_policy=_enums.PrecisionPolicy.FP16, + soc_version='Ascend310P3') + print("end export") + + dummy_input = np.random.randn(BATCH_SIZE, 3, INPUT_WIDTH, INPUT_HEIGHT).astype(np.float32) + input_tensor = torch.Tensor(dummy_input) + loops = 100 + warm_ctr = 10 + + default_stream = torch_aie.npu.default_stream() + time_cost = 0 + + input_tensor = input_tensor.to("npu") + while warm_ctr: + _ = torchaie_model(input_tensor) + default_stream.synchronize() + warm_ctr -= 1 + + print("send to npu") + input_tensor = input_tensor.to("npu") + print("finish sent") + for i in range(loops): + t0 = time.time() + _ = torchaie_model(input_tensor) # tuple of 2 lists of len 6 + default_stream.synchronize() + t1 = time.time() + time_cost += (t1 - t0) + print(i) + + print(f"fps: {loops} * {BATCH_SIZE} / {time_cost : .3f} samples/s") + print("torch_aie fps: ", loops * BATCH_SIZE / time_cost) + + from datetime import datetime + current_time = datetime.now() + formatted_time = current_time.strftime("%Y-%m-%d %H:%M:%S") + print("Current Time:", formatted_time) diff --git a/AscendIE/TorchAIE/built-in/cv/detection/ssd/ssd-requirements.txt b/AscendIE/TorchAIE/built-in/cv/detection/ssd/ssd-requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..27c4ebd5e927592e41a54ae86416dfab190c02e6 --- /dev/null +++ b/AscendIE/TorchAIE/built-in/cv/detection/ssd/ssd-requirements.txt @@ -0,0 +1,7 @@ +protobuf==3.20.0 +Cython==0.29.35 +matplotlib==3.5.3 +mmpycocotools==12.0.3 +torch==1.8.1 +torchvision==0.9.1 +tqdm==4.66.1 \ No newline at end of file diff --git a/AscendIE/TorchAIE/built-in/cv/detection/ssd/ssd_preprocess.py b/AscendIE/TorchAIE/built-in/cv/detection/ssd/ssd_preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..8fc86186d8619490f723b40a7c55140e686c958c --- /dev/null +++ b/AscendIE/TorchAIE/built-in/cv/detection/ssd/ssd_preprocess.py @@ -0,0 +1,68 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import argparse +from tqdm import tqdm + +import numpy as np +import mmcv + + +dataset_config = { + 'resize': (300, 300), + 'mean': [123.675, 116.28, 103.53], + 'std': [1, 1, 1], +} + +tensor_height = 300 +tensor_width = 300 + + +def coco_preprocess(input_image, output_bin_path): + """coco_preprocess""" + # define the output file name + img_name = input_image.split('/')[-1] + bin_name = img_name.split('.')[0] + ".bin" + bin_fl = os.path.join(output_bin_path, bin_name) + + one_img = mmcv.imread(input_image, backend='cv2') + one_img = mmcv.imresize(one_img, (tensor_height, tensor_width)) + # calculate padding + mean = np.array(dataset_config['mean'], dtype=np.float32) + std = np.array(dataset_config['std'], dtype=np.float32) + one_img = mmcv.imnormalize(one_img, mean, std) + one_img = one_img.transpose(2, 0, 1) + print("one_img.dtype: ", one_img.dtype) + print("one_img.shape: ", one_img.shape) + one_img.tofile(bin_fl) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description='preprocess of FasterRCNN pytorch model') + parser.add_argument("--image_folder_path", + default="/home/ascend/coco2017/val2017", help='image of dataset') + parser.add_argument( + "--bin_folder_path", default="/home/ascend/coco2017_bin/", help='Preprocessed image buffer') + flags = parser.parse_args() + + if not os.path.exists(flags.bin_folder_path): + os.makedirs(flags.bin_folder_path) + images = os.listdir(flags.image_folder_path) + for image_name in tqdm(images, desc="Starting to process image..."): + if not (image_name.endswith(".jpeg") or image_name.endswith(".JPEG") or image_name.endswith(".jpg")): + continue + path_image = os.path.join(flags.image_folder_path, image_name) + coco_preprocess(path_image, flags.bin_folder_path) diff --git a/AscendIE/TorchAIE/built-in/cv/detection/ssd/txt_to_json.py b/AscendIE/TorchAIE/built-in/cv/detection/ssd/txt_to_json.py new file mode 100644 index 0000000000000000000000000000000000000000..328b2319c84f68e8bdcd626b62ce330b3c908033 --- /dev/null +++ b/AscendIE/TorchAIE/built-in/cv/detection/ssd/txt_to_json.py @@ -0,0 +1,114 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import glob +import os +import sys +import argparse +import mmcv + +CLASSES = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', + 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', + 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', + 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', + 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', + 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', + 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', + 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', + 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', + 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', + 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', + 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', + 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', + 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'] + +cat_ids = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, +24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, +48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, +72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90] + +''' + 0,0 ------> x (width) + | + | (Left,Top) + | *_________ + | | | + | | + y |_________| + (height) * + (Right,Bottom) +''' + +def file_lines_to_list(path): + """file_lines_to_list""" + # open txt file lines to a list + with open(path) as f: + content = f.readlines() + # remove whitespace characters like `\n` at the end of each line + content = [x.strip() for x in content] + return content + + +def error(msg): + """error""" + print(msg) + sys.exit(0) + + +def get_predict_list(file_path): + """get_predict_list""" + dr_files_list = glob.glob(file_path + '/*.txt') + dr_files_list.sort() + + bounding_boxes = [] + for txt_file in dr_files_list: + file_id = txt_file.split(".txt", 1)[0] + file_id = os.path.basename(os.path.normpath(file_id)) + lines = file_lines_to_list(txt_file) + for line in lines: + try: + sl = line.split() + if len(sl) > 6: + class_name = sl[0] + ' ' + sl[1] + scores, left, top, right, bottom = sl[2:] + else: + class_name, scores, left, top, right, bottom = sl + if float(scores) < 0.02: + continue + except ValueError: + error_msg = "Error: File " + txt_file + " wrong format.\n" + error_msg += " Expected: \n" + error_msg += " Received: " + line + error(error_msg) + + # bbox = left + " " + top + " " + right + " " + bottom + left = float(left) + right = float(right) + top = float(top) + bottom = float(bottom) + bbox = [left, top, right - left, bottom - top] + bounding_boxes.append({"image_id": int(file_id), "bbox": bbox, "score": float(scores), + "category_id": cat_ids[CLASSES.index(class_name)]}) + return bounding_boxes + + +if __name__ == '__main__': + parser = argparse.ArgumentParser('mAp calculate') + parser.add_argument('--npu_txt_path', default="detection-results", + help='the path of the predict result') + parser.add_argument("--json_output_file", default="coco_detection_result") + args = parser.parse_args() + + res_bbox = get_predict_list(args.npu_txt_path) + mmcv.dump(res_bbox, args.json_output_file + '.json')