ChocolateBlack
基于Langchain-Chatchat和BERT-VITS2实现RAG多路召回+AI语音播报

基于Langchain-Chatchat和BERT-VITS2实现RAG多路召回+AI语音播报

前言

目前Langchain-Chatchat的0.2.x版本并不支持多路召回,着实有点遗憾,于是我个人在Langchain-Chatchat的基础上增加了可以融合多个召回算法的多路召回demo以及BM25检索算法,并且结合了BERT-VITS2,实现大模型回答的AI语音播报。

项目的简单演示视频:

https://www.bilibili.com/video/BV11w4m1R7H8

完整代码可以看我的Github项目:

https://github.com/Chocolate-Black/Langchain-MO-AI-Chat

下面简单介绍下我是怎么在Langchain-Chatchat上进行修改的。

多路召回部分

Langchain-Chatchat的知识库功能,主要通过KBService基类进行实现,子类负责实现具体方法。文件存放在server/knowledge_base/kb_service下。

这里以Langchain-Chatchat自带的DefaultKBService为例,对核心函数进行讲解:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class DefaultKBService(KBService):
def do_create_kb(self):
pass

def do_drop_kb(self):
pass

def do_add_doc(self, docs: List[Document]):
pass

def do_clear_vs(self):
pass

def vs_type(self) -> str:
return "default"

def do_init(self):
pass

def do_search(self):
pass

def do_insert_multi_knowledge(self): # 暂不清楚功能,非必需
pass

def do_insert_one_knowledge(self):# 暂不清楚功能,非必需
pass

def do_delete_doc(self):
pass

do_init

知识库的初始化操作,创建知识库后会执行。这里一般用来初始化文件路径等等。

在我实现的用于多路召回的MixKBService中,do_init会根据选择的KBService,依次进行初始化,代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def get_vs_path(self):
return os.path.join(get_kb_path(self.kb_name), "vector_store")

def get_kb_path(self):
return get_kb_path(self.kb_name)

def init_mix_vs_types(self,vs_types: List[str] = MIX_VS_TYPES):
self.mix_vs_types = vs_types

def do_init(self):
self.kb_path = self.get_kb_path()
self.vs_path = self.get_vs_path()
self.init_mix_vs_types()
services = []
for vs_type in self.mix_vs_types:
kb = KBServiceFactory.get_service(self.kb_name,vector_store_type=vs_type,embed_model=self.embed_model)
services.append(kb)

MIX_VS_TYPES定义于configs/kb_config.py,是存放vs_types的列表。

KBServiceFactory.get_service用于获取对应vs_type的KBService。

vs_type

知识库的类别,用于区分不同的知识库类别。KBServiceFactory类会根据输入的类别(SupportedVSType)返回相应的知识库service。

因此,如果我们想要创建自己的知识库Service类,就需要在server/knowledge_base/kb_service/base.py里增加SupportedVSType:

1
2
3
4
5
6
7
8
9
10
class SupportedVSType:
FAISS = 'faiss'
MILVUS = 'milvus'
DEFAULT = 'default'
ZILLIZ = 'zilliz'
PG = 'pg'
ES = 'es'
CHROMADB = 'chromadb'
BM25 = 'bm25'
MIX = 'mix'

这里我增加了BM25和MIX两个类别,其中MIX为我实现多路召回的Service的vs_type。

然后,对该文件下的KBServiceFactory类的get_service增加elif选项:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
class KBServiceFactory:

