From 7aede7cfc8a5cf93918e859e9a0f51f2947088ab Mon Sep 17 00:00:00 2001 From: "Shine.Wang" Date: Tue, 15 Apr 2025 16:04:30 +0800 Subject: [PATCH] small bug fix --- chat2db/app/service/sql_generate_service.py | 32 ++++++++++----------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/chat2db/app/service/sql_generate_service.py b/chat2db/app/service/sql_generate_service.py index dfd667d..8354a72 100644 --- a/chat2db/app/service/sql_generate_service.py +++ b/chat2db/app/service/sql_generate_service.py @@ -49,8 +49,8 @@ class SqlGenerateService(): if len(matches) == 0: return '' tmp = matches[0] - tmp.replace('\'', '\"') - tmp.replace(',', ',') + tmp = tmp.replace('\'', '\"') + tmp = tmp.replace(',', ',') return tmp @staticmethod @@ -119,7 +119,7 @@ class SqlGenerateService(): except Exception as e: logging.error(f'数据库{database_id}信息获取失败由于{e}') return [] - database_type = 'psotgres' + database_type = 'postgres' if 'mysql' in database_url: database_type = 'mysql' del database_url @@ -132,21 +132,20 @@ class SqlGenerateService(): data_frame_list = [] if table_id_list is None: if use_llm_enhancements: - table_id_list = await asyncio.wait_for(SqlGenerateService.get_most_similar_table_id_list(database_id, question, table_choose_cnt)) + table_id_list = await asyncio.wait_for(SqlGenerateService.get_most_similar_table_id_list(database_id, question, table_choose_cnt)) else: try: table_info_list = await TableInfoManager.get_table_info_by_database_id(database_id) - table_id_list = [] - for table_info in table_info_list: - table_id_list.append(table_info['table_id']) - max_retry=3 + table_id_list = [table_info['table_id'] for table_info in table_info_list] + max_retry = 3 sql_example_list = [] for _ in range(max_retry): try: - sql_example_list = await asyncio.wait_for(SqlExampleManager.get_topk_sql_example_by_cos_dis( + sql_example_list = await asyncio.wait_for(SqlExampleManager.get_topk_sql_example_by_cos_dis( question_vector=question_vector, table_id_list=table_id_list, topk=table_choose_cnt * 2), timeout=1 + ) break except Exception as e: @@ -196,17 +195,16 @@ class SqlGenerateService(): except Exception as e: sql_example_list = [] logging.error(f'id为{table_id}的表的最相近的{topk}条sql案例获取失败由于:{e}') - question_sql_list = [] - for i in range(len(sql_example_list)): - question_sql_list.append( - {'question': sql_example_list[i]['question'], - 'sql': sql_example_list[i]['sql']}) + question_sql_list = [{ + 'question': ex['question'], + 'sql': ex['sql'] + } for ex in sql_example_list] data_frame_list.append({'table_id': table_id, 'table_info': table_info, 'column_info_list': column_info_list, 'sql_example_list': question_sql_list}) return data_frame_list @staticmethod - async def megre_sql_example(sql_example_list): + async def merge_sql_example(sql_example_list): sql_example = '' for i in range(len(sql_example_list)): sql_example += '问题'+str(i)+':\n'+sql_example_list[i].get('question', @@ -225,7 +223,7 @@ class SqlGenerateService(): return sql @staticmethod - async def generate_sql_base_on_exmpale( + async def generate_sql_base_on_example( database_id, question, table_id_list=None, sql_generate_cnt=1, use_llm_enhancements=False): try: database_url = await DatabaseInfoManager.get_database_url_by_id(database_id) @@ -234,7 +232,7 @@ class SqlGenerateService(): return {} if database_url is None: raise Exception('数据库配置不存在') - database_type = 'psotgres' + database_type = 'postgres' if 'mysql' in database_url: database_type = 'mysql' data_frame_list = await SqlGenerateService.find_most_similar_sql_example(database_id, table_id_list, question, use_llm_enhancements) -- Gitee