diff --git a/.idea/icon.png b/.idea/icon.png deleted file mode 100644 index 7d9781edbb6a38cf876574b7ae0566b2dd45672d..0000000000000000000000000000000000000000 Binary files a/.idea/icon.png and /dev/null differ diff --git a/apps/application/chat_pipeline/step/search_dataset_step/i_search_dataset_step.py b/apps/application/chat_pipeline/step/search_dataset_step/i_search_dataset_step.py index 97da2964334ea67d98059c899f8ce63fb42774c0..1c472f2992c84023d6e197d849b694aa2cacdf09 100644 --- a/apps/application/chat_pipeline/step/search_dataset_step/i_search_dataset_step.py +++ b/apps/application/chat_pipeline/step/search_dataset_step/i_search_dataset_step.py @@ -51,6 +51,7 @@ class ISearchDatasetStep(IBaseChatPipelineStep): def _run(self, manage: PipelineManage): paragraph_list = self.execute(**self.context['step_args']) manage.context['paragraph_list'] = paragraph_list + print("查询结果:",paragraph_list) self.context['paragraph_list'] = paragraph_list @abstractmethod diff --git a/apps/application/serializers/application_serializers.py b/apps/application/serializers/application_serializers.py index eecfc6a0f16f9216d9d41fbb5875bfd53cc3e3af..e1c633f72ac10cd501396b9773a7d7ff54608020 100644 --- a/apps/application/serializers/application_serializers.py +++ b/apps/application/serializers/application_serializers.py @@ -461,10 +461,17 @@ class ApplicationSerializer(serializers.Serializer): self.data.get('similarity'), SearchMode(self.data.get('search_mode')), model) + # 获取段落标识 hit_dict = reduce(lambda x, y: {**x, **y}, [{hit.get('paragraph_id'): hit} for hit in hit_list], {}) + # 查询分段信息表 p_list = list_paragraph([h.get('paragraph_id') for h in hit_list]) - return [{**p, 'similarity': hit_dict.get(p.get('id')).get('similarity'), - 'comprehensive_score': hit_dict.get(p.get('id')).get('comprehensive_score')} for p in p_list] + + print("p_list---->", p_list) + print("hit_dict---->", hit_dict) + return [{**p, + 'similarity': hit_dict.get(p.get('id')).get('similarity'), + 'comprehensive_score': hit_dict.get(p.get('id')).get('comprehensive_score') + } for p in p_list] class Query(serializers.Serializer): name = serializers.CharField(required=False, error_messages=ErrMessage.char("应用名称")) diff --git a/apps/common/config/embedding_config.py b/apps/common/config/embedding_config.py index a6e9ab9aa9b3668c09d43def402abc966037f07c..2fb9062c7ab8e0d8e7d5b4c58e592859d7a727a3 100644 --- a/apps/common/config/embedding_config.py +++ b/apps/common/config/embedding_config.py @@ -48,19 +48,31 @@ class ModelManage: class VectorStore: - from embedding.vector.pg_vector import PGVector from embedding.vector.base_vector import BaseVectorStore - instance_map = { - 'pg_vector': PGVector, - } instance = None @staticmethod def get_embedding_vector() -> BaseVectorStore: - from embedding.vector.pg_vector import PGVector + if VectorStore.instance is None: from smartdoc.const import CONFIG - vector_store_class = VectorStore.instance_map.get(CONFIG.get("VECTOR_STORE_NAME"), - PGVector) - VectorStore.instance = vector_store_class() + store_name = CONFIG.get("VECTOR_STORE_NAME") + if store_name == 'es_vector': + from embedding.vector.es_vector import ESVector + from elasticsearch import Elasticsearch + + es_host = CONFIG.get("VECTOR_STORE_ES_ENDPOINT") + es_username = CONFIG.get("VECTOR_STORE_ES_USERNAME") + es_password = CONFIG.get("VECTOR_STORE_ES_PASSWORD") + es_indexname = CONFIG.get("VECTOR_STORE_ES_INDEX") + http_auth = (es_username, es_password) + VectorStore.instance = ESVector(Elasticsearch(hosts=[es_host], + http_auth=http_auth) + , es_indexname) + if not VectorStore.instance.vector_is_create(): + VectorStore.instance.vector_create() + else: + from embedding.vector.pg_vector import PGVector + VectorStore.instance = PGVector() + return VectorStore.instance diff --git a/apps/common/event/listener_manage.py b/apps/common/event/listener_manage.py index 537677ba2cbd6ed7a3e9cf8f32129e44d2bd941d..6b8d4ba4c218184f4edf4c3f0dc2a4fa9ae3768e 100644 --- a/apps/common/event/listener_manage.py +++ b/apps/common/event/listener_manage.py @@ -127,7 +127,6 @@ class ListenerManagement: def is_save_function(): return QuerySet(Paragraph).filter(id=paragraph_id).exists() - # 批量向量化 VectorStore.get_embedding_vector().batch_save(data_list, embedding_model, is_save_function) except Exception as e: diff --git a/apps/embedding/models/embedding.py b/apps/embedding/models/embedding.py index 5f954e36b6e14d0732b27b8e74060cd96c993d0d..c54144239b9e3527dc2843ad536d67556539a4e5 100644 --- a/apps/embedding/models/embedding.py +++ b/apps/embedding/models/embedding.py @@ -48,5 +48,19 @@ class Embedding(models.Model): meta = models.JSONField(verbose_name="元数据", default=dict) + # 将对象转换成es可用的json对象 + def to_dict(self): + return { + "id": str(self.id), + "document_id": self.document_id, + "paragraph_id": self.paragraph_id, + "dataset_id": self.dataset_id, + "is_active": self.is_active, + "source_id": self.source_id, + "source_type": self.source_type, + "embedding": self.embedding, + "search_vector": self.search_vector, + } + class Meta: db_table = "embedding" diff --git a/apps/embedding/vector/es_vector.py b/apps/embedding/vector/es_vector.py new file mode 100644 index 0000000000000000000000000000000000000000..17cab8dee4b9cd0b2bfc82f8230830cf8880c1b6 --- /dev/null +++ b/apps/embedding/vector/es_vector.py @@ -0,0 +1,324 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:ilxch@163.com + @file: es_vector.py + @date:2024/09/04 15:28 + @desc: ES数据库操作API +""" +import uuid +from abc import ABC, abstractmethod +from typing import Dict, List +from langchain_core.embeddings import Embeddings +from common.util.ts_vecto_util import to_ts_vector, to_query +from embedding.models import Embedding, SourceType, SearchMode +from embedding.vector.base_vector import BaseVectorStore +from elasticsearch import helpers + + +class ESVector(BaseVectorStore): + def __init__(self, es_connection, index_name): + self.es = es_connection + self.index_name = index_name + + def search(self, body: Dict, size: int): + return self.es.search(index=self.index_name, body=body, size=size) + + def searchembed(self, query: Dict, vector: List[float], size, similarity): + return self.es.search(index=self.index_name, body={ + "min_score": similarity, + "size": size, + "query": query, + "knn": { + "field": "embedding", + "query_vector": vector, + "k": size, + "num_candidates": 2 * size # 默认为2倍size + }, + "_source": ["paragraph_id"] # 添加这一行 + }) + + def delete_by_query(self, query): + return self.es.search(index=self.index_name, body=query) + + def bulk(self, documents: List[Dict]): + actions = [ + { + "_index": self.index_name, + "_id": document.get('id'), + "_source": document + } + for document in documents + ] + return helpers.bulk(self.es, actions) + + def delete_by_source_ids(self, source_ids: List[str], source_type: str): + if len(source_ids) == 0: + return + return self.delete_by_query(query={ + "query": { + "terms": { + "_id": source_ids + } + } + }) + + def update_by_source_ids(self, source_ids: List[str], instance: Dict): + return self.bulk(source_ids, instance) + + def vector_is_create(self) -> bool: + return self.es.indices.exists(index=self.index_name) + + def vector_create(self): + return self.es.indices.create(index=self.index_name, ignore=400) + + def _save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str, + is_active: bool, + embedding: Embeddings): + text_embedding = embedding.embed_query(text) + embedding = Embedding(id=uuid.uuid1(), + dataset_id=dataset_id, + document_id=document_id, + is_active=is_active, + paragraph_id=paragraph_id, + source_id=source_id, + embedding=text_embedding, + source_type=source_type, + search_vector=to_ts_vector(text)) + res = self.es.index(index=self.index_name, body=embedding.to_dict()) + print("res---->" + str(res)) + return True + + def _batch_save(self, text_list: List[Dict], embedding: Embeddings, is_save_function): + texts = [row.get('text') for row in text_list] + embeddings = embedding.embed_documents(texts) + embedding_list = [Embedding(id=uuid.uuid1(), + document_id=text_list[index].get('document_id'), + paragraph_id=text_list[index].get('paragraph_id'), + dataset_id=text_list[index].get('dataset_id'), + is_active=text_list[index].get('is_active', True), + source_id=text_list[index].get('source_id'), + source_type=text_list[index].get('source_type'), + embedding=embeddings[index], + search_vector=to_ts_vector(text_list[index]['text'])) for index in + range(0, len(texts))] + if is_save_function(): + embeddings = [embedding.to_dict() for embedding in embedding_list] + self.bulk(embeddings) + return True + + def update_by_source_id(self, source_id: str, instance: Dict): + self.es.update(index=self.index_name, id=source_id, body={"doc": instance}) + + def update_by_paragraph_id(self, paragraph_id: str, instance: Dict): + self.es.update(index=self.index_name, id=paragraph_id, body={"doc": instance}) + + def update_by_paragraph_ids(self, paragraph_ids: List[str], instance: Dict): + for paragraph_id in paragraph_ids: + self.es.update(index=self.index_name, id=paragraph_id, body={"doc": instance}) + + def delete_by_dataset_id(self, dataset_id: str): + self.es.delete_by_query(index=self.index_name, body={"query": {"match": {"dataset_id": dataset_id}}}) + + def delete_by_dataset_id_list(self, dataset_id_list: List[str]): + self.delete_by_query(body={"query": {"terms": {"dataset_id": dataset_id_list}}}) + + def delete_by_document_id(self, document_id: str): + self.delete_by_query(query={ + "query": { + "match": { + "document_id.keyword": document_id + } + } + }) + return True + + def delete_by_document_id_list(self, document_id_list: List[str]): + if len(document_id_list) == 0: + return True + self.delete_by_query(query={ + "query": { + "match": { + "document_id.keyword": document_id_list + } + } + }) + return True + + def delete_by_source_id(self, source_id: str, source_type: str): + self.delete_by_query(query={"query": { + "bool": {"must": [{"match": {"source_id.keyword": source_id}}, {"match": {"source_type": source_type}}]} + }}) + return True + + def delete_by_paragraph_id(self, paragraph_id: str): + self.delete_by_query(query={ + "query": { + "match": { + "paragraph_id.keyword": paragraph_id + } + } + }) + + def delete_by_paragraph_ids(self, paragraph_ids: List[str]): + self.delete_by_query(query={ + "query": { + "terms": { + "paragraph_id.keyword": paragraph_ids + } + } + }) + + def hit_test(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str], top_number: int, + similarity: float, + search_mode: SearchMode, + embedding: Embeddings): + if dataset_id_list is None or len(dataset_id_list) == 0: + return [] + embedding_query = embedding.embed_query(query_text) + filter_query = { + "bool": { + "filter": [ + {"term": {"is_active": True}}, + {"terms": {"dataset_id.keyword": [str(uuid) for uuid in dataset_id_list]}}, + {"bool": {"must_not": [ + {"terms": {"document_id.keyword": [str(uuid) for uuid in exclude_document_id_list]}}]}} + ] + } + } + for search_handle in search_handle_list: + if search_handle.support(search_mode, self): + return search_handle.handle(filter_query, query_text, embedding_query, top_number, similarity, + search_mode) + + def query(self, query_text: str, query_embedding: List[float], dataset_id_list: list[str], + exclude_document_id_list: list[str], + exclude_paragraph_list: list[str], is_active: bool, top_n: int, similarity: float, + search_mode: SearchMode): + if dataset_id_list is None or len(dataset_id_list) == 0: + return [] + filter_query = { + "bool": { + "filter": [ + {"term": {"is_active": "true"}}, + {"term": {"dataset_id.keyword": [str(uuid) for uuid in dataset_id_list]}}, + {"bool": {"must_not": [ + {"terms": {"document_id.keyword": [str(uuid) for uuid in exclude_document_id_list]}}]}}, + {"bool": {"must_not": [ + {"terms": {"paragraph_id.keyword": [str(uuid) for uuid in exclude_paragraph_list]}}]}} + ] + } + } + for search_handle in search_handle_list: + if search_handle.support(search_mode): + return search_handle.handle(filter_query, query_text, query_embedding, top_n, similarity, search_mode) + + +class ISearch(ABC): + @abstractmethod + def support(self, search_mode: SearchMode, esVector: ESVector): + pass + + @abstractmethod + def handle(self, filter_query: Dict, query_text: str, query_embedding: List[float], top_number: int, + similarity: float, search_mode: SearchMode): + pass + + +# 向量检索 +class EmbeddingSearch(ISearch): + def __init__(self): + self.esVector = None + + def handle(self, + filter_query, + query_text, + query_embedding, + top_number: int, + similarity: float, + search_mode: SearchMode): + result = self.esVector.searchembed(query=filter_query, vector=query_embedding, size=top_number, + similarity=similarity) + hits = result['hits']['hits'] + return [{'paragraph_id': uuid.UUID(hit['_source']['paragraph_id']), + 'similarity': hit['_score'], + 'comprehensive_score': hit['_source'].get('_score', 0)} + for hit in hits] + + def support(self, search_mode: SearchMode, esVector: ESVector): + self.esVector = esVector + return search_mode.value == SearchMode.embedding.value + + +# 关键词检索 +class KeywordsSearch(ISearch): + def __init__(self): + self.esVector = None + + def handle(self, + filter_query, + query_text, + query_embedding, + top_number: int, + similarity: float, + search_mode: SearchMode): + filter_query['bool']['must'] = {"multi_match": { + "query": to_query(query_text), + "fields": ["search_vector"] + }} + result = self.esVector.search(body={ + "query": filter_query + }, size=top_number) + hits = [hit for hit in result['hits']['hits'] if hit['_score'] > similarity] + return [{'paragraph_id': uuid.UUID(hit['_source']['paragraph_id']), + 'similarity': hit['_score'], + 'comprehensive_score': hit['_source'].get('_score', 0)} + for hit in hits] + + def support(self, search_mode: SearchMode, esVector: ESVector): + self.esVector = esVector + return search_mode.value == SearchMode.keywords.value + + +# 混合检索 +class BlendSearch(ISearch): + def __init__(self): + self.esVector = None + + def handle(self, + filter_query, + query_text, + query_embedding, + top_number: int, + similarity: float, + search_mode: SearchMode): + knn_filter = filter_query + filter_query['bool']['must'] = {"multi_match": { + "query": to_query(query_text), + "fields": ["search_vector"] + }} + result = self.esVector.search(body={ + "min_score": similarity, + "query": filter_query, + "knn": { + "field": "embedding", + "query_vector": query_embedding, + "k": top_number, + "num_candidates": 2 * top_number, # 默认为2倍size + "filter": knn_filter + }, + "_source": ["paragraph_id"] # 添加这一行 + }, size=top_number) + + hits = [hit for hit in result['hits']['hits'] if hit['_score'] > similarity] + + return [{'paragraph_id': uuid.UUID(hit['_source']['paragraph_id']), + 'similarity': hit['_score'], + 'comprehensive_score': hit['_source'].get('_score', 0)} + for hit in hits] + + def support(self, search_mode: SearchMode, esVector: ESVector): + self.esVector = esVector + return search_mode.value == SearchMode.blend.value + +search_handle_list = [EmbeddingSearch(), KeywordsSearch(), BlendSearch()] diff --git a/apps/setting/models_provider/base_model_provider.py b/apps/setting/models_provider/base_model_provider.py index 9171e036cfcb92a724904df14ef7c7b29973cfd2..6b33139838fd21fbe0fc27ef7695e1513b7a2a0c 100644 --- a/apps/setting/models_provider/base_model_provider.py +++ b/apps/setting/models_provider/base_model_provider.py @@ -139,9 +139,6 @@ class BaseModelCredential(ABC): class ModelTypeConst(Enum): LLM = {'code': 'LLM', 'message': '大语言模型'} EMBEDDING = {'code': 'EMBEDDING', 'message': '向量模型'} - STT = {'code': 'STT', 'message': '语音识别'} - TTS = {'code': 'TTS', 'message': '语音合成'} - RERANKER = {'code': 'RERANKER', 'message': '重排模型'} class ModelInfo: diff --git a/apps/setting/models_provider/impl/base_stt.py b/apps/setting/models_provider/impl/base_stt.py deleted file mode 100644 index aae72a559ebf1409855d3a4a833a78e2cc50efb1..0000000000000000000000000000000000000000 --- a/apps/setting/models_provider/impl/base_stt.py +++ /dev/null @@ -1,14 +0,0 @@ -# coding=utf-8 -from abc import abstractmethod - -from pydantic import BaseModel - - -class BaseSpeechToText(BaseModel): - @abstractmethod - def check_auth(self): - pass - - @abstractmethod - def speech_to_text(self, audio_file): - pass diff --git a/apps/setting/models_provider/impl/base_tts.py b/apps/setting/models_provider/impl/base_tts.py deleted file mode 100644 index 6311f26865391ed174dbb5f2147009596735e169..0000000000000000000000000000000000000000 --- a/apps/setting/models_provider/impl/base_tts.py +++ /dev/null @@ -1,14 +0,0 @@ -# coding=utf-8 -from abc import abstractmethod - -from pydantic import BaseModel - - -class BaseTextToSpeech(BaseModel): - @abstractmethod - def check_auth(self): - pass - - @abstractmethod - def text_to_speech(self, text): - pass diff --git a/apps/setting/models_provider/impl/local_model_provider/model/embedding.py b/apps/setting/models_provider/impl/local_model_provider/model/embedding.py index 38cf1aeb98f3856fba83b7dc9e0d5b414f0529b5..b2c60f40bca62e5fbd279530792c224cc6a6b46f 100644 --- a/apps/setting/models_provider/impl/local_model_provider/model/embedding.py +++ b/apps/setting/models_provider/impl/local_model_provider/model/embedding.py @@ -33,6 +33,7 @@ class WebLocalEmbedding(MaxKBBaseModel, BaseModel, Embeddings): res = requests.post(f'{CONFIG.get("LOCAL_MODEL_PROTOCOL")}://{bind}/api/model/{self.model_id}/embed_query', {'text': text}) result = res.json() + #print("查询内容:",result) if result.get('code', 500) == 200: return result.get('data') raise Exception(result.get('msg')) diff --git a/apps/setting/models_provider/impl/openai_model_provider/credential/stt.py b/apps/setting/models_provider/impl/openai_model_provider/credential/stt.py deleted file mode 100644 index 59506311c21e0b9d402a2b942b8d1ad02482b873..0000000000000000000000000000000000000000 --- a/apps/setting/models_provider/impl/openai_model_provider/credential/stt.py +++ /dev/null @@ -1,42 +0,0 @@ -# coding=utf-8 -from typing import Dict - -from common import forms -from common.exception.app_exception import AppApiException -from common.forms import BaseForm -from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode - - -class OpenAISTTModelCredential(BaseForm, BaseModelCredential): - api_base = forms.TextInputField('API 域名', required=True) - api_key = forms.PasswordInputField('API Key', required=True) - - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, - raise_exception=False): - model_type_list = provider.get_model_type_list() - if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): - raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') - - for key in ['api_base', 'api_key']: - if key not in model_credential: - if raise_exception: - raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') - else: - return False - try: - model = provider.get_model(model_type, model_name, model_credential) - model.check_auth() - except Exception as e: - if isinstance(e, AppApiException): - raise e - if raise_exception: - raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') - else: - return False - return True - - def encryption_dict(self, model: Dict[str, object]): - return {**model, 'api_key': super().encryption(model.get('api_key', ''))} - - def get_model_params_setting_form(self, model_name): - pass diff --git a/apps/setting/models_provider/impl/openai_model_provider/model/stt.py b/apps/setting/models_provider/impl/openai_model_provider/model/stt.py deleted file mode 100644 index 66b2daeda0c82b4e12f6bd8f6b170774c36f54fc..0000000000000000000000000000000000000000 --- a/apps/setting/models_provider/impl/openai_model_provider/model/stt.py +++ /dev/null @@ -1,53 +0,0 @@ -from typing import Dict - -from openai import OpenAI - -from common.config.tokenizer_manage_config import TokenizerManage -from setting.models_provider.base_model_provider import MaxKBBaseModel -from setting.models_provider.impl.base_stt import BaseSpeechToText - - -def custom_get_token_ids(text: str): - tokenizer = TokenizerManage.get_tokenizer() - return tokenizer.encode(text) - - -class OpenAISpeechToText(MaxKBBaseModel, BaseSpeechToText): - api_base: str - api_key: str - model: str - - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.api_key = kwargs.get('api_key') - self.api_base = kwargs.get('api_base') - - @staticmethod - def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): - optional_params = {} - if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: - optional_params['max_tokens'] = model_kwargs['max_tokens'] - if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: - optional_params['temperature'] = model_kwargs['temperature'] - return OpenAISpeechToText( - model=model_name, - api_base=model_credential.get('api_base'), - api_key=model_credential.get('api_key'), - **optional_params, - ) - - def check_auth(self): - client = OpenAI( - base_url=self.api_base, - api_key=self.api_key - ) - response_list = client.models.with_raw_response.list() - # print(response_list) - - def speech_to_text(self, audio_file): - client = OpenAI( - base_url=self.api_base, - api_key=self.api_key - ) - res = client.audio.transcriptions.create(model=self.model, language="zh", file=audio_file) - return res.text diff --git a/apps/setting/models_provider/impl/openai_model_provider/model/tts.py b/apps/setting/models_provider/impl/openai_model_provider/model/tts.py deleted file mode 100644 index c0975484012da60ad66e6f45b9acaecfabcc3888..0000000000000000000000000000000000000000 --- a/apps/setting/models_provider/impl/openai_model_provider/model/tts.py +++ /dev/null @@ -1,58 +0,0 @@ -from typing import Dict - -from openai import OpenAI - -from common.config.tokenizer_manage_config import TokenizerManage -from setting.models_provider.base_model_provider import MaxKBBaseModel -from setting.models_provider.impl.base_tts import BaseTextToSpeech - - -def custom_get_token_ids(text: str): - tokenizer = TokenizerManage.get_tokenizer() - return tokenizer.encode(text) - - -class OpenAITextToSpeech(MaxKBBaseModel, BaseTextToSpeech): - api_base: str - api_key: str - model: str - - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.api_key = kwargs.get('api_key') - self.api_base = kwargs.get('api_base') - self.model = kwargs.get('model') - - @staticmethod - def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): - optional_params = {} - if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: - optional_params['max_tokens'] = model_kwargs['max_tokens'] - if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: - optional_params['temperature'] = model_kwargs['temperature'] - return OpenAITextToSpeech( - model=model_name, - api_base=model_credential.get('api_base'), - api_key=model_credential.get('api_key'), - **optional_params, - ) - - def check_auth(self): - client = OpenAI( - base_url=self.api_base, - api_key=self.api_key - ) - response_list = client.models.with_raw_response.list() - # print(response_list) - - def text_to_speech(self, text): - client = OpenAI( - base_url=self.api_base, - api_key=self.api_key - ) - with client.audio.speech.with_streaming_response.create( - model=self.model, - voice="alloy", - input=text, - ) as response: - return response.read() diff --git a/apps/setting/models_provider/impl/openai_model_provider/openai_model_provider.py b/apps/setting/models_provider/impl/openai_model_provider/openai_model_provider.py index cbec2d504e4fff0e5535f82465e56ba1c34a807b..fb4c89d7b846d4eef411851f1d66c672d2dc51b4 100644 --- a/apps/setting/models_provider/impl/openai_model_provider/openai_model_provider.py +++ b/apps/setting/models_provider/impl/openai_model_provider/openai_model_provider.py @@ -13,15 +13,11 @@ from setting.models_provider.base_model_provider import IModelProvider, ModelPro ModelTypeConst, ModelInfoManage from setting.models_provider.impl.openai_model_provider.credential.embedding import OpenAIEmbeddingCredential from setting.models_provider.impl.openai_model_provider.credential.llm import OpenAILLMModelCredential -from setting.models_provider.impl.openai_model_provider.credential.stt import OpenAISTTModelCredential from setting.models_provider.impl.openai_model_provider.model.embedding import OpenAIEmbeddingModel from setting.models_provider.impl.openai_model_provider.model.llm import OpenAIChatModel -from setting.models_provider.impl.openai_model_provider.model.stt import OpenAISpeechToText -from setting.models_provider.impl.openai_model_provider.model.tts import OpenAITextToSpeech from smartdoc.conf import PROJECT_DIR openai_llm_model_credential = OpenAILLMModelCredential() -openai_stt_model_credential = OpenAISTTModelCredential() model_info_list = [ ModelInfo('gpt-3.5-turbo', '最新的gpt-3.5-turbo,随OpenAI调整而更新', ModelTypeConst.LLM, openai_llm_model_credential, OpenAIChatModel @@ -62,13 +58,7 @@ model_info_list = [ OpenAIChatModel), ModelInfo('gpt-4-1106-preview', '2023年11月6日的gpt-4-turbo快照,支持上下文长度128,000 tokens', ModelTypeConst.LLM, openai_llm_model_credential, - OpenAIChatModel), - ModelInfo('whisper-1', '', - ModelTypeConst.STT, openai_stt_model_credential, - OpenAISpeechToText), - ModelInfo('tts-1', '', - ModelTypeConst.TTS, openai_stt_model_credential, - OpenAITextToSpeech) + OpenAIChatModel) ] open_ai_embedding_credential = OpenAIEmbeddingCredential() model_info_embedding_list = [ diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/stt.py b/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/stt.py deleted file mode 100644 index 94e656162d66c39d79d08dda031977150cb3979f..0000000000000000000000000000000000000000 --- a/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/stt.py +++ /dev/null @@ -1,45 +0,0 @@ -# coding=utf-8 - -from typing import Dict - -from common import forms -from common.exception.app_exception import AppApiException -from common.forms import BaseForm -from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode - - -class VolcanicEngineSTTModelCredential(BaseForm, BaseModelCredential): - volcanic_api_url = forms.TextInputField('API 域名', required=True, default_value='wss://openspeech.bytedance.com/api/v2/asr') - volcanic_app_id = forms.TextInputField('App ID', required=True) - volcanic_token = forms.PasswordInputField('Token', required=True) - volcanic_cluster = forms.TextInputField('Cluster', required=True) - - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, - raise_exception=False): - model_type_list = provider.get_model_type_list() - if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): - raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') - - for key in ['volcanic_api_url', 'volcanic_app_id', 'volcanic_token', 'volcanic_cluster']: - if key not in model_credential: - if raise_exception: - raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') - else: - return False - try: - model = provider.get_model(model_type, model_name, model_credential) - model.check_auth() - except Exception as e: - if isinstance(e, AppApiException): - raise e - if raise_exception: - raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') - else: - return False - return True - - def encryption_dict(self, model: Dict[str, object]): - return {**model, 'volcanic_token': super().encryption(model.get('volcanic_token', ''))} - - def get_model_params_setting_form(self, model_name): - pass diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/tts.py b/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/tts.py deleted file mode 100644 index 353740a428e0e9f9e0115444d06b0f2f6bdfe407..0000000000000000000000000000000000000000 --- a/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/tts.py +++ /dev/null @@ -1,45 +0,0 @@ -# coding=utf-8 - -from typing import Dict - -from common import forms -from common.exception.app_exception import AppApiException -from common.forms import BaseForm -from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode - - -class VolcanicEngineTTSModelCredential(BaseForm, BaseModelCredential): - volcanic_api_url = forms.TextInputField('API 域名', required=True, default_value='wss://openspeech.bytedance.com/api/v1/tts/ws_binary') - volcanic_app_id = forms.TextInputField('App ID', required=True) - volcanic_token = forms.PasswordInputField('Token', required=True) - volcanic_cluster = forms.TextInputField('Cluster', required=True) - - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, - raise_exception=False): - model_type_list = provider.get_model_type_list() - if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): - raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') - - for key in ['volcanic_api_url', 'volcanic_app_id', 'volcanic_token', 'volcanic_cluster']: - if key not in model_credential: - if raise_exception: - raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') - else: - return False - try: - model = provider.get_model(model_type, model_name, model_credential) - model.check_auth() - except Exception as e: - if isinstance(e, AppApiException): - raise e - if raise_exception: - raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') - else: - return False - return True - - def encryption_dict(self, model: Dict[str, object]): - return {**model, 'volcanic_token': super().encryption(model.get('volcanic_token', ''))} - - def get_model_params_setting_form(self, model_name): - pass diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/stt.py b/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/stt.py deleted file mode 100644 index 36aed64263c037a1a055b580d1a0b03665095f6d..0000000000000000000000000000000000000000 --- a/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/stt.py +++ /dev/null @@ -1,345 +0,0 @@ -# coding=utf-8 - -""" -requires Python 3.6 or later - -pip install asyncio -pip install websockets -""" -import asyncio -import base64 -import gzip -import hmac -import json -import uuid -import wave -from enum import Enum -from hashlib import sha256 -from io import BytesIO -from typing import Dict -from urllib.parse import urlparse -import ssl -import websockets - -from setting.models_provider.base_model_provider import MaxKBBaseModel -from setting.models_provider.impl.base_stt import BaseSpeechToText - -audio_format = "mp3" # wav 或者 mp3,根据实际音频格式设置 - -PROTOCOL_VERSION = 0b0001 -DEFAULT_HEADER_SIZE = 0b0001 - -PROTOCOL_VERSION_BITS = 4 -HEADER_BITS = 4 -MESSAGE_TYPE_BITS = 4 -MESSAGE_TYPE_SPECIFIC_FLAGS_BITS = 4 -MESSAGE_SERIALIZATION_BITS = 4 -MESSAGE_COMPRESSION_BITS = 4 -RESERVED_BITS = 8 - -# Message Type: -CLIENT_FULL_REQUEST = 0b0001 -CLIENT_AUDIO_ONLY_REQUEST = 0b0010 -SERVER_FULL_RESPONSE = 0b1001 -SERVER_ACK = 0b1011 -SERVER_ERROR_RESPONSE = 0b1111 - -# Message Type Specific Flags -NO_SEQUENCE = 0b0000 # no check sequence -POS_SEQUENCE = 0b0001 -NEG_SEQUENCE = 0b0010 -NEG_SEQUENCE_1 = 0b0011 - -# Message Serialization -NO_SERIALIZATION = 0b0000 -JSON = 0b0001 -THRIFT = 0b0011 -CUSTOM_TYPE = 0b1111 - -# Message Compression -NO_COMPRESSION = 0b0000 -GZIP = 0b0001 -CUSTOM_COMPRESSION = 0b1111 - -ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) -ssl_context.check_hostname = False -ssl_context.verify_mode = ssl.CERT_NONE - - -def generate_header( - version=PROTOCOL_VERSION, - message_type=CLIENT_FULL_REQUEST, - message_type_specific_flags=NO_SEQUENCE, - serial_method=JSON, - compression_type=GZIP, - reserved_data=0x00, - extension_header=bytes() -): - """ - protocol_version(4 bits), header_size(4 bits), - message_type(4 bits), message_type_specific_flags(4 bits) - serialization_method(4 bits) message_compression(4 bits) - reserved (8bits) 保留字段 - header_extensions 扩展头(大小等于 8 * 4 * (header_size - 1) ) - """ - header = bytearray() - header_size = int(len(extension_header) / 4) + 1 - header.append((version << 4) | header_size) - header.append((message_type << 4) | message_type_specific_flags) - header.append((serial_method << 4) | compression_type) - header.append(reserved_data) - header.extend(extension_header) - return header - - -def generate_full_default_header(): - return generate_header() - - -def generate_audio_default_header(): - return generate_header( - message_type=CLIENT_AUDIO_ONLY_REQUEST - ) - - -def generate_last_audio_default_header(): - return generate_header( - message_type=CLIENT_AUDIO_ONLY_REQUEST, - message_type_specific_flags=NEG_SEQUENCE - ) - - -def parse_response(res): - """ - protocol_version(4 bits), header_size(4 bits), - message_type(4 bits), message_type_specific_flags(4 bits) - serialization_method(4 bits) message_compression(4 bits) - reserved (8bits) 保留字段 - header_extensions 扩展头(大小等于 8 * 4 * (header_size - 1) ) - payload 类似与http 请求体 - """ - protocol_version = res[0] >> 4 - header_size = res[0] & 0x0f - message_type = res[1] >> 4 - message_type_specific_flags = res[1] & 0x0f - serialization_method = res[2] >> 4 - message_compression = res[2] & 0x0f - reserved = res[3] - header_extensions = res[4:header_size * 4] - payload = res[header_size * 4:] - result = {} - payload_msg = None - payload_size = 0 - if message_type == SERVER_FULL_RESPONSE: - payload_size = int.from_bytes(payload[:4], "big", signed=True) - payload_msg = payload[4:] - elif message_type == SERVER_ACK: - seq = int.from_bytes(payload[:4], "big", signed=True) - result['seq'] = seq - if len(payload) >= 8: - payload_size = int.from_bytes(payload[4:8], "big", signed=False) - payload_msg = payload[8:] - elif message_type == SERVER_ERROR_RESPONSE: - code = int.from_bytes(payload[:4], "big", signed=False) - result['code'] = code - payload_size = int.from_bytes(payload[4:8], "big", signed=False) - payload_msg = payload[8:] - if payload_msg is None: - return result - if message_compression == GZIP: - payload_msg = gzip.decompress(payload_msg) - if serialization_method == JSON: - payload_msg = json.loads(str(payload_msg, "utf-8")) - elif serialization_method != NO_SERIALIZATION: - payload_msg = str(payload_msg, "utf-8") - result['payload_msg'] = payload_msg - result['payload_size'] = payload_size - return result - - -def read_wav_info(data: bytes = None) -> (int, int, int, int, int): - with BytesIO(data) as _f: - wave_fp = wave.open(_f, 'rb') - nchannels, sampwidth, framerate, nframes = wave_fp.getparams()[:4] - wave_bytes = wave_fp.readframes(nframes) - return nchannels, sampwidth, framerate, nframes, len(wave_bytes) - - -class VolcanicEngineSpeechToText(MaxKBBaseModel, BaseSpeechToText): - workflow: str = "audio_in,resample,partition,vad,fe,decode,itn,nlu_punctuate" - show_language: bool = False - show_utterances: bool = False - result_type: str = "full" - format: str = "mp3" - rate: int = 16000 - language: str = "zh-CN" - bits: int = 16 - channel: int = 1 - codec: str = "raw" - audio_type: int = 1 - secret: str = "access_secret" - auth_method: str = "token" - mp3_seg_size: int = 10000 - success_code: int = 1000 # success code, default is 1000 - seg_duration: int = 15000 - nbest: int = 1 - - volcanic_app_id: str - volcanic_cluster: str - volcanic_api_url: str - volcanic_token: str - - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.volcanic_api_url = kwargs.get('volcanic_api_url') - self.volcanic_token = kwargs.get('volcanic_token') - self.volcanic_app_id = kwargs.get('volcanic_app_id') - self.volcanic_cluster = kwargs.get('volcanic_cluster') - - @staticmethod - def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): - optional_params = {} - if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: - optional_params['max_tokens'] = model_kwargs['max_tokens'] - if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: - optional_params['temperature'] = model_kwargs['temperature'] - return VolcanicEngineSpeechToText( - volcanic_api_url=model_credential.get('volcanic_api_url'), - volcanic_token=model_credential.get('volcanic_token'), - volcanic_app_id=model_credential.get('volcanic_app_id'), - volcanic_cluster=model_credential.get('volcanic_cluster'), - **optional_params - ) - - def construct_request(self, reqid): - req = { - 'app': { - 'appid': self.volcanic_app_id, - 'cluster': self.volcanic_cluster, - 'token': self.volcanic_token, - }, - 'user': { - 'uid': 'uid' - }, - 'request': { - 'reqid': reqid, - 'nbest': self.nbest, - 'workflow': self.workflow, - 'show_language': self.show_language, - 'show_utterances': self.show_utterances, - 'result_type': self.result_type, - "sequence": 1 - }, - 'audio': { - 'format': self.format, - 'rate': self.rate, - 'language': self.language, - 'bits': self.bits, - 'channel': self.channel, - 'codec': self.codec - } - } - return req - - @staticmethod - def slice_data(data: bytes, chunk_size: int) -> (list, bool): - """ - slice data - :param data: wav data - :param chunk_size: the segment size in one request - :return: segment data, last flag - """ - data_len = len(data) - offset = 0 - while offset + chunk_size < data_len: - yield data[offset: offset + chunk_size], False - offset += chunk_size - else: - yield data[offset: data_len], True - - def _real_processor(self, request_params: dict) -> dict: - pass - - def token_auth(self): - return {'Authorization': 'Bearer; {}'.format(self.volcanic_token)} - - def signature_auth(self, data): - header_dicts = { - 'Custom': 'auth_custom', - } - - url_parse = urlparse(self.volcanic_api_url) - input_str = 'GET {} HTTP/1.1\n'.format(url_parse.path) - auth_headers = 'Custom' - for header in auth_headers.split(','): - input_str += '{}\n'.format(header_dicts[header]) - input_data = bytearray(input_str, 'utf-8') - input_data += data - mac = base64.urlsafe_b64encode( - hmac.new(self.secret.encode('utf-8'), input_data, digestmod=sha256).digest()) - header_dicts['Authorization'] = 'HMAC256; access_token="{}"; mac="{}"; h="{}"'.format(self.volcanic_token, - str(mac, 'utf-8'), - auth_headers) - return header_dicts - - async def segment_data_processor(self, wav_data: bytes, segment_size: int): - reqid = str(uuid.uuid4()) - # 构建 full client request,并序列化压缩 - request_params = self.construct_request(reqid) - payload_bytes = str.encode(json.dumps(request_params)) - payload_bytes = gzip.compress(payload_bytes) - full_client_request = bytearray(generate_full_default_header()) - full_client_request.extend((len(payload_bytes)).to_bytes(4, 'big')) # payload size(4 bytes) - full_client_request.extend(payload_bytes) # payload - header = None - if self.auth_method == "token": - header = self.token_auth() - elif self.auth_method == "signature": - header = self.signature_auth(full_client_request) - async with websockets.connect(self.volcanic_api_url, extra_headers=header, max_size=1000000000, - ssl=ssl_context) as ws: - # 发送 full client request - await ws.send(full_client_request) - res = await ws.recv() - result = parse_response(res) - if 'payload_msg' in result and result['payload_msg']['code'] != self.success_code: - return result - for seq, (chunk, last) in enumerate(VolcanicEngineSpeechToText.slice_data(wav_data, segment_size), 1): - # if no compression, comment this line - payload_bytes = gzip.compress(chunk) - audio_only_request = bytearray(generate_audio_default_header()) - if last: - audio_only_request = bytearray(generate_last_audio_default_header()) - audio_only_request.extend((len(payload_bytes)).to_bytes(4, 'big')) # payload size(4 bytes) - audio_only_request.extend(payload_bytes) # payload - # 发送 audio-only client request - await ws.send(audio_only_request) - res = await ws.recv() - result = parse_response(res) - if 'payload_msg' in result and result['payload_msg']['code'] != self.success_code: - return result - return result['payload_msg']['result'][0]['text'] - - def check_auth(self): - header = self.token_auth() - - async def check(): - async with websockets.connect(self.volcanic_api_url, extra_headers=header, max_size=1000000000, - ssl=ssl_context) as ws: - pass - - asyncio.run(check()) - - def speech_to_text(self, file): - data = file.read() - audio_data = bytes(data) - if self.format == "mp3": - segment_size = self.mp3_seg_size - return asyncio.run(self.segment_data_processor(audio_data, segment_size)) - if self.format != "wav": - raise Exception("format should in wav or mp3") - nchannels, sampwidth, framerate, nframes, wav_len = read_wav_info( - audio_data) - size_per_sec = nchannels * sampwidth * framerate - segment_size = int(size_per_sec * self.seg_duration / 1000) - return asyncio.run(self.segment_data_processor(audio_data, segment_size)) diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/tts.py b/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/tts.py deleted file mode 100644 index 3a5e0afb398cf3bd11ed134e4434058e85e1548b..0000000000000000000000000000000000000000 --- a/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/tts.py +++ /dev/null @@ -1,170 +0,0 @@ -# coding=utf-8 - -''' -requires Python 3.6 or later - -pip install asyncio -pip install websockets - -''' - -import asyncio -import copy -import gzip -import json -import uuid -from typing import Dict -import ssl -import websockets - -from setting.models_provider.base_model_provider import MaxKBBaseModel -from setting.models_provider.impl.base_tts import BaseTextToSpeech - -MESSAGE_TYPES = {11: "audio-only server response", 12: "frontend server response", 15: "error message from server"} -MESSAGE_TYPE_SPECIFIC_FLAGS = {0: "no sequence number", 1: "sequence number > 0", - 2: "last message from server (seq < 0)", 3: "sequence number < 0"} -MESSAGE_SERIALIZATION_METHODS = {0: "no serialization", 1: "JSON", 15: "custom type"} -MESSAGE_COMPRESSIONS = {0: "no compression", 1: "gzip", 15: "custom compression method"} - -# version: b0001 (4 bits) -# header size: b0001 (4 bits) -# message type: b0001 (Full client request) (4bits) -# message type specific flags: b0000 (none) (4bits) -# message serialization method: b0001 (JSON) (4 bits) -# message compression: b0001 (gzip) (4bits) -# reserved data: 0x00 (1 byte) -default_header = bytearray(b'\x11\x10\x11\x00') - -ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) -ssl_context.check_hostname = False -ssl_context.verify_mode = ssl.CERT_NONE - - -class VolcanicEngineTextToSpeech(MaxKBBaseModel, BaseTextToSpeech): - volcanic_app_id: str - volcanic_cluster: str - volcanic_api_url: str - volcanic_token: str - - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.volcanic_api_url = kwargs.get('volcanic_api_url') - self.volcanic_token = kwargs.get('volcanic_token') - self.volcanic_app_id = kwargs.get('volcanic_app_id') - self.volcanic_cluster = kwargs.get('volcanic_cluster') - - @staticmethod - def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): - optional_params = {} - if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: - optional_params['max_tokens'] = model_kwargs['max_tokens'] - if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: - optional_params['temperature'] = model_kwargs['temperature'] - return VolcanicEngineTextToSpeech( - volcanic_api_url=model_credential.get('volcanic_api_url'), - volcanic_token=model_credential.get('volcanic_token'), - volcanic_app_id=model_credential.get('volcanic_app_id'), - volcanic_cluster=model_credential.get('volcanic_cluster'), - **optional_params - ) - - def check_auth(self): - header = self.token_auth() - - async def check(): - async with websockets.connect(self.volcanic_api_url, extra_headers=header, ping_interval=None, - ssl=ssl_context) as ws: - pass - - asyncio.run(check()) - - def text_to_speech(self, text): - request_json = { - "app": { - "appid": self.volcanic_app_id, - "token": "access_token", - "cluster": self.volcanic_cluster - }, - "user": { - "uid": "uid" - }, - "audio": { - "voice_type": "BV002_streaming", - "encoding": "mp3", - "speed_ratio": 1.0, - "volume_ratio": 1.0, - "pitch_ratio": 1.0, - }, - "request": { - "reqid": str(uuid.uuid4()), - "text": text, - "text_type": "plain", - "operation": "xxx" - } - } - - return asyncio.run(self.submit(request_json)) - - def token_auth(self): - return {'Authorization': 'Bearer; {}'.format(self.volcanic_token)} - - async def submit(self, request_json): - submit_request_json = copy.deepcopy(request_json) - submit_request_json["request"]["reqid"] = str(uuid.uuid4()) - submit_request_json["request"]["operation"] = "submit" - payload_bytes = str.encode(json.dumps(submit_request_json)) - payload_bytes = gzip.compress(payload_bytes) # if no compression, comment this line - full_client_request = bytearray(default_header) - full_client_request.extend((len(payload_bytes)).to_bytes(4, 'big')) # payload size(4 bytes) - full_client_request.extend(payload_bytes) # payload - header = {"Authorization": f"Bearer; {self.volcanic_token}"} - async with websockets.connect(self.volcanic_api_url, extra_headers=header, ping_interval=None, - ssl=ssl_context) as ws: - await ws.send(full_client_request) - return await self.parse_response(ws) - - @staticmethod - async def parse_response(ws): - result = b'' - while True: - res = await ws.recv() - protocol_version = res[0] >> 4 - header_size = res[0] & 0x0f - message_type = res[1] >> 4 - message_type_specific_flags = res[1] & 0x0f - serialization_method = res[2] >> 4 - message_compression = res[2] & 0x0f - reserved = res[3] - header_extensions = res[4:header_size * 4] - payload = res[header_size * 4:] - if header_size != 1: - # print(f" Header extensions: {header_extensions}") - pass - if message_type == 0xb: # audio-only server response - if message_type_specific_flags == 0: # no sequence number as ACK - continue - else: - sequence_number = int.from_bytes(payload[:4], "big", signed=True) - payload_size = int.from_bytes(payload[4:8], "big", signed=False) - payload = payload[8:] - result += payload - if sequence_number < 0: - break - else: - continue - elif message_type == 0xf: - code = int.from_bytes(payload[:4], "big", signed=False) - msg_size = int.from_bytes(payload[4:8], "big", signed=False) - error_msg = payload[8:] - if message_compression == 1: - error_msg = gzip.decompress(error_msg) - error_msg = str(error_msg, "utf-8") - break - elif message_type == 0xc: - msg_size = int.from_bytes(payload[:4], "big", signed=False) - payload = payload[4:] - if message_compression == 1: - payload = gzip.decompress(payload) - else: - break - return result diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/volcanic_engine_model_provider.py b/apps/setting/models_provider/impl/volcanic_engine_model_provider/volcanic_engine_model_provider.py index 1a0e17d8bb3b287a220e2e9ee544db2eab1e86e8..48802f6b83edebb5aec3dcd6c5578730de7662e7 100644 --- a/apps/setting/models_provider/impl/volcanic_engine_model_provider/volcanic_engine_model_provider.py +++ b/apps/setting/models_provider/impl/volcanic_engine_model_provider/volcanic_engine_model_provider.py @@ -14,34 +14,18 @@ from setting.models_provider.base_model_provider import IModelProvider, ModelPro from setting.models_provider.impl.openai_model_provider.credential.embedding import OpenAIEmbeddingCredential from setting.models_provider.impl.openai_model_provider.credential.llm import OpenAILLMModelCredential from setting.models_provider.impl.openai_model_provider.model.embedding import OpenAIEmbeddingModel -from setting.models_provider.impl.volcanic_engine_model_provider.credential.tts import VolcanicEngineTTSModelCredential from setting.models_provider.impl.volcanic_engine_model_provider.model.llm import VolcanicEngineChatModel -from setting.models_provider.impl.volcanic_engine_model_provider.credential.stt import VolcanicEngineSTTModelCredential -from setting.models_provider.impl.volcanic_engine_model_provider.model.stt import VolcanicEngineSpeechToText -from setting.models_provider.impl.volcanic_engine_model_provider.model.tts import VolcanicEngineTextToSpeech from smartdoc.conf import PROJECT_DIR volcanic_engine_llm_model_credential = OpenAILLMModelCredential() -volcanic_engine_stt_model_credential = VolcanicEngineSTTModelCredential() -volcanic_engine_tts_model_credential = VolcanicEngineTTSModelCredential() model_info_list = [ ModelInfo('ep-xxxxxxxxxx-yyyy', '用户前往火山方舟的模型推理页面创建推理接入点,这里需要输入ep-xxxxxxxxxx-yyyy进行调用', ModelTypeConst.LLM, volcanic_engine_llm_model_credential, VolcanicEngineChatModel - ), - ModelInfo('asr', - '', - ModelTypeConst.STT, - volcanic_engine_stt_model_credential, VolcanicEngineSpeechToText - ), - ModelInfo('tts', - '', - ModelTypeConst.TTS, - volcanic_engine_tts_model_credential, VolcanicEngineTextToSpeech - ), + ) ] open_ai_embedding_credential = OpenAIEmbeddingCredential() diff --git a/apps/setting/models_provider/impl/xf_model_provider/credential/stt.py b/apps/setting/models_provider/impl/xf_model_provider/credential/stt.py deleted file mode 100644 index bf051c18a57750929a12a45b2d8d8dbc457e53e5..0000000000000000000000000000000000000000 --- a/apps/setting/models_provider/impl/xf_model_provider/credential/stt.py +++ /dev/null @@ -1,46 +0,0 @@ -# coding=utf-8 - -from typing import Dict - -from common import forms -from common.exception.app_exception import AppApiException -from common.forms import BaseForm -from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode - - -class XunFeiSTTModelCredential(BaseForm, BaseModelCredential): - spark_api_url = forms.TextInputField('API 域名', required=True, default_value='wss://iat-api.xfyun.cn/v2/iat') - spark_app_id = forms.TextInputField('APP ID', required=True) - spark_api_key = forms.PasswordInputField("API Key", required=True) - spark_api_secret = forms.PasswordInputField('API Secret', required=True) - - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, - raise_exception=False): - model_type_list = provider.get_model_type_list() - if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): - raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') - - for key in ['spark_api_url', 'spark_app_id', 'spark_api_key', 'spark_api_secret']: - if key not in model_credential: - if raise_exception: - raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') - else: - return False - try: - model = provider.get_model(model_type, model_name, model_credential) - model.check_auth() - except Exception as e: - if isinstance(e, AppApiException): - raise e - if raise_exception: - raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') - else: - return False - return True - - def encryption_dict(self, model: Dict[str, object]): - return {**model, 'spark_api_secret': super().encryption(model.get('spark_api_secret', ''))} - - - def get_model_params_setting_form(self, model_name): - pass diff --git a/apps/setting/models_provider/impl/xf_model_provider/credential/tts.py b/apps/setting/models_provider/impl/xf_model_provider/credential/tts.py deleted file mode 100644 index 309f882b9b07ee46fd2f7cfad16bd126629db2ea..0000000000000000000000000000000000000000 --- a/apps/setting/models_provider/impl/xf_model_provider/credential/tts.py +++ /dev/null @@ -1,46 +0,0 @@ -# coding=utf-8 - -from typing import Dict - -from common import forms -from common.exception.app_exception import AppApiException -from common.forms import BaseForm -from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode - - -class XunFeiTTSModelCredential(BaseForm, BaseModelCredential): - spark_api_url = forms.TextInputField('API 域名', required=True, default_value='wss://tts-api.xfyun.cn/v2/tts') - spark_app_id = forms.TextInputField('APP ID', required=True) - spark_api_key = forms.PasswordInputField("API Key", required=True) - spark_api_secret = forms.PasswordInputField('API Secret', required=True) - - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, - raise_exception=False): - model_type_list = provider.get_model_type_list() - if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): - raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') - - for key in ['spark_api_url', 'spark_app_id', 'spark_api_key', 'spark_api_secret']: - if key not in model_credential: - if raise_exception: - raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') - else: - return False - try: - model = provider.get_model(model_type, model_name, model_credential) - model.check_auth() - except Exception as e: - if isinstance(e, AppApiException): - raise e - if raise_exception: - raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') - else: - return False - return True - - def encryption_dict(self, model: Dict[str, object]): - return {**model, 'spark_api_secret': super().encryption(model.get('spark_api_secret', ''))} - - - def get_model_params_setting_form(self, model_name): - pass diff --git a/apps/setting/models_provider/impl/xf_model_provider/model/stt.py b/apps/setting/models_provider/impl/xf_model_provider/model/stt.py deleted file mode 100644 index f57e6bf13b78782386c2ce49a99785a15634bd4c..0000000000000000000000000000000000000000 --- a/apps/setting/models_provider/impl/xf_model_provider/model/stt.py +++ /dev/null @@ -1,169 +0,0 @@ -# -*- coding:utf-8 -*- -# -# 错误码链接:https://www.xfyun.cn/document/error-code (code返回错误码时必看) -# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # -import asyncio -import base64 -import datetime -import hashlib -import hmac -import json -from datetime import datetime -from typing import Dict -from urllib.parse import urlencode, urlparse -import ssl -import websockets - -from setting.models_provider.base_model_provider import MaxKBBaseModel -from setting.models_provider.impl.base_stt import BaseSpeechToText - -STATUS_FIRST_FRAME = 0 # 第一帧的标识 -STATUS_CONTINUE_FRAME = 1 # 中间帧标识 -STATUS_LAST_FRAME = 2 # 最后一帧的标识 - -ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) -ssl_context.check_hostname = False -ssl_context.verify_mode = ssl.CERT_NONE - - -class XFSparkSpeechToText(MaxKBBaseModel, BaseSpeechToText): - spark_app_id: str - spark_api_key: str - spark_api_secret: str - spark_api_url: str - - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.spark_api_url = kwargs.get('spark_api_url') - self.spark_app_id = kwargs.get('spark_app_id') - self.spark_api_key = kwargs.get('spark_api_key') - self.spark_api_secret = kwargs.get('spark_api_secret') - - @staticmethod - def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): - optional_params = {} - if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: - optional_params['max_tokens'] = model_kwargs['max_tokens'] - if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: - optional_params['temperature'] = model_kwargs['temperature'] - return XFSparkSpeechToText( - spark_app_id=model_credential.get('spark_app_id'), - spark_api_key=model_credential.get('spark_api_key'), - spark_api_secret=model_credential.get('spark_api_secret'), - spark_api_url=model_credential.get('spark_api_url'), - **optional_params - ) - - # 生成url - def create_url(self): - url = self.spark_api_url - host = urlparse(url).hostname - # 生成RFC1123格式的时间戳 - gmt_format = '%a, %d %b %Y %H:%M:%S GMT' - date = datetime.utcnow().strftime(gmt_format) - - # 拼接字符串 - signature_origin = "host: " + host + "\n" - signature_origin += "date: " + date + "\n" - signature_origin += "GET " + "/v2/iat " + "HTTP/1.1" - # 进行hmac-sha256进行加密 - signature_sha = hmac.new(self.spark_api_secret.encode('utf-8'), signature_origin.encode('utf-8'), - digestmod=hashlib.sha256).digest() - signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8') - - authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % ( - self.spark_api_key, "hmac-sha256", "host date request-line", signature_sha) - authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8') - # 将请求的鉴权参数组合为字典 - v = { - "authorization": authorization, - "date": date, - "host": host - } - # 拼接鉴权参数,生成url - url = url + '?' + urlencode(v) - # print("date: ",date) - # print("v: ",v) - # 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致 - # print('websocket url :', url) - return url - - def check_auth(self): - async def check(): - async with websockets.connect(self.create_url(), ssl=ssl_context) as ws: - pass - - asyncio.run(check()) - - def speech_to_text(self, file): - async def handle(): - async with websockets.connect(self.create_url(), max_size=1000000000, ssl=ssl_context) as ws: - # 发送 full client request - await self.send(ws, file) - return await self.handle_message(ws) - - return asyncio.run(handle()) - - @staticmethod - async def handle_message(ws): - res = await ws.recv() - message = json.loads(res) - code = message["code"] - sid = message["sid"] - if code != 0: - errMsg = message["message"] - print("sid:%s call error:%s code is:%s" % (sid, errMsg, code)) - return errMsg - else: - data = message["data"]["result"]["ws"] - result = "" - for i in data: - for w in i["cw"]: - result += w["w"] - # print("sid:%s call success!,data is:%s" % (sid, json.dumps(data, ensure_ascii=False))) - return result - - # 收到websocket连接建立的处理 - async def send(self, ws, file): - frameSize = 8000 # 每一帧的音频大小 - status = STATUS_FIRST_FRAME # 音频的状态信息,标识音频是第一帧,还是中间帧、最后一帧 - - while True: - buf = file.read(frameSize) - # 文件结束 - if not buf: - status = STATUS_LAST_FRAME - # 第一帧处理 - # 发送第一帧音频,带business 参数 - # appid 必须带上,只需第一帧发送 - if status == STATUS_FIRST_FRAME: - d = { - "common": {"app_id": self.spark_app_id}, - "business": { - "domain": "iat", - "language": "zh_cn", - "accent": "mandarin", - "vinfo": 1, - "vad_eos": 10000 - }, - "data": { - "status": 0, "format": "audio/L16;rate=16000", - "audio": str(base64.b64encode(buf), 'utf-8'), - "encoding": "lame"} - } - d = json.dumps(d) - await ws.send(d) - status = STATUS_CONTINUE_FRAME - # 中间帧处理 - elif status == STATUS_CONTINUE_FRAME: - d = {"data": {"status": 1, "format": "audio/L16;rate=16000", - "audio": str(base64.b64encode(buf), 'utf-8'), - "encoding": "lame"}} - await ws.send(json.dumps(d)) - # 最后一帧处理 - elif status == STATUS_LAST_FRAME: - d = {"data": {"status": 2, "format": "audio/L16;rate=16000", - "audio": str(base64.b64encode(buf), 'utf-8'), - "encoding": "lame"}} - await ws.send(json.dumps(d)) - break diff --git a/apps/setting/models_provider/impl/xf_model_provider/model/tts.py b/apps/setting/models_provider/impl/xf_model_provider/model/tts.py deleted file mode 100644 index 8d775583265f474cdc5d1a32506d8943eb50545d..0000000000000000000000000000000000000000 --- a/apps/setting/models_provider/impl/xf_model_provider/model/tts.py +++ /dev/null @@ -1,142 +0,0 @@ -# -*- coding:utf-8 -*- -# -# author: iflytek -# -# 错误码链接:https://www.xfyun.cn/document/error-code (code返回错误码时必看) -# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # -import asyncio -import base64 -import datetime -import hashlib -import hmac -import json -import os -from datetime import datetime -from typing import Dict -from urllib.parse import urlencode, urlparse -import ssl -import websockets - -from setting.models_provider.base_model_provider import MaxKBBaseModel -from setting.models_provider.impl.base_tts import BaseTextToSpeech - -STATUS_FIRST_FRAME = 0 # 第一帧的标识 -STATUS_CONTINUE_FRAME = 1 # 中间帧标识 -STATUS_LAST_FRAME = 2 # 最后一帧的标识 - -ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) -ssl_context.check_hostname = False -ssl_context.verify_mode = ssl.CERT_NONE - - -class XFSparkTextToSpeech(MaxKBBaseModel, BaseTextToSpeech): - spark_app_id: str - spark_api_key: str - spark_api_secret: str - spark_api_url: str - - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.spark_api_url = kwargs.get('spark_api_url') - self.spark_app_id = kwargs.get('spark_app_id') - self.spark_api_key = kwargs.get('spark_api_key') - self.spark_api_secret = kwargs.get('spark_api_secret') - - @staticmethod - def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): - optional_params = {} - if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: - optional_params['max_tokens'] = model_kwargs['max_tokens'] - if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: - optional_params['temperature'] = model_kwargs['temperature'] - return XFSparkTextToSpeech( - spark_app_id=model_credential.get('spark_app_id'), - spark_api_key=model_credential.get('spark_api_key'), - spark_api_secret=model_credential.get('spark_api_secret'), - spark_api_url=model_credential.get('spark_api_url'), - **optional_params - ) - - # 生成url - def create_url(self): - url = self.spark_api_url - host = urlparse(url).hostname - # 生成RFC1123格式的时间戳 - gmt_format = '%a, %d %b %Y %H:%M:%S GMT' - date = datetime.utcnow().strftime(gmt_format) - - # 拼接字符串 - signature_origin = "host: " + host + "\n" - signature_origin += "date: " + date + "\n" - signature_origin += "GET " + "/v2/tts " + "HTTP/1.1" - # 进行hmac-sha256进行加密 - signature_sha = hmac.new(self.spark_api_secret.encode('utf-8'), signature_origin.encode('utf-8'), - digestmod=hashlib.sha256).digest() - signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8') - - authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % ( - self.spark_api_key, "hmac-sha256", "host date request-line", signature_sha) - authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8') - # 将请求的鉴权参数组合为字典 - v = { - "authorization": authorization, - "date": date, - "host": host - } - # 拼接鉴权参数,生成url - url = url + '?' + urlencode(v) - # print("date: ",date) - # print("v: ",v) - # 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致 - # print('websocket url :', url) - return url - - def check_auth(self): - async def check(): - async with websockets.connect(self.create_url(), max_size=1000000000, ssl=ssl_context) as ws: - pass - - asyncio.run(check()) - - def text_to_speech(self, text): - - # 使用小语种须使用以下方式,此处的unicode指的是 utf16小端的编码方式,即"UTF-16LE"” - # self.Data = {"status": 2, "text": str(base64.b64encode(self.Text.encode('utf-16')), "UTF8")} - async def handle(): - async with websockets.connect(self.create_url(), max_size=1000000000, ssl=ssl_context) as ws: - # 发送 full client request - await self.send(ws, text) - return await self.handle_message(ws) - - return asyncio.run(handle()) - - @staticmethod - async def handle_message(ws): - audio_bytes: bytes = b'' - while True: - res = await ws.recv() - message = json.loads(res) - # print(message) - code = message["code"] - sid = message["sid"] - audio = message["data"]["audio"] - audio = base64.b64decode(audio) - - if code != 0: - errMsg = message["message"] - print("sid:%s call error:%s code is:%s" % (sid, errMsg, code)) - else: - audio_bytes += audio - # 退出 - if message["data"]["status"] == 2: - break - return audio_bytes - - async def send(self, ws, text): - d = { - "common": {"app_id": self.spark_app_id}, - "business": {"aue": "lame", "sfl": 1, "auf": "audio/L16;rate=16000", "vcn": "xiaoyan", "tte": "utf8"}, - "data": {"status": 2, "text": str(base64.b64encode(text.encode('utf-8')), "UTF8")}, - } - d = json.dumps(d) - await ws.send(d) diff --git a/apps/setting/models_provider/impl/xf_model_provider/xf_model_provider.py b/apps/setting/models_provider/impl/xf_model_provider/xf_model_provider.py index 8c60083cf6f18bf2314a810e61afdb89d2b3a9aa..d33b944e375e64e45314dc86ce4f80cba770ef7c 100644 --- a/apps/setting/models_provider/impl/xf_model_provider/xf_model_provider.py +++ b/apps/setting/models_provider/impl/xf_model_provider/xf_model_provider.py @@ -13,25 +13,16 @@ from common.util.file_util import get_file_content from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, ModelInfo, IModelProvider, \ ModelInfoManage from setting.models_provider.impl.xf_model_provider.credential.llm import XunFeiLLMModelCredential -from setting.models_provider.impl.xf_model_provider.credential.stt import XunFeiSTTModelCredential -from setting.models_provider.impl.xf_model_provider.credential.tts import XunFeiTTSModelCredential from setting.models_provider.impl.xf_model_provider.model.llm import XFChatSparkLLM -from setting.models_provider.impl.xf_model_provider.model.stt import XFSparkSpeechToText -from setting.models_provider.impl.xf_model_provider.model.tts import XFSparkTextToSpeech from smartdoc.conf import PROJECT_DIR ssl._create_default_https_context = ssl.create_default_context() qwen_model_credential = XunFeiLLMModelCredential() -stt_model_credential = XunFeiSTTModelCredential() -tts_model_credential = XunFeiTTSModelCredential() -model_info_list = [ - ModelInfo('generalv3.5', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM), - ModelInfo('generalv3', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM), - ModelInfo('generalv2', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM), - ModelInfo('iat', '中英文识别', ModelTypeConst.STT, stt_model_credential, XFSparkSpeechToText), - ModelInfo('tts', '', ModelTypeConst.TTS, tts_model_credential, XFSparkTextToSpeech), -] +model_info_list = [ModelInfo('generalv3.5', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM), + ModelInfo('generalv3', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM), + ModelInfo('generalv2', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM) + ] model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info( ModelInfo('generalv3.5', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM)).build() diff --git a/apps/setting/serializers/model_apply_serializers.py b/apps/setting/serializers/model_apply_serializers.py index fd4186987050ff0bcf121759e47915bc50d29472..8f56af3b1a6605b4a753f434137f301228844425 100644 --- a/apps/setting/serializers/model_apply_serializers.py +++ b/apps/setting/serializers/model_apply_serializers.py @@ -46,14 +46,16 @@ class CompressDocuments(serializers.Serializer): class ModelApplySerializers(serializers.Serializer): model_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("模型id")) - def embed_documents(self, instance, with_valid=True): if with_valid: self.is_valid(raise_exception=True) EmbedDocuments(data=instance).is_valid(raise_exception=True) model = get_embedding_model(self.data.get('model_id')) - return model.embed_documents(instance.getlist('texts')) + texts = instance.getlist('texts') + for text in texts: + print("to embed texts:",text) + return model.embed_documents(texts) def embed_query(self, instance, with_valid=True): if with_valid: @@ -61,13 +63,6 @@ class ModelApplySerializers(serializers.Serializer): EmbedQuery(data=instance).is_valid(raise_exception=True) model = get_embedding_model(self.data.get('model_id')) - return model.embed_query(instance.get('text')) - - def compress_documents(self, instance, with_valid=True): - if with_valid: - self.is_valid(raise_exception=True) - CompressDocuments(data=instance).is_valid(raise_exception=True) - model = get_embedding_model(self.data.get('model_id')) - return [{'page_content': d.page_content, 'metadata': d.metadata} for d in model.compress_documents( - [Document(page_content=document.get('page_content'), metadata=document.get('metadata')) for document in - instance.get('documents')], instance.get('query'))] + text = instance.get('text') + query = model.embed_query(text) + return query diff --git a/apps/smartdoc/const.py b/apps/smartdoc/const.py index 9b1159aa484a0cb37f6e24656f4ab0f147923db0..c1fba498c3d5eaba9c8770ed67f51a41b52de1e2 100644 --- a/apps/smartdoc/const.py +++ b/apps/smartdoc/const.py @@ -9,4 +9,4 @@ __all__ = ['BASE_DIR', 'PROJECT_DIR', 'VERSION', 'CONFIG'] BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) PROJECT_DIR = os.path.dirname(BASE_DIR) VERSION = '1.0.0' -CONFIG = ConfigManager.load_user_config(root_path=os.path.abspath('/opt/maxkb/conf')) +CONFIG = ConfigManager.load_user_config(root_path=os.path.abspath('J:/workspace/openai/MaxKB/conf')) diff --git a/conf/config_example.yml b/conf/config_example.yml new file mode 100644 index 0000000000000000000000000000000000000000..e47f57819191204df3c187b32e0aa36d5169b738 --- /dev/null +++ b/conf/config_example.yml @@ -0,0 +1,29 @@ +# 数据库链接信息 +DB_NAME: maxkb +DB_HOST: 192.168.10.245 +DB_PORT: 5432 +DB_USER: postgres +DB_PASSWORD: nmghl@123 +DB_ENGINE: django.db.backends.postgresql_psycopg2 + +DEBUG: false + +TIME_ZONE: Asia/Shanghai + + +EMBEDDING_MODEL_NAME: BAAI/bge-m3 +EMBEDDING_DEVICE: cpu + +LOCAL_MODEL_HOST: 192.168.10.217 +LOCAL_MODEL_PORT: 11636 +LOCAL_MODEL_PROTOCOL: http + +# 使用的向量数据库 pg_vector | es_vector +VECTOR_STORE_NAME : es_vector +# es链接地址 +VECTOR_STORE_ES_ENDPOINT: http://192.168.10.252:9200 +# es库用户名和密码,默认不开启用于验证 +VECTOR_STORE_ES_USERNAME: default +VECTOR_STORE_ES_PASSWORD: default +# es索引库名称 +VECTOR_STORE_ES_INDEX: maxkb_vector \ No newline at end of file diff --git a/main.py b/main.py index a8bd74af453d624f91b477a7f874fd6b075bc2ad..b27e7ae85938da4d8f5c70dd7c161b37adee99e5 100644 --- a/main.py +++ b/main.py @@ -70,7 +70,7 @@ def start_services(): def dev(): services = args.services if isinstance(args.services, list) else args.services if services.__contains__('web'): - management.call_command('runserver', "0.0.0.0:8080") + management.call_command('runserver', "192.168.10.217:8080") elif services.__contains__('celery'): management.call_command('celery', 'celery') elif services.__contains__('local_model'): @@ -81,7 +81,8 @@ def dev(): if __name__ == '__main__': - os.environ['HF_HOME'] = '/opt/maxkb/model/base' + os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com' + os.environ['HF_HOME'] = 'J:/workspace/openai/maxkb/model/base' parser = argparse.ArgumentParser( description=""" qabot service control tools; diff --git a/model/base/hub/version.txt b/model/base/hub/version.txt new file mode 100644 index 0000000000000000000000000000000000000000..56a6051ca2b02b04ef92d5150c9ef600403cb1de --- /dev/null +++ b/model/base/hub/version.txt @@ -0,0 +1 @@ +1 \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 37d1f61dfdd269eb2886ed55fe67abf4c8045119..c93ed0be0cd18ca7be281af2fcb5a4204d005d93 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ dashscope = "^1.17.0" zhipuai = "^2.0.1" httpx = "^0.27.0" httpx-sse = "^0.4.0" -websockets = "^13.0" +websocket-client = "^1.7.0" langchain-google-genai = "^1.0.3" openpyxl = "^3.1.2" xlrd = "^2.0.1" @@ -60,6 +60,8 @@ celery = { extras = ["sqlalchemy"], version = "^5.4.0" } eventlet = "^0.36.1" django-celery-beat = "^2.6.0" celery-once = "^3.0.1" +# 加入elastic依赖 +elasticsearch = "^8.15.0" [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" diff --git a/ui/src/enums/model.ts b/ui/src/enums/model.ts index 52ff935d8ad0620c7770c0192ef1f6f85192c458..c27fc61fd3ac6f8dc72cbf2c256fe6237d3bc81b 100644 --- a/ui/src/enums/model.ts +++ b/ui/src/enums/model.ts @@ -9,9 +9,7 @@ export enum PermissionDesc { export enum modelType { EMBEDDING = '向量模型', - LLM = '大语言模型', - STT = '语音识别', - TTS = '语音合成', + LLM = '大语言模型' }