@staticmethod
def get_service(kb_name: str,
vector_store_type: Union[str, SupportedVSType],
embed_model: str = EMBEDDING_MODEL,
) -> KBService:
if isinstance(vector_store_type, str):
vector_store_type = getattr(SupportedVSType, vector_store_type.upper())
if SupportedVSType.FAISS == vector_store_type:
from server.knowledge_base.kb_service.faiss_kb_service import FaissKBService
return FaissKBService(kb_name, embed_model=embed_model)
elif SupportedVSType.PG == vector_store_type:
from server.knowledge_base.kb_service.pg_kb_service import PGKBService
return PGKBService(kb_name, embed_model=embed_model)
elif SupportedVSType.MILVUS == vector_store_type:
from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
return MilvusKBService(kb_name,embed_model=embed_model)
elif SupportedVSType.ZILLIZ == vector_store_type:
from server.knowledge_base.kb_service.zilliz_kb_service import ZillizKBService
return ZillizKBService(kb_name, embed_model=embed_model)
elif SupportedVSType.DEFAULT == vector_store_type:
from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
return MilvusKBService(kb_name,
embed_model=embed_model) # other milvus parameters are set in model_config.kbs_config
elif SupportedVSType.ES == vector_store_type:
from server.knowledge_base.kb_service.es_kb_service import ESKBService
return ESKBService(kb_name, embed_model=embed_model)
elif SupportedVSType.CHROMADB == vector_store_type:
from server.knowledge_base.kb_service.chromadb_kb_service import ChromaKBService
return ChromaKBService(kb_name, embed_model=embed_model)
# BM25
elif SupportedVSType.BM25 == vector_store_type:
from server.knowledge_base.kb_service.bm25_kb_service import BM25KBService
return BM25KBService(kb_name)
# 多路召回
elif SupportedVSType.MIX == vector_store_type:
from server.knowledge_base.kb_service.mix_kb_service import MixKBService
return MixKBService(kb_name)
elif SupportedVSType.DEFAULT == vector_store_type: # kb_exists of default kbservice is False, to make validation easier.
from server.knowledge_base.kb_service.default_kb_service import DefaultKBService
return DefaultKBService(kb_name)

do_create_kb

创建知识库。该函数主要是用来创建vector_store文件夹,创建并存放索引文件。

1
2
3
def do_create_kb(self):
for service in self.kb_services:
service.do_create_kb() # 挨个创建

do_clear_vs

删除向量文件夹(一般都是vectore_store)内的索引文件。

1
2
3
def do_clear_vs(self):
for service in self.kb_services:
service.do_clear_vs() # 挨个清空

do_drop_kb

彻底删除知识库,包括存放知识库文件的文件夹(content)以及向量文件夹(vector_store)

1
2
3
4
5
6
def do_drop_kb(self):
self.do_clear_vs()
try:
shutil.rmtree(self.kb_path)
except Exception:
...

do_add_doc

该函数用于向知识库添加文档。对于向量数据库来说(faiss),还需要将文档先转变为embedding。

添加完成之后需要返回文档的信息,用于存储到info.db中,这里以FaissKBService的do_add_doc函数为例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def do_add_doc(self,
docs: List[Document],
**kwargs,
) -> List[Dict]:
data = self._docs_to_embeddings(docs) # 将向量化单独出来可以减少向量库的锁定时间

with self.load_vector_store().acquire() as vs:
ids = vs.add_embeddings(text_embeddings=zip(data["texts"], data["embeddings"]),
metadatas=data["metadatas"],
ids=kwargs.get("ids"))
if not kwargs.get("not_refresh_vs_cache"):
vs.save_local(self.vs_path)
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
torch_gc()
return doc_infos

在我实现的MixKBService类中,只需要挨个添加即可,然后返回其中一个doc_infos:

1
2
3
4
5
6
def do_add_doc(self, docs: List[Document],**kwargs)-> List[Dict]:
doc_infos = []
for service in self.kb_services:
doc_infos = service.do_add_doc(docs,**kwargs)

return doc_infos

do_delete_doc

删除知识库内某个文件下的所有文档。也是一样的,挨个删除。

1
2
3
4
5
def do_delete_doc(self,kb_file: KnowledgeFile,**kwargs):
ids = []
for service in self.kb_services:
ids = service.do_delete_doc(kb_file,**kwargs)
return ids

save_vector_store/load_vector_store

虽然DefaultKBService没有这俩函数,但一般还是需要实现。

save/load主要用于保存/加载索引文件,根据算法/向量库的不同,这里的写法也会不同。这里不再细说,具体可参考我github上的代码。

do_search

重点来了。这一函数用于搜索与query最相关的top_k个文档,也是多路召回算法的关键处。

这里我采用的是比较简单的多路召回策略:每个service搜索top_k个文档,然后取并集,再根据score进行排序,选取前top_k个。其中召回用到的算法有:向量最大相似度、BM25。

