From 49fddcea009d9dd0366fe6323df21fc87dcba7a8 Mon Sep 17 00:00:00 2001 From: yinin Date: Tue, 22 Mar 2022 03:52:02 +0000 Subject: [PATCH 01/12] =?UTF-8?q?=E6=96=B0=E5=BB=BA=20rawnet2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ACL_PyTorch/contrib/audio/rawnet2/.keep | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 ACL_PyTorch/contrib/audio/rawnet2/.keep diff --git a/ACL_PyTorch/contrib/audio/rawnet2/.keep b/ACL_PyTorch/contrib/audio/rawnet2/.keep new file mode 100644 index 0000000000..e69de29bb2 -- Gitee From f13c6517afe8cc526733b0a20ba7503b3b4baf63 Mon Sep 17 00:00:00 2001 From: yinin Date: Tue, 22 Mar 2022 03:52:17 +0000 Subject: [PATCH 02/12] =?UTF-8?q?=E5=88=A0=E9=99=A4=E6=96=87=E4=BB=B6=20AC?= =?UTF-8?q?L=5FPyTorch/contrib/audio/rawnet2/.keep?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ACL_PyTorch/contrib/audio/rawnet2/.keep | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 ACL_PyTorch/contrib/audio/rawnet2/.keep diff --git a/ACL_PyTorch/contrib/audio/rawnet2/.keep b/ACL_PyTorch/contrib/audio/rawnet2/.keep deleted file mode 100644 index e69de29bb2..0000000000 -- Gitee From f7965757f1a109b8e8fd9530e4014af1f45cfa8a Mon Sep 17 00:00:00 2001 From: yinin Date: Tue, 22 Mar 2022 03:52:35 +0000 Subject: [PATCH 03/12] =?UTF-8?q?=E6=96=B0=E5=BB=BA=20rawnet2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ACL_PyTorch/contrib/audio/rawnet2/.keep | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 ACL_PyTorch/contrib/audio/rawnet2/.keep diff --git a/ACL_PyTorch/contrib/audio/rawnet2/.keep b/ACL_PyTorch/contrib/audio/rawnet2/.keep new file mode 100644 index 0000000000..e69de29bb2 -- Gitee From 4bc512a667de37d2383b26821acb62ee3599208a Mon Sep 17 00:00:00 2001 From: yinin Date: Tue, 22 Mar 2022 03:54:39 +0000 Subject: [PATCH 04/12] =?UTF-8?q?rawnet2=E6=8F=90=E4=BA=A4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ACL_PyTorch/contrib/audio/rawnet2/LICENSE | 201 ++++++++++++++++++ ACL_PyTorch/contrib/audio/rawnet2/README.md | 50 +++++ .../audio/rawnet2/RawNet2_postprocess.py | 141 ++++++++++++ .../audio/rawnet2/RawNet2_preprocess.py | 84 ++++++++ .../contrib/audio/rawnet2/RawNet2_pth2onnx.py | 34 +++ ACL_PyTorch/contrib/audio/rawnet2/env.sh | 6 + .../contrib/audio/rawnet2/fusion_switch.cfg | 7 + .../contrib/audio/rawnet2/modelzoo_level.txt | 3 + ACL_PyTorch/contrib/audio/rawnet2/perf_t4.sh | 21 ++ .../contrib/audio/rawnet2/rawnet2.patch | 39 ++++ .../contrib/audio/rawnet2/requirements.txt | 10 + .../audio/rawnet2/test/eval_acc_perf.sh | 92 ++++++++ .../contrib/audio/rawnet2/test/parse.py | 32 +++ .../contrib/audio/rawnet2/test/pth2om.sh | 14 ++ 14 files changed, 734 insertions(+) create mode 100644 ACL_PyTorch/contrib/audio/rawnet2/LICENSE create mode 100644 ACL_PyTorch/contrib/audio/rawnet2/README.md create mode 100644 ACL_PyTorch/contrib/audio/rawnet2/RawNet2_postprocess.py create mode 100644 ACL_PyTorch/contrib/audio/rawnet2/RawNet2_preprocess.py create mode 100644 ACL_PyTorch/contrib/audio/rawnet2/RawNet2_pth2onnx.py create mode 100644 ACL_PyTorch/contrib/audio/rawnet2/env.sh create mode 100644 ACL_PyTorch/contrib/audio/rawnet2/fusion_switch.cfg create mode 100644 ACL_PyTorch/contrib/audio/rawnet2/modelzoo_level.txt create mode 100644 ACL_PyTorch/contrib/audio/rawnet2/perf_t4.sh create mode 100644 ACL_PyTorch/contrib/audio/rawnet2/rawnet2.patch create mode 100644 ACL_PyTorch/contrib/audio/rawnet2/requirements.txt create mode 100644 ACL_PyTorch/contrib/audio/rawnet2/test/eval_acc_perf.sh create mode 100644 ACL_PyTorch/contrib/audio/rawnet2/test/parse.py create mode 100644 ACL_PyTorch/contrib/audio/rawnet2/test/pth2om.sh diff --git a/ACL_PyTorch/contrib/audio/rawnet2/LICENSE b/ACL_PyTorch/contrib/audio/rawnet2/LICENSE new file mode 100644 index 0000000000..261eeb9e9f --- /dev/null +++ b/ACL_PyTorch/contrib/audio/rawnet2/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/ACL_PyTorch/contrib/audio/rawnet2/README.md b/ACL_PyTorch/contrib/audio/rawnet2/README.md new file mode 100644 index 0000000000..4ead1dd821 --- /dev/null +++ b/ACL_PyTorch/contrib/audio/rawnet2/README.md @@ -0,0 +1,50 @@ +### 1.环境准备 + +1.安装必要的依赖,测试环境可能已经安装其中的一些不同版本的库了,故手动测试时不推荐使用该命令安装 + +``` +pip3 install -r requirements.txt +``` + +2.获取,开源模型代码 + +``` +git clone https://github.com/Jungjee/RawNet.git +cd RawNet +patch -p1 < ../rawnet2.patch +cd .. +``` + +3.获取权重文件 + +通过2获得代码仓后,权重文件位置:RawNet\python\RawNet2\Pre-trained_model\rawnet2_best_weights.pt,将其放到当前目录 + +4.获取数据集 [VoxCeleb1](https://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1.html) ,下载Audio files测试集,重命名VoxCeleb1,确保VoxCeleb1下全部为id1xxxx文件夹,放到/root/datasets目录,注:该路径为绝对路径 + +5.获取 [msame工具](https://gitee.com/ascend/tools/tree/master/msame) 和 [benchmark工具](https://gitee.com/ascend/cann-benchmark/tree/master/infer) + +将msame和benchmark.x86_64放到与test文件夹同级目录下。 + +### 2.离线推理 + +310上执行,执行时使npu-smi info查看设备状态,确保device空闲 + +备注: + +1.需要对onnx模型进行onnxsim优化,否则无法达到精度要求,pth2om.sh脚本首先将pth文件转换为onnx模型,然后分别对bs1和bs16进行onnxsim优化,最后分别转化为om模型 + +2.eval_acc_perf.sh脚本逐步完成数据前处理bin文件输出、bs1和bs16模型推理、bs1和bs16精度测试,以及bs1和bs16benchmark的性能测试 + +``` +bash test/pth2om.sh + +bash test/eval_acc_perf.sh --datasets_path=/root/datasets/VoxCeleb1/ +``` + +评测结果: + +| 模型 | 官网pth精度 | 310精度 | 基准性能 | 310性能 | +| --------------------- | ----------------------------------------------- | --------- | -------- | ------- | +| Baseline-RawNet2 bs1 | [EER 2.49%](https://github.com/Jungjee/RawNet/) | EER 2.50% | 285.7fps | 72.8fps | +| Baseline-RawNet2 bs16 | [EER 2.49%](https://github.com/Jungjee/RawNet/) | EER 2.50% | 489.3fps | 77.6fps | + diff --git a/ACL_PyTorch/contrib/audio/rawnet2/RawNet2_postprocess.py b/ACL_PyTorch/contrib/audio/rawnet2/RawNet2_postprocess.py new file mode 100644 index 0000000000..56900f7023 --- /dev/null +++ b/ACL_PyTorch/contrib/audio/rawnet2/RawNet2_postprocess.py @@ -0,0 +1,141 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import sys +import numpy as np +from sklearn.metrics import roc_curve +from scipy.optimize import brentq +from scipy.interpolate import interp1d +from tqdm import tqdm +import argparse + +sys.path.append('RawNet/python/RawNet2/') +from utils import cos_sim + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--input', help='bin path', default="", required=True) + parser.add_argument('--batch_size', help='batch size', required=True) + parser.add_argument('--output', help='result path', default="result/") + args = parser.parse_args() + batch_size = int(args.batch_size) + base = args.input + save_dir = args.output + d_embeddings = {} + if batch_size == 1: + for path, dirs, files in os.walk(base): + temp = "" + l_embeddings = [] + index = 0 + l_utt = [] + l_code = [] + with tqdm(total=len(files), ncols=70) as pbar: + files = sorted(files) + for f in files: + t = np.loadtxt(path + "/" + f) + t = t.astype(np.float32) + index += 1 + key = f.replace("$", "/", 2).split("$")[0] + if (temp == ""): + temp = key + l_utt.append(key) + if (key == temp): + l_code.append(t) + if (key != temp): + l_embeddings.append(np.mean(l_code, axis=0)) + temp = key + l_utt.append(key) + l_code = [] + l_code.append(t) + if (index == len(files)): + l_embeddings.append(np.mean(l_code, axis=0)) + pbar.update(1) + if not len(l_utt) == len(l_embeddings): + print(len(l_utt), len(l_embeddings)) + exit() + for k, v in zip(l_utt, l_embeddings): + d_embeddings[k] = v + + else: + with open('bs16_key.txt', 'r') as f: + l_val = f.readlines() + bs16_out = [] + for path, dirs, files in os.walk(base): + files = sorted(files, key=lambda x: [int(x.split('_')[0])]) + for f in files: + t = np.loadtxt(path + "/" + f) + for i in t: + i.reshape(1024, ) + bs16_out.append(i) + bs16_out_embeddings = {} + if not len(l_val) == len(bs16_out): + print(len(l_val), len(bs16_out)) + exit() + for k, v in zip(l_val, bs16_out): + bs16_out_embeddings[k] = v + temp = "" + l_embeddings = [] + index = 0 + l_utt = [] + l_code = [] + with tqdm(total=len(bs16_out_embeddings), ncols=70) as pbar: + for key in bs16_out_embeddings.keys(): + index += 1 + xxx = key + key = key.replace("$", "/", 2).split("$")[0] + if (temp == ""): + temp = key + l_utt.append(key) + if (key == temp): + l_code.append(bs16_out_embeddings[xxx]) + if (key != temp): + l_embeddings.append(np.mean(l_code, axis=0)) + temp = key + l_utt.append(key) + l_code = [] + l_code.append(bs16_out_embeddings[xxx]) + if (index == len(bs16_out_embeddings.keys())): + l_embeddings.append(np.mean(l_code, axis=0)) + pbar.update(1) + if not len(l_utt) == len(l_embeddings): + print(len(l_utt), len(l_embeddings)) + exit() + for k, v in zip(l_utt, l_embeddings): + d_embeddings[k] = v + + with open('RawNet/trials/vox_original.txt', 'r') as f: + l_val_trial = f.readlines() + y_score = [] + y = [] + + f_res = open(save_dir + 'result_detail_bs{}.txt'.format(batch_size), 'w') + for line in l_val_trial: + trg, utt_a, utt_b = line.strip().split(' ') + y.append(int(trg)) + y_score.append(cos_sim(d_embeddings[utt_a], d_embeddings[utt_b])) + f_res.write('{score} {target}\n'.format(score=y_score[-1], target=y[-1])) + f_res.close() + fpr, tpr, _ = roc_curve(y, y_score, pos_label=1) + eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.) + f_eer_301 = open(save_dir + 'result_eer_{}.txt'.format(batch_size), 'w') + f_eer_301.write('bs{dir} evaluation EER: {eer}\n'.format(dir=batch_size, eer=eer)) + f_eer_301.close() + print('bs{dir} evaluation EER: {eer}\n'.format(dir=batch_size, eer=eer)) + + +if __name__ == '__main__': + main() diff --git a/ACL_PyTorch/contrib/audio/rawnet2/RawNet2_preprocess.py b/ACL_PyTorch/contrib/audio/rawnet2/RawNet2_preprocess.py new file mode 100644 index 0000000000..1f6015fe56 --- /dev/null +++ b/ACL_PyTorch/contrib/audio/rawnet2/RawNet2_preprocess.py @@ -0,0 +1,84 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import numpy as np +import argparse +import sys +sys.path.append('RawNet/python/RawNet2/') +from dataloader import TA_Dataset_VoxCeleb2 + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--input', help='dataset path', default="/root/datasets/VoxCeleb1/") + parser.add_argument('--batch_size', help='batch size', default=1) + parser.add_argument('--output', help='out bin path', default="bin_out_bs1/") + args = parser.parse_args() + base_dir = args.input + out_dir = args.output + batch_size = int(args.batch_size) + + def get_utt_list(src_dir): + l_utt = [] + for path, dirs, files in os.walk(src_dir): + path = path.replace('\\', '/') + base = '/'.join(path.split('/')[-2:]) + '/' + for f in files: + if f[-3:] != 'wav': + continue + l_utt.append(base + f) + return l_utt + + l_val = sorted(get_utt_list(base_dir)) + TA_evalset = TA_Dataset_VoxCeleb2(list_IDs=l_val, + return_label=True, + window_size=11810, + nb_samp=59049, + base_dir=base_dir) + if batch_size == 1: + for item in TA_evalset: + n = 0 + for i in item[0]: + i.tofile(out_dir + item[1].replace('/', '$') + "$" + str(n) + ".bin") + n += 1 + else: + bs16_key = open('bs16_key.txt', mode='w') + bs16 = [] + n = 0 + i = 0 + for item in TA_evalset: + l = 0 + for t in item[0]: + bs16_key.write(item[1].replace('/', '$') + "$" + str(n) + ".bin" + "$" + str(l) + "\n") + l += 1 + n += 1 + bs16.append(t) + if n == 16: + np.vstack(bs16).tofile(out_dir + str(i) + ".bin") + i += 1 + bs16 = [] + n = 0 + if n % 16 == 0: + return + for j in range(16 - (n % 16)): + bs16_key.write("temp$" + str(j) + "\n") + bs16.append(np.empty((59049,), dtype='float32')) + bs16_key.close() + np.vstack(bs16).tofile(out_dir + str(i) + ".bin") + + +if __name__ == '__main__': + main() diff --git a/ACL_PyTorch/contrib/audio/rawnet2/RawNet2_pth2onnx.py b/ACL_PyTorch/contrib/audio/rawnet2/RawNet2_pth2onnx.py new file mode 100644 index 0000000000..8cdefee1d3 --- /dev/null +++ b/ACL_PyTorch/contrib/audio/rawnet2/RawNet2_pth2onnx.py @@ -0,0 +1,34 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import sys +import torch + +sys.path.append('RawNet/python/RawNet2/Pre-trained_model') +from RawNet.python.RawNet2.parser import get_args +from model_RawNet2_original_code import RawNet + +ptfile = "rawnet2_best_weights.pt" +args = get_args() +args.model['nb_classes'] = 6112 +model = RawNet(args.model, device="cpu") +model.load_state_dict(torch.load(ptfile, map_location=torch.device('cpu'))) +input_names = ["wav"] +output_names = ["class"] +dynamic_axes = {'wav': {0: '-1'}, 'class': {0: '-1'}} +dummy_input = torch.randn(1, 59049) +export_onnx_file = "RawNet2.onnx" +torch.onnx.export(model, dummy_input, export_onnx_file, input_names=input_names, dynamic_axes=dynamic_axes, + output_names=output_names, opset_version=11, verbose=True) diff --git a/ACL_PyTorch/contrib/audio/rawnet2/env.sh b/ACL_PyTorch/contrib/audio/rawnet2/env.sh new file mode 100644 index 0000000000..ea514e10c5 --- /dev/null +++ b/ACL_PyTorch/contrib/audio/rawnet2/env.sh @@ -0,0 +1,6 @@ +export install_path=/usr/local/Ascend/ascend-toolkit/latest +export PATH=/usr/local/python3.7.5/bin:${install_path}/atc/ccec_compiler/bin:${install_path}/atc/bin:$PATH +export PYTHONPATH=${install_path}/atc/python/site-packages:$PYTHONPATH +export LD_LIBRARY_PATH=${install_path}/atc/lib64:${install_path}/acllib/lib64:$LD_LIBRARY_PATH +export ASCEND_OPP_PATH=${install_path}/opp +export ASCEND_AICPU_PATH=/usr/local/Ascend/ascend-toolkit/latest diff --git a/ACL_PyTorch/contrib/audio/rawnet2/fusion_switch.cfg b/ACL_PyTorch/contrib/audio/rawnet2/fusion_switch.cfg new file mode 100644 index 0000000000..ece7f48ff2 --- /dev/null +++ b/ACL_PyTorch/contrib/audio/rawnet2/fusion_switch.cfg @@ -0,0 +1,7 @@ +{ + "Switch":{ + "UBFusion":{ + "TbeEltwiseFusionPass":"off" + } + } +} \ No newline at end of file diff --git a/ACL_PyTorch/contrib/audio/rawnet2/modelzoo_level.txt b/ACL_PyTorch/contrib/audio/rawnet2/modelzoo_level.txt new file mode 100644 index 0000000000..20e36b3f78 --- /dev/null +++ b/ACL_PyTorch/contrib/audio/rawnet2/modelzoo_level.txt @@ -0,0 +1,3 @@ +FuncStatus:OK +PrecisionStatus:OK +PerfStatus:NOK \ No newline at end of file diff --git a/ACL_PyTorch/contrib/audio/rawnet2/perf_t4.sh b/ACL_PyTorch/contrib/audio/rawnet2/perf_t4.sh new file mode 100644 index 0000000000..26aa1c9748 --- /dev/null +++ b/ACL_PyTorch/contrib/audio/rawnet2/perf_t4.sh @@ -0,0 +1,21 @@ +#! /bin/bash + +trtexec --onnx=RawNet2_sim_bs1.onnx --fp16 --shapes=wav:1x59049 > RawNet2_bs1.log +perf_str=`grep "GPU.* mean.*ms$" ReID_bs1.log` +if [ -n "$perf_str" ]; then + perf_num=`echo $perf_str | awk -F' ' '{print $16}'` +else + perf_str=`grep "mean.*ms$" ReID_bs1.log` + perf_num=`echo $perf_str | awk -F' ' '{print $4}'` +fi +awk 'BEGIN{printf "gpu bs1 fps:%.3f\n", 1000*1/('$perf_num'/1)}' + +trtexec --onnx=RawNet2_sim_bs16.onnx --fp16 --shapes=wav:16x59049 > RawNet2_bs16.log +perf_str=`grep "GPU.* mean.*ms$" ReID_bs16.log` +if [ -n "$perf_str" ]; then + perf_num=`echo $perf_str | awk -F' ' '{print $16}'` +else + perf_str=`grep "mean.*ms$" ReID_bs16.log` + perf_num=`echo $perf_str | awk -F' ' '{print $4}'` +fi +awk 'BEGIN{printf "gpu bs16 fps:%.3f\n", 1000*1/('$perf_num'/16)}' diff --git a/ACL_PyTorch/contrib/audio/rawnet2/rawnet2.patch b/ACL_PyTorch/contrib/audio/rawnet2/rawnet2.patch new file mode 100644 index 0000000000..0d4f84a687 --- /dev/null +++ b/ACL_PyTorch/contrib/audio/rawnet2/rawnet2.patch @@ -0,0 +1,39 @@ +From 0876124bd723f26561911fff951d1e6524b3f332 Mon Sep 17 00:00:00 2001 +From: +Date: Thu, 4 Nov 2021 13:04:37 +0800 +Subject: [PATCH] rawnet2 + +--- + python/RawNet2/Pre-trained_model/model_RawNet2_original_code.py | 2 +- + python/RawNet2/dataloader.py | 2 +- + 2 files changed, 2 insertions(+), 2 deletions(-) + +diff --git a/python/RawNet2/Pre-trained_model/model_RawNet2_original_code.py b/python/RawNet2/Pre-trained_model/model_RawNet2_original_code.py +index c5981fc..9e3df1d 100644 +--- a/python/RawNet2/Pre-trained_model/model_RawNet2_original_code.py ++++ b/python/RawNet2/Pre-trained_model/model_RawNet2_original_code.py +@@ -265,7 +265,7 @@ class RawNet(nn.Module): + + self.sig = nn.Sigmoid() + +- def forward(self, x, y = 0, is_test=False, is_TS=False): ++ def forward(self, x, y = 0, is_test=True, is_TS=False): + #follow sincNet recipe + nb_samp = x.shape[0] + len_seq = x.shape[1] +diff --git a/python/RawNet2/dataloader.py b/python/RawNet2/dataloader.py +index c1791e2..67a4d2d 100644 +--- a/python/RawNet2/dataloader.py ++++ b/python/RawNet2/dataloader.py +@@ -114,7 +114,7 @@ class TA_Dataset_VoxCeleb2(data.Dataset): + if not self.return_label: + return list_X + y = self.labels[ID.split('/')[0]] +- return list_X, y ++ return list_X, ID + + def _normalize_scale(self, x): + ''' +-- +2.33.0.windows.2 + diff --git a/ACL_PyTorch/contrib/audio/rawnet2/requirements.txt b/ACL_PyTorch/contrib/audio/rawnet2/requirements.txt new file mode 100644 index 0000000000..e67279c448 --- /dev/null +++ b/ACL_PyTorch/contrib/audio/rawnet2/requirements.txt @@ -0,0 +1,10 @@ +python3.7.5 +pytorch == 1.5.0 +torchvision == 0.6.0 +onnx == 1.10.1 +onnx-simplifier==0.3.6 +numpy +scikit-learn +scipy +tqdm +soundfile \ No newline at end of file diff --git a/ACL_PyTorch/contrib/audio/rawnet2/test/eval_acc_perf.sh b/ACL_PyTorch/contrib/audio/rawnet2/test/eval_acc_perf.sh new file mode 100644 index 0000000000..9d2ae11dea --- /dev/null +++ b/ACL_PyTorch/contrib/audio/rawnet2/test/eval_acc_perf.sh @@ -0,0 +1,92 @@ +datasets_path="/root/datasets/VoxCeleb1/" + +for para in $* +do + if [[ $para == --datasets_path* ]]; then + datasets_path=`echo ${para#*=}` + fi +done + +# 数据预处理输出bin文件夹 +if [ -e "bin_out_bs1" ]; then + rm -r bin_out_bs1 +fi +if [ -e "bin_out_bs16" ]; then + rm -r bin_out_bs16 +fi +mkdir bin_out_bs1 +mkdir bin_out_bs16 + +python3.7 RawNet2_preprocess.py --input=${datasets_path} --batch_size=1 --output="bin_out_bs1/" + +if [ $? != 0 ]; then + echo "fail preprocess" +fi + +python3.7 RawNet2_preprocess.py --input=${datasets_path} --batch_size=16 --output="bin_out_bs16/" + +if [ $? != 0 ]; then + echo "fail preprocess" +else + echo "success" +fi + +#om推理 +if [ -e "om_bs1" ]; then + rm -r om_bs1 +fi +if [ -e "om_bs16" ]; then + rm -r om_bs16 +fi +if [ -e "result" ]; then + rm -r result +fi + +mkdir om_bs1 +mkdir om_bs16 +mkdir result + +source env.sh + +./msame --model RawNet2_sim_bs1.om --input bin_out_bs1 --output om_bs1 --outfmt TXT --device 0 +if [ $? != 0 ]; then + echo "fail msame!" +fi +./msame --model RawNet2_sim_bs16.om --input bin_out_bs16 --output om_bs16 --outfmt TXT --device 0 + +if [ $? != 0 ]; then + echo "fail msame!" +fi + +#om精度判断 + +python3.7 RawNet2_postprocess.py --input="om_bs1/" --batch_size=1 + +if [ $? != 0 ]; then + echo "fail!" +else + echo "success" +fi + +python3.7 RawNet2_postprocess.py --input="om_bs16/" --batch_size=16 + +if [ $? != 0 ]; then + echo "fail!" +else + echo "success" +fi + +arch=`uname -m` +./benchmark.${arch} -round=10 -om_path=RawNet2_sim_bs1.om -device_id=0 -batch_size=1 +./benchmark.${arch} -round=10 -om_path=RawNet2_sim_bs16.om -device_id=0 -batch_size=16 +python3.7 test/parse.py result/PureInfer_perf_of_RawNet2_sim_bs1_in_device_0.txt +if [ $? != 0 ]; then + echo "fail!" + exit -1 +fi + +python3.7 test/parse.py result/PureInfer_perf_of_RawNet2_sim_bs16_in_device_0.txt +if [ $? != 0 ]; then + echo "fail!" + exit -1 +fi diff --git a/ACL_PyTorch/contrib/audio/rawnet2/test/parse.py b/ACL_PyTorch/contrib/audio/rawnet2/test/parse.py new file mode 100644 index 0000000000..6cdf1420bd --- /dev/null +++ b/ACL_PyTorch/contrib/audio/rawnet2/test/parse.py @@ -0,0 +1,32 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import json +import re + +if __name__ == '__main__': + if sys.argv[1].endswith('.json'): + result_json = sys.argv[1] + with open(result_json, 'r') as f: + content = f.read() + tops = [i.get('value') for i in json.loads(content).get('value') if 'Top' in i.get('key')] + print('om {} top1:{} top5:{}'.format(result_json.split('_')[1].split('.')[0], tops[0], tops[4])) + elif sys.argv[1].endswith('.txt'): + result_txt = sys.argv[1] + with open(result_txt, 'r') as f: + content = f.read() + txt_data_list = [i.strip() for i in re.findall(r':(.*?),', content.replace('\n', ',') + ',')] + fps = float(txt_data_list[7].replace('samples/s', '')) * 4 + print('310 bs{} fps:{}'.format(result_txt.split('_')[3], fps)) diff --git a/ACL_PyTorch/contrib/audio/rawnet2/test/pth2om.sh b/ACL_PyTorch/contrib/audio/rawnet2/test/pth2om.sh new file mode 100644 index 0000000000..93affbf0f0 --- /dev/null +++ b/ACL_PyTorch/contrib/audio/rawnet2/test/pth2om.sh @@ -0,0 +1,14 @@ + + +python RawNet2_pth2onnx.py -name x + + +python3.7 -m onnxsim --input-shape="1,59049" RawNet2.onnx RawNet2_sim_bs1.onnx + +python3.7 -m onnxsim --input-shape="16,59049" RawNet2.onnx RawNet2_sim_bs16.onnx + +source env.sh + +atc --framework=5 --model=RawNet2_sim_bs1.onnx --output=RawNet2_sim_bs1 --input_format=ND --input_shape="wav:1,59049" --log=error --soc_version=Ascend310 --fusion_switch_file=fusion_switch.cfg + +atc --framework=5 --model=RawNet2_sim_bs16.onnx --output=RawNet2_sim_bs16 --input_format=ND --input_shape="wav:16,59049" --log=error --soc_version=Ascend310 --fusion_switch_file=fusion_switch.cfg \ No newline at end of file -- Gitee From ba6496aa5841a53f3d8263ae8f662be6c1f60253 Mon Sep 17 00:00:00 2001 From: yinin Date: Tue, 22 Mar 2022 03:54:49 +0000 Subject: [PATCH 05/12] =?UTF-8?q?=E5=88=A0=E9=99=A4=E6=96=87=E4=BB=B6=20AC?= =?UTF-8?q?L=5FPyTorch/contrib/audio/rawnet2/.keep?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ACL_PyTorch/contrib/audio/rawnet2/.keep | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 ACL_PyTorch/contrib/audio/rawnet2/.keep diff --git a/ACL_PyTorch/contrib/audio/rawnet2/.keep b/ACL_PyTorch/contrib/audio/rawnet2/.keep deleted file mode 100644 index e69de29bb2..0000000000 -- Gitee From 13afc437ab3743ed96244962b126413e7d6bcddd Mon Sep 17 00:00:00 2001 From: yinin Date: Tue, 22 Mar 2022 15:17:20 +0000 Subject: [PATCH 06/12] update ACL_PyTorch/contrib/audio/rawnet2/rawnet2.patch. --- .../contrib/audio/rawnet2/rawnet2.patch | 24 ++++++++++--------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/ACL_PyTorch/contrib/audio/rawnet2/rawnet2.patch b/ACL_PyTorch/contrib/audio/rawnet2/rawnet2.patch index 0d4f84a687..c32d7ecc5e 100644 --- a/ACL_PyTorch/contrib/audio/rawnet2/rawnet2.patch +++ b/ACL_PyTorch/contrib/audio/rawnet2/rawnet2.patch @@ -1,15 +1,15 @@ -From 0876124bd723f26561911fff951d1e6524b3f332 Mon Sep 17 00:00:00 2001 -From: -Date: Thu, 4 Nov 2021 13:04:37 +0800 -Subject: [PATCH] rawnet2 +From 63d3f6f71de11066a3df2781c62836dd09f3b1f5 Mon Sep 17 00:00:00 2001 +From: yinin +Date: Tue, 22 Mar 2022 23:11:27 +0800 +Subject: [PATCH] patch --- - python/RawNet2/Pre-trained_model/model_RawNet2_original_code.py | 2 +- - python/RawNet2/dataloader.py | 2 +- - 2 files changed, 2 insertions(+), 2 deletions(-) + .../RawNet2/Pre-trained_model/model_RawNet2_original_code.py | 2 +- + python/RawNet2/dataloader.py | 4 ++-- + 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/RawNet2/Pre-trained_model/model_RawNet2_original_code.py b/python/RawNet2/Pre-trained_model/model_RawNet2_original_code.py -index c5981fc..9e3df1d 100644 +index c5981fc..9e3df1d 100755 --- a/python/RawNet2/Pre-trained_model/model_RawNet2_original_code.py +++ b/python/RawNet2/Pre-trained_model/model_RawNet2_original_code.py @@ -265,7 +265,7 @@ class RawNet(nn.Module): @@ -22,14 +22,16 @@ index c5981fc..9e3df1d 100644 nb_samp = x.shape[0] len_seq = x.shape[1] diff --git a/python/RawNet2/dataloader.py b/python/RawNet2/dataloader.py -index c1791e2..67a4d2d 100644 +index c1791e2..3dfa802 100644 --- a/python/RawNet2/dataloader.py +++ b/python/RawNet2/dataloader.py -@@ -114,7 +114,7 @@ class TA_Dataset_VoxCeleb2(data.Dataset): +@@ -113,8 +113,8 @@ class TA_Dataset_VoxCeleb2(data.Dataset): + if not self.return_label: return list_X - y = self.labels[ID.split('/')[0]] +- y = self.labels[ID.split('/')[0]] - return list_X, y ++ #y = self.labels[ID.split('/')[0]] + return list_X, ID def _normalize_scale(self, x): -- Gitee From 052440e481156bc6b0ff922a7606829a3344fb74 Mon Sep 17 00:00:00 2001 From: yinin Date: Sat, 26 Mar 2022 03:48:13 +0000 Subject: [PATCH 07/12] update ACL_PyTorch/contrib/audio/rawnet2/requirements.txt. --- ACL_PyTorch/contrib/audio/rawnet2/requirements.txt | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ACL_PyTorch/contrib/audio/rawnet2/requirements.txt b/ACL_PyTorch/contrib/audio/rawnet2/requirements.txt index e67279c448..4c42a0cc0a 100644 --- a/ACL_PyTorch/contrib/audio/rawnet2/requirements.txt +++ b/ACL_PyTorch/contrib/audio/rawnet2/requirements.txt @@ -1,5 +1,4 @@ -python3.7.5 -pytorch == 1.5.0 +torch == 1.5.0 torchvision == 0.6.0 onnx == 1.10.1 onnx-simplifier==0.3.6 -- Gitee From a1ed21232a0b3fefff08b7bb6f8225084f2469b8 Mon Sep 17 00:00:00 2001 From: yinin Date: Sat, 26 Mar 2022 08:42:01 +0000 Subject: [PATCH 08/12] update ACL_PyTorch/contrib/audio/rawnet2/rawnet2.patch. --- .../contrib/audio/rawnet2/rawnet2.patch | 20 ++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/ACL_PyTorch/contrib/audio/rawnet2/rawnet2.patch b/ACL_PyTorch/contrib/audio/rawnet2/rawnet2.patch index c32d7ecc5e..2c7a30355c 100644 --- a/ACL_PyTorch/contrib/audio/rawnet2/rawnet2.patch +++ b/ACL_PyTorch/contrib/audio/rawnet2/rawnet2.patch @@ -1,12 +1,13 @@ -From 63d3f6f71de11066a3df2781c62836dd09f3b1f5 Mon Sep 17 00:00:00 2001 +From 2d6205ea3f2b1b61f4eb3063ca3ccf49138c4465 Mon Sep 17 00:00:00 2001 From: yinin -Date: Tue, 22 Mar 2022 23:11:27 +0800 +Date: Sat, 26 Mar 2022 16:26:20 +0800 Subject: [PATCH] patch --- .../RawNet2/Pre-trained_model/model_RawNet2_original_code.py | 2 +- python/RawNet2/dataloader.py | 4 ++-- - 2 files changed, 3 insertions(+), 3 deletions(-) + python/RawNet2/parser.py | 2 +- + 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/RawNet2/Pre-trained_model/model_RawNet2_original_code.py b/python/RawNet2/Pre-trained_model/model_RawNet2_original_code.py index c5981fc..9e3df1d 100755 @@ -36,6 +37,19 @@ index c1791e2..3dfa802 100644 def _normalize_scale(self, x): ''' +diff --git a/python/RawNet2/parser.py b/python/RawNet2/parser.py +index bea9112..4c15f34 100644 +--- a/python/RawNet2/parser.py ++++ b/python/RawNet2/parser.py +@@ -14,7 +14,7 @@ def str2bool(v): + def get_args(): + parser = argparse.ArgumentParser() + #dir +- parser.add_argument('-name', type = str, required = True) ++ parser.add_argument('-name', type = str, default = 'rawnet2') + parser.add_argument('-save_dir', type = str, default = 'DNNs/') + parser.add_argument('-DB', type = str, default = 'DB/VoxCeleb1/') + parser.add_argument('-DB_vox2', type = str, default = 'DB/VoxCeleb2/') -- 2.33.0.windows.2 -- Gitee From c2b7dd51a5881f7ac5c8b0082becb0b4cb404fdc Mon Sep 17 00:00:00 2001 From: yinin Date: Sat, 26 Mar 2022 08:42:57 +0000 Subject: [PATCH 09/12] update ACL_PyTorch/contrib/audio/rawnet2/test/pth2om.sh. --- ACL_PyTorch/contrib/audio/rawnet2/test/pth2om.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ACL_PyTorch/contrib/audio/rawnet2/test/pth2om.sh b/ACL_PyTorch/contrib/audio/rawnet2/test/pth2om.sh index 93affbf0f0..86513dd740 100644 --- a/ACL_PyTorch/contrib/audio/rawnet2/test/pth2om.sh +++ b/ACL_PyTorch/contrib/audio/rawnet2/test/pth2om.sh @@ -1,6 +1,6 @@ -python RawNet2_pth2onnx.py -name x +python RawNet2_pth2onnx.py python3.7 -m onnxsim --input-shape="1,59049" RawNet2.onnx RawNet2_sim_bs1.onnx -- Gitee From 8e738ccd4021ca9671636b047ddd2e2db8dc3bcf Mon Sep 17 00:00:00 2001 From: yinin Date: Mon, 28 Mar 2022 14:22:04 +0000 Subject: [PATCH 10/12] update ACL_PyTorch/contrib/audio/rawnet2/RawNet2_preprocess.py. --- ACL_PyTorch/contrib/audio/rawnet2/RawNet2_preprocess.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ACL_PyTorch/contrib/audio/rawnet2/RawNet2_preprocess.py b/ACL_PyTorch/contrib/audio/rawnet2/RawNet2_preprocess.py index 1f6015fe56..a0e421fb77 100644 --- a/ACL_PyTorch/contrib/audio/rawnet2/RawNet2_preprocess.py +++ b/ACL_PyTorch/contrib/audio/rawnet2/RawNet2_preprocess.py @@ -66,14 +66,14 @@ def main(): l += 1 n += 1 bs16.append(t) - if n == 16: + if n == batch_size: np.vstack(bs16).tofile(out_dir + str(i) + ".bin") i += 1 bs16 = [] n = 0 - if n % 16 == 0: + if n % batch_size == 0: return - for j in range(16 - (n % 16)): + for j in range(batch_size - (n % batch_size)): bs16_key.write("temp$" + str(j) + "\n") bs16.append(np.empty((59049,), dtype='float32')) bs16_key.close() -- Gitee From 5c55c6e378297384e6ffc1991e5484f2b0769115 Mon Sep 17 00:00:00 2001 From: yinin Date: Thu, 31 Mar 2022 04:09:33 +0000 Subject: [PATCH 11/12] =?UTF-8?q?=E5=90=8E=E5=A4=84=E7=90=86=E8=84=9A?= =?UTF-8?q?=E6=9C=AC=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../audio/rawnet2/RawNet2_postprocess.py | 98 ++++++++----------- 1 file changed, 41 insertions(+), 57 deletions(-) diff --git a/ACL_PyTorch/contrib/audio/rawnet2/RawNet2_postprocess.py b/ACL_PyTorch/contrib/audio/rawnet2/RawNet2_postprocess.py index 56900f7023..c47a7f36be 100644 --- a/ACL_PyTorch/contrib/audio/rawnet2/RawNet2_postprocess.py +++ b/ACL_PyTorch/contrib/audio/rawnet2/RawNet2_postprocess.py @@ -25,6 +25,40 @@ import argparse sys.path.append('RawNet/python/RawNet2/') from utils import cos_sim +def get_l_embeddings(list_embeddings,bs,path="def"): + temp = "" + l_embeddings = [] + index = 0 + l_utt = [] + l_code = [] + with tqdm(total=len(list_embeddings), ncols=70) as pbar: + if bs==1: + files = sorted(list_embeddings) + else: + files = list_embeddings.keys() + for f in files: + if bs==1: + t = np.loadtxt(path + "/" + f) + t = t.astype(np.float32) + else: + t = list_embeddings[f] + index += 1 + key = f.replace("$", "/", 2).split("$")[0] + if (temp == ""): + temp = key + l_utt.append(key) + if (key == temp): + l_code.append(t) + else: + l_embeddings.append(np.mean(l_code, axis=0)) + temp = key + l_utt.append(key) + l_code = [] + l_code.append(t) + if (index == len(files)): + l_embeddings.append(np.mean(l_code, axis=0)) + pbar.update(1) + return l_utt,l_embeddings def main(): parser = argparse.ArgumentParser() @@ -38,38 +72,12 @@ def main(): d_embeddings = {} if batch_size == 1: for path, dirs, files in os.walk(base): - temp = "" - l_embeddings = [] - index = 0 - l_utt = [] - l_code = [] - with tqdm(total=len(files), ncols=70) as pbar: - files = sorted(files) - for f in files: - t = np.loadtxt(path + "/" + f) - t = t.astype(np.float32) - index += 1 - key = f.replace("$", "/", 2).split("$")[0] - if (temp == ""): - temp = key - l_utt.append(key) - if (key == temp): - l_code.append(t) - if (key != temp): - l_embeddings.append(np.mean(l_code, axis=0)) - temp = key - l_utt.append(key) - l_code = [] - l_code.append(t) - if (index == len(files)): - l_embeddings.append(np.mean(l_code, axis=0)) - pbar.update(1) + l_utt,l_embeddings = get_l_embeddings(files,batch_size,path); if not len(l_utt) == len(l_embeddings): print(len(l_utt), len(l_embeddings)) exit() for k, v in zip(l_utt, l_embeddings): d_embeddings[k] = v - else: with open('bs16_key.txt', 'r') as f: l_val = f.readlines() @@ -87,41 +95,17 @@ def main(): exit() for k, v in zip(l_val, bs16_out): bs16_out_embeddings[k] = v - temp = "" - l_embeddings = [] - index = 0 - l_utt = [] - l_code = [] - with tqdm(total=len(bs16_out_embeddings), ncols=70) as pbar: - for key in bs16_out_embeddings.keys(): - index += 1 - xxx = key - key = key.replace("$", "/", 2).split("$")[0] - if (temp == ""): - temp = key - l_utt.append(key) - if (key == temp): - l_code.append(bs16_out_embeddings[xxx]) - if (key != temp): - l_embeddings.append(np.mean(l_code, axis=0)) - temp = key - l_utt.append(key) - l_code = [] - l_code.append(bs16_out_embeddings[xxx]) - if (index == len(bs16_out_embeddings.keys())): - l_embeddings.append(np.mean(l_code, axis=0)) - pbar.update(1) - if not len(l_utt) == len(l_embeddings): - print(len(l_utt), len(l_embeddings)) - exit() - for k, v in zip(l_utt, l_embeddings): - d_embeddings[k] = v + l_utt,l_embeddings = get_l_embeddings(bs16_out_embeddings,batch_size); + if not len(l_utt) == len(l_embeddings): + print(len(l_utt), len(l_embeddings)) + exit() + for k, v in zip(l_utt, l_embeddings): + d_embeddings[k] = v with open('RawNet/trials/vox_original.txt', 'r') as f: l_val_trial = f.readlines() y_score = [] y = [] - f_res = open(save_dir + 'result_detail_bs{}.txt'.format(batch_size), 'w') for line in l_val_trial: trg, utt_a, utt_b = line.strip().split(' ') -- Gitee From fc716faca4832de529969e1a609bc77654058209 Mon Sep 17 00:00:00 2001 From: yinin Date: Thu, 31 Mar 2022 04:17:53 +0000 Subject: [PATCH 12/12] =?UTF-8?q?=E5=90=8E=E5=A4=84=E7=90=86=E8=84=9A?= =?UTF-8?q?=E6=9C=AC=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../audio/rawnet2/RawNet2_postprocess.py | 22 ++++++++----------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/ACL_PyTorch/contrib/audio/rawnet2/RawNet2_postprocess.py b/ACL_PyTorch/contrib/audio/rawnet2/RawNet2_postprocess.py index c47a7f36be..a0263c606a 100644 --- a/ACL_PyTorch/contrib/audio/rawnet2/RawNet2_postprocess.py +++ b/ACL_PyTorch/contrib/audio/rawnet2/RawNet2_postprocess.py @@ -58,7 +58,13 @@ def get_l_embeddings(list_embeddings,bs,path="def"): if (index == len(files)): l_embeddings.append(np.mean(l_code, axis=0)) pbar.update(1) - return l_utt,l_embeddings + if not len(l_utt) == len(l_embeddings): + print(len(l_utt), len(l_embeddings)) + exit() + d_embeddings = {} + for k, v in zip(l_utt, l_embeddings): + d_embeddings[k] = v + return d_embeddings def main(): parser = argparse.ArgumentParser() @@ -72,12 +78,7 @@ def main(): d_embeddings = {} if batch_size == 1: for path, dirs, files in os.walk(base): - l_utt,l_embeddings = get_l_embeddings(files,batch_size,path); - if not len(l_utt) == len(l_embeddings): - print(len(l_utt), len(l_embeddings)) - exit() - for k, v in zip(l_utt, l_embeddings): - d_embeddings[k] = v + d_embeddings = get_l_embeddings(files,batch_size,path); else: with open('bs16_key.txt', 'r') as f: l_val = f.readlines() @@ -95,12 +96,7 @@ def main(): exit() for k, v in zip(l_val, bs16_out): bs16_out_embeddings[k] = v - l_utt,l_embeddings = get_l_embeddings(bs16_out_embeddings,batch_size); - if not len(l_utt) == len(l_embeddings): - print(len(l_utt), len(l_embeddings)) - exit() - for k, v in zip(l_utt, l_embeddings): - d_embeddings[k] = v + d_embeddings = get_l_embeddings(bs16_out_embeddings,batch_size); with open('RawNet/trials/vox_original.txt', 'r') as f: l_val_trial = f.readlines() -- Gitee