diff --git a/chat2db/app/__init__.py b/chat2db/app/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/chat2db/app/app.py b/chat2db/app/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..bce8892578b2f1d4151e0ec210d8b212cd5647bf
--- /dev/null
+++ b/chat2db/app/app.py
@@ -0,0 +1,35 @@
+import uvicorn
+from fastapi import FastAPI
+from chat2db.app.router import sql_example
+from chat2db.app.router import sql_generate
+from chat2db.app.router import database
+from chat2db.app.router import table
+from chat2db.config.config import config
+import logging
+
+
+logging.basicConfig(filename='app.log', level=logging.INFO,
+ format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s')
+
+app = FastAPI()
+
+app.include_router(sql_example.router)
+app.include_router(sql_generate.router)
+app.include_router(database.router)
+app.include_router(table.router)
+
+if __name__ == '__main__':
+ try:
+ ssl_enable = config["SSL_ENABLE"]
+ if ssl_enable:
+ uvicorn.run(app, host=config["UVICORN_IP"], port=int(config["UVICORN_PORT"]),
+ proxy_headers=True, forwarded_allow_ips='*',
+ ssl_certfile=config["SSL_CERTFILE"],
+ ssl_keyfile=config["SSL_KEYFILE"],
+ )
+ else:
+ uvicorn.run(app, host=config["UVICORN_IP"], port=int(config["UVICORN_PORT"]),
+ proxy_headers=True, forwarded_allow_ips='*'
+ )
+ except Exception as e:
+ exit(1)
diff --git a/chat2db/app/base/ac_automation.py b/chat2db/app/base/ac_automation.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a408abbd9c87f4d1b0f66d806e8b09ca2e9f11f
--- /dev/null
+++ b/chat2db/app/base/ac_automation.py
@@ -0,0 +1,87 @@
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
+import copy
+import logging
+
+
+class Node:
+ def __init__(self, dep, pre_id):
+ self.dep = dep
+ self.pre_id = pre_id
+ self.pre_nearest_children_id = {}
+ self.children_id = {}
+ self.data_frame = None
+
+
+logging.basicConfig(filename='app.log', level=logging.INFO,
+ format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s')
+
+
+class DictTree:
+ def __init__(self):
+ self.root = 0
+ self.node_list = [Node(0, -1)]
+
+ def load_data(self, data_dict):
+ for key in data_dict:
+ self.insert_data(key, data_dict[key])
+ self.init_pre()
+
+ def insert_data(self, keyword, data_frame):
+ if not isinstance(keyword,str):
+ return
+ if len(keyword) == 0:
+ return
+ node_index = self.root
+ try:
+ for i in range(len(keyword)):
+ if keyword[i] not in self.node_list[node_index].children_id.keys():
+ self.node_list.append(Node(self.node_list[node_index].dep+1, 0))
+ self.node_list[node_index].children_id[keyword[i]] = len(self.node_list)-1
+ node_index = self.node_list[node_index].children_id[keyword[i]]
+ except Exception as e:
+ logging.error(f'关键字插入失败由于:{e}')
+ return
+ self.node_list[node_index].data_frame = data_frame
+
+ def init_pre(self):
+ q = [self.root]
+ l = 0
+ r = 1
+ try:
+ while l < r:
+ node_index = q[l]
+ self.node_list[node_index].pre_nearest_children_id = self.node_list[self.node_list[node_index].pre_id].children_id.copy()
+ l += 1
+ for key, val in self.node_list[node_index].children_id.items():
+ q.append(val)
+ r += 1
+ if key in self.node_list[node_index].pre_nearest_children_id.keys():
+ pre_id = self.node_list[node_index].pre_nearest_children_id[key]
+ self.node_list[val].pre_id = pre_id
+ self.node_list[node_index].pre_nearest_children_id[key] = val
+ except Exception as e:
+ logging.error(f'字典树前缀构建失败由于:{e}')
+ return
+
+ def get_results(self, content: str):
+ content = content.lower()
+ pre_node_index = self.root
+ nex_node_index = None
+ results = []
+ logging.info(f'当前问题{content}')
+ try:
+ for i in range(len(content)):
+ if content[i] in self.node_list[pre_node_index].pre_nearest_children_id.keys():
+ nex_node_index = self.node_list[pre_node_index].pre_nearest_children_id[content[i]]
+ else:
+ nex_node_index = 0
+ if self.node_list[pre_node_index].dep >= self.node_list[nex_node_index].dep:
+ if self.node_list[pre_node_index].data_frame is not None:
+ results.extend(copy.deepcopy(self.node_list[pre_node_index].data_frame))
+ pre_node_index = nex_node_index
+ logging.info(f'当前深度{self.node_list[pre_node_index].dep}')
+ if self.node_list[pre_node_index].data_frame is not None:
+ results.extend(copy.deepcopy(self.node_list[pre_node_index].data_frame))
+ except Exception as e:
+ logging.error(f'结果获取失败由于:{e}')
+ return results
diff --git a/chat2db/app/base/mysql.py b/chat2db/app/base/mysql.py
new file mode 100644
index 0000000000000000000000000000000000000000..176ddd908b99ac027851841a6ebc2479cf3b4491
--- /dev/null
+++ b/chat2db/app/base/mysql.py
@@ -0,0 +1,215 @@
+
+import asyncio
+import aiomysql
+import concurrent.futures
+import logging
+from sqlalchemy.orm import sessionmaker
+from sqlalchemy import create_engine, text
+from concurrent.futures import ThreadPoolExecutor
+from urllib.parse import urlparse
+logging.basicConfig(filename='app.log', level=logging.INFO,
+ format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s')
+
+
+class Mysql():
+ executor = ThreadPoolExecutor(max_workers=10)
+
+ async def test_database_connection(database_url):
+ try:
+ with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
+ future = executor.submit(Mysql._connect_and_query, database_url)
+ result = future.result(timeout=0.1)
+ return result
+ except concurrent.futures.TimeoutError:
+ logging.error('mysql数据库连接超时')
+ return False
+ except Exception as e:
+ logging.error(f'mysql数据库连接失败由于{e}')
+ return False
+
+ @staticmethod
+ def _connect_and_query(database_url):
+ try:
+ engine = create_engine(
+ database_url,
+ pool_size=20,
+ max_overflow=80,
+ pool_recycle=300,
+ pool_pre_ping=True
+ )
+ session = sessionmaker(bind=engine)()
+ session.execute(text("SELECT 1"))
+ session.close()
+ return True
+ except Exception as e:
+ raise e
+
+ @staticmethod
+ async def drop_table(database_url, table_name):
+ engine = create_engine(
+ database_url,
+ pool_size=20,
+ max_overflow=80,
+ pool_recycle=300,
+ pool_pre_ping=True
+ )
+ with sessionmaker(engine)() as session:
+ sql_str = f"DROP TABLE IF EXISTS {table_name};"
+ session.execute(text(sql_str))
+
+ @staticmethod
+ async def select_primary_key_and_keyword_from_table(database_url, table_name, keyword):
+ try:
+ url = urlparse(database_url)
+ db_config = {
+ 'host': url.hostname or 'localhost',
+ 'port': int(url.port or 3306),
+ 'user': url.username or 'root',
+ 'password': url.password or '',
+ 'db': url.path.strip('/')
+ }
+
+ async with aiomysql.create_pool(**db_config) as pool:
+ async with pool.acquire() as conn:
+ async with conn.cursor() as cur:
+ primary_key_query = """
+ SELECT
+ COLUMNS.column_name
+ FROM
+ information_schema.tables AS TABLES
+ INNER JOIN information_schema.columns AS COLUMNS ON TABLES.table_name = COLUMNS.table_name
+ WHERE
+ TABLES.table_schema = %s AND TABLES.table_name = %s AND COLUMNS.column_key = 'PRI';
+ """
+
+ # 尝试执行查询
+ await cur.execute(primary_key_query, (db_config['db'], table_name))
+ primary_key_list = await cur.fetchall()
+ if not primary_key_list:
+ return []
+ primary_key_names = ', '.join([record[0] for record in primary_key_list])
+ columns = f'{primary_key_names}, {keyword}'
+ query = f'SELECT {columns} FROM {table_name};'
+ await cur.execute(query)
+ results = await cur.fetchall()
+
+ def _process_results(results, primary_key_list):
+ tmp_dict = {}
+ for row in results:
+ key = str(row[-1])
+ if key not in tmp_dict:
+ tmp_dict[key] = []
+ pk_values = [str(row[i]) for i in range(len(primary_key_list))]
+ tmp_dict[key].append(pk_values)
+
+ return {
+ 'primary_key_list': [record[0] for record in primary_key_list],
+ 'keyword_value_dict': tmp_dict
+ }
+ result = await asyncio.get_event_loop().run_in_executor(
+ Mysql.executor,
+ _process_results,
+ results,
+ primary_key_list
+ )
+ return result
+
+ except Exception as e:
+ logging.error(f'mysql数据检索失败由于 {e}')
+
+ @staticmethod
+ async def assemble_sql_query_base_on_primary_key(table_name, primary_key_list, primary_key_value_list):
+ sql_str = f'SELECT * FROM {table_name} where '
+ for i in range(len(primary_key_list)):
+ sql_str += primary_key_list[i]+'= \''+primary_key_value_list[i]+'\''
+ if i != len(primary_key_list)-1:
+ sql_str += ' and '
+ sql_str += ';'
+ return sql_str
+
+ @staticmethod
+ async def get_table_info(database_url, table_name):
+ engine = create_engine(
+ database_url,
+ pool_size=20,
+ max_overflow=80,
+ pool_recycle=300,
+ pool_pre_ping=True
+ )
+ with sessionmaker(engine)() as session:
+ sql_str = f"""SELECT TABLE_COMMENT FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = '{table_name}';"""
+ table_note = session.execute(text(sql_str)).one()[0]
+ if table_note == '':
+ table_note = table_name
+ table_note = {
+ 'table_name': table_name,
+ 'table_note': table_note
+ }
+ return table_note
+
+ @staticmethod
+ async def get_column_info(database_url, table_name):
+ engine = create_engine(
+ database_url,
+ pool_size=20,
+ max_overflow=80,
+ pool_recycle=300,
+ pool_pre_ping=True
+ )
+ with engine.connect() as conn:
+ sql_str = f"""
+ SELECT column_name, column_type, column_comment FROM information_schema.columns where TABLE_NAME='{table_name}';
+ """
+ results = conn.execute(text(sql_str), {'table_name': table_name}).all()
+ column_info_list = []
+ for result in results:
+ column_info_list.append({'column_name': result[0], 'column_type': result[1], 'column_note': result[2]})
+ return column_info_list
+
+ @staticmethod
+ async def get_all_table_name_from_database_url(database_url):
+ engine = create_engine(
+ database_url,
+ pool_size=20,
+ max_overflow=80,
+ pool_recycle=300,
+ pool_pre_ping=True
+ )
+ with engine.connect() as connection:
+ result = connection.execute(text("SHOW TABLES"))
+ table_name_list = [row[0] for row in result]
+ return table_name_list
+
+ @staticmethod
+ async def get_rand_data(database_url, table_name, cnt=10):
+ engine = create_engine(
+ database_url,
+ pool_size=20,
+ max_overflow=80,
+ pool_recycle=300,
+ pool_pre_ping=True
+ )
+ try:
+ with sessionmaker(engine)() as session:
+ sql_str = f'''SELECT *
+ FROM {table_name}
+ ORDER BY RAND()
+ LIMIT {cnt};'''
+ dataframe = str(session.execute(text(sql_str)).all())
+ except Exception as e:
+ dataframe = ''
+ logging.error(f'随机从数据库中获取数据失败由于{e}')
+ return dataframe
+
+ @staticmethod
+ async def try_excute(database_url, sql_str):
+ engine = create_engine(
+ database_url,
+ pool_size=20,
+ max_overflow=80,
+ pool_recycle=300,
+ pool_pre_ping=True
+ )
+ with sessionmaker(engine)() as session:
+ result=session.execute(text(sql_str)).all()
+ return result
\ No newline at end of file
diff --git a/chat2db/app/base/postgres.py b/chat2db/app/base/postgres.py
new file mode 100644
index 0000000000000000000000000000000000000000..11dc22f86fa1752cca991f0db29837c275eb76a2
--- /dev/null
+++ b/chat2db/app/base/postgres.py
@@ -0,0 +1,233 @@
+import asyncio
+import asyncpg
+import concurrent.futures
+import logging
+from sqlalchemy.orm import sessionmaker
+from sqlalchemy import create_engine, text
+from concurrent.futures import ThreadPoolExecutor
+
+logging.basicConfig(filename='app.log', level=logging.INFO,
+ format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s')
+
+
+def handler(signum, frame):
+ raise TimeoutError("超时")
+
+
+class Postgres():
+ executor = ThreadPoolExecutor(max_workers=10)
+
+ async def test_database_connection(database_url):
+ try:
+ with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
+ future = executor.submit(Postgres._connect_and_query, database_url)
+ result = future.result(timeout=0.1)
+ return result
+ except concurrent.futures.TimeoutError:
+ logging.error('postgres数据库连接超时')
+ return False
+ except Exception as e:
+ logging.error(f'postgres数据库连接失败由于{e}')
+ return False
+
+ @staticmethod
+ def _connect_and_query(database_url):
+ try:
+ engine = create_engine(
+ database_url,
+ pool_size=20,
+ max_overflow=80,
+ pool_recycle=300,
+ pool_pre_ping=True
+ )
+ session = sessionmaker(bind=engine)()
+ session.execute(text("SELECT 1"))
+ session.close()
+ return True
+ except Exception as e:
+ raise e
+
+ @staticmethod
+ async def drop_table(database_url, table_name):
+ engine = create_engine(
+ database_url,
+ pool_size=20,
+ max_overflow=80,
+ pool_recycle=300,
+ pool_pre_ping=True
+ )
+ with sessionmaker(engine)() as session:
+ sql_str = f"DROP TABLE IF EXISTS {table_name};"
+ session.execute(text(sql_str))
+
+ @staticmethod
+ async def select_primary_key_and_keyword_from_table(database_url, table_name, keyword):
+ try:
+ dsn = database_url.replace('+psycopg2', '')
+ conn = await asyncpg.connect(dsn=dsn)
+ primary_key_query = """
+ SELECT
+ kcu.column_name
+ FROM
+ information_schema.table_constraints AS tc
+ JOIN information_schema.key_column_usage AS kcu
+ ON tc.constraint_name = kcu.constraint_name
+ WHERE
+ tc.constraint_type = 'PRIMARY KEY'
+ AND tc.table_name = $1;
+ """
+ primary_key_list = await conn.fetch(primary_key_query, table_name)
+ if not primary_key_list:
+ return []
+ columns = ', '.join([record['column_name'] for record in primary_key_list]) + f', {keyword}'
+ query = f'SELECT {columns} FROM {table_name};'
+ results = await conn.fetch(query)
+
+ def _process_results(results, primary_key_list):
+ tmp_dict = {}
+ for row in results:
+ key = str(row[-1])
+ if key not in tmp_dict:
+ tmp_dict[key] = []
+ pk_values = [str(row[i]) for i in range(len(primary_key_list))]
+ tmp_dict[key].append(pk_values)
+
+ return {
+ 'primary_key_list': [record['column_name'] for record in primary_key_list],
+ 'keyword_value_dict': tmp_dict
+ }
+ result = await asyncio.get_event_loop().run_in_executor(
+ Postgres.executor,
+ _process_results,
+ results,
+ primary_key_list
+ )
+ await conn.close()
+
+ return result
+ except Exception as e:
+ logging.error(f'postgres数据检索失败由于 {e}')
+ return None
+
+ @staticmethod
+ async def assemble_sql_query_base_on_primary_key(table_name, primary_key_list, primary_key_value_list):
+ sql_str = f'SELECT * FROM {table_name} where '
+ for i in range(len(primary_key_list)):
+ sql_str += primary_key_list[i]+'='+'\''+primary_key_value_list[i]+'\''
+ if i != len(primary_key_list)-1:
+ sql_str += ' and '
+ sql_str += ';'
+ return sql_str
+
+ @staticmethod
+ async def get_table_info(database_url, table_name):
+ engine = create_engine(
+ database_url,
+ pool_size=20,
+ max_overflow=80,
+ pool_recycle=300,
+ pool_pre_ping=True
+ )
+ with engine.connect() as conn:
+ sql_str = """
+ SELECT
+ d.description AS table_description
+ FROM
+ pg_class t
+ JOIN
+ pg_description d ON t.oid = d.objoid
+ WHERE
+ t.relkind = 'r' AND
+ d.objsubid = 0 AND
+ t.relname = :table_name; """
+ table_note = conn.execute(text(sql_str), {'table_name': table_name}).one()[0]
+ if table_note == '':
+ table_note = table_name
+ table_note = {
+ 'table_name': table_name,
+ 'table_note': table_note
+ }
+ return table_note
+
+ @staticmethod
+ async def get_column_info(database_url, table_name):
+ engine = create_engine(
+ database_url,
+ pool_size=20,
+ max_overflow=80,
+ pool_recycle=300,
+ pool_pre_ping=True
+ )
+ with engine.connect() as conn:
+ sql_str = """
+ SELECT
+ a.attname as 字段名,
+ format_type(a.atttypid,a.atttypmod) as 类型,
+ col_description(a.attrelid,a.attnum) as 注释
+ FROM
+ pg_class as c,pg_attribute as a
+ where
+ a.attrelid = c.oid
+ and
+ a.attnum>0
+ and
+ c.relname = :table_name;
+ """
+ results = conn.execute(text(sql_str), {'table_name': table_name}).all()
+ column_info_list = []
+ for result in results:
+ column_info_list.append({'column_name': result[0], 'column_type': result[1], 'column_note': result[2]})
+ return column_info_list
+
+ @staticmethod
+ async def get_all_table_name_from_database_url(database_url):
+ engine = create_engine(
+ database_url,
+ pool_size=20,
+ max_overflow=80,
+ pool_recycle=300,
+ pool_pre_ping=True
+ )
+ with engine.connect() as connection:
+ sql_str = '''
+ SELECT table_name
+ FROM information_schema.tables
+ WHERE table_schema = 'public';
+ '''
+ result = connection.execute(text(sql_str))
+ table_name_list = [row[0] for row in result]
+ return table_name_list
+
+ @staticmethod
+ async def get_rand_data(database_url, table_name, cnt=10):
+ engine = create_engine(
+ database_url,
+ pool_size=20,
+ max_overflow=80,
+ pool_recycle=300,
+ pool_pre_ping=True
+ )
+ try:
+ with sessionmaker(engine)() as session:
+ sql_str = f'''SELECT *
+ FROM {table_name}
+ ORDER BY RANDOM()
+ LIMIT {cnt};'''
+ dataframe = str(session.execute(text(sql_str)).all())
+ except Exception as e:
+ dataframe = ''
+ logging.error(f'随机从数据库中获取数据失败由于{e}')
+ return dataframe
+
+ @staticmethod
+ async def try_excute(database_url, sql_str):
+ engine = create_engine(
+ database_url,
+ pool_size=20,
+ max_overflow=80,
+ pool_recycle=300,
+ pool_pre_ping=True
+ )
+ with sessionmaker(engine)() as session:
+ result=session.execute(text(sql_str)).all()
+ return result
diff --git a/chat2db/app/base/vectorize.py b/chat2db/app/base/vectorize.py
new file mode 100644
index 0000000000000000000000000000000000000000..53fe4af3fe090d2ccae66b6ea61121fa4931cebc
--- /dev/null
+++ b/chat2db/app/base/vectorize.py
@@ -0,0 +1,46 @@
+import requests
+import urllib3
+from chat2db.config.config import config
+import json
+import logging
+
+logging.basicConfig(filename='app.log', level=logging.INFO,
+ format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s')
+
+urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
+
+
+class Vectorize():
+ @staticmethod
+ async def vectorize_embedding(text):
+ if config['EMBEDDING_TYPE']=='openai':
+ headers = {
+ "Authorization": f"Bearer {config['EMBEDDING_API_KEY']}"
+ }
+ data = {
+ "input": text,
+ "model": config["EMBEDDING_MODEL_NAME"],
+ "encoding_format": "float"
+ }
+ try:
+ res = requests.post(url=config["EMBEDDING_ENDPOINT"],headers=headers, json=data, verify=False)
+ if res.status_code != 200:
+ return None
+ return res.json()['data'][0]['embedding']
+ except Exception as e:
+ logging.error(f"Embedding error failed due to: {e}")
+ return None
+ elif config['EMBEDDING_TYPE'] =='mindie':
+ try:
+ data = {
+ "inputs": text,
+ }
+ res = requests.post(url=config["EMBEDDING_ENDPOINT"], json=data, verify=False)
+ if res.status_code != 200:
+ return None
+ return json.loads(res.text)[0]
+ except Exception as e:
+ logging.error(f"Embedding error failed due to: {e}")
+ return None
+ else:
+ return None
diff --git a/chat2db/app/router/database.py b/chat2db/app/router/database.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd9a3f735bbf5fa227970406c717f1a0dc121edd
--- /dev/null
+++ b/chat2db/app/router/database.py
@@ -0,0 +1,114 @@
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
+
+import logging
+import uuid
+from fastapi import APIRouter, status
+
+from chat2db.model.request import DatabaseAddRequest, DatabaseDelRequest
+from chat2db.model.response import ResponseData
+from chat2db.manager.database_info_manager import DatabaseInfoManager
+from chat2db.app.service.diff_database_service import DiffDatabaseService
+
+logging.basicConfig(filename='app.log', level=logging.INFO,
+ format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s')
+
+router = APIRouter(
+ prefix="/database"
+)
+
+
+@router.post("/add", response_model=ResponseData)
+async def add_database_info(request: DatabaseAddRequest):
+ database_url = request.database_url
+ if 'mysql' not in database_url and 'postgres' not in database_url:
+ return ResponseData(
+ code=status.HTTP_422_UNPROCESSABLE_ENTITY,
+ message="不支持当前数据库",
+ result={}
+ )
+ database_type = 'postgres'
+ if 'mysql' in database_url:
+ database_type = 'mysql'
+ flag = await DiffDatabaseService.get_database_service(database_type).test_database_connection(database_url)
+ if not flag:
+ return ResponseData(
+ code=status.HTTP_422_UNPROCESSABLE_ENTITY,
+ message="无法连接当前数据库",
+ result={}
+ )
+ database_id = await DatabaseInfoManager.add_database(database_url)
+ if database_id is None:
+ return ResponseData(
+ code=status.HTTP_422_UNPROCESSABLE_ENTITY,
+ message="数据库连接添加失败,当前存在重复数据库配置",
+ result={}
+ )
+ return ResponseData(
+ code=status.HTTP_200_OK,
+ message="success",
+ result={'database_id': database_id}
+ )
+
+
+@router.post("/del", response_model=ResponseData)
+async def del_database_info(request: DatabaseDelRequest):
+ database_id = request.database_id
+ flag = await DatabaseInfoManager.del_database_by_id(database_id)
+ if not flag:
+ return ResponseData(
+ code=status.HTTP_422_UNPROCESSABLE_ENTITY,
+ message="删除数据库配置失败,数据库配置不存在",
+ result={}
+ )
+ return ResponseData(
+ code=status.HTTP_200_OK,
+ message="删除数据库配置成功",
+ result={}
+ )
+
+
+@router.get("/query", response_model=ResponseData)
+async def query_database_info():
+ database_info_list = await DatabaseInfoManager.get_all_database_info()
+ return ResponseData(
+ code=status.HTTP_200_OK,
+ message="查询数据库配置成功",
+ result={'database_info_list': database_info_list}
+ )
+
+
+@router.get("/list", response_model=ResponseData)
+async def list_table_in_database(database_id: uuid.UUID, table_filter: str = ''):
+ database_url = await DatabaseInfoManager.get_database_url_by_id(database_id)
+ if database_url is None:
+ return ResponseData(
+ code=status.HTTP_422_UNPROCESSABLE_ENTITY,
+ message="查询数据库内表格配置失败,数据库配置不存在",
+ result={}
+ )
+ if 'mysql' not in database_url and 'postgres' not in database_url:
+ return ResponseData(
+ code=status.HTTP_422_UNPROCESSABLE_ENTITY,
+ message="不支持当前数据库",
+ result={}
+ )
+ database_type = 'postgres'
+ if 'mysql' in database_url:
+ database_type = 'mysql'
+ flag = await DiffDatabaseService.get_database_service(database_type).test_database_connection(database_url)
+ if not flag:
+ return ResponseData(
+ code=status.HTTP_422_UNPROCESSABLE_ENTITY,
+ message="无法连接当前数据库",
+ result={}
+ )
+ table_name_list = await DiffDatabaseService.get_database_service(database_type).get_all_table_name_from_database_url(database_url)
+ results = []
+ for table_name in table_name_list:
+ if table_filter in table_name:
+ results.append(table_name)
+ return ResponseData(
+ code=status.HTTP_200_OK,
+ message="查询数据库配置成功",
+ result={'table_name_list': results}
+ )
diff --git a/chat2db/app/router/sql_example.py b/chat2db/app/router/sql_example.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b794e25e53459c90d6d666756a7d7066adc6124
--- /dev/null
+++ b/chat2db/app/router/sql_example.py
@@ -0,0 +1,136 @@
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
+
+import logging
+import uuid
+from fastapi import APIRouter, status
+
+from chat2db.model.request import SqlExampleAddRequest, SqlExampleDelRequest, SqlExampleUpdateRequest, SqlExampleGenerateRequest
+from chat2db.model.response import ResponseData
+from chat2db.manager.database_info_manager import DatabaseInfoManager
+from chat2db.manager.table_info_manager import TableInfoManager
+from chat2db.manager.sql_example_manager import SqlExampleManager
+from chat2db.app.service.sql_generate_service import SqlGenerateService
+from chat2db.app.base.vectorize import Vectorize
+logging.basicConfig(filename='app.log', level=logging.INFO,
+ format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s')
+
+router = APIRouter(
+ prefix="/sql/example"
+)
+
+
+@router.post("/add", response_model=ResponseData)
+async def add_sql_example(request: SqlExampleAddRequest):
+ table_id = request.table_id
+ table_info = await TableInfoManager.get_table_info_by_table_id(table_id)
+ if table_info is None:
+ return ResponseData(
+ code=status.HTTP_422_UNPROCESSABLE_ENTITY,
+ message="表格不存在",
+ result={}
+ )
+ database_id = table_info['database_id']
+ question = request.question
+ question_vector = await Vectorize.vectorize_embedding(question)
+ sql = request.sql
+ try:
+ sql_example_id = await SqlExampleManager.add_sql_example(question, sql, table_id, question_vector)
+ except Exception as e:
+ logging.error(f'sql案例添加失败由于{e}')
+ return ResponseData(
+ code=status.HTTP_400_BAD_REQUEST,
+ message="sql案例添加失败",
+ result={}
+ )
+ return ResponseData(
+ code=status.HTTP_200_OK,
+ message="success",
+ result={'sql_example_id': sql_example_id}
+ )
+
+
+@router.post("/del", response_model=ResponseData)
+async def del_sql_example(request: SqlExampleDelRequest):
+ sql_example_id = request.sql_example_id
+ flag = await SqlExampleManager.del_sql_example_by_id(sql_example_id)
+ if not flag:
+ return ResponseData(
+ code=status.HTTP_422_UNPROCESSABLE_ENTITY,
+ message="sql案例不存在",
+ result={}
+ )
+ return ResponseData(
+ code=status.HTTP_200_OK,
+ message="sql案例删除成功",
+ result={}
+ )
+
+
+@router.get("/query", response_model=ResponseData)
+async def query_sql_example(table_id: uuid.UUID):
+ sql_example_list = await SqlExampleManager.query_sql_example_by_table_id(table_id)
+ return ResponseData(
+ code=status.HTTP_200_OK,
+ message="查询sql案例成功",
+ result={'sql_example_list': sql_example_list}
+ )
+
+
+@router.post("/udpate", response_model=ResponseData)
+async def update_sql_example(request: SqlExampleUpdateRequest):
+ sql_example_id = request.sql_example_id
+ question = request.question
+ question_vector = await Vectorize.vectorize_embedding(question)
+ sql = request.sql
+ flag = await SqlExampleManager.update_sql_example_by_id(sql_example_id, question, sql, question_vector)
+ if not flag:
+ return ResponseData(
+ code=status.HTTP_422_UNPROCESSABLE_ENTITY,
+ message="sql案例不存在",
+ result={}
+ )
+ return ResponseData(
+ code=status.HTTP_200_OK,
+ message="sql案例更新成功",
+ result={}
+ )
+
+
+@router.post("/generate", response_model=ResponseData)
+async def generate_sql_example(request: SqlExampleGenerateRequest):
+ table_id = request.table_id
+ generate_cnt = request.generate_cnt
+ table_info = await TableInfoManager.get_table_info_by_table_id(table_id)
+ if table_info is None:
+ return ResponseData(
+ code=status.HTTP_422_UNPROCESSABLE_ENTITY,
+ message="表格不存在",
+ result={}
+ )
+ table_name = table_info['table_name']
+ database_id = table_info['database_id']
+ database_url = await DatabaseInfoManager.get_database_url_by_id(database_id)
+ sql_var = request.sql_var
+ sql_example_list = []
+ for i in range(generate_cnt):
+ try:
+ tmp_dict = await SqlGenerateService.generate_sql_base_on_data(database_url, table_name, sql_var)
+ except Exception as e:
+ logging.error(f'sql案例生成失败由于{e}')
+ continue
+ if tmp_dict is None:
+ continue
+ question = tmp_dict['question']
+ question_vector = await Vectorize.vectorize_embedding(question)
+ sql = tmp_dict['sql']
+ await SqlExampleManager.add_sql_example(question, sql, table_id, question_vector)
+ tmp_dict['database_id'] = database_id
+ tmp_dict['table_id'] = table_id
+ sql_example_list.append(tmp_dict)
+ return ResponseData(
+ code=status.HTTP_200_OK,
+ message="sql案例生成成功",
+ result={
+ 'sql_example_list': sql_example_list
+ }
+ )
diff --git a/chat2db/app/router/sql_generate.py b/chat2db/app/router/sql_generate.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a09e52de54c108eb2f73aa2b418f7744f256fb3
--- /dev/null
+++ b/chat2db/app/router/sql_generate.py
@@ -0,0 +1,129 @@
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
+
+import logging
+from fastapi import APIRouter, status
+
+from chat2db.manager.database_info_manager import DatabaseInfoManager
+from chat2db.manager.table_info_manager import TableInfoManager
+from chat2db.manager.column_info_manager import ColumnInfoManager
+from chat2db.model.request import SqlGenerateRequest, SqlRepairRequest, SqlExcuteRequest
+from chat2db.model.response import ResponseData
+from chat2db.app.service.sql_generate_service import SqlGenerateService
+from chat2db.app.service.keyword_service import keyword_service
+from chat2db.app.service.diff_database_service import DiffDatabaseService
+logging.basicConfig(filename='app.log', level=logging.INFO,
+ format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s')
+
+router = APIRouter(
+ prefix="/sql"
+)
+
+
+@router.post("/generate", response_model=ResponseData)
+async def generate_sql(request: SqlGenerateRequest):
+ database_id = request.database_id
+ database_url = await DatabaseInfoManager.get_database_url_by_id(database_id)
+ table_id_list = request.table_id_list
+ question = request.question
+ use_llm_enhancements = request.use_llm_enhancements
+ if table_id_list == []:
+ table_id_list = None
+ results = {}
+ sql_list = await SqlGenerateService.generate_sql_base_on_exmpale(
+ database_id=database_id, question=question, table_id_list=table_id_list,
+ use_llm_enhancements=use_llm_enhancements)
+ try:
+ sql_list += await keyword_service.generate_sql(question, database_id, table_id_list)
+ results['sql_list'] = sql_list[:request.topk]
+ results['database_url'] = database_url
+ except Exception as e:
+ logging.error(f'sql生成失败由于{e}')
+ return ResponseData(
+ code=status.HTTP_400_BAD_REQUEST,
+ message="sql生成失败",
+ result={}
+ )
+ return ResponseData(
+ code=status.HTTP_200_OK, message="success",
+ result=results
+ )
+
+
+@router.post("/repair", response_model=ResponseData)
+async def repair_sql(request: SqlRepairRequest):
+ database_id = request.database_id
+ table_id = request.table_id
+ database_url = await DatabaseInfoManager.get_database_url_by_id(database_id)
+ if database_url is None:
+ return ResponseData(
+ code=status.HTTP_422_UNPROCESSABLE_ENTITY,
+ message="当前数据库配置不存在",
+ result={}
+ )
+ database_type = 'postgres'
+ if 'mysql' in database_url:
+ database_type = 'mysql'
+ table_info = await TableInfoManager.get_table_info_by_table_id(table_id)
+ if table_info is None:
+ return ResponseData(
+ code=status.HTTP_422_UNPROCESSABLE_ENTITY,
+ message="表格不存在",
+ result={}
+ )
+ if table_info['database_id'] != database_id:
+ return ResponseData(
+ code=status.HTTP_422_UNPROCESSABLE_ENTITY,
+ message="表格不属于当前数据库",
+ result={}
+ )
+ column_info_list = await ColumnInfoManager.get_column_info_by_table_id(table_id)
+ sql = request.sql
+ message = request.message
+ question = request.question
+ try:
+ sql = await SqlGenerateService.repair_sql(database_type, table_info, column_info_list, sql, message, question)
+ except Exception as e:
+ logging.error(f'sql修复失败由于{e}')
+ return ResponseData(
+ code=status.HTTP_422_UNPROCESSABLE_ENTITY,
+ message="sql修复失败",
+ result={}
+ )
+ return ResponseData(
+ code=status.HTTP_200_OK,
+ message="sql修复成功",
+ result={'database_id': database_id,
+ 'table_id': table_id,
+ 'sql': sql}
+ )
+
+
+@router.post("/execute", response_model=ResponseData)
+async def execute_sql(request: SqlExcuteRequest):
+ database_id = request.database_id
+ sql = request.sql
+ database_url = await DatabaseInfoManager.get_database_url_by_id(database_id)
+ if database_url is None:
+ return ResponseData(
+ code=status.HTTP_422_UNPROCESSABLE_ENTITY,
+ message="当前数据库配置不存在",
+ result={}
+ )
+ database_type = 'postgres'
+ if 'mysql' in database_url:
+ database_type = 'mysql'
+ try:
+ results = await DiffDatabaseService.database_map[database_type].try_excute(database_url, sql)
+ results = str(results)
+ except Exception as e:
+ logging.error(f'sql执行失败由于{e}')
+ return ResponseData(
+ code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ message="sql执行失败",
+ result={'Error':str(e)}
+ )
+ return ResponseData(
+ code=status.HTTP_200_OK,
+ message="sql执行成功",
+ result={'results':results}
+ )
diff --git a/chat2db/app/router/table.py b/chat2db/app/router/table.py
new file mode 100644
index 0000000000000000000000000000000000000000..85ad02b143913ca75bddd0a677bb9792663ec582
--- /dev/null
+++ b/chat2db/app/router/table.py
@@ -0,0 +1,148 @@
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
+
+import logging
+import uuid
+from fastapi import APIRouter, status
+
+from chat2db.model.request import TableAddRequest, TableDelRequest, EnableColumnRequest
+from chat2db.model.response import ResponseData
+from chat2db.manager.database_info_manager import DatabaseInfoManager
+from chat2db.manager.table_info_manager import TableInfoManager
+from chat2db.manager.column_info_manager import ColumnInfoManager
+from chat2db.app.service.diff_database_service import DiffDatabaseService
+from chat2db.app.base.vectorize import Vectorize
+from chat2db.app.service.keyword_service import keyword_service
+logging.basicConfig(filename='app.log', level=logging.INFO,
+ format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s')
+
+router = APIRouter(
+ prefix="/table"
+)
+
+
+@router.post("/add", response_model=ResponseData)
+async def add_database_info(request: TableAddRequest):
+ database_id = request.database_id
+ database_url = await DatabaseInfoManager.get_database_url_by_id(database_id)
+ if database_url is None:
+ return ResponseData(
+ code=status.HTTP_422_UNPROCESSABLE_ENTITY,
+ message="当前数据库配置不存在",
+ result={}
+ )
+ database_type = 'postgres'
+ if 'mysql' in database_url:
+ database_type = 'mysql'
+ flag = await DiffDatabaseService.get_database_service(database_type).test_database_connection(database_url)
+ if not flag:
+ return ResponseData(
+ code=status.HTTP_422_UNPROCESSABLE_ENTITY,
+ message="无法连接当前数据库",
+ result={}
+ )
+ table_name = request.table_name
+ table_name_list = await DiffDatabaseService.get_database_service(database_type).get_all_table_name_from_database_url(database_url)
+ if table_name not in table_name_list:
+ return ResponseData(
+ code=status.HTTP_422_UNPROCESSABLE_ENTITY,
+ message="表格不存在",
+ result={}
+ )
+ tmp_dict = await DiffDatabaseService.get_database_service(database_type).get_table_info(database_url, table_name)
+ table_note = tmp_dict['table_note']
+ table_note_vector = await Vectorize.vectorize_embedding(table_note)
+ table_id = await TableInfoManager.add_table_info(database_id, table_name, table_note, table_note_vector)
+ if table_id is None:
+ return ResponseData(
+ code=status.HTTP_422_UNPROCESSABLE_ENTITY,
+ message="表格添加失败,当前存在重复表格",
+ result={}
+ )
+ column_info_list = await DiffDatabaseService.get_database_service(database_type).get_column_info(database_url, table_name)
+ for column_info in column_info_list:
+ await ColumnInfoManager.add_column_info_with_table_id(
+ table_id, column_info['column_name'],
+ column_info['column_type'],
+ column_info['column_note'])
+ return ResponseData(
+ code=status.HTTP_200_OK,
+ message="success",
+ result={'table_id': table_id}
+ )
+
+
+@router.post("/del", response_model=ResponseData)
+async def del_table_info(request: TableDelRequest):
+ table_id = request.table_id
+ flag = await TableInfoManager.del_table_by_id(table_id)
+ if not flag:
+ return ResponseData(
+ code=status.HTTP_422_UNPROCESSABLE_ENTITY,
+ message="表格不存在",
+ result={}
+ )
+ return ResponseData(
+ code=status.HTTP_200_OK,
+ message="删除表格成功",
+ result={}
+ )
+
+
+@router.get("/query", response_model=ResponseData)
+async def query_table_info(database_id: uuid.UUID):
+ database_url = await DatabaseInfoManager.get_database_url_by_id(database_id)
+ if database_url is None:
+ return ResponseData(
+ code=status.HTTP_422_UNPROCESSABLE_ENTITY,
+ message="当前数据库配置不存在",
+ result={}
+ )
+ table_info_list = await TableInfoManager.get_table_info_by_database_id(database_id)
+ return ResponseData(
+ code=status.HTTP_200_OK,
+ message="查询表格成功",
+ result={'table_info_list': table_info_list}
+ )
+
+
+@router.get("/column/query", response_model=ResponseData)
+async def query_column(table_id: uuid.UUID):
+ column_info_list = await ColumnInfoManager.get_column_info_by_table_id(table_id)
+ return ResponseData(
+ code=status.HTTP_200_OK,
+ message="",
+ result={'column_info_list': column_info_list}
+ )
+
+
+@router.post("/column/enable", response_model=ResponseData)
+async def enable_column(request: EnableColumnRequest):
+ column_id = request.column_id
+ enable = request.enable
+ flag = await ColumnInfoManager.update_column_info_enable(column_id, enable)
+ if not flag:
+ return ResponseData(
+ code=status.HTTP_422_UNPROCESSABLE_ENTITY,
+ message="列不存在",
+ result={}
+ )
+ column_info = await ColumnInfoManager.get_column_info_by_column_id(column_id)
+ column_name = column_info['column_name']
+ table_id = column_info['table_id']
+ table_info = await TableInfoManager.get_table_info_by_table_id(table_id)
+ database_id = table_info['database_id']
+ if enable:
+ flag = await keyword_service.add(database_id, table_id, column_name)
+ else:
+ flag = await keyword_service.del_by_column_name(database_id, table_id, column_name)
+ if not flag:
+ return ResponseData(
+ code=status.HTTP_422_UNPROCESSABLE_ENTITY,
+ message="列关键字功能开启/关闭失败",
+ result={}
+ )
+ return ResponseData(
+ code=status.HTTP_200_OK,
+ message="列关键字功能开启/关闭成功",
+ result={}
+ )
diff --git a/chat2db/app/service/diff_database_service.py b/chat2db/app/service/diff_database_service.py
new file mode 100644
index 0000000000000000000000000000000000000000..0fe77b0b2eac045d3645f6d143ed1619060127e3
--- /dev/null
+++ b/chat2db/app/service/diff_database_service.py
@@ -0,0 +1,10 @@
+from chat2db.app.base.mysql import Mysql
+from chat2db.app.base.postgres import Postgres
+
+
+class DiffDatabaseService():
+ database_map = {"mysql": Mysql, "postgres": Postgres}
+
+ @staticmethod
+ def get_database_service(database_type):
+ return DiffDatabaseService.database_map[database_type]
diff --git a/chat2db/app/service/keyword_service.py b/chat2db/app/service/keyword_service.py
new file mode 100644
index 0000000000000000000000000000000000000000..335410a8f3ff0155c7eeee7a3b54b77699fe2e68
--- /dev/null
+++ b/chat2db/app/service/keyword_service.py
@@ -0,0 +1,134 @@
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
+import asyncio
+import copy
+import uuid
+import threading
+from concurrent.futures import ThreadPoolExecutor
+from chat2db.app.service.diff_database_service import DiffDatabaseService
+from chat2db.app.base.ac_automation import DictTree
+from chat2db.manager.database_info_manager import DatabaseInfoManager
+from chat2db.manager.table_info_manager import TableInfoManager
+from chat2db.manager.column_info_manager import ColumnInfoManager
+import logging
+
+logging.basicConfig(filename='app.log', level=logging.INFO,
+ format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s')
+
+
+class KeywordManager():
+ def __init__(self):
+ self.keyword_asset_dict = {}
+ self.lock = threading.Lock()
+ self.data_frame_dict = {}
+
+ async def load_keywords(self):
+ database_info_list = await DatabaseInfoManager.get_all_database_info()
+ for database_info in database_info_list:
+ database_id = database_info['database_id']
+ table_info_list = await TableInfoManager.get_table_info_by_database_id(database_id)
+ cnt=0
+ for table_info in table_info_list:
+ table_id = table_info['table_id']
+ column_info_list = await ColumnInfoManager.get_column_info_by_table_id(table_id, True)
+ for i in range(len(column_info_list)):
+ column_info = column_info_list[i]
+ cnt+=1
+ try:
+ column_name = column_info['column_name']
+ await self.add(database_id, table_id, column_name)
+ except Exception as e:
+ logging.error('关键字数据结构生成失败')
+ def add_excutor(self, rd_id, database_id, table_id, table_info, column_info_list, column_name):
+ tmp_dict = self.data_frame_dict[rd_id]
+ tmp_dict_tree = DictTree()
+ tmp_dict_tree.load_data(tmp_dict['keyword_value_dict'])
+ if database_id not in self.keyword_asset_dict.keys():
+ self.keyword_asset_dict[database_id] = {}
+ with self.lock:
+ if table_id not in self.keyword_asset_dict[database_id].keys():
+ self.keyword_asset_dict[database_id][table_id] = {}
+ self.keyword_asset_dict[database_id][table_id]['table_info'] = table_info
+ self.keyword_asset_dict[database_id][table_id]['column_info_list'] = column_info_list
+ self.keyword_asset_dict[database_id][table_id]['primary_key_list'] = copy.deepcopy(
+ tmp_dict['primary_key_list'])
+ self.keyword_asset_dict[database_id][table_id]['dict_tree_dict'] = {}
+ self.keyword_asset_dict[database_id][table_id]['dict_tree_dict'][column_name] = tmp_dict_tree
+ del self.data_frame_dict[rd_id]
+
+ async def add(self, database_id, table_id, column_name):
+ database_url = await DatabaseInfoManager.get_database_url_by_id(database_id)
+ database_type = 'postgres'
+ if 'mysql' in database_url:
+ database_type = 'mysql'
+ table_info = await TableInfoManager.get_table_info_by_table_id(table_id)
+ table_name = table_info['table_name']
+ tmp_dict = await DiffDatabaseService.get_database_service(
+ database_type).select_primary_key_and_keyword_from_table(database_url, table_name, column_name)
+ if tmp_dict is None:
+ return
+ rd_id = str(uuid.uuid4)
+ self.data_frame_dict[rd_id] = tmp_dict
+ del database_url
+ column_info_list = await ColumnInfoManager.get_column_info_by_table_id(table_id)
+ try:
+ thread = threading.Thread(target=self.add_excutor, args=(rd_id, database_id, table_id,
+ table_info, column_info_list, column_name,))
+ thread.start()
+ except Exception as e:
+ logging.error(f'创建增加线程失败由于{e}')
+ return False
+ return True
+
+ async def update_keyword_asset(self):
+ database_info_list = DatabaseInfoManager.get_all_database_info()
+ for database_info in database_info_list:
+ database_id = database_info['database_id']
+ table_info_list = TableInfoManager.get_table_info_by_database_id(database_id)
+ for table_info in table_info_list:
+ table_id = table_info['table_id']
+ column_info_list = ColumnInfoManager.get_column_info_by_table_id(table_id, True)
+ for column_info in column_info_list:
+ await self.add(database_id, table_id, column_info['column_name'])
+
+ async def del_by_column_name(self, database_id, table_id, column_name):
+ try:
+ with self.lock:
+ if database_id in self.keyword_asset_dict.keys():
+ if table_id in self.keyword_asset_dict[database_id].keys():
+ if column_name in self.keyword_asset_dict[database_id][table_id]['dict_tree_dict'].keys():
+ del self.keyword_asset_dict[database_id][table_id]['dict_tree_dict'][column_name]
+ except Exception as e:
+ logging.error(f'字典树删除失败由于{e}')
+ return False
+ return True
+
+ async def generate_sql(self, question, database_id, table_id_list=None):
+ with self.lock:
+ results = []
+ if database_id in self.keyword_asset_dict.keys():
+ database_url = await DatabaseInfoManager.get_database_url_by_id(database_id)
+ database_type = 'postgres'
+ if 'mysql' in database_url:
+ database_type = 'mysql'
+ for table_id in self.keyword_asset_dict[database_id].keys():
+ if table_id_list is None or table_id in table_id_list:
+ table_info = self.keyword_asset_dict[database_id][table_id]['table_info']
+ primary_key_list = self.keyword_asset_dict[database_id][table_id]['primary_key_list']
+ primary_key_value_list = []
+ try:
+ for dict_tree in self.keyword_asset_dict[database_id][table_id]['dict_tree_dict'].values():
+ primary_key_value_list += dict_tree.get_results(question)
+ except Exception as e:
+ logging.error(f'从字典树中获取结果失败由于{e}')
+ continue
+ for i in range(len(primary_key_value_list)):
+ sql_str = await DiffDatabaseService.get_database_service(database_type).assemble_sql_query_base_on_primary_key(
+ table_info['table_name'], primary_key_list, primary_key_value_list[i])
+ tmp_dict = {'database_id': database_id, 'table_id': table_id, 'sql': sql_str}
+ results.append(tmp_dict)
+ del database_url
+ return results
+
+
+keyword_service = KeywordManager()
+asyncio.run(keyword_service.load_keywords())
diff --git a/chat2db/app/service/sql_generate_service.py b/chat2db/app/service/sql_generate_service.py
new file mode 100644
index 0000000000000000000000000000000000000000..e02e430ecf968f7e2d050b9f68c94030d0c83f61
--- /dev/null
+++ b/chat2db/app/service/sql_generate_service.py
@@ -0,0 +1,351 @@
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
+import yaml
+import re
+import json
+import random
+import uuid
+import logging
+from pandas.core.api import DataFrame as DataFrame
+
+from chat2db.manager.database_info_manager import DatabaseInfoManager
+from chat2db.manager.table_info_manager import TableInfoManager
+from chat2db.manager.column_info_manager import ColumnInfoManager
+from chat2db.manager.sql_example_manager import SqlExampleManager
+from chat2db.app.service.diff_database_service import DiffDatabaseService
+from chat2db.llm.chat_with_model import LLM
+from chat2db.config.config import config
+from chat2db.app.base.vectorize import Vectorize
+
+
+logging.basicConfig(filename='app.log', level=logging.INFO,
+ format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s')
+
+
+class SqlGenerateService():
+
+ @staticmethod
+ async def merge_table_and_column_info(table_info, column_info_list):
+ table_name = table_info.get('table_name', '')
+ table_note = table_info.get('table_note', '')
+ note = '
\n'
+ note += '\n'+'表名 | \n'+'
\n'
+ note += '\n'+f'{table_name} | \n'+'
\n'
+ note += '\n'+'表的注释 | \n'+'
\n'
+ note += '\n'+f'{table_note} | \n'+'
\n'
+ note += '\n'+' 字段 | \n字段类型 | \n字段注释 | \n'+'
\n'
+ for column_info in column_info_list:
+ column_name = column_info.get('column_name', '')
+ column_type = column_info.get('column_type', '')
+ column_note = column_info.get('column_note', '')
+ note += '\n'+f' {column_name} | \n{column_type} | \n{column_note} | \n'+'
\n'
+ note += '
'
+ return note
+
+ @staticmethod
+ def extract_list_statements(list_string):
+ pattern = r'\[.*?\]'
+ matches = re.findall(pattern, list_string)
+ if len(matches) == 0:
+ return ''
+ tmp = matches[0]
+ tmp.replace('\'', '\"')
+ tmp.replace(',', ',')
+ return tmp
+
+ @staticmethod
+ async def get_most_similar_table_id_list(database_id, question, table_choose_cnt):
+ table_info_list = await TableInfoManager.get_table_info_by_database_id(database_id)
+ random.shuffle(table_info_list)
+ table_id_set = set()
+ for table_info in table_info_list:
+ table_id = table_info['table_id']
+ table_id_set.add(str(table_id))
+ try:
+ with open('./chat2db/docs/prompt.yaml', 'r', encoding='utf-8') as f:
+ prompt_dict = yaml.load(f, Loader=yaml.SafeLoader)
+ prompt = prompt_dict.get('table_choose_prompt', '')
+ table_entries = '\n'
+ table_entries += '\n'+' 主键 | \n表注释 | \n'+'
\n'
+ token_upper = 2048
+ for table_info in table_info_list:
+ table_id = table_info['table_id']
+ table_note = table_info['table_note']
+ if len(table_entries) + len(
+ '\n' + f' {table_id} | \n{table_note} | \n' + '
\n') > token_upper:
+ break
+ table_entries += '\n'+f' {table_id} | \n{table_note} | \n'+'
\n'
+ table_entries += '
'
+ prompt = prompt.format(table_cnt=table_choose_cnt, table_entries=table_entries, question=question)
+ logging.info(f'在大模型增强模式下,选择表的prompt构造成功:{prompt}')
+ except Exception as e:
+ logging.error(f'在大模型增强模式下,选择表的prompt构造失败由于:{e}')
+ return []
+ try:
+ llm = LLM(model_name=config['LLM_MODEL'],
+ openai_api_base=config['LLM_URL'],
+ openai_api_key=config['LLM_KEY'],
+ max_tokens=config['LLM_MAX_TOKENS'],
+ request_timeout=60,
+ temperature=0.5)
+ except Exception as e:
+ llm = None
+ logging.error(f'在大模型增强模式下,选择表的过程中,与大模型建立连接失败由于:{e}')
+ table_id_list = []
+ if llm is not None:
+ for i in range(2):
+ content = await llm.chat_with_model(prompt, '请输包含选择表主键的列表')
+ try:
+ sub_table_id_list = json.loads(SqlGenerateService.extract_list_statements(content))
+ except:
+ sub_table_id_list = []
+ for j in range(len(sub_table_id_list)):
+ if sub_table_id_list[j] in table_id_set and uuid.UUID(sub_table_id_list[j]) not in table_id_list:
+ table_id_list.append(uuid.UUID(sub_table_id_list[j]))
+ if len(table_id_list) < table_choose_cnt:
+ table_choose_cnt -= len(table_id_list)
+ for i in range(min(table_choose_cnt, len(table_info_list))):
+ table_id = table_info_list[i]['table_id']
+ if table_id is not None and table_id not in table_id_list:
+ table_id_list.append(table_id)
+ return table_id_list
+
+ @staticmethod
+ async def find_most_similar_sql_example(
+ database_id, table_id_list, question, use_llm_enhancements=False, table_choose_cnt=2, sql_example_choose_cnt=10,
+ topk=5):
+ try:
+ database_url = await DatabaseInfoManager.get_database_url_by_id(database_id)
+ except Exception as e:
+ logging.error(f'数据库{database_id}信息获取失败由于{e}')
+ return []
+ database_type = 'psotgres'
+ if 'mysql' in database_url:
+ database_type = 'mysql'
+ del database_url
+ try:
+ question_vector = await Vectorize.vectorize_embedding(question)
+ except Exception as e:
+ logging.error(f'问题向量化失败由于:{e}')
+ return {}
+ sql_example = []
+ data_frame_list = []
+ if table_id_list is None:
+ if use_llm_enhancements:
+ table_id_list = await SqlGenerateService.get_most_similar_table_id_list(database_id, question, table_choose_cnt)
+ else:
+ try:
+ table_info_list = await TableInfoManager.get_table_info_by_database_id(database_id)
+ table_id_list = []
+ for table_info in table_info_list:
+ table_id_list.append(table_info['table_id'])
+ sql_example_list = await SqlExampleManager.get_topk_sql_example_by_cos_dis(
+ question_vector=question_vector,
+ table_id_list=table_id_list, topk=table_choose_cnt * 2)
+ table_id_list = []
+ for sql_example in sql_example_list:
+ table_id_list.append(sql_example['table_id'])
+ except Exception as e:
+ logging.error(f'非增强模式下,表id获取失败由于:{e}')
+ return []
+ table_id_list = list(set(table_id_list))
+ if len(table_id_list) < table_choose_cnt:
+ try:
+ expand_table_id_list = await TableInfoManager.get_topk_table_by_cos_dis(
+ database_id, question_vector, table_choose_cnt - len(table_id_list))
+ table_id_list += expand_table_id_list
+ except Exception as e:
+ logging.error(f'非增强模式下,表id补充失败由于:{e}')
+ exist_table_id = set()
+ note_list = []
+ for i in range(min(2, len(table_id_list))):
+ table_id = table_id_list[i]
+ if table_id in exist_table_id:
+ continue
+ exist_table_id.add(table_id)
+ try:
+ table_info = await TableInfoManager.get_table_info_by_table_id(table_id)
+ column_info_list = await ColumnInfoManager.get_column_info_by_table_id(table_id)
+ except Exception as e:
+ logging.error(f'表{table_id}注释获取失败由于{e}')
+ note = await SqlGenerateService.merge_table_and_column_info(table_info, column_info_list)
+ note_list.append(note)
+ try:
+ sql_example_list = await SqlExampleManager.get_topk_sql_example_by_cos_dis(
+ question_vector, table_id_list=[table_id], topk=sql_example_choose_cnt)
+ except Exception as e:
+ sql_example_list = []
+ logging.error(f'id为{table_id}的表的最相近的{topk}条sql案例获取失败由于:{e}')
+ question_sql_list = []
+ for i in range(len(sql_example_list)):
+ question_sql_list.append(
+ {'question': sql_example_list[i]['question'],
+ 'sql': sql_example_list[i]['sql']})
+ data_frame_list.append({'table_id': table_id, 'table_info': table_info,
+ 'column_info_list': column_info_list, 'sql_example_list': question_sql_list})
+ return data_frame_list
+
+ @staticmethod
+ async def megre_sql_example(sql_example_list):
+ sql_example = ''
+ for i in range(len(sql_example_list)):
+ sql_example += '问题'+str(i)+':\n'+sql_example_list[i].get('question',
+ '')+'\nsql'+str(i)+':\n'+sql_example_list[i].get('sql', '')+'\n'
+ return sql_example
+
+ @staticmethod
+ async def extract_select_statements(sql_string):
+ pattern = r"(?i)select[^;]*;"
+ matches = re.findall(pattern, sql_string)
+ if len(matches) == 0:
+ return ''
+ sql = matches[0]
+ sql = sql.strip()
+ sql.replace(',', ',')
+ return sql
+
+ @staticmethod
+ async def generate_sql_base_on_exmpale(
+ database_id, question, table_id_list=None, sql_generate_cnt=1, use_llm_enhancements=False):
+ try:
+ database_url = await DatabaseInfoManager.get_database_url_by_id(database_id)
+ except Exception as e:
+ logging.error(f'数据库{database_id}信息获取失败由于{e}')
+ return {}
+ if database_url is None:
+ raise Exception('数据库配置不存在')
+ database_type = 'psotgres'
+ if 'mysql' in database_url:
+ database_type = 'mysql'
+ data_frame_list = await SqlGenerateService.find_most_similar_sql_example(database_id, table_id_list, question, use_llm_enhancements)
+ logging.info(f'问题{question}关联到的表信息如下{str(data_frame_list)}')
+ try:
+ with open('./chat2db/docs/prompt.yaml', 'r', encoding='utf-8') as f:
+ prompt_dict = yaml.load(f, Loader=yaml.SafeLoader)
+ llm = LLM(model_name=config['LLM_MODEL'],
+ openai_api_base=config['LLM_URL'],
+ openai_api_key=config['LLM_KEY'],
+ max_tokens=config['LLM_MAX_TOKENS'],
+ request_timeout=60,
+ temperature=0.5)
+ results = []
+ for data_frame in data_frame_list:
+ prompt = prompt_dict.get('sql_generate_base_on_example_prompt', '')
+ table_info = data_frame.get('table_info', '')
+ table_id = table_info['table_id']
+ column_info_list = data_frame.get('column_info_list', '')
+ note = await SqlGenerateService.merge_table_and_column_info(table_info, column_info_list)
+ sql_example = await SqlGenerateService.megre_sql_example(data_frame.get('sql_example_list', []))
+ try:
+ prompt = prompt.format(
+ database_url=database_url, note=note, k=len(data_frame.get('sql_example_list', [])),
+ sql_example=sql_example, question=question)
+ except Exception as e:
+ logging.info(f'sql生成失败{e}')
+ return []
+ ge_cnt=0
+ ge_sql_cnt=0
+ while ge_cnt<10*sql_generate_cnt and ge_sql_cnt 1 or count_char(question, '?') > 1:
+ continue
+ except Exception as e:
+ logging.error(f'问题生成失败由于{e}')
+ continue
+ try:
+ with open('./chat2db/docs/prompt.yaml', 'r', encoding='utf-8') as f:
+ prompt_dict = yaml.load(f, Loader=yaml.SafeLoader)
+ prompt = prompt_dict['sql_generate_base_on_data_prompt'].format(
+ database_type=database_type,
+ note=note, data_frame=data_frame, question=question)
+ sql = await llm.chat_with_model(prompt, f'请输出一条可以用于查询{database_type}的sql,要以分号结尾')
+ sql = await SqlGenerateService.extract_select_statements(sql)
+ if not sql:
+ continue
+ except Exception as e:
+ logging.error(f'sql生成失败由于{e}')
+ continue
+ try:
+ if sql_var:
+ await DiffDatabaseService.get_database_service(database_type).try_excute(database_url,sql)
+ except Exception as e:
+ logging.error(f'生成的sql执行失败由于{e}')
+ continue
+ return {
+ 'question': question,
+ 'sql': sql
+ }
+ return None
+
+ @staticmethod
+ async def repair_sql(database_type, table_info, column_info_list, sql_failed, sql_failed_message, question):
+ try:
+ with open('./chat2db/docs/prompt.yaml', 'r', encoding='utf-8') as f:
+ prompt_dict = yaml.load(f, Loader=yaml.SafeLoader)
+ llm = LLM(model_name=config['LLM_MODEL'],
+ openai_api_base=config['LLM_URL'],
+ openai_api_key=config['LLM_KEY'],
+ max_tokens=config['LLM_MAX_TOKENS'],
+ request_timeout=60,
+ temperature=0.5)
+ try:
+ note = await SqlGenerateService.merge_table_and_column_info(table_info, column_info_list)
+ prompt = prompt_dict.get('sql_expand_prompt', '')
+ prompt = prompt.format(
+ database_type=database_type, note=note, sql_failed=sql_failed,
+ sql_failed_message=sql_failed_message,
+ question=question)
+ except Exception as e:
+ logging.error(f'sql修复失败由于{e}')
+ return ''
+ sql = await llm.chat_with_model(prompt, f'请输出一条在与{database_type}下能运行的sql,要以分号结尾')
+ sql = await SqlGenerateService.extract_select_statements(sql)
+ logging.info(f"修复前的sql为{sql_failed}修复后的sql为{sql}")
+ except Exception as e:
+ logging.error(f'sql生成失败由于:{e}')
+ return ''
+ return sql
diff --git a/chat2db/common/.env.example b/chat2db/common/.env.example
new file mode 100644
index 0000000000000000000000000000000000000000..9011268f306402ccb853ecca76d219888d11f03c
--- /dev/null
+++ b/chat2db/common/.env.example
@@ -0,0 +1,26 @@
+# FastAPI
+UVICORN_IP=
+UVICORN_PORT=
+SSL_CERTFILE=
+SSL_KEYFILE=
+SSL_ENABLE=
+
+# Postgres
+DATABASE_URL=
+
+# QWEN
+LLM_KEY=
+LLM_URL=
+LLM_MAX_TOKENS=
+LLM_MODEL=
+
+# Vectorize
+EMBEDDING_TYPE=
+EMBEDDING_API_KEY=
+EMBEDDING_ENDPOINT=
+EMBEDDING_MODEL_NAME=
+
+# security
+HALF_KEY1=
+HALF_KEY2=
+HALF_KEY3=
\ No newline at end of file
diff --git a/chat2db/common/init_sql_example.py b/chat2db/common/init_sql_example.py
new file mode 100644
index 0000000000000000000000000000000000000000..4288dfa8256098db5bb69204e4f80b03d11678f5
--- /dev/null
+++ b/chat2db/common/init_sql_example.py
@@ -0,0 +1,29 @@
+import yaml
+import requests
+chat2db_url='http://0.0.0.0:9015'
+with open('table_name_id.yaml') as f:
+ table_name_id=yaml.load(f,Loader=yaml.SafeLoader)
+with open('table_name_sql_exmple.yaml') as f:
+ table_name_sql_example_list=yaml.load(f,Loader=yaml.SafeLoader)
+for table_name_sql_example in table_name_sql_example_list:
+ table_name=table_name_sql_example['table_name']
+ if table_name not in table_name_id:
+ continue
+ table_id=table_name_id[table_name]
+ sql_example_list=table_name_sql_example['sql_example_list']
+ for sql_example in sql_example_list:
+ request_data = {
+ "table_id": str(table_id),
+ "question": sql_example['question'],
+ "sql": sql_example['sql']
+ }
+ url = f"{chat2db_url}/sql/example/add" # 请替换为实际的 API 域名
+
+ try:
+ response = requests.post(url, json=request_data)
+ if response.status_code!=200:
+ print(f'添加sql案例失败{response.text}')
+ else:
+ print(f'添加sql案例成功{response.text}')
+ except Exception as e:
+ print(f'添加sql案例失败由于{e}')
\ No newline at end of file
diff --git a/chat2db/common/table_name_id.yaml b/chat2db/common/table_name_id.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..22d04d50455fb44439d9bc34dac84ffe409dae8a
--- /dev/null
+++ b/chat2db/common/table_name_id.yaml
@@ -0,0 +1,10 @@
+oe_community_openeuler_version: 5ffcd194-b895-4d3c-a3ad-c4af4a270d6a
+oe_community_organization_structure: 89cd2fd0-a5ca-4ee8-8bfe-44da86a32208
+oe_compatibility_card: fb4dbdda-88c5-482a-bec8-28d51669284e
+oe_compatibility_commercial_software: 86ab7dad-4848-48da-8667-72be6c99780f
+oe_compatibility_cve_database: 984d1c82-c6d5-4d3d-93d9-8d5bc11254ba
+oe_compatibility_oepkgs: bb2698c7-f715-487f-95c4-1061a9c33851
+oe_compatibility_osv: 8c9f6608-e2e2-475d-8f67-b3bdab3e9234
+oe_compatibility_overall_unit: 82b9521a-c924-4c52-aedc-229bca5ea4c0
+oe_compatibility_security_notice: f0fa7cc5-4e5d-4c69-b202-39d522f18383
+oe_compatibility_solution: 7bdaf15c-af9f-4cb8-adec-e7e45f1de6ca
diff --git a/chat2db/common/table_name_sql_exmple.yaml b/chat2db/common/table_name_sql_exmple.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8e87a1100ebebd05e579c244395ec6e97698f094
--- /dev/null
+++ b/chat2db/common/table_name_sql_exmple.yaml
@@ -0,0 +1,490 @@
+- keyword_list:
+ - test_organization
+ - product_name
+ - company_name
+ sql_example_list:
+ - question: openEuler支持的哪些商业软件在江苏鲲鹏&欧拉生态创新中心测试通过
+ sql: SELECT product_name, product_version, openeuler_version FROM public.oe_compatibility_commercial_software
+ WHERE test_organization ILIKE '%江苏鲲鹏&欧拉生态创新中心%';
+ - question: 哪个版本的openEuler支持的商业软件最多
+ sql: SELECT openeuler_version, COUNT(*) AS software_count FROM public.oe_compatibility_commercial_software GROUP
+ BY openeuler_version ORDER BY software_count DESC LIMIT 1;
+ - question: openEuler支持测试商业软件的机构有哪些?
+ sql: SELECT DISTINCT test_organization FROM public.oe_compatibility_commercial_software;
+ - question: openEuler支持的商业软件有哪些类别
+ sql: SELECT DISTINCT "type" FROM public.oe_compatibility_commercial_software;
+ - question: openEuler有哪些虚拟化类别的商业软件
+ sql: SELECT product_name FROM public.oe_compatibility_commercial_software WHERE
+ "type" ILIKE '%虚拟化%';
+ - question: openEuler支持哪些ISV商业软件呢,请列出10个
+ sql: SELECT product_name FROM public.oe_compatibility_commercial_software;
+ - question: openEuler支持的适配Kunpeng 920的互联网商业软件有哪些?
+ sql: SELECT product_name, openeuler_version,platform_type_and_server_model FROM
+ public.oe_compatibility_commercial_software WHERE platform_type_and_server_model
+ ILIKE '%Kunpeng 920%' AND "type" ILIKE '%互联网%' limit 30;
+ - question: openEuler-22.03版本支持哪些商业软件?
+ sql: SELECT product_name, openeuler_version FROM oe_compatibility_commercial_software
+ WHERE openeuler_version ILIKE '%22.03%';
+ - question: openEuler支持的数字政府类型的商业软件有哪些
+ sql: SELECT product_name, product_version FROM oe_compatibility_commercial_software
+ WHERE type ILIKE '%数字政府%';
+ - question: 有哪些商业软件支持超过一种服务器平台
+ sql: SELECT product_name FROM public.oe_compatibility_commercial_software WHERE
+ platform_type_and_server_model ILIKE '%Intel%' AND platform_type_and_server_model
+ ILIKE '%Kunpeng%';
+ - question: 每个openEuler版本有多少种类型的商业软件支持
+ sql: SELECT openeuler_version, COUNT(DISTINCT type) AS type_count FROM public.oe_compatibility_commercial_software GROUP
+ BY openeuler_version;
+ - question: openEuler支持的哪些商业ISV在江苏鲲鹏&欧拉生态创新中心测试通过
+ sql: SELECT product_name, product_version, openeuler_version FROM public.oe_compatibility_commercial_software
+ WHERE test_organization ILIKE '%江苏鲲鹏&欧拉生态创新中心%';
+ - question: 哪个版本的openEuler支持的商业ISV最多
+ sql: SELECT openeuler_version, COUNT(*) AS software_count FROM public.oe_compatibility_commercial_software GROUP
+ BY openeuler_version ORDER BY software_count DESC LIMIT 1;
+ - question: openEuler支持测试商业ISV的机构有哪些?
+ sql: SELECT DISTINCT test_organization FROM public.oe_compatibility_commercial_software;
+ - question: openEuler支持的商业ISV有哪些类别
+ sql: SELECT DISTINCT "type" FROM public.oe_compatibility_commercial_software;
+ - question: openEuler有哪些虚拟化类别的商业ISV
+ sql: SELECT product_name FROM public.oe_compatibility_commercial_software WHERE
+ "type" ILIKE '%虚拟化%';
+ - question: openEuler支持哪些ISV商业ISV呢,请列出10个
+ sql: SELECT product_name FROM public.oe_compatibility_commercial_software;
+ - question: openEuler支持的适配Kunpeng 920的互联网商业ISV有哪些?
+ sql: SELECT product_name, openeuler_version,platform_type_and_server_model FROM
+ public.oe_compatibility_commercial_software WHERE platform_type_and_server_model
+ ILIKE '%Kunpeng 920%' AND "type" ILIKE '%互联网%' limit 30;
+ - question: openEuler-22.03版本支持哪些商业ISV?
+ sql: SELECT product_name, openeuler_version FROM oe_compatibility_commercial_software
+ WHERE openeuler_version ILIKE '%22.03%';
+ - question: openEuler支持的数字政府类型的商业ISV有哪些
+ sql: SELECT product_name, product_version FROM oe_compatibility_commercial_software
+ WHERE type ILIKE '%数字政府%';
+ - question: 有哪些商业ISV支持超过一种服务器平台
+ sql: SELECT product_name FROM public.oe_compatibility_commercial_software WHERE
+ platform_type_and_server_model ILIKE '%Intel%' AND platform_type_and_server_model
+ ILIKE '%Kunpeng%';
+ - question: 每个openEuler版本有多少种类型的商业ISV支持
+ sql: SELECT openeuler_version, COUNT(DISTINCT type) AS type_count FROM public.oe_compatibility_commercial_software GROUP
+ BY openeuler_version;
+ - question: 卓智校园网接入门户系统基于openeuelr的什么版本?
+ sql: select * from oe_compatibility_commercial_software where product_name ilike
+ '%卓智校园网接入门户系统%';
+ table_name: oe_compatibility_commercial_software
+- keyword_list:
+ - softwareName
+ sql_example_list:
+ - question: openEuler-20.03-LTS-SP1支持哪些开源软件?
+ sql: SELECT DISTINCT openeuler_version,"softwareName" FROM public.oe_compatibility_open_source_software WHERE
+ openeuler_version ILIKE '%20.03-LTS-SP1%';
+ - question: openEuler的aarch64下支持开源软件
+ sql: SELECT "softwareName" FROM public.oe_compatibility_open_source_software WHERE
+ "arch" ILIKE '%aarch64%';
+ - question: openEuler支持开源软件使用了GPLv2+许可证
+ sql: SELECT "softwareName" FROM public.oe_compatibility_open_source_software WHERE
+ "license" ILIKE '%GPLv2+%';
+ - question: tcplay支持的架构是什么
+ sql: SELECT "arch" FROM public.oe_compatibility_open_source_software WHERE "softwareName"
+ ILIKE '%tcplay%';
+ - question: openEuler支持哪些开源软件,请列出10个
+ sql: SELECT "softwareName" FROM public.oe_compatibility_open_source_software LIMIT
+ 10;
+ - question: openEuler支持开源软件支持哪些结构
+ sql: SELECT "arch" FROM public.oe_compatibility_open_source_software group by
+ "arch";
+ - question: openEuler支持多少个开源软件?
+ sql: select tmp_table.openeuler_version,count(*) as open_source_software_cnt from
+ (select DISTINCT openeuler_version,"softwareName" from oe_compatibility_open_source_software)
+ as tmp_table group by tmp_table.openeuler_version;
+ - question: openEuler-20.03-LTS-SP1支持哪些开源ISV?
+ sql: SELECT DISTINCT openeuler_version,"softwareName" FROM public.oe_compatibility_open_source_software WHERE
+ openeuler_version ILIKE '%20.03-LTS-SP1%';
+ - question: openEuler的aarch64下支持开源ISV
+ sql: SELECT "softwareName" FROM public.oe_compatibility_open_source_software WHERE
+ "arch" ILIKE '%aarch64%';
+ - question: openEuler支持开源ISV使用了GPLv2+许可证
+ sql: SELECT "softwareName" FROM public.oe_compatibility_open_source_software WHERE
+ "license" ILIKE '%GPLv2+%';
+ - question: tcplay支持的架构是什么
+ sql: SELECT "arch" FROM public.oe_compatibility_open_source_software WHERE "softwareName"
+ ILIKE '%tcplay%';
+ - question: openEuler支持哪些开源ISV,请列出10个
+ sql: SELECT "softwareName" FROM public.oe_compatibility_open_source_software LIMIT
+ 10;
+ - question: openEuler支持开源ISV支持哪些结构
+ sql: SELECT "arch" FROM public.oe_compatibility_open_source_software group by
+ "arch";
+ - question: openEuler-20.03-LTS-SP1支持多少个开源ISV?
+ sql: select tmp_table.openeuler_version,count(*) as open_source_software_cnt from
+ (select DISTINCT openeuler_version,"softwareName" from oe_compatibility_open_source_software
+ where openeuler_version ilike 'openEuler-20.03-LTS-SP1') as tmp_table group
+ by tmp_table.openeuler_version;
+ - question: openEuler支持多少个开源ISV?
+ sql: select tmp_table.openeuler_version,count(*) as open_source_software_cnt from
+ (select DISTINCT openeuler_version,"softwareName" from oe_compatibility_open_source_software)
+ as tmp_table group by tmp_table.openeuler_version;
+ table_name: oe_compatibility_open_source_software
+- keyword_list: []
+ sql_example_list:
+ - question: 在openEuler技术委员会担任委员的人有哪些
+ sql: SELECT name FROM oe_community_organization_structure WHERE committee_name
+ ILIKE '%技术委员会%' AND role = '委员';
+ - question: openEuler的委员会中哪些人是教授
+ sql: SELECT name FROM oe_community_organization_structure WHERE personal_message
+ ILIKE '%教授%';
+ - question: openEuler各委员会中担任主席有多少个?
+ sql: SELECT committee_name, COUNT(*) FROM oe_community_organization_structure
+ WHERE role = '主席' GROUP BY committee_name;
+ - question: openEuler 用户委员会中有多少位成员
+ sql: SELECT count(*) FROM oe_community_organization_structure WHERE committee_name
+ ILIKE '%用户委员会%';
+ - question: openEuler 技术委员会有多少位成员
+ sql: SELECT count(*) FROM oe_community_organization_structure WHERE committee_name
+ ILIKE '%技术委员会%';
+ - question: openEuler委员会的委员常务委员会委员有哪些人
+ sql: SELECT name FROM oe_community_organization_structure WHERE committee_name
+ ILIKE '%委员会%' AND role ILIKE '%常务委员会委员%';
+ - question: openEuler委员会有哪些人属于华为技术有限公司?
+ sql: SELECT DISTINCT name FROM oe_community_organization_structure WHERE personal_message
+ ILIKE '%华为技术有限公司%';
+ - question: openEuler每个委员会有多少人?
+ sql: SELECT committee_name, COUNT(*) FROM oe_community_organization_structure
+ GROUP BY committee_name;
+ - question: openEuler的执行总监是谁
+ sql: SELECT name FROM oe_community_organization_structure WHERE role = '执行总监';
+ - question: openEuler委员会有哪些组织?
+ sql: SELECT DISTINCT committee_name from oe_community_organization_structure;
+ - question: openEuler技术委员会的主席是谁?
+ sql: SELECT committee_name,name FROM oe_community_organization_structure WHERE
+ role = '主席' and committee_name ilike '%技术委员会%';
+ - question: openEuler品牌委员会的主席是谁?
+ sql: SELECT committee_name,name FROM oe_community_organization_structure WHERE
+ role = '主席' and committee_name ilike '%品牌委员会%';
+ - question: openEuler委员会的主席是谁?
+ sql: SELECT committee_name,name FROM oe_community_organization_structure WHERE
+ role = '主席' and committee_name ilike '%openEuler 委员会%';
+ - question: openEuler委员会的执行总监是谁?
+ sql: SELECT committee_name,name FROM oe_community_organization_structure WHERE
+ role = '执行总监' and committee_name ilike '%openEuler 委员会%';
+ - question: openEuler委员会的执行秘书是谁?
+ sql: SELECT committee_name,name FROM oe_community_organization_structure WHERE
+ role = '执行秘书' and committee_name ilike '%openEuler 委员会%';
+ table_name: oe_community_organization_structure
+- keyword_list:
+ - cve_id
+ sql_example_list:
+ - question: 安全公告openEuler-SA-2024-2059的详细信息在哪里?
+ sql: select DISTINCT security_notice_no,details from oe_compatibility_security_notice
+ where security_notice_no='openEuler-SA-2024-2059';
+ table_name: oe_compatibility_security_notice
+- keyword_list:
+ - hardware_model
+ sql_example_list:
+ - question: openEuler-22.03 LTS支持哪些整机?
+ sql: SELECT main_board_model, cpu, ram FROM oe_compatibility_overall_unit WHERE
+ openeuler_version ILIKE '%openEuler-22.03-LTS%';
+ - question: 查询所有支持`openEuler-22.09`,并且提供详细产品介绍链接的整机型号和它们的内存配置?
+ sql: SELECT hardware_model, ram FROM oe_compatibility_overall_unit WHERE openeuler_version
+ ILIKE '%openEuler-22.09%' AND product_information IS NOT NULL;
+ - question: 显示所有由新华三生产,支持`openEuler-20.03 LTS SP2`版本的整机,列出它们的型号和架构类型
+ sql: SELECT hardware_model, architecture FROM oe_compatibility_overall_unit WHERE
+ hardware_factory = '新华三' AND openeuler_version ILIKE '%openEuler-20.03 LTS SP2%';
+ - question: openEuler支持多少种整机?
+ sql: SELECT count(DISTINCT main_board_model) FROM oe_compatibility_overall_unit;
+ - question: openEuler每个版本支持多少种整机?
+ sql: select openeuler_version,count(*) from (SELECT DISTINCT openeuler_version,main_board_model
+ FROM oe_compatibility_overall_unit) as tmp_table group by openeuler_version;
+ - question: openEuler每个版本多少种架构的整机?
+ sql: select openeuler_version,architecture,count(*) from (SELECT DISTINCT openeuler_version,architecture,main_board_model
+ FROM oe_compatibility_overall_unit) as tmp_table group by openeuler_version,architecture;
+ table_name: oe_compatibility_overall_unit
+- keyword_list:
+ - osv_name
+ - os_version
+ sql_example_list:
+ - question: 深圳开鸿数字产业发展有限公司基于openEuler的什么版本发行了什么商用版本?
+ sql: select os_version,openeuler_version,os_download_link from oe_compatibility_osv
+ where osv_name='深圳开鸿数字产业发展有限公司';
+ - question: 统计各个openEuler版本下的商用操作系统数量
+ sql: SELECT openeuler_version, COUNT(*) AS os_count FROM public.oe_compatibility_osv GROUP
+ BY openeuler_version;
+ - question: 哪个OS厂商基于openEuler发布的商用操作系统最多
+ sql: SELECT osv_name, COUNT(*) AS os_count FROM public.oe_compatibility_osv GROUP
+ BY osv_name ORDER BY os_count DESC LIMIT 1;
+ - question: 不同OS厂商基于openEuler发布不同架构的商用操作系统数量是多少?
+ sql: SELECT arch, osv_name, COUNT(*) AS os_count FROM public.oe_compatibility_osv GROUP
+ BY arch, osv_name ORDER BY arch, os_count DESC;
+ - question: 深圳开鸿数字产业发展有限公司的商用操作系统是基于什么openEuler版本发布的
+ sql: SELECT os_version, openeuler_version FROM public.oe_compatibility_osv WHERE
+ osv_name ILIKE '%深圳开鸿数字产业发展有限公司%';
+ - question: openEuler有哪些OSV伙伴
+ sql: SELECT DISTINCT osv_name FROM public.oe_compatibility_osv;
+ - question: 有哪些OSV友商的操作系统是x86_64架构的
+ sql: SELECT osv_name, os_version FROM public.oe_compatibility_osv WHERE arch ILIKE
+ '%x86_64%';
+ - question: 哪些OSV友商操作系统是嵌入式类型的
+ sql: SELECT osv_name, os_version,openeuler_version FROM public.oe_compatibility_osv
+ WHERE type ILIKE '%嵌入式%';
+ - question: 成都鼎桥的商用操作系统版本是基于openEuler 22.03的版本吗
+ sql: SELECT osv_name, os_version,"openeuler_version" FROM public.oe_compatibility_osv WHERE
+ osv_name ILIKE '%成都鼎桥通信技术有限公司%' AND openeuler_version ILIKE '%22.03%';
+ - question: 最近发布的基于openEuler 23.09的商用系统有哪些
+ sql: SELECT osv_name, os_version,"openeuler_version" FROM public.oe_compatibility_osv WHERE
+ openeuler_version ILIKE '%23.09%' ORDER BY date DESC limit 10;
+ - question: 帮我查下成都智明达发布的所有嵌入式系统
+ sql: SELECT osv_name, os_version,"openeuler_version" FROM public.oe_compatibility_osv WHERE
+ osv_name ILIKE '%成都智明达电子股份有限公司%' AND type = '嵌入式';
+ - question: 基于openEuler发布的商用操作系统有哪些类型
+ sql: SELECT DISTINCT type FROM public.oe_compatibility_osv;
+ - question: 江苏润和系统版本HopeOS-V22-x86_64-dvd.iso基于openEuler哪个版本
+ sql: SELECT DISTINCT osv_name, os_version,"openeuler_version" FROM public.oe_compatibility_osv
+ WHERE "osv_name" ILIKE '%江苏润和%' AND os_version ILIKE '%HopeOS-V22-x86_64-dvd.iso%'
+ ;
+ - question: 浙江大华DH-IVSS-OSV-22.03-LTS-SP2-x86_64-dvd.iso系统版本基于openEuler哪个版本
+ sql: SELECT DISTINCT osv_name, os_version,"openeuler_version" FROM public.oe_compatibility_osv
+ WHERE "osv_name" ILIKE '%浙江大华%' AND os_version ILIKE '%DH-IVSS-OSV-22.03-LTS-SP2-x86_64-dvd.iso%'
+ ;
+ table_name: oe_compatibility_osv
+- keyword_list:
+ - board_model
+ - chip_model
+ - chip_vendor
+ - product
+ sql_example_list:
+ - question: openEuler 22.03支持哪些网络接口卡型号?
+ sql: SELECT board_model, chip_model,type FROM oe_compatibility_card WHERE type
+ ILIKE '%NIC%' AND openeuler_version ILIKE '%22.03%' limit 30;
+ - question: 请列出openEuler支持的所有Renesas公司的密码卡
+ sql: SELECT * FROM oe_compatibility_card WHERE chip_vendor ILIKE '%Renesas%' AND
+ type ILIKE '%密码卡%' limit 30;
+ - question: openEuler各种架构支持的板卡数量是多少
+ sql: SELECT architecture, COUNT(*) AS total_cards FROM oe_compatibility_card GROUP
+ BY architecture limit 30;
+ - question: 每个openEuler版本支持了多少种板卡
+ sql: SELECT openeuler_version, COUNT(*) AS number_of_cards FROM oe_compatibility_card
+ GROUP BY openeuler_version limit 30;
+ - question: openEuler总共支持多少种不同的板卡型号
+ sql: SELECT COUNT(DISTINCT board_model) AS board_model_cnt FROM oe_compatibility_card
+ limit 30;
+ - question: openEuler支持的GPU型号有哪些?
+ sql: SELECT chip_model, openeuler_version,type FROM public.oe_compatibility_card WHERE
+ type ILIKE '%GPU%' ORDER BY driver_date DESC limit 30;
+ - question: openEuler 20.03 LTS-SP4版本支持哪些类型的设备
+ sql: SELECT DISTINCT openeuler_version,type FROM public.oe_compatibility_card WHERE
+ openeuler_version ILIKE '%20.03-LTS-SP4%' limit 30;
+ - question: openEuler支持的板卡驱动在2023年后发布
+ sql: SELECT board_model, driver_date, driver_name FROM oe_compatibility_card WHERE
+ driver_date >= '2023-01-01' limit 30;
+ - question: 给些支持openEuler的aarch64架构下支持的的板卡的驱动下载链接
+ sql: SELECT openeuler_version,board_model, download_link FROM oe_compatibility_card
+ WHERE architecture ILIKE '%aarch64%' AND download_link IS NOT NULL limit 30;
+ - question: openEuler-22.03-LTS-SP1支持的存储卡有哪些?
+ sql: SELECT openeuler_version,board_model, chip_model,type FROM oe_compatibility_card
+ WHERE type ILIKE '%SSD%' AND openeuler_version ILIKE '%openEuler-22.03-LTS-SP1%'
+ limit 30;
+ table_name: oe_compatibility_card
+- keyword_list:
+ - cve_id
+ sql_example_list:
+ - question: CVE-2024-41053的详细信息在哪里可以看到?
+ sql: select DISTINCT cve_id,details from oe_compatibility_cve_database where cve_id='CVE-2024-41053';
+ - question: CVE-2024-41053是个怎么样的漏洞?
+ sql: select DISTINCT cve_id,summary from oe_compatibility_cve_database where cve_id='CVE-2024-41053';
+ - question: CVE-2024-41053影响了哪些包?
+ sql: select DISTINCT cve_id,package_name from oe_compatibility_cve_database where
+ cve_id='CVE-2024-41053';
+ - question: CVE-2024-41053的cvss评分是多少?
+ sql: select DISTINCT cve_id,cvsss_core_nvd from oe_compatibility_cve_database
+ where cve_id='CVE-2024-41053';
+ - question: CVE-2024-41053现在修复了么?
+ sql: select DISTINCT cve_id, status from oe_compatibility_cve_database where cve_id='CVE-2024-41053';
+ - question: CVE-2024-41053影响了openEuler哪些版本?
+ sql: select DISTINCT cve_id, affected_product from oe_compatibility_cve_database
+ where cve_id='CVE-2024-41053';
+ - question: CVE-2024-41053发布时间是?
+ sql: select DISTINCT cve_id, announcement_time from oe_compatibility_cve_database
+ where cve_id='CVE-2024-41053';
+ - question: openEuler-20.03-LTS-SP4在2024年8月发布哪些漏洞?
+ sql: select DISTINCT affected_product,cve_id,announcement_time from oe_compatibility_cve_database
+ where cve_id='CVE-2024-41053' and affected_product='openEuler-20.03-LTS-SP4'
+ and EXTRACT(MONTH FROM announcement_time)=8;
+ - question: openEuler-20.03-LTS-SP4在2024年发布哪些漏洞?
+ sql: select DISTINCT affected_product,cve_id,announcement_time from oe_compatibility_cve_database
+ where cve_id='CVE-2024-41053' and affected_product='openEuler-20.03-LTS-SP4'
+ and EXTRACT(YEAR FROM announcement_time)=2024;
+ - question: CVE-2024-41053的威胁程度是怎样的?
+ sql: select DISTINCT affected_product,cve_id,cvsss_core_nvd,attack_complexity_nvd,attack_complexity_oe,attack_vector_nvd,attack_vector_oe
+ from oe_compatibility_cve_database where cve_id='CVE-2024-41053';
+ table_name: oe_compatibility_cve_database
+- keyword_list:
+ - name
+ sql_example_list:
+ - question: openEuler-20.03-LTS的非官方软件包有多少个?
+ sql: SELECT COUNT(*) FROM oe_compatibility_oepkgs WHERE repotype = 'openeuler_compatible'
+ AND openeuler_version ILIKE '%openEuler-20.03-LTS%';
+ - question: openEuler支持的nginx版本有哪些?
+ sql: SELECT DISTINCT name,version, srcrpmpackurl FROM oe_compatibility_oepkgs
+ WHERE name ILIKE 'nginx';
+ - question: openEuler的支持哪些架构的glibc?
+ sql: SELECT DISTINCT name,arch FROM oe_compatibility_oepkgs WHERE name ILIKE 'glibc';
+ - question: openEuler-22.03-LTS带GPLv2许可的软件包有哪些
+ sql: SELECT name,rpmlicense FROM oe_compatibility_oepkgs WHERE openeuler_version
+ ILIKE '%openEuler-22.03-LTS%' AND rpmlicense = 'GPLv2';
+ - question: openEuler支持的python3这个软件包是用来干什么的?
+ sql: SELECT DISTINCT name,summary FROM oe_compatibility_oepkgs WHERE name ILIKE
+ 'python3';
+ - question: 哪些版本的openEuler的zlib中有官方源的?
+ sql: SELECT DISTINCT openeuler_version,name,version FROM oe_compatibility_oepkgs
+ WHERE name ILIKE '%zlib%' AND repotype = 'openeuler_official';
+ - question: 请以表格的形式提供openEuler-20.09的gcc软件包的下载链接
+ sql: SELECT DISTINCT openeuler_version,name, rpmpackurl FROM oe_compatibility_oepkgs
+ WHERE openeuler_version ILIKE '%openEuler-20.09%' AND name ilike 'gcc';
+ - question: 请以表格的形式提供openEuler-20.09的glibc软件包的下载链接
+ sql: SELECT DISTINCT openeuler_version,name, rpmpackurl FROM oe_compatibility_oepkgs
+ WHERE openeuler_version ILIKE '%openEuler-20.09%' AND name ilike 'glibc';
+ - question: 请以表格的形式提供openEuler-20.09的redis软件包的下载链接
+ sql: SELECT DISTINCT openeuler_version,name, rpmpackurl FROM oe_compatibility_oepkgs
+ WHERE openeuler_version ILIKE '%openEuler-20.09%' AND name ilike 'redis';
+ - question: openEuler-20.09的支持多少个软件包?
+ sql: select tmp_table.openeuler_version,count(*) as oepkgs_cnt from (select DISTINCT
+ openeuler_version,name from oe_compatibility_oepkgs WHERE openeuler_version
+ ILIKE '%openEuler-20.09') as tmp_table group by tmp_table.openeuler_version;
+ - question: openEuler支持多少个软件包?
+ sql: select tmp_table.openeuler_version,count(*) as oepkgs_cnt from (select DISTINCT
+ openeuler_version,name from oe_compatibility_oepkgs) as tmp_table group by tmp_table.openeuler_version;
+ - question: 请以表格的形式提供openEuler-20.09的gcc的版本
+ sql: SELECT DISTINCT openeuler_version,name, version FROM oe_compatibility_oepkgs
+ WHERE openeuler_version ILIKE '%openEuler-20.09%' AND name ilike 'gcc';
+ - question: 请以表格的形式提供openEuler-20.09的glibc的版本
+ sql: SELECT DISTINCT openeuler_version,name, version FROM oe_compatibility_oepkgs
+ WHERE openeuler_version ILIKE '%openEuler-20.09%' AND name ilike 'glibc';
+ - question: 请以表格的形式提供openEuler-20.09的redis的版本
+ sql: SELECT DISTINCT openeuler_version,name, version FROM oe_compatibility_oepkgs
+ WHERE openeuler_version ILIKE '%openEuler-20.09%' AND name ilike 'redis';
+ - question: openEuler-20.09支持哪些gcc的版本
+ sql: SELECT DISTINCT openeuler_version,name, version FROM oe_compatibility_oepkgs
+ WHERE openeuler_version ILIKE '%openEuler-20.09%' AND name ilike 'gcc';
+ - question: openEuler-20.09支持哪些glibc的版本
+ sql: SELECT DISTINCT openeuler_version,name, version FROM oe_compatibility_oepkgs
+ WHERE openeuler_version ILIKE '%openEuler-20.09%' AND name ilike 'glibc';
+ - question: openEuler-20.09支持哪些redis的版本
+ sql: SELECT DISTINCT openeuler_version,name, version FROM oe_compatibility_oepkgs
+ WHERE openeuler_version ILIKE '%openEuler-20.09%' AND name ilike 'redis';
+ - question: ''
+ sql: openEuler-20.09支持的gcc版本有哪些
+ - question: SELECT DISTINCT openeuler_version,name, version FROM oe_compatibility_oepkgs
+ WHERE openeuler_version ILIKE '%openEuler-20.09%' AND name ilike 'gcc';
+ sql: openEuler-20.09支持的glibc版本有哪些
+ - question: SELECT DISTINCT openeuler_version,name, version FROM oe_compatibility_oepkgs
+ WHERE openeuler_version ILIKE '%openEuler-20.09%' AND name ilike 'glibc';
+ sql: openEuler-20.09支持的redis版本有哪些
+ - question: SELECT DISTINCT openeuler_version,name, version FROM oe_compatibility_oepkgs
+ WHERE openeuler_version ILIKE '%openEuler-20.09%' AND name ilike 'redis';
+ sql: ''
+ - question: openEuler-20.09支持gcc 9.3.1么?
+ sql: SELECT DISTINCT openeuler_version,name, version FROM oe_compatibility_oepkgs
+ WHERE openeuler_version ILIKE '%openEuler-20.09%' AND name ilike 'gcc' AND version
+ ilike '9.3.1';
+ table_name: oe_compatibility_oepkgs
+- keyword_list: []
+ sql_example_list:
+ - question: openEuler社区创新版本有哪些
+ sql: SELECT DISTINCT openeuler_version,version_type FROM oe_community_openeuler_version
+ where version_type ILIKE '%社区创新版本%';
+ - question: openEuler有哪些版本
+ sql: SELECT openeuler_version FROM public.oe_community_openeuler_version;
+ - question: 查询openeuler各版本对应的内核版本
+ sql: SELECT DISTINCT openeuler_version, kernel_version FROM public.oe_community_openeuler_version;
+ - question: openEuler有多少个长期支持版本(LTS)
+ sql: SELECT COUNT(*) as publish_version_count FROM public.oe_community_openeuler_version
+ WHERE version_type ILIKE '%长期支持版本%';
+ - question: 查询openEuler-20.03的所有SP版本
+ sql: SELECT openeuler_version FROM public.oe_community_openeuler_version WHERE
+ openeuler_version ILIKE '%openEuler-20.03-LTS-SP%';
+ - question: openEuler最新的社区创新版本内核是啥
+ sql: SELECT kernel_version FROM public.oe_community_openeuler_version WHERE version_type
+ ILIKE '%社区创新版本%' ORDER BY publish_time DESC LIMIT 1;
+ - question: 最早的openEuler版本是什么时候发布的
+ sql: SELECT openeuler_version,publish_time FROM public.oe_community_openeuler_version
+ ORDER BY publish_time ASC LIMIT 1;
+ - question: 最新的openEuler版本是哪个
+ sql: SELECT openeuler_version,publish_time FROM public.oe_community_openeuler_version
+ ORDER BY publish_time LIMIT 1;
+ - question: openEuler有哪些版本使用了Linux 5.10.0内核
+ sql: SELECT openeuler_version,kernel_version FROM public.oe_community_openeuler_version
+ WHERE kernel_version ILIKE '5.10.0%';
+ - question: 哪个openEuler版本是最近更新的长期支持版本
+ sql: SELECT openeuler_version,publish_time FROM public.oe_community_openeuler_version
+ WHERE version_type ILIKE '%长期支持版本%' ORDER BY publish_time DESC LIMIT 1;
+ - question: openEuler每个年份发布了多少个版本
+ sql: SELECT EXTRACT(YEAR FROM publish_time) AS year, COUNT(*) AS publish_version_count
+ FROM oe_community_openeuler_version group by EXTRACT(YEAR FROM publish_time);
+ - question: openEuler-20.03-LTS版本的linux内核是多少?
+ sql: SELECT openeuler_version,kernel_version FROM public.oe_community_openeuler_version
+ WHERE openeuler_version = 'openEuler-20.03-LTS';
+ - question: openEuler-20.03-LTS版本的linux内核是多少?
+ sql: SELECT openeuler_version,kernel_version FROM public.oe_community_openeuler_version
+ WHERE openeuler_version = 'openEuler-24.09';
+ table_name: oe_community_openeuler_version
+- keyword_list:
+ - product
+ sql_example_list:
+ - question: 哪些openEuler版本支持使用至强6338N的解决方案
+ sql: SELECT DISTINCT openeuler_version FROM oe_compatibility_solution WHERE cpu
+ ILIKE '%6338N%';
+ - question: 使用intel XXV710作为网卡的解决方案对应的是哪些服务器型号
+ sql: SELECT DISTINCT server_model FROM oe_compatibility_solution WHERE network_card
+ ILIKE '%intel XXV710%';
+ - question: 哪些解决方案的硬盘驱动为SATA-SSD Skhynix
+ sql: SELECT DISTINCT product FROM oe_compatibility_solution WHERE hard_disk_drive
+ ILIKE 'SATA-SSD Skhynix';
+ - question: 查询所有使用6230R系列CPU且支持磁盘阵列支持PERC H740P Adapter的解决方案的产品名
+ sql: SELECT DISTINCT product FROM oe_compatibility_solution WHERE cpu ILIKE '%6230R%'
+ AND raid ILIKE '%PERC H740P Adapter%';
+ - question: R4900-G3有哪些驱动版本
+ sql: SELECT DISTINCT driver FROM oe_compatibility_solution WHERE product ILIKE
+ '%R4900-G3%';
+ - question: DL380 Gen10支持哪些架构
+ sql: SELECT DISTINCT architecture FROM oe_compatibility_solution WHERE server_model
+ ILIKE '%DL380 Gen10%';
+ - question: 列出所有使用Intel(R) Xeon(R)系列cpu且磁盘冗余阵列为LSI SAS3408的解决方案的服务器厂家
+ sql: SELECT DISTINCT server_vendor FROM oe_compatibility_solution WHERE cpu ILIKE
+ '%Intel(R) Xeon(R)%' AND raid ILIKE '%LSI SAS3408%';
+ - question: 哪些解决方案提供了针对SEAGATE ST4000NM0025硬盘驱动的支持
+ sql: SELECT * FROM oe_compatibility_solution WHERE hard_disk_drive ILIKE '%SEAGATE
+ ST4000NM0025%';
+ - question: 查询所有使用4316系列CPU的解决方案
+ sql: SELECT * FROM oe_compatibility_solution WHERE cpu ILIKE '%4316%';
+ - question: 支持openEuler-22.03-LTS-SP2版本的解决方案中,哪款服务器型号出现次数最多
+ sql: SELECT server_model, COUNT(*) as count FROM oe_compatibility_solution WHERE
+ openeuler_version ILIKE '%openEuler-22.03-LTS-SP2%' GROUP BY server_model ORDER
+ BY count DESC LIMIT 1;
+ - question: HPE提供的解决方案的介绍链接是什么
+ sql: SELECT DISTINCT introduce_link FROM oe_compatibility_solution WHERE server_vendor
+ ILIKE '%HPE%';
+ - question: 列出所有使用intel XXV710网络卡接口的解决方案的CPU型号
+ sql: SELECT DISTINCT cpu FROM oe_compatibility_solution WHERE network_card ILIKE
+ '%intel XXV710%';
+ - question: 服务器型号为2288H V5的解决方案支持哪些不同的openEuler版本
+ sql: SELECT DISTINCT openeuler_version FROM oe_compatibility_solution WHERE server_model
+ ILIKE '%NF5180M5%';
+ - question: 使用6230R系列CPU的解决方案内存最小是多少GB
+ sql: SELECT MIN(ram) FROM oe_compatibility_solution WHERE cpu ILIKE '%6230R%';
+ - question: 哪些解决方案的磁盘驱动为MegaRAID 9560-8i
+ sql: SELECT * FROM oe_compatibility_solution WHERE hard_disk_drive LIKE '%MegaRAID
+ 9560-8i%';
+ - question: 列出所有使用6330N系列CPU且服务器厂家为Dell的解决方案的产品名
+ sql: SELECT DISTINCT product FROM oe_compatibility_solution WHERE cpu ILIKE '%6330N%'
+ AND server_vendor ILIKE '%Dell%';
+ - question: R4900-G3的驱动版本是多少
+ sql: SELECT driver FROM oe_compatibility_solution WHERE product ILIKE '%R4900-G3%';
+ - question: 哪些解决方案的服务器型号为2288H V7
+ sql: SELECT * FROM oe_compatibility_solution WHERE server_model ILIKE '%2288H
+ V7%';
+ - question: 使用Intel i350网卡且硬盘驱动为ST4000NM0025的解决方案的服务器厂家有哪些
+ sql: SELECT DISTINCT server_vendor FROM oe_compatibility_solution WHERE network_card
+ ILIKE '%Intel i350%' AND hard_disk_drive ILIKE '%ST4000NM0025%';
+ - question: 有多少种不同的驱动版本被用于支持openEuler-22.03-LTS-SP2版本的解决方案
+ sql: SELECT COUNT(DISTINCT driver) FROM oe_compatibility_solution WHERE openeuler_version
+ ILIKE '%openEuler-22.03-LTS-SP2%';
+ table_name: oe_compatibility_solution
diff --git a/chat2db/config/config.py b/chat2db/config/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..93faa3aa13e07a629bb38c3d7732c8cef8c21bfd
--- /dev/null
+++ b/chat2db/config/config.py
@@ -0,0 +1,55 @@
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
+import os
+
+from dotenv import dotenv_values
+from pydantic import BaseModel, Field
+
+
+class ConfigModel(BaseModel):
+ # FastAPI
+ UVICORN_IP: str = Field(None, description="FastAPI 服务的IP地址")
+ UVICORN_PORT: int = Field(None, description="FastAPI 服务的端口号")
+ SSL_CERTFILE: str = Field(None, description="SSL证书文件的路径")
+ SSL_KEYFILE: str = Field(None, description="SSL密钥文件的路径")
+ SSL_ENABLE: str = Field(None, description="是否启用SSL连接")
+
+ # Postgres
+ DATABASE_URL: str = Field(None, description="数据库url")
+
+ # QWEN
+ LLM_KEY: str = Field(None, description="语言模型访问密钥")
+ LLM_URL: str = Field(None, description="语言模型服务的基础URL")
+ LLM_MAX_TOKENS: int = Field(None, description="单次请求中允许的最大Token数")
+ LLM_MODEL: str = Field(None, description="使用的语言模型名称或版本")
+
+ # Vectorize
+ EMBEDDING_TYPE: str = Field("openai", description="embedding 服务的类型")
+ EMBEDDING_API_KEY: str = Field(None, description="embedding服务api key")
+ EMBEDDING_ENDPOINT: str = Field(None, description="embedding服务url地址")
+ EMBEDDING_MODEL_NAME: str = Field(None, description="embedding模型名称")
+
+ # security
+ HALF_KEY1: str = Field(None, description='加密的密钥组件1')
+ HALF_KEY2: str = Field(None, description='加密的密钥组件2')
+ HALF_KEY3: str = Field(None, description='加密的密钥组件3')
+
+
+class Config:
+ config: ConfigModel
+
+ def __init__(self):
+ if os.getenv("CONFIG"):
+ config_file = os.getenv("CONFIG")
+ else:
+ config_file = "./chat2db/common/.env"
+ self.config = ConfigModel(**(dotenv_values(config_file)))
+ if os.getenv("PROD"):
+ os.remove(config_file)
+
+ def __getitem__(self, key):
+ if key in self.config.__dict__:
+ return self.config.__dict__[key]
+ return None
+
+
+config = Config()
diff --git a/chat2db/database/postgres.py b/chat2db/database/postgres.py
new file mode 100644
index 0000000000000000000000000000000000000000..0148c4bd586daf0bf279549ed2725f89b45d6be9
--- /dev/null
+++ b/chat2db/database/postgres.py
@@ -0,0 +1,119 @@
+import logging
+from uuid import uuid4
+from pgvector.sqlalchemy import Vector
+from sqlalchemy.orm import sessionmaker, declarative_base
+from sqlalchemy import TIMESTAMP, UUID, Column, String, Boolean, ForeignKey, create_engine, func, Index
+from chat2db.config.config import config
+
+logging.basicConfig(filename='app.log', level=logging.INFO,
+ format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s')
+Base = declarative_base()
+
+
+class DatabaseInfo(Base):
+ __tablename__ = 'database_info_table'
+ id = Column(UUID(), default=uuid4, primary_key=True)
+ encrypted_database_url = Column(String())
+ encrypted_config = Column(String())
+ hashmac = Column(String())
+ created_at = Column(TIMESTAMP(timezone=True), nullable=True, server_default=func.current_timestamp())
+
+
+class TableInfo(Base):
+ __tablename__ = 'table_info_table'
+ id = Column(UUID(), default=uuid4, primary_key=True)
+ database_id = Column(UUID(), ForeignKey('database_info_table.id', ondelete='CASCADE'))
+ table_name = Column(String())
+ table_note = Column(String())
+ table_note_vector = Column(Vector(1024))
+ enable = Column(Boolean, default=False)
+ created_at = Column(TIMESTAMP(timezone=True), nullable=True, server_default=func.current_timestamp())
+ updated_at = Column(
+ TIMESTAMP(timezone=True),
+ server_default=func.current_timestamp(),
+ onupdate=func.current_timestamp())
+ __table_args__ = (
+ Index(
+ 'table_note_vector_index',
+ table_note_vector,
+ postgresql_using='hnsw',
+ postgresql_with={'m': 16, 'ef_construction': 200},
+ postgresql_ops={'table_note_vector': 'vector_cosine_ops'}
+ ),
+ )
+
+
+class ColumnInfo(Base):
+ __tablename__ = 'column_info_table'
+ id = Column(UUID(), default=uuid4, primary_key=True)
+ table_id = Column(UUID(), ForeignKey('table_info_table.id', ondelete='CASCADE'))
+ column_name = Column(String)
+ column_type = Column(String)
+ column_note = Column(String)
+ enable = Column(Boolean, default=False)
+
+
+class SqlExample(Base):
+ __tablename__ = 'sql_example_table'
+ id = Column(UUID(), default=uuid4, primary_key=True)
+ table_id = Column(UUID(), ForeignKey('table_info_table.id', ondelete='CASCADE'))
+ question = Column(String())
+ sql = Column(String())
+ question_vector = Column(Vector(1024))
+ created_at = Column(TIMESTAMP(timezone=True), nullable=True, server_default=func.current_timestamp())
+ updated_at = Column(
+ TIMESTAMP(timezone=True),
+ server_default=func.current_timestamp(),
+ onupdate=func.current_timestamp())
+ __table_args__ = (
+ Index(
+ 'question_vector_index',
+ question_vector,
+ postgresql_using='hnsw',
+ postgresql_with={'m': 16, 'ef_construction': 200},
+ postgresql_ops={'question_vector': 'vector_cosine_ops'}
+ ),
+ )
+
+
+class PostgresDB:
+ _engine = None
+
+ @classmethod
+ def get_mysql_engine(cls):
+ if not cls._engine:
+ cls.engine = create_engine(
+ config['DATABASE_URL'],
+ hide_parameters=True,
+ echo=False,
+ pool_recycle=300,
+ pool_pre_ping=True)
+
+ Base.metadata.create_all(cls.engine)
+ return cls._engine
+
+ @classmethod
+ def get_session(cls):
+ connection = None
+ try:
+ connection = sessionmaker(bind=cls.engine)()
+ except Exception as e:
+ logging.error(f"Error creating a postgres sessiondue to error: {e}")
+ return None
+ return cls._ConnectionManager(connection)
+
+ class _ConnectionManager:
+ def __init__(self, connection):
+ self.connection = connection
+
+ def __enter__(self):
+ return self.connection
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ try:
+ self.connection.close()
+ except Exception as e:
+ logging.error(f"Postgres connection close failed due to error: {e}")
+
+
+PostgresDB.get_mysql_engine()
diff --git a/chat2db/docs/change_txt_to_yaml.py b/chat2db/docs/change_txt_to_yaml.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e673d817e146c9762c7c712ab5ecf7a8689bf3b
--- /dev/null
+++ b/chat2db/docs/change_txt_to_yaml.py
@@ -0,0 +1,92 @@
+import yaml
+text = {
+ 'sql_generate_base_on_example_prompt': '''你是一个数据库专家,你的任务是参考给出的表结构以及表注释和示例,基于给出的问题生成一条在{database_url}连接下可进行查询的sql语句。
+注意:
+#01 sql语句中,特殊字段需要带上双引号。
+#02 sql语句中,如果要使用 as,请用双引号把别名包裹起来。
+#03 sql语句中,查询字段必须使用`distinct`关键字去重。
+#04 sql语句中,只返回生成的sql语句, 不要返回其他任何无关的内容
+#05 sql语句中,参考问题,对查询字段进行冗余。
+#06 sql语句中,需要以分号结尾。
+
+以下是表结构以及表注释:
+{note}
+以下是{k}个示例:
+{sql_example}
+以下是问题:
+{question}
+''',
+ 'question_generate_base_on_data_prompt': '''你是一个postgres数据库专家,你的任务是根据给出的表结构和表内数据,输出一个用户可能针对这张表内的信息提出的问题。
+注意:
+#01 问题内容和形式需要多样化,例如要用到统计、排序、模糊匹配等相关问题。
+#02 要以口语化的方式输出问题,不要机械的使用表内字段输出问题。
+#03 不要输出问题之外多余的内容!
+#04 要基于用户的角度取提出问题,问题内容需要口语化、拟人化。
+#05 优先生成有注释的字段相关的sql语句。
+
+以下是表结构和注释:
+{note}
+以下是表内数据
+{data_frame}
+''',
+ 'sql_generate_base_on_data_prompt': '''你是一个postgres数据库专家,你的任务是参考给出的表结构以及表注释和表内数据,基于给出的问题生成一条查询{database_type}的sql语句。
+注意:
+#01 sql语句中,特殊字段需要带上双引号。
+#02 sql语句中,如果要使用 as,请用双引号把别名包裹起来。
+#03 sql语句中,查询字段必须使用`distinct`关键字去重。
+#04 sql语句中,只返回生成的sql语句, 不要返回其他任何无关的内容
+#05 sql语句中,参考问题,对查询字段进行冗余。
+#06 sql语句中,需要以分号结尾。
+
+以下是表结构以及表注释:
+{note}
+以下是表内的数据:
+{data_frame}
+以下是问题:
+{question}
+''',
+ 'sql_expand_prompt': '''你是一个数据库专家,你的任务是参考给出的表结构以及表注释、执行失败的sql和执行失败的报错,基于给出的问题修改执行失败的sql生成一条在{database_type}连接下可进行查询的sql语句。
+
+ 注意:
+
+ #01 假设sql中有特殊字符干扰了sql的执行,请优先替换这些特殊字符保证sql可执行。
+
+ #02 假设sql用于检索或者过滤的字段导致了sql执行的失败,请尝试替换这些字段保证sql可执行。
+
+ #03 假设sql检索结果为空,请尝试将 = 的匹配方式替换为 ilike \'\%\%\' 保证sql执行给出结果。
+
+ #04 假设sql检索结果为空,可以使用问题中的关键字的子集作为sql的过滤条件保证sql执行给出结果。
+
+ 以下是表结构以及表注释:
+
+ {note}
+
+ 以下是执行失败的sql:
+
+ {sql_failed}
+
+ 以下是执行失败的报错:
+
+ {sql_failed_message}
+
+ 以下是问题:
+
+ {question}
+''',
+ 'table_choose_prompt': '''你是一个数据库专家,你的任务是参考给出的表名以及表的条目(主键,表名、表注释),输出最适配于问题回答检索的{table_cnt}张表,并返回表对应的主键。
+注意:
+#01 输出的表名用python的list格式返回,下面是list的一个示例:
+[\"prime_key1\",\"prime_key2\"]。
+#02 只输出包含主键的list即可不要输出其他内容!!!
+#03 list重主键的顺序,按表与问题的适配程度从高到底排列。
+#04 若无任何一张表适用于问题的回答,请返回空列表。
+
+以下是表的条目:
+{table_entries}
+以下是问题:
+{question}
+'''
+}
+print(text)
+with open('./prompt.yaml', 'w', encoding='utf-8') as f:
+ yaml.dump(text, f, allow_unicode=True)
diff --git a/chat2db/docs/prompt.yaml b/chat2db/docs/prompt.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0013b12f6df19007bbc0467d2dc8add497469a1d
--- /dev/null
+++ b/chat2db/docs/prompt.yaml
@@ -0,0 +1,115 @@
+question_generate_base_on_data_prompt: '你是一个postgres数据库专家,你的任务是根据给出的表结构和表内数据,输出一个用户可能针对这张表内的信息提出的问题。
+
+ 注意:
+
+ #01 问题内容和形式需要多样化,例如要用到统计、排序、模糊匹配等相关问题。
+
+ #02 要以口语化的方式输出问题,不要机械的使用表内字段输出问题。
+
+ #03 不要输出问题之外多余的内容!
+
+ #04 要基于用户的角度取提出问题,问题内容需要口语化、拟人化。
+
+ #05 优先生成有注释的字段相关的sql语句。
+
+ #06 不要对生成的sql进行解释。
+
+ 以下是表结构和注释:
+
+ {note}
+
+ 以下是表内数据
+
+ {data_frame}
+
+ '
+sql_expand_prompt: "你是一个数据库专家,你的任务是参考给出的表结构以及表注释、执行失败的sql和执行失败的报错,基于给出的问题修改执行失败的sql生成一条在{database_type}连接下可进行查询的sql语句。\n\
+ \n 注意:\n\n #01 假设sql中有特殊字符干扰了sql的执行,请优先替换这些特殊字符保证sql可执行。\n\n #02 假设sql用于检索或者过滤的字段导致了sql执行的失败,请尝试替换这些字段保证sql可执行。\n\
+ \n #03 假设sql检索结果为空,请尝试将 = 的匹配方式替换为 ilike '\\%\\%' 保证sql执行给出结果。\n\n #04 假设sql检索结果为空,可以使用问题中的关键字的子集作为sql的过滤条件保证sql执行给出结果。\n\
+ \n 以下是表结构以及表注释:\n\n {note}\n\n 以下是执行失败的sql:\n\n {sql_failed}\n\n 以下是执行失败的报错:\n\
+ \n {sql_failed_message}\n\n 以下是问题:\n\n {question}\n"
+sql_generate_base_on_data_prompt: '你是一个postgres数据库专家,你的任务是参考给出的表结构以及表注释和表内数据,基于给出的问题生成一条查询{database_type}的sql语句。
+
+ 注意:
+
+ #01 sql语句中,特殊字段需要带上双引号。
+
+ #02 sql语句中,如果要使用 as,请用双引号把别名包裹起来。
+
+ #03 sql语句中,查询字段必须使用`distinct`关键字去重。
+
+ #04 sql语句中,只返回生成的sql语句, 不要返回其他任何无关的内容
+
+ #05 sql语句中,参考问题,对查询字段进行冗余。
+
+ #06 sql语句中,需要以分号结尾。
+
+ #07 不要对生成的sql进行解释。
+
+ 以下是表结构以及表注释:
+
+ {note}
+
+ 以下是表内的数据:
+
+ {data_frame}
+
+ 以下是问题:
+
+ {question}
+
+ '
+sql_generate_base_on_example_prompt: '你是一个数据库专家,你的任务是参考给出的表结构以及表注释和示例,基于给出的问题生成一条在{database_url}连接下可进行查询的sql语句。
+
+ 注意:
+
+ #01 sql语句中,特殊字段需要带上双引号。
+
+ #02 sql语句中,如果要使用 as,请用双引号把别名包裹起来。
+
+ #03 sql语句中,查询字段必须使用`distinct`关键字去重。
+
+ #04 sql语句中,只返回生成的sql语句, 不要返回其他任何无关的内容
+
+ #05 sql语句中,参考问题,对查询字段进行冗余。
+
+ #06 sql语句中,需要以分号结尾。
+
+
+ 以下是表结构以及表注释:
+
+ {note}
+
+ 以下是{k}个示例:
+
+ {sql_example}
+
+ 以下是问题:
+
+ {question}
+
+ '
+table_choose_prompt: '你是一个数据库专家,你的任务是参考给出的表名以及表的条目(主键,表名、表注释),输出最适配于问题回答检索的{table_cnt}张表,并返回表对应的主键。
+
+ 注意:
+
+ #01 输出的表名用python的list格式返回,下面是list的一个示例:
+
+ ["prime_key1","prime_key2"]。
+
+ #02 只输出包含主键的list即可不要输出其他内容!!!
+
+ #03 list重主键的顺序,按表与问题的适配程度从高到底排列。
+
+ #04 若无任何一张表适用于问题的回答,请返回空列表。
+
+
+ 以下是表的条目:
+
+ {table_entries}
+
+ 以下是问题:
+
+ {question}
+
+ '
diff --git a/chat2db/llm/chat_with_model.py b/chat2db/llm/chat_with_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..9cc1ad2d60bd40e79318ef4df29fc9d47f15c250
--- /dev/null
+++ b/chat2db/llm/chat_with_model.py
@@ -0,0 +1,25 @@
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
+from langchain_openai import ChatOpenAI
+from langchain.schema import SystemMessage, HumanMessage
+import re
+
+class LLM:
+ def __init__(self, model_name, openai_api_base, openai_api_key, request_timeout, max_tokens, temperature):
+ self.client = ChatOpenAI(model_name=model_name,
+ openai_api_base=openai_api_base,
+ openai_api_key=openai_api_key,
+ request_timeout=request_timeout,
+ max_tokens=max_tokens,
+ temperature=temperature)
+
+ def assemble_chat(self, system_call, user_call):
+ chat = []
+ chat.append(SystemMessage(content=system_call))
+ chat.append(HumanMessage(content=user_call))
+ return chat
+
+ async def chat_with_model(self, system_call, user_call):
+ chat = self.assemble_chat(system_call, user_call)
+ response = await self.client.ainvoke(chat)
+ content = re.sub(r'.*?\n\n', '', response.content, flags=re.DOTALL)
+ return content
diff --git a/chat2db/manager/column_info_manager.py b/chat2db/manager/column_info_manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..3816bcf01fcdfe3e8e192979de585a1aa645ec8c
--- /dev/null
+++ b/chat2db/manager/column_info_manager.py
@@ -0,0 +1,69 @@
+from sqlalchemy import and_
+
+from chat2db.database.postgres import ColumnInfo, PostgresDB
+
+
+class ColumnInfoManager():
+ @staticmethod
+ async def add_column_info_with_table_id(table_id, column_name, column_type, column_note):
+ column_info_entry = ColumnInfo(table_id=table_id, column_name=column_name,
+ column_type=column_type, column_note=column_note)
+ with PostgresDB.get_session() as session:
+ session.add(column_info_entry)
+ session.commit()
+
+ @staticmethod
+ async def del_column_info_by_column_id(column_id):
+ with PostgresDB.get_session() as session:
+ column_info_to_delete = session.query(ColumnInfo).filter(ColumnInfo.id == column_id)
+ session.delete(column_info_to_delete)
+ session.commit()
+
+ @staticmethod
+ async def get_column_info_by_column_id(column_id):
+ tmp_dict = {}
+ with PostgresDB.get_session() as session:
+ result = session.query(ColumnInfo).filter(ColumnInfo.id == column_id).first()
+ session.commit()
+ if not result:
+ return None
+ tmp_dict = {
+ 'column_id': result.id,
+ 'table_id': result.table_id,
+ 'column_name': result.column_name,
+ 'column_type': result.column_type,
+ 'column_note': result.column_note,
+ 'enable': result.enable
+ }
+ return tmp_dict
+
+ @staticmethod
+ async def update_column_info_enable(column_id, enable=True):
+ with PostgresDB.get_session() as session:
+ column_info = session.query(ColumnInfo).filter(ColumnInfo.id == column_id).first()
+ if column_info is not None:
+ column_info.enable = True
+ session.commit()
+ else:
+ return False
+ return True
+
+ @staticmethod
+ async def get_column_info_by_table_id(table_id, enable=None):
+ column_info_list = []
+ with PostgresDB.get_session() as session:
+ if enable is None:
+ results = session.query(ColumnInfo).filter(ColumnInfo.table_id == table_id).all()
+ else:
+ results = session.query(ColumnInfo).filter(
+ and_(ColumnInfo.table_id == table_id, ColumnInfo.enable == enable)).all()
+ for result in results:
+ tmp_dict = {
+ 'column_id': result.id,
+ 'column_name': result.column_name,
+ 'column_type': result.column_type,
+ 'column_note': result.column_note,
+ 'enable': result.enable
+ }
+ column_info_list.append(tmp_dict)
+ return column_info_list
diff --git a/chat2db/manager/database_info_manager.py b/chat2db/manager/database_info_manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..e6109c193d62ef65a4aa185c6e9305a313936784
--- /dev/null
+++ b/chat2db/manager/database_info_manager.py
@@ -0,0 +1,74 @@
+import json
+import hashlib
+import logging
+from chat2db.database.postgres import DatabaseInfo, PostgresDB
+from chat2db.security.security import Security
+
+logging.basicConfig(filename='app.log', level=logging.INFO,
+ format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s')
+
+
+class DatabaseInfoManager():
+ @staticmethod
+ async def add_database(database_url: str):
+ id = None
+ with PostgresDB.get_session() as session:
+ encrypted_database_url, encrypted_config = Security.encrypt(database_url)
+ hashmac = hashlib.sha256(database_url.encode('utf-8')).hexdigest()
+ counter = session.query(DatabaseInfo).filter(DatabaseInfo.hashmac == hashmac).first()
+ if counter:
+ return id
+ encrypted_config = json.dumps(encrypted_config)
+ database_info_entry = DatabaseInfo(encrypted_database_url=encrypted_database_url,
+ encrypted_config=encrypted_config, hashmac=hashmac)
+ session.add(database_info_entry)
+ session.commit()
+ id = database_info_entry.id
+ return id
+
+ @staticmethod
+ async def del_database_by_id(id):
+ with PostgresDB.get_session() as session:
+ database_info_to_delete = session.query(DatabaseInfo).filter(DatabaseInfo.id == id).first()
+ if database_info_to_delete:
+ session.delete(database_info_to_delete)
+ else:
+ return False
+ session.commit()
+ return True
+
+ @staticmethod
+ async def get_database_url_by_id(id):
+ with PostgresDB.get_session() as session:
+ result = session.query(
+ DatabaseInfo.encrypted_database_url, DatabaseInfo.encrypted_config).filter(
+ DatabaseInfo.id == id).first()
+ if result is None:
+ return None
+ try:
+ encrypted_database_url, encrypted_config = result
+ encrypted_config = json.loads(encrypted_config)
+ except Exception as e:
+ logging.error(f'数据库url解密失败由于{e}')
+ return None
+ if encrypted_database_url:
+ database_url = Security.decrypt(encrypted_database_url, encrypted_config)
+ else:
+ return None
+ return database_url
+
+ @staticmethod
+ async def get_all_database_info():
+ with PostgresDB.get_session() as session:
+ results = session.query(DatabaseInfo).order_by(DatabaseInfo.created_at).all()
+ database_info_list = []
+ for i in range(len(results)):
+ database_id = results[i].id
+ encrypted_database_url = results[i].encrypted_database_url
+ encrypted_config = json.loads(results[i].encrypted_config)
+ created_at = results[i].created_at
+ if encrypted_database_url:
+ database_url = Security.decrypt(encrypted_database_url, encrypted_config)
+ tmp_dict = {'database_id': database_id, 'database_url': database_url, 'created_at': created_at}
+ database_info_list.append(tmp_dict)
+ return database_info_list
diff --git a/chat2db/manager/sql_example_manager.py b/chat2db/manager/sql_example_manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..09805e24cd52c47c40febc10b3e67727bec072ca
--- /dev/null
+++ b/chat2db/manager/sql_example_manager.py
@@ -0,0 +1,75 @@
+import json
+from sqlalchemy import and_
+from chat2db.database.postgres import SqlExample, PostgresDB
+from chat2db.security.security import Security
+
+
+class SqlExampleManager():
+ @staticmethod
+ async def add_sql_example(question, sql, table_id, question_vector):
+ id = None
+ sql_example_entry = SqlExample(question=question, sql=sql,
+ table_id=table_id, question_vector=question_vector)
+ with PostgresDB.get_session() as session:
+ session.add(sql_example_entry)
+ session.commit()
+ id = sql_example_entry.id
+ return id
+
+ @staticmethod
+ async def del_sql_example_by_id(id):
+ with PostgresDB.get_session() as session:
+ sql_example_to_delete = session.query(SqlExample).filter(SqlExample.id == id).first()
+ if sql_example_to_delete:
+ session.delete(sql_example_to_delete)
+ else:
+ return False
+ session.commit()
+ return True
+
+ @staticmethod
+ async def update_sql_example_by_id(id, question, sql, question_vector):
+ with PostgresDB.get_session() as session:
+ sql_example_to_update = session.query(SqlExample).filter(SqlExample.id == id).first()
+ if sql_example_to_update:
+ sql_example_to_update.sql = sql
+ sql_example_to_update.question = question
+ sql_example_to_update.question_vector = question_vector
+ session.commit()
+ else:
+ return False
+ return True
+
+ @staticmethod
+ async def query_sql_example_by_table_id(table_id):
+ with PostgresDB.get_session() as session:
+ results = session.query(SqlExample).filter(SqlExample.table_id == table_id).all()
+ sql_example_list = []
+ for result in results:
+ tmp_dict = {
+ 'sql_example_id': result.id,
+ 'question': result.question,
+ 'sql': result.sql
+ }
+ sql_example_list.append(tmp_dict)
+ return sql_example_list
+
+ @staticmethod
+ async def get_topk_sql_example_by_cos_dis(question_vector, table_id_list=None, topk=3):
+ with PostgresDB.get_session() as session:
+ if table_id_list is not None:
+ sql_example_list = session.query(
+ SqlExample
+ ).filter(SqlExample.table_id.in_(table_id_list)).order_by(
+ SqlExample.question_vector.cosine_distance(question_vector)
+ ).limit(topk).all()
+ else:
+ sql_example_list = session.query(
+ SqlExample
+ ).order_by(
+ SqlExample.question_vector.cosine_distance(question_vector)
+ ).limit(topk).all()
+ sql_example_list = [
+ {'table_id': sql_example.table_id, 'question': sql_example.question, 'sql': sql_example.sql}
+ for sql_example in sql_example_list]
+ return sql_example_list
diff --git a/chat2db/manager/table_info_manager.py b/chat2db/manager/table_info_manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a5e0c94c5b8e126dab79b9734eb04e801ffdcb6
--- /dev/null
+++ b/chat2db/manager/table_info_manager.py
@@ -0,0 +1,75 @@
+from sqlalchemy import and_
+
+from chat2db.database.postgres import TableInfo, PostgresDB
+
+
+class TableInfoManager():
+ @staticmethod
+ async def add_table_info(database_id, table_name, table_note, table_note_vector):
+ id = None
+ with PostgresDB.get_session() as session:
+ counter = session.query(TableInfo).filter(
+ and_(TableInfo.database_id == database_id, TableInfo.table_name == table_name)).first()
+ if counter:
+ return id
+ table_info_entry = TableInfo(database_id=database_id, table_name=table_name,
+ table_note=table_note, table_note_vector=table_note_vector)
+ session.add(table_info_entry)
+ session.commit()
+ id = table_info_entry.id
+ return id
+
+ @staticmethod
+ async def del_table_by_id(id):
+ with PostgresDB.get_session() as session:
+ table_info_to_delete = session.query(TableInfo).filter(TableInfo.id == id).first()
+ if table_info_to_delete:
+ session.delete(table_info_to_delete)
+ else:
+ return False
+ session.commit()
+ return True
+
+ @staticmethod
+ async def get_table_info_by_table_id(table_id):
+ with PostgresDB.get_session() as session:
+ table_id, database_id, table_name, table_note = session.query(
+ TableInfo.id, TableInfo.database_id, TableInfo.table_name, TableInfo.table_note).filter(
+ TableInfo.id == table_id).first()
+ if table_id is None:
+ return None
+ return {
+ 'table_id': table_id,
+ 'database_id': database_id,
+ 'table_name': table_name,
+ 'table_note': table_note
+ }
+
+ @staticmethod
+ async def get_table_info_by_database_id(database_id, enable=None):
+ with PostgresDB.get_session() as session:
+ if enable is None:
+ results = session.query(
+ TableInfo).filter(TableInfo.database_id == database_id).all()
+ else:
+ results = session.query(
+ TableInfo).filter(
+ and_(TableInfo.database_id == database_id,
+ TableInfo.enable == enable
+ )).all()
+ table_info_list = []
+ for result in results:
+ table_info_list.append({'table_id': result.id, 'table_name': result.table_name,
+ 'table_note': result.table_note, 'created_at': result.created_at})
+ return table_info_list
+
+ @staticmethod
+ async def get_topk_table_by_cos_dis(database_id, tmp_vector, topk=3):
+ with PostgresDB.get_session() as session:
+ results = session.query(
+ TableInfo.id
+ ).filter(TableInfo.database_id == database_id).order_by(
+ TableInfo.table_note_vector.cosine_distance(tmp_vector)
+ ).limit(topk).all()
+ table_id_list = [result[0] for result in results]
+ return table_id_list
diff --git a/chat2db/model/request.py b/chat2db/model/request.py
new file mode 100644
index 0000000000000000000000000000000000000000..beba6859b66f77d239a444ab019ec04a1cb7bd2b
--- /dev/null
+++ b/chat2db/model/request.py
@@ -0,0 +1,83 @@
+import uuid
+from pydantic import BaseModel, Field
+from enum import Enum
+
+
+class QueryRequest(BaseModel):
+ question: str
+ topk_sql: int = 5
+ topk_answer: int = 15
+ use_llm_enhancements: bool = False
+
+
+class DatabaseAddRequest(BaseModel):
+ database_url: str
+
+
+class DatabaseDelRequest(BaseModel):
+ database_id: uuid.UUID
+
+
+class TableAddRequest(BaseModel):
+ database_id: uuid.UUID
+ table_name: str
+
+
+class TableDelRequest(BaseModel):
+ table_id: uuid.UUID
+
+
+class TableQueryRequest(BaseModel):
+ database_id: uuid.UUID
+
+
+class EnableColumnRequest(BaseModel):
+ column_id: uuid.UUID
+ enable: bool
+
+
+class SqlExampleAddRequest(BaseModel):
+ table_id: uuid.UUID
+ question: str
+ sql: str
+
+
+class SqlExampleDelRequest(BaseModel):
+ sql_example_id: uuid.UUID
+
+
+class SqlExampleQueryRequest(BaseModel):
+ table_id: uuid.UUID
+
+
+class SqlExampleUpdateRequest(BaseModel):
+ sql_example_id: uuid.UUID
+ question: str
+ sql: str
+
+
+class SqlGenerateRequest(BaseModel):
+ database_id: uuid.UUID
+ table_id_list: list[uuid.UUID] = []
+ question: str
+ topk: int = 5
+ use_llm_enhancements: bool = True
+
+
+class SqlRepairRequest(BaseModel):
+ database_id: uuid.UUID
+ table_id: uuid.UUID
+ sql: str
+ message: str = Field(..., max_length=2048)
+ question: str
+
+
+class SqlExcuteRequest(BaseModel):
+ database_id: uuid.UUID
+ sql: str
+
+
+class SqlExampleGenerateRequest(BaseModel):
+ table_id: uuid.UUID
+ generate_cnt: int = 1
+ sql_var: bool = False
diff --git a/chat2db/model/response.py b/chat2db/model/response.py
new file mode 100644
index 0000000000000000000000000000000000000000..6aba13235e0ecb0dfd0814dc92534a544f275a0a
--- /dev/null
+++ b/chat2db/model/response.py
@@ -0,0 +1,5 @@
+from pydantic import BaseModel
+class ResponseData(BaseModel):
+ code: int
+ message: str
+ result: dict
\ No newline at end of file
diff --git a/chat2db/scripts/chat2db_config/config.yaml b/chat2db/scripts/chat2db_config/config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..78e3719e8a65cf870cad6994a66bf4120dc113e4
--- /dev/null
+++ b/chat2db/scripts/chat2db_config/config.yaml
@@ -0,0 +1,2 @@
+UVICORN_IP: 0.0.0.0
+UVICORN_PORT: '9015'
diff --git a/chat2db/scripts/docs/output_examples.xlsx b/chat2db/scripts/docs/output_examples.xlsx
new file mode 100644
index 0000000000000000000000000000000000000000..599501ca6c0f1d2b88fe5235d40fb11a56fbf005
Binary files /dev/null and b/chat2db/scripts/docs/output_examples.xlsx differ
diff --git a/chat2db/scripts/run_chat2db.py b/chat2db/scripts/run_chat2db.py
new file mode 100644
index 0000000000000000000000000000000000000000..d8139142b6197b6c4909abc7045b1d7eabf7f1f1
--- /dev/null
+++ b/chat2db/scripts/run_chat2db.py
@@ -0,0 +1,436 @@
+import argparse
+import os
+import pandas as pd
+import requests
+import yaml
+from fastapi import FastAPI
+import shutil
+
+terminal_width = shutil.get_terminal_size().columns
+app = FastAPI()
+
+CHAT2DB_CONFIG_PATH = './chat2db_config'
+CONFIG_YAML_PATH = './chat2db_config/config.yaml'
+DEFAULT_CHAT2DB_CONFIG = {
+ "UVICORN_IP": "127.0.0.1",
+ "UVICORN_PORT": "8000"
+}
+
+
+# 修改
+def update_config(uvicorn_ip, uvicorn_port):
+ try:
+ yml = {'UVICORN_IP': uvicorn_ip, 'UVICORN_PORT': uvicorn_port}
+ with open(CONFIG_YAML_PATH, 'w') as file:
+ yaml.dump(yml, file)
+ return {"message": "修改成功"}
+ except Exception as e:
+ return {"message": f"修改失败,由于:{e}"}
+
+
+# 增加数据库
+def call_add_database_info(database_url):
+ url = f"http://{config['UVICORN_IP']}:{config['UVICORN_PORT']}/database/add"
+ request_body = {
+ "database_url": database_url
+ }
+ response = requests.post(url, json=request_body)
+ return response.json()
+
+
+# 删除数据库
+def call_del_database_info(database_id):
+ url = f"http://{config['UVICORN_IP']}:{config['UVICORN_PORT']}/database/del"
+ request_body = {
+ "database_id": database_id
+ }
+ response = requests.post(url, json=request_body)
+ return response.json()
+
+
+# 查询数据库配置
+def call_query_database_info():
+ url = f"http://{config['UVICORN_IP']}:{config['UVICORN_PORT']}/database/query"
+ response = requests.get(url)
+ return response.json()
+
+
+# 查询数据库内表格配置
+def call_list_table_in_database(database_id, table_filter=''):
+ url = f"http://{config['UVICORN_IP']}:{config['UVICORN_PORT']}/database/list"
+ params = {
+ "database_id": database_id,
+ "table_filter": table_filter
+ }
+ print(params)
+ response = requests.get(url, params=params)
+ return response.json()
+
+
+# 增加数据表
+def call_add_table_info(database_id, table_name):
+ url = f"http://{config['UVICORN_IP']}:{config['UVICORN_PORT']}/table/add"
+ request_body = {
+ "database_id": database_id,
+ "table_name": table_name
+ }
+ response = requests.post(url, json=request_body)
+ return response.json()
+
+
+# 删除数据表
+def call_del_table_info(table_id):
+ url = f"http://{config['UVICORN_IP']}:{config['UVICORN_PORT']}/table/del"
+ request_body = {
+ "table_id": table_id
+ }
+ response = requests.post(url, json=request_body)
+ return response.json()
+
+
+# 查询数据表配置
+def call_query_table_info(database_id):
+ url = f"http://{config['UVICORN_IP']}:{config['UVICORN_PORT']}/table/query"
+ params = {
+ "database_id": database_id
+ }
+ response = requests.get(url, params=params)
+ return response.json()
+
+
+# 查询数据表列信息
+def call_query_column(table_id):
+ url = f"http://{config['UVICORN_IP']}:{config['UVICORN_PORT']}/table/column/query"
+ params = {
+ "table_id": table_id
+ }
+ response = requests.get(url, params=params)
+ return response.json()
+
+
+# 启用禁用列
+def call_enable_column(column_id, enable):
+ url = f"http://{config['UVICORN_IP']}:{config['UVICORN_PORT']}/table/column/enable"
+ request_body = {
+ "column_id": column_id,
+ "enable": enable
+ }
+ response = requests.post(url, json=request_body)
+ return response.json()
+
+
+# 增加sql_example案例
+def call_add_sql_example(table_id, question, sql):
+ url = f"http://{config['UVICORN_IP']}:{config['UVICORN_PORT']}/sql/example/add"
+ request_body = {
+ "table_id": table_id,
+ "question": question,
+ "sql": sql
+ }
+ response = requests.post(url, json=request_body)
+ return response.json()
+
+
+# 删除sql_example案例
+def call_del_sql_example(sql_example_id):
+ url = f"http://{config['UVICORN_IP']}:{config['UVICORN_PORT']}/sql/example/del"
+ request_body = {
+ "sql_example_id": sql_example_id
+ }
+ response = requests.post(url, json=request_body)
+ return response.json()
+
+
+# 查询sql_example案例
+def call_query_sql_example(table_id):
+ url = f"http://{config['UVICORN_IP']}:{config['UVICORN_PORT']}/sql/example/query"
+ params = {
+ "table_id": table_id
+ }
+ response = requests.get(url, params=params)
+ return response.json()
+
+
+# 更新sql_example案例
+def call_update_sql_example(sql_example_id, question, sql):
+ url = f"http://{config['UVICORN_IP']}:{config['UVICORN_PORT']}/sql/example/update"
+ request_body = {
+ "sql_example_id": sql_example_id,
+ "question": question,
+ "sql": sql
+ }
+ response = requests.post(url, json=request_body)
+ return response.json()
+
+
+# 生成sql_example案例
+def call_generate_sql_example(table_id, generate_cnt=1, sql_var=False):
+ url = f"http://{config['UVICORN_IP']}:{config['UVICORN_PORT']}/sql/example/generate"
+ response_body = {
+ "table_id": table_id,
+ "generate_cnt": generate_cnt,
+ "sql_var": sql_var
+ }
+ response = requests.post(url, json=response_body)
+ return response.json()
+
+
+def write_sql_example_to_excel(dir, sql_example_list):
+ try:
+ if not os.path.exists(os.path.dirname(dir)):
+ os.makedirs(os.path.dirname(dir))
+ data = {
+ 'question': [],
+ 'sql': []
+ }
+ for sql_example in sql_example_list:
+ data['question'].append(sql_example['question'])
+ data['sql'].append(sql_example['sql'])
+
+ df = pd.DataFrame(data)
+ df.to_excel(dir, index=False)
+
+ print("Data written to Excel file successfully.")
+ except Exception as e:
+ print("Error writing data to Excel file:", str(e))
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="chat2DB脚本")
+ subparsers = parser.add_subparsers(dest="command", help="子命令列表")
+
+ # 修改config.yaml
+ parser_config = subparsers.add_parser("config", help="修改config.yaml")
+ parser_config.add_argument("--ip", type=str, required=True, help="uvicorn ip")
+ parser_config.add_argument("--port", type=str, required=True, help="uvicorn port")
+
+ # 增加数据库
+ parser_add_database = subparsers.add_parser("add_db", help="增加指定数据库")
+ parser_add_database.add_argument("--database_url", type=str, required=True,
+ help="数据库连接地址,如postgresql+psycopg2://postgres:123456@0.0.0.0:5432/postgres")
+
+ # 删除数据库
+ parser_del_database = subparsers.add_parser("del_db", help="删除指定数据库")
+ parser_del_database.add_argument("--database_id", type=str, required=True, help="数据库id")
+
+ # 查询数据库配置
+ parser_query_database = subparsers.add_parser("query_db", help="查询指定数据库配置")
+
+ # 查询数据库内表格配置
+ parser_list_table_in_database = subparsers.add_parser("list_tb_in_db", help="查询数据库内表格配置")
+ parser_list_table_in_database.add_argument("--database_id", type=str, required=True, help="数据库id")
+ parser_list_table_in_database.add_argument("--table_filter", type=str, required=False, help="表格名称过滤条件")
+
+ # 增加数据表
+ parser_add_table = subparsers.add_parser("add_tb", help="增加指定数据库内的数据表")
+ parser_add_table.add_argument("--database_id", type=str, required=True, help="数据库id")
+ parser_add_table.add_argument("--table_name", type=str, required=True, help="数据表名称")
+
+ # 删除数据表
+ parser_del_table = subparsers.add_parser("del_tb", help="删除指定数据表")
+ parser_del_table.add_argument("--table_id", type=str, required=True, help="数据表id")
+
+ # 查询数据表配置
+ parser_query_table = subparsers.add_parser("query_tb", help="查询指定数据表配置")
+ parser_query_table.add_argument("--database_id", type=str, required=True, help="数据库id")
+
+ # 查询数据表列信息
+ parser_query_column = subparsers.add_parser("query_col", help="查询指定数据表详细列信息")
+ parser_query_column.add_argument("--table_id", type=str, required=True, help="数据表id")
+
+ # 启用禁用列
+ parser_enable_column = subparsers.add_parser("enable_col", help="启用禁用指定列")
+ parser_enable_column.add_argument("--column_id", type=str, required=True, help="列id")
+ parser_enable_column.add_argument("--enable", type=bool, required=True, help="是否启用")
+
+ # 增加sql案例
+ parser_add_sql_example = subparsers.add_parser("add_sql_exp", help="增加指定数据表sql案例")
+ parser_add_sql_example.add_argument("--table_id", type=str, required=True, help="数据表id")
+ parser_add_sql_example.add_argument("--question", type=str, required=False, help="问题")
+ parser_add_sql_example.add_argument("--sql", type=str, required=False, help="sql")
+ parser_add_sql_example.add_argument("--dir", type=str, required=False, help="输入路径")
+
+ # 删除sql_exp
+ parser_del_sql_example = subparsers.add_parser("del_sql_exp", help="删除指定sql案例")
+ parser_del_sql_example.add_argument("--sql_example_id", type=str, required=True, help="sql案例id")
+
+ # 查询sql案例
+ parser_query_sql_example = subparsers.add_parser("query_sql_exp", help="查询指定数据表sql对案例")
+ parser_query_sql_example.add_argument("--table_id", type=str, required=True, help="数据表id")
+
+ # 更新sql案例
+ parser_update_sql_example = subparsers.add_parser("update_sql_exp", help="更新sql对案例")
+ parser_update_sql_example.add_argument("--sql_example_id", type=str, required=True, help="sql案例id")
+ parser_update_sql_example.add_argument("--question", type=str, required=True, help="sql语句对应的问题")
+ parser_update_sql_example.add_argument("--sql", type=str, required=True, help="sql语句")
+
+ # 生成sql案例
+ parser_generate_sql_example = subparsers.add_parser("generate_sql_exp", help="生成指定数据表sql对案例")
+ parser_generate_sql_example.add_argument("--table_id", type=str, required=True, help="数据表id")
+ parser_generate_sql_example.add_argument("--generate_cnt", type=int, required=False, help="生成sql对数量",
+ default=1)
+ parser_generate_sql_example.add_argument("--sql_var", type=bool, required=False,
+ help="是否验证生成的sql对,True为验证,False不验证",
+ default=False)
+ parser_generate_sql_example.add_argument("--dir", type=str, required=False, help="生成的sql对输出路径",
+ default="docs/output_examples.xlsx")
+
+ args = parser.parse_args()
+
+ if os.path.exists(CONFIG_YAML_PATH):
+ exist = True
+ with open(CONFIG_YAML_PATH, 'r') as file:
+ yml = yaml.safe_load(file)
+ config = {
+ 'UVICORN_IP': yml.get('UVICORN_IP'),
+ 'UVICORN_PORT': yml.get('UVICORN_PORT'),
+ }
+ else:
+ exist = False
+
+ if args.command == "config":
+ if not exist:
+ os.makedirs(CHAT2DB_CONFIG_PATH, exist_ok=True)
+ with open(CONFIG_YAML_PATH, 'w') as file:
+ yaml.dump(DEFAULT_CHAT2DB_CONFIG, file, default_flow_style=False)
+ response = update_config(args.ip, args.port)
+ with open(CONFIG_YAML_PATH, 'r') as file:
+ yml = yaml.safe_load(file)
+ config = {
+ 'UVICORN_IP': yml.get('UVICORN_IP'),
+ 'UVICORN_PORT': yml.get('UVICORN_PORT'),
+ }
+ print(response.get("message"))
+ elif not exist:
+ print("please update_config first")
+
+ elif args.command == "add_db":
+ response = call_add_database_info(args.database_url)
+ database_id = response.get("result")['database_id']
+ print(response.get("message"))
+ if response.get("code") == 200:
+ print(f'database_id: ', database_id)
+
+ elif args.command == "del_db":
+ response = call_del_sql_example(args.database_id)
+ print(response.get("message"))
+
+ elif args.command == "query_db":
+ response = call_query_database_info()
+ print(response.get("message"))
+ if response.get("code") == 200:
+ database_info = response.get("result")['database_info_list']
+ for database in database_info:
+ print('-' * terminal_width)
+ print("database_id:", database["database_id"])
+ print("database_url:", database["database_url"])
+ print("created_at:", database["created_at"])
+ print('-' * terminal_width)
+
+ elif args.command == "list_tb_in_db":
+ response = call_list_table_in_database(args.database_id, args.table_filter)
+ print(response.get("message"))
+ if response.get("code") == 200:
+ table_name_list = response.get("result")['table_name_list']
+ for table_name in table_name_list:
+ print(table_name)
+
+ elif args.command == "add_tb":
+ response = call_add_table_info(args.database_id, args.table_name)
+ print(response.get("message"))
+ table_id = response.get("result")['table_id']
+ if response.get("code") == 200:
+ print('table_id: ', table_id)
+
+ elif args.command == "del_tb":
+ response = call_del_table_info(args.table_id)
+ print(response.get("message"))
+
+ elif args.command == "query_tb":
+ response = call_query_table_info(args.database_id)
+ print(response.get("message"))
+ if response.get("code") == 200:
+ table_list = response.get("result")['table_info_list']
+ for table in table_list:
+ print('-' * terminal_width)
+ print("table_id:", table['table_id'])
+ print("table_name:", table['table_name'])
+ print("table_note:", table['table_note'])
+ print("created_at:", table['created_at'])
+ print('-' * terminal_width)
+
+ elif args.command == "query_col":
+ response = call_query_column(args.table_id)
+ print(response.get("message"))
+ if response.get("code") == 200:
+ column_list = response.get("result")['column_info_list']
+ for column in column_list:
+ print('-' * terminal_width)
+ print("column_id:", column['column_id'])
+ print("column_name:", column['column_name'])
+ print("column_note:", column['column_note'])
+ print("column_type:", column['column_type'])
+ print("enable:", column['enable'])
+ print('-' * terminal_width)
+
+ elif args.command == "enable_col":
+ response = call_enable_column(args.column_id, args.enable)
+ print(response.get("message"))
+
+ elif args.command == "add_sql_exp":
+ def get_sql_exp(dir):
+ if not os.path.exists(os.path.dirname(dir)):
+ return None
+ # 读取 xlsx 文件
+ df = pd.read_excel(dir)
+
+ # 遍历每一行数据
+ for index, row in df.iterrows():
+ question = row['question']
+ sql = row['sql']
+
+ # 调用 call_add_sql_example 函数
+ response = call_add_sql_example(args.table_id, question, sql)
+ print(response.get("message"))
+ sql_example_id = response.get("result")['sql_example_id']
+ print('sql_example_id: ', sql_example_id)
+ print(question, sql)
+
+
+ if args.dir:
+ get_sql_exp(args.dir)
+ else:
+ response = call_add_sql_example(args.table_id, args.question, args.sql)
+ print(response.get("message"))
+ sql_example_id = response.get("result")['sql_example_id']
+ print('sql_example_id: ', sql_example_id)
+
+ elif args.command == "del_sql_exp":
+ response = call_del_sql_example(args.sql_example_id)
+ print(response.get("message"))
+
+ elif args.command == "query_sql_exp":
+ response = call_query_sql_example(args.table_id)
+ print(response.get("message"))
+ if response.get("code") == 200:
+ sql_example_list = response.get("result")['sql_example_list']
+ for sql_example in sql_example_list:
+ print('-' * terminal_width)
+ print("sql_example_id:", sql_example['sql_example_id'])
+ print("question:", sql_example['question'])
+ print("sql:", sql_example['sql'])
+ print('-' * terminal_width)
+
+ elif args.command == "update_sql_exp":
+ response = call_update_sql_example(args.sql_example_id, args.question, args.sql)
+ print(response.get("message"))
+
+ elif args.command == "generate_sql_exp":
+ response = call_generate_sql_example(args.table_id, args.generate_cnt, args.sql_var)
+ print(response.get("message"))
+ if response.get("code") == 200:
+ # 输出到execl中
+ sql_example_list = response.get("result")['sql_example_list']
+ write_sql_example_to_excel(args.dir, sql_example_list)
+ else:
+ print("未知命令,请检查输入的命令是否正确。")
diff --git a/chat2db/security/security.py b/chat2db/security/security.py
new file mode 100644
index 0000000000000000000000000000000000000000..0909f27bf29fa5cc8c405ff0e4d998c7f1fbf03d
--- /dev/null
+++ b/chat2db/security/security.py
@@ -0,0 +1,116 @@
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
+
+import base64
+import binascii
+import hashlib
+import secrets
+
+from cryptography.hazmat.backends import default_backend
+from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
+
+from chat2db.config.config import config
+
+
+class Security:
+
+ @staticmethod
+ def encrypt(plaintext: str) -> tuple[str, dict]:
+ """
+ 加密公共方法
+ :param plaintext:
+ :return:
+ """
+ half_key1 = config['HALF_KEY1']
+
+ encrypted_work_key, encrypted_work_key_iv = Security._generate_encrypted_work_key(
+ half_key1)
+ encrypted_plaintext, encrypted_iv = Security._encrypt_plaintext(half_key1, encrypted_work_key,
+ encrypted_work_key_iv, plaintext)
+ del plaintext
+ secret_dict = {
+ "encrypted_work_key": encrypted_work_key,
+ "encrypted_work_key_iv": encrypted_work_key_iv,
+ "encrypted_iv": encrypted_iv,
+ "half_key1": half_key1
+ }
+ return encrypted_plaintext, secret_dict
+
+ @staticmethod
+ def decrypt(encrypted_plaintext: str, secret_dict: dict):
+ """
+ 解密公共方法
+ :param encrypted_plaintext: 待解密的字符串
+ :param secret_dict: 存放工作密钥的dict
+ :return:
+ """
+ plaintext = Security._decrypt_plaintext(half_key1=secret_dict.get("half_key1"),
+ encrypted_work_key=secret_dict.get(
+ "encrypted_work_key"),
+ encrypted_work_key_iv=secret_dict.get(
+ "encrypted_work_key_iv"),
+ encrypted_iv=secret_dict.get(
+ "encrypted_iv"),
+ encrypted_plaintext=encrypted_plaintext)
+ return plaintext
+
+ @staticmethod
+ def _get_root_key(half_key1: str) -> bytes:
+ half_key2 = config['HALF_KEY2']
+ key = (half_key1 + half_key2).encode("utf-8")
+ half_key3 = config['HALF_KEY3'].encode("utf-8")
+ hash_key = hashlib.pbkdf2_hmac("sha256", key, half_key3, 10000)
+ return binascii.hexlify(hash_key)[13:45]
+
+ @staticmethod
+ def _generate_encrypted_work_key(half_key1: str) -> tuple[str, str]:
+ bin_root_key = Security._get_root_key(half_key1)
+ bin_work_key = secrets.token_bytes(32)
+ bin_encrypted_work_key_iv = secrets.token_bytes(16)
+ bin_encrypted_work_key = Security._root_encrypt(bin_root_key, bin_encrypted_work_key_iv, bin_work_key)
+ encrypted_work_key = base64.b64encode(bin_encrypted_work_key).decode("ascii")
+ encrypted_work_key_iv = base64.b64encode(bin_encrypted_work_key_iv).decode("ascii")
+ return encrypted_work_key, encrypted_work_key_iv
+
+ @staticmethod
+ def _get_work_key(half_key1: str, encrypted_work_key: str, encrypted_work_key_iv: str) -> bytes:
+ bin_root_key = Security._get_root_key(half_key1)
+ bin_encrypted_work_key = base64.b64decode(encrypted_work_key.encode("ascii"))
+ bin_encrypted_work_key_iv = base64.b64decode(encrypted_work_key_iv.encode("ascii"))
+ return Security._root_decrypt(bin_root_key, bin_encrypted_work_key_iv, bin_encrypted_work_key)
+
+ @staticmethod
+ def _root_encrypt(key: bytes, encrypted_iv: bytes, plaintext: bytes) -> bytes:
+ encryptor = Cipher(algorithms.AES(key), modes.GCM(encrypted_iv), default_backend()).encryptor()
+ encrypted = encryptor.update(plaintext) + encryptor.finalize()
+ return encrypted
+
+ @staticmethod
+ def _root_decrypt(key: bytes, encrypted_iv: bytes, encrypted: bytes) -> bytes:
+ encryptor = Cipher(algorithms.AES(key), modes.GCM(encrypted_iv), default_backend()).encryptor()
+ plaintext = encryptor.update(encrypted)
+ return plaintext
+
+ @staticmethod
+ def _encrypt_plaintext(half_key1: str, encrypted_work_key: str, encrypted_work_key_iv: str,
+ plaintext: str) -> tuple[str, str]:
+ bin_work_key = Security._get_work_key(half_key1, encrypted_work_key, encrypted_work_key_iv)
+ salt = f"{half_key1}{plaintext}"
+ plaintext_temp = salt.encode("utf-8")
+ del plaintext
+ del salt
+ bin_encrypted_iv = secrets.token_bytes(16)
+ bin_encrypted_plaintext = Security._root_encrypt(bin_work_key, bin_encrypted_iv, plaintext_temp)
+ encrypted_plaintext = base64.b64encode(bin_encrypted_plaintext).decode("ascii")
+ encrypted_iv = base64.b64encode(bin_encrypted_iv).decode("ascii")
+ return encrypted_plaintext, encrypted_iv
+
+ @staticmethod
+ def _decrypt_plaintext(half_key1: str, encrypted_work_key: str, encrypted_work_key_iv: str,
+ encrypted_plaintext: str, encrypted_iv) -> str:
+ bin_work_key = Security._get_work_key(half_key1, encrypted_work_key, encrypted_work_key_iv)
+ bin_encrypted_plaintext = base64.b64decode(encrypted_plaintext.encode("ascii"))
+ bin_encrypted_iv = base64.b64decode(encrypted_iv.encode("ascii"))
+ plaintext_temp = Security._root_decrypt(bin_work_key, bin_encrypted_iv, bin_encrypted_plaintext)
+ plaintext_salt = plaintext_temp.decode("utf-8")
+ plaintext = plaintext_salt[len(half_key1):]
+ return plaintext
\ No newline at end of file
diff --git a/data_chain/common/.env.example b/data_chain/common/.env.example
index f1b5e2f13eaddd47a46eda373997308c75e20910..1e0836f8522075eeb20a25bf987955d44e7f50cb 100644
--- a/data_chain/common/.env.example
+++ b/data_chain/common/.env.example
@@ -1,54 +1,57 @@
- # Fastapi
+# FastAPI
UVICORN_IP=
UVICORN_PORT=
SSL_CERTFILE=
SSL_KEYFILE=
SSL_ENABLE=
-LOG=
-
-# Postgres 5444
+# LOG METHOD
+LOG_METHOD=
+# Database
+DATABASE_TYPE=
DATABASE_URL=
-
# MinIO
MINIO_ENDPOINT=
MINIO_ACCESS_KEY=
MINIO_SECRET_KEY=
MINIO_SECURE=
-
# Redis
REDIS_HOST=
REDIS_PORT=
-# REDIS_PWD=
+REDIS_PWD=
REDIS_PENDING_TASK_QUEUE_NAME=
REDIS_SUCCESS_TASK_QUEUE_NAME=
REDIS_RESTART_TASK_QUEUE_NAME=
-# Embedding Service
-REMOTE_EMBEDDING_ENDPOINT=
-
-# Key
-CSRF_KEY=
+REDIS_SILENT_ERROR_TASK_QUEUE_NAME=
+# Task
+TASK_RETRY_TIME=
+# Embedding
+EMBEDDING_TYPE=
+EMBEDDING_API_KEY=
+EMBEDDING_ENDPOINT=
+EMBEDDING_MODEL_NAME=
+# Token
SESSION_TTL=
-
-# PROMPT_PATH
-PROMPT_PATH=
-# Stop Words PATH
-STOP_WORDS_PATH=
-
-#Security
+CSRF_KEY=
+# Security
HALF_KEY1=
HALF_KEY2=
HALF_KEY3=
-
-#LLM config
-MODEL_NAME=
-OPENAI_API_BASE=
-OPENAI_API_KEY=
-REQUEST_TIMEOUT=
-MAX_TOKENS=
+# Prompt file
+PROMPT_PATH=
+# Stop Words PATH
+STOP_WORDS_PATH=
+# LLM config
+MODEL_1_MODEL_ID=
+MODEL_1_MODEL_NAME=
+MODEL_1_MODEL_TYPE=
+MODEL_1_OPENAI_API_BASE=
+MODEL_1_OPENAI_API_KEY=
+MODEL_1_MAX_TOKENS=
MODEL_ENH=
-
-#DEFAULT USER
+# DEFAULT USER
DEFAULT_USER_ACCOUNT=
DEFAULT_USER_PASSWD=
DEFAULT_USER_NAME=
-DEFAULT_USER_LANGUAGE=
\ No newline at end of file
+DEFAULT_USER_LANGUAGE=
+# DOCUMENT PARSER
+DOCUMENT_PARSE_USE_CPU_LIMIT=
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
index c7a0ecefe3ca8617fa904c2225b64674700e165f..e9363212fe1ca04367f3f7840c4d0e20d8a77492 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,40 +1,48 @@
-python-docx==1.1.0
+aiofiles==24.1.0
+aiomysql==0.2.0
+apscheduler==3.10.0
asgi-correlation-id==4.3.1
+asyncpg==0.29.0
+asyncpg==0.29.0
+beautifulsoup4==4.12.3
+chardet==5.2.0
+cryptography==42.0.2
fastapi==0.110.2
-starlette==0.37.2
-sqlalchemy==2.0.23
-numpy==1.26.4
-scikit-learn==1.5.0
+fastapi-pagination==0.12.19
+httpx==0.27.0
+itsdangerous==2.1.2
+jieba==0.42.1
langchain==0.3.7
-redis==5.0.3
-pgvector==0.2.5
+langchain-community == 0.0.34
+langchain-openai == 0.0.8
+langchain-openai==0.2.5
+langchain-text-splitters==0.0.1
+minio==7.2.4
+more-itertools==10.1.0
+numpy==1.26.4
+openai==1.53.0
+opencv-python==4.9.0.80
+openpyxl==3.1.2
+paddleocr==2.9.1
+paddlepaddle==2.6.2
pandas==2.2.2
-jieba==0.42.1
-cryptography==42.0.2
-httpx==0.27.0
+pgvector==0.2.5
+Pillow==10.3.0
+psycopg2-binary==2.9.9
pydantic==2.7.4
-aiofiles==24.1.0
+PyMuPDF==1.24.13
+python-docx==1.1.0
+python-dotenv==1.0.1
+python-multipart==0.0.9
+python-pptx==0.6.23
pytz==2020.5
pyyaml==6.0.1
+redis==5.0.3
+requests==2.32.2
+scikit-learn==1.5.0
+sqlalchemy==2.0.23
+starlette==0.37.2
tika==2.6.0
+tiktoken==0.8.0
urllib3==2.2.1
-paddlepaddle==2.6.2
-paddleocr==2.9.1
-apscheduler==3.10.0
uvicorn==0.21.0
-requests==2.32.2
-Pillow==10.3.0
-minio==7.2.4
-python-dotenv==1.0.1
-langchain-openai==0.2.5
-openai==1.53.0
-opencv-python==4.9.0.80
-chardet==5.2.0
-PyMuPDF==1.24.13
-python-multipart==0.0.9
-asyncpg==0.29.0
-psycopg2-binary==2.9.9
-openpyxl==3.1.2
-beautifulsoup4==4.12.3
-tiktoken==0.8.0
-python-pptx==0.6.23
\ No newline at end of file
diff --git a/run.sh b/run.sh
index 6cc9be0ad973788a67c40166e99aef173538e21c..df3d5bcce1839f2466c944af41f18d281b24b1f4 100644
--- a/run.sh
+++ b/run.sh
@@ -1,4 +1,4 @@
#!/usr/bin/env sh
nohup java -jar tika-server-standard-2.9.2.jar &
python3 /rag-service/data_chain/apps/app.py
-
+python3 /rag-service/chat2db/apps/app.py
diff --git a/spider/oe_spider/config/pg_info.yaml b/spider/oe_spider/config/pg_info.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..9013b1a707330af8699caf10a0dd86bcf918c286
--- /dev/null
+++ b/spider/oe_spider/config/pg_info.yaml
@@ -0,0 +1,5 @@
+pg_database: postgres
+pg_host: 0.0.0.0
+pg_port: '5438'
+pg_pwd: '123456'
+pg_user: postgres
diff --git a/spider/oe_spider/docs/openeuler_version.txt b/spider/oe_spider/docs/openeuler_version.txt
new file mode 100644
index 0000000000000000000000000000000000000000..c70bca346d1395467bb3d2432e9c70177e278e2c
--- /dev/null
+++ b/spider/oe_spider/docs/openeuler_version.txt
@@ -0,0 +1,26 @@
+openEuler-22.03-LTS 5.10.0-60.137.0 2022-03-31 长期支持版本
+openEuler-22.03-LTS-LoongArch 5.10.0-60.18.0 2022-03-31 其他版本
+openEuler-22.03-LTS-Next 5.10.0-199.0.0 2022-03-31 其他版本
+openEuler-22.09 5.10.0-106.19.0 2022-09-30 社区创新版本
+openEuler-22.09-HeXin 5.10.0-106.18.0 2022-09-30 其他版本
+openEuler-23.03 6.1.19 2023-03-31 社区创新版本
+openEuler-23.09 6.4.0-10.1.0 2023-09-30 社区创新版本
+openEuler-24.03-LTS 6.6.0-26.0.0 2024-06-06 长期支持版本
+openEuler-24.03-LTS-Next 6.6.0-16.0.0 2024-06-06 其他版本
+sync-pr1486-master-to-openEuler-24.03-LTS-Next 6.6.0-10.0.0 2024-06-06 其他版本
+sync-pr1519-openEuler-24.03-LTS-to-openEuler-24.03-LTS-Next 6.6.0-17.0.0 2024-06-06 其他版本
+openEuler-20.03-LTS 4.19.90-2202.1.0 2020-03-27 长期支持版本
+openEuler-20.03-LTS-Next 4.19.194-2106.1.0 2020-03-27 其他版本
+openEuler-20.09 4.19.140-2104.1.0 2020-09-30 社区创新版本
+openEuler-21.03 5.10.0-4.22.0 2021-03-30 社区创新版本
+openEuler-21.09 5.10.0-5.10.1 2021-09-30 社区创新版本
+openEuler-22.03-LTS-SP1 5.10.0-136.75.0 2022-12-30 其他版本
+openEuler-22.03-LTS-SP2 5.10.0-153.54.0 2023-06-30 其他版本
+openEuler-22.03-LTS-SP3 5.10.0-199.0.0 2023-12-30 其他版本
+openEuler-22.03-LTS-SP4 5.10.0-198.0.0 2024-06-30 其他版本
+sync-pr1314-openEuler-22.03-LTS-SP3-to-openEuler-22.03-LTS-Next 5.10.0-166.0.0 2023-12-30 其他版本
+openEuler-20.03-LTS-SP1 4.19.90-2405.3.0 2020-03-30 其他版本
+openEuler-20.03-LTS-SP1-testing 4.19.90-2012.4.0 2020-03-30 其他版本
+openEuler-20.03-LTS-SP2 4.19.90-2205.1.0 2021-06-30 其他版本
+openEuler-20.03-LTS-SP3 4.19.90-2401.1.0 2021-12-30 其他版本
+openEuler-20.03-LTS-SP4 4.19.90-2405.3.0 2023-11-30 其他版本
diff --git a/spider/oe_spider/docs/organize.txt b/spider/oe_spider/docs/organize.txt
new file mode 100644
index 0000000000000000000000000000000000000000..321dd6f928bb3ea281a44f0d45902131cc9970c3
--- /dev/null
+++ b/spider/oe_spider/docs/organize.txt
@@ -0,0 +1,67 @@
+openEuler 委员会顾问专家委员会
+成员 陆首群 中国开源软件联盟名誉主席_原国务院信息办常务副主任
+成员 倪光南 中国工程院院士
+成员 廖湘科 中国工程院院士
+成员 王怀民 中国科学院院士
+成员 邬贺铨 中国科学院院士
+成员 周明辉 北京大学计算机系教授_北大博雅特聘教授
+成员 霍太稳 极客邦科技创始人兼CEO
+
+2023-2024 年 openEuler 委员会
+主席 江大勇 华为技术有限公司
+常务委员会委员 韩乃平 麒麟软件有限公司
+常务委员会委员 刘文清 湖南麒麟信安科技股份有限公司
+常务委员会委员 武延军 中国科学院软件研究所
+常务委员会委员 朱建忠 统信软件技术有限公司
+委员 蔡志旻 江苏润和软件股份有限公司
+委员 高培 软通动力信息技术(集团)股份有限公司
+委员 姜振华 超聚变数字技术有限公司
+委员 李培源 天翼云科技有限公司
+委员 张胜举 中国移动云能力中心
+委员 钟忻 联通数字科技有限公司
+执行总监 邱成锋 华为技术有限公司
+执行秘书 刘彦飞 开放原子开源基金会
+
+openEuler 技术委员会
+主席 胡欣蔚
+委员 卞乃猛
+委员 曹志
+委员 陈棋德
+委员 侯健
+委员 胡峰
+委员 胡亚弟
+委员 李永强
+委员 刘寿永
+委员 吕从庆
+委员 任慰
+委员 石勇
+委员 田俊
+委员 王建民
+委员 王伶卓
+委员 王志钢
+委员 魏刚
+委员 吴峰光
+委员 谢秀奇
+委员 熊伟
+委员 赵川峰
+
+openEuler 品牌委员会
+主席 梁冰
+委员 黄兆伟
+委员 康丽锋
+委员 丽娜
+委员 李震宁
+委员 马骏
+委员 王鑫辉
+委员 文丹
+委员 张迎
+
+openEuler 用户委员会
+主席 薛蕾
+委员 高洪鹤
+委员 黄伟旭
+委员 王军
+委员 王友海
+委员 魏建刚
+委员 谢留华
+委员 朱晨
diff --git a/spider/oe_spider/oe_message_manager.py b/spider/oe_spider/oe_message_manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..a6d15ade16522d91ccc0120a7e13739b025204e9
--- /dev/null
+++ b/spider/oe_spider/oe_message_manager.py
@@ -0,0 +1,560 @@
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
+import time
+
+from sqlalchemy import text
+from pg import PostgresDB, OeSigGroupRepo, OeSigGroupMember, OeSigRepoMember
+from pg import (OeCompatibilityOverallUnit, OeCompatibilityCard, OeCompatibilitySolution,
+ OeCompatibilityOpenSourceSoftware, OeCompatibilityCommercialSoftware, OeCompatibilityOepkgs,
+ OeCompatibilityOsv, OeCompatibilitySecurityNotice, OeCompatibilityCveDatabase,
+ OeCommunityOrganizationStructure, OeCommunityOpenEulerVersion, OeOpeneulerSigGroup,
+ OeOpeneulerSigMembers, OeOpeneulerSigRepos)
+
+
+class OeMessageManager:
+ @staticmethod
+ def clear_oe_compatibility_overall_unit(pg_url):
+ try:
+ with PostgresDB(pg_url).get_session() as session:
+ session.execute(text("DROP TABLE IF EXISTS oe_compatibility_overall_unit;"))
+ session.commit()
+ PostgresDB(pg_url).create_table()
+ except Exception as e:
+ return
+
+ @staticmethod
+ def add_oe_compatibility_overall_unit(pg_url, info):
+ oe_compatibility_overall_unit_slice = OeCompatibilityOverallUnit(
+ id=info.get("id", ''),
+ architecture=info.get("architecture", ''),
+ bios_uefi=info.get("biosUefi", ''),
+ certification_addr=info.get("certificationAddr", ''),
+ certification_time=info.get("certificationTime", ''),
+ commitid=info.get("commitID", ''),
+ computer_type=info.get("computerType", ''),
+ cpu=info.get("cpu", ''),
+ date=info.get("date", ''),
+ friendly_link=info.get("friendlyLink", ''),
+ hard_disk_drive=info.get("hardDiskDrive", ''),
+ hardware_factory=info.get("hardwareFactory", ''),
+ hardware_model=info.get("hardwareModel", ''),
+ host_bus_adapter=info.get("hostBusAdapter", ''),
+ lang=info.get("lang", ''),
+ main_board_model=info.get("mainboardModel", ''),
+ openeuler_version=info.get("osVersion", '').replace(' ', '-'),
+ ports_bus_types=info.get("portsBusTypes", ''),
+ product_information=info.get("productInformation", ''),
+ ram=info.get("ram", ''),
+ video_adapter=info.get("videoAdapter", ''),
+ compatibility_configuration=info.get("compatibilityConfiguration", ''),
+ boardCards=info.get("boardCards", '')
+ )
+ try:
+ with PostgresDB(pg_url).get_session() as session:
+ session.add(oe_compatibility_overall_unit_slice)
+ session.commit()
+ except Exception as e:
+ return
+
+ @staticmethod
+ def clear_oe_compatibility_card(pg_url):
+ try:
+ with PostgresDB(pg_url).get_session() as session:
+ session.execute(text("DROP TABLE IF EXISTS oe_compatibility_card;"))
+ session.commit()
+ PostgresDB(pg_url).create_table()
+ except Exception as e:
+ return
+
+ @staticmethod
+ def add_oe_compatibility_card(pg_url, info):
+ oe_compatibility_card_slice = OeCompatibilityCard(
+ id=info.get("id", ''),
+ architecture=info.get("architecture", ''),
+ board_model=info.get("boardModel", ''),
+ chip_model=info.get("chipModel", ''),
+ chip_vendor=info.get("chipVendor", ''),
+ device_id=info.get("deviceID", ''),
+ download_link=info.get("downloadLink", ''),
+ driver_date=info.get("driverDate", ''),
+ driver_name=info.get("driverName", ''),
+ driver_size=info.get("driverSize", ''),
+ item=info.get("item", ''),
+ lang=info.get("lang", ''),
+ openeuler_version=info.get("os", '').replace(' ', '-'),
+ sha256=info.get("sha256", ''),
+ ss_id=info.get("ssID", ''),
+ sv_id=info.get("svID", ''),
+ type=info.get("type", ''),
+ vendor_id=info.get("vendorID", ''),
+ version=info.get("version", '')
+ )
+ try:
+ with PostgresDB(pg_url).get_session() as session:
+ session.add(oe_compatibility_card_slice)
+ session.commit()
+ except Exception as e:
+ return
+
+ @staticmethod
+ def clear_oe_compatibility_open_source_software(pg_url):
+ try:
+ with PostgresDB(pg_url).get_session() as session:
+ session.execute(text("DROP TABLE IF EXISTS oe_compatibility_open_source_software;"))
+ session.commit()
+ PostgresDB(pg_url).create_table()
+ except Exception as e:
+ return
+
+ @staticmethod
+ def add_oe_compatibility_open_source_software(pg_url, info):
+ oe_compatibility_open_source_software_slice = OeCompatibilityOpenSourceSoftware(
+ openeuler_version=info.get("os", '').replace(' ', '-'),
+ arch=info.get("arch", ''),
+ property=info.get("property", ''),
+ result_url=info.get("result_url", ''),
+ result_root=info.get("result_root", ''),
+ bin=info.get("bin", ''),
+ uninstall=info.get("uninstall", ''),
+ license=info.get("license", ''),
+ libs=info.get("libs", ''),
+ install=info.get("install", ''),
+ src_location=info.get("src_location", ''),
+ group=info.get("group", ''),
+ cmds=info.get("cmds", ''),
+ type=info.get("type", ''),
+ softwareName=info.get("softwareName", ''),
+ category=info.get("category", ''),
+ version=info.get("version", ''),
+ downloadLink=info.get("downloadLink", '')
+ )
+ try:
+ with PostgresDB(pg_url).get_session() as session:
+ session.add(oe_compatibility_open_source_software_slice)
+ session.commit()
+ except Exception as e:
+ return
+
+ @staticmethod
+ def clear_oe_compatibility_commercial_software(pg_url):
+ try:
+ with PostgresDB(pg_url).get_session() as session:
+ session.execute(text("DROP TABLE IF EXISTS oe_compatibility_commercial_software;"))
+ session.commit()
+ PostgresDB(pg_url).create_table()
+ except Exception as e:
+ return
+
+ @staticmethod
+ def add_oe_compatibility_commercial_software(pg_url, info):
+ oe_compatibility_commercial_software_slice = OeCompatibilityCommercialSoftware(
+ id=info.get("certId", ''),
+ data_id=info.get("dataId", ''),
+ type=info.get("type", ''),
+ test_organization=info.get("testOrganization", ''),
+ product_name=info.get("productName", ''),
+ product_version=info.get("productVersion", ''),
+ company_name=info.get("companyName", ''),
+ platform_type_and_server_model=info.get("platformTypeAndServerModel", ''),
+ authenticate_link=info.get("authenticateLink", ''),
+ openeuler_version=(info.get("osName", '') + info.get("osVersion", '')).replace(' ', '-'),
+ region=info.get("region", '')
+ )
+ try:
+ with PostgresDB(pg_url).get_session() as session:
+ session.add(oe_compatibility_commercial_software_slice)
+ session.commit()
+ except Exception as e:
+ return
+
+ @staticmethod
+ def clear_oe_compatibility_solution(pg_url):
+ try:
+ with PostgresDB(pg_url).get_session() as session:
+ session.execute(text("DROP TABLE IF EXISTS oe_compatibility_solution;"))
+ session.commit()
+ PostgresDB(pg_url).create_table()
+ except Exception as e:
+ return
+
+ @staticmethod
+ def add_oe_compatibility_solution(pg_url, info):
+ oe_compatibility_solution_slice = OeCompatibilitySolution(
+ id=info.get("id", ''),
+ architecture=info.get("architecture", ''),
+ bios_uefi=info.get("biosUefi", ''),
+ certification_type=info.get("certificationType", ''),
+ cpu=info.get("cpu", ''),
+ date=info.get("date", ''),
+ driver=info.get("driver", ''),
+ hard_disk_drive=info.get("hardDiskDrive", ''),
+ introduce_link=info.get("introduceLink", ''),
+ lang=info.get("lang", ''),
+ libvirt_version=info.get("libvirtVersion", ''),
+ network_card=info.get("networkCard", ''),
+ openeuler_version=info.get("os", '').replace(' ', '-'),
+ ovs_version=info.get("OVSVersion", ''),
+ product=info.get("product", ''),
+ qemu_version=info.get("qemuVersion", ''),
+ raid=info.get("raid", ''),
+ ram=info.get("ram", ''),
+ server_model=info.get("serverModel", ''),
+ server_vendor=info.get("serverVendor", ''),
+ solution=info.get("solution", ''),
+ stratovirt_version=info.get("stratovirtVersion", '')
+ )
+ try:
+ with PostgresDB(pg_url).get_session() as session:
+ session.add(oe_compatibility_solution_slice)
+ session.commit()
+ except Exception as e:
+ return
+
+ @staticmethod
+ def clear_oe_compatibility_oepkgs(pg_url):
+ try:
+ with PostgresDB(pg_url).get_session() as session:
+ session.execute(text("DROP TABLE IF EXISTS oe_compatibility_oepkgs;"))
+ session.commit()
+ PostgresDB(pg_url).create_table()
+ except Exception as e:
+ return
+
+ @staticmethod
+ def add_oe_compatibility_oepkgs(pg_url, info):
+ oe_compatibility_oepkgs_slice = OeCompatibilityOepkgs(
+ id=info.get("id", ''),
+ name=info.get("name", ''),
+ summary=info.get("summary", ''),
+ repotype=info.get("repoType", ''),
+ openeuler_version=info.get("os", '') + '-' + info.get("osVer", ''),
+ rpmpackurl=info.get("rpmPackUrl", ''),
+ srcrpmpackurl=info.get("srcRpmPackUrl", ''),
+ arch=info.get("arch", ''),
+ rpmlicense=info.get("rpmLicense", ''),
+ version=info.get("version", ''),
+ )
+ try:
+ with PostgresDB(pg_url).get_session() as session:
+ session.add(oe_compatibility_oepkgs_slice)
+ session.commit()
+ except Exception as e:
+ return
+
+ @staticmethod
+ def clear_oe_compatibility_osv(pg_url):
+ try:
+ with PostgresDB(pg_url).get_session() as session:
+ session.execute(text("DROP TABLE IF EXISTS oe_compatibility_osv;"))
+ session.commit()
+ PostgresDB(pg_url).create_table()
+ except Exception as e:
+ return
+
+ @staticmethod
+ def add_oe_compatibility_osv(pg_url, info):
+ oe_compatibility_osv_slice = OeCompatibilityOsv(
+ id=info.get("id", ''),
+ arch=info.get("arch", ''),
+ os_version=info.get("osVersion", ''),
+ osv_name=info.get("osvName", ''),
+ date=info.get("date", ''),
+ os_download_link=info.get("osDownloadLink", ''),
+ type=info.get("type", ''),
+ details=info.get("details", ''),
+ friendly_link=info.get("friendlyLink", ''),
+ total_result=info.get("totalResult", ''),
+ checksum=info.get("checksum", ''),
+ openeuler_version=info.get('baseOpeneulerVersion', '').replace(' ', '-')
+ )
+ try:
+ with PostgresDB(pg_url).get_session() as session:
+ session.add(oe_compatibility_osv_slice)
+ session.commit()
+ except Exception as e:
+ return
+
+ @staticmethod
+ def clear_oe_compatibility_security_notice(pg_url):
+ try:
+ with PostgresDB(pg_url).get_session() as session:
+ session.execute(text("DROP TABLE IF EXISTS oe_compatibility_security_notice;"))
+ session.commit()
+ PostgresDB(pg_url).create_table()
+ except Exception as e:
+ return
+
+ @staticmethod
+ def add_oe_compatibility_security_notice(pg_url, info):
+ oe_compatibility_security_notice_slice = OeCompatibilitySecurityNotice(
+ id=info.get("id", ''),
+ affected_component=info.get("affectedComponent", ''),
+ openeuler_version=info.get("openeuler_version", ''),
+ announcement_time=info.get("announcementTime", ''),
+ cve_id=info.get("cveId", ''),
+ description=info.get("description", ''),
+ introduction=info.get("introduction", ''),
+ package_name=info.get("packageName", ''),
+ reference_documents=info.get("referenceDocuments", ''),
+ revision_history=info.get("revisionHistory", ''),
+ security_notice_no=info.get("securityNoticeNo", ''),
+ subject=info.get("subject", ''),
+ summary=info.get("summary", ''),
+ type=info.get("type", ''),
+ notice_type=info.get("notice_type", ''),
+ cvrf=info.get("cvrf", ''),
+ package_helper_list=info.get("packageHelperList", ''),
+ package_hotpatch_list=info.get("packageHotpatchList", ''),
+ package_list=info.get("packageList", ''),
+ reference_list=info.get("referenceList", ''),
+ cve_list=info.get("cveList", ''),
+ details=info.get("details", ''),
+ )
+ try:
+ with PostgresDB(pg_url).get_session() as session:
+ session.add(oe_compatibility_security_notice_slice)
+ session.commit()
+ except Exception as e:
+ return
+
+ @staticmethod
+ def clear_oe_compatibility_cve_database(pg_url):
+ try:
+ with PostgresDB(pg_url).get_session() as session:
+ session.execute(text("DROP TABLE IF EXISTS oe_compatibility_cve_database;"))
+ session.commit()
+ PostgresDB(pg_url).create_table()
+ except Exception as e:
+ return
+
+ @staticmethod
+ def add_oe_compatibility_cve_database(pg_url, info):
+ oe_compatibility_cve_database_slice = OeCompatibilityCveDatabase(
+ id=info.get("id", ''),
+ affected_product=info.get("productName", ''),
+ announcement_time=info.get("announcementTime", ''),
+ attack_complexity_nvd=info.get("attackComplexityNVD", ''),
+ attack_complexity_oe=info.get("attackVectorNVD", ''),
+ attack_vector_nvd=info.get("attackVectorOE", ''),
+ attack_vector_oe=info.get("attackVectorOE", ''),
+ availability_nvd=info.get("availabilityNVD", ''),
+ availability_oe=info.get("availabilityOE", ''),
+ confidentiality_nvd=info.get("confidentialityNVD", ''),
+ confidentiality_oe=info.get("confidentialityOE", ''),
+ cve_id=info.get("cveId", ''),
+ cvsss_core_nvd=info.get("cvsssCoreNVD", ''),
+ cvsss_core_oe=info.get("cvsssCoreOE", ''),
+ integrity_nvd=info.get("integrityNVD", ''),
+ integrity_oe=info.get("integrityOE", ''),
+ national_cyberAwareness_system=info.get("nationalCyberAwarenessSystem", ''),
+ package_name=info.get("packageName", ''),
+ privileges_required_nvd=info.get("privilegesRequiredNVD", ''),
+ privileges_required_oe=info.get("privilegesRequiredOE", ''),
+ scope_nvd=info.get("scopeNVD", ''),
+ scope_oe=info.get("scopeOE", ''),
+ status=info.get("status", ''),
+ summary=info.get("summary", ''),
+ type=info.get("type", ''),
+ user_interaction_nvd=info.get("userInteractionNVD", ''),
+ user_interactio_oe=info.get("userInteractionOE", ''),
+ update_time=info.get("updateTime", ''),
+ create_time=info.get("createTime", ''),
+ security_notice_no=info.get("securityNoticeNo", ''),
+ parser_bean=info.get("parserBean", ''),
+ cvrf=info.get("cvrf", ''),
+ package_list=info.get("packageList", ''),
+ details=info.get("details", ''),
+ )
+ try:
+ with PostgresDB(pg_url).get_session() as session:
+ session.add(oe_compatibility_cve_database_slice)
+ session.commit()
+ except Exception as e:
+ return
+
+ @staticmethod
+ def clear_oe_openeuler_sig_members(pg_url):
+ try:
+ with PostgresDB(pg_url).get_session() as session:
+ session.execute(text("DROP TABLE IF EXISTS oe_openeuler_sig_members CASCADE;"))
+ session.commit()
+ PostgresDB(pg_url).create_table()
+ except Exception as e:
+ return
+
+ @staticmethod
+ def add_oe_openeuler_sig_members(pg_url, info):
+ oe_openeuler_sig_members_slice = OeOpeneulerSigMembers(
+ name=info.get('name', ''),
+ gitee_id=info.get('gitee_id', ''),
+ organization=info.get('organization', ''),
+ email=info.get('email', ''),
+ )
+ try:
+ with PostgresDB(pg_url).get_session() as session:
+ session.add(oe_openeuler_sig_members_slice)
+ session.commit()
+ except Exception as e:
+ return
+
+ @staticmethod
+ def clear_oe_openeuler_sig_group(pg_url):
+ try:
+ with PostgresDB(pg_url).get_session() as session:
+ session.execute(text("DROP TABLE IF EXISTS oe_openeuler_sig_groups CASCADE;"))
+ session.commit()
+ PostgresDB(pg_url).create_table()
+ except Exception as e:
+ return
+
+ @staticmethod
+ def add_oe_openeuler_sig_group(pg_url, info):
+ oe_openeuler_sig_group_slice = OeOpeneulerSigGroup(
+ sig_name=info.get("sig_name", ''),
+ description=info.get("description", ''),
+ mailing_list=info.get("mailing_list", ''),
+ created_at=info.get("created_at", ''),
+ is_sig_original=info.get("is_sig_original", ''),
+ )
+ # max_retries = 10
+ # for retry in range(max_retries):
+ try:
+ with PostgresDB(pg_url).get_session() as session:
+ session.add(oe_openeuler_sig_group_slice)
+ session.commit()
+ except Exception as e:
+ return
+
+ @staticmethod
+ def clear_oe_openeuler_sig_repos(pg_url):
+ try:
+ with PostgresDB(pg_url).get_session() as session:
+ session.execute(text("DROP TABLE IF EXISTS oe_openeuler_sig_repos CASCADE;"))
+ session.commit()
+ PostgresDB(pg_url).create_table()
+ except Exception as e:
+ return
+
+ @staticmethod
+ def add_oe_openeuler_sig_repos(pg_url, info):
+ oe_openeuler_sig_repo_slice = OeOpeneulerSigRepos(
+ repo=info.get("repo", ''),
+ url=info.get("url",''),
+ )
+ try:
+ with PostgresDB(pg_url).get_session() as session:
+ session.add(oe_openeuler_sig_repo_slice)
+ session.commit()
+ except Exception as e:
+ return
+
+ @staticmethod
+ def clear_oe_sig_group_to_repos(pg_url):
+ try:
+ with PostgresDB(pg_url).get_session() as session:
+ session.execute(text("DROP TABLE IF EXISTS oe_sig_group_to_repos;"))
+ session.commit()
+ PostgresDB(pg_url).create_table()
+ except Exception as e:
+ return
+ @staticmethod
+ def add_oe_sig_group_to_repos(pg_url, group_name, repo_name):
+ try:
+ with PostgresDB(pg_url).get_session() as session:
+ repo = session.query(OeOpeneulerSigRepos).filter_by(repo=repo_name).first()
+ group = session.query(OeOpeneulerSigGroup).filter_by(sig_name=group_name).first()
+ # 插入映射关系
+ relations = OeSigGroupRepo(group_id=group.id, repo_id=repo.id)
+ session.add(relations)
+ session.commit()
+ except Exception as e:
+ return
+
+ @staticmethod
+ def clear_oe_sig_group_to_members(pg_url):
+ try:
+ with PostgresDB(pg_url).get_session() as session:
+ session.execute(text("DROP TABLE IF EXISTS oe_sig_group_to_members;"))
+ session.commit()
+ PostgresDB(pg_url).create_table()
+ except Exception as e:
+ return
+ @staticmethod
+ def add_oe_sig_group_to_members(pg_url, group_name, member_name, role):
+ try:
+ with PostgresDB(pg_url).get_session() as session:
+ member = session.query(OeOpeneulerSigMembers).filter_by(gitee_id=member_name).first()
+ group = session.query(OeOpeneulerSigGroup).filter_by(sig_name=group_name).first()
+ # 插入映射关系
+ relations = OeSigGroupMember(member_id=member.id, group_id=group.id, member_role=role)
+ session.add(relations)
+ session.commit()
+ except Exception as e:
+ return
+
+ @staticmethod
+ def clear_oe_sig_repos_to_members(pg_url):
+ try:
+ with PostgresDB(pg_url).get_session() as session:
+ session.execute(text("DROP TABLE IF EXISTS oe_sig_repos_to_members;"))
+ session.commit()
+ PostgresDB(pg_url).create_table()
+ except Exception as e:
+ return
+ @staticmethod
+ def add_oe_sig_repos_to_members(pg_url, repo_name, member_name, role):
+ try:
+ with PostgresDB(pg_url).get_session() as session:
+ repo = session.query(OeOpeneulerSigRepos).filter_by(repo=repo_name).first()
+ member = session.query(OeOpeneulerSigMembers).filter_by(name=member_name).first()
+ # 插入映射关系
+ relations = OeSigRepoMember(repo_id=repo.id, member_id=member.id, member_role=role)
+ session.add(relations)
+ session.commit()
+ except Exception as e:
+ return
+ @staticmethod
+ def clear_oe_community_organization_structure(pg_url):
+ try:
+ with PostgresDB(pg_url).get_session() as session:
+ session.execute(text("DROP TABLE IF EXISTS oe_community_organization_structure;"))
+ session.commit()
+ PostgresDB(pg_url).create_table()
+ except Exception as e:
+ return
+
+ @staticmethod
+ def add_oe_community_organization_structure(pg_url, info):
+ oe_community_organization_structure_slice = OeCommunityOrganizationStructure(
+ committee_name=info.get('committee_name', ''),
+ role=info.get('role', ''),
+ name=info.get('name', ''),
+ personal_message=info.get('personal_message', '')
+ )
+ try:
+ with PostgresDB(pg_url).get_session() as session:
+ session.add(oe_community_organization_structure_slice)
+ session.commit()
+ except Exception as e:
+ return
+
+ @staticmethod
+ def clear_oe_community_openEuler_version(pg_url):
+ try:
+ with PostgresDB(pg_url).get_session() as session:
+ session.execute(text("DROP TABLE IF EXISTS oe_community_openEuler_version;"))
+ session.commit()
+ PostgresDB(pg_url).create_table()
+ except Exception as e:
+ return
+
+ @staticmethod
+ def add_oe_community_openEuler_version(pg_url, info):
+ oe_community_openeuler_version_slice = OeCommunityOpenEulerVersion(
+ openeuler_version=info.get('openeuler_version', ''),
+ kernel_version=info.get('kernel_version', ''),
+ publish_time=info.get('publish_time', ''),
+ version_type=info.get('version_type', '')
+ )
+ with PostgresDB(pg_url).get_session() as session:
+ session.add(oe_community_openeuler_version_slice)
+ session.commit()
diff --git a/spider/oe_spider/pg.py b/spider/oe_spider/pg.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c8197811605f05923a89c8f37df2b26b9d418d7
--- /dev/null
+++ b/spider/oe_spider/pg.py
@@ -0,0 +1,344 @@
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
+from threading import Lock
+
+from sqlalchemy import Column, String, BigInteger, DateTime, create_engine, MetaData, Sequence, ForeignKey
+from sqlalchemy.orm import declarative_base, sessionmaker,relationship
+
+
+Base = declarative_base()
+
+
+class OeCompatibilityOverallUnit(Base):
+ __tablename__ = 'oe_compatibility_overall_unit'
+ __table_args__ = {'comment': 'openEuler支持的整机信息表,存储了openEuler支持的整机的架构、CPU型号、硬件厂家、硬件型号和相关的openEuler版本'}
+ id = Column(BigInteger, primary_key=True)
+ architecture = Column(String(), comment='openEuler支持的整机信息的架构')
+ bios_uefi = Column(String())
+ certification_addr = Column(String())
+ certification_time = Column(String())
+ commitid = Column(String())
+ computer_type = Column(String())
+ cpu = Column(String(), comment='openEuler支持的整机信息的CPU型号')
+ date = Column(String())
+ friendly_link = Column(String())
+ hard_disk_drive = Column(String())
+ hardware_factory = Column(String(), comment='openEuler支持的整机信息的硬件厂家')
+ hardware_model = Column(String(), comment='openEuler支持的整机信息的硬件型号')
+ host_bus_adapter = Column(String())
+ lang = Column(String())
+ main_board_model = Column(String())
+ openeuler_version = Column(String(), comment='openEuler支持的整机信息的相关的openEuler版本')
+ ports_bus_types = Column(String())
+ product_information = Column(String())
+ ram = Column(String())
+ update_time = Column(String())
+ video_adapter = Column(String())
+ compatibility_configuration = Column(String())
+ boardCards = Column(String())
+
+
+class OeCompatibilityCard(Base):
+ __tablename__ = 'oe_compatibility_card'
+ __table_args__ = {'comment': 'openEuler支持的板卡信息表,存储了openEuler支持的板卡信息的支持架构、板卡型号、芯片信号、芯片厂家、驱动名称、相关的openEuler版本、类型和版本'}
+ id = Column(BigInteger, primary_key=True)
+ architecture = Column(String(), comment='openEuler支持的板卡信息的支持架构')
+ board_model = Column(String(), comment='openEuler支持的板卡信息的板卡型号')
+ chip_model = Column(String(), comment='openEuler支持的板卡信息的芯片型号')
+ chip_vendor = Column(String(), comment='openEuler支持的板卡信息的芯片厂家')
+ device_id = Column(String())
+ download_link = Column(String())
+ driver_date = Column(String())
+ driver_name = Column(String(), comment='openEuler支持的板卡信息的驱动名称')
+ driver_size = Column(String())
+ item = Column(String())
+ lang = Column(String())
+ openeuler_version = Column(String(), comment='openEuler支持的板卡信息的相关的openEuler版本')
+ sha256 = Column(String())
+ ss_id = Column(String())
+ sv_id = Column(String())
+ type = Column(String(), comment='openEuler支持的板卡信息的类型')
+ vendor_id = Column(String())
+ version = Column(String(), comment='openEuler支持的板卡信息的版本')
+
+
+class OeCompatibilityOpenSourceSoftware(Base):
+ __tablename__ = 'oe_compatibility_open_source_software'
+ __table_args__ = {'comment': 'openEuler支持的开源软件信息表,存储了openEuler支持的开源软件的相关openEuler版本、支持的架构、软件属性、开源协议、软件类型、软件名称和软件版本'}
+ id = Column(BigInteger, primary_key=True, autoincrement=True)
+ openeuler_version = Column(String(), comment='openEuler支持的开源软件信息表的相关openEuler版本')
+ arch = Column(String(), comment='openEuler支持的开源软件信息的支持的架构')
+ property = Column(String(), comment='openEuler支持的开源软件信息的软件属性')
+ result_url = Column(String())
+ result_root = Column(String())
+ bin = Column(String())
+ uninstall = Column(String())
+ license = Column(String(), comment='openEuler支持的开源软件信息的开源协议')
+ libs = Column(String())
+ install = Column(String())
+ src_location = Column(String())
+ group = Column(String())
+ cmds = Column(String())
+ type = Column(String(), comment='openEuler支持的开源软件信息的软件类型')
+ softwareName = Column(String())
+ category = Column(String(), comment='openEuler支持的开源软件信息的软件名称')
+ version = Column(String(), comment='openEuler支持的开源软件信息的软件版本')
+ downloadLink = Column(String())
+
+
+class OeCompatibilityCommercialSoftware(Base):
+ __tablename__ = 'oe_compatibility_commercial_software'
+ __table_args__ = {'comment': 'openEuler支持的商业软件信息表,存储了openEuler支持的商业软件的软件类型、测试机构、软件名称、厂家名称和相关的openEuler版本'}
+ id = Column(BigInteger, primary_key=True,)
+ data_id = Column(String())
+ type = Column(String(), comment='openEuler支持的商业软件的软件类型')
+ test_organization = Column(String(), comment='openEuler支持的商业软件的测试机构')
+ product_name = Column(String(), comment='openEuler支持的商业软件的软件名称')
+ product_version = Column(String(), comment='openEuler支持的商业软件的软件版本')
+ company_name = Column(String(), comment='openEuler支持的商业软件的厂家名称')
+ platform_type_and_server_model = Column(String())
+ authenticate_link = Column(String())
+ openeuler_version = Column(String(), comment='openEuler支持的商业软件的openEuler版本')
+ region = Column(String())
+
+
+class OeCompatibilityOepkgs(Base):
+ __tablename__ = 'oe_compatibility_oepkgs'
+ __table_args__ = {
+ 'comment': 'openEuler支持的软件包信息表,存储了openEuler支持软件包的名称、简介、是否为官方版本标识、相关的openEuler版本、rpm包下载链接、源码包下载链接、支持的架构、版本'}
+ id = Column(String(), primary_key=True)
+ name = Column(String(), comment='openEuler支持的软件包的名称')
+ summary = Column(String(), comment='openEuler支持的软件包的简介')
+ repotype = Column(String(), comment='openEuler支持的软件包的是否为官方版本标识')
+ openeuler_version = Column(String(), comment='openEuler支持的软件包的相关的openEuler版本')
+ rpmpackurl = Column(String(), comment='openEuler支持的软件包的rpm包下载链接')
+ srcrpmpackurl = Column(String(), comment='openEuler支持的软件包的源码包下载链接')
+ arch = Column(String(), comment='openEuler支持的软件包的支持的架构')
+ rpmlicense = Column(String())
+ version = Column(String(), comment='openEuler支持的软件包的版本')
+
+
+class OeCompatibilitySolution(Base):
+ __tablename__ = 'oe_compatibility_solution'
+ __table_args__ = {'comment': 'openeuler支持的解决方案表,存储了openeuler支持的解决方案的架构、硬件类别、cpu型号、相关的openEuler版本、硬件型号、厂家和解决方案类型'}
+ id = Column(String(), primary_key=True,)
+ architecture = Column(String(), comment='openeuler支持的解决方案的架构')
+ bios_uefi = Column(String())
+ certification_type = Column(String(), comment='openeuler支持的解决方案的硬件')
+ cpu = Column(String(), comment='openeuler支持的解决方案的硬件的cpu型号')
+ date = Column(String())
+ driver = Column(String())
+ hard_disk_drive = Column(String())
+ introduce_link = Column(String())
+ lang = Column(String())
+ libvirt_version = Column(String())
+ network_card = Column(String())
+ openeuler_version = Column(String(), comment='openeuler支持的解决方案的相关openEuler版本')
+ ovs_version = Column(String())
+ product = Column(String(), comment='openeuler支持的解决方案的硬件型号')
+ qemu_version = Column(String())
+ raid = Column(String())
+ ram = Column(String())
+ server_model = Column(String(), comment='openeuler支持的解决方案的厂家')
+ server_vendor = Column(String())
+ solution = Column(String(), comment='openeuler支持的解决方案的解决方案类型')
+ stratovirt_version = Column(String())
+
+
+class OeCompatibilityOsv(Base):
+ __tablename__ = 'oe_compatibility_osv'
+ __table_args__ = {
+ 'comment': 'openEuler相关的OSV(Operating System Vendor,操作系统供应商)信息表,存储了openEuler相关的OSV(Operating System Vendor,操作系统供应商)的支持的架构、版本、名称、下载链接、详细信息的链接和相关的openEuler版本'}
+ id = Column(String(), primary_key=True,)
+ arch = Column(String())
+ os_version = Column(String(), comment='openEuler相关的OSV(Operating System Vendor,操作系统供应商)的版本')
+ osv_name = Column(String(), comment='openEuler相关的OSV(Operating System Vendor,操作系统供应商)的名称')
+ date = Column(String())
+ os_download_link = Column(String(), comment='openEuler相关的OSV(Operating System Vendor,操作系统供应商)的下载链接')
+ type = Column(String())
+ details = Column(String(), comment='openEuler相关的OSV(Operating System Vendor,操作系统供应商)的详细信息的链接')
+ friendly_link = Column(String())
+ total_result = Column(String())
+ checksum = Column(String())
+ openeuler_version = Column(String(), comment='openEuler相关的OSV(Operating System Vendor,操作系统供应商)的相关的openEuler版本')
+
+
+class OeCompatibilitySecurityNotice(Base):
+ __tablename__ = 'oe_compatibility_security_notice'
+ __table_args__ = {'comment': 'openEuler社区组安全公告信息表,存储了安全公告影响的openEuler版本、关联的cve漏洞的id、编号和详细信息的链接'}
+ id = Column(String(), primary_key=True,)
+ affected_component = Column(String())
+ openeuler_version = Column(String(), comment='openEuler社区组安全公告信息影响的openEuler版本')
+ announcement_time = Column(String())
+ cve_id = Column(String(), comment='openEuler社区组安全公告关联的cve漏洞的id')
+ description = Column(String())
+ introduction = Column(String())
+ package_name = Column(String())
+ reference_documents = Column(String())
+ revision_history = Column(String())
+ security_notice_no = Column(String(), comment='openEuler社区组安全公告的编号')
+ subject = Column(String())
+ summary = Column(String())
+ type = Column(String())
+ notice_type = Column(String())
+ cvrf = Column(String())
+ package_helper_list = Column(String())
+ package_hotpatch_list = Column(String())
+ package_list = Column(String())
+ reference_list = Column(String())
+ cve_list = Column(String())
+ details = Column(String(), comment='openEuler社区组安全公告的详细信息的链接')
+
+
+class OeCompatibilityCveDatabase(Base):
+ __tablename__ = 'oe_compatibility_cve_database'
+ __table_args__ = {'comment': 'openEuler社区组cve漏洞信息表,存储了cve漏洞的公告时间、id、关联的软件包名称、简介、cvss评分'}
+
+ id = Column(String(), primary_key=True,)
+ affected_product = Column(String())
+ announcement_time = Column( DateTime(), comment='openEuler社区组cve漏洞的公告时间')
+ attack_complexity_nvd = Column(String())
+ attack_complexity_oe = Column(String())
+ attack_vector_nvd = Column(String())
+ attack_vector_oe = Column(String())
+ availability_nvd = Column(String())
+ availability_oe = Column(String())
+ confidentiality_nvd = Column(String())
+ confidentiality_oe = Column(String())
+ cve_id = Column(String(), comment='openEuler社区组cve漏洞的id')
+ cvsss_core_nvd = Column(String(), comment='openEuler社区组cve漏洞的cvss评分')
+ cvsss_core_oe = Column(String())
+ integrity_nvd = Column(String())
+ integrity_oe = Column(String())
+ national_cyberAwareness_system = Column(String())
+ package_name = Column(String(), comment='openEuler社区组cve漏洞的关联的软件包名称')
+ privileges_required_nvd = Column(String())
+ privileges_required_oe = Column(String())
+ scope_nvd = Column(String())
+ scope_oe = Column(String())
+ status = Column(String())
+ summary = Column(String(), comment='openEuler社区组cve漏洞的简介')
+ type = Column(String())
+ user_interaction_nvd = Column(String())
+ user_interactio_oe = Column(String())
+ update_time = Column( DateTime())
+ create_time = Column( DateTime())
+ security_notice_no = Column(String())
+ parser_bean = Column(String())
+ cvrf = Column(String())
+ package_list = Column(String())
+ details = Column(String())
+
+
+class OeOpeneulerSigGroup(Base):
+ __tablename__ = 'oe_openeuler_sig_groups'
+ __table_args__ = {'comment': 'openEuler社区特别兴趣小组信息表,存储了SIG的名称、描述等信息'}
+
+ id = Column(BigInteger(), Sequence('sig_group_id'), primary_key=True)
+ sig_name = Column(String(), comment='SIG的名称')
+ description = Column(String(), comment='SIG的描述')
+ created_at = Column( DateTime())
+ is_sig_original = Column(String(), comment='是否为原始SIG')
+ mailing_list = Column(String(), comment='SIG的邮件列表')
+ repos_group = relationship("OeSigGroupRepo", back_populates="group")
+ members_group = relationship("OeSigGroupMember", back_populates="group")
+
+class OeOpeneulerSigMembers(Base):
+ __tablename__ = 'oe_openeuler_sig_members'
+ __table_args = { 'comment': 'openEuler社区特别兴趣小组成员信息表,储存了小组成员信息,如gitee用户名、所属组织等'}
+
+ id = Column(BigInteger(), Sequence('sig_members_id'), primary_key=True)
+ name = Column(String(), comment='成员名称')
+ gitee_id = Column(String(), comment='成员gitee用户名')
+ organization = Column(String(), comment='成员所属组织')
+ email = Column(String(), comment='成员邮箱地址')
+
+ groups_member = relationship("OeSigGroupMember", back_populates="member")
+ repos_member = relationship("OeSigRepoMember", back_populates="member")
+
+class OeOpeneulerSigRepos(Base):
+ __tablename__ = 'oe_openeuler_sig_repos'
+ __table_args__ = {'comment': 'openEuler社区特别兴趣小组信息表,存储了SIG的名称、描述等信息'}
+
+ id = Column(BigInteger(), Sequence('sig_repos_id'), primary_key=True)
+ repo = Column(String(), comment='repo地址')
+ url = Column(String(), comment='repo URL')
+ groups_repo = relationship("OeSigGroupRepo", back_populates="repo")
+ members_repo = relationship("OeSigRepoMember", back_populates="repo")
+class OeSigGroupRepo(Base):
+ __tablename__ = 'oe_sig_group_to_repos'
+ __table_args__ = {'comment': 'SIG group id包含repos'}
+
+ group_id = Column(BigInteger(), ForeignKey('oe_openeuler_sig_groups.id'), primary_key=True)
+ repo_id = Column(BigInteger(), ForeignKey('oe_openeuler_sig_repos.id'), primary_key=True)
+ group = relationship("OeOpeneulerSigGroup", back_populates="repos_group")
+ repo = relationship("OeOpeneulerSigRepos", back_populates="groups_repo")
+
+class OeSigRepoMember(Base):
+ __tablename__ = 'oe_sig_repos_to_members'
+ __table_args__ = {'comment': 'SIG repo id对应维护成员'}
+
+ member_id = Column(BigInteger(), ForeignKey('oe_openeuler_sig_members.id'), primary_key=True)
+ repo_id = Column(BigInteger(), ForeignKey('oe_openeuler_sig_repos.id'), primary_key=True)
+ member_role = Column(String(), comment='成员属性maintainer or committer') # repo_member
+ member = relationship("OeOpeneulerSigMembers", back_populates="repos_member")
+ repo = relationship("OeOpeneulerSigRepos", back_populates="members_repo")
+
+class OeSigGroupMember(Base):
+ __tablename__ = 'oe_sig_group_to_members'
+ __table_args__ = {'comment': 'SIG group id对应包含成员id'}
+
+ group_id = Column(BigInteger(), ForeignKey('oe_openeuler_sig_groups.id'), primary_key=True)
+ member_id = Column(BigInteger(), ForeignKey('oe_openeuler_sig_members.id'), primary_key=True)
+ member_role = Column(String(), comment='成员属性maintainer or committer')
+ group = relationship("OeOpeneulerSigGroup", back_populates="members_group")
+ member = relationship("OeOpeneulerSigMembers", back_populates="groups_member")
+
+class OeCommunityOrganizationStructure(Base):
+ __tablename__ = 'oe_community_organization_structure'
+ __table_args__ = {'comment': 'openEuler社区组织架构信息表,存储了openEuler社区成员所属的委员会、职位、姓名和个人信息'}
+ id = Column(BigInteger, primary_key=True, autoincrement=True)
+ committee_name = Column(String(), comment='openEuler社区成员所属的委员会')
+ role = Column(String(), comment='openEuler社区成员的职位')
+ name = Column(String(), comment='openEuler社区成员的姓名')
+ personal_message = Column(String(), comment='openEuler社区成员的个人信息')
+
+
+class OeCommunityOpenEulerVersion(Base):
+ __tablename__ = 'oe_community_openeuler_version'
+ __table_args__ = {'comment': 'openEuler社区openEuler版本信息表,存储了openEuler版本的版本、内核版本、发布日期和版本类型'}
+ id = Column(BigInteger, primary_key=True, autoincrement=True)
+ openeuler_version = Column(String(), comment='openEuler版本的版本号')
+ kernel_version = Column(String(), comment='openEuler版本的内核版本')
+ publish_time = Column( DateTime(), comment='openEuler版本的发布日期')
+ version_type = Column(String(), comment='openEuler社区成员的版本类型')
+
+
+
+class PostgresDBMeta(type):
+ _instances = {}
+ _lock: Lock = Lock()
+
+ def __call__(cls, *args, **kwargs):
+ if cls not in cls._instances:
+ instance = super().__call__(*args, **kwargs)
+ cls._instances[cls] = instance
+ return cls._instances[cls]
+
+
+class PostgresDB(metaclass=PostgresDBMeta):
+
+ def __init__(self, pg_url):
+ self.engine = create_engine(
+ pg_url,
+ echo=False,
+ pool_pre_ping=True)
+ self.create_table()
+
+ def create_table(self):
+ Base.metadata.create_all(self.engine)
+
+ def get_session(self):
+ return sessionmaker(bind=self.engine)()
+
+ def close(self):
+ self.engine.dispose()
diff --git a/spider/oe_spider/pull_message_from_oe_web.py b/spider/oe_spider/pull_message_from_oe_web.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ad7a30f6e3681e54e38788431af794f2edd6e06
--- /dev/null
+++ b/spider/oe_spider/pull_message_from_oe_web.py
@@ -0,0 +1,894 @@
+from multiprocessing import Process, set_start_method
+import time
+
+import requests
+import json
+import copy
+from datetime import datetime
+import os
+import yaml
+from yaml import SafeLoader
+from oe_message_manager import OeMessageManager
+import argparse
+import urllib.parse
+from concurrent.futures import ProcessPoolExecutor, as_completed
+
+
+class PullMessageFromOeWeb:
+ @staticmethod
+ def add_to_DB(pg_url, results, dataset_name):
+ dataset_name_map = {
+ 'overall_unit': OeMessageManager.add_oe_compatibility_overall_unit,
+ 'card': OeMessageManager.add_oe_compatibility_card,
+ 'open_source_software': OeMessageManager.add_oe_compatibility_open_source_software,
+ 'commercial_software': OeMessageManager.add_oe_compatibility_commercial_software,
+ 'solution': OeMessageManager.add_oe_compatibility_solution,
+ 'oepkgs': OeMessageManager.add_oe_compatibility_oepkgs,
+ 'osv': OeMessageManager.add_oe_compatibility_osv,
+ 'security_notice': OeMessageManager.add_oe_compatibility_security_notice,
+ 'cve_database': OeMessageManager.add_oe_compatibility_cve_database,
+ 'sig_group': OeMessageManager.add_oe_openeuler_sig_group,
+ 'sig_members': OeMessageManager.add_oe_openeuler_sig_members,
+ 'sig_repos': OeMessageManager.add_oe_openeuler_sig_repos,
+ }
+
+ for i in range(len(results)):
+ for key in results[i]:
+ if isinstance(results[i][key], (dict, list, tuple)):
+ results[i][key] = json.dumps(results[i][key])
+ dataset_name_map[dataset_name](pg_url, results[i])
+
+
+ @staticmethod
+ def save_result(pg_url, results, dataset_name):
+ prompt_map = {'card': 'openEuler支持的板卡信息',
+ 'commercial_software': 'openEuler支持的商业软件',
+ 'cve_database': 'openEuler的cve信息',
+ 'oepkgs': 'openEuler支持的软件包信息',
+ 'solution': 'openEuler支持的解决方案',
+ 'open_source_software': 'openEuler支持的开源软件信息',
+ 'overall_unit': 'openEuler支持的整机信息',
+ 'security_notice': 'openEuler官网的安全公告',
+ 'osv': 'openEuler相关的osv厂商',
+ 'openeuler_version_message': 'openEuler的版本信息',
+ 'organize_message': 'openEuler社区成员组织架构',
+ 'openeuler_sig': 'openeuler_sig(openEuler SIG组成员信息)',
+ 'sig_group': 'openEuler SIG组信息',
+ 'sig_members': 'openEuler SIG组成员信息',
+ 'sig_repos': 'openEuler SIG组仓库信息'}
+ dataset_name_map = {
+ 'overall_unit': OeMessageManager.clear_oe_compatibility_overall_unit,
+ 'card': OeMessageManager.clear_oe_compatibility_card,
+ 'open_source_software': OeMessageManager.clear_oe_compatibility_open_source_software,
+ 'commercial_software': OeMessageManager.clear_oe_compatibility_commercial_software,
+ 'solution': OeMessageManager.clear_oe_compatibility_solution,
+ 'oepkgs': OeMessageManager.clear_oe_compatibility_oepkgs,
+ 'osv': OeMessageManager.clear_oe_compatibility_osv,
+ 'security_notice': OeMessageManager.clear_oe_compatibility_security_notice,
+ 'cve_database': OeMessageManager.clear_oe_compatibility_cve_database,
+ 'sig_group': OeMessageManager.clear_oe_openeuler_sig_group,
+ 'sig_members': OeMessageManager.clear_oe_openeuler_sig_members,
+ 'sig_repos': OeMessageManager.clear_oe_openeuler_sig_repos,
+ }
+ process = []
+ num_cores = (os.cpu_count()+1) // 2
+ if num_cores > 8:
+ num_cores = 8
+
+ print(f"{prompt_map[dataset_name]}录入中")
+ dataset_name_map[dataset_name](pg_url)
+ step = len(results) // num_cores
+ process = []
+ for i in range(num_cores):
+ begin = i * step
+ end = begin + step
+ if i == num_cores - 1:
+ end = len(results)
+ sub_result = results[begin:end]
+ p = Process(target=PullMessageFromOeWeb.add_to_DB, args=(pg_url, sub_result, dataset_name))
+ process.append(p)
+ p.start()
+
+ for p in process:
+ p.join()
+ print('数据插入完成')
+
+ @staticmethod
+ def pull_oe_compatibility_overall_unit(pg_url, dataset_name):
+ url = 'https://www.openeuler.org/api-cve/cve-security-notice-server/hardwarecomp/findAll'
+ headers = {
+ 'Content-Type': 'application/json'
+ }
+ data = {
+ "pages": {
+ "page": 1,
+ "size": 10
+ },
+ "architecture": "",
+ "keyword": "",
+ "cpu": "",
+ "os": "",
+ "testOrganization": "",
+ "type": "",
+ "cardType": "",
+ "lang": "zh",
+ "dataSource": "assessment",
+ "solution": "",
+ "certificationType": ""
+ }
+
+ response = requests.post(
+ url,
+ headers=headers,
+ json=data
+ )
+ results = []
+ total_num = response.json()["result"]["totalCount"]
+ for i in range(total_num // 10 + (total_num % 10 != 0)):
+ data = {
+ "pages": {
+ "page": i + 1,
+ "size": 10
+ },
+ "architecture": "",
+ "keyword": "",
+ "cpu": "",
+ "os": "",
+ "testOrganization": "",
+ "type": "",
+ "cardType": "",
+ "lang": "zh",
+ "dataSource": "assessment",
+ "solution": "",
+ "certificationType": ""
+ }
+ response = requests.post(
+ url,
+ headers=headers,
+ json=data
+ )
+ results += response.json()["result"]["hardwareCompList"]
+
+ PullMessageFromOeWeb.save_result(pg_url, results, dataset_name)
+
+ @staticmethod
+ def pull_oe_compatibility_card(pg_url, dataset_name):
+ url = 'https://www.openeuler.org/api-cve/cve-security-notice-server/drivercomp/findAll'
+ headers = {
+ 'Content-Type': 'application/json'
+ }
+ data = {
+ "pages": {
+ "page": 1,
+ "size": 10
+ },
+ "architecture": "",
+ "keyword": "",
+ "cpu": "",
+ "os": "",
+ "testOrganization": "",
+ "type": "",
+ "cardType": "",
+ "lang": "zh",
+ "dataSource": "assessment",
+ "solution": "",
+ "certificationType": ""
+ }
+ response = requests.post(
+ url,
+ headers=headers,
+ json=data
+ )
+ results = []
+ total_num = response.json()["result"]["totalCount"]
+ for i in range(total_num // 10 + (total_num % 10 != 0)):
+ data = {
+ "pages": {
+ "page": i + 1,
+ "size": 10
+ },
+ "architecture": "",
+ "keyword": "",
+ "cpu": "",
+ "os": "",
+ "testOrganization": "",
+ "type": "",
+ "cardType": "",
+ "lang": "zh",
+ "dataSource": "",
+ "solution": "",
+ "certificationType": ""
+ }
+ response = requests.post(
+ url,
+ headers=headers,
+ json=data
+ )
+ results += response.json()["result"]["driverCompList"]
+ PullMessageFromOeWeb.save_result(pg_url, results, dataset_name)
+
+
+ @staticmethod
+ def pull_oe_compatibility_commercial_software(pg_url, dataset_name):
+ url = 'https://www.openeuler.org/api-certification/server/certification/software/communityChecklist'
+ headers = {
+ 'Content-Type': 'application/json'
+ }
+ data = {
+ "pageSize": 1,
+ "pageNo": 1,
+ "testOrganization": "",
+ "osName": "",
+ "keyword": "",
+ "dataSource": [
+ "assessment"
+ ],
+ "productType": [
+ "软件"
+ ]
+ }
+ response = requests.post(
+ url,
+ headers=headers,
+ json=data
+ )
+ results = []
+ total_num = response.json()["result"]["totalNum"]
+ for i in range(total_num // 10 + (total_num % 10 != 0)):
+ data = {
+ "pageSize": i + 1,
+ "pageNo": 10,
+ "testOrganization": "",
+ "osName": "",
+ "keyword": "",
+ "dataSource": [
+ "assessment"
+ ],
+ "productType": [
+ "软件"
+ ]
+ }
+ response = requests.post(
+ url,
+ headers=headers,
+ json=data
+ )
+ results += response.json()["result"]["data"]
+ PullMessageFromOeWeb.save_result(pg_url, results, dataset_name)
+
+ @staticmethod
+ def pull_oe_compatibility_open_source_software(pg_url, dataset_name):
+ url = 'https://www.openeuler.org/compatibility/api/web_backend/compat_software_info'
+ headers = {
+ 'Content-Type': 'application/json'
+ }
+ data = {
+ "page_size": 1,
+ "page_num": 1
+ }
+ response = requests.get(
+ url,
+ headers=headers,
+ data=data
+ )
+ results = []
+ total_num = response.json()["total"]
+ for i in range(total_num // 10 + (total_num % 10 != 0)):
+ data = {
+ "page_size": i + 1,
+ "page_num": 10
+ }
+ response = requests.get(
+ url,
+ headers=headers,
+ data=data
+ )
+ results += response.json()["info"]
+ PullMessageFromOeWeb.save_result(pg_url, results, dataset_name)
+
+ @staticmethod
+ def pull_oe_compatibility_oepkgs(pg_url, oepkgs_info, dataset_name):
+ headers = {
+ 'Content-Type': 'application/json'
+ }
+ params = {'user': oepkgs_info.get('oepkgs_user', ''),
+ 'password': oepkgs_info.get('oepkgs_pwd', ''),
+ 'expireInSeconds': 36000
+ }
+ url = 'https://search.oepkgs.net/api/search/openEuler/genToken'
+ response = requests.post(url, headers=headers, params=params)
+ token = response.json()['data']['Authorization']
+ headers = {
+ 'Content-Type': 'application/json',
+ 'Authorization': token
+ }
+ url = 'https://search.oepkgs.net/api/search/openEuler/scroll?scrollId='
+ response = requests.get(url, headers=headers)
+ data = response.json()['data']
+ scrollId = data['scrollId']
+ totalHits = data['totalHits']
+ data_list = data['list']
+ results = data_list
+ totalHits -= 1000
+
+ while totalHits > 0:
+ url = 'https://search.oepkgs.net/api/search/openEuler/scroll?scrollId=' + scrollId
+ response = requests.get(url, headers=headers)
+ data = response.json()['data']
+ data_list = data['list']
+ results += data_list
+ totalHits -= 1000
+
+ PullMessageFromOeWeb.save_result(pg_url, results, dataset_name)
+
+ @staticmethod
+ def pull_oe_compatibility_solution(pg_url, dataset_name):
+ url = 'https://www.openeuler.org/api-cve/cve-security-notice-server/solutioncomp/findAll'
+ headers = {
+ 'Content-Type': 'application/json'
+ }
+ data = {
+ "pages": {
+ "page": 1,
+ "size": 1
+ },
+ "architecture": "",
+ "keyword": "",
+ "cpu": "",
+ "os": "",
+ "testOrganization": "",
+ "type": "",
+ "cardType": "",
+ "lang": "zh",
+ "dataSource": "assessment",
+ "solution": "",
+ "certificationType": ""
+ }
+ response = requests.post(
+ url,
+ headers=headers,
+ json=data
+ )
+ results = []
+ total_num = response.json()["result"]["totalCount"]
+ for i in range(total_num // 10 + (total_num % 10 != 0)):
+ data = {
+ "pages": {
+ "page": i + 1,
+ "size": total_num
+ },
+ "architecture": "",
+ "keyword": "",
+ "cpu": "",
+ "os": "",
+ "testOrganization": "",
+ "type": "",
+ "cardType": "",
+ "lang": "zh",
+ "dataSource": "assessment",
+ "solution": "",
+ "certificationType": ""
+ }
+ response = requests.post(
+ url,
+ headers=headers,
+ json=data
+ )
+ results += response.json()["result"]["solutionCompList"]
+ PullMessageFromOeWeb.save_result(pg_url, results, dataset_name)
+
+ @staticmethod
+ def pull_oe_compatibility_osv(pg_url, dataset_name):
+ url = 'https://www.openeuler.org/api-cve/cve-security-notice-server/osv/findAll'
+ headers = {
+ 'Content-Type': 'application/json'
+ }
+ data = {
+ "pages": {
+ "page": 1,
+ "size": 10
+ },
+ "keyword": "",
+ "type": "",
+ "osvName": ""
+ }
+ response = requests.post(
+ url,
+ headers=headers,
+ json=data
+ )
+ results = []
+ total_num = response.json()["result"]["totalCount"]
+ for i in range(total_num // 10 + (total_num % 10 != 0)):
+ data = {
+ "pages": {
+ "page": i + 1,
+ "size": 10
+ },
+ "keyword": "",
+ "type": "",
+ "osvName": ""
+ }
+ response = requests.post(
+ url,
+ headers=headers,
+ json=data
+ )
+ results += response.json()["result"]["osvList"]
+ for i in range(len(results)):
+ results[i]['details'] = 'https://www.openeuler.org/zh/approve/approve-info/?id=' + str(results[i]['id'])
+
+ PullMessageFromOeWeb.save_result(pg_url, results, dataset_name)
+
+ @staticmethod
+ def pull_oe_compatibility_security_notice(pg_url, dataset_name):
+ url = 'https://www.openeuler.org/api-cve/cve-security-notice-server/securitynotice/findAll'
+ headers = {
+ 'Content-Type': 'application/json'
+ }
+ data = {
+ "pages": {
+ "page": 1,
+ "size": 1
+ },
+ "keyword": "",
+ "type": [],
+ "date": [],
+ "affectedProduct": [],
+ "affectedComponent": "",
+ "noticeType": "cve"
+ }
+ response = requests.post(
+ url,
+ headers=headers,
+ json=data
+ )
+ results = []
+ total_num = response.json()["result"]["totalCount"]
+ for i in range(total_num // 10 + (total_num % 10 != 0)):
+ data = {
+ "pages": {
+ "page": i + 1,
+ "size": 10
+ },
+ "keyword": "",
+ "type": [],
+ "date": [],
+ "affectedProduct": [],
+ "affectedComponent": "",
+ "noticeType": "cve"
+ }
+ response = requests.post(
+ url,
+ headers=headers,
+ json=data
+ )
+ results += response.json()["result"]["securityNoticeList"]
+
+ new_results = []
+ for i in range(len(results)):
+ openeuler_version_list = results[i]['affectedProduct'].split(';')
+ cve_id_list = results[i]['cveId'].split(';')
+ cnt = 0
+ for openeuler_version in openeuler_version_list:
+ if len(openeuler_version) > 0:
+ for cve_id in cve_id_list:
+ if len(cve_id) > 0:
+ tmp_dict = copy.deepcopy(results[i])
+ del tmp_dict['affectedProduct']
+ tmp_dict['id'] = int(str(tmp_dict['id']) + str(cnt))
+ tmp_dict['openeuler_version'] = openeuler_version
+ tmp_dict['cveId'] = cve_id
+ new_results.append(tmp_dict)
+ tmp_dict[
+ 'details'] = 'https://www.openeuler.org/zh/security/security-bulletins/detail/?id=' + \
+ tmp_dict['securityNoticeNo']
+ cnt += 1
+ results = new_results
+ PullMessageFromOeWeb.save_result(pg_url, results, dataset_name)
+
+ @staticmethod
+ def pull_oe_compatibility_cve_database(pg_url, dataset_name):
+ url = 'https://www.openeuler.org/api-cve/cve-security-notice-server/cvedatabase/findAll'
+ headers = {
+ 'Content-Type': 'application/json'
+ }
+ data = {
+ "pages": {
+ "page": 1,
+ "size": 1
+ },
+ "keyword": "",
+ "status": "",
+ "year": "",
+ "score": "",
+ "noticeType": "cve"
+ }
+ response = requests.post(
+ url,
+ headers=headers,
+ json=data
+ )
+ js = response.json()
+ results = []
+ total_num = response.json()["result"]["totalCount"]
+ for i in range(total_num // 10 + (total_num % 10 != 0)):
+ data = {
+ "pages": {
+ "page": i + 1,
+ "size": 10
+ },
+ "keyword": "",
+ "status": "",
+ "year": "",
+ "score": "",
+ "noticeType": "cve"
+ }
+ response = requests.post(
+ url,
+ headers=headers,
+ json=data
+ )
+ results += response.json()["result"]["cveDatabaseList"]
+
+ data_cnt=0
+ new_results=[]
+ for i in range(len(results)):
+ id1=results[i]['cveId']
+ results[i]['details'] = 'https://www.openeuler.org/zh/security/cve/detail/?cveId=' + \
+ results[i]['cveId'] + '&packageName=' + results[i]['packageName']
+ url='https://www.openeuler.org/api-cve/cve-security-notice-server/cvedatabase/getCVEProductPackageList?cveId=' + \
+ results[i]['cveId'] + '&packageName=' + results[i]['packageName']
+ try:
+ response = requests.get(
+ url,
+ headers=headers
+ )
+ for key in response.json()["result"].keys():
+ if key not in results[i].keys() or results[i][key]=='':
+ results[i][key] = response.json()["result"][key]
+ url='https://www.openeuler.org/api-cve/cve-security-notice-server/cvedatabase/getCVEProductPackageList?cveId='+ \
+ results[i]['cveId'] + '&packageName=' + results[i]['packageName']
+ response = requests.get(
+ url,
+ headers=headers
+ )
+ except:
+ pass
+ try:
+ for j in range(len(response.json()["result"])):
+ tmp=copy.deepcopy(results[i])
+ for key in response.json()["result"][j].keys():
+ if key not in tmp.keys() or tmp[key]=='':
+ tmp[key] = response.json()["result"][j][key]
+ try:
+ date_str = tmp['updateTime']
+ date_format = "%Y-%m-%d %H:%M:%S"
+ dt = datetime.strptime(date_str, date_format)
+ year_month_day = dt.strftime("%Y-%m-%d")
+ tmp['announcementTime']= year_month_day
+ except:
+ pass
+ tmp['id']=str(data_cnt)
+ data_cnt+=1
+ new_results.append(tmp)
+ except:
+ pass
+ results=new_results
+
+ PullMessageFromOeWeb.save_result(pg_url, results, dataset_name)
+
+ @staticmethod
+ def pull_oe_openeuler_sig(pg_url):
+ sig_url_base = 'https://www.openeuler.org/api-dsapi/query/sig/info'
+ headers = {
+ 'Content-Type': 'application/json'
+ }
+ data = {
+ "community": "openeuler",
+ "page": 1,
+ "pageSize": 12,
+ "search": "fuzzy"
+ }
+ sig_url = sig_url_base + '?' + urllib.parse.urlencode(data)
+ response = requests.post(
+ sig_url,
+ headers=headers,
+ json=data
+ )
+ results_all = []
+ total_num = response.json()["data"][0]["total"]
+ for i in range(total_num // 12 + (total_num % 12 != 0)):
+ data = {
+ "community": "openeuler",
+ "page": i + 1,
+ "pageSize": 12,
+ "search": "fuzzy"
+ }
+ sig_url = sig_url_base + '?' + urllib.parse.urlencode(data)
+ response = requests.post(
+ sig_url,
+ headers=headers,
+ json=data
+ )
+ results_all += response.json()["data"][0]["data"]
+
+ results_members = []
+ results_repos = []
+ for i in range(len(results_all)):
+ temp_members = []
+ sig_name = results_all[i]['sig_name']
+ repos_url_base = 'https://www.openeuler.org/api-dsapi/query/sig/repo/committers'
+ repos_url = repos_url_base + '?community=openeuler&sig=' + sig_name
+ try:
+ response = requests.get(repos_url, timeout=2)
+ except Exception as e:
+ k = 0
+ while response.status_code != 200 and k < 5:
+ k = k + 1
+ try:
+ response = requests.get(repos_url, timeout=2)
+ except Exception as e:
+ print(e)
+ # continue
+ results = response.json()["data"]
+ committers = results['committers']
+ maintainers = results['maintainers']
+ if committers is None:
+ committers = []
+ if maintainers is None:
+ maintainers = []
+ results_all[i]['committers'] = committers
+ results_all[i]['maintainers'] = maintainers
+ for key in results_all[i]:
+ if key == "committer_info" or key == "maintainer_info": # solve with members
+ if results_all[i][key] is not None:
+ for member in results_all[i][key]:
+ member['sig_name'] = sig_name
+ temp_members.append(member)
+ elif key == "repos": # solve with repos
+ if results_all[i][key] is not None:
+ for repo in results_all[i][key]:
+ results_repos.append({'repo': repo, 'url': 'https://gitee.com/' + repo,
+ 'sig_name': sig_name,
+ 'committers': committers,
+ 'maintainers': maintainers})
+ else: # solve others
+ if isinstance(results_all[i][key], (dict, list, tuple)):
+ results_all[i][key] = json.dumps(results_all[i][key])
+
+ results_members += temp_members
+
+ temp_results_members = []
+ seen = set()
+ for d in results_members:
+ if d['gitee_id'] not in seen:
+ seen.add(d['gitee_id'])
+ temp_results_members.append(d)
+ results_members = temp_results_members
+
+ PullMessageFromOeWeb.save_result(pg_url, results_all, 'sig_group')
+ PullMessageFromOeWeb.save_result(pg_url, results_members, 'sig_members')
+ PullMessageFromOeWeb.save_result(pg_url, results_repos, 'sig_repos')
+
+ OeMessageManager.clear_oe_sig_group_to_repos(pg_url)
+ for i in range(len(results_repos)):
+ repo_name = results_repos[i]['repo']
+ group_name = results_repos[i]['sig_name']
+ OeMessageManager.add_oe_sig_group_to_repos(pg_url, group_name, repo_name)
+
+ OeMessageManager.clear_oe_sig_group_to_members(pg_url)
+ for i in range(len(results_all)):
+ committers = set(json.loads(results_all[i]['committers']))
+ maintainers = set(json.loads(results_all[i]['maintainers']))
+ group_name = results_all[i]['sig_name']
+
+ all_members = committers.union(maintainers)
+
+ for member_name in all_members:
+ if member_name in committers and member_name in maintainers:
+ role = 'committer & maintainer'
+ elif member_name in committers:
+ role = 'committer'
+ elif member_name in maintainers:
+ role = 'maintainer'
+
+ OeMessageManager.add_oe_sig_group_to_members(pg_url, group_name, member_name, role=role)
+
+ OeMessageManager.clear_oe_sig_repos_to_members(pg_url)
+ for i in range(len(results_repos)):
+ repo_name = results_repos[i]['repo']
+ committers = set(results_repos[i]['committers'])
+ maintainers = set(results_repos[i]['maintainers'])
+
+ all_members = committers.union(maintainers)
+
+ for member_name in all_members:
+ if member_name in committers and member_name in maintainers:
+ role = 'committer & maintainer'
+ elif member_name in committers:
+ role = 'committer'
+ elif member_name in maintainers:
+ role = 'maintainer'
+
+ OeMessageManager.add_oe_sig_repos_to_members(pg_url, repo_name, member_name, role=role)
+
+ @staticmethod
+ def oe_organize_message_handler(pg_url, oe_spider_method):
+ f = open('./docs/organize.txt', 'r', encoding='utf-8')
+ lines = f.readlines()
+ st = 0
+ en = 0
+ filed_list = ['role', 'name', 'personal_message']
+ OeMessageManager.clear_oe_community_organization_structure(pg_url)
+ while st < len(lines):
+ while en < len(lines) and lines[en] != '\n':
+ en += 1
+ lines[st] = lines[st].replace('\n', '')
+ committee_name = lines[st]
+ for i in range(st + 1, en):
+ data = lines[i].replace('\n', '').split(' ')
+ info = {'committee_name': committee_name}
+ for j in range(min(len(filed_list), len(data))):
+ data[j] = data[j].replace('\n', '')
+ if j == 2:
+ data[j] = data[j].replace('_', ' ')
+ info[filed_list[j]] = data[j]
+ OeMessageManager.add_oe_community_organization_structure(pg_url, info)
+ st = en + 1
+ en = st
+
+ @staticmethod
+ def oe_openeuler_version_message_handler(pg_url, oe_spider_method):
+ f = open('./docs/openeuler_version.txt', 'r', encoding='utf-8')
+ lines = f.readlines()
+ filed_list = ['openeuler_version', 'kernel_version', 'publish_time', 'version_type']
+ OeMessageManager.clear_oe_community_openEuler_version(pg_url)
+ for line in lines:
+ tmp_list = line.replace('\n', '').split(' ')
+ tmp_list[2] = datetime.strptime(tmp_list[2], "%Y-%m-%d")
+ tmp_dict = {}
+ for i in range(len(filed_list)):
+ tmp_dict[filed_list[i]] = tmp_list[i]
+ OeMessageManager.add_oe_community_openEuler_version(pg_url, tmp_dict)
+
+
+def work(args):
+ func_map = {'card': PullMessageFromOeWeb.pull_oe_compatibility_card,
+ 'commercial_software': PullMessageFromOeWeb.pull_oe_compatibility_commercial_software,
+ 'cve_database': PullMessageFromOeWeb.pull_oe_compatibility_cve_database,
+ 'oepkgs': PullMessageFromOeWeb.pull_oe_compatibility_oepkgs,
+ 'solution':PullMessageFromOeWeb.pull_oe_compatibility_solution,
+ 'open_source_software': PullMessageFromOeWeb.pull_oe_compatibility_open_source_software,
+ 'overall_unit': PullMessageFromOeWeb.pull_oe_compatibility_overall_unit,
+ 'security_notice': PullMessageFromOeWeb.pull_oe_compatibility_security_notice,
+ 'osv': PullMessageFromOeWeb.pull_oe_compatibility_osv,
+ 'openeuler_version_message': PullMessageFromOeWeb.oe_openeuler_version_message_handler,
+ 'organize_message': PullMessageFromOeWeb.oe_organize_message_handler,
+ 'openeuler_sig': PullMessageFromOeWeb.pull_oe_openeuler_sig}
+ prompt_map = {'card': 'openEuler支持的板卡信息',
+ 'commercial_software': 'openEuler支持的商业软件',
+ 'cve_database': 'openEuler的cve信息',
+ 'oepkgs': 'openEuler支持的软件包信息',
+ 'solution':'openEuler支持的解决方案',
+ 'open_source_software': 'openEuler支持的开源软件信息',
+ 'overall_unit': 'openEuler支持的整机信息',
+ 'security_notice': 'openEuler官网的安全公告',
+ 'osv': 'openEuler相关的osv厂商',
+ 'openeuler_version_message': 'openEuler的版本信息',
+ 'organize_message': 'openEuler社区成员组织架构',
+ 'openeuler_sig': 'openeuler_sig(openEuler SIG组成员信息)'}
+ method = args['method']
+ pg_host = args['pg_host']
+ pg_port = args['pg_port']
+ pg_user = args['pg_user']
+ pg_pwd = args['pg_pwd']
+ pg_database = args['pg_database']
+ oepkgs_user = args['oepkgs_user']
+ oepkgs_pwd = args['oepkgs_pwd']
+ oe_spider_method = args['oe_spider_method']
+ if method == 'init_pg_info':
+ if pg_host is None or pg_port is None or pg_pwd is None or pg_pwd is None or pg_database is None:
+ print('请入完整pg配置信息')
+ exit()
+ pg_info = {'pg_host': pg_host, 'pg_port': pg_port,
+ 'pg_user': pg_user, 'pg_pwd': pg_pwd, 'pg_database': pg_database}
+ if not os.path.exists('./config'):
+ os.mkdir('./config')
+ with open('./config/pg_info.yaml', 'w', encoding='utf-8') as f:
+ yaml.dump(pg_info, f)
+ elif method == 'init_oepkgs_info':
+ if oepkgs_user is None or oepkgs_pwd is None:
+ print('请入完整oepkgs配置信息')
+ exit()
+ oepkgs_info = {'oepkgs_user': oepkgs_user, 'oepkgs_pwd': oepkgs_pwd}
+ if not os.path.exists('./config'):
+ os.mkdir('./config')
+ with open('./config/oepkgs_info.yaml', 'w', encoding='utf-8') as f:
+ yaml.dump(oepkgs_info, f)
+ elif method == 'update_oe_table':
+ if not os.path.exists('./config/pg_info.yaml'):
+ print('请先配置postgres数据库信息')
+ exit()
+ if (oe_spider_method == 'all' or oe_spider_method == 'oepkgs') and not os.path.exists(
+ './config/oepkgs_info.yaml'):
+ print('请先配置oepkgs用户信息')
+ exit()
+ pg_info = {}
+ try:
+ with open('./config/pg_info.yaml', 'r', encoding='utf-8') as f:
+ pg_info = yaml.load(f, Loader=SafeLoader)
+ except:
+ print('postgres数据库配置文件加载失败')
+ exit()
+ pg_host = pg_info.get('pg_host', '') + ':' + pg_info.get('pg_port', '')
+ pg_user = pg_info.get('pg_user', '')
+ pg_pwd = pg_info.get('pg_pwd', '')
+ pg_database = pg_info.get('pg_database', '')
+ pg_url = f'postgresql+psycopg2://{pg_user}:{pg_pwd}@{pg_host}/{pg_database}'
+ if oe_spider_method == 'all' or oe_spider_method == 'oepkgs':
+ with open('./config/oepkgs_info.yaml', 'r', encoding='utf-8') as f:
+ try:
+ oepkgs_info = yaml.load(f, Loader=SafeLoader)
+ except:
+ print('oepkgs配置信息读取失败')
+ exit()
+ if oe_spider_method == 'all':
+ for func in func_map:
+ print(f'开始爬取{prompt_map[func]}')
+ try:
+ if func == 'oepkgs':
+ func_map[func](pg_url, oepkgs_info, func)
+ elif func == 'openeuler_sig':
+ func_map[func](pg_url)
+ else:
+ func_map[func](pg_url, func)
+ print(prompt_map[func] + '入库成功')
+ except Exception as e:
+ import traceback
+ print(traceback.format_exc)
+ print(f'{prompt_map[func]}入库失败由于:{e}')
+ else:
+ try:
+ print(f'开始爬取{prompt_map[oe_spider_method]}')
+ if oe_spider_method == 'oepkgs':
+ func_map[oe_spider_method](pg_url, oepkgs_info, oe_spider_method)
+ elif oe_spider_method == 'openeuler_sig':
+ func_map[oe_spider_method](pg_url)
+ else:
+ func_map[oe_spider_method](pg_url, oe_spider_method)
+ print(prompt_map[oe_spider_method] + '入库成功')
+ except Exception as e:
+ print(f'{prompt_map[oe_spider_method]}入库失败由于:{e}')
+
+
+def init_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--method", type=str, required=True,
+ choices=['init_pg_info', 'init_oepkgs_info', 'update_oe_table'],
+ help='''脚本使用模式,有初始化pg信息、初始化oepkgs账号信息和爬取openEuler官网数据的功能''')
+ parser.add_argument("--pg_host", default=None, required=False, help="postres的ip")
+ parser.add_argument("--pg_port", default=None, required=False, help="postres的端口")
+ parser.add_argument("--pg_user", default=None, required=False, help="postres的用户")
+ parser.add_argument("--pg_pwd", default=None, required=False, help="postres的密码")
+ parser.add_argument("--pg_database", default=None, required=False, help="postres的数据库")
+ parser.add_argument("--oepkgs_user", default=None, required=False, help="语料库所在postres的ip")
+ parser.add_argument("--oepkgs_pwd", default=None, required=False, help="语料库所在postres的端口")
+ parser.add_argument(
+ "--oe_spider_method", default='all', required=False,
+ choices=['all', 'card', 'commercial_software', 'cve_database', 'oepkgs','solution','open_source_software',
+ 'overall_unit', 'security_notice', 'osv', 'cve_database', 'openeuler_version_message',
+ 'organize_message', 'openeuler_sig'],
+ help="需要爬取的openEuler数据类型,有all(所有内容),card(openEuler支持的板卡信息),commercial_software(openEuler支持的商业软件),"
+ "cve_database(openEuler的cve信息),oepkgs(openEuler支持的软件包信息),solution(openEuler支持的解决方案),open_source_software(openEuler支持的开源软件信息),"
+ "overall_unit(openEuler支持的整机信息),security_notice(openEuler官网的安全公告),osv(openEuler相关的osv厂商),"
+ "cve_database(openEuler的cve漏洞),openeuler_version_message(openEuler的版本信息),"
+ "organize_message(openEuler社区成员组织架构),openeuler_sig(openEuler SIG组成员信息)")
+ args = parser.parse_args()
+ return args
+
+
+if __name__ == "__main__":
+ set_start_method('spawn')
+ args = init_args()
+ work(vars(args))