另外最大相似度的得分标准(越小越相似)和BM25的得分标准(越大越相似)相反,因此我对BM25的得分进行倒数处理并做了简单的权重平衡操作。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
def do_search(self,
query: str,
top_k: int,
score_threshold: float = SCORE_THRESHOLD)-> List[Tuple[Document, float]]:
docs_info = dict()
results = []
for service in self.kb_services:
docs = service.do_search(query,top_k,score_threshold)
for doc in docs:
md5 = hashlib.md5()
md5.update(doc[0].page_content.encode('utf-8'))
key = md5.hexdigest()
if service.vs_type() == 'bm25':
score = 4/doc[1] - 0.01
else:
score = doc[1]
if key not in docs_info:
docs_info[key] = {'doc':doc[0],'score':score,'cnt':1}
else:
docs_info[key]['cnt'] += 1
docs_info[key]['score'] += score

for key in docs_info.keys():
score = docs_info[key]['score']/docs_info[key]['cnt']
pair = (docs_info[key]['doc'],score)
results.append(pair)

results = sorted(results,key=lambda t: t[1])
if len(results) > top_k:
return results[:top_k]
else:
return results

关于BM25Service的实现方法,这里不再讲解,具体可以参考我github上的代码。

AI语音部分

关于BERT-VITS2怎么训练自己的语音模型,限于篇幅不进行讲解。有兴趣可以去看下BERT-VITS2的教程,一抓一大把。

这里主要讲解如何将AI语音模型结合到Langchain-Chatchat的WebUI中,实现大模型回答的语音播报。

WebUI对话功能与界面,主要通过webui_pages/dialogue/dialogue.py实现。

首先,BERT-VITS2需要确定模型文件路径、说话人、以及sdp_ratio/noise_scale/noise_scale_w/length_scale四项参数。所以我们需要在WebUI中让用户进行选择:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
def on_project_change():
project = st.session_state.dialogue_project
config_path = PROJECTS[project]['config_path']
st.session_state["audio_generator"] = AudioGenerator(project,config_path)
st.toast(f"切换Project为:{project}")

dialogue_projects = list(PROJECTS.keys())
default_index = dialogue_projects.index(DEFAULT_PROJECT)

dialogue_project = st.selectbox("请选择Project:",
dialogue_projects,
index=default_index,
on_change=on_project_change,
key="dialogue_project",
)

def on_speaker_change():
speaker = st.session_state.dialogue_speaker
text = f"当前语音助手: {speaker}。"
st.toast(text)

if PROJECTS[dialogue_project]['config_path'] == "":
hps = None
speakers = []
else:
hps = get_hparams_from_file(PROJECTS[dialogue_project]['config_path'])
speaker_ids = hps.data.spk2id
speakers = list(speaker_ids.keys())
dialogue_speaker = st.selectbox("请选择语音助手:",
speakers,
index=0,
on_change=on_speaker_change,
key="dialogue_speaker",
)

sdp_ratio = st.slider("sdp ratio:", 0.0, 1.0, SDP_RATIO, 0.1)
noise_scale = st.slider("noise scale:", 0.1, 2.0, NOISE_SCALE, 0.1)
noise_scale_w = st.slider("noise scale w:", 0.1, 2.0, NOISE_SCALE_W, 0.1)
length_scale = st.slider("length scale:", 0.1, 2.0, LENGTH_SCALE, 0.1)

其中,PROJECTS/SDP_RATIO/NOISE_SCALE/NOISE_SCALE_W/LENGTH_SCALE在configs/audio_config.py(我自己额外增加的配置文件)定义,格式如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# 请填写/修改对应文件的路径,其中xxx.pth为BERT-VITS2训练得到的模型文件,config.json文件为训练过程中用到的配置文件。具体请参考BERT-VITS2项目的说明。
PROJECTS = {
"无":{
"model_path":"",
"config_path":"",
},
"白河萤":{
"model_path":"./audio/Data/hotaru/models/hotaru_G_900.pth",
"config_path":"./audio/Data/hotaru/config.json",
},
"稻穗信":{
"model_path":"./audio/Data/shin/models/shin_G_600.pth",
"config_path":"./audio/Data/shin/config.json",
},
"嘉神川克罗艾":{
"model_path":"./audio/Data/chloe/models/chloe_G_1500.pth",
"config_path":"./audio/Data/chloe/config.json",
},
"嘉神川诺艾儿":{
"model_path":"./audio/Data/noelle/models/noelle_G_500.pth",
"config_path":"./audio/Data/noelle/config.json",
},
"三城柚莉":{
"model_path":"./audio/Data/yuzuri/models/yuzuri_G_1000.pth",
"config_path":"./audio/Data/yuzuri/config.json",
},
}

