From f057eb95e70e89a3df96e25342cf83fdb8940e8e Mon Sep 17 00:00:00 2001 From: aisiqi <239171919@qq.com> Date: Mon, 1 Aug 2022 15:38:28 +0800 Subject: [PATCH] =?UTF-8?q?[=E4=BC=97=E6=99=BA][CRNN][PyTorch]=20=E9=83=A8?= =?UTF-8?q?=E5=88=86=E6=96=87=E4=BB=B6=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit [众智][CRNN][PyTorch] 更正导入模型参数时参数的名称格式 [众智][CRNN][PyTorch] 修改模型参数文件路径 [众智][CRNN][PyTorch] 更正img_key和lab_key的格式 [众智][CRNN][PyTorch] 更正导入模型参数时参数的名称格式 [众智][CRNN][PyTorch] 更正img_key和lab_key的格式 删除了设置环境变量的语句 --- .../classification/CRNN_for_PyTorch/LMDB_config.yaml | 2 +- .../cv/classification/CRNN_for_PyTorch/main.py | 12 ++++++++++-- .../CRNN_for_PyTorch/sdk_infer/extract_lmdb.py | 4 ++-- .../CRNN_for_PyTorch/sdk_infer/sdk_run_infer/run.sh | 6 ------ .../cv/classification/CRNN_for_PyTorch/utils.py | 7 +++---- 5 files changed, 16 insertions(+), 15 deletions(-) diff --git a/PyTorch/built-in/cv/classification/CRNN_for_PyTorch/LMDB_config.yaml b/PyTorch/built-in/cv/classification/CRNN_for_PyTorch/LMDB_config.yaml index 5672906c8a..fd70997bed 100644 --- a/PyTorch/built-in/cv/classification/CRNN_for_PyTorch/LMDB_config.yaml +++ b/PyTorch/built-in/cv/classification/CRNN_for_PyTorch/LMDB_config.yaml @@ -24,7 +24,7 @@ TRAIN: OPTIMIZER: 'adadelta' RESUME: IS_RESUME: False - FILE: '/home/CRNN_Chinese_Characters_Rec-stable_o2_epoch100/npu/output/2020-10-10-12-45/checkpoints/checkpoint_10_acc_0.7927.pth' + FILE: "./checkpoint.pth" TEST: MODEL_FILE: '' diff --git a/PyTorch/built-in/cv/classification/CRNN_for_PyTorch/main.py b/PyTorch/built-in/cv/classification/CRNN_for_PyTorch/main.py index 9222664a27..84c04631e4 100644 --- a/PyTorch/built-in/cv/classification/CRNN_for_PyTorch/main.py +++ b/PyTorch/built-in/cv/classification/CRNN_for_PyTorch/main.py @@ -27,7 +27,7 @@ import torch.nn.parallel from torch.utils.data import DataLoader from apex import amp from easydict import EasyDict as edict - +from collections import OrderedDict def parse_arg(): parser = argparse.ArgumentParser(description="train crnn") @@ -78,7 +78,15 @@ def main(): checkpoint = torch.load(model_state_file, map_location=device) if 'state_dict' in checkpoint.keys(): last_epoch = checkpoint['epoch'] - model.load_state_dict(checkpoint['state_dict']) + state_dict = checkpoint['state_dict'] + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + if(k[0:7]=="module."): + name = k[7:] + else: + name = k[0:] + new_state_dict[name] = v + model.load_state_dict(new_state_dict) best_acc = checkpoint['best_acc'] optimizer.load_state_dict(checkpoint['optimizer']) if config.TRAIN.AMP: diff --git a/PyTorch/built-in/cv/classification/CRNN_for_PyTorch/sdk_infer/extract_lmdb.py b/PyTorch/built-in/cv/classification/CRNN_for_PyTorch/sdk_infer/extract_lmdb.py index a2d814e2ec..6921a319df 100644 --- a/PyTorch/built-in/cv/classification/CRNN_for_PyTorch/sdk_infer/extract_lmdb.py +++ b/PyTorch/built-in/cv/classification/CRNN_for_PyTorch/sdk_infer/extract_lmdb.py @@ -47,9 +47,9 @@ class LmdbDataLoader(object): index = self.cur_index self.cur_index += 1 with self.env.begin(write=False) as txn: - img_key = b'image-%09d' % index + img_key = b'img_%d' % index imgbuf = txn.get(img_key) - label_key = b'label-%09d' % index + label_key = b'lab_%d' % index label = txn.get(label_key).decode('utf-8').lower() print(f"read img {img_key} {label}") diff --git a/PyTorch/built-in/cv/classification/CRNN_for_PyTorch/sdk_infer/sdk_run_infer/run.sh b/PyTorch/built-in/cv/classification/CRNN_for_PyTorch/sdk_infer/sdk_run_infer/run.sh index a6fc8f6f9d..1d2b84b483 100644 --- a/PyTorch/built-in/cv/classification/CRNN_for_PyTorch/sdk_infer/sdk_run_infer/run.sh +++ b/PyTorch/built-in/cv/classification/CRNN_for_PyTorch/sdk_infer/sdk_run_infer/run.sh @@ -24,13 +24,7 @@ CUR_PATH=$(cd "$(dirname "$0")" || { warn "Failed to check path/to/run.sh" ; exi info() { echo -e "\033[1;34m[INFO ][MxStream] $1\033[1;37m" ; } warn() { echo >&2 -e "\033[1;31m[WARN ][MxStream] $1\033[1;37m" ; } -#export MX_SDK_HOME=${CUR_PATH}/../../.. -export LD_LIBRARY_PATH=${MX_SDK_HOME}/lib:${MX_SDK_HOME}/opensource/lib:${MX_SDK_HOME}/opensource/lib64:/usr/local/Ascend/ascend-toolkit/latest/acllib/lib64:${LD_LIBRARY_PATH} -export GST_PLUGIN_SCANNER=${MX_SDK_HOME}/opensource/libexec/gstreamer-1.0/gst-plugin-scanner -export GST_PLUGIN_PATH=${MX_SDK_HOME}/opensource/lib/gstreamer-1.0:${MX_SDK_HOME}/lib/plugins -#to set PYTHONPATH, import the StreamManagerApi.py -export PYTHONPATH=$PYTHONPATH:${MX_SDK_HOME}/python python3 main.py $image_path $result_dir exit 0 diff --git a/PyTorch/built-in/cv/classification/CRNN_for_PyTorch/utils.py b/PyTorch/built-in/cv/classification/CRNN_for_PyTorch/utils.py index 54dcdbb5ab..396d5101ef 100644 --- a/PyTorch/built-in/cv/classification/CRNN_for_PyTorch/utils.py +++ b/PyTorch/built-in/cv/classification/CRNN_for_PyTorch/utils.py @@ -171,7 +171,6 @@ class strLabelConverter(object): index += l return texts - class resizeNormalize(object): def __init__(self, size, interpolation=Image.BICUBIC): self.size = size @@ -244,7 +243,7 @@ class lmdbDataset(Dataset): self.filtered_index_list = [] for index in range(self.nSamples): index += 1 - label_key = 'label-%09d'.encode() % index + label_key = 'lab_%d'.encode() % index label = txn.get(label_key).decode('utf-8') out_of_char = f'[^{self.alphabets}]' if re.search(out_of_char, label.lower()): @@ -262,7 +261,7 @@ class lmdbDataset(Dataset): index = self.filtered_index_list[index] with self.env.begin(write=False) as txn: - img_key = 'image-%09d'.encode() % index + img_key = 'img_%d'.encode() % index imgbuf = txn.get(img_key) buf = six.BytesIO() buf.write(imgbuf) @@ -274,7 +273,7 @@ class lmdbDataset(Dataset): return self[index + 1] if self.transform is not None: img = self.transform(img) - label_key = 'label-%09d'.encode() % index + label_key = 'lab_%d'.encode() % index label = txn.get(label_key).decode('utf-8') label = label.lower() return (img, label) -- Gitee