diff --git a/test/test_qa.py b/test/test_qa.py index 369a69e7b9fea18c935fc4bd6c251a09c6d61bd2..78145504a4178dabb894cf0bc35100a7e6753e02 100644 --- a/test/test_qa.py +++ b/test/test_qa.py @@ -132,19 +132,24 @@ def list_chunks(session_cookie: str, csrf_cookie: str, document_id: str, def parser(): # 创建 ArgumentParser 对象 parser = argparse.ArgumentParser(description="Script to process document and generate QA pairs.") - - # 添加必需的参数 - parser.add_argument('-n', '--name', type=str, required=True, help='User name') - parser.add_argument('-p', '--password', type=str, required=True, help='User password') - parser.add_argument('-k', '--kb_id', type=str, required=True, help='KnowledgeBase ID') - parser.add_argument('-u', '--url', type=str, required=True, help='URL for witChainD') + subparser = parser.add_subparsers(dest='mode', required=True, help='Mode of operation') + + # 离线模式参数 + offline = subparser.add_parser('offline', help='Offline mode for processing documents') # noqa: F841 + offline.add_argument("-i", "--input_path", required=True, default="./document", help="Path of document names",) + # 在线模式所需添加的参数 + online = subparser.add_parser('online', help='Online mode for processing documents') + online.add_argument('-n', '--name', type=str, required=True, help='User name') + online.add_argument('-p', '--password', type=str, required=True, help='User password') + online.add_argument('-k', '--kb_id', type=str, required=True, help='KnowledgeBase ID') + online.add_argument('-u', '--url', type=str, required=True, help='URL for witChainD') # 添加可选参数,并设置默认值 - parser.add_argument('-q', '--qa_count', type=int, default=1, + online.add_argument('-q', '--qa_count', type=int, default=1, help='Number of QA pairs to generate per text block (default: 1)') # 添加文件名列表参数 - parser.add_argument('-d', '--doc_names', nargs='+', required=False, default=[], help='List of document names') + online.add_argument('-d', '--doc_names', nargs='+', required=False, default=[], help='List of document names') # 解析命令行参数 args = parser.parse_args() @@ -172,7 +177,6 @@ llm = LLM(model_name=config['MODEL_NAME'], request_timeout=60, temperature=0.35) - def get_random_number(l, r): return random.randint(l, r-1) @@ -474,6 +478,19 @@ def list_documents(session_cookie, csrf_cookie, kb_id, base_url="http://0.0.0.0: # 返回响应文本 return documents +def get_document(dir): + documents = [] + print(os.listdir(dir)) + for file in os.listdir(dir): + if file.endswith('.xlsx'): + file_path = os.path.join(dir, file) + df = pd.read_excel(file_path) + documents.append(df.to_dict(orient='records')) + if file.endswith('.csv'): + file_path = os.path.join(dir, file) + df = pd.read_csv(file_path, ) + documents.append(df.to_dict(orient='records')) + return documents if __name__ == '__main__': """ @@ -487,71 +504,90 @@ if __name__ == '__main__': 需要在.env中配置好LLM和witChainD相关的config,以及prompt路径 """ args = parser() - js, tmp_dict = login_and_get_tokens(args.name, args.password, args.url) - session_cookie = tmp_dict['ECSESSION'] - csrf_cookie = tmp_dict['csrf_token'] - print('login success') - documents = list_documents(session_cookie, csrf_cookie, args.kb_id, args.url) - print('get document success') QAs = [] - # print(documents) - for document in documents: - # print('refresh tokens') - # print(json.dumps(document, indent=4, ensure_ascii=False)) - if args.doc_names != [] and document['name'] not in args.doc_names: - # args.doc_names = [] - continue - else: - args.doc_names = [] + if args.mode == 'online': js, tmp_dict = login_and_get_tokens(args.name, args.password, args.url) session_cookie = tmp_dict['ECSESSION'] csrf_cookie = tmp_dict['csrf_token'] - args.doc_id = document['id'] - args.doc_name = document['name'] - count = 0 - while count < 5: + print('login success') + documents = list_documents(session_cookie, csrf_cookie, args.kb_id, args.url) + print('get document success') + print(documents) + for document in documents: + # print('refresh tokens') + # print(json.dumps(document, indent=4, ensure_ascii=False)) + if args.doc_names != [] and document['name'] not in args.doc_names: + # args.doc_names = [] + continue + else: + args.doc_names = [] + js, tmp_dict = login_and_get_tokens(args.name, args.password, args.url) + session_cookie = tmp_dict['ECSESSION'] + csrf_cookie = tmp_dict['csrf_token'] + args.doc_id = document['id'] + args.doc_name = document['name'] + count = 0 + while count < 5: + try: + js = list_chunks(session_cookie, csrf_cookie, str(args.doc_id), base_url=args.url) + print(f'js: {js}') + count = 10 + except Exception as e: + print(f"document {args.doc_name} check failed {e} with retry {count}") + count = count + 1 + time.sleep(1) + continue + if count == 5: + print(f"document {args.doc_name} check failed") + continue + chunks = js['data']['data_list'] + new_chunks = [] + for chunk in chunks: + new_chunk = { + 'text': chunk['text'], + 'type': chunk['type'], + } + new_chunks.append(new_chunk) + chunks = new_chunks + model = QAgenerator() try: - js = list_chunks(session_cookie, csrf_cookie, str(args.doc_id), base_url=args.url) - count = 10 + print('正在生成QA对...') + t_QAs = asyncio.run(model.qa_generate(chunks=chunks, file=args.doc_name)) + print("QA对生成完毕,正在获取答案...") + tt_QAs = asyncio.run(get_QAs_answers(t_QAs, args.kb_id, session_cookie, csrf_cookie, args.url)) + print(f"tt_QAs: {tt_QAs}") + print("答案获取完毕,正在计算答案正确性...") + ttt_QAs = asyncio.run(QAScore().get_scores(tt_QAs)) + print(f"ttt_QAs: {ttt_QAs}") + for QA in t_QAs: + QAs.append(QA) + df = pd.DataFrame(QAs) + df.astype(str) + print(document['name'], 'down') + print('sample:', t_QAs[0]['question'][:40]) + df.to_excel(current_dir / 'temp_answer.xlsx', index=False) + print(f'temp_Excel结果已输出到{current_dir}/temp_answer.xlsx') except Exception as e: - print(f"document {args.doc_name} check failed {e} with retry {count}") - count = count + 1 - time.sleep(1) + import traceback + print(traceback.print_exc()) + print(f"document {args.doc_name} failed {e}") continue - if count == 5: - print(f"document {args.doc_name} check failed") - continue - chunks = js['data']['data_list'] - new_chunks = [] - for chunk in chunks: - new_chunk = { - 'text': chunk['text'], - 'type': chunk['type'], + else: + # 离线模式 + # print(document_path) + t_QAs = get_document(args.input_path) + print(f"获取到{len(t_QAs)}个文档") + for item in t_QAs[0]: + single_item = { + "question": item['问题'], + "answer": item['标准答案'], + "witChainD_answer": item['llm的回答'], + "text": item['原始片段'], + "witChainD_source": item['检索片段'], } - new_chunks.append(new_chunk) - chunks = new_chunks - model = QAgenerator() - try: - print('正在生成QA对...') - t_QAs = asyncio.run(model.qa_generate(chunks=chunks, file=args.doc_name)) - print("QA对生成完毕,正在获取答案...") - tt_QAs = asyncio.run(get_QAs_answers(t_QAs, args.kb_id, session_cookie, csrf_cookie, args.url)) - print("答案获取完毕,正在计算答案正确性...") - ttt_QAs = asyncio.run(QAScore().get_scores(tt_QAs)) - for QA in t_QAs: - QAs.append(QA) - df = pd.DataFrame(QAs) - df.astype(str) - print(document['name'], 'down') - print('sample:', t_QAs[0]['question'][:40]) - df.to_excel(current_dir / 'temp_answer.xlsx', index=False) - print(f'temp_Excel结果已输出到{current_dir}/temp_answer.xlsx') - except Exception as e: - import traceback - print(traceback.print_exc()) - print(f"document {args.doc_name} failed {e}") - continue - # + # print(single_item) + ttt_QAs = asyncio.run(QAScore().get_score(single_item)) + QAs.append(ttt_QAs) # # 输出QAs到xlsx中 # exit(0) newQAs = [] @@ -563,7 +599,7 @@ if __name__ == '__main__': "lcs_score(最大公共子串)": [], "jac_score(杰卡德距离)": [], "leve_score(编辑距离)": [], - "time_cost": { + "time_cost": { "keyword_searching": [], "text_to_vector": [], "vector_searching": [], @@ -573,16 +609,17 @@ if __name__ == '__main__': }, } - time_cost_metrics = list(total["time_cost"].keys()) - + time_cost_metrics = list(total["time_cost"].keys()) + for QA in QAs: + print(QA) try: if 'time_cost' in QA.keys(): ReOrderedQA = { '领域': str(QA['type']), '问题': str(QA['question']), '标准答案': str(QA['answer']), - 'witChainD 回答': str(QA['witChainD_answer']), + 'llm的回答': str(QA['witChainD_answer']), 'context_relevancy(上下文相关性)': str(QA['context_relevancy']), 'context_recall(召回率)': str(QA['context_recall']), 'faithfulness(忠实性)': str(QA['faithfulness']), @@ -604,7 +641,7 @@ if __name__ == '__main__': '领域': str(QA['type']), '问题': str(QA['question']), '标准答案': str(QA['answer']), - 'witChainD 回答': str(QA['witChainD_answer']), + 'llm的回答': str(QA['witChainD_answer']), 'context_relevancy(上下文相关性)': str(QA['context_relevancy']), 'context_recall(召回率)': str(QA['context_recall']), 'faithfulness(忠实性)': str(QA['faithfulness']), @@ -617,13 +654,13 @@ if __name__ == '__main__': } print(ReOrderedQA) newQAs.append(ReOrderedQA) - + for metric in total.keys(): if metric != "time_cost": # 跳过time_cost(特殊处理) value = ReOrderedQA.get(metric) if value is not None: total[metric].append(float(value)) - + if "time_cost" in QA: for sub_metric in time_cost_metrics: value = QA["time_cost"].get(sub_metric) @@ -631,7 +668,7 @@ if __name__ == '__main__': total["time_cost"][sub_metric].append(float(value)) except Exception as e: print(f"QA {QA} error {e}") - + # 计算平均值 avg = {} for metric, values in total.items(): @@ -646,19 +683,22 @@ if __name__ == '__main__': avg[metric] = avg_time_cost print(f"生成测试结果: {avg}") - - excel_path = current_dir / "answer.xlsx" - with pd.ExcelWriter(excel_path, engine="xlsxwriter") as writer: + + excel_path = current_dir / 'answer.xlsx' + with pd.ExcelWriter(excel_path, engine='xlsxwriter') as writer: # 写入第一个sheet(测试样例) df = pd.DataFrame(newQAs).astype(str) df.to_excel(writer, sheet_name="测试样例", index=False) # 写入第二个sheet(测试结果) + filtered_time_cost = {k: v for k, v in avg["time_cost"].items() if v != 0} flat_avg = { **{k: v for k, v in avg.items() if k != "time_cost"}, - **{f"time_cost_{k}": v for k, v in avg["time_cost"].items()}, + **{f"time_cost_{k}": v for k, v in filtered_time_cost.items()}, } - avg_df = pd.DataFrame([flat_avg]) # 转换为DataFrame(一行数据) + avg_df = pd.DataFrame([flat_avg]) avg_df.to_excel(writer, sheet_name="测试结果", index=False) - print(f"测试样例和结果已输出到{excel_path}") + + print(f'测试样例和结果已输出到{excel_path}') + diff --git a/test/tools/config.py b/test/tools/config.py index 52b0e556ca0a2e884b6639c632a17f39ec8250a7..e677904723fd6896f8497af2aa6f543e6564b5a8 100644 --- a/test/tools/config.py +++ b/test/tools/config.py @@ -1,10 +1,10 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. config = { - 'PROMPT_PATH': './tools/prompt.yaml', - 'MODEL_NAME': 'your_model_name', # Replace with your actual model name - 'OPENAI_API_BASE': "your_openai_api_base_url", - 'OPENAI_API_KEY': 'your_openai_api_key', - 'REQUEST_TIMEOUT': 120, - 'MAX_TOKENS': 8096, - 'MODEL_ENH': 'false', + "PROMPT_PATH": "./tools/prompt.yaml", + "MODEL_NAME": "your_model_name", # Replace with your actual model name + "OPENAI_API_BASE": "your_openai_api_base_url", + "OPENAI_API_KEY": "your_openai_api_key", + "REQUEST_TIMEOUT": 120, + "MAX_TOKENS": 8096, + "MODEL_ENH": "false", } diff --git "a/test/witchainD\346\265\213\350\257\225\346\214\207\345\257\274.docm" "b/test/witchainD\346\265\213\350\257\225\346\214\207\345\257\274.docm" new file mode 100644 index 0000000000000000000000000000000000000000..dd3b2489e2a5a895fa19daacfd79623f6d78e4e1 Binary files /dev/null and "b/test/witchainD\346\265\213\350\257\225\346\214\207\345\257\274.docm" differ