DEFAULT_PROJECT = "无"

SDP_RATIO = 0.5
NOISE_SCALE = 0.6
NOISE_SCALE_W = 0.9
LENGTH_SCALE = 1.0

DEFAULT_EMOTION = 'Normal'
DEVICE = 'cuda'

AudioGenerator是我实现的用于将文本转换为语音的类,get_hparams_from_file为BERT-VITS2中的函数,这里不再细讲,详情请参考我github上的代码。

然后是实现文本转音频的函数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def convert_text_to_audio(text:str):
flag = True
params = dict()
try:
text_chunks = text.split("。")
audio_data = []
for t in text_chunks:
if t == "":
continue
audio_data_part,params['sample_rate'] = st.session_state['audio_generator'].convert_text_to_audio(t, dialogue_speaker, sdp_ratio, noise_scale, noise_scale_w, length_scale, emotion=DEFAULT_EMOTION, prompt_mode= 'Text prompt')

audio_data.append(audio_data_part)
audio_data = np.concatenate(audio_data)
except Exception as e:
print(type(e))
print(str(e))
traceback.print_exc()
traceback.print_exc(file=open('log.txt', 'a'))
flag = False

if flag == True:
audio_element = Audio(content=audio_data,**params)
return audio_element
else:
return

convert_text_to_audio负责将文本转变成Audio类(其中Audio为Langchain-Chatchat自己实现的类,为OutputElement的子类)。转为Audio类之后,后续就可以很方便地展示在网页中了。

以LLM对话模式为例,如果我们想要展示音频,首先需要在chat_box.ai_say中增加一个额外的Element,代码如下:

1
2
3
4
5
6
if dialogue_mode == "LLM 对话":
if dialogue_project != "无":
chat_box.ai_say(["正在思考...",
"请等待语音生成..."])
else:
chat_box.ai_say(["正在思考..."])

这里新增的元素(“请等待语音生成…”)后续可以被替换成Audio。

大模型在得到回答后,首先通过convert_text_to_audio把回答的文本转换为Audio,然后再通过chat_box.update_msg将对应element_index下的元素替换成Audio:

1
2
3
4
5
6
7
chat_box.update_msg(text, element_index=0,streaming=False, metadata=metadata)  # 更新最终的字符串,去除光标
if dialogue_project != "无":
audio_element = convert_text_to_audio(text)
if audio_element != None:
chat_box.update_msg(audio_element, element_index=1,streaming=False)
else:
chat_box.update_msg("语音生成失败! 请检查报错文件log.txt!", element_index=1,streaming=False)

这里替换的是”请等待语音生成…”,其element_index=1。另外记得把streaming设置为False。

其他对话模式也是同理,这里就不再重复。

另外,这里面有个大坑!虽然Langchain-Chatchat内有Audio/Image/Video等类,但是Streamlit的的ChatBox类下的export2md函数并不支持音频格式,且未做区分。直接使用会导致报错。因此需要对环境里的streamlit_chatbox/messages.py进行修改,请把182行的代码修改为:

1
contents = [e.content for e in msg["elements"] if type(e.content) == str]

最后

未来打算尝试将本地AI绘画的lora模型融入进去。另外还有一些东西没细讲,例如BM25的实现、Embedding模型的选取等等,各位可以参考下我github上的代码。

本文作者:ChocolateBlack
本文链接:http://chocolateblack.club/2024/03/20/基于Langchain-Chatchat和BERT-VITS2实现RAG多路召回+AI语音生成/
版权声明:本文采用 CC BY-NC-SA 3.0 CN 协议进行许可