diff --git a/chat2db/app/__init__.py b/chat2db/app/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/chat2db/app/app.py b/chat2db/app/app.py new file mode 100644 index 0000000000000000000000000000000000000000..bce8892578b2f1d4151e0ec210d8b212cd5647bf --- /dev/null +++ b/chat2db/app/app.py @@ -0,0 +1,35 @@ +import uvicorn +from fastapi import FastAPI +from chat2db.app.router import sql_example +from chat2db.app.router import sql_generate +from chat2db.app.router import database +from chat2db.app.router import table +from chat2db.config.config import config +import logging + + +logging.basicConfig(filename='app.log', level=logging.INFO, + format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s') + +app = FastAPI() + +app.include_router(sql_example.router) +app.include_router(sql_generate.router) +app.include_router(database.router) +app.include_router(table.router) + +if __name__ == '__main__': + try: + ssl_enable = config["SSL_ENABLE"] + if ssl_enable: + uvicorn.run(app, host=config["UVICORN_IP"], port=int(config["UVICORN_PORT"]), + proxy_headers=True, forwarded_allow_ips='*', + ssl_certfile=config["SSL_CERTFILE"], + ssl_keyfile=config["SSL_KEYFILE"], + ) + else: + uvicorn.run(app, host=config["UVICORN_IP"], port=int(config["UVICORN_PORT"]), + proxy_headers=True, forwarded_allow_ips='*' + ) + except Exception as e: + exit(1) diff --git a/chat2db/app/base/ac_automation.py b/chat2db/app/base/ac_automation.py new file mode 100644 index 0000000000000000000000000000000000000000..5a408abbd9c87f4d1b0f66d806e8b09ca2e9f11f --- /dev/null +++ b/chat2db/app/base/ac_automation.py @@ -0,0 +1,87 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +import copy +import logging + + +class Node: + def __init__(self, dep, pre_id): + self.dep = dep + self.pre_id = pre_id + self.pre_nearest_children_id = {} + self.children_id = {} + self.data_frame = None + + +logging.basicConfig(filename='app.log', level=logging.INFO, + format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s') + + +class DictTree: + def __init__(self): + self.root = 0 + self.node_list = [Node(0, -1)] + + def load_data(self, data_dict): + for key in data_dict: + self.insert_data(key, data_dict[key]) + self.init_pre() + + def insert_data(self, keyword, data_frame): + if not isinstance(keyword,str): + return + if len(keyword) == 0: + return + node_index = self.root + try: + for i in range(len(keyword)): + if keyword[i] not in self.node_list[node_index].children_id.keys(): + self.node_list.append(Node(self.node_list[node_index].dep+1, 0)) + self.node_list[node_index].children_id[keyword[i]] = len(self.node_list)-1 + node_index = self.node_list[node_index].children_id[keyword[i]] + except Exception as e: + logging.error(f'关键字插入失败由于:{e}') + return + self.node_list[node_index].data_frame = data_frame + + def init_pre(self): + q = [self.root] + l = 0 + r = 1 + try: + while l < r: + node_index = q[l] + self.node_list[node_index].pre_nearest_children_id = self.node_list[self.node_list[node_index].pre_id].children_id.copy() + l += 1 + for key, val in self.node_list[node_index].children_id.items(): + q.append(val) + r += 1 + if key in self.node_list[node_index].pre_nearest_children_id.keys(): + pre_id = self.node_list[node_index].pre_nearest_children_id[key] + self.node_list[val].pre_id = pre_id + self.node_list[node_index].pre_nearest_children_id[key] = val + except Exception as e: + logging.error(f'字典树前缀构建失败由于:{e}') + return + + def get_results(self, content: str): + content = content.lower() + pre_node_index = self.root + nex_node_index = None + results = [] + logging.info(f'当前问题{content}') + try: + for i in range(len(content)): + if content[i] in self.node_list[pre_node_index].pre_nearest_children_id.keys(): + nex_node_index = self.node_list[pre_node_index].pre_nearest_children_id[content[i]] + else: + nex_node_index = 0 + if self.node_list[pre_node_index].dep >= self.node_list[nex_node_index].dep: + if self.node_list[pre_node_index].data_frame is not None: + results.extend(copy.deepcopy(self.node_list[pre_node_index].data_frame)) + pre_node_index = nex_node_index + logging.info(f'当前深度{self.node_list[pre_node_index].dep}') + if self.node_list[pre_node_index].data_frame is not None: + results.extend(copy.deepcopy(self.node_list[pre_node_index].data_frame)) + except Exception as e: + logging.error(f'结果获取失败由于:{e}') + return results diff --git a/chat2db/app/base/mysql.py b/chat2db/app/base/mysql.py new file mode 100644 index 0000000000000000000000000000000000000000..176ddd908b99ac027851841a6ebc2479cf3b4491 --- /dev/null +++ b/chat2db/app/base/mysql.py @@ -0,0 +1,215 @@ + +import asyncio +import aiomysql +import concurrent.futures +import logging +from sqlalchemy.orm import sessionmaker +from sqlalchemy import create_engine, text +from concurrent.futures import ThreadPoolExecutor +from urllib.parse import urlparse +logging.basicConfig(filename='app.log', level=logging.INFO, + format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s') + + +class Mysql(): + executor = ThreadPoolExecutor(max_workers=10) + + async def test_database_connection(database_url): + try: + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(Mysql._connect_and_query, database_url) + result = future.result(timeout=0.1) + return result + except concurrent.futures.TimeoutError: + logging.error('mysql数据库连接超时') + return False + except Exception as e: + logging.error(f'mysql数据库连接失败由于{e}') + return False + + @staticmethod + def _connect_and_query(database_url): + try: + engine = create_engine( + database_url, + pool_size=20, + max_overflow=80, + pool_recycle=300, + pool_pre_ping=True + ) + session = sessionmaker(bind=engine)() + session.execute(text("SELECT 1")) + session.close() + return True + except Exception as e: + raise e + + @staticmethod + async def drop_table(database_url, table_name): + engine = create_engine( + database_url, + pool_size=20, + max_overflow=80, + pool_recycle=300, + pool_pre_ping=True + ) + with sessionmaker(engine)() as session: + sql_str = f"DROP TABLE IF EXISTS {table_name};" + session.execute(text(sql_str)) + + @staticmethod + async def select_primary_key_and_keyword_from_table(database_url, table_name, keyword): + try: + url = urlparse(database_url) + db_config = { + 'host': url.hostname or 'localhost', + 'port': int(url.port or 3306), + 'user': url.username or 'root', + 'password': url.password or '', + 'db': url.path.strip('/') + } + + async with aiomysql.create_pool(**db_config) as pool: + async with pool.acquire() as conn: + async with conn.cursor() as cur: + primary_key_query = """ + SELECT + COLUMNS.column_name + FROM + information_schema.tables AS TABLES + INNER JOIN information_schema.columns AS COLUMNS ON TABLES.table_name = COLUMNS.table_name + WHERE + TABLES.table_schema = %s AND TABLES.table_name = %s AND COLUMNS.column_key = 'PRI'; + """ + + # 尝试执行查询 + await cur.execute(primary_key_query, (db_config['db'], table_name)) + primary_key_list = await cur.fetchall() + if not primary_key_list: + return [] + primary_key_names = ', '.join([record[0] for record in primary_key_list]) + columns = f'{primary_key_names}, {keyword}' + query = f'SELECT {columns} FROM {table_name};' + await cur.execute(query) + results = await cur.fetchall() + + def _process_results(results, primary_key_list): + tmp_dict = {} + for row in results: + key = str(row[-1]) + if key not in tmp_dict: + tmp_dict[key] = [] + pk_values = [str(row[i]) for i in range(len(primary_key_list))] + tmp_dict[key].append(pk_values) + + return { + 'primary_key_list': [record[0] for record in primary_key_list], + 'keyword_value_dict': tmp_dict + } + result = await asyncio.get_event_loop().run_in_executor( + Mysql.executor, + _process_results, + results, + primary_key_list + ) + return result + + except Exception as e: + logging.error(f'mysql数据检索失败由于 {e}') + + @staticmethod + async def assemble_sql_query_base_on_primary_key(table_name, primary_key_list, primary_key_value_list): + sql_str = f'SELECT * FROM {table_name} where ' + for i in range(len(primary_key_list)): + sql_str += primary_key_list[i]+'= \''+primary_key_value_list[i]+'\'' + if i != len(primary_key_list)-1: + sql_str += ' and ' + sql_str += ';' + return sql_str + + @staticmethod + async def get_table_info(database_url, table_name): + engine = create_engine( + database_url, + pool_size=20, + max_overflow=80, + pool_recycle=300, + pool_pre_ping=True + ) + with sessionmaker(engine)() as session: + sql_str = f"""SELECT TABLE_COMMENT FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = '{table_name}';""" + table_note = session.execute(text(sql_str)).one()[0] + if table_note == '': + table_note = table_name + table_note = { + 'table_name': table_name, + 'table_note': table_note + } + return table_note + + @staticmethod + async def get_column_info(database_url, table_name): + engine = create_engine( + database_url, + pool_size=20, + max_overflow=80, + pool_recycle=300, + pool_pre_ping=True + ) + with engine.connect() as conn: + sql_str = f""" + SELECT column_name, column_type, column_comment FROM information_schema.columns where TABLE_NAME='{table_name}'; + """ + results = conn.execute(text(sql_str), {'table_name': table_name}).all() + column_info_list = [] + for result in results: + column_info_list.append({'column_name': result[0], 'column_type': result[1], 'column_note': result[2]}) + return column_info_list + + @staticmethod + async def get_all_table_name_from_database_url(database_url): + engine = create_engine( + database_url, + pool_size=20, + max_overflow=80, + pool_recycle=300, + pool_pre_ping=True + ) + with engine.connect() as connection: + result = connection.execute(text("SHOW TABLES")) + table_name_list = [row[0] for row in result] + return table_name_list + + @staticmethod + async def get_rand_data(database_url, table_name, cnt=10): + engine = create_engine( + database_url, + pool_size=20, + max_overflow=80, + pool_recycle=300, + pool_pre_ping=True + ) + try: + with sessionmaker(engine)() as session: + sql_str = f'''SELECT * + FROM {table_name} + ORDER BY RAND() + LIMIT {cnt};''' + dataframe = str(session.execute(text(sql_str)).all()) + except Exception as e: + dataframe = '' + logging.error(f'随机从数据库中获取数据失败由于{e}') + return dataframe + + @staticmethod + async def try_excute(database_url, sql_str): + engine = create_engine( + database_url, + pool_size=20, + max_overflow=80, + pool_recycle=300, + pool_pre_ping=True + ) + with sessionmaker(engine)() as session: + result=session.execute(text(sql_str)).all() + return result \ No newline at end of file diff --git a/chat2db/app/base/postgres.py b/chat2db/app/base/postgres.py new file mode 100644 index 0000000000000000000000000000000000000000..11dc22f86fa1752cca991f0db29837c275eb76a2 --- /dev/null +++ b/chat2db/app/base/postgres.py @@ -0,0 +1,233 @@ +import asyncio +import asyncpg +import concurrent.futures +import logging +from sqlalchemy.orm import sessionmaker +from sqlalchemy import create_engine, text +from concurrent.futures import ThreadPoolExecutor + +logging.basicConfig(filename='app.log', level=logging.INFO, + format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s') + + +def handler(signum, frame): + raise TimeoutError("超时") + + +class Postgres(): + executor = ThreadPoolExecutor(max_workers=10) + + async def test_database_connection(database_url): + try: + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(Postgres._connect_and_query, database_url) + result = future.result(timeout=0.1) + return result + except concurrent.futures.TimeoutError: + logging.error('postgres数据库连接超时') + return False + except Exception as e: + logging.error(f'postgres数据库连接失败由于{e}') + return False + + @staticmethod + def _connect_and_query(database_url): + try: + engine = create_engine( + database_url, + pool_size=20, + max_overflow=80, + pool_recycle=300, + pool_pre_ping=True + ) + session = sessionmaker(bind=engine)() + session.execute(text("SELECT 1")) + session.close() + return True + except Exception as e: + raise e + + @staticmethod + async def drop_table(database_url, table_name): + engine = create_engine( + database_url, + pool_size=20, + max_overflow=80, + pool_recycle=300, + pool_pre_ping=True + ) + with sessionmaker(engine)() as session: + sql_str = f"DROP TABLE IF EXISTS {table_name};" + session.execute(text(sql_str)) + + @staticmethod + async def select_primary_key_and_keyword_from_table(database_url, table_name, keyword): + try: + dsn = database_url.replace('+psycopg2', '') + conn = await asyncpg.connect(dsn=dsn) + primary_key_query = """ + SELECT + kcu.column_name + FROM + information_schema.table_constraints AS tc + JOIN information_schema.key_column_usage AS kcu + ON tc.constraint_name = kcu.constraint_name + WHERE + tc.constraint_type = 'PRIMARY KEY' + AND tc.table_name = $1; + """ + primary_key_list = await conn.fetch(primary_key_query, table_name) + if not primary_key_list: + return [] + columns = ', '.join([record['column_name'] for record in primary_key_list]) + f', {keyword}' + query = f'SELECT {columns} FROM {table_name};' + results = await conn.fetch(query) + + def _process_results(results, primary_key_list): + tmp_dict = {} + for row in results: + key = str(row[-1]) + if key not in tmp_dict: + tmp_dict[key] = [] + pk_values = [str(row[i]) for i in range(len(primary_key_list))] + tmp_dict[key].append(pk_values) + + return { + 'primary_key_list': [record['column_name'] for record in primary_key_list], + 'keyword_value_dict': tmp_dict + } + result = await asyncio.get_event_loop().run_in_executor( + Postgres.executor, + _process_results, + results, + primary_key_list + ) + await conn.close() + + return result + except Exception as e: + logging.error(f'postgres数据检索失败由于 {e}') + return None + + @staticmethod + async def assemble_sql_query_base_on_primary_key(table_name, primary_key_list, primary_key_value_list): + sql_str = f'SELECT * FROM {table_name} where ' + for i in range(len(primary_key_list)): + sql_str += primary_key_list[i]+'='+'\''+primary_key_value_list[i]+'\'' + if i != len(primary_key_list)-1: + sql_str += ' and ' + sql_str += ';' + return sql_str + + @staticmethod + async def get_table_info(database_url, table_name): + engine = create_engine( + database_url, + pool_size=20, + max_overflow=80, + pool_recycle=300, + pool_pre_ping=True + ) + with engine.connect() as conn: + sql_str = """ + SELECT + d.description AS table_description + FROM + pg_class t + JOIN + pg_description d ON t.oid = d.objoid + WHERE + t.relkind = 'r' AND + d.objsubid = 0 AND + t.relname = :table_name; """ + table_note = conn.execute(text(sql_str), {'table_name': table_name}).one()[0] + if table_note == '': + table_note = table_name + table_note = { + 'table_name': table_name, + 'table_note': table_note + } + return table_note + + @staticmethod + async def get_column_info(database_url, table_name): + engine = create_engine( + database_url, + pool_size=20, + max_overflow=80, + pool_recycle=300, + pool_pre_ping=True + ) + with engine.connect() as conn: + sql_str = """ + SELECT + a.attname as 字段名, + format_type(a.atttypid,a.atttypmod) as 类型, + col_description(a.attrelid,a.attnum) as 注释 + FROM + pg_class as c,pg_attribute as a + where + a.attrelid = c.oid + and + a.attnum>0 + and + c.relname = :table_name; + """ + results = conn.execute(text(sql_str), {'table_name': table_name}).all() + column_info_list = [] + for result in results: + column_info_list.append({'column_name': result[0], 'column_type': result[1], 'column_note': result[2]}) + return column_info_list + + @staticmethod + async def get_all_table_name_from_database_url(database_url): + engine = create_engine( + database_url, + pool_size=20, + max_overflow=80, + pool_recycle=300, + pool_pre_ping=True + ) + with engine.connect() as connection: + sql_str = ''' + SELECT table_name + FROM information_schema.tables + WHERE table_schema = 'public'; + ''' + result = connection.execute(text(sql_str)) + table_name_list = [row[0] for row in result] + return table_name_list + + @staticmethod + async def get_rand_data(database_url, table_name, cnt=10): + engine = create_engine( + database_url, + pool_size=20, + max_overflow=80, + pool_recycle=300, + pool_pre_ping=True + ) + try: + with sessionmaker(engine)() as session: + sql_str = f'''SELECT * + FROM {table_name} + ORDER BY RANDOM() + LIMIT {cnt};''' + dataframe = str(session.execute(text(sql_str)).all()) + except Exception as e: + dataframe = '' + logging.error(f'随机从数据库中获取数据失败由于{e}') + return dataframe + + @staticmethod + async def try_excute(database_url, sql_str): + engine = create_engine( + database_url, + pool_size=20, + max_overflow=80, + pool_recycle=300, + pool_pre_ping=True + ) + with sessionmaker(engine)() as session: + result=session.execute(text(sql_str)).all() + return result diff --git a/chat2db/app/base/vectorize.py b/chat2db/app/base/vectorize.py new file mode 100644 index 0000000000000000000000000000000000000000..53fe4af3fe090d2ccae66b6ea61121fa4931cebc --- /dev/null +++ b/chat2db/app/base/vectorize.py @@ -0,0 +1,46 @@ +import requests +import urllib3 +from chat2db.config.config import config +import json +import logging + +logging.basicConfig(filename='app.log', level=logging.INFO, + format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s') + +urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) + + +class Vectorize(): + @staticmethod + async def vectorize_embedding(text): + if config['EMBEDDING_TYPE']=='openai': + headers = { + "Authorization": f"Bearer {config['EMBEDDING_API_KEY']}" + } + data = { + "input": text, + "model": config["EMBEDDING_MODEL_NAME"], + "encoding_format": "float" + } + try: + res = requests.post(url=config["EMBEDDING_ENDPOINT"],headers=headers, json=data, verify=False) + if res.status_code != 200: + return None + return res.json()['data'][0]['embedding'] + except Exception as e: + logging.error(f"Embedding error failed due to: {e}") + return None + elif config['EMBEDDING_TYPE'] =='mindie': + try: + data = { + "inputs": text, + } + res = requests.post(url=config["EMBEDDING_ENDPOINT"], json=data, verify=False) + if res.status_code != 200: + return None + return json.loads(res.text)[0] + except Exception as e: + logging.error(f"Embedding error failed due to: {e}") + return None + else: + return None diff --git a/chat2db/app/router/database.py b/chat2db/app/router/database.py new file mode 100644 index 0000000000000000000000000000000000000000..dd9a3f735bbf5fa227970406c717f1a0dc121edd --- /dev/null +++ b/chat2db/app/router/database.py @@ -0,0 +1,114 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. + +import logging +import uuid +from fastapi import APIRouter, status + +from chat2db.model.request import DatabaseAddRequest, DatabaseDelRequest +from chat2db.model.response import ResponseData +from chat2db.manager.database_info_manager import DatabaseInfoManager +from chat2db.app.service.diff_database_service import DiffDatabaseService + +logging.basicConfig(filename='app.log', level=logging.INFO, + format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s') + +router = APIRouter( + prefix="/database" +) + + +@router.post("/add", response_model=ResponseData) +async def add_database_info(request: DatabaseAddRequest): + database_url = request.database_url + if 'mysql' not in database_url and 'postgres' not in database_url: + return ResponseData( + code=status.HTTP_422_UNPROCESSABLE_ENTITY, + message="不支持当前数据库", + result={} + ) + database_type = 'postgres' + if 'mysql' in database_url: + database_type = 'mysql' + flag = await DiffDatabaseService.get_database_service(database_type).test_database_connection(database_url) + if not flag: + return ResponseData( + code=status.HTTP_422_UNPROCESSABLE_ENTITY, + message="无法连接当前数据库", + result={} + ) + database_id = await DatabaseInfoManager.add_database(database_url) + if database_id is None: + return ResponseData( + code=status.HTTP_422_UNPROCESSABLE_ENTITY, + message="数据库连接添加失败,当前存在重复数据库配置", + result={} + ) + return ResponseData( + code=status.HTTP_200_OK, + message="success", + result={'database_id': database_id} + ) + + +@router.post("/del", response_model=ResponseData) +async def del_database_info(request: DatabaseDelRequest): + database_id = request.database_id + flag = await DatabaseInfoManager.del_database_by_id(database_id) + if not flag: + return ResponseData( + code=status.HTTP_422_UNPROCESSABLE_ENTITY, + message="删除数据库配置失败,数据库配置不存在", + result={} + ) + return ResponseData( + code=status.HTTP_200_OK, + message="删除数据库配置成功", + result={} + ) + + +@router.get("/query", response_model=ResponseData) +async def query_database_info(): + database_info_list = await DatabaseInfoManager.get_all_database_info() + return ResponseData( + code=status.HTTP_200_OK, + message="查询数据库配置成功", + result={'database_info_list': database_info_list} + ) + + +@router.get("/list", response_model=ResponseData) +async def list_table_in_database(database_id: uuid.UUID, table_filter: str = ''): + database_url = await DatabaseInfoManager.get_database_url_by_id(database_id) + if database_url is None: + return ResponseData( + code=status.HTTP_422_UNPROCESSABLE_ENTITY, + message="查询数据库内表格配置失败,数据库配置不存在", + result={} + ) + if 'mysql' not in database_url and 'postgres' not in database_url: + return ResponseData( + code=status.HTTP_422_UNPROCESSABLE_ENTITY, + message="不支持当前数据库", + result={} + ) + database_type = 'postgres' + if 'mysql' in database_url: + database_type = 'mysql' + flag = await DiffDatabaseService.get_database_service(database_type).test_database_connection(database_url) + if not flag: + return ResponseData( + code=status.HTTP_422_UNPROCESSABLE_ENTITY, + message="无法连接当前数据库", + result={} + ) + table_name_list = await DiffDatabaseService.get_database_service(database_type).get_all_table_name_from_database_url(database_url) + results = [] + for table_name in table_name_list: + if table_filter in table_name: + results.append(table_name) + return ResponseData( + code=status.HTTP_200_OK, + message="查询数据库配置成功", + result={'table_name_list': results} + ) diff --git a/chat2db/app/router/sql_example.py b/chat2db/app/router/sql_example.py new file mode 100644 index 0000000000000000000000000000000000000000..5b794e25e53459c90d6d666756a7d7066adc6124 --- /dev/null +++ b/chat2db/app/router/sql_example.py @@ -0,0 +1,136 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. + +import logging +import uuid +from fastapi import APIRouter, status + +from chat2db.model.request import SqlExampleAddRequest, SqlExampleDelRequest, SqlExampleUpdateRequest, SqlExampleGenerateRequest +from chat2db.model.response import ResponseData +from chat2db.manager.database_info_manager import DatabaseInfoManager +from chat2db.manager.table_info_manager import TableInfoManager +from chat2db.manager.sql_example_manager import SqlExampleManager +from chat2db.app.service.sql_generate_service import SqlGenerateService +from chat2db.app.base.vectorize import Vectorize +logging.basicConfig(filename='app.log', level=logging.INFO, + format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s') + +router = APIRouter( + prefix="/sql/example" +) + + +@router.post("/add", response_model=ResponseData) +async def add_sql_example(request: SqlExampleAddRequest): + table_id = request.table_id + table_info = await TableInfoManager.get_table_info_by_table_id(table_id) + if table_info is None: + return ResponseData( + code=status.HTTP_422_UNPROCESSABLE_ENTITY, + message="表格不存在", + result={} + ) + database_id = table_info['database_id'] + question = request.question + question_vector = await Vectorize.vectorize_embedding(question) + sql = request.sql + try: + sql_example_id = await SqlExampleManager.add_sql_example(question, sql, table_id, question_vector) + except Exception as e: + logging.error(f'sql案例添加失败由于{e}') + return ResponseData( + code=status.HTTP_400_BAD_REQUEST, + message="sql案例添加失败", + result={} + ) + return ResponseData( + code=status.HTTP_200_OK, + message="success", + result={'sql_example_id': sql_example_id} + ) + + +@router.post("/del", response_model=ResponseData) +async def del_sql_example(request: SqlExampleDelRequest): + sql_example_id = request.sql_example_id + flag = await SqlExampleManager.del_sql_example_by_id(sql_example_id) + if not flag: + return ResponseData( + code=status.HTTP_422_UNPROCESSABLE_ENTITY, + message="sql案例不存在", + result={} + ) + return ResponseData( + code=status.HTTP_200_OK, + message="sql案例删除成功", + result={} + ) + + +@router.get("/query", response_model=ResponseData) +async def query_sql_example(table_id: uuid.UUID): + sql_example_list = await SqlExampleManager.query_sql_example_by_table_id(table_id) + return ResponseData( + code=status.HTTP_200_OK, + message="查询sql案例成功", + result={'sql_example_list': sql_example_list} + ) + + +@router.post("/udpate", response_model=ResponseData) +async def update_sql_example(request: SqlExampleUpdateRequest): + sql_example_id = request.sql_example_id + question = request.question + question_vector = await Vectorize.vectorize_embedding(question) + sql = request.sql + flag = await SqlExampleManager.update_sql_example_by_id(sql_example_id, question, sql, question_vector) + if not flag: + return ResponseData( + code=status.HTTP_422_UNPROCESSABLE_ENTITY, + message="sql案例不存在", + result={} + ) + return ResponseData( + code=status.HTTP_200_OK, + message="sql案例更新成功", + result={} + ) + + +@router.post("/generate", response_model=ResponseData) +async def generate_sql_example(request: SqlExampleGenerateRequest): + table_id = request.table_id + generate_cnt = request.generate_cnt + table_info = await TableInfoManager.get_table_info_by_table_id(table_id) + if table_info is None: + return ResponseData( + code=status.HTTP_422_UNPROCESSABLE_ENTITY, + message="表格不存在", + result={} + ) + table_name = table_info['table_name'] + database_id = table_info['database_id'] + database_url = await DatabaseInfoManager.get_database_url_by_id(database_id) + sql_var = request.sql_var + sql_example_list = [] + for i in range(generate_cnt): + try: + tmp_dict = await SqlGenerateService.generate_sql_base_on_data(database_url, table_name, sql_var) + except Exception as e: + logging.error(f'sql案例生成失败由于{e}') + continue + if tmp_dict is None: + continue + question = tmp_dict['question'] + question_vector = await Vectorize.vectorize_embedding(question) + sql = tmp_dict['sql'] + await SqlExampleManager.add_sql_example(question, sql, table_id, question_vector) + tmp_dict['database_id'] = database_id + tmp_dict['table_id'] = table_id + sql_example_list.append(tmp_dict) + return ResponseData( + code=status.HTTP_200_OK, + message="sql案例生成成功", + result={ + 'sql_example_list': sql_example_list + } + ) diff --git a/chat2db/app/router/sql_generate.py b/chat2db/app/router/sql_generate.py new file mode 100644 index 0000000000000000000000000000000000000000..2a09e52de54c108eb2f73aa2b418f7744f256fb3 --- /dev/null +++ b/chat2db/app/router/sql_generate.py @@ -0,0 +1,129 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. + +import logging +from fastapi import APIRouter, status + +from chat2db.manager.database_info_manager import DatabaseInfoManager +from chat2db.manager.table_info_manager import TableInfoManager +from chat2db.manager.column_info_manager import ColumnInfoManager +from chat2db.model.request import SqlGenerateRequest, SqlRepairRequest, SqlExcuteRequest +from chat2db.model.response import ResponseData +from chat2db.app.service.sql_generate_service import SqlGenerateService +from chat2db.app.service.keyword_service import keyword_service +from chat2db.app.service.diff_database_service import DiffDatabaseService +logging.basicConfig(filename='app.log', level=logging.INFO, + format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s') + +router = APIRouter( + prefix="/sql" +) + + +@router.post("/generate", response_model=ResponseData) +async def generate_sql(request: SqlGenerateRequest): + database_id = request.database_id + database_url = await DatabaseInfoManager.get_database_url_by_id(database_id) + table_id_list = request.table_id_list + question = request.question + use_llm_enhancements = request.use_llm_enhancements + if table_id_list == []: + table_id_list = None + results = {} + sql_list = await SqlGenerateService.generate_sql_base_on_exmpale( + database_id=database_id, question=question, table_id_list=table_id_list, + use_llm_enhancements=use_llm_enhancements) + try: + sql_list += await keyword_service.generate_sql(question, database_id, table_id_list) + results['sql_list'] = sql_list[:request.topk] + results['database_url'] = database_url + except Exception as e: + logging.error(f'sql生成失败由于{e}') + return ResponseData( + code=status.HTTP_400_BAD_REQUEST, + message="sql生成失败", + result={} + ) + return ResponseData( + code=status.HTTP_200_OK, message="success", + result=results + ) + + +@router.post("/repair", response_model=ResponseData) +async def repair_sql(request: SqlRepairRequest): + database_id = request.database_id + table_id = request.table_id + database_url = await DatabaseInfoManager.get_database_url_by_id(database_id) + if database_url is None: + return ResponseData( + code=status.HTTP_422_UNPROCESSABLE_ENTITY, + message="当前数据库配置不存在", + result={} + ) + database_type = 'postgres' + if 'mysql' in database_url: + database_type = 'mysql' + table_info = await TableInfoManager.get_table_info_by_table_id(table_id) + if table_info is None: + return ResponseData( + code=status.HTTP_422_UNPROCESSABLE_ENTITY, + message="表格不存在", + result={} + ) + if table_info['database_id'] != database_id: + return ResponseData( + code=status.HTTP_422_UNPROCESSABLE_ENTITY, + message="表格不属于当前数据库", + result={} + ) + column_info_list = await ColumnInfoManager.get_column_info_by_table_id(table_id) + sql = request.sql + message = request.message + question = request.question + try: + sql = await SqlGenerateService.repair_sql(database_type, table_info, column_info_list, sql, message, question) + except Exception as e: + logging.error(f'sql修复失败由于{e}') + return ResponseData( + code=status.HTTP_422_UNPROCESSABLE_ENTITY, + message="sql修复失败", + result={} + ) + return ResponseData( + code=status.HTTP_200_OK, + message="sql修复成功", + result={'database_id': database_id, + 'table_id': table_id, + 'sql': sql} + ) + + +@router.post("/execute", response_model=ResponseData) +async def execute_sql(request: SqlExcuteRequest): + database_id = request.database_id + sql = request.sql + database_url = await DatabaseInfoManager.get_database_url_by_id(database_id) + if database_url is None: + return ResponseData( + code=status.HTTP_422_UNPROCESSABLE_ENTITY, + message="当前数据库配置不存在", + result={} + ) + database_type = 'postgres' + if 'mysql' in database_url: + database_type = 'mysql' + try: + results = await DiffDatabaseService.database_map[database_type].try_excute(database_url, sql) + results = str(results) + except Exception as e: + logging.error(f'sql执行失败由于{e}') + return ResponseData( + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + message="sql执行失败", + result={'Error':str(e)} + ) + return ResponseData( + code=status.HTTP_200_OK, + message="sql执行成功", + result={'results':results} + ) diff --git a/chat2db/app/router/table.py b/chat2db/app/router/table.py new file mode 100644 index 0000000000000000000000000000000000000000..85ad02b143913ca75bddd0a677bb9792663ec582 --- /dev/null +++ b/chat2db/app/router/table.py @@ -0,0 +1,148 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. + +import logging +import uuid +from fastapi import APIRouter, status + +from chat2db.model.request import TableAddRequest, TableDelRequest, EnableColumnRequest +from chat2db.model.response import ResponseData +from chat2db.manager.database_info_manager import DatabaseInfoManager +from chat2db.manager.table_info_manager import TableInfoManager +from chat2db.manager.column_info_manager import ColumnInfoManager +from chat2db.app.service.diff_database_service import DiffDatabaseService +from chat2db.app.base.vectorize import Vectorize +from chat2db.app.service.keyword_service import keyword_service +logging.basicConfig(filename='app.log', level=logging.INFO, + format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s') + +router = APIRouter( + prefix="/table" +) + + +@router.post("/add", response_model=ResponseData) +async def add_database_info(request: TableAddRequest): + database_id = request.database_id + database_url = await DatabaseInfoManager.get_database_url_by_id(database_id) + if database_url is None: + return ResponseData( + code=status.HTTP_422_UNPROCESSABLE_ENTITY, + message="当前数据库配置不存在", + result={} + ) + database_type = 'postgres' + if 'mysql' in database_url: + database_type = 'mysql' + flag = await DiffDatabaseService.get_database_service(database_type).test_database_connection(database_url) + if not flag: + return ResponseData( + code=status.HTTP_422_UNPROCESSABLE_ENTITY, + message="无法连接当前数据库", + result={} + ) + table_name = request.table_name + table_name_list = await DiffDatabaseService.get_database_service(database_type).get_all_table_name_from_database_url(database_url) + if table_name not in table_name_list: + return ResponseData( + code=status.HTTP_422_UNPROCESSABLE_ENTITY, + message="表格不存在", + result={} + ) + tmp_dict = await DiffDatabaseService.get_database_service(database_type).get_table_info(database_url, table_name) + table_note = tmp_dict['table_note'] + table_note_vector = await Vectorize.vectorize_embedding(table_note) + table_id = await TableInfoManager.add_table_info(database_id, table_name, table_note, table_note_vector) + if table_id is None: + return ResponseData( + code=status.HTTP_422_UNPROCESSABLE_ENTITY, + message="表格添加失败,当前存在重复表格", + result={} + ) + column_info_list = await DiffDatabaseService.get_database_service(database_type).get_column_info(database_url, table_name) + for column_info in column_info_list: + await ColumnInfoManager.add_column_info_with_table_id( + table_id, column_info['column_name'], + column_info['column_type'], + column_info['column_note']) + return ResponseData( + code=status.HTTP_200_OK, + message="success", + result={'table_id': table_id} + ) + + +@router.post("/del", response_model=ResponseData) +async def del_table_info(request: TableDelRequest): + table_id = request.table_id + flag = await TableInfoManager.del_table_by_id(table_id) + if not flag: + return ResponseData( + code=status.HTTP_422_UNPROCESSABLE_ENTITY, + message="表格不存在", + result={} + ) + return ResponseData( + code=status.HTTP_200_OK, + message="删除表格成功", + result={} + ) + + +@router.get("/query", response_model=ResponseData) +async def query_table_info(database_id: uuid.UUID): + database_url = await DatabaseInfoManager.get_database_url_by_id(database_id) + if database_url is None: + return ResponseData( + code=status.HTTP_422_UNPROCESSABLE_ENTITY, + message="当前数据库配置不存在", + result={} + ) + table_info_list = await TableInfoManager.get_table_info_by_database_id(database_id) + return ResponseData( + code=status.HTTP_200_OK, + message="查询表格成功", + result={'table_info_list': table_info_list} + ) + + +@router.get("/column/query", response_model=ResponseData) +async def query_column(table_id: uuid.UUID): + column_info_list = await ColumnInfoManager.get_column_info_by_table_id(table_id) + return ResponseData( + code=status.HTTP_200_OK, + message="", + result={'column_info_list': column_info_list} + ) + + +@router.post("/column/enable", response_model=ResponseData) +async def enable_column(request: EnableColumnRequest): + column_id = request.column_id + enable = request.enable + flag = await ColumnInfoManager.update_column_info_enable(column_id, enable) + if not flag: + return ResponseData( + code=status.HTTP_422_UNPROCESSABLE_ENTITY, + message="列不存在", + result={} + ) + column_info = await ColumnInfoManager.get_column_info_by_column_id(column_id) + column_name = column_info['column_name'] + table_id = column_info['table_id'] + table_info = await TableInfoManager.get_table_info_by_table_id(table_id) + database_id = table_info['database_id'] + if enable: + flag = await keyword_service.add(database_id, table_id, column_name) + else: + flag = await keyword_service.del_by_column_name(database_id, table_id, column_name) + if not flag: + return ResponseData( + code=status.HTTP_422_UNPROCESSABLE_ENTITY, + message="列关键字功能开启/关闭失败", + result={} + ) + return ResponseData( + code=status.HTTP_200_OK, + message="列关键字功能开启/关闭成功", + result={} + ) diff --git a/chat2db/app/service/diff_database_service.py b/chat2db/app/service/diff_database_service.py new file mode 100644 index 0000000000000000000000000000000000000000..0fe77b0b2eac045d3645f6d143ed1619060127e3 --- /dev/null +++ b/chat2db/app/service/diff_database_service.py @@ -0,0 +1,10 @@ +from chat2db.app.base.mysql import Mysql +from chat2db.app.base.postgres import Postgres + + +class DiffDatabaseService(): + database_map = {"mysql": Mysql, "postgres": Postgres} + + @staticmethod + def get_database_service(database_type): + return DiffDatabaseService.database_map[database_type] diff --git a/chat2db/app/service/keyword_service.py b/chat2db/app/service/keyword_service.py new file mode 100644 index 0000000000000000000000000000000000000000..335410a8f3ff0155c7eeee7a3b54b77699fe2e68 --- /dev/null +++ b/chat2db/app/service/keyword_service.py @@ -0,0 +1,134 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +import asyncio +import copy +import uuid +import threading +from concurrent.futures import ThreadPoolExecutor +from chat2db.app.service.diff_database_service import DiffDatabaseService +from chat2db.app.base.ac_automation import DictTree +from chat2db.manager.database_info_manager import DatabaseInfoManager +from chat2db.manager.table_info_manager import TableInfoManager +from chat2db.manager.column_info_manager import ColumnInfoManager +import logging + +logging.basicConfig(filename='app.log', level=logging.INFO, + format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s') + + +class KeywordManager(): + def __init__(self): + self.keyword_asset_dict = {} + self.lock = threading.Lock() + self.data_frame_dict = {} + + async def load_keywords(self): + database_info_list = await DatabaseInfoManager.get_all_database_info() + for database_info in database_info_list: + database_id = database_info['database_id'] + table_info_list = await TableInfoManager.get_table_info_by_database_id(database_id) + cnt=0 + for table_info in table_info_list: + table_id = table_info['table_id'] + column_info_list = await ColumnInfoManager.get_column_info_by_table_id(table_id, True) + for i in range(len(column_info_list)): + column_info = column_info_list[i] + cnt+=1 + try: + column_name = column_info['column_name'] + await self.add(database_id, table_id, column_name) + except Exception as e: + logging.error('关键字数据结构生成失败') + def add_excutor(self, rd_id, database_id, table_id, table_info, column_info_list, column_name): + tmp_dict = self.data_frame_dict[rd_id] + tmp_dict_tree = DictTree() + tmp_dict_tree.load_data(tmp_dict['keyword_value_dict']) + if database_id not in self.keyword_asset_dict.keys(): + self.keyword_asset_dict[database_id] = {} + with self.lock: + if table_id not in self.keyword_asset_dict[database_id].keys(): + self.keyword_asset_dict[database_id][table_id] = {} + self.keyword_asset_dict[database_id][table_id]['table_info'] = table_info + self.keyword_asset_dict[database_id][table_id]['column_info_list'] = column_info_list + self.keyword_asset_dict[database_id][table_id]['primary_key_list'] = copy.deepcopy( + tmp_dict['primary_key_list']) + self.keyword_asset_dict[database_id][table_id]['dict_tree_dict'] = {} + self.keyword_asset_dict[database_id][table_id]['dict_tree_dict'][column_name] = tmp_dict_tree + del self.data_frame_dict[rd_id] + + async def add(self, database_id, table_id, column_name): + database_url = await DatabaseInfoManager.get_database_url_by_id(database_id) + database_type = 'postgres' + if 'mysql' in database_url: + database_type = 'mysql' + table_info = await TableInfoManager.get_table_info_by_table_id(table_id) + table_name = table_info['table_name'] + tmp_dict = await DiffDatabaseService.get_database_service( + database_type).select_primary_key_and_keyword_from_table(database_url, table_name, column_name) + if tmp_dict is None: + return + rd_id = str(uuid.uuid4) + self.data_frame_dict[rd_id] = tmp_dict + del database_url + column_info_list = await ColumnInfoManager.get_column_info_by_table_id(table_id) + try: + thread = threading.Thread(target=self.add_excutor, args=(rd_id, database_id, table_id, + table_info, column_info_list, column_name,)) + thread.start() + except Exception as e: + logging.error(f'创建增加线程失败由于{e}') + return False + return True + + async def update_keyword_asset(self): + database_info_list = DatabaseInfoManager.get_all_database_info() + for database_info in database_info_list: + database_id = database_info['database_id'] + table_info_list = TableInfoManager.get_table_info_by_database_id(database_id) + for table_info in table_info_list: + table_id = table_info['table_id'] + column_info_list = ColumnInfoManager.get_column_info_by_table_id(table_id, True) + for column_info in column_info_list: + await self.add(database_id, table_id, column_info['column_name']) + + async def del_by_column_name(self, database_id, table_id, column_name): + try: + with self.lock: + if database_id in self.keyword_asset_dict.keys(): + if table_id in self.keyword_asset_dict[database_id].keys(): + if column_name in self.keyword_asset_dict[database_id][table_id]['dict_tree_dict'].keys(): + del self.keyword_asset_dict[database_id][table_id]['dict_tree_dict'][column_name] + except Exception as e: + logging.error(f'字典树删除失败由于{e}') + return False + return True + + async def generate_sql(self, question, database_id, table_id_list=None): + with self.lock: + results = [] + if database_id in self.keyword_asset_dict.keys(): + database_url = await DatabaseInfoManager.get_database_url_by_id(database_id) + database_type = 'postgres' + if 'mysql' in database_url: + database_type = 'mysql' + for table_id in self.keyword_asset_dict[database_id].keys(): + if table_id_list is None or table_id in table_id_list: + table_info = self.keyword_asset_dict[database_id][table_id]['table_info'] + primary_key_list = self.keyword_asset_dict[database_id][table_id]['primary_key_list'] + primary_key_value_list = [] + try: + for dict_tree in self.keyword_asset_dict[database_id][table_id]['dict_tree_dict'].values(): + primary_key_value_list += dict_tree.get_results(question) + except Exception as e: + logging.error(f'从字典树中获取结果失败由于{e}') + continue + for i in range(len(primary_key_value_list)): + sql_str = await DiffDatabaseService.get_database_service(database_type).assemble_sql_query_base_on_primary_key( + table_info['table_name'], primary_key_list, primary_key_value_list[i]) + tmp_dict = {'database_id': database_id, 'table_id': table_id, 'sql': sql_str} + results.append(tmp_dict) + del database_url + return results + + +keyword_service = KeywordManager() +asyncio.run(keyword_service.load_keywords()) diff --git a/chat2db/app/service/sql_generate_service.py b/chat2db/app/service/sql_generate_service.py new file mode 100644 index 0000000000000000000000000000000000000000..e02e430ecf968f7e2d050b9f68c94030d0c83f61 --- /dev/null +++ b/chat2db/app/service/sql_generate_service.py @@ -0,0 +1,351 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +import yaml +import re +import json +import random +import uuid +import logging +from pandas.core.api import DataFrame as DataFrame + +from chat2db.manager.database_info_manager import DatabaseInfoManager +from chat2db.manager.table_info_manager import TableInfoManager +from chat2db.manager.column_info_manager import ColumnInfoManager +from chat2db.manager.sql_example_manager import SqlExampleManager +from chat2db.app.service.diff_database_service import DiffDatabaseService +from chat2db.llm.chat_with_model import LLM +from chat2db.config.config import config +from chat2db.app.base.vectorize import Vectorize + + +logging.basicConfig(filename='app.log', level=logging.INFO, + format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s') + + +class SqlGenerateService(): + + @staticmethod + async def merge_table_and_column_info(table_info, column_info_list): + table_name = table_info.get('table_name', '') + table_note = table_info.get('table_note', '') + note = '\n' + note += '\n'+'\n'+'\n' + note += '\n'+f'\n'+'\n' + note += '\n'+'\n'+'\n' + note += '\n'+f'\n'+'\n' + note += '\n'+' \n\n\n'+'\n' + for column_info in column_info_list: + column_name = column_info.get('column_name', '') + column_type = column_info.get('column_type', '') + column_note = column_info.get('column_note', '') + note += '\n'+f' \n\n\n'+'\n' + note += '
表名
{table_name}
表的注释
{table_note}
字段字段类型字段注释
{column_name}{column_type}{column_note}
' + return note + + @staticmethod + def extract_list_statements(list_string): + pattern = r'\[.*?\]' + matches = re.findall(pattern, list_string) + if len(matches) == 0: + return '' + tmp = matches[0] + tmp.replace('\'', '\"') + tmp.replace(',', ',') + return tmp + + @staticmethod + async def get_most_similar_table_id_list(database_id, question, table_choose_cnt): + table_info_list = await TableInfoManager.get_table_info_by_database_id(database_id) + random.shuffle(table_info_list) + table_id_set = set() + for table_info in table_info_list: + table_id = table_info['table_id'] + table_id_set.add(str(table_id)) + try: + with open('./chat2db/docs/prompt.yaml', 'r', encoding='utf-8') as f: + prompt_dict = yaml.load(f, Loader=yaml.SafeLoader) + prompt = prompt_dict.get('table_choose_prompt', '') + table_entries = '\n' + table_entries += '\n'+' \n\n'+'\n' + token_upper = 2048 + for table_info in table_info_list: + table_id = table_info['table_id'] + table_note = table_info['table_note'] + if len(table_entries) + len( + '\n' + f' \n\n' + '\n') > token_upper: + break + table_entries += '\n'+f' \n\n'+'\n' + table_entries += '
主键表注释
{table_id}{table_note}
{table_id}{table_note}
' + prompt = prompt.format(table_cnt=table_choose_cnt, table_entries=table_entries, question=question) + logging.info(f'在大模型增强模式下,选择表的prompt构造成功:{prompt}') + except Exception as e: + logging.error(f'在大模型增强模式下,选择表的prompt构造失败由于:{e}') + return [] + try: + llm = LLM(model_name=config['LLM_MODEL'], + openai_api_base=config['LLM_URL'], + openai_api_key=config['LLM_KEY'], + max_tokens=config['LLM_MAX_TOKENS'], + request_timeout=60, + temperature=0.5) + except Exception as e: + llm = None + logging.error(f'在大模型增强模式下,选择表的过程中,与大模型建立连接失败由于:{e}') + table_id_list = [] + if llm is not None: + for i in range(2): + content = await llm.chat_with_model(prompt, '请输包含选择表主键的列表') + try: + sub_table_id_list = json.loads(SqlGenerateService.extract_list_statements(content)) + except: + sub_table_id_list = [] + for j in range(len(sub_table_id_list)): + if sub_table_id_list[j] in table_id_set and uuid.UUID(sub_table_id_list[j]) not in table_id_list: + table_id_list.append(uuid.UUID(sub_table_id_list[j])) + if len(table_id_list) < table_choose_cnt: + table_choose_cnt -= len(table_id_list) + for i in range(min(table_choose_cnt, len(table_info_list))): + table_id = table_info_list[i]['table_id'] + if table_id is not None and table_id not in table_id_list: + table_id_list.append(table_id) + return table_id_list + + @staticmethod + async def find_most_similar_sql_example( + database_id, table_id_list, question, use_llm_enhancements=False, table_choose_cnt=2, sql_example_choose_cnt=10, + topk=5): + try: + database_url = await DatabaseInfoManager.get_database_url_by_id(database_id) + except Exception as e: + logging.error(f'数据库{database_id}信息获取失败由于{e}') + return [] + database_type = 'psotgres' + if 'mysql' in database_url: + database_type = 'mysql' + del database_url + try: + question_vector = await Vectorize.vectorize_embedding(question) + except Exception as e: + logging.error(f'问题向量化失败由于:{e}') + return {} + sql_example = [] + data_frame_list = [] + if table_id_list is None: + if use_llm_enhancements: + table_id_list = await SqlGenerateService.get_most_similar_table_id_list(database_id, question, table_choose_cnt) + else: + try: + table_info_list = await TableInfoManager.get_table_info_by_database_id(database_id) + table_id_list = [] + for table_info in table_info_list: + table_id_list.append(table_info['table_id']) + sql_example_list = await SqlExampleManager.get_topk_sql_example_by_cos_dis( + question_vector=question_vector, + table_id_list=table_id_list, topk=table_choose_cnt * 2) + table_id_list = [] + for sql_example in sql_example_list: + table_id_list.append(sql_example['table_id']) + except Exception as e: + logging.error(f'非增强模式下,表id获取失败由于:{e}') + return [] + table_id_list = list(set(table_id_list)) + if len(table_id_list) < table_choose_cnt: + try: + expand_table_id_list = await TableInfoManager.get_topk_table_by_cos_dis( + database_id, question_vector, table_choose_cnt - len(table_id_list)) + table_id_list += expand_table_id_list + except Exception as e: + logging.error(f'非增强模式下,表id补充失败由于:{e}') + exist_table_id = set() + note_list = [] + for i in range(min(2, len(table_id_list))): + table_id = table_id_list[i] + if table_id in exist_table_id: + continue + exist_table_id.add(table_id) + try: + table_info = await TableInfoManager.get_table_info_by_table_id(table_id) + column_info_list = await ColumnInfoManager.get_column_info_by_table_id(table_id) + except Exception as e: + logging.error(f'表{table_id}注释获取失败由于{e}') + note = await SqlGenerateService.merge_table_and_column_info(table_info, column_info_list) + note_list.append(note) + try: + sql_example_list = await SqlExampleManager.get_topk_sql_example_by_cos_dis( + question_vector, table_id_list=[table_id], topk=sql_example_choose_cnt) + except Exception as e: + sql_example_list = [] + logging.error(f'id为{table_id}的表的最相近的{topk}条sql案例获取失败由于:{e}') + question_sql_list = [] + for i in range(len(sql_example_list)): + question_sql_list.append( + {'question': sql_example_list[i]['question'], + 'sql': sql_example_list[i]['sql']}) + data_frame_list.append({'table_id': table_id, 'table_info': table_info, + 'column_info_list': column_info_list, 'sql_example_list': question_sql_list}) + return data_frame_list + + @staticmethod + async def megre_sql_example(sql_example_list): + sql_example = '' + for i in range(len(sql_example_list)): + sql_example += '问题'+str(i)+':\n'+sql_example_list[i].get('question', + '')+'\nsql'+str(i)+':\n'+sql_example_list[i].get('sql', '')+'\n' + return sql_example + + @staticmethod + async def extract_select_statements(sql_string): + pattern = r"(?i)select[^;]*;" + matches = re.findall(pattern, sql_string) + if len(matches) == 0: + return '' + sql = matches[0] + sql = sql.strip() + sql.replace(',', ',') + return sql + + @staticmethod + async def generate_sql_base_on_exmpale( + database_id, question, table_id_list=None, sql_generate_cnt=1, use_llm_enhancements=False): + try: + database_url = await DatabaseInfoManager.get_database_url_by_id(database_id) + except Exception as e: + logging.error(f'数据库{database_id}信息获取失败由于{e}') + return {} + if database_url is None: + raise Exception('数据库配置不存在') + database_type = 'psotgres' + if 'mysql' in database_url: + database_type = 'mysql' + data_frame_list = await SqlGenerateService.find_most_similar_sql_example(database_id, table_id_list, question, use_llm_enhancements) + logging.info(f'问题{question}关联到的表信息如下{str(data_frame_list)}') + try: + with open('./chat2db/docs/prompt.yaml', 'r', encoding='utf-8') as f: + prompt_dict = yaml.load(f, Loader=yaml.SafeLoader) + llm = LLM(model_name=config['LLM_MODEL'], + openai_api_base=config['LLM_URL'], + openai_api_key=config['LLM_KEY'], + max_tokens=config['LLM_MAX_TOKENS'], + request_timeout=60, + temperature=0.5) + results = [] + for data_frame in data_frame_list: + prompt = prompt_dict.get('sql_generate_base_on_example_prompt', '') + table_info = data_frame.get('table_info', '') + table_id = table_info['table_id'] + column_info_list = data_frame.get('column_info_list', '') + note = await SqlGenerateService.merge_table_and_column_info(table_info, column_info_list) + sql_example = await SqlGenerateService.megre_sql_example(data_frame.get('sql_example_list', [])) + try: + prompt = prompt.format( + database_url=database_url, note=note, k=len(data_frame.get('sql_example_list', [])), + sql_example=sql_example, question=question) + except Exception as e: + logging.info(f'sql生成失败{e}') + return [] + ge_cnt=0 + ge_sql_cnt=0 + while ge_cnt<10*sql_generate_cnt and ge_sql_cnt 1 or count_char(question, '?') > 1: + continue + except Exception as e: + logging.error(f'问题生成失败由于{e}') + continue + try: + with open('./chat2db/docs/prompt.yaml', 'r', encoding='utf-8') as f: + prompt_dict = yaml.load(f, Loader=yaml.SafeLoader) + prompt = prompt_dict['sql_generate_base_on_data_prompt'].format( + database_type=database_type, + note=note, data_frame=data_frame, question=question) + sql = await llm.chat_with_model(prompt, f'请输出一条可以用于查询{database_type}的sql,要以分号结尾') + sql = await SqlGenerateService.extract_select_statements(sql) + if not sql: + continue + except Exception as e: + logging.error(f'sql生成失败由于{e}') + continue + try: + if sql_var: + await DiffDatabaseService.get_database_service(database_type).try_excute(database_url,sql) + except Exception as e: + logging.error(f'生成的sql执行失败由于{e}') + continue + return { + 'question': question, + 'sql': sql + } + return None + + @staticmethod + async def repair_sql(database_type, table_info, column_info_list, sql_failed, sql_failed_message, question): + try: + with open('./chat2db/docs/prompt.yaml', 'r', encoding='utf-8') as f: + prompt_dict = yaml.load(f, Loader=yaml.SafeLoader) + llm = LLM(model_name=config['LLM_MODEL'], + openai_api_base=config['LLM_URL'], + openai_api_key=config['LLM_KEY'], + max_tokens=config['LLM_MAX_TOKENS'], + request_timeout=60, + temperature=0.5) + try: + note = await SqlGenerateService.merge_table_and_column_info(table_info, column_info_list) + prompt = prompt_dict.get('sql_expand_prompt', '') + prompt = prompt.format( + database_type=database_type, note=note, sql_failed=sql_failed, + sql_failed_message=sql_failed_message, + question=question) + except Exception as e: + logging.error(f'sql修复失败由于{e}') + return '' + sql = await llm.chat_with_model(prompt, f'请输出一条在与{database_type}下能运行的sql,要以分号结尾') + sql = await SqlGenerateService.extract_select_statements(sql) + logging.info(f"修复前的sql为{sql_failed}修复后的sql为{sql}") + except Exception as e: + logging.error(f'sql生成失败由于:{e}') + return '' + return sql diff --git a/chat2db/common/.env.example b/chat2db/common/.env.example new file mode 100644 index 0000000000000000000000000000000000000000..9011268f306402ccb853ecca76d219888d11f03c --- /dev/null +++ b/chat2db/common/.env.example @@ -0,0 +1,26 @@ +# FastAPI +UVICORN_IP= +UVICORN_PORT= +SSL_CERTFILE= +SSL_KEYFILE= +SSL_ENABLE= + +# Postgres +DATABASE_URL= + +# QWEN +LLM_KEY= +LLM_URL= +LLM_MAX_TOKENS= +LLM_MODEL= + +# Vectorize +EMBEDDING_TYPE= +EMBEDDING_API_KEY= +EMBEDDING_ENDPOINT= +EMBEDDING_MODEL_NAME= + +# security +HALF_KEY1= +HALF_KEY2= +HALF_KEY3= \ No newline at end of file diff --git a/chat2db/common/init_sql_example.py b/chat2db/common/init_sql_example.py new file mode 100644 index 0000000000000000000000000000000000000000..4288dfa8256098db5bb69204e4f80b03d11678f5 --- /dev/null +++ b/chat2db/common/init_sql_example.py @@ -0,0 +1,29 @@ +import yaml +import requests +chat2db_url='http://0.0.0.0:9015' +with open('table_name_id.yaml') as f: + table_name_id=yaml.load(f,Loader=yaml.SafeLoader) +with open('table_name_sql_exmple.yaml') as f: + table_name_sql_example_list=yaml.load(f,Loader=yaml.SafeLoader) +for table_name_sql_example in table_name_sql_example_list: + table_name=table_name_sql_example['table_name'] + if table_name not in table_name_id: + continue + table_id=table_name_id[table_name] + sql_example_list=table_name_sql_example['sql_example_list'] + for sql_example in sql_example_list: + request_data = { + "table_id": str(table_id), + "question": sql_example['question'], + "sql": sql_example['sql'] + } + url = f"{chat2db_url}/sql/example/add" # 请替换为实际的 API 域名 + + try: + response = requests.post(url, json=request_data) + if response.status_code!=200: + print(f'添加sql案例失败{response.text}') + else: + print(f'添加sql案例成功{response.text}') + except Exception as e: + print(f'添加sql案例失败由于{e}') \ No newline at end of file diff --git a/chat2db/common/table_name_id.yaml b/chat2db/common/table_name_id.yaml new file mode 100644 index 0000000000000000000000000000000000000000..22d04d50455fb44439d9bc34dac84ffe409dae8a --- /dev/null +++ b/chat2db/common/table_name_id.yaml @@ -0,0 +1,10 @@ +oe_community_openeuler_version: 5ffcd194-b895-4d3c-a3ad-c4af4a270d6a +oe_community_organization_structure: 89cd2fd0-a5ca-4ee8-8bfe-44da86a32208 +oe_compatibility_card: fb4dbdda-88c5-482a-bec8-28d51669284e +oe_compatibility_commercial_software: 86ab7dad-4848-48da-8667-72be6c99780f +oe_compatibility_cve_database: 984d1c82-c6d5-4d3d-93d9-8d5bc11254ba +oe_compatibility_oepkgs: bb2698c7-f715-487f-95c4-1061a9c33851 +oe_compatibility_osv: 8c9f6608-e2e2-475d-8f67-b3bdab3e9234 +oe_compatibility_overall_unit: 82b9521a-c924-4c52-aedc-229bca5ea4c0 +oe_compatibility_security_notice: f0fa7cc5-4e5d-4c69-b202-39d522f18383 +oe_compatibility_solution: 7bdaf15c-af9f-4cb8-adec-e7e45f1de6ca diff --git a/chat2db/common/table_name_sql_exmple.yaml b/chat2db/common/table_name_sql_exmple.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8e87a1100ebebd05e579c244395ec6e97698f094 --- /dev/null +++ b/chat2db/common/table_name_sql_exmple.yaml @@ -0,0 +1,490 @@ +- keyword_list: + - test_organization + - product_name + - company_name + sql_example_list: + - question: openEuler支持的哪些商业软件在江苏鲲鹏&欧拉生态创新中心测试通过 + sql: SELECT product_name, product_version, openeuler_version FROM public.oe_compatibility_commercial_software + WHERE test_organization ILIKE '%江苏鲲鹏&欧拉生态创新中心%'; + - question: 哪个版本的openEuler支持的商业软件最多 + sql: SELECT openeuler_version, COUNT(*) AS software_count FROM public.oe_compatibility_commercial_software GROUP + BY openeuler_version ORDER BY software_count DESC LIMIT 1; + - question: openEuler支持测试商业软件的机构有哪些? + sql: SELECT DISTINCT test_organization FROM public.oe_compatibility_commercial_software; + - question: openEuler支持的商业软件有哪些类别 + sql: SELECT DISTINCT "type" FROM public.oe_compatibility_commercial_software; + - question: openEuler有哪些虚拟化类别的商业软件 + sql: SELECT product_name FROM public.oe_compatibility_commercial_software WHERE + "type" ILIKE '%虚拟化%'; + - question: openEuler支持哪些ISV商业软件呢,请列出10个 + sql: SELECT product_name FROM public.oe_compatibility_commercial_software; + - question: openEuler支持的适配Kunpeng 920的互联网商业软件有哪些? + sql: SELECT product_name, openeuler_version,platform_type_and_server_model FROM + public.oe_compatibility_commercial_software WHERE platform_type_and_server_model + ILIKE '%Kunpeng 920%' AND "type" ILIKE '%互联网%' limit 30; + - question: openEuler-22.03版本支持哪些商业软件? + sql: SELECT product_name, openeuler_version FROM oe_compatibility_commercial_software + WHERE openeuler_version ILIKE '%22.03%'; + - question: openEuler支持的数字政府类型的商业软件有哪些 + sql: SELECT product_name, product_version FROM oe_compatibility_commercial_software + WHERE type ILIKE '%数字政府%'; + - question: 有哪些商业软件支持超过一种服务器平台 + sql: SELECT product_name FROM public.oe_compatibility_commercial_software WHERE + platform_type_and_server_model ILIKE '%Intel%' AND platform_type_and_server_model + ILIKE '%Kunpeng%'; + - question: 每个openEuler版本有多少种类型的商业软件支持 + sql: SELECT openeuler_version, COUNT(DISTINCT type) AS type_count FROM public.oe_compatibility_commercial_software GROUP + BY openeuler_version; + - question: openEuler支持的哪些商业ISV在江苏鲲鹏&欧拉生态创新中心测试通过 + sql: SELECT product_name, product_version, openeuler_version FROM public.oe_compatibility_commercial_software + WHERE test_organization ILIKE '%江苏鲲鹏&欧拉生态创新中心%'; + - question: 哪个版本的openEuler支持的商业ISV最多 + sql: SELECT openeuler_version, COUNT(*) AS software_count FROM public.oe_compatibility_commercial_software GROUP + BY openeuler_version ORDER BY software_count DESC LIMIT 1; + - question: openEuler支持测试商业ISV的机构有哪些? + sql: SELECT DISTINCT test_organization FROM public.oe_compatibility_commercial_software; + - question: openEuler支持的商业ISV有哪些类别 + sql: SELECT DISTINCT "type" FROM public.oe_compatibility_commercial_software; + - question: openEuler有哪些虚拟化类别的商业ISV + sql: SELECT product_name FROM public.oe_compatibility_commercial_software WHERE + "type" ILIKE '%虚拟化%'; + - question: openEuler支持哪些ISV商业ISV呢,请列出10个 + sql: SELECT product_name FROM public.oe_compatibility_commercial_software; + - question: openEuler支持的适配Kunpeng 920的互联网商业ISV有哪些? + sql: SELECT product_name, openeuler_version,platform_type_and_server_model FROM + public.oe_compatibility_commercial_software WHERE platform_type_and_server_model + ILIKE '%Kunpeng 920%' AND "type" ILIKE '%互联网%' limit 30; + - question: openEuler-22.03版本支持哪些商业ISV? + sql: SELECT product_name, openeuler_version FROM oe_compatibility_commercial_software + WHERE openeuler_version ILIKE '%22.03%'; + - question: openEuler支持的数字政府类型的商业ISV有哪些 + sql: SELECT product_name, product_version FROM oe_compatibility_commercial_software + WHERE type ILIKE '%数字政府%'; + - question: 有哪些商业ISV支持超过一种服务器平台 + sql: SELECT product_name FROM public.oe_compatibility_commercial_software WHERE + platform_type_and_server_model ILIKE '%Intel%' AND platform_type_and_server_model + ILIKE '%Kunpeng%'; + - question: 每个openEuler版本有多少种类型的商业ISV支持 + sql: SELECT openeuler_version, COUNT(DISTINCT type) AS type_count FROM public.oe_compatibility_commercial_software GROUP + BY openeuler_version; + - question: 卓智校园网接入门户系统基于openeuelr的什么版本? + sql: select * from oe_compatibility_commercial_software where product_name ilike + '%卓智校园网接入门户系统%'; + table_name: oe_compatibility_commercial_software +- keyword_list: + - softwareName + sql_example_list: + - question: openEuler-20.03-LTS-SP1支持哪些开源软件? + sql: SELECT DISTINCT openeuler_version,"softwareName" FROM public.oe_compatibility_open_source_software WHERE + openeuler_version ILIKE '%20.03-LTS-SP1%'; + - question: openEuler的aarch64下支持开源软件 + sql: SELECT "softwareName" FROM public.oe_compatibility_open_source_software WHERE + "arch" ILIKE '%aarch64%'; + - question: openEuler支持开源软件使用了GPLv2+许可证 + sql: SELECT "softwareName" FROM public.oe_compatibility_open_source_software WHERE + "license" ILIKE '%GPLv2+%'; + - question: tcplay支持的架构是什么 + sql: SELECT "arch" FROM public.oe_compatibility_open_source_software WHERE "softwareName" + ILIKE '%tcplay%'; + - question: openEuler支持哪些开源软件,请列出10个 + sql: SELECT "softwareName" FROM public.oe_compatibility_open_source_software LIMIT + 10; + - question: openEuler支持开源软件支持哪些结构 + sql: SELECT "arch" FROM public.oe_compatibility_open_source_software group by + "arch"; + - question: openEuler支持多少个开源软件? + sql: select tmp_table.openeuler_version,count(*) as open_source_software_cnt from + (select DISTINCT openeuler_version,"softwareName" from oe_compatibility_open_source_software) + as tmp_table group by tmp_table.openeuler_version; + - question: openEuler-20.03-LTS-SP1支持哪些开源ISV? + sql: SELECT DISTINCT openeuler_version,"softwareName" FROM public.oe_compatibility_open_source_software WHERE + openeuler_version ILIKE '%20.03-LTS-SP1%'; + - question: openEuler的aarch64下支持开源ISV + sql: SELECT "softwareName" FROM public.oe_compatibility_open_source_software WHERE + "arch" ILIKE '%aarch64%'; + - question: openEuler支持开源ISV使用了GPLv2+许可证 + sql: SELECT "softwareName" FROM public.oe_compatibility_open_source_software WHERE + "license" ILIKE '%GPLv2+%'; + - question: tcplay支持的架构是什么 + sql: SELECT "arch" FROM public.oe_compatibility_open_source_software WHERE "softwareName" + ILIKE '%tcplay%'; + - question: openEuler支持哪些开源ISV,请列出10个 + sql: SELECT "softwareName" FROM public.oe_compatibility_open_source_software LIMIT + 10; + - question: openEuler支持开源ISV支持哪些结构 + sql: SELECT "arch" FROM public.oe_compatibility_open_source_software group by + "arch"; + - question: openEuler-20.03-LTS-SP1支持多少个开源ISV? + sql: select tmp_table.openeuler_version,count(*) as open_source_software_cnt from + (select DISTINCT openeuler_version,"softwareName" from oe_compatibility_open_source_software + where openeuler_version ilike 'openEuler-20.03-LTS-SP1') as tmp_table group + by tmp_table.openeuler_version; + - question: openEuler支持多少个开源ISV? + sql: select tmp_table.openeuler_version,count(*) as open_source_software_cnt from + (select DISTINCT openeuler_version,"softwareName" from oe_compatibility_open_source_software) + as tmp_table group by tmp_table.openeuler_version; + table_name: oe_compatibility_open_source_software +- keyword_list: [] + sql_example_list: + - question: 在openEuler技术委员会担任委员的人有哪些 + sql: SELECT name FROM oe_community_organization_structure WHERE committee_name + ILIKE '%技术委员会%' AND role = '委员'; + - question: openEuler的委员会中哪些人是教授 + sql: SELECT name FROM oe_community_organization_structure WHERE personal_message + ILIKE '%教授%'; + - question: openEuler各委员会中担任主席有多少个? + sql: SELECT committee_name, COUNT(*) FROM oe_community_organization_structure + WHERE role = '主席' GROUP BY committee_name; + - question: openEuler 用户委员会中有多少位成员 + sql: SELECT count(*) FROM oe_community_organization_structure WHERE committee_name + ILIKE '%用户委员会%'; + - question: openEuler 技术委员会有多少位成员 + sql: SELECT count(*) FROM oe_community_organization_structure WHERE committee_name + ILIKE '%技术委员会%'; + - question: openEuler委员会的委员常务委员会委员有哪些人 + sql: SELECT name FROM oe_community_organization_structure WHERE committee_name + ILIKE '%委员会%' AND role ILIKE '%常务委员会委员%'; + - question: openEuler委员会有哪些人属于华为技术有限公司? + sql: SELECT DISTINCT name FROM oe_community_organization_structure WHERE personal_message + ILIKE '%华为技术有限公司%'; + - question: openEuler每个委员会有多少人? + sql: SELECT committee_name, COUNT(*) FROM oe_community_organization_structure + GROUP BY committee_name; + - question: openEuler的执行总监是谁 + sql: SELECT name FROM oe_community_organization_structure WHERE role = '执行总监'; + - question: openEuler委员会有哪些组织? + sql: SELECT DISTINCT committee_name from oe_community_organization_structure; + - question: openEuler技术委员会的主席是谁? + sql: SELECT committee_name,name FROM oe_community_organization_structure WHERE + role = '主席' and committee_name ilike '%技术委员会%'; + - question: openEuler品牌委员会的主席是谁? + sql: SELECT committee_name,name FROM oe_community_organization_structure WHERE + role = '主席' and committee_name ilike '%品牌委员会%'; + - question: openEuler委员会的主席是谁? + sql: SELECT committee_name,name FROM oe_community_organization_structure WHERE + role = '主席' and committee_name ilike '%openEuler 委员会%'; + - question: openEuler委员会的执行总监是谁? + sql: SELECT committee_name,name FROM oe_community_organization_structure WHERE + role = '执行总监' and committee_name ilike '%openEuler 委员会%'; + - question: openEuler委员会的执行秘书是谁? + sql: SELECT committee_name,name FROM oe_community_organization_structure WHERE + role = '执行秘书' and committee_name ilike '%openEuler 委员会%'; + table_name: oe_community_organization_structure +- keyword_list: + - cve_id + sql_example_list: + - question: 安全公告openEuler-SA-2024-2059的详细信息在哪里? + sql: select DISTINCT security_notice_no,details from oe_compatibility_security_notice + where security_notice_no='openEuler-SA-2024-2059'; + table_name: oe_compatibility_security_notice +- keyword_list: + - hardware_model + sql_example_list: + - question: openEuler-22.03 LTS支持哪些整机? + sql: SELECT main_board_model, cpu, ram FROM oe_compatibility_overall_unit WHERE + openeuler_version ILIKE '%openEuler-22.03-LTS%'; + - question: 查询所有支持`openEuler-22.09`,并且提供详细产品介绍链接的整机型号和它们的内存配置? + sql: SELECT hardware_model, ram FROM oe_compatibility_overall_unit WHERE openeuler_version + ILIKE '%openEuler-22.09%' AND product_information IS NOT NULL; + - question: 显示所有由新华三生产,支持`openEuler-20.03 LTS SP2`版本的整机,列出它们的型号和架构类型 + sql: SELECT hardware_model, architecture FROM oe_compatibility_overall_unit WHERE + hardware_factory = '新华三' AND openeuler_version ILIKE '%openEuler-20.03 LTS SP2%'; + - question: openEuler支持多少种整机? + sql: SELECT count(DISTINCT main_board_model) FROM oe_compatibility_overall_unit; + - question: openEuler每个版本支持多少种整机? + sql: select openeuler_version,count(*) from (SELECT DISTINCT openeuler_version,main_board_model + FROM oe_compatibility_overall_unit) as tmp_table group by openeuler_version; + - question: openEuler每个版本多少种架构的整机? + sql: select openeuler_version,architecture,count(*) from (SELECT DISTINCT openeuler_version,architecture,main_board_model + FROM oe_compatibility_overall_unit) as tmp_table group by openeuler_version,architecture; + table_name: oe_compatibility_overall_unit +- keyword_list: + - osv_name + - os_version + sql_example_list: + - question: 深圳开鸿数字产业发展有限公司基于openEuler的什么版本发行了什么商用版本? + sql: select os_version,openeuler_version,os_download_link from oe_compatibility_osv + where osv_name='深圳开鸿数字产业发展有限公司'; + - question: 统计各个openEuler版本下的商用操作系统数量 + sql: SELECT openeuler_version, COUNT(*) AS os_count FROM public.oe_compatibility_osv GROUP + BY openeuler_version; + - question: 哪个OS厂商基于openEuler发布的商用操作系统最多 + sql: SELECT osv_name, COUNT(*) AS os_count FROM public.oe_compatibility_osv GROUP + BY osv_name ORDER BY os_count DESC LIMIT 1; + - question: 不同OS厂商基于openEuler发布不同架构的商用操作系统数量是多少? + sql: SELECT arch, osv_name, COUNT(*) AS os_count FROM public.oe_compatibility_osv GROUP + BY arch, osv_name ORDER BY arch, os_count DESC; + - question: 深圳开鸿数字产业发展有限公司的商用操作系统是基于什么openEuler版本发布的 + sql: SELECT os_version, openeuler_version FROM public.oe_compatibility_osv WHERE + osv_name ILIKE '%深圳开鸿数字产业发展有限公司%'; + - question: openEuler有哪些OSV伙伴 + sql: SELECT DISTINCT osv_name FROM public.oe_compatibility_osv; + - question: 有哪些OSV友商的操作系统是x86_64架构的 + sql: SELECT osv_name, os_version FROM public.oe_compatibility_osv WHERE arch ILIKE + '%x86_64%'; + - question: 哪些OSV友商操作系统是嵌入式类型的 + sql: SELECT osv_name, os_version,openeuler_version FROM public.oe_compatibility_osv + WHERE type ILIKE '%嵌入式%'; + - question: 成都鼎桥的商用操作系统版本是基于openEuler 22.03的版本吗 + sql: SELECT osv_name, os_version,"openeuler_version" FROM public.oe_compatibility_osv WHERE + osv_name ILIKE '%成都鼎桥通信技术有限公司%' AND openeuler_version ILIKE '%22.03%'; + - question: 最近发布的基于openEuler 23.09的商用系统有哪些 + sql: SELECT osv_name, os_version,"openeuler_version" FROM public.oe_compatibility_osv WHERE + openeuler_version ILIKE '%23.09%' ORDER BY date DESC limit 10; + - question: 帮我查下成都智明达发布的所有嵌入式系统 + sql: SELECT osv_name, os_version,"openeuler_version" FROM public.oe_compatibility_osv WHERE + osv_name ILIKE '%成都智明达电子股份有限公司%' AND type = '嵌入式'; + - question: 基于openEuler发布的商用操作系统有哪些类型 + sql: SELECT DISTINCT type FROM public.oe_compatibility_osv; + - question: 江苏润和系统版本HopeOS-V22-x86_64-dvd.iso基于openEuler哪个版本 + sql: SELECT DISTINCT osv_name, os_version,"openeuler_version" FROM public.oe_compatibility_osv + WHERE "osv_name" ILIKE '%江苏润和%' AND os_version ILIKE '%HopeOS-V22-x86_64-dvd.iso%' + ; + - question: 浙江大华DH-IVSS-OSV-22.03-LTS-SP2-x86_64-dvd.iso系统版本基于openEuler哪个版本 + sql: SELECT DISTINCT osv_name, os_version,"openeuler_version" FROM public.oe_compatibility_osv + WHERE "osv_name" ILIKE '%浙江大华%' AND os_version ILIKE '%DH-IVSS-OSV-22.03-LTS-SP2-x86_64-dvd.iso%' + ; + table_name: oe_compatibility_osv +- keyword_list: + - board_model + - chip_model + - chip_vendor + - product + sql_example_list: + - question: openEuler 22.03支持哪些网络接口卡型号? + sql: SELECT board_model, chip_model,type FROM oe_compatibility_card WHERE type + ILIKE '%NIC%' AND openeuler_version ILIKE '%22.03%' limit 30; + - question: 请列出openEuler支持的所有Renesas公司的密码卡 + sql: SELECT * FROM oe_compatibility_card WHERE chip_vendor ILIKE '%Renesas%' AND + type ILIKE '%密码卡%' limit 30; + - question: openEuler各种架构支持的板卡数量是多少 + sql: SELECT architecture, COUNT(*) AS total_cards FROM oe_compatibility_card GROUP + BY architecture limit 30; + - question: 每个openEuler版本支持了多少种板卡 + sql: SELECT openeuler_version, COUNT(*) AS number_of_cards FROM oe_compatibility_card + GROUP BY openeuler_version limit 30; + - question: openEuler总共支持多少种不同的板卡型号 + sql: SELECT COUNT(DISTINCT board_model) AS board_model_cnt FROM oe_compatibility_card + limit 30; + - question: openEuler支持的GPU型号有哪些? + sql: SELECT chip_model, openeuler_version,type FROM public.oe_compatibility_card WHERE + type ILIKE '%GPU%' ORDER BY driver_date DESC limit 30; + - question: openEuler 20.03 LTS-SP4版本支持哪些类型的设备 + sql: SELECT DISTINCT openeuler_version,type FROM public.oe_compatibility_card WHERE + openeuler_version ILIKE '%20.03-LTS-SP4%' limit 30; + - question: openEuler支持的板卡驱动在2023年后发布 + sql: SELECT board_model, driver_date, driver_name FROM oe_compatibility_card WHERE + driver_date >= '2023-01-01' limit 30; + - question: 给些支持openEuler的aarch64架构下支持的的板卡的驱动下载链接 + sql: SELECT openeuler_version,board_model, download_link FROM oe_compatibility_card + WHERE architecture ILIKE '%aarch64%' AND download_link IS NOT NULL limit 30; + - question: openEuler-22.03-LTS-SP1支持的存储卡有哪些? + sql: SELECT openeuler_version,board_model, chip_model,type FROM oe_compatibility_card + WHERE type ILIKE '%SSD%' AND openeuler_version ILIKE '%openEuler-22.03-LTS-SP1%' + limit 30; + table_name: oe_compatibility_card +- keyword_list: + - cve_id + sql_example_list: + - question: CVE-2024-41053的详细信息在哪里可以看到? + sql: select DISTINCT cve_id,details from oe_compatibility_cve_database where cve_id='CVE-2024-41053'; + - question: CVE-2024-41053是个怎么样的漏洞? + sql: select DISTINCT cve_id,summary from oe_compatibility_cve_database where cve_id='CVE-2024-41053'; + - question: CVE-2024-41053影响了哪些包? + sql: select DISTINCT cve_id,package_name from oe_compatibility_cve_database where + cve_id='CVE-2024-41053'; + - question: CVE-2024-41053的cvss评分是多少? + sql: select DISTINCT cve_id,cvsss_core_nvd from oe_compatibility_cve_database + where cve_id='CVE-2024-41053'; + - question: CVE-2024-41053现在修复了么? + sql: select DISTINCT cve_id, status from oe_compatibility_cve_database where cve_id='CVE-2024-41053'; + - question: CVE-2024-41053影响了openEuler哪些版本? + sql: select DISTINCT cve_id, affected_product from oe_compatibility_cve_database + where cve_id='CVE-2024-41053'; + - question: CVE-2024-41053发布时间是? + sql: select DISTINCT cve_id, announcement_time from oe_compatibility_cve_database + where cve_id='CVE-2024-41053'; + - question: openEuler-20.03-LTS-SP4在2024年8月发布哪些漏洞? + sql: select DISTINCT affected_product,cve_id,announcement_time from oe_compatibility_cve_database + where cve_id='CVE-2024-41053' and affected_product='openEuler-20.03-LTS-SP4' + and EXTRACT(MONTH FROM announcement_time)=8; + - question: openEuler-20.03-LTS-SP4在2024年发布哪些漏洞? + sql: select DISTINCT affected_product,cve_id,announcement_time from oe_compatibility_cve_database + where cve_id='CVE-2024-41053' and affected_product='openEuler-20.03-LTS-SP4' + and EXTRACT(YEAR FROM announcement_time)=2024; + - question: CVE-2024-41053的威胁程度是怎样的? + sql: select DISTINCT affected_product,cve_id,cvsss_core_nvd,attack_complexity_nvd,attack_complexity_oe,attack_vector_nvd,attack_vector_oe + from oe_compatibility_cve_database where cve_id='CVE-2024-41053'; + table_name: oe_compatibility_cve_database +- keyword_list: + - name + sql_example_list: + - question: openEuler-20.03-LTS的非官方软件包有多少个? + sql: SELECT COUNT(*) FROM oe_compatibility_oepkgs WHERE repotype = 'openeuler_compatible' + AND openeuler_version ILIKE '%openEuler-20.03-LTS%'; + - question: openEuler支持的nginx版本有哪些? + sql: SELECT DISTINCT name,version, srcrpmpackurl FROM oe_compatibility_oepkgs + WHERE name ILIKE 'nginx'; + - question: openEuler的支持哪些架构的glibc? + sql: SELECT DISTINCT name,arch FROM oe_compatibility_oepkgs WHERE name ILIKE 'glibc'; + - question: openEuler-22.03-LTS带GPLv2许可的软件包有哪些 + sql: SELECT name,rpmlicense FROM oe_compatibility_oepkgs WHERE openeuler_version + ILIKE '%openEuler-22.03-LTS%' AND rpmlicense = 'GPLv2'; + - question: openEuler支持的python3这个软件包是用来干什么的? + sql: SELECT DISTINCT name,summary FROM oe_compatibility_oepkgs WHERE name ILIKE + 'python3'; + - question: 哪些版本的openEuler的zlib中有官方源的? + sql: SELECT DISTINCT openeuler_version,name,version FROM oe_compatibility_oepkgs + WHERE name ILIKE '%zlib%' AND repotype = 'openeuler_official'; + - question: 请以表格的形式提供openEuler-20.09的gcc软件包的下载链接 + sql: SELECT DISTINCT openeuler_version,name, rpmpackurl FROM oe_compatibility_oepkgs + WHERE openeuler_version ILIKE '%openEuler-20.09%' AND name ilike 'gcc'; + - question: 请以表格的形式提供openEuler-20.09的glibc软件包的下载链接 + sql: SELECT DISTINCT openeuler_version,name, rpmpackurl FROM oe_compatibility_oepkgs + WHERE openeuler_version ILIKE '%openEuler-20.09%' AND name ilike 'glibc'; + - question: 请以表格的形式提供openEuler-20.09的redis软件包的下载链接 + sql: SELECT DISTINCT openeuler_version,name, rpmpackurl FROM oe_compatibility_oepkgs + WHERE openeuler_version ILIKE '%openEuler-20.09%' AND name ilike 'redis'; + - question: openEuler-20.09的支持多少个软件包? + sql: select tmp_table.openeuler_version,count(*) as oepkgs_cnt from (select DISTINCT + openeuler_version,name from oe_compatibility_oepkgs WHERE openeuler_version + ILIKE '%openEuler-20.09') as tmp_table group by tmp_table.openeuler_version; + - question: openEuler支持多少个软件包? + sql: select tmp_table.openeuler_version,count(*) as oepkgs_cnt from (select DISTINCT + openeuler_version,name from oe_compatibility_oepkgs) as tmp_table group by tmp_table.openeuler_version; + - question: 请以表格的形式提供openEuler-20.09的gcc的版本 + sql: SELECT DISTINCT openeuler_version,name, version FROM oe_compatibility_oepkgs + WHERE openeuler_version ILIKE '%openEuler-20.09%' AND name ilike 'gcc'; + - question: 请以表格的形式提供openEuler-20.09的glibc的版本 + sql: SELECT DISTINCT openeuler_version,name, version FROM oe_compatibility_oepkgs + WHERE openeuler_version ILIKE '%openEuler-20.09%' AND name ilike 'glibc'; + - question: 请以表格的形式提供openEuler-20.09的redis的版本 + sql: SELECT DISTINCT openeuler_version,name, version FROM oe_compatibility_oepkgs + WHERE openeuler_version ILIKE '%openEuler-20.09%' AND name ilike 'redis'; + - question: openEuler-20.09支持哪些gcc的版本 + sql: SELECT DISTINCT openeuler_version,name, version FROM oe_compatibility_oepkgs + WHERE openeuler_version ILIKE '%openEuler-20.09%' AND name ilike 'gcc'; + - question: openEuler-20.09支持哪些glibc的版本 + sql: SELECT DISTINCT openeuler_version,name, version FROM oe_compatibility_oepkgs + WHERE openeuler_version ILIKE '%openEuler-20.09%' AND name ilike 'glibc'; + - question: openEuler-20.09支持哪些redis的版本 + sql: SELECT DISTINCT openeuler_version,name, version FROM oe_compatibility_oepkgs + WHERE openeuler_version ILIKE '%openEuler-20.09%' AND name ilike 'redis'; + - question: '' + sql: openEuler-20.09支持的gcc版本有哪些 + - question: SELECT DISTINCT openeuler_version,name, version FROM oe_compatibility_oepkgs + WHERE openeuler_version ILIKE '%openEuler-20.09%' AND name ilike 'gcc'; + sql: openEuler-20.09支持的glibc版本有哪些 + - question: SELECT DISTINCT openeuler_version,name, version FROM oe_compatibility_oepkgs + WHERE openeuler_version ILIKE '%openEuler-20.09%' AND name ilike 'glibc'; + sql: openEuler-20.09支持的redis版本有哪些 + - question: SELECT DISTINCT openeuler_version,name, version FROM oe_compatibility_oepkgs + WHERE openeuler_version ILIKE '%openEuler-20.09%' AND name ilike 'redis'; + sql: '' + - question: openEuler-20.09支持gcc 9.3.1么? + sql: SELECT DISTINCT openeuler_version,name, version FROM oe_compatibility_oepkgs + WHERE openeuler_version ILIKE '%openEuler-20.09%' AND name ilike 'gcc' AND version + ilike '9.3.1'; + table_name: oe_compatibility_oepkgs +- keyword_list: [] + sql_example_list: + - question: openEuler社区创新版本有哪些 + sql: SELECT DISTINCT openeuler_version,version_type FROM oe_community_openeuler_version + where version_type ILIKE '%社区创新版本%'; + - question: openEuler有哪些版本 + sql: SELECT openeuler_version FROM public.oe_community_openeuler_version; + - question: 查询openeuler各版本对应的内核版本 + sql: SELECT DISTINCT openeuler_version, kernel_version FROM public.oe_community_openeuler_version; + - question: openEuler有多少个长期支持版本(LTS) + sql: SELECT COUNT(*) as publish_version_count FROM public.oe_community_openeuler_version + WHERE version_type ILIKE '%长期支持版本%'; + - question: 查询openEuler-20.03的所有SP版本 + sql: SELECT openeuler_version FROM public.oe_community_openeuler_version WHERE + openeuler_version ILIKE '%openEuler-20.03-LTS-SP%'; + - question: openEuler最新的社区创新版本内核是啥 + sql: SELECT kernel_version FROM public.oe_community_openeuler_version WHERE version_type + ILIKE '%社区创新版本%' ORDER BY publish_time DESC LIMIT 1; + - question: 最早的openEuler版本是什么时候发布的 + sql: SELECT openeuler_version,publish_time FROM public.oe_community_openeuler_version + ORDER BY publish_time ASC LIMIT 1; + - question: 最新的openEuler版本是哪个 + sql: SELECT openeuler_version,publish_time FROM public.oe_community_openeuler_version + ORDER BY publish_time LIMIT 1; + - question: openEuler有哪些版本使用了Linux 5.10.0内核 + sql: SELECT openeuler_version,kernel_version FROM public.oe_community_openeuler_version + WHERE kernel_version ILIKE '5.10.0%'; + - question: 哪个openEuler版本是最近更新的长期支持版本 + sql: SELECT openeuler_version,publish_time FROM public.oe_community_openeuler_version + WHERE version_type ILIKE '%长期支持版本%' ORDER BY publish_time DESC LIMIT 1; + - question: openEuler每个年份发布了多少个版本 + sql: SELECT EXTRACT(YEAR FROM publish_time) AS year, COUNT(*) AS publish_version_count + FROM oe_community_openeuler_version group by EXTRACT(YEAR FROM publish_time); + - question: openEuler-20.03-LTS版本的linux内核是多少? + sql: SELECT openeuler_version,kernel_version FROM public.oe_community_openeuler_version + WHERE openeuler_version = 'openEuler-20.03-LTS'; + - question: openEuler-20.03-LTS版本的linux内核是多少? + sql: SELECT openeuler_version,kernel_version FROM public.oe_community_openeuler_version + WHERE openeuler_version = 'openEuler-24.09'; + table_name: oe_community_openeuler_version +- keyword_list: + - product + sql_example_list: + - question: 哪些openEuler版本支持使用至强6338N的解决方案 + sql: SELECT DISTINCT openeuler_version FROM oe_compatibility_solution WHERE cpu + ILIKE '%6338N%'; + - question: 使用intel XXV710作为网卡的解决方案对应的是哪些服务器型号 + sql: SELECT DISTINCT server_model FROM oe_compatibility_solution WHERE network_card + ILIKE '%intel XXV710%'; + - question: 哪些解决方案的硬盘驱动为SATA-SSD Skhynix + sql: SELECT DISTINCT product FROM oe_compatibility_solution WHERE hard_disk_drive + ILIKE 'SATA-SSD Skhynix'; + - question: 查询所有使用6230R系列CPU且支持磁盘阵列支持PERC H740P Adapter的解决方案的产品名 + sql: SELECT DISTINCT product FROM oe_compatibility_solution WHERE cpu ILIKE '%6230R%' + AND raid ILIKE '%PERC H740P Adapter%'; + - question: R4900-G3有哪些驱动版本 + sql: SELECT DISTINCT driver FROM oe_compatibility_solution WHERE product ILIKE + '%R4900-G3%'; + - question: DL380 Gen10支持哪些架构 + sql: SELECT DISTINCT architecture FROM oe_compatibility_solution WHERE server_model + ILIKE '%DL380 Gen10%'; + - question: 列出所有使用Intel(R) Xeon(R)系列cpu且磁盘冗余阵列为LSI SAS3408的解决方案的服务器厂家 + sql: SELECT DISTINCT server_vendor FROM oe_compatibility_solution WHERE cpu ILIKE + '%Intel(R) Xeon(R)%' AND raid ILIKE '%LSI SAS3408%'; + - question: 哪些解决方案提供了针对SEAGATE ST4000NM0025硬盘驱动的支持 + sql: SELECT * FROM oe_compatibility_solution WHERE hard_disk_drive ILIKE '%SEAGATE + ST4000NM0025%'; + - question: 查询所有使用4316系列CPU的解决方案 + sql: SELECT * FROM oe_compatibility_solution WHERE cpu ILIKE '%4316%'; + - question: 支持openEuler-22.03-LTS-SP2版本的解决方案中,哪款服务器型号出现次数最多 + sql: SELECT server_model, COUNT(*) as count FROM oe_compatibility_solution WHERE + openeuler_version ILIKE '%openEuler-22.03-LTS-SP2%' GROUP BY server_model ORDER + BY count DESC LIMIT 1; + - question: HPE提供的解决方案的介绍链接是什么 + sql: SELECT DISTINCT introduce_link FROM oe_compatibility_solution WHERE server_vendor + ILIKE '%HPE%'; + - question: 列出所有使用intel XXV710网络卡接口的解决方案的CPU型号 + sql: SELECT DISTINCT cpu FROM oe_compatibility_solution WHERE network_card ILIKE + '%intel XXV710%'; + - question: 服务器型号为2288H V5的解决方案支持哪些不同的openEuler版本 + sql: SELECT DISTINCT openeuler_version FROM oe_compatibility_solution WHERE server_model + ILIKE '%NF5180M5%'; + - question: 使用6230R系列CPU的解决方案内存最小是多少GB + sql: SELECT MIN(ram) FROM oe_compatibility_solution WHERE cpu ILIKE '%6230R%'; + - question: 哪些解决方案的磁盘驱动为MegaRAID 9560-8i + sql: SELECT * FROM oe_compatibility_solution WHERE hard_disk_drive LIKE '%MegaRAID + 9560-8i%'; + - question: 列出所有使用6330N系列CPU且服务器厂家为Dell的解决方案的产品名 + sql: SELECT DISTINCT product FROM oe_compatibility_solution WHERE cpu ILIKE '%6330N%' + AND server_vendor ILIKE '%Dell%'; + - question: R4900-G3的驱动版本是多少 + sql: SELECT driver FROM oe_compatibility_solution WHERE product ILIKE '%R4900-G3%'; + - question: 哪些解决方案的服务器型号为2288H V7 + sql: SELECT * FROM oe_compatibility_solution WHERE server_model ILIKE '%2288H + V7%'; + - question: 使用Intel i350网卡且硬盘驱动为ST4000NM0025的解决方案的服务器厂家有哪些 + sql: SELECT DISTINCT server_vendor FROM oe_compatibility_solution WHERE network_card + ILIKE '%Intel i350%' AND hard_disk_drive ILIKE '%ST4000NM0025%'; + - question: 有多少种不同的驱动版本被用于支持openEuler-22.03-LTS-SP2版本的解决方案 + sql: SELECT COUNT(DISTINCT driver) FROM oe_compatibility_solution WHERE openeuler_version + ILIKE '%openEuler-22.03-LTS-SP2%'; + table_name: oe_compatibility_solution diff --git a/chat2db/config/config.py b/chat2db/config/config.py new file mode 100644 index 0000000000000000000000000000000000000000..93faa3aa13e07a629bb38c3d7732c8cef8c21bfd --- /dev/null +++ b/chat2db/config/config.py @@ -0,0 +1,55 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +import os + +from dotenv import dotenv_values +from pydantic import BaseModel, Field + + +class ConfigModel(BaseModel): + # FastAPI + UVICORN_IP: str = Field(None, description="FastAPI 服务的IP地址") + UVICORN_PORT: int = Field(None, description="FastAPI 服务的端口号") + SSL_CERTFILE: str = Field(None, description="SSL证书文件的路径") + SSL_KEYFILE: str = Field(None, description="SSL密钥文件的路径") + SSL_ENABLE: str = Field(None, description="是否启用SSL连接") + + # Postgres + DATABASE_URL: str = Field(None, description="数据库url") + + # QWEN + LLM_KEY: str = Field(None, description="语言模型访问密钥") + LLM_URL: str = Field(None, description="语言模型服务的基础URL") + LLM_MAX_TOKENS: int = Field(None, description="单次请求中允许的最大Token数") + LLM_MODEL: str = Field(None, description="使用的语言模型名称或版本") + + # Vectorize + EMBEDDING_TYPE: str = Field("openai", description="embedding 服务的类型") + EMBEDDING_API_KEY: str = Field(None, description="embedding服务api key") + EMBEDDING_ENDPOINT: str = Field(None, description="embedding服务url地址") + EMBEDDING_MODEL_NAME: str = Field(None, description="embedding模型名称") + + # security + HALF_KEY1: str = Field(None, description='加密的密钥组件1') + HALF_KEY2: str = Field(None, description='加密的密钥组件2') + HALF_KEY3: str = Field(None, description='加密的密钥组件3') + + +class Config: + config: ConfigModel + + def __init__(self): + if os.getenv("CONFIG"): + config_file = os.getenv("CONFIG") + else: + config_file = "./chat2db/common/.env" + self.config = ConfigModel(**(dotenv_values(config_file))) + if os.getenv("PROD"): + os.remove(config_file) + + def __getitem__(self, key): + if key in self.config.__dict__: + return self.config.__dict__[key] + return None + + +config = Config() diff --git a/chat2db/database/postgres.py b/chat2db/database/postgres.py new file mode 100644 index 0000000000000000000000000000000000000000..0148c4bd586daf0bf279549ed2725f89b45d6be9 --- /dev/null +++ b/chat2db/database/postgres.py @@ -0,0 +1,119 @@ +import logging +from uuid import uuid4 +from pgvector.sqlalchemy import Vector +from sqlalchemy.orm import sessionmaker, declarative_base +from sqlalchemy import TIMESTAMP, UUID, Column, String, Boolean, ForeignKey, create_engine, func, Index +from chat2db.config.config import config + +logging.basicConfig(filename='app.log', level=logging.INFO, + format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s') +Base = declarative_base() + + +class DatabaseInfo(Base): + __tablename__ = 'database_info_table' + id = Column(UUID(), default=uuid4, primary_key=True) + encrypted_database_url = Column(String()) + encrypted_config = Column(String()) + hashmac = Column(String()) + created_at = Column(TIMESTAMP(timezone=True), nullable=True, server_default=func.current_timestamp()) + + +class TableInfo(Base): + __tablename__ = 'table_info_table' + id = Column(UUID(), default=uuid4, primary_key=True) + database_id = Column(UUID(), ForeignKey('database_info_table.id', ondelete='CASCADE')) + table_name = Column(String()) + table_note = Column(String()) + table_note_vector = Column(Vector(1024)) + enable = Column(Boolean, default=False) + created_at = Column(TIMESTAMP(timezone=True), nullable=True, server_default=func.current_timestamp()) + updated_at = Column( + TIMESTAMP(timezone=True), + server_default=func.current_timestamp(), + onupdate=func.current_timestamp()) + __table_args__ = ( + Index( + 'table_note_vector_index', + table_note_vector, + postgresql_using='hnsw', + postgresql_with={'m': 16, 'ef_construction': 200}, + postgresql_ops={'table_note_vector': 'vector_cosine_ops'} + ), + ) + + +class ColumnInfo(Base): + __tablename__ = 'column_info_table' + id = Column(UUID(), default=uuid4, primary_key=True) + table_id = Column(UUID(), ForeignKey('table_info_table.id', ondelete='CASCADE')) + column_name = Column(String) + column_type = Column(String) + column_note = Column(String) + enable = Column(Boolean, default=False) + + +class SqlExample(Base): + __tablename__ = 'sql_example_table' + id = Column(UUID(), default=uuid4, primary_key=True) + table_id = Column(UUID(), ForeignKey('table_info_table.id', ondelete='CASCADE')) + question = Column(String()) + sql = Column(String()) + question_vector = Column(Vector(1024)) + created_at = Column(TIMESTAMP(timezone=True), nullable=True, server_default=func.current_timestamp()) + updated_at = Column( + TIMESTAMP(timezone=True), + server_default=func.current_timestamp(), + onupdate=func.current_timestamp()) + __table_args__ = ( + Index( + 'question_vector_index', + question_vector, + postgresql_using='hnsw', + postgresql_with={'m': 16, 'ef_construction': 200}, + postgresql_ops={'question_vector': 'vector_cosine_ops'} + ), + ) + + +class PostgresDB: + _engine = None + + @classmethod + def get_mysql_engine(cls): + if not cls._engine: + cls.engine = create_engine( + config['DATABASE_URL'], + hide_parameters=True, + echo=False, + pool_recycle=300, + pool_pre_ping=True) + + Base.metadata.create_all(cls.engine) + return cls._engine + + @classmethod + def get_session(cls): + connection = None + try: + connection = sessionmaker(bind=cls.engine)() + except Exception as e: + logging.error(f"Error creating a postgres sessiondue to error: {e}") + return None + return cls._ConnectionManager(connection) + + class _ConnectionManager: + def __init__(self, connection): + self.connection = connection + + def __enter__(self): + return self.connection + + def __exit__(self, exc_type, exc_val, exc_tb): + try: + self.connection.close() + except Exception as e: + logging.error(f"Postgres connection close failed due to error: {e}") + + +PostgresDB.get_mysql_engine() diff --git a/chat2db/docs/change_txt_to_yaml.py b/chat2db/docs/change_txt_to_yaml.py new file mode 100644 index 0000000000000000000000000000000000000000..8e673d817e146c9762c7c712ab5ecf7a8689bf3b --- /dev/null +++ b/chat2db/docs/change_txt_to_yaml.py @@ -0,0 +1,92 @@ +import yaml +text = { + 'sql_generate_base_on_example_prompt': '''你是一个数据库专家,你的任务是参考给出的表结构以及表注释和示例,基于给出的问题生成一条在{database_url}连接下可进行查询的sql语句。 +注意: +#01 sql语句中,特殊字段需要带上双引号。 +#02 sql语句中,如果要使用 as,请用双引号把别名包裹起来。 +#03 sql语句中,查询字段必须使用`distinct`关键字去重。 +#04 sql语句中,只返回生成的sql语句, 不要返回其他任何无关的内容 +#05 sql语句中,参考问题,对查询字段进行冗余。 +#06 sql语句中,需要以分号结尾。 + +以下是表结构以及表注释: +{note} +以下是{k}个示例: +{sql_example} +以下是问题: +{question} +''', + 'question_generate_base_on_data_prompt': '''你是一个postgres数据库专家,你的任务是根据给出的表结构和表内数据,输出一个用户可能针对这张表内的信息提出的问题。 +注意: +#01 问题内容和形式需要多样化,例如要用到统计、排序、模糊匹配等相关问题。 +#02 要以口语化的方式输出问题,不要机械的使用表内字段输出问题。 +#03 不要输出问题之外多余的内容! +#04 要基于用户的角度取提出问题,问题内容需要口语化、拟人化。 +#05 优先生成有注释的字段相关的sql语句。 + +以下是表结构和注释: +{note} +以下是表内数据 +{data_frame} +''', + 'sql_generate_base_on_data_prompt': '''你是一个postgres数据库专家,你的任务是参考给出的表结构以及表注释和表内数据,基于给出的问题生成一条查询{database_type}的sql语句。 +注意: +#01 sql语句中,特殊字段需要带上双引号。 +#02 sql语句中,如果要使用 as,请用双引号把别名包裹起来。 +#03 sql语句中,查询字段必须使用`distinct`关键字去重。 +#04 sql语句中,只返回生成的sql语句, 不要返回其他任何无关的内容 +#05 sql语句中,参考问题,对查询字段进行冗余。 +#06 sql语句中,需要以分号结尾。 + +以下是表结构以及表注释: +{note} +以下是表内的数据: +{data_frame} +以下是问题: +{question} +''', + 'sql_expand_prompt': '''你是一个数据库专家,你的任务是参考给出的表结构以及表注释、执行失败的sql和执行失败的报错,基于给出的问题修改执行失败的sql生成一条在{database_type}连接下可进行查询的sql语句。 + + 注意: + + #01 假设sql中有特殊字符干扰了sql的执行,请优先替换这些特殊字符保证sql可执行。 + + #02 假设sql用于检索或者过滤的字段导致了sql执行的失败,请尝试替换这些字段保证sql可执行。 + + #03 假设sql检索结果为空,请尝试将 = 的匹配方式替换为 ilike \'\%\%\' 保证sql执行给出结果。 + + #04 假设sql检索结果为空,可以使用问题中的关键字的子集作为sql的过滤条件保证sql执行给出结果。 + + 以下是表结构以及表注释: + + {note} + + 以下是执行失败的sql: + + {sql_failed} + + 以下是执行失败的报错: + + {sql_failed_message} + + 以下是问题: + + {question} +''', + 'table_choose_prompt': '''你是一个数据库专家,你的任务是参考给出的表名以及表的条目(主键,表名、表注释),输出最适配于问题回答检索的{table_cnt}张表,并返回表对应的主键。 +注意: +#01 输出的表名用python的list格式返回,下面是list的一个示例: +[\"prime_key1\",\"prime_key2\"]。 +#02 只输出包含主键的list即可不要输出其他内容!!! +#03 list重主键的顺序,按表与问题的适配程度从高到底排列。 +#04 若无任何一张表适用于问题的回答,请返回空列表。 + +以下是表的条目: +{table_entries} +以下是问题: +{question} +''' +} +print(text) +with open('./prompt.yaml', 'w', encoding='utf-8') as f: + yaml.dump(text, f, allow_unicode=True) diff --git a/chat2db/docs/prompt.yaml b/chat2db/docs/prompt.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0013b12f6df19007bbc0467d2dc8add497469a1d --- /dev/null +++ b/chat2db/docs/prompt.yaml @@ -0,0 +1,115 @@ +question_generate_base_on_data_prompt: '你是一个postgres数据库专家,你的任务是根据给出的表结构和表内数据,输出一个用户可能针对这张表内的信息提出的问题。 + + 注意: + + #01 问题内容和形式需要多样化,例如要用到统计、排序、模糊匹配等相关问题。 + + #02 要以口语化的方式输出问题,不要机械的使用表内字段输出问题。 + + #03 不要输出问题之外多余的内容! + + #04 要基于用户的角度取提出问题,问题内容需要口语化、拟人化。 + + #05 优先生成有注释的字段相关的sql语句。 + + #06 不要对生成的sql进行解释。 + + 以下是表结构和注释: + + {note} + + 以下是表内数据 + + {data_frame} + + ' +sql_expand_prompt: "你是一个数据库专家,你的任务是参考给出的表结构以及表注释、执行失败的sql和执行失败的报错,基于给出的问题修改执行失败的sql生成一条在{database_type}连接下可进行查询的sql语句。\n\ + \n 注意:\n\n #01 假设sql中有特殊字符干扰了sql的执行,请优先替换这些特殊字符保证sql可执行。\n\n #02 假设sql用于检索或者过滤的字段导致了sql执行的失败,请尝试替换这些字段保证sql可执行。\n\ + \n #03 假设sql检索结果为空,请尝试将 = 的匹配方式替换为 ilike '\\%\\%' 保证sql执行给出结果。\n\n #04 假设sql检索结果为空,可以使用问题中的关键字的子集作为sql的过滤条件保证sql执行给出结果。\n\ + \n 以下是表结构以及表注释:\n\n {note}\n\n 以下是执行失败的sql:\n\n {sql_failed}\n\n 以下是执行失败的报错:\n\ + \n {sql_failed_message}\n\n 以下是问题:\n\n {question}\n" +sql_generate_base_on_data_prompt: '你是一个postgres数据库专家,你的任务是参考给出的表结构以及表注释和表内数据,基于给出的问题生成一条查询{database_type}的sql语句。 + + 注意: + + #01 sql语句中,特殊字段需要带上双引号。 + + #02 sql语句中,如果要使用 as,请用双引号把别名包裹起来。 + + #03 sql语句中,查询字段必须使用`distinct`关键字去重。 + + #04 sql语句中,只返回生成的sql语句, 不要返回其他任何无关的内容 + + #05 sql语句中,参考问题,对查询字段进行冗余。 + + #06 sql语句中,需要以分号结尾。 + + #07 不要对生成的sql进行解释。 + + 以下是表结构以及表注释: + + {note} + + 以下是表内的数据: + + {data_frame} + + 以下是问题: + + {question} + + ' +sql_generate_base_on_example_prompt: '你是一个数据库专家,你的任务是参考给出的表结构以及表注释和示例,基于给出的问题生成一条在{database_url}连接下可进行查询的sql语句。 + + 注意: + + #01 sql语句中,特殊字段需要带上双引号。 + + #02 sql语句中,如果要使用 as,请用双引号把别名包裹起来。 + + #03 sql语句中,查询字段必须使用`distinct`关键字去重。 + + #04 sql语句中,只返回生成的sql语句, 不要返回其他任何无关的内容 + + #05 sql语句中,参考问题,对查询字段进行冗余。 + + #06 sql语句中,需要以分号结尾。 + + + 以下是表结构以及表注释: + + {note} + + 以下是{k}个示例: + + {sql_example} + + 以下是问题: + + {question} + + ' +table_choose_prompt: '你是一个数据库专家,你的任务是参考给出的表名以及表的条目(主键,表名、表注释),输出最适配于问题回答检索的{table_cnt}张表,并返回表对应的主键。 + + 注意: + + #01 输出的表名用python的list格式返回,下面是list的一个示例: + + ["prime_key1","prime_key2"]。 + + #02 只输出包含主键的list即可不要输出其他内容!!! + + #03 list重主键的顺序,按表与问题的适配程度从高到底排列。 + + #04 若无任何一张表适用于问题的回答,请返回空列表。 + + + 以下是表的条目: + + {table_entries} + + 以下是问题: + + {question} + + ' diff --git a/chat2db/llm/chat_with_model.py b/chat2db/llm/chat_with_model.py new file mode 100644 index 0000000000000000000000000000000000000000..9cc1ad2d60bd40e79318ef4df29fc9d47f15c250 --- /dev/null +++ b/chat2db/llm/chat_with_model.py @@ -0,0 +1,25 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +from langchain_openai import ChatOpenAI +from langchain.schema import SystemMessage, HumanMessage +import re + +class LLM: + def __init__(self, model_name, openai_api_base, openai_api_key, request_timeout, max_tokens, temperature): + self.client = ChatOpenAI(model_name=model_name, + openai_api_base=openai_api_base, + openai_api_key=openai_api_key, + request_timeout=request_timeout, + max_tokens=max_tokens, + temperature=temperature) + + def assemble_chat(self, system_call, user_call): + chat = [] + chat.append(SystemMessage(content=system_call)) + chat.append(HumanMessage(content=user_call)) + return chat + + async def chat_with_model(self, system_call, user_call): + chat = self.assemble_chat(system_call, user_call) + response = await self.client.ainvoke(chat) + content = re.sub(r'.*?\n\n', '', response.content, flags=re.DOTALL) + return content diff --git a/chat2db/manager/column_info_manager.py b/chat2db/manager/column_info_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..3816bcf01fcdfe3e8e192979de585a1aa645ec8c --- /dev/null +++ b/chat2db/manager/column_info_manager.py @@ -0,0 +1,69 @@ +from sqlalchemy import and_ + +from chat2db.database.postgres import ColumnInfo, PostgresDB + + +class ColumnInfoManager(): + @staticmethod + async def add_column_info_with_table_id(table_id, column_name, column_type, column_note): + column_info_entry = ColumnInfo(table_id=table_id, column_name=column_name, + column_type=column_type, column_note=column_note) + with PostgresDB.get_session() as session: + session.add(column_info_entry) + session.commit() + + @staticmethod + async def del_column_info_by_column_id(column_id): + with PostgresDB.get_session() as session: + column_info_to_delete = session.query(ColumnInfo).filter(ColumnInfo.id == column_id) + session.delete(column_info_to_delete) + session.commit() + + @staticmethod + async def get_column_info_by_column_id(column_id): + tmp_dict = {} + with PostgresDB.get_session() as session: + result = session.query(ColumnInfo).filter(ColumnInfo.id == column_id).first() + session.commit() + if not result: + return None + tmp_dict = { + 'column_id': result.id, + 'table_id': result.table_id, + 'column_name': result.column_name, + 'column_type': result.column_type, + 'column_note': result.column_note, + 'enable': result.enable + } + return tmp_dict + + @staticmethod + async def update_column_info_enable(column_id, enable=True): + with PostgresDB.get_session() as session: + column_info = session.query(ColumnInfo).filter(ColumnInfo.id == column_id).first() + if column_info is not None: + column_info.enable = True + session.commit() + else: + return False + return True + + @staticmethod + async def get_column_info_by_table_id(table_id, enable=None): + column_info_list = [] + with PostgresDB.get_session() as session: + if enable is None: + results = session.query(ColumnInfo).filter(ColumnInfo.table_id == table_id).all() + else: + results = session.query(ColumnInfo).filter( + and_(ColumnInfo.table_id == table_id, ColumnInfo.enable == enable)).all() + for result in results: + tmp_dict = { + 'column_id': result.id, + 'column_name': result.column_name, + 'column_type': result.column_type, + 'column_note': result.column_note, + 'enable': result.enable + } + column_info_list.append(tmp_dict) + return column_info_list diff --git a/chat2db/manager/database_info_manager.py b/chat2db/manager/database_info_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..e6109c193d62ef65a4aa185c6e9305a313936784 --- /dev/null +++ b/chat2db/manager/database_info_manager.py @@ -0,0 +1,74 @@ +import json +import hashlib +import logging +from chat2db.database.postgres import DatabaseInfo, PostgresDB +from chat2db.security.security import Security + +logging.basicConfig(filename='app.log', level=logging.INFO, + format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s') + + +class DatabaseInfoManager(): + @staticmethod + async def add_database(database_url: str): + id = None + with PostgresDB.get_session() as session: + encrypted_database_url, encrypted_config = Security.encrypt(database_url) + hashmac = hashlib.sha256(database_url.encode('utf-8')).hexdigest() + counter = session.query(DatabaseInfo).filter(DatabaseInfo.hashmac == hashmac).first() + if counter: + return id + encrypted_config = json.dumps(encrypted_config) + database_info_entry = DatabaseInfo(encrypted_database_url=encrypted_database_url, + encrypted_config=encrypted_config, hashmac=hashmac) + session.add(database_info_entry) + session.commit() + id = database_info_entry.id + return id + + @staticmethod + async def del_database_by_id(id): + with PostgresDB.get_session() as session: + database_info_to_delete = session.query(DatabaseInfo).filter(DatabaseInfo.id == id).first() + if database_info_to_delete: + session.delete(database_info_to_delete) + else: + return False + session.commit() + return True + + @staticmethod + async def get_database_url_by_id(id): + with PostgresDB.get_session() as session: + result = session.query( + DatabaseInfo.encrypted_database_url, DatabaseInfo.encrypted_config).filter( + DatabaseInfo.id == id).first() + if result is None: + return None + try: + encrypted_database_url, encrypted_config = result + encrypted_config = json.loads(encrypted_config) + except Exception as e: + logging.error(f'数据库url解密失败由于{e}') + return None + if encrypted_database_url: + database_url = Security.decrypt(encrypted_database_url, encrypted_config) + else: + return None + return database_url + + @staticmethod + async def get_all_database_info(): + with PostgresDB.get_session() as session: + results = session.query(DatabaseInfo).order_by(DatabaseInfo.created_at).all() + database_info_list = [] + for i in range(len(results)): + database_id = results[i].id + encrypted_database_url = results[i].encrypted_database_url + encrypted_config = json.loads(results[i].encrypted_config) + created_at = results[i].created_at + if encrypted_database_url: + database_url = Security.decrypt(encrypted_database_url, encrypted_config) + tmp_dict = {'database_id': database_id, 'database_url': database_url, 'created_at': created_at} + database_info_list.append(tmp_dict) + return database_info_list diff --git a/chat2db/manager/sql_example_manager.py b/chat2db/manager/sql_example_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..09805e24cd52c47c40febc10b3e67727bec072ca --- /dev/null +++ b/chat2db/manager/sql_example_manager.py @@ -0,0 +1,75 @@ +import json +from sqlalchemy import and_ +from chat2db.database.postgres import SqlExample, PostgresDB +from chat2db.security.security import Security + + +class SqlExampleManager(): + @staticmethod + async def add_sql_example(question, sql, table_id, question_vector): + id = None + sql_example_entry = SqlExample(question=question, sql=sql, + table_id=table_id, question_vector=question_vector) + with PostgresDB.get_session() as session: + session.add(sql_example_entry) + session.commit() + id = sql_example_entry.id + return id + + @staticmethod + async def del_sql_example_by_id(id): + with PostgresDB.get_session() as session: + sql_example_to_delete = session.query(SqlExample).filter(SqlExample.id == id).first() + if sql_example_to_delete: + session.delete(sql_example_to_delete) + else: + return False + session.commit() + return True + + @staticmethod + async def update_sql_example_by_id(id, question, sql, question_vector): + with PostgresDB.get_session() as session: + sql_example_to_update = session.query(SqlExample).filter(SqlExample.id == id).first() + if sql_example_to_update: + sql_example_to_update.sql = sql + sql_example_to_update.question = question + sql_example_to_update.question_vector = question_vector + session.commit() + else: + return False + return True + + @staticmethod + async def query_sql_example_by_table_id(table_id): + with PostgresDB.get_session() as session: + results = session.query(SqlExample).filter(SqlExample.table_id == table_id).all() + sql_example_list = [] + for result in results: + tmp_dict = { + 'sql_example_id': result.id, + 'question': result.question, + 'sql': result.sql + } + sql_example_list.append(tmp_dict) + return sql_example_list + + @staticmethod + async def get_topk_sql_example_by_cos_dis(question_vector, table_id_list=None, topk=3): + with PostgresDB.get_session() as session: + if table_id_list is not None: + sql_example_list = session.query( + SqlExample + ).filter(SqlExample.table_id.in_(table_id_list)).order_by( + SqlExample.question_vector.cosine_distance(question_vector) + ).limit(topk).all() + else: + sql_example_list = session.query( + SqlExample + ).order_by( + SqlExample.question_vector.cosine_distance(question_vector) + ).limit(topk).all() + sql_example_list = [ + {'table_id': sql_example.table_id, 'question': sql_example.question, 'sql': sql_example.sql} + for sql_example in sql_example_list] + return sql_example_list diff --git a/chat2db/manager/table_info_manager.py b/chat2db/manager/table_info_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..0a5e0c94c5b8e126dab79b9734eb04e801ffdcb6 --- /dev/null +++ b/chat2db/manager/table_info_manager.py @@ -0,0 +1,75 @@ +from sqlalchemy import and_ + +from chat2db.database.postgres import TableInfo, PostgresDB + + +class TableInfoManager(): + @staticmethod + async def add_table_info(database_id, table_name, table_note, table_note_vector): + id = None + with PostgresDB.get_session() as session: + counter = session.query(TableInfo).filter( + and_(TableInfo.database_id == database_id, TableInfo.table_name == table_name)).first() + if counter: + return id + table_info_entry = TableInfo(database_id=database_id, table_name=table_name, + table_note=table_note, table_note_vector=table_note_vector) + session.add(table_info_entry) + session.commit() + id = table_info_entry.id + return id + + @staticmethod + async def del_table_by_id(id): + with PostgresDB.get_session() as session: + table_info_to_delete = session.query(TableInfo).filter(TableInfo.id == id).first() + if table_info_to_delete: + session.delete(table_info_to_delete) + else: + return False + session.commit() + return True + + @staticmethod + async def get_table_info_by_table_id(table_id): + with PostgresDB.get_session() as session: + table_id, database_id, table_name, table_note = session.query( + TableInfo.id, TableInfo.database_id, TableInfo.table_name, TableInfo.table_note).filter( + TableInfo.id == table_id).first() + if table_id is None: + return None + return { + 'table_id': table_id, + 'database_id': database_id, + 'table_name': table_name, + 'table_note': table_note + } + + @staticmethod + async def get_table_info_by_database_id(database_id, enable=None): + with PostgresDB.get_session() as session: + if enable is None: + results = session.query( + TableInfo).filter(TableInfo.database_id == database_id).all() + else: + results = session.query( + TableInfo).filter( + and_(TableInfo.database_id == database_id, + TableInfo.enable == enable + )).all() + table_info_list = [] + for result in results: + table_info_list.append({'table_id': result.id, 'table_name': result.table_name, + 'table_note': result.table_note, 'created_at': result.created_at}) + return table_info_list + + @staticmethod + async def get_topk_table_by_cos_dis(database_id, tmp_vector, topk=3): + with PostgresDB.get_session() as session: + results = session.query( + TableInfo.id + ).filter(TableInfo.database_id == database_id).order_by( + TableInfo.table_note_vector.cosine_distance(tmp_vector) + ).limit(topk).all() + table_id_list = [result[0] for result in results] + return table_id_list diff --git a/chat2db/model/request.py b/chat2db/model/request.py new file mode 100644 index 0000000000000000000000000000000000000000..beba6859b66f77d239a444ab019ec04a1cb7bd2b --- /dev/null +++ b/chat2db/model/request.py @@ -0,0 +1,83 @@ +import uuid +from pydantic import BaseModel, Field +from enum import Enum + + +class QueryRequest(BaseModel): + question: str + topk_sql: int = 5 + topk_answer: int = 15 + use_llm_enhancements: bool = False + + +class DatabaseAddRequest(BaseModel): + database_url: str + + +class DatabaseDelRequest(BaseModel): + database_id: uuid.UUID + + +class TableAddRequest(BaseModel): + database_id: uuid.UUID + table_name: str + + +class TableDelRequest(BaseModel): + table_id: uuid.UUID + + +class TableQueryRequest(BaseModel): + database_id: uuid.UUID + + +class EnableColumnRequest(BaseModel): + column_id: uuid.UUID + enable: bool + + +class SqlExampleAddRequest(BaseModel): + table_id: uuid.UUID + question: str + sql: str + + +class SqlExampleDelRequest(BaseModel): + sql_example_id: uuid.UUID + + +class SqlExampleQueryRequest(BaseModel): + table_id: uuid.UUID + + +class SqlExampleUpdateRequest(BaseModel): + sql_example_id: uuid.UUID + question: str + sql: str + + +class SqlGenerateRequest(BaseModel): + database_id: uuid.UUID + table_id_list: list[uuid.UUID] = [] + question: str + topk: int = 5 + use_llm_enhancements: bool = True + + +class SqlRepairRequest(BaseModel): + database_id: uuid.UUID + table_id: uuid.UUID + sql: str + message: str = Field(..., max_length=2048) + question: str + + +class SqlExcuteRequest(BaseModel): + database_id: uuid.UUID + sql: str + + +class SqlExampleGenerateRequest(BaseModel): + table_id: uuid.UUID + generate_cnt: int = 1 + sql_var: bool = False diff --git a/chat2db/model/response.py b/chat2db/model/response.py new file mode 100644 index 0000000000000000000000000000000000000000..6aba13235e0ecb0dfd0814dc92534a544f275a0a --- /dev/null +++ b/chat2db/model/response.py @@ -0,0 +1,5 @@ +from pydantic import BaseModel +class ResponseData(BaseModel): + code: int + message: str + result: dict \ No newline at end of file diff --git a/chat2db/scripts/chat2db_config/config.yaml b/chat2db/scripts/chat2db_config/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..78e3719e8a65cf870cad6994a66bf4120dc113e4 --- /dev/null +++ b/chat2db/scripts/chat2db_config/config.yaml @@ -0,0 +1,2 @@ +UVICORN_IP: 0.0.0.0 +UVICORN_PORT: '9015' diff --git a/chat2db/scripts/docs/output_examples.xlsx b/chat2db/scripts/docs/output_examples.xlsx new file mode 100644 index 0000000000000000000000000000000000000000..599501ca6c0f1d2b88fe5235d40fb11a56fbf005 Binary files /dev/null and b/chat2db/scripts/docs/output_examples.xlsx differ diff --git a/chat2db/scripts/run_chat2db.py b/chat2db/scripts/run_chat2db.py new file mode 100644 index 0000000000000000000000000000000000000000..d8139142b6197b6c4909abc7045b1d7eabf7f1f1 --- /dev/null +++ b/chat2db/scripts/run_chat2db.py @@ -0,0 +1,436 @@ +import argparse +import os +import pandas as pd +import requests +import yaml +from fastapi import FastAPI +import shutil + +terminal_width = shutil.get_terminal_size().columns +app = FastAPI() + +CHAT2DB_CONFIG_PATH = './chat2db_config' +CONFIG_YAML_PATH = './chat2db_config/config.yaml' +DEFAULT_CHAT2DB_CONFIG = { + "UVICORN_IP": "127.0.0.1", + "UVICORN_PORT": "8000" +} + + +# 修改 +def update_config(uvicorn_ip, uvicorn_port): + try: + yml = {'UVICORN_IP': uvicorn_ip, 'UVICORN_PORT': uvicorn_port} + with open(CONFIG_YAML_PATH, 'w') as file: + yaml.dump(yml, file) + return {"message": "修改成功"} + except Exception as e: + return {"message": f"修改失败,由于:{e}"} + + +# 增加数据库 +def call_add_database_info(database_url): + url = f"http://{config['UVICORN_IP']}:{config['UVICORN_PORT']}/database/add" + request_body = { + "database_url": database_url + } + response = requests.post(url, json=request_body) + return response.json() + + +# 删除数据库 +def call_del_database_info(database_id): + url = f"http://{config['UVICORN_IP']}:{config['UVICORN_PORT']}/database/del" + request_body = { + "database_id": database_id + } + response = requests.post(url, json=request_body) + return response.json() + + +# 查询数据库配置 +def call_query_database_info(): + url = f"http://{config['UVICORN_IP']}:{config['UVICORN_PORT']}/database/query" + response = requests.get(url) + return response.json() + + +# 查询数据库内表格配置 +def call_list_table_in_database(database_id, table_filter=''): + url = f"http://{config['UVICORN_IP']}:{config['UVICORN_PORT']}/database/list" + params = { + "database_id": database_id, + "table_filter": table_filter + } + print(params) + response = requests.get(url, params=params) + return response.json() + + +# 增加数据表 +def call_add_table_info(database_id, table_name): + url = f"http://{config['UVICORN_IP']}:{config['UVICORN_PORT']}/table/add" + request_body = { + "database_id": database_id, + "table_name": table_name + } + response = requests.post(url, json=request_body) + return response.json() + + +# 删除数据表 +def call_del_table_info(table_id): + url = f"http://{config['UVICORN_IP']}:{config['UVICORN_PORT']}/table/del" + request_body = { + "table_id": table_id + } + response = requests.post(url, json=request_body) + return response.json() + + +# 查询数据表配置 +def call_query_table_info(database_id): + url = f"http://{config['UVICORN_IP']}:{config['UVICORN_PORT']}/table/query" + params = { + "database_id": database_id + } + response = requests.get(url, params=params) + return response.json() + + +# 查询数据表列信息 +def call_query_column(table_id): + url = f"http://{config['UVICORN_IP']}:{config['UVICORN_PORT']}/table/column/query" + params = { + "table_id": table_id + } + response = requests.get(url, params=params) + return response.json() + + +# 启用禁用列 +def call_enable_column(column_id, enable): + url = f"http://{config['UVICORN_IP']}:{config['UVICORN_PORT']}/table/column/enable" + request_body = { + "column_id": column_id, + "enable": enable + } + response = requests.post(url, json=request_body) + return response.json() + + +# 增加sql_example案例 +def call_add_sql_example(table_id, question, sql): + url = f"http://{config['UVICORN_IP']}:{config['UVICORN_PORT']}/sql/example/add" + request_body = { + "table_id": table_id, + "question": question, + "sql": sql + } + response = requests.post(url, json=request_body) + return response.json() + + +# 删除sql_example案例 +def call_del_sql_example(sql_example_id): + url = f"http://{config['UVICORN_IP']}:{config['UVICORN_PORT']}/sql/example/del" + request_body = { + "sql_example_id": sql_example_id + } + response = requests.post(url, json=request_body) + return response.json() + + +# 查询sql_example案例 +def call_query_sql_example(table_id): + url = f"http://{config['UVICORN_IP']}:{config['UVICORN_PORT']}/sql/example/query" + params = { + "table_id": table_id + } + response = requests.get(url, params=params) + return response.json() + + +# 更新sql_example案例 +def call_update_sql_example(sql_example_id, question, sql): + url = f"http://{config['UVICORN_IP']}:{config['UVICORN_PORT']}/sql/example/update" + request_body = { + "sql_example_id": sql_example_id, + "question": question, + "sql": sql + } + response = requests.post(url, json=request_body) + return response.json() + + +# 生成sql_example案例 +def call_generate_sql_example(table_id, generate_cnt=1, sql_var=False): + url = f"http://{config['UVICORN_IP']}:{config['UVICORN_PORT']}/sql/example/generate" + response_body = { + "table_id": table_id, + "generate_cnt": generate_cnt, + "sql_var": sql_var + } + response = requests.post(url, json=response_body) + return response.json() + + +def write_sql_example_to_excel(dir, sql_example_list): + try: + if not os.path.exists(os.path.dirname(dir)): + os.makedirs(os.path.dirname(dir)) + data = { + 'question': [], + 'sql': [] + } + for sql_example in sql_example_list: + data['question'].append(sql_example['question']) + data['sql'].append(sql_example['sql']) + + df = pd.DataFrame(data) + df.to_excel(dir, index=False) + + print("Data written to Excel file successfully.") + except Exception as e: + print("Error writing data to Excel file:", str(e)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="chat2DB脚本") + subparsers = parser.add_subparsers(dest="command", help="子命令列表") + + # 修改config.yaml + parser_config = subparsers.add_parser("config", help="修改config.yaml") + parser_config.add_argument("--ip", type=str, required=True, help="uvicorn ip") + parser_config.add_argument("--port", type=str, required=True, help="uvicorn port") + + # 增加数据库 + parser_add_database = subparsers.add_parser("add_db", help="增加指定数据库") + parser_add_database.add_argument("--database_url", type=str, required=True, + help="数据库连接地址,如postgresql+psycopg2://postgres:123456@0.0.0.0:5432/postgres") + + # 删除数据库 + parser_del_database = subparsers.add_parser("del_db", help="删除指定数据库") + parser_del_database.add_argument("--database_id", type=str, required=True, help="数据库id") + + # 查询数据库配置 + parser_query_database = subparsers.add_parser("query_db", help="查询指定数据库配置") + + # 查询数据库内表格配置 + parser_list_table_in_database = subparsers.add_parser("list_tb_in_db", help="查询数据库内表格配置") + parser_list_table_in_database.add_argument("--database_id", type=str, required=True, help="数据库id") + parser_list_table_in_database.add_argument("--table_filter", type=str, required=False, help="表格名称过滤条件") + + # 增加数据表 + parser_add_table = subparsers.add_parser("add_tb", help="增加指定数据库内的数据表") + parser_add_table.add_argument("--database_id", type=str, required=True, help="数据库id") + parser_add_table.add_argument("--table_name", type=str, required=True, help="数据表名称") + + # 删除数据表 + parser_del_table = subparsers.add_parser("del_tb", help="删除指定数据表") + parser_del_table.add_argument("--table_id", type=str, required=True, help="数据表id") + + # 查询数据表配置 + parser_query_table = subparsers.add_parser("query_tb", help="查询指定数据表配置") + parser_query_table.add_argument("--database_id", type=str, required=True, help="数据库id") + + # 查询数据表列信息 + parser_query_column = subparsers.add_parser("query_col", help="查询指定数据表详细列信息") + parser_query_column.add_argument("--table_id", type=str, required=True, help="数据表id") + + # 启用禁用列 + parser_enable_column = subparsers.add_parser("enable_col", help="启用禁用指定列") + parser_enable_column.add_argument("--column_id", type=str, required=True, help="列id") + parser_enable_column.add_argument("--enable", type=bool, required=True, help="是否启用") + + # 增加sql案例 + parser_add_sql_example = subparsers.add_parser("add_sql_exp", help="增加指定数据表sql案例") + parser_add_sql_example.add_argument("--table_id", type=str, required=True, help="数据表id") + parser_add_sql_example.add_argument("--question", type=str, required=False, help="问题") + parser_add_sql_example.add_argument("--sql", type=str, required=False, help="sql") + parser_add_sql_example.add_argument("--dir", type=str, required=False, help="输入路径") + + # 删除sql_exp + parser_del_sql_example = subparsers.add_parser("del_sql_exp", help="删除指定sql案例") + parser_del_sql_example.add_argument("--sql_example_id", type=str, required=True, help="sql案例id") + + # 查询sql案例 + parser_query_sql_example = subparsers.add_parser("query_sql_exp", help="查询指定数据表sql对案例") + parser_query_sql_example.add_argument("--table_id", type=str, required=True, help="数据表id") + + # 更新sql案例 + parser_update_sql_example = subparsers.add_parser("update_sql_exp", help="更新sql对案例") + parser_update_sql_example.add_argument("--sql_example_id", type=str, required=True, help="sql案例id") + parser_update_sql_example.add_argument("--question", type=str, required=True, help="sql语句对应的问题") + parser_update_sql_example.add_argument("--sql", type=str, required=True, help="sql语句") + + # 生成sql案例 + parser_generate_sql_example = subparsers.add_parser("generate_sql_exp", help="生成指定数据表sql对案例") + parser_generate_sql_example.add_argument("--table_id", type=str, required=True, help="数据表id") + parser_generate_sql_example.add_argument("--generate_cnt", type=int, required=False, help="生成sql对数量", + default=1) + parser_generate_sql_example.add_argument("--sql_var", type=bool, required=False, + help="是否验证生成的sql对,True为验证,False不验证", + default=False) + parser_generate_sql_example.add_argument("--dir", type=str, required=False, help="生成的sql对输出路径", + default="docs/output_examples.xlsx") + + args = parser.parse_args() + + if os.path.exists(CONFIG_YAML_PATH): + exist = True + with open(CONFIG_YAML_PATH, 'r') as file: + yml = yaml.safe_load(file) + config = { + 'UVICORN_IP': yml.get('UVICORN_IP'), + 'UVICORN_PORT': yml.get('UVICORN_PORT'), + } + else: + exist = False + + if args.command == "config": + if not exist: + os.makedirs(CHAT2DB_CONFIG_PATH, exist_ok=True) + with open(CONFIG_YAML_PATH, 'w') as file: + yaml.dump(DEFAULT_CHAT2DB_CONFIG, file, default_flow_style=False) + response = update_config(args.ip, args.port) + with open(CONFIG_YAML_PATH, 'r') as file: + yml = yaml.safe_load(file) + config = { + 'UVICORN_IP': yml.get('UVICORN_IP'), + 'UVICORN_PORT': yml.get('UVICORN_PORT'), + } + print(response.get("message")) + elif not exist: + print("please update_config first") + + elif args.command == "add_db": + response = call_add_database_info(args.database_url) + database_id = response.get("result")['database_id'] + print(response.get("message")) + if response.get("code") == 200: + print(f'database_id: ', database_id) + + elif args.command == "del_db": + response = call_del_sql_example(args.database_id) + print(response.get("message")) + + elif args.command == "query_db": + response = call_query_database_info() + print(response.get("message")) + if response.get("code") == 200: + database_info = response.get("result")['database_info_list'] + for database in database_info: + print('-' * terminal_width) + print("database_id:", database["database_id"]) + print("database_url:", database["database_url"]) + print("created_at:", database["created_at"]) + print('-' * terminal_width) + + elif args.command == "list_tb_in_db": + response = call_list_table_in_database(args.database_id, args.table_filter) + print(response.get("message")) + if response.get("code") == 200: + table_name_list = response.get("result")['table_name_list'] + for table_name in table_name_list: + print(table_name) + + elif args.command == "add_tb": + response = call_add_table_info(args.database_id, args.table_name) + print(response.get("message")) + table_id = response.get("result")['table_id'] + if response.get("code") == 200: + print('table_id: ', table_id) + + elif args.command == "del_tb": + response = call_del_table_info(args.table_id) + print(response.get("message")) + + elif args.command == "query_tb": + response = call_query_table_info(args.database_id) + print(response.get("message")) + if response.get("code") == 200: + table_list = response.get("result")['table_info_list'] + for table in table_list: + print('-' * terminal_width) + print("table_id:", table['table_id']) + print("table_name:", table['table_name']) + print("table_note:", table['table_note']) + print("created_at:", table['created_at']) + print('-' * terminal_width) + + elif args.command == "query_col": + response = call_query_column(args.table_id) + print(response.get("message")) + if response.get("code") == 200: + column_list = response.get("result")['column_info_list'] + for column in column_list: + print('-' * terminal_width) + print("column_id:", column['column_id']) + print("column_name:", column['column_name']) + print("column_note:", column['column_note']) + print("column_type:", column['column_type']) + print("enable:", column['enable']) + print('-' * terminal_width) + + elif args.command == "enable_col": + response = call_enable_column(args.column_id, args.enable) + print(response.get("message")) + + elif args.command == "add_sql_exp": + def get_sql_exp(dir): + if not os.path.exists(os.path.dirname(dir)): + return None + # 读取 xlsx 文件 + df = pd.read_excel(dir) + + # 遍历每一行数据 + for index, row in df.iterrows(): + question = row['question'] + sql = row['sql'] + + # 调用 call_add_sql_example 函数 + response = call_add_sql_example(args.table_id, question, sql) + print(response.get("message")) + sql_example_id = response.get("result")['sql_example_id'] + print('sql_example_id: ', sql_example_id) + print(question, sql) + + + if args.dir: + get_sql_exp(args.dir) + else: + response = call_add_sql_example(args.table_id, args.question, args.sql) + print(response.get("message")) + sql_example_id = response.get("result")['sql_example_id'] + print('sql_example_id: ', sql_example_id) + + elif args.command == "del_sql_exp": + response = call_del_sql_example(args.sql_example_id) + print(response.get("message")) + + elif args.command == "query_sql_exp": + response = call_query_sql_example(args.table_id) + print(response.get("message")) + if response.get("code") == 200: + sql_example_list = response.get("result")['sql_example_list'] + for sql_example in sql_example_list: + print('-' * terminal_width) + print("sql_example_id:", sql_example['sql_example_id']) + print("question:", sql_example['question']) + print("sql:", sql_example['sql']) + print('-' * terminal_width) + + elif args.command == "update_sql_exp": + response = call_update_sql_example(args.sql_example_id, args.question, args.sql) + print(response.get("message")) + + elif args.command == "generate_sql_exp": + response = call_generate_sql_example(args.table_id, args.generate_cnt, args.sql_var) + print(response.get("message")) + if response.get("code") == 200: + # 输出到execl中 + sql_example_list = response.get("result")['sql_example_list'] + write_sql_example_to_excel(args.dir, sql_example_list) + else: + print("未知命令,请检查输入的命令是否正确。") diff --git a/chat2db/security/security.py b/chat2db/security/security.py new file mode 100644 index 0000000000000000000000000000000000000000..0909f27bf29fa5cc8c405ff0e4d998c7f1fbf03d --- /dev/null +++ b/chat2db/security/security.py @@ -0,0 +1,116 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. + +import base64 +import binascii +import hashlib +import secrets + +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes + +from chat2db.config.config import config + + +class Security: + + @staticmethod + def encrypt(plaintext: str) -> tuple[str, dict]: + """ + 加密公共方法 + :param plaintext: + :return: + """ + half_key1 = config['HALF_KEY1'] + + encrypted_work_key, encrypted_work_key_iv = Security._generate_encrypted_work_key( + half_key1) + encrypted_plaintext, encrypted_iv = Security._encrypt_plaintext(half_key1, encrypted_work_key, + encrypted_work_key_iv, plaintext) + del plaintext + secret_dict = { + "encrypted_work_key": encrypted_work_key, + "encrypted_work_key_iv": encrypted_work_key_iv, + "encrypted_iv": encrypted_iv, + "half_key1": half_key1 + } + return encrypted_plaintext, secret_dict + + @staticmethod + def decrypt(encrypted_plaintext: str, secret_dict: dict): + """ + 解密公共方法 + :param encrypted_plaintext: 待解密的字符串 + :param secret_dict: 存放工作密钥的dict + :return: + """ + plaintext = Security._decrypt_plaintext(half_key1=secret_dict.get("half_key1"), + encrypted_work_key=secret_dict.get( + "encrypted_work_key"), + encrypted_work_key_iv=secret_dict.get( + "encrypted_work_key_iv"), + encrypted_iv=secret_dict.get( + "encrypted_iv"), + encrypted_plaintext=encrypted_plaintext) + return plaintext + + @staticmethod + def _get_root_key(half_key1: str) -> bytes: + half_key2 = config['HALF_KEY2'] + key = (half_key1 + half_key2).encode("utf-8") + half_key3 = config['HALF_KEY3'].encode("utf-8") + hash_key = hashlib.pbkdf2_hmac("sha256", key, half_key3, 10000) + return binascii.hexlify(hash_key)[13:45] + + @staticmethod + def _generate_encrypted_work_key(half_key1: str) -> tuple[str, str]: + bin_root_key = Security._get_root_key(half_key1) + bin_work_key = secrets.token_bytes(32) + bin_encrypted_work_key_iv = secrets.token_bytes(16) + bin_encrypted_work_key = Security._root_encrypt(bin_root_key, bin_encrypted_work_key_iv, bin_work_key) + encrypted_work_key = base64.b64encode(bin_encrypted_work_key).decode("ascii") + encrypted_work_key_iv = base64.b64encode(bin_encrypted_work_key_iv).decode("ascii") + return encrypted_work_key, encrypted_work_key_iv + + @staticmethod + def _get_work_key(half_key1: str, encrypted_work_key: str, encrypted_work_key_iv: str) -> bytes: + bin_root_key = Security._get_root_key(half_key1) + bin_encrypted_work_key = base64.b64decode(encrypted_work_key.encode("ascii")) + bin_encrypted_work_key_iv = base64.b64decode(encrypted_work_key_iv.encode("ascii")) + return Security._root_decrypt(bin_root_key, bin_encrypted_work_key_iv, bin_encrypted_work_key) + + @staticmethod + def _root_encrypt(key: bytes, encrypted_iv: bytes, plaintext: bytes) -> bytes: + encryptor = Cipher(algorithms.AES(key), modes.GCM(encrypted_iv), default_backend()).encryptor() + encrypted = encryptor.update(plaintext) + encryptor.finalize() + return encrypted + + @staticmethod + def _root_decrypt(key: bytes, encrypted_iv: bytes, encrypted: bytes) -> bytes: + encryptor = Cipher(algorithms.AES(key), modes.GCM(encrypted_iv), default_backend()).encryptor() + plaintext = encryptor.update(encrypted) + return plaintext + + @staticmethod + def _encrypt_plaintext(half_key1: str, encrypted_work_key: str, encrypted_work_key_iv: str, + plaintext: str) -> tuple[str, str]: + bin_work_key = Security._get_work_key(half_key1, encrypted_work_key, encrypted_work_key_iv) + salt = f"{half_key1}{plaintext}" + plaintext_temp = salt.encode("utf-8") + del plaintext + del salt + bin_encrypted_iv = secrets.token_bytes(16) + bin_encrypted_plaintext = Security._root_encrypt(bin_work_key, bin_encrypted_iv, plaintext_temp) + encrypted_plaintext = base64.b64encode(bin_encrypted_plaintext).decode("ascii") + encrypted_iv = base64.b64encode(bin_encrypted_iv).decode("ascii") + return encrypted_plaintext, encrypted_iv + + @staticmethod + def _decrypt_plaintext(half_key1: str, encrypted_work_key: str, encrypted_work_key_iv: str, + encrypted_plaintext: str, encrypted_iv) -> str: + bin_work_key = Security._get_work_key(half_key1, encrypted_work_key, encrypted_work_key_iv) + bin_encrypted_plaintext = base64.b64decode(encrypted_plaintext.encode("ascii")) + bin_encrypted_iv = base64.b64decode(encrypted_iv.encode("ascii")) + plaintext_temp = Security._root_decrypt(bin_work_key, bin_encrypted_iv, bin_encrypted_plaintext) + plaintext_salt = plaintext_temp.decode("utf-8") + plaintext = plaintext_salt[len(half_key1):] + return plaintext \ No newline at end of file diff --git a/data_chain/common/.env.example b/data_chain/common/.env.example index f1b5e2f13eaddd47a46eda373997308c75e20910..1e0836f8522075eeb20a25bf987955d44e7f50cb 100644 --- a/data_chain/common/.env.example +++ b/data_chain/common/.env.example @@ -1,54 +1,57 @@ - # Fastapi +# FastAPI UVICORN_IP= UVICORN_PORT= SSL_CERTFILE= SSL_KEYFILE= SSL_ENABLE= -LOG= - -# Postgres 5444 +# LOG METHOD +LOG_METHOD= +# Database +DATABASE_TYPE= DATABASE_URL= - # MinIO MINIO_ENDPOINT= MINIO_ACCESS_KEY= MINIO_SECRET_KEY= MINIO_SECURE= - # Redis REDIS_HOST= REDIS_PORT= -# REDIS_PWD= +REDIS_PWD= REDIS_PENDING_TASK_QUEUE_NAME= REDIS_SUCCESS_TASK_QUEUE_NAME= REDIS_RESTART_TASK_QUEUE_NAME= -# Embedding Service -REMOTE_EMBEDDING_ENDPOINT= - -# Key -CSRF_KEY= +REDIS_SILENT_ERROR_TASK_QUEUE_NAME= +# Task +TASK_RETRY_TIME= +# Embedding +EMBEDDING_TYPE= +EMBEDDING_API_KEY= +EMBEDDING_ENDPOINT= +EMBEDDING_MODEL_NAME= +# Token SESSION_TTL= - -# PROMPT_PATH -PROMPT_PATH= -# Stop Words PATH -STOP_WORDS_PATH= - -#Security +CSRF_KEY= +# Security HALF_KEY1= HALF_KEY2= HALF_KEY3= - -#LLM config -MODEL_NAME= -OPENAI_API_BASE= -OPENAI_API_KEY= -REQUEST_TIMEOUT= -MAX_TOKENS= +# Prompt file +PROMPT_PATH= +# Stop Words PATH +STOP_WORDS_PATH= +# LLM config +MODEL_1_MODEL_ID= +MODEL_1_MODEL_NAME= +MODEL_1_MODEL_TYPE= +MODEL_1_OPENAI_API_BASE= +MODEL_1_OPENAI_API_KEY= +MODEL_1_MAX_TOKENS= MODEL_ENH= - -#DEFAULT USER +# DEFAULT USER DEFAULT_USER_ACCOUNT= DEFAULT_USER_PASSWD= DEFAULT_USER_NAME= -DEFAULT_USER_LANGUAGE= \ No newline at end of file +DEFAULT_USER_LANGUAGE= +# DOCUMENT PARSER +DOCUMENT_PARSE_USE_CPU_LIMIT= \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index c7a0ecefe3ca8617fa904c2225b64674700e165f..e9363212fe1ca04367f3f7840c4d0e20d8a77492 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,40 +1,48 @@ -python-docx==1.1.0 +aiofiles==24.1.0 +aiomysql==0.2.0 +apscheduler==3.10.0 asgi-correlation-id==4.3.1 +asyncpg==0.29.0 +asyncpg==0.29.0 +beautifulsoup4==4.12.3 +chardet==5.2.0 +cryptography==42.0.2 fastapi==0.110.2 -starlette==0.37.2 -sqlalchemy==2.0.23 -numpy==1.26.4 -scikit-learn==1.5.0 +fastapi-pagination==0.12.19 +httpx==0.27.0 +itsdangerous==2.1.2 +jieba==0.42.1 langchain==0.3.7 -redis==5.0.3 -pgvector==0.2.5 +langchain-community == 0.0.34 +langchain-openai == 0.0.8 +langchain-openai==0.2.5 +langchain-text-splitters==0.0.1 +minio==7.2.4 +more-itertools==10.1.0 +numpy==1.26.4 +openai==1.53.0 +opencv-python==4.9.0.80 +openpyxl==3.1.2 +paddleocr==2.9.1 +paddlepaddle==2.6.2 pandas==2.2.2 -jieba==0.42.1 -cryptography==42.0.2 -httpx==0.27.0 +pgvector==0.2.5 +Pillow==10.3.0 +psycopg2-binary==2.9.9 pydantic==2.7.4 -aiofiles==24.1.0 +PyMuPDF==1.24.13 +python-docx==1.1.0 +python-dotenv==1.0.1 +python-multipart==0.0.9 +python-pptx==0.6.23 pytz==2020.5 pyyaml==6.0.1 +redis==5.0.3 +requests==2.32.2 +scikit-learn==1.5.0 +sqlalchemy==2.0.23 +starlette==0.37.2 tika==2.6.0 +tiktoken==0.8.0 urllib3==2.2.1 -paddlepaddle==2.6.2 -paddleocr==2.9.1 -apscheduler==3.10.0 uvicorn==0.21.0 -requests==2.32.2 -Pillow==10.3.0 -minio==7.2.4 -python-dotenv==1.0.1 -langchain-openai==0.2.5 -openai==1.53.0 -opencv-python==4.9.0.80 -chardet==5.2.0 -PyMuPDF==1.24.13 -python-multipart==0.0.9 -asyncpg==0.29.0 -psycopg2-binary==2.9.9 -openpyxl==3.1.2 -beautifulsoup4==4.12.3 -tiktoken==0.8.0 -python-pptx==0.6.23 \ No newline at end of file diff --git a/run.sh b/run.sh index 6cc9be0ad973788a67c40166e99aef173538e21c..df3d5bcce1839f2466c944af41f18d281b24b1f4 100644 --- a/run.sh +++ b/run.sh @@ -1,4 +1,4 @@ #!/usr/bin/env sh nohup java -jar tika-server-standard-2.9.2.jar & python3 /rag-service/data_chain/apps/app.py - +python3 /rag-service/chat2db/apps/app.py diff --git a/spider/oe_spider/config/pg_info.yaml b/spider/oe_spider/config/pg_info.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9013b1a707330af8699caf10a0dd86bcf918c286 --- /dev/null +++ b/spider/oe_spider/config/pg_info.yaml @@ -0,0 +1,5 @@ +pg_database: postgres +pg_host: 0.0.0.0 +pg_port: '5438' +pg_pwd: '123456' +pg_user: postgres diff --git a/spider/oe_spider/docs/openeuler_version.txt b/spider/oe_spider/docs/openeuler_version.txt new file mode 100644 index 0000000000000000000000000000000000000000..c70bca346d1395467bb3d2432e9c70177e278e2c --- /dev/null +++ b/spider/oe_spider/docs/openeuler_version.txt @@ -0,0 +1,26 @@ +openEuler-22.03-LTS 5.10.0-60.137.0 2022-03-31 长期支持版本 +openEuler-22.03-LTS-LoongArch 5.10.0-60.18.0 2022-03-31 其他版本 +openEuler-22.03-LTS-Next 5.10.0-199.0.0 2022-03-31 其他版本 +openEuler-22.09 5.10.0-106.19.0 2022-09-30 社区创新版本 +openEuler-22.09-HeXin 5.10.0-106.18.0 2022-09-30 其他版本 +openEuler-23.03 6.1.19 2023-03-31 社区创新版本 +openEuler-23.09 6.4.0-10.1.0 2023-09-30 社区创新版本 +openEuler-24.03-LTS 6.6.0-26.0.0 2024-06-06 长期支持版本 +openEuler-24.03-LTS-Next 6.6.0-16.0.0 2024-06-06 其他版本 +sync-pr1486-master-to-openEuler-24.03-LTS-Next 6.6.0-10.0.0 2024-06-06 其他版本 +sync-pr1519-openEuler-24.03-LTS-to-openEuler-24.03-LTS-Next 6.6.0-17.0.0 2024-06-06 其他版本 +openEuler-20.03-LTS 4.19.90-2202.1.0 2020-03-27 长期支持版本 +openEuler-20.03-LTS-Next 4.19.194-2106.1.0 2020-03-27 其他版本 +openEuler-20.09 4.19.140-2104.1.0 2020-09-30 社区创新版本 +openEuler-21.03 5.10.0-4.22.0 2021-03-30 社区创新版本 +openEuler-21.09 5.10.0-5.10.1 2021-09-30 社区创新版本 +openEuler-22.03-LTS-SP1 5.10.0-136.75.0 2022-12-30 其他版本 +openEuler-22.03-LTS-SP2 5.10.0-153.54.0 2023-06-30 其他版本 +openEuler-22.03-LTS-SP3 5.10.0-199.0.0 2023-12-30 其他版本 +openEuler-22.03-LTS-SP4 5.10.0-198.0.0 2024-06-30 其他版本 +sync-pr1314-openEuler-22.03-LTS-SP3-to-openEuler-22.03-LTS-Next 5.10.0-166.0.0 2023-12-30 其他版本 +openEuler-20.03-LTS-SP1 4.19.90-2405.3.0 2020-03-30 其他版本 +openEuler-20.03-LTS-SP1-testing 4.19.90-2012.4.0 2020-03-30 其他版本 +openEuler-20.03-LTS-SP2 4.19.90-2205.1.0 2021-06-30 其他版本 +openEuler-20.03-LTS-SP3 4.19.90-2401.1.0 2021-12-30 其他版本 +openEuler-20.03-LTS-SP4 4.19.90-2405.3.0 2023-11-30 其他版本 diff --git a/spider/oe_spider/docs/organize.txt b/spider/oe_spider/docs/organize.txt new file mode 100644 index 0000000000000000000000000000000000000000..321dd6f928bb3ea281a44f0d45902131cc9970c3 --- /dev/null +++ b/spider/oe_spider/docs/organize.txt @@ -0,0 +1,67 @@ +openEuler 委员会顾问专家委员会 +成员 陆首群 中国开源软件联盟名誉主席_原国务院信息办常务副主任 +成员 倪光南 中国工程院院士 +成员 廖湘科 中国工程院院士 +成员 王怀民 中国科学院院士 +成员 邬贺铨 中国科学院院士 +成员 周明辉 北京大学计算机系教授_北大博雅特聘教授 +成员 霍太稳 极客邦科技创始人兼CEO + +2023-2024 年 openEuler 委员会 +主席 江大勇 华为技术有限公司 +常务委员会委员 韩乃平 麒麟软件有限公司 +常务委员会委员 刘文清 湖南麒麟信安科技股份有限公司 +常务委员会委员 武延军 中国科学院软件研究所 +常务委员会委员 朱建忠 统信软件技术有限公司 +委员 蔡志旻 江苏润和软件股份有限公司 +委员 高培 软通动力信息技术(集团)股份有限公司 +委员 姜振华 超聚变数字技术有限公司 +委员 李培源 天翼云科技有限公司 +委员 张胜举 中国移动云能力中心 +委员 钟忻 联通数字科技有限公司 +执行总监 邱成锋 华为技术有限公司 +执行秘书 刘彦飞 开放原子开源基金会 + +openEuler 技术委员会 +主席 胡欣蔚 +委员 卞乃猛 +委员 曹志 +委员 陈棋德 +委员 侯健 +委员 胡峰 +委员 胡亚弟 +委员 李永强 +委员 刘寿永 +委员 吕从庆 +委员 任慰 +委员 石勇 +委员 田俊 +委员 王建民 +委员 王伶卓 +委员 王志钢 +委员 魏刚 +委员 吴峰光 +委员 谢秀奇 +委员 熊伟 +委员 赵川峰 + +openEuler 品牌委员会 +主席 梁冰 +委员 黄兆伟 +委员 康丽锋 +委员 丽娜 +委员 李震宁 +委员 马骏 +委员 王鑫辉 +委员 文丹 +委员 张迎 + +openEuler 用户委员会 +主席 薛蕾 +委员 高洪鹤 +委员 黄伟旭 +委员 王军 +委员 王友海 +委员 魏建刚 +委员 谢留华 +委员 朱晨 diff --git a/spider/oe_spider/oe_message_manager.py b/spider/oe_spider/oe_message_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..a6d15ade16522d91ccc0120a7e13739b025204e9 --- /dev/null +++ b/spider/oe_spider/oe_message_manager.py @@ -0,0 +1,560 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +import time + +from sqlalchemy import text +from pg import PostgresDB, OeSigGroupRepo, OeSigGroupMember, OeSigRepoMember +from pg import (OeCompatibilityOverallUnit, OeCompatibilityCard, OeCompatibilitySolution, + OeCompatibilityOpenSourceSoftware, OeCompatibilityCommercialSoftware, OeCompatibilityOepkgs, + OeCompatibilityOsv, OeCompatibilitySecurityNotice, OeCompatibilityCveDatabase, + OeCommunityOrganizationStructure, OeCommunityOpenEulerVersion, OeOpeneulerSigGroup, + OeOpeneulerSigMembers, OeOpeneulerSigRepos) + + +class OeMessageManager: + @staticmethod + def clear_oe_compatibility_overall_unit(pg_url): + try: + with PostgresDB(pg_url).get_session() as session: + session.execute(text("DROP TABLE IF EXISTS oe_compatibility_overall_unit;")) + session.commit() + PostgresDB(pg_url).create_table() + except Exception as e: + return + + @staticmethod + def add_oe_compatibility_overall_unit(pg_url, info): + oe_compatibility_overall_unit_slice = OeCompatibilityOverallUnit( + id=info.get("id", ''), + architecture=info.get("architecture", ''), + bios_uefi=info.get("biosUefi", ''), + certification_addr=info.get("certificationAddr", ''), + certification_time=info.get("certificationTime", ''), + commitid=info.get("commitID", ''), + computer_type=info.get("computerType", ''), + cpu=info.get("cpu", ''), + date=info.get("date", ''), + friendly_link=info.get("friendlyLink", ''), + hard_disk_drive=info.get("hardDiskDrive", ''), + hardware_factory=info.get("hardwareFactory", ''), + hardware_model=info.get("hardwareModel", ''), + host_bus_adapter=info.get("hostBusAdapter", ''), + lang=info.get("lang", ''), + main_board_model=info.get("mainboardModel", ''), + openeuler_version=info.get("osVersion", '').replace(' ', '-'), + ports_bus_types=info.get("portsBusTypes", ''), + product_information=info.get("productInformation", ''), + ram=info.get("ram", ''), + video_adapter=info.get("videoAdapter", ''), + compatibility_configuration=info.get("compatibilityConfiguration", ''), + boardCards=info.get("boardCards", '') + ) + try: + with PostgresDB(pg_url).get_session() as session: + session.add(oe_compatibility_overall_unit_slice) + session.commit() + except Exception as e: + return + + @staticmethod + def clear_oe_compatibility_card(pg_url): + try: + with PostgresDB(pg_url).get_session() as session: + session.execute(text("DROP TABLE IF EXISTS oe_compatibility_card;")) + session.commit() + PostgresDB(pg_url).create_table() + except Exception as e: + return + + @staticmethod + def add_oe_compatibility_card(pg_url, info): + oe_compatibility_card_slice = OeCompatibilityCard( + id=info.get("id", ''), + architecture=info.get("architecture", ''), + board_model=info.get("boardModel", ''), + chip_model=info.get("chipModel", ''), + chip_vendor=info.get("chipVendor", ''), + device_id=info.get("deviceID", ''), + download_link=info.get("downloadLink", ''), + driver_date=info.get("driverDate", ''), + driver_name=info.get("driverName", ''), + driver_size=info.get("driverSize", ''), + item=info.get("item", ''), + lang=info.get("lang", ''), + openeuler_version=info.get("os", '').replace(' ', '-'), + sha256=info.get("sha256", ''), + ss_id=info.get("ssID", ''), + sv_id=info.get("svID", ''), + type=info.get("type", ''), + vendor_id=info.get("vendorID", ''), + version=info.get("version", '') + ) + try: + with PostgresDB(pg_url).get_session() as session: + session.add(oe_compatibility_card_slice) + session.commit() + except Exception as e: + return + + @staticmethod + def clear_oe_compatibility_open_source_software(pg_url): + try: + with PostgresDB(pg_url).get_session() as session: + session.execute(text("DROP TABLE IF EXISTS oe_compatibility_open_source_software;")) + session.commit() + PostgresDB(pg_url).create_table() + except Exception as e: + return + + @staticmethod + def add_oe_compatibility_open_source_software(pg_url, info): + oe_compatibility_open_source_software_slice = OeCompatibilityOpenSourceSoftware( + openeuler_version=info.get("os", '').replace(' ', '-'), + arch=info.get("arch", ''), + property=info.get("property", ''), + result_url=info.get("result_url", ''), + result_root=info.get("result_root", ''), + bin=info.get("bin", ''), + uninstall=info.get("uninstall", ''), + license=info.get("license", ''), + libs=info.get("libs", ''), + install=info.get("install", ''), + src_location=info.get("src_location", ''), + group=info.get("group", ''), + cmds=info.get("cmds", ''), + type=info.get("type", ''), + softwareName=info.get("softwareName", ''), + category=info.get("category", ''), + version=info.get("version", ''), + downloadLink=info.get("downloadLink", '') + ) + try: + with PostgresDB(pg_url).get_session() as session: + session.add(oe_compatibility_open_source_software_slice) + session.commit() + except Exception as e: + return + + @staticmethod + def clear_oe_compatibility_commercial_software(pg_url): + try: + with PostgresDB(pg_url).get_session() as session: + session.execute(text("DROP TABLE IF EXISTS oe_compatibility_commercial_software;")) + session.commit() + PostgresDB(pg_url).create_table() + except Exception as e: + return + + @staticmethod + def add_oe_compatibility_commercial_software(pg_url, info): + oe_compatibility_commercial_software_slice = OeCompatibilityCommercialSoftware( + id=info.get("certId", ''), + data_id=info.get("dataId", ''), + type=info.get("type", ''), + test_organization=info.get("testOrganization", ''), + product_name=info.get("productName", ''), + product_version=info.get("productVersion", ''), + company_name=info.get("companyName", ''), + platform_type_and_server_model=info.get("platformTypeAndServerModel", ''), + authenticate_link=info.get("authenticateLink", ''), + openeuler_version=(info.get("osName", '') + info.get("osVersion", '')).replace(' ', '-'), + region=info.get("region", '') + ) + try: + with PostgresDB(pg_url).get_session() as session: + session.add(oe_compatibility_commercial_software_slice) + session.commit() + except Exception as e: + return + + @staticmethod + def clear_oe_compatibility_solution(pg_url): + try: + with PostgresDB(pg_url).get_session() as session: + session.execute(text("DROP TABLE IF EXISTS oe_compatibility_solution;")) + session.commit() + PostgresDB(pg_url).create_table() + except Exception as e: + return + + @staticmethod + def add_oe_compatibility_solution(pg_url, info): + oe_compatibility_solution_slice = OeCompatibilitySolution( + id=info.get("id", ''), + architecture=info.get("architecture", ''), + bios_uefi=info.get("biosUefi", ''), + certification_type=info.get("certificationType", ''), + cpu=info.get("cpu", ''), + date=info.get("date", ''), + driver=info.get("driver", ''), + hard_disk_drive=info.get("hardDiskDrive", ''), + introduce_link=info.get("introduceLink", ''), + lang=info.get("lang", ''), + libvirt_version=info.get("libvirtVersion", ''), + network_card=info.get("networkCard", ''), + openeuler_version=info.get("os", '').replace(' ', '-'), + ovs_version=info.get("OVSVersion", ''), + product=info.get("product", ''), + qemu_version=info.get("qemuVersion", ''), + raid=info.get("raid", ''), + ram=info.get("ram", ''), + server_model=info.get("serverModel", ''), + server_vendor=info.get("serverVendor", ''), + solution=info.get("solution", ''), + stratovirt_version=info.get("stratovirtVersion", '') + ) + try: + with PostgresDB(pg_url).get_session() as session: + session.add(oe_compatibility_solution_slice) + session.commit() + except Exception as e: + return + + @staticmethod + def clear_oe_compatibility_oepkgs(pg_url): + try: + with PostgresDB(pg_url).get_session() as session: + session.execute(text("DROP TABLE IF EXISTS oe_compatibility_oepkgs;")) + session.commit() + PostgresDB(pg_url).create_table() + except Exception as e: + return + + @staticmethod + def add_oe_compatibility_oepkgs(pg_url, info): + oe_compatibility_oepkgs_slice = OeCompatibilityOepkgs( + id=info.get("id", ''), + name=info.get("name", ''), + summary=info.get("summary", ''), + repotype=info.get("repoType", ''), + openeuler_version=info.get("os", '') + '-' + info.get("osVer", ''), + rpmpackurl=info.get("rpmPackUrl", ''), + srcrpmpackurl=info.get("srcRpmPackUrl", ''), + arch=info.get("arch", ''), + rpmlicense=info.get("rpmLicense", ''), + version=info.get("version", ''), + ) + try: + with PostgresDB(pg_url).get_session() as session: + session.add(oe_compatibility_oepkgs_slice) + session.commit() + except Exception as e: + return + + @staticmethod + def clear_oe_compatibility_osv(pg_url): + try: + with PostgresDB(pg_url).get_session() as session: + session.execute(text("DROP TABLE IF EXISTS oe_compatibility_osv;")) + session.commit() + PostgresDB(pg_url).create_table() + except Exception as e: + return + + @staticmethod + def add_oe_compatibility_osv(pg_url, info): + oe_compatibility_osv_slice = OeCompatibilityOsv( + id=info.get("id", ''), + arch=info.get("arch", ''), + os_version=info.get("osVersion", ''), + osv_name=info.get("osvName", ''), + date=info.get("date", ''), + os_download_link=info.get("osDownloadLink", ''), + type=info.get("type", ''), + details=info.get("details", ''), + friendly_link=info.get("friendlyLink", ''), + total_result=info.get("totalResult", ''), + checksum=info.get("checksum", ''), + openeuler_version=info.get('baseOpeneulerVersion', '').replace(' ', '-') + ) + try: + with PostgresDB(pg_url).get_session() as session: + session.add(oe_compatibility_osv_slice) + session.commit() + except Exception as e: + return + + @staticmethod + def clear_oe_compatibility_security_notice(pg_url): + try: + with PostgresDB(pg_url).get_session() as session: + session.execute(text("DROP TABLE IF EXISTS oe_compatibility_security_notice;")) + session.commit() + PostgresDB(pg_url).create_table() + except Exception as e: + return + + @staticmethod + def add_oe_compatibility_security_notice(pg_url, info): + oe_compatibility_security_notice_slice = OeCompatibilitySecurityNotice( + id=info.get("id", ''), + affected_component=info.get("affectedComponent", ''), + openeuler_version=info.get("openeuler_version", ''), + announcement_time=info.get("announcementTime", ''), + cve_id=info.get("cveId", ''), + description=info.get("description", ''), + introduction=info.get("introduction", ''), + package_name=info.get("packageName", ''), + reference_documents=info.get("referenceDocuments", ''), + revision_history=info.get("revisionHistory", ''), + security_notice_no=info.get("securityNoticeNo", ''), + subject=info.get("subject", ''), + summary=info.get("summary", ''), + type=info.get("type", ''), + notice_type=info.get("notice_type", ''), + cvrf=info.get("cvrf", ''), + package_helper_list=info.get("packageHelperList", ''), + package_hotpatch_list=info.get("packageHotpatchList", ''), + package_list=info.get("packageList", ''), + reference_list=info.get("referenceList", ''), + cve_list=info.get("cveList", ''), + details=info.get("details", ''), + ) + try: + with PostgresDB(pg_url).get_session() as session: + session.add(oe_compatibility_security_notice_slice) + session.commit() + except Exception as e: + return + + @staticmethod + def clear_oe_compatibility_cve_database(pg_url): + try: + with PostgresDB(pg_url).get_session() as session: + session.execute(text("DROP TABLE IF EXISTS oe_compatibility_cve_database;")) + session.commit() + PostgresDB(pg_url).create_table() + except Exception as e: + return + + @staticmethod + def add_oe_compatibility_cve_database(pg_url, info): + oe_compatibility_cve_database_slice = OeCompatibilityCveDatabase( + id=info.get("id", ''), + affected_product=info.get("productName", ''), + announcement_time=info.get("announcementTime", ''), + attack_complexity_nvd=info.get("attackComplexityNVD", ''), + attack_complexity_oe=info.get("attackVectorNVD", ''), + attack_vector_nvd=info.get("attackVectorOE", ''), + attack_vector_oe=info.get("attackVectorOE", ''), + availability_nvd=info.get("availabilityNVD", ''), + availability_oe=info.get("availabilityOE", ''), + confidentiality_nvd=info.get("confidentialityNVD", ''), + confidentiality_oe=info.get("confidentialityOE", ''), + cve_id=info.get("cveId", ''), + cvsss_core_nvd=info.get("cvsssCoreNVD", ''), + cvsss_core_oe=info.get("cvsssCoreOE", ''), + integrity_nvd=info.get("integrityNVD", ''), + integrity_oe=info.get("integrityOE", ''), + national_cyberAwareness_system=info.get("nationalCyberAwarenessSystem", ''), + package_name=info.get("packageName", ''), + privileges_required_nvd=info.get("privilegesRequiredNVD", ''), + privileges_required_oe=info.get("privilegesRequiredOE", ''), + scope_nvd=info.get("scopeNVD", ''), + scope_oe=info.get("scopeOE", ''), + status=info.get("status", ''), + summary=info.get("summary", ''), + type=info.get("type", ''), + user_interaction_nvd=info.get("userInteractionNVD", ''), + user_interactio_oe=info.get("userInteractionOE", ''), + update_time=info.get("updateTime", ''), + create_time=info.get("createTime", ''), + security_notice_no=info.get("securityNoticeNo", ''), + parser_bean=info.get("parserBean", ''), + cvrf=info.get("cvrf", ''), + package_list=info.get("packageList", ''), + details=info.get("details", ''), + ) + try: + with PostgresDB(pg_url).get_session() as session: + session.add(oe_compatibility_cve_database_slice) + session.commit() + except Exception as e: + return + + @staticmethod + def clear_oe_openeuler_sig_members(pg_url): + try: + with PostgresDB(pg_url).get_session() as session: + session.execute(text("DROP TABLE IF EXISTS oe_openeuler_sig_members CASCADE;")) + session.commit() + PostgresDB(pg_url).create_table() + except Exception as e: + return + + @staticmethod + def add_oe_openeuler_sig_members(pg_url, info): + oe_openeuler_sig_members_slice = OeOpeneulerSigMembers( + name=info.get('name', ''), + gitee_id=info.get('gitee_id', ''), + organization=info.get('organization', ''), + email=info.get('email', ''), + ) + try: + with PostgresDB(pg_url).get_session() as session: + session.add(oe_openeuler_sig_members_slice) + session.commit() + except Exception as e: + return + + @staticmethod + def clear_oe_openeuler_sig_group(pg_url): + try: + with PostgresDB(pg_url).get_session() as session: + session.execute(text("DROP TABLE IF EXISTS oe_openeuler_sig_groups CASCADE;")) + session.commit() + PostgresDB(pg_url).create_table() + except Exception as e: + return + + @staticmethod + def add_oe_openeuler_sig_group(pg_url, info): + oe_openeuler_sig_group_slice = OeOpeneulerSigGroup( + sig_name=info.get("sig_name", ''), + description=info.get("description", ''), + mailing_list=info.get("mailing_list", ''), + created_at=info.get("created_at", ''), + is_sig_original=info.get("is_sig_original", ''), + ) + # max_retries = 10 + # for retry in range(max_retries): + try: + with PostgresDB(pg_url).get_session() as session: + session.add(oe_openeuler_sig_group_slice) + session.commit() + except Exception as e: + return + + @staticmethod + def clear_oe_openeuler_sig_repos(pg_url): + try: + with PostgresDB(pg_url).get_session() as session: + session.execute(text("DROP TABLE IF EXISTS oe_openeuler_sig_repos CASCADE;")) + session.commit() + PostgresDB(pg_url).create_table() + except Exception as e: + return + + @staticmethod + def add_oe_openeuler_sig_repos(pg_url, info): + oe_openeuler_sig_repo_slice = OeOpeneulerSigRepos( + repo=info.get("repo", ''), + url=info.get("url",''), + ) + try: + with PostgresDB(pg_url).get_session() as session: + session.add(oe_openeuler_sig_repo_slice) + session.commit() + except Exception as e: + return + + @staticmethod + def clear_oe_sig_group_to_repos(pg_url): + try: + with PostgresDB(pg_url).get_session() as session: + session.execute(text("DROP TABLE IF EXISTS oe_sig_group_to_repos;")) + session.commit() + PostgresDB(pg_url).create_table() + except Exception as e: + return + @staticmethod + def add_oe_sig_group_to_repos(pg_url, group_name, repo_name): + try: + with PostgresDB(pg_url).get_session() as session: + repo = session.query(OeOpeneulerSigRepos).filter_by(repo=repo_name).first() + group = session.query(OeOpeneulerSigGroup).filter_by(sig_name=group_name).first() + # 插入映射关系 + relations = OeSigGroupRepo(group_id=group.id, repo_id=repo.id) + session.add(relations) + session.commit() + except Exception as e: + return + + @staticmethod + def clear_oe_sig_group_to_members(pg_url): + try: + with PostgresDB(pg_url).get_session() as session: + session.execute(text("DROP TABLE IF EXISTS oe_sig_group_to_members;")) + session.commit() + PostgresDB(pg_url).create_table() + except Exception as e: + return + @staticmethod + def add_oe_sig_group_to_members(pg_url, group_name, member_name, role): + try: + with PostgresDB(pg_url).get_session() as session: + member = session.query(OeOpeneulerSigMembers).filter_by(gitee_id=member_name).first() + group = session.query(OeOpeneulerSigGroup).filter_by(sig_name=group_name).first() + # 插入映射关系 + relations = OeSigGroupMember(member_id=member.id, group_id=group.id, member_role=role) + session.add(relations) + session.commit() + except Exception as e: + return + + @staticmethod + def clear_oe_sig_repos_to_members(pg_url): + try: + with PostgresDB(pg_url).get_session() as session: + session.execute(text("DROP TABLE IF EXISTS oe_sig_repos_to_members;")) + session.commit() + PostgresDB(pg_url).create_table() + except Exception as e: + return + @staticmethod + def add_oe_sig_repos_to_members(pg_url, repo_name, member_name, role): + try: + with PostgresDB(pg_url).get_session() as session: + repo = session.query(OeOpeneulerSigRepos).filter_by(repo=repo_name).first() + member = session.query(OeOpeneulerSigMembers).filter_by(name=member_name).first() + # 插入映射关系 + relations = OeSigRepoMember(repo_id=repo.id, member_id=member.id, member_role=role) + session.add(relations) + session.commit() + except Exception as e: + return + @staticmethod + def clear_oe_community_organization_structure(pg_url): + try: + with PostgresDB(pg_url).get_session() as session: + session.execute(text("DROP TABLE IF EXISTS oe_community_organization_structure;")) + session.commit() + PostgresDB(pg_url).create_table() + except Exception as e: + return + + @staticmethod + def add_oe_community_organization_structure(pg_url, info): + oe_community_organization_structure_slice = OeCommunityOrganizationStructure( + committee_name=info.get('committee_name', ''), + role=info.get('role', ''), + name=info.get('name', ''), + personal_message=info.get('personal_message', '') + ) + try: + with PostgresDB(pg_url).get_session() as session: + session.add(oe_community_organization_structure_slice) + session.commit() + except Exception as e: + return + + @staticmethod + def clear_oe_community_openEuler_version(pg_url): + try: + with PostgresDB(pg_url).get_session() as session: + session.execute(text("DROP TABLE IF EXISTS oe_community_openEuler_version;")) + session.commit() + PostgresDB(pg_url).create_table() + except Exception as e: + return + + @staticmethod + def add_oe_community_openEuler_version(pg_url, info): + oe_community_openeuler_version_slice = OeCommunityOpenEulerVersion( + openeuler_version=info.get('openeuler_version', ''), + kernel_version=info.get('kernel_version', ''), + publish_time=info.get('publish_time', ''), + version_type=info.get('version_type', '') + ) + with PostgresDB(pg_url).get_session() as session: + session.add(oe_community_openeuler_version_slice) + session.commit() diff --git a/spider/oe_spider/pg.py b/spider/oe_spider/pg.py new file mode 100644 index 0000000000000000000000000000000000000000..3c8197811605f05923a89c8f37df2b26b9d418d7 --- /dev/null +++ b/spider/oe_spider/pg.py @@ -0,0 +1,344 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +from threading import Lock + +from sqlalchemy import Column, String, BigInteger, DateTime, create_engine, MetaData, Sequence, ForeignKey +from sqlalchemy.orm import declarative_base, sessionmaker,relationship + + +Base = declarative_base() + + +class OeCompatibilityOverallUnit(Base): + __tablename__ = 'oe_compatibility_overall_unit' + __table_args__ = {'comment': 'openEuler支持的整机信息表,存储了openEuler支持的整机的架构、CPU型号、硬件厂家、硬件型号和相关的openEuler版本'} + id = Column(BigInteger, primary_key=True) + architecture = Column(String(), comment='openEuler支持的整机信息的架构') + bios_uefi = Column(String()) + certification_addr = Column(String()) + certification_time = Column(String()) + commitid = Column(String()) + computer_type = Column(String()) + cpu = Column(String(), comment='openEuler支持的整机信息的CPU型号') + date = Column(String()) + friendly_link = Column(String()) + hard_disk_drive = Column(String()) + hardware_factory = Column(String(), comment='openEuler支持的整机信息的硬件厂家') + hardware_model = Column(String(), comment='openEuler支持的整机信息的硬件型号') + host_bus_adapter = Column(String()) + lang = Column(String()) + main_board_model = Column(String()) + openeuler_version = Column(String(), comment='openEuler支持的整机信息的相关的openEuler版本') + ports_bus_types = Column(String()) + product_information = Column(String()) + ram = Column(String()) + update_time = Column(String()) + video_adapter = Column(String()) + compatibility_configuration = Column(String()) + boardCards = Column(String()) + + +class OeCompatibilityCard(Base): + __tablename__ = 'oe_compatibility_card' + __table_args__ = {'comment': 'openEuler支持的板卡信息表,存储了openEuler支持的板卡信息的支持架构、板卡型号、芯片信号、芯片厂家、驱动名称、相关的openEuler版本、类型和版本'} + id = Column(BigInteger, primary_key=True) + architecture = Column(String(), comment='openEuler支持的板卡信息的支持架构') + board_model = Column(String(), comment='openEuler支持的板卡信息的板卡型号') + chip_model = Column(String(), comment='openEuler支持的板卡信息的芯片型号') + chip_vendor = Column(String(), comment='openEuler支持的板卡信息的芯片厂家') + device_id = Column(String()) + download_link = Column(String()) + driver_date = Column(String()) + driver_name = Column(String(), comment='openEuler支持的板卡信息的驱动名称') + driver_size = Column(String()) + item = Column(String()) + lang = Column(String()) + openeuler_version = Column(String(), comment='openEuler支持的板卡信息的相关的openEuler版本') + sha256 = Column(String()) + ss_id = Column(String()) + sv_id = Column(String()) + type = Column(String(), comment='openEuler支持的板卡信息的类型') + vendor_id = Column(String()) + version = Column(String(), comment='openEuler支持的板卡信息的版本') + + +class OeCompatibilityOpenSourceSoftware(Base): + __tablename__ = 'oe_compatibility_open_source_software' + __table_args__ = {'comment': 'openEuler支持的开源软件信息表,存储了openEuler支持的开源软件的相关openEuler版本、支持的架构、软件属性、开源协议、软件类型、软件名称和软件版本'} + id = Column(BigInteger, primary_key=True, autoincrement=True) + openeuler_version = Column(String(), comment='openEuler支持的开源软件信息表的相关openEuler版本') + arch = Column(String(), comment='openEuler支持的开源软件信息的支持的架构') + property = Column(String(), comment='openEuler支持的开源软件信息的软件属性') + result_url = Column(String()) + result_root = Column(String()) + bin = Column(String()) + uninstall = Column(String()) + license = Column(String(), comment='openEuler支持的开源软件信息的开源协议') + libs = Column(String()) + install = Column(String()) + src_location = Column(String()) + group = Column(String()) + cmds = Column(String()) + type = Column(String(), comment='openEuler支持的开源软件信息的软件类型') + softwareName = Column(String()) + category = Column(String(), comment='openEuler支持的开源软件信息的软件名称') + version = Column(String(), comment='openEuler支持的开源软件信息的软件版本') + downloadLink = Column(String()) + + +class OeCompatibilityCommercialSoftware(Base): + __tablename__ = 'oe_compatibility_commercial_software' + __table_args__ = {'comment': 'openEuler支持的商业软件信息表,存储了openEuler支持的商业软件的软件类型、测试机构、软件名称、厂家名称和相关的openEuler版本'} + id = Column(BigInteger, primary_key=True,) + data_id = Column(String()) + type = Column(String(), comment='openEuler支持的商业软件的软件类型') + test_organization = Column(String(), comment='openEuler支持的商业软件的测试机构') + product_name = Column(String(), comment='openEuler支持的商业软件的软件名称') + product_version = Column(String(), comment='openEuler支持的商业软件的软件版本') + company_name = Column(String(), comment='openEuler支持的商业软件的厂家名称') + platform_type_and_server_model = Column(String()) + authenticate_link = Column(String()) + openeuler_version = Column(String(), comment='openEuler支持的商业软件的openEuler版本') + region = Column(String()) + + +class OeCompatibilityOepkgs(Base): + __tablename__ = 'oe_compatibility_oepkgs' + __table_args__ = { + 'comment': 'openEuler支持的软件包信息表,存储了openEuler支持软件包的名称、简介、是否为官方版本标识、相关的openEuler版本、rpm包下载链接、源码包下载链接、支持的架构、版本'} + id = Column(String(), primary_key=True) + name = Column(String(), comment='openEuler支持的软件包的名称') + summary = Column(String(), comment='openEuler支持的软件包的简介') + repotype = Column(String(), comment='openEuler支持的软件包的是否为官方版本标识') + openeuler_version = Column(String(), comment='openEuler支持的软件包的相关的openEuler版本') + rpmpackurl = Column(String(), comment='openEuler支持的软件包的rpm包下载链接') + srcrpmpackurl = Column(String(), comment='openEuler支持的软件包的源码包下载链接') + arch = Column(String(), comment='openEuler支持的软件包的支持的架构') + rpmlicense = Column(String()) + version = Column(String(), comment='openEuler支持的软件包的版本') + + +class OeCompatibilitySolution(Base): + __tablename__ = 'oe_compatibility_solution' + __table_args__ = {'comment': 'openeuler支持的解决方案表,存储了openeuler支持的解决方案的架构、硬件类别、cpu型号、相关的openEuler版本、硬件型号、厂家和解决方案类型'} + id = Column(String(), primary_key=True,) + architecture = Column(String(), comment='openeuler支持的解决方案的架构') + bios_uefi = Column(String()) + certification_type = Column(String(), comment='openeuler支持的解决方案的硬件') + cpu = Column(String(), comment='openeuler支持的解决方案的硬件的cpu型号') + date = Column(String()) + driver = Column(String()) + hard_disk_drive = Column(String()) + introduce_link = Column(String()) + lang = Column(String()) + libvirt_version = Column(String()) + network_card = Column(String()) + openeuler_version = Column(String(), comment='openeuler支持的解决方案的相关openEuler版本') + ovs_version = Column(String()) + product = Column(String(), comment='openeuler支持的解决方案的硬件型号') + qemu_version = Column(String()) + raid = Column(String()) + ram = Column(String()) + server_model = Column(String(), comment='openeuler支持的解决方案的厂家') + server_vendor = Column(String()) + solution = Column(String(), comment='openeuler支持的解决方案的解决方案类型') + stratovirt_version = Column(String()) + + +class OeCompatibilityOsv(Base): + __tablename__ = 'oe_compatibility_osv' + __table_args__ = { + 'comment': 'openEuler相关的OSV(Operating System Vendor,操作系统供应商)信息表,存储了openEuler相关的OSV(Operating System Vendor,操作系统供应商)的支持的架构、版本、名称、下载链接、详细信息的链接和相关的openEuler版本'} + id = Column(String(), primary_key=True,) + arch = Column(String()) + os_version = Column(String(), comment='openEuler相关的OSV(Operating System Vendor,操作系统供应商)的版本') + osv_name = Column(String(), comment='openEuler相关的OSV(Operating System Vendor,操作系统供应商)的名称') + date = Column(String()) + os_download_link = Column(String(), comment='openEuler相关的OSV(Operating System Vendor,操作系统供应商)的下载链接') + type = Column(String()) + details = Column(String(), comment='openEuler相关的OSV(Operating System Vendor,操作系统供应商)的详细信息的链接') + friendly_link = Column(String()) + total_result = Column(String()) + checksum = Column(String()) + openeuler_version = Column(String(), comment='openEuler相关的OSV(Operating System Vendor,操作系统供应商)的相关的openEuler版本') + + +class OeCompatibilitySecurityNotice(Base): + __tablename__ = 'oe_compatibility_security_notice' + __table_args__ = {'comment': 'openEuler社区组安全公告信息表,存储了安全公告影响的openEuler版本、关联的cve漏洞的id、编号和详细信息的链接'} + id = Column(String(), primary_key=True,) + affected_component = Column(String()) + openeuler_version = Column(String(), comment='openEuler社区组安全公告信息影响的openEuler版本') + announcement_time = Column(String()) + cve_id = Column(String(), comment='openEuler社区组安全公告关联的cve漏洞的id') + description = Column(String()) + introduction = Column(String()) + package_name = Column(String()) + reference_documents = Column(String()) + revision_history = Column(String()) + security_notice_no = Column(String(), comment='openEuler社区组安全公告的编号') + subject = Column(String()) + summary = Column(String()) + type = Column(String()) + notice_type = Column(String()) + cvrf = Column(String()) + package_helper_list = Column(String()) + package_hotpatch_list = Column(String()) + package_list = Column(String()) + reference_list = Column(String()) + cve_list = Column(String()) + details = Column(String(), comment='openEuler社区组安全公告的详细信息的链接') + + +class OeCompatibilityCveDatabase(Base): + __tablename__ = 'oe_compatibility_cve_database' + __table_args__ = {'comment': 'openEuler社区组cve漏洞信息表,存储了cve漏洞的公告时间、id、关联的软件包名称、简介、cvss评分'} + + id = Column(String(), primary_key=True,) + affected_product = Column(String()) + announcement_time = Column( DateTime(), comment='openEuler社区组cve漏洞的公告时间') + attack_complexity_nvd = Column(String()) + attack_complexity_oe = Column(String()) + attack_vector_nvd = Column(String()) + attack_vector_oe = Column(String()) + availability_nvd = Column(String()) + availability_oe = Column(String()) + confidentiality_nvd = Column(String()) + confidentiality_oe = Column(String()) + cve_id = Column(String(), comment='openEuler社区组cve漏洞的id') + cvsss_core_nvd = Column(String(), comment='openEuler社区组cve漏洞的cvss评分') + cvsss_core_oe = Column(String()) + integrity_nvd = Column(String()) + integrity_oe = Column(String()) + national_cyberAwareness_system = Column(String()) + package_name = Column(String(), comment='openEuler社区组cve漏洞的关联的软件包名称') + privileges_required_nvd = Column(String()) + privileges_required_oe = Column(String()) + scope_nvd = Column(String()) + scope_oe = Column(String()) + status = Column(String()) + summary = Column(String(), comment='openEuler社区组cve漏洞的简介') + type = Column(String()) + user_interaction_nvd = Column(String()) + user_interactio_oe = Column(String()) + update_time = Column( DateTime()) + create_time = Column( DateTime()) + security_notice_no = Column(String()) + parser_bean = Column(String()) + cvrf = Column(String()) + package_list = Column(String()) + details = Column(String()) + + +class OeOpeneulerSigGroup(Base): + __tablename__ = 'oe_openeuler_sig_groups' + __table_args__ = {'comment': 'openEuler社区特别兴趣小组信息表,存储了SIG的名称、描述等信息'} + + id = Column(BigInteger(), Sequence('sig_group_id'), primary_key=True) + sig_name = Column(String(), comment='SIG的名称') + description = Column(String(), comment='SIG的描述') + created_at = Column( DateTime()) + is_sig_original = Column(String(), comment='是否为原始SIG') + mailing_list = Column(String(), comment='SIG的邮件列表') + repos_group = relationship("OeSigGroupRepo", back_populates="group") + members_group = relationship("OeSigGroupMember", back_populates="group") + +class OeOpeneulerSigMembers(Base): + __tablename__ = 'oe_openeuler_sig_members' + __table_args = { 'comment': 'openEuler社区特别兴趣小组成员信息表,储存了小组成员信息,如gitee用户名、所属组织等'} + + id = Column(BigInteger(), Sequence('sig_members_id'), primary_key=True) + name = Column(String(), comment='成员名称') + gitee_id = Column(String(), comment='成员gitee用户名') + organization = Column(String(), comment='成员所属组织') + email = Column(String(), comment='成员邮箱地址') + + groups_member = relationship("OeSigGroupMember", back_populates="member") + repos_member = relationship("OeSigRepoMember", back_populates="member") + +class OeOpeneulerSigRepos(Base): + __tablename__ = 'oe_openeuler_sig_repos' + __table_args__ = {'comment': 'openEuler社区特别兴趣小组信息表,存储了SIG的名称、描述等信息'} + + id = Column(BigInteger(), Sequence('sig_repos_id'), primary_key=True) + repo = Column(String(), comment='repo地址') + url = Column(String(), comment='repo URL') + groups_repo = relationship("OeSigGroupRepo", back_populates="repo") + members_repo = relationship("OeSigRepoMember", back_populates="repo") +class OeSigGroupRepo(Base): + __tablename__ = 'oe_sig_group_to_repos' + __table_args__ = {'comment': 'SIG group id包含repos'} + + group_id = Column(BigInteger(), ForeignKey('oe_openeuler_sig_groups.id'), primary_key=True) + repo_id = Column(BigInteger(), ForeignKey('oe_openeuler_sig_repos.id'), primary_key=True) + group = relationship("OeOpeneulerSigGroup", back_populates="repos_group") + repo = relationship("OeOpeneulerSigRepos", back_populates="groups_repo") + +class OeSigRepoMember(Base): + __tablename__ = 'oe_sig_repos_to_members' + __table_args__ = {'comment': 'SIG repo id对应维护成员'} + + member_id = Column(BigInteger(), ForeignKey('oe_openeuler_sig_members.id'), primary_key=True) + repo_id = Column(BigInteger(), ForeignKey('oe_openeuler_sig_repos.id'), primary_key=True) + member_role = Column(String(), comment='成员属性maintainer or committer') # repo_member + member = relationship("OeOpeneulerSigMembers", back_populates="repos_member") + repo = relationship("OeOpeneulerSigRepos", back_populates="members_repo") + +class OeSigGroupMember(Base): + __tablename__ = 'oe_sig_group_to_members' + __table_args__ = {'comment': 'SIG group id对应包含成员id'} + + group_id = Column(BigInteger(), ForeignKey('oe_openeuler_sig_groups.id'), primary_key=True) + member_id = Column(BigInteger(), ForeignKey('oe_openeuler_sig_members.id'), primary_key=True) + member_role = Column(String(), comment='成员属性maintainer or committer') + group = relationship("OeOpeneulerSigGroup", back_populates="members_group") + member = relationship("OeOpeneulerSigMembers", back_populates="groups_member") + +class OeCommunityOrganizationStructure(Base): + __tablename__ = 'oe_community_organization_structure' + __table_args__ = {'comment': 'openEuler社区组织架构信息表,存储了openEuler社区成员所属的委员会、职位、姓名和个人信息'} + id = Column(BigInteger, primary_key=True, autoincrement=True) + committee_name = Column(String(), comment='openEuler社区成员所属的委员会') + role = Column(String(), comment='openEuler社区成员的职位') + name = Column(String(), comment='openEuler社区成员的姓名') + personal_message = Column(String(), comment='openEuler社区成员的个人信息') + + +class OeCommunityOpenEulerVersion(Base): + __tablename__ = 'oe_community_openeuler_version' + __table_args__ = {'comment': 'openEuler社区openEuler版本信息表,存储了openEuler版本的版本、内核版本、发布日期和版本类型'} + id = Column(BigInteger, primary_key=True, autoincrement=True) + openeuler_version = Column(String(), comment='openEuler版本的版本号') + kernel_version = Column(String(), comment='openEuler版本的内核版本') + publish_time = Column( DateTime(), comment='openEuler版本的发布日期') + version_type = Column(String(), comment='openEuler社区成员的版本类型') + + + +class PostgresDBMeta(type): + _instances = {} + _lock: Lock = Lock() + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + instance = super().__call__(*args, **kwargs) + cls._instances[cls] = instance + return cls._instances[cls] + + +class PostgresDB(metaclass=PostgresDBMeta): + + def __init__(self, pg_url): + self.engine = create_engine( + pg_url, + echo=False, + pool_pre_ping=True) + self.create_table() + + def create_table(self): + Base.metadata.create_all(self.engine) + + def get_session(self): + return sessionmaker(bind=self.engine)() + + def close(self): + self.engine.dispose() diff --git a/spider/oe_spider/pull_message_from_oe_web.py b/spider/oe_spider/pull_message_from_oe_web.py new file mode 100644 index 0000000000000000000000000000000000000000..8ad7a30f6e3681e54e38788431af794f2edd6e06 --- /dev/null +++ b/spider/oe_spider/pull_message_from_oe_web.py @@ -0,0 +1,894 @@ +from multiprocessing import Process, set_start_method +import time + +import requests +import json +import copy +from datetime import datetime +import os +import yaml +from yaml import SafeLoader +from oe_message_manager import OeMessageManager +import argparse +import urllib.parse +from concurrent.futures import ProcessPoolExecutor, as_completed + + +class PullMessageFromOeWeb: + @staticmethod + def add_to_DB(pg_url, results, dataset_name): + dataset_name_map = { + 'overall_unit': OeMessageManager.add_oe_compatibility_overall_unit, + 'card': OeMessageManager.add_oe_compatibility_card, + 'open_source_software': OeMessageManager.add_oe_compatibility_open_source_software, + 'commercial_software': OeMessageManager.add_oe_compatibility_commercial_software, + 'solution': OeMessageManager.add_oe_compatibility_solution, + 'oepkgs': OeMessageManager.add_oe_compatibility_oepkgs, + 'osv': OeMessageManager.add_oe_compatibility_osv, + 'security_notice': OeMessageManager.add_oe_compatibility_security_notice, + 'cve_database': OeMessageManager.add_oe_compatibility_cve_database, + 'sig_group': OeMessageManager.add_oe_openeuler_sig_group, + 'sig_members': OeMessageManager.add_oe_openeuler_sig_members, + 'sig_repos': OeMessageManager.add_oe_openeuler_sig_repos, + } + + for i in range(len(results)): + for key in results[i]: + if isinstance(results[i][key], (dict, list, tuple)): + results[i][key] = json.dumps(results[i][key]) + dataset_name_map[dataset_name](pg_url, results[i]) + + + @staticmethod + def save_result(pg_url, results, dataset_name): + prompt_map = {'card': 'openEuler支持的板卡信息', + 'commercial_software': 'openEuler支持的商业软件', + 'cve_database': 'openEuler的cve信息', + 'oepkgs': 'openEuler支持的软件包信息', + 'solution': 'openEuler支持的解决方案', + 'open_source_software': 'openEuler支持的开源软件信息', + 'overall_unit': 'openEuler支持的整机信息', + 'security_notice': 'openEuler官网的安全公告', + 'osv': 'openEuler相关的osv厂商', + 'openeuler_version_message': 'openEuler的版本信息', + 'organize_message': 'openEuler社区成员组织架构', + 'openeuler_sig': 'openeuler_sig(openEuler SIG组成员信息)', + 'sig_group': 'openEuler SIG组信息', + 'sig_members': 'openEuler SIG组成员信息', + 'sig_repos': 'openEuler SIG组仓库信息'} + dataset_name_map = { + 'overall_unit': OeMessageManager.clear_oe_compatibility_overall_unit, + 'card': OeMessageManager.clear_oe_compatibility_card, + 'open_source_software': OeMessageManager.clear_oe_compatibility_open_source_software, + 'commercial_software': OeMessageManager.clear_oe_compatibility_commercial_software, + 'solution': OeMessageManager.clear_oe_compatibility_solution, + 'oepkgs': OeMessageManager.clear_oe_compatibility_oepkgs, + 'osv': OeMessageManager.clear_oe_compatibility_osv, + 'security_notice': OeMessageManager.clear_oe_compatibility_security_notice, + 'cve_database': OeMessageManager.clear_oe_compatibility_cve_database, + 'sig_group': OeMessageManager.clear_oe_openeuler_sig_group, + 'sig_members': OeMessageManager.clear_oe_openeuler_sig_members, + 'sig_repos': OeMessageManager.clear_oe_openeuler_sig_repos, + } + process = [] + num_cores = (os.cpu_count()+1) // 2 + if num_cores > 8: + num_cores = 8 + + print(f"{prompt_map[dataset_name]}录入中") + dataset_name_map[dataset_name](pg_url) + step = len(results) // num_cores + process = [] + for i in range(num_cores): + begin = i * step + end = begin + step + if i == num_cores - 1: + end = len(results) + sub_result = results[begin:end] + p = Process(target=PullMessageFromOeWeb.add_to_DB, args=(pg_url, sub_result, dataset_name)) + process.append(p) + p.start() + + for p in process: + p.join() + print('数据插入完成') + + @staticmethod + def pull_oe_compatibility_overall_unit(pg_url, dataset_name): + url = 'https://www.openeuler.org/api-cve/cve-security-notice-server/hardwarecomp/findAll' + headers = { + 'Content-Type': 'application/json' + } + data = { + "pages": { + "page": 1, + "size": 10 + }, + "architecture": "", + "keyword": "", + "cpu": "", + "os": "", + "testOrganization": "", + "type": "", + "cardType": "", + "lang": "zh", + "dataSource": "assessment", + "solution": "", + "certificationType": "" + } + + response = requests.post( + url, + headers=headers, + json=data + ) + results = [] + total_num = response.json()["result"]["totalCount"] + for i in range(total_num // 10 + (total_num % 10 != 0)): + data = { + "pages": { + "page": i + 1, + "size": 10 + }, + "architecture": "", + "keyword": "", + "cpu": "", + "os": "", + "testOrganization": "", + "type": "", + "cardType": "", + "lang": "zh", + "dataSource": "assessment", + "solution": "", + "certificationType": "" + } + response = requests.post( + url, + headers=headers, + json=data + ) + results += response.json()["result"]["hardwareCompList"] + + PullMessageFromOeWeb.save_result(pg_url, results, dataset_name) + + @staticmethod + def pull_oe_compatibility_card(pg_url, dataset_name): + url = 'https://www.openeuler.org/api-cve/cve-security-notice-server/drivercomp/findAll' + headers = { + 'Content-Type': 'application/json' + } + data = { + "pages": { + "page": 1, + "size": 10 + }, + "architecture": "", + "keyword": "", + "cpu": "", + "os": "", + "testOrganization": "", + "type": "", + "cardType": "", + "lang": "zh", + "dataSource": "assessment", + "solution": "", + "certificationType": "" + } + response = requests.post( + url, + headers=headers, + json=data + ) + results = [] + total_num = response.json()["result"]["totalCount"] + for i in range(total_num // 10 + (total_num % 10 != 0)): + data = { + "pages": { + "page": i + 1, + "size": 10 + }, + "architecture": "", + "keyword": "", + "cpu": "", + "os": "", + "testOrganization": "", + "type": "", + "cardType": "", + "lang": "zh", + "dataSource": "", + "solution": "", + "certificationType": "" + } + response = requests.post( + url, + headers=headers, + json=data + ) + results += response.json()["result"]["driverCompList"] + PullMessageFromOeWeb.save_result(pg_url, results, dataset_name) + + + @staticmethod + def pull_oe_compatibility_commercial_software(pg_url, dataset_name): + url = 'https://www.openeuler.org/api-certification/server/certification/software/communityChecklist' + headers = { + 'Content-Type': 'application/json' + } + data = { + "pageSize": 1, + "pageNo": 1, + "testOrganization": "", + "osName": "", + "keyword": "", + "dataSource": [ + "assessment" + ], + "productType": [ + "软件" + ] + } + response = requests.post( + url, + headers=headers, + json=data + ) + results = [] + total_num = response.json()["result"]["totalNum"] + for i in range(total_num // 10 + (total_num % 10 != 0)): + data = { + "pageSize": i + 1, + "pageNo": 10, + "testOrganization": "", + "osName": "", + "keyword": "", + "dataSource": [ + "assessment" + ], + "productType": [ + "软件" + ] + } + response = requests.post( + url, + headers=headers, + json=data + ) + results += response.json()["result"]["data"] + PullMessageFromOeWeb.save_result(pg_url, results, dataset_name) + + @staticmethod + def pull_oe_compatibility_open_source_software(pg_url, dataset_name): + url = 'https://www.openeuler.org/compatibility/api/web_backend/compat_software_info' + headers = { + 'Content-Type': 'application/json' + } + data = { + "page_size": 1, + "page_num": 1 + } + response = requests.get( + url, + headers=headers, + data=data + ) + results = [] + total_num = response.json()["total"] + for i in range(total_num // 10 + (total_num % 10 != 0)): + data = { + "page_size": i + 1, + "page_num": 10 + } + response = requests.get( + url, + headers=headers, + data=data + ) + results += response.json()["info"] + PullMessageFromOeWeb.save_result(pg_url, results, dataset_name) + + @staticmethod + def pull_oe_compatibility_oepkgs(pg_url, oepkgs_info, dataset_name): + headers = { + 'Content-Type': 'application/json' + } + params = {'user': oepkgs_info.get('oepkgs_user', ''), + 'password': oepkgs_info.get('oepkgs_pwd', ''), + 'expireInSeconds': 36000 + } + url = 'https://search.oepkgs.net/api/search/openEuler/genToken' + response = requests.post(url, headers=headers, params=params) + token = response.json()['data']['Authorization'] + headers = { + 'Content-Type': 'application/json', + 'Authorization': token + } + url = 'https://search.oepkgs.net/api/search/openEuler/scroll?scrollId=' + response = requests.get(url, headers=headers) + data = response.json()['data'] + scrollId = data['scrollId'] + totalHits = data['totalHits'] + data_list = data['list'] + results = data_list + totalHits -= 1000 + + while totalHits > 0: + url = 'https://search.oepkgs.net/api/search/openEuler/scroll?scrollId=' + scrollId + response = requests.get(url, headers=headers) + data = response.json()['data'] + data_list = data['list'] + results += data_list + totalHits -= 1000 + + PullMessageFromOeWeb.save_result(pg_url, results, dataset_name) + + @staticmethod + def pull_oe_compatibility_solution(pg_url, dataset_name): + url = 'https://www.openeuler.org/api-cve/cve-security-notice-server/solutioncomp/findAll' + headers = { + 'Content-Type': 'application/json' + } + data = { + "pages": { + "page": 1, + "size": 1 + }, + "architecture": "", + "keyword": "", + "cpu": "", + "os": "", + "testOrganization": "", + "type": "", + "cardType": "", + "lang": "zh", + "dataSource": "assessment", + "solution": "", + "certificationType": "" + } + response = requests.post( + url, + headers=headers, + json=data + ) + results = [] + total_num = response.json()["result"]["totalCount"] + for i in range(total_num // 10 + (total_num % 10 != 0)): + data = { + "pages": { + "page": i + 1, + "size": total_num + }, + "architecture": "", + "keyword": "", + "cpu": "", + "os": "", + "testOrganization": "", + "type": "", + "cardType": "", + "lang": "zh", + "dataSource": "assessment", + "solution": "", + "certificationType": "" + } + response = requests.post( + url, + headers=headers, + json=data + ) + results += response.json()["result"]["solutionCompList"] + PullMessageFromOeWeb.save_result(pg_url, results, dataset_name) + + @staticmethod + def pull_oe_compatibility_osv(pg_url, dataset_name): + url = 'https://www.openeuler.org/api-cve/cve-security-notice-server/osv/findAll' + headers = { + 'Content-Type': 'application/json' + } + data = { + "pages": { + "page": 1, + "size": 10 + }, + "keyword": "", + "type": "", + "osvName": "" + } + response = requests.post( + url, + headers=headers, + json=data + ) + results = [] + total_num = response.json()["result"]["totalCount"] + for i in range(total_num // 10 + (total_num % 10 != 0)): + data = { + "pages": { + "page": i + 1, + "size": 10 + }, + "keyword": "", + "type": "", + "osvName": "" + } + response = requests.post( + url, + headers=headers, + json=data + ) + results += response.json()["result"]["osvList"] + for i in range(len(results)): + results[i]['details'] = 'https://www.openeuler.org/zh/approve/approve-info/?id=' + str(results[i]['id']) + + PullMessageFromOeWeb.save_result(pg_url, results, dataset_name) + + @staticmethod + def pull_oe_compatibility_security_notice(pg_url, dataset_name): + url = 'https://www.openeuler.org/api-cve/cve-security-notice-server/securitynotice/findAll' + headers = { + 'Content-Type': 'application/json' + } + data = { + "pages": { + "page": 1, + "size": 1 + }, + "keyword": "", + "type": [], + "date": [], + "affectedProduct": [], + "affectedComponent": "", + "noticeType": "cve" + } + response = requests.post( + url, + headers=headers, + json=data + ) + results = [] + total_num = response.json()["result"]["totalCount"] + for i in range(total_num // 10 + (total_num % 10 != 0)): + data = { + "pages": { + "page": i + 1, + "size": 10 + }, + "keyword": "", + "type": [], + "date": [], + "affectedProduct": [], + "affectedComponent": "", + "noticeType": "cve" + } + response = requests.post( + url, + headers=headers, + json=data + ) + results += response.json()["result"]["securityNoticeList"] + + new_results = [] + for i in range(len(results)): + openeuler_version_list = results[i]['affectedProduct'].split(';') + cve_id_list = results[i]['cveId'].split(';') + cnt = 0 + for openeuler_version in openeuler_version_list: + if len(openeuler_version) > 0: + for cve_id in cve_id_list: + if len(cve_id) > 0: + tmp_dict = copy.deepcopy(results[i]) + del tmp_dict['affectedProduct'] + tmp_dict['id'] = int(str(tmp_dict['id']) + str(cnt)) + tmp_dict['openeuler_version'] = openeuler_version + tmp_dict['cveId'] = cve_id + new_results.append(tmp_dict) + tmp_dict[ + 'details'] = 'https://www.openeuler.org/zh/security/security-bulletins/detail/?id=' + \ + tmp_dict['securityNoticeNo'] + cnt += 1 + results = new_results + PullMessageFromOeWeb.save_result(pg_url, results, dataset_name) + + @staticmethod + def pull_oe_compatibility_cve_database(pg_url, dataset_name): + url = 'https://www.openeuler.org/api-cve/cve-security-notice-server/cvedatabase/findAll' + headers = { + 'Content-Type': 'application/json' + } + data = { + "pages": { + "page": 1, + "size": 1 + }, + "keyword": "", + "status": "", + "year": "", + "score": "", + "noticeType": "cve" + } + response = requests.post( + url, + headers=headers, + json=data + ) + js = response.json() + results = [] + total_num = response.json()["result"]["totalCount"] + for i in range(total_num // 10 + (total_num % 10 != 0)): + data = { + "pages": { + "page": i + 1, + "size": 10 + }, + "keyword": "", + "status": "", + "year": "", + "score": "", + "noticeType": "cve" + } + response = requests.post( + url, + headers=headers, + json=data + ) + results += response.json()["result"]["cveDatabaseList"] + + data_cnt=0 + new_results=[] + for i in range(len(results)): + id1=results[i]['cveId'] + results[i]['details'] = 'https://www.openeuler.org/zh/security/cve/detail/?cveId=' + \ + results[i]['cveId'] + '&packageName=' + results[i]['packageName'] + url='https://www.openeuler.org/api-cve/cve-security-notice-server/cvedatabase/getCVEProductPackageList?cveId=' + \ + results[i]['cveId'] + '&packageName=' + results[i]['packageName'] + try: + response = requests.get( + url, + headers=headers + ) + for key in response.json()["result"].keys(): + if key not in results[i].keys() or results[i][key]=='': + results[i][key] = response.json()["result"][key] + url='https://www.openeuler.org/api-cve/cve-security-notice-server/cvedatabase/getCVEProductPackageList?cveId='+ \ + results[i]['cveId'] + '&packageName=' + results[i]['packageName'] + response = requests.get( + url, + headers=headers + ) + except: + pass + try: + for j in range(len(response.json()["result"])): + tmp=copy.deepcopy(results[i]) + for key in response.json()["result"][j].keys(): + if key not in tmp.keys() or tmp[key]=='': + tmp[key] = response.json()["result"][j][key] + try: + date_str = tmp['updateTime'] + date_format = "%Y-%m-%d %H:%M:%S" + dt = datetime.strptime(date_str, date_format) + year_month_day = dt.strftime("%Y-%m-%d") + tmp['announcementTime']= year_month_day + except: + pass + tmp['id']=str(data_cnt) + data_cnt+=1 + new_results.append(tmp) + except: + pass + results=new_results + + PullMessageFromOeWeb.save_result(pg_url, results, dataset_name) + + @staticmethod + def pull_oe_openeuler_sig(pg_url): + sig_url_base = 'https://www.openeuler.org/api-dsapi/query/sig/info' + headers = { + 'Content-Type': 'application/json' + } + data = { + "community": "openeuler", + "page": 1, + "pageSize": 12, + "search": "fuzzy" + } + sig_url = sig_url_base + '?' + urllib.parse.urlencode(data) + response = requests.post( + sig_url, + headers=headers, + json=data + ) + results_all = [] + total_num = response.json()["data"][0]["total"] + for i in range(total_num // 12 + (total_num % 12 != 0)): + data = { + "community": "openeuler", + "page": i + 1, + "pageSize": 12, + "search": "fuzzy" + } + sig_url = sig_url_base + '?' + urllib.parse.urlencode(data) + response = requests.post( + sig_url, + headers=headers, + json=data + ) + results_all += response.json()["data"][0]["data"] + + results_members = [] + results_repos = [] + for i in range(len(results_all)): + temp_members = [] + sig_name = results_all[i]['sig_name'] + repos_url_base = 'https://www.openeuler.org/api-dsapi/query/sig/repo/committers' + repos_url = repos_url_base + '?community=openeuler&sig=' + sig_name + try: + response = requests.get(repos_url, timeout=2) + except Exception as e: + k = 0 + while response.status_code != 200 and k < 5: + k = k + 1 + try: + response = requests.get(repos_url, timeout=2) + except Exception as e: + print(e) + # continue + results = response.json()["data"] + committers = results['committers'] + maintainers = results['maintainers'] + if committers is None: + committers = [] + if maintainers is None: + maintainers = [] + results_all[i]['committers'] = committers + results_all[i]['maintainers'] = maintainers + for key in results_all[i]: + if key == "committer_info" or key == "maintainer_info": # solve with members + if results_all[i][key] is not None: + for member in results_all[i][key]: + member['sig_name'] = sig_name + temp_members.append(member) + elif key == "repos": # solve with repos + if results_all[i][key] is not None: + for repo in results_all[i][key]: + results_repos.append({'repo': repo, 'url': 'https://gitee.com/' + repo, + 'sig_name': sig_name, + 'committers': committers, + 'maintainers': maintainers}) + else: # solve others + if isinstance(results_all[i][key], (dict, list, tuple)): + results_all[i][key] = json.dumps(results_all[i][key]) + + results_members += temp_members + + temp_results_members = [] + seen = set() + for d in results_members: + if d['gitee_id'] not in seen: + seen.add(d['gitee_id']) + temp_results_members.append(d) + results_members = temp_results_members + + PullMessageFromOeWeb.save_result(pg_url, results_all, 'sig_group') + PullMessageFromOeWeb.save_result(pg_url, results_members, 'sig_members') + PullMessageFromOeWeb.save_result(pg_url, results_repos, 'sig_repos') + + OeMessageManager.clear_oe_sig_group_to_repos(pg_url) + for i in range(len(results_repos)): + repo_name = results_repos[i]['repo'] + group_name = results_repos[i]['sig_name'] + OeMessageManager.add_oe_sig_group_to_repos(pg_url, group_name, repo_name) + + OeMessageManager.clear_oe_sig_group_to_members(pg_url) + for i in range(len(results_all)): + committers = set(json.loads(results_all[i]['committers'])) + maintainers = set(json.loads(results_all[i]['maintainers'])) + group_name = results_all[i]['sig_name'] + + all_members = committers.union(maintainers) + + for member_name in all_members: + if member_name in committers and member_name in maintainers: + role = 'committer & maintainer' + elif member_name in committers: + role = 'committer' + elif member_name in maintainers: + role = 'maintainer' + + OeMessageManager.add_oe_sig_group_to_members(pg_url, group_name, member_name, role=role) + + OeMessageManager.clear_oe_sig_repos_to_members(pg_url) + for i in range(len(results_repos)): + repo_name = results_repos[i]['repo'] + committers = set(results_repos[i]['committers']) + maintainers = set(results_repos[i]['maintainers']) + + all_members = committers.union(maintainers) + + for member_name in all_members: + if member_name in committers and member_name in maintainers: + role = 'committer & maintainer' + elif member_name in committers: + role = 'committer' + elif member_name in maintainers: + role = 'maintainer' + + OeMessageManager.add_oe_sig_repos_to_members(pg_url, repo_name, member_name, role=role) + + @staticmethod + def oe_organize_message_handler(pg_url, oe_spider_method): + f = open('./docs/organize.txt', 'r', encoding='utf-8') + lines = f.readlines() + st = 0 + en = 0 + filed_list = ['role', 'name', 'personal_message'] + OeMessageManager.clear_oe_community_organization_structure(pg_url) + while st < len(lines): + while en < len(lines) and lines[en] != '\n': + en += 1 + lines[st] = lines[st].replace('\n', '') + committee_name = lines[st] + for i in range(st + 1, en): + data = lines[i].replace('\n', '').split(' ') + info = {'committee_name': committee_name} + for j in range(min(len(filed_list), len(data))): + data[j] = data[j].replace('\n', '') + if j == 2: + data[j] = data[j].replace('_', ' ') + info[filed_list[j]] = data[j] + OeMessageManager.add_oe_community_organization_structure(pg_url, info) + st = en + 1 + en = st + + @staticmethod + def oe_openeuler_version_message_handler(pg_url, oe_spider_method): + f = open('./docs/openeuler_version.txt', 'r', encoding='utf-8') + lines = f.readlines() + filed_list = ['openeuler_version', 'kernel_version', 'publish_time', 'version_type'] + OeMessageManager.clear_oe_community_openEuler_version(pg_url) + for line in lines: + tmp_list = line.replace('\n', '').split(' ') + tmp_list[2] = datetime.strptime(tmp_list[2], "%Y-%m-%d") + tmp_dict = {} + for i in range(len(filed_list)): + tmp_dict[filed_list[i]] = tmp_list[i] + OeMessageManager.add_oe_community_openEuler_version(pg_url, tmp_dict) + + +def work(args): + func_map = {'card': PullMessageFromOeWeb.pull_oe_compatibility_card, + 'commercial_software': PullMessageFromOeWeb.pull_oe_compatibility_commercial_software, + 'cve_database': PullMessageFromOeWeb.pull_oe_compatibility_cve_database, + 'oepkgs': PullMessageFromOeWeb.pull_oe_compatibility_oepkgs, + 'solution':PullMessageFromOeWeb.pull_oe_compatibility_solution, + 'open_source_software': PullMessageFromOeWeb.pull_oe_compatibility_open_source_software, + 'overall_unit': PullMessageFromOeWeb.pull_oe_compatibility_overall_unit, + 'security_notice': PullMessageFromOeWeb.pull_oe_compatibility_security_notice, + 'osv': PullMessageFromOeWeb.pull_oe_compatibility_osv, + 'openeuler_version_message': PullMessageFromOeWeb.oe_openeuler_version_message_handler, + 'organize_message': PullMessageFromOeWeb.oe_organize_message_handler, + 'openeuler_sig': PullMessageFromOeWeb.pull_oe_openeuler_sig} + prompt_map = {'card': 'openEuler支持的板卡信息', + 'commercial_software': 'openEuler支持的商业软件', + 'cve_database': 'openEuler的cve信息', + 'oepkgs': 'openEuler支持的软件包信息', + 'solution':'openEuler支持的解决方案', + 'open_source_software': 'openEuler支持的开源软件信息', + 'overall_unit': 'openEuler支持的整机信息', + 'security_notice': 'openEuler官网的安全公告', + 'osv': 'openEuler相关的osv厂商', + 'openeuler_version_message': 'openEuler的版本信息', + 'organize_message': 'openEuler社区成员组织架构', + 'openeuler_sig': 'openeuler_sig(openEuler SIG组成员信息)'} + method = args['method'] + pg_host = args['pg_host'] + pg_port = args['pg_port'] + pg_user = args['pg_user'] + pg_pwd = args['pg_pwd'] + pg_database = args['pg_database'] + oepkgs_user = args['oepkgs_user'] + oepkgs_pwd = args['oepkgs_pwd'] + oe_spider_method = args['oe_spider_method'] + if method == 'init_pg_info': + if pg_host is None or pg_port is None or pg_pwd is None or pg_pwd is None or pg_database is None: + print('请入完整pg配置信息') + exit() + pg_info = {'pg_host': pg_host, 'pg_port': pg_port, + 'pg_user': pg_user, 'pg_pwd': pg_pwd, 'pg_database': pg_database} + if not os.path.exists('./config'): + os.mkdir('./config') + with open('./config/pg_info.yaml', 'w', encoding='utf-8') as f: + yaml.dump(pg_info, f) + elif method == 'init_oepkgs_info': + if oepkgs_user is None or oepkgs_pwd is None: + print('请入完整oepkgs配置信息') + exit() + oepkgs_info = {'oepkgs_user': oepkgs_user, 'oepkgs_pwd': oepkgs_pwd} + if not os.path.exists('./config'): + os.mkdir('./config') + with open('./config/oepkgs_info.yaml', 'w', encoding='utf-8') as f: + yaml.dump(oepkgs_info, f) + elif method == 'update_oe_table': + if not os.path.exists('./config/pg_info.yaml'): + print('请先配置postgres数据库信息') + exit() + if (oe_spider_method == 'all' or oe_spider_method == 'oepkgs') and not os.path.exists( + './config/oepkgs_info.yaml'): + print('请先配置oepkgs用户信息') + exit() + pg_info = {} + try: + with open('./config/pg_info.yaml', 'r', encoding='utf-8') as f: + pg_info = yaml.load(f, Loader=SafeLoader) + except: + print('postgres数据库配置文件加载失败') + exit() + pg_host = pg_info.get('pg_host', '') + ':' + pg_info.get('pg_port', '') + pg_user = pg_info.get('pg_user', '') + pg_pwd = pg_info.get('pg_pwd', '') + pg_database = pg_info.get('pg_database', '') + pg_url = f'postgresql+psycopg2://{pg_user}:{pg_pwd}@{pg_host}/{pg_database}' + if oe_spider_method == 'all' or oe_spider_method == 'oepkgs': + with open('./config/oepkgs_info.yaml', 'r', encoding='utf-8') as f: + try: + oepkgs_info = yaml.load(f, Loader=SafeLoader) + except: + print('oepkgs配置信息读取失败') + exit() + if oe_spider_method == 'all': + for func in func_map: + print(f'开始爬取{prompt_map[func]}') + try: + if func == 'oepkgs': + func_map[func](pg_url, oepkgs_info, func) + elif func == 'openeuler_sig': + func_map[func](pg_url) + else: + func_map[func](pg_url, func) + print(prompt_map[func] + '入库成功') + except Exception as e: + import traceback + print(traceback.format_exc) + print(f'{prompt_map[func]}入库失败由于:{e}') + else: + try: + print(f'开始爬取{prompt_map[oe_spider_method]}') + if oe_spider_method == 'oepkgs': + func_map[oe_spider_method](pg_url, oepkgs_info, oe_spider_method) + elif oe_spider_method == 'openeuler_sig': + func_map[oe_spider_method](pg_url) + else: + func_map[oe_spider_method](pg_url, oe_spider_method) + print(prompt_map[oe_spider_method] + '入库成功') + except Exception as e: + print(f'{prompt_map[oe_spider_method]}入库失败由于:{e}') + + +def init_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--method", type=str, required=True, + choices=['init_pg_info', 'init_oepkgs_info', 'update_oe_table'], + help='''脚本使用模式,有初始化pg信息、初始化oepkgs账号信息和爬取openEuler官网数据的功能''') + parser.add_argument("--pg_host", default=None, required=False, help="postres的ip") + parser.add_argument("--pg_port", default=None, required=False, help="postres的端口") + parser.add_argument("--pg_user", default=None, required=False, help="postres的用户") + parser.add_argument("--pg_pwd", default=None, required=False, help="postres的密码") + parser.add_argument("--pg_database", default=None, required=False, help="postres的数据库") + parser.add_argument("--oepkgs_user", default=None, required=False, help="语料库所在postres的ip") + parser.add_argument("--oepkgs_pwd", default=None, required=False, help="语料库所在postres的端口") + parser.add_argument( + "--oe_spider_method", default='all', required=False, + choices=['all', 'card', 'commercial_software', 'cve_database', 'oepkgs','solution','open_source_software', + 'overall_unit', 'security_notice', 'osv', 'cve_database', 'openeuler_version_message', + 'organize_message', 'openeuler_sig'], + help="需要爬取的openEuler数据类型,有all(所有内容),card(openEuler支持的板卡信息),commercial_software(openEuler支持的商业软件)," + "cve_database(openEuler的cve信息),oepkgs(openEuler支持的软件包信息),solution(openEuler支持的解决方案),open_source_software(openEuler支持的开源软件信息)," + "overall_unit(openEuler支持的整机信息),security_notice(openEuler官网的安全公告),osv(openEuler相关的osv厂商)," + "cve_database(openEuler的cve漏洞),openeuler_version_message(openEuler的版本信息)," + "organize_message(openEuler社区成员组织架构),openeuler_sig(openEuler SIG组成员信息)") + args = parser.parse_args() + return args + + +if __name__ == "__main__": + set_start_method('spawn') + args = init_args() + work(vars(args))