diff --git a/PyTorch/built-in/cv/classification/CRNN_for_PyTorch/LMDB_config.yaml b/PyTorch/built-in/cv/classification/CRNN_for_PyTorch/LMDB_config.yaml index 5672906c8a2a8dc49c56afad970334ae7d122650..fd70997bed1618b2604136af2efab4f5a30fb022 100644 --- a/PyTorch/built-in/cv/classification/CRNN_for_PyTorch/LMDB_config.yaml +++ b/PyTorch/built-in/cv/classification/CRNN_for_PyTorch/LMDB_config.yaml @@ -24,7 +24,7 @@ TRAIN: OPTIMIZER: 'adadelta' RESUME: IS_RESUME: False - FILE: '/home/CRNN_Chinese_Characters_Rec-stable_o2_epoch100/npu/output/2020-10-10-12-45/checkpoints/checkpoint_10_acc_0.7927.pth' + FILE: "./checkpoint.pth" TEST: MODEL_FILE: '' diff --git a/PyTorch/built-in/cv/classification/CRNN_for_PyTorch/main.py b/PyTorch/built-in/cv/classification/CRNN_for_PyTorch/main.py index 9222664a27eedd09564a0adbded731c5274a9eb3..84c04631e46ad36c75fc70f7e550167fb26ec4a8 100644 --- a/PyTorch/built-in/cv/classification/CRNN_for_PyTorch/main.py +++ b/PyTorch/built-in/cv/classification/CRNN_for_PyTorch/main.py @@ -27,7 +27,7 @@ import torch.nn.parallel from torch.utils.data import DataLoader from apex import amp from easydict import EasyDict as edict - +from collections import OrderedDict def parse_arg(): parser = argparse.ArgumentParser(description="train crnn") @@ -78,7 +78,15 @@ def main(): checkpoint = torch.load(model_state_file, map_location=device) if 'state_dict' in checkpoint.keys(): last_epoch = checkpoint['epoch'] - model.load_state_dict(checkpoint['state_dict']) + state_dict = checkpoint['state_dict'] + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + if(k[0:7]=="module."): + name = k[7:] + else: + name = k[0:] + new_state_dict[name] = v + model.load_state_dict(new_state_dict) best_acc = checkpoint['best_acc'] optimizer.load_state_dict(checkpoint['optimizer']) if config.TRAIN.AMP: diff --git a/PyTorch/built-in/cv/classification/CRNN_for_PyTorch/sdk_infer/extract_lmdb.py b/PyTorch/built-in/cv/classification/CRNN_for_PyTorch/sdk_infer/extract_lmdb.py index a2d814e2ecadc0e266551d258509f39d8c3d77b6..6921a319df4ac1dc8a5296ba3d172e7f80cba972 100644 --- a/PyTorch/built-in/cv/classification/CRNN_for_PyTorch/sdk_infer/extract_lmdb.py +++ b/PyTorch/built-in/cv/classification/CRNN_for_PyTorch/sdk_infer/extract_lmdb.py @@ -47,9 +47,9 @@ class LmdbDataLoader(object): index = self.cur_index self.cur_index += 1 with self.env.begin(write=False) as txn: - img_key = b'image-%09d' % index + img_key = b'img_%d' % index imgbuf = txn.get(img_key) - label_key = b'label-%09d' % index + label_key = b'lab_%d' % index label = txn.get(label_key).decode('utf-8').lower() print(f"read img {img_key} {label}") diff --git a/PyTorch/built-in/cv/classification/CRNN_for_PyTorch/sdk_infer/sdk_run_infer/run.sh b/PyTorch/built-in/cv/classification/CRNN_for_PyTorch/sdk_infer/sdk_run_infer/run.sh index a6fc8f6f9dda8d3adbda3ee6d47ef5a52f77442d..1d2b84b4837979cb66edda94a75e407651ce0804 100644 --- a/PyTorch/built-in/cv/classification/CRNN_for_PyTorch/sdk_infer/sdk_run_infer/run.sh +++ b/PyTorch/built-in/cv/classification/CRNN_for_PyTorch/sdk_infer/sdk_run_infer/run.sh @@ -24,13 +24,7 @@ CUR_PATH=$(cd "$(dirname "$0")" || { warn "Failed to check path/to/run.sh" ; exi info() { echo -e "\033[1;34m[INFO ][MxStream] $1\033[1;37m" ; } warn() { echo >&2 -e "\033[1;31m[WARN ][MxStream] $1\033[1;37m" ; } -#export MX_SDK_HOME=${CUR_PATH}/../../.. -export LD_LIBRARY_PATH=${MX_SDK_HOME}/lib:${MX_SDK_HOME}/opensource/lib:${MX_SDK_HOME}/opensource/lib64:/usr/local/Ascend/ascend-toolkit/latest/acllib/lib64:${LD_LIBRARY_PATH} -export GST_PLUGIN_SCANNER=${MX_SDK_HOME}/opensource/libexec/gstreamer-1.0/gst-plugin-scanner -export GST_PLUGIN_PATH=${MX_SDK_HOME}/opensource/lib/gstreamer-1.0:${MX_SDK_HOME}/lib/plugins -#to set PYTHONPATH, import the StreamManagerApi.py -export PYTHONPATH=$PYTHONPATH:${MX_SDK_HOME}/python python3 main.py $image_path $result_dir exit 0 diff --git a/PyTorch/built-in/cv/classification/CRNN_for_PyTorch/utils.py b/PyTorch/built-in/cv/classification/CRNN_for_PyTorch/utils.py index 54dcdbb5abdf21cdf705fc6086748622d9d2904c..396d5101effe05ed34a5de47e2b9f02fde481737 100644 --- a/PyTorch/built-in/cv/classification/CRNN_for_PyTorch/utils.py +++ b/PyTorch/built-in/cv/classification/CRNN_for_PyTorch/utils.py @@ -171,7 +171,6 @@ class strLabelConverter(object): index += l return texts - class resizeNormalize(object): def __init__(self, size, interpolation=Image.BICUBIC): self.size = size @@ -244,7 +243,7 @@ class lmdbDataset(Dataset): self.filtered_index_list = [] for index in range(self.nSamples): index += 1 - label_key = 'label-%09d'.encode() % index + label_key = 'lab_%d'.encode() % index label = txn.get(label_key).decode('utf-8') out_of_char = f'[^{self.alphabets}]' if re.search(out_of_char, label.lower()): @@ -262,7 +261,7 @@ class lmdbDataset(Dataset): index = self.filtered_index_list[index] with self.env.begin(write=False) as txn: - img_key = 'image-%09d'.encode() % index + img_key = 'img_%d'.encode() % index imgbuf = txn.get(img_key) buf = six.BytesIO() buf.write(imgbuf) @@ -274,7 +273,7 @@ class lmdbDataset(Dataset): return self[index + 1] if self.transform is not None: img = self.transform(img) - label_key = 'label-%09d'.encode() % index + label_key = 'lab_%d'.encode() % index label = txn.get(label_key).decode('utf-8') label = label.lower() return (img, label)