Files
ragflow_python/asr-monitor-test/app/asr_service.py

353 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# app/asr_service.py
import os
import json
import gzip
import uuid
import wave
import ssl
import logging
from fastapi import WebSocket, WebSocketDisconnect
from fastapi.routing import APIRouter
import websockets
import asyncio
from starlette.websockets import WebSocketState
# 初始化路由
asr_router = APIRouter()
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 定义常量
PROTOCOL_VERSION = 0b0001
DEFAULT_HEADER_SIZE = 0b0001
# 消息类型
FULL_CLIENT_REQUEST = 0b0001
AUDIO_ONLY_REQUEST = 0b0010
FULL_SERVER_RESPONSE = 0b1001
SERVER_ACK = 0b1011
SERVER_ERROR_RESPONSE = 0b1111
# 消息类型特定标志
NO_SEQUENCE = 0b0000 # 无检查序列
POS_SEQUENCE = 0b0001
NEG_SEQUENCE = 0b0010
NEG_WITH_SEQUENCE = 0b0011
# 消息序列化
NO_SERIALIZATION = 0b0000
JSON = 0b0001
# 消息压缩
NO_COMPRESSION = 0b0000
GZIP = 0b0001
# 假设这是你的目标WebSocket服务器的URL
TARGET_WS_URL = "wss://openspeech.bytedance.com/api/v3/sauc/bigmodel"
reqId = 1
wav_file = None
write_to_file = False
total_chunk_size = 0
total_chunks = 0
def create_wav_file(file_name, sample_width, sample_rate, channels):
"""
创建一个WAV文件并设置其参数。
参数:
- file_name: 输出WAV文件的名称。
- sample_width: 采样宽度通常为2字节表示16位
- sample_rate: 采样率例如44100Hz
- channels: 声道数1表示单声道2表示立体声
"""
# 检查文件是否存在,如果存在则打开并清空文件,否则创建新文件
mode = 'wb' if os.path.exists(file_name) else 'wb'
wav_file = wave.open(file_name, mode)
wav_file.setnchannels(channels)
wav_file.setsampwidth(sample_width)
wav_file.setframerate(sample_rate)
return wav_file
def write_pcm_data_to_wav(wav_file, pcm_data):
"""
将PCM数据写入到已打开的WAV文件。
参数:
- wav_file: 已经打开的WAV文件对象。
- pcm_data: 要写入的PCM音频数据的字节序列。
"""
wav_file.writeframesraw(pcm_data)
def close_wav_file(wav_file):
"""
关闭WAV文件。
参数:
- wav_file: 已经打开的WAV文件对象。
"""
wav_file.close()
# 生成请求头
def generate_header(message_type=FULL_CLIENT_REQUEST, message_type_specific_flags=NO_SEQUENCE, serial_method=JSON,
compression_type=GZIP, reserved_data=0x00):
header = bytearray()
header_size = 1
header.append((PROTOCOL_VERSION << 4) | header_size)
header.append((message_type << 4) | message_type_specific_flags)
header.append((serial_method << 4) | compression_type)
header.append(reserved_data)
return header
# 生成负载前置数据
def generate_before_payload(sequence: int):
before_payload = bytearray()
before_payload.extend(sequence.to_bytes(4, 'big', signed=True)) # sequence
return before_payload
# 构建完整请求
def construct_request(reqid):
req = {
"user": {
"uid": "test",
},
"audio": {
'format': "pcm",
"sample_rate": 16000,
"bits": 16,
"channel": 1,
"codec": "raw",
},
"request": {
"model_name": "bigmodel",
"enable_punc": True,
"enable_ddc": True,
"show_utterances": False
}
}
return req
def construct_full_request(reqid, seq):
request_params = construct_request(reqid)
payload_bytes = str.encode(json.dumps(request_params))
payload_bytes = gzip.compress(payload_bytes)
full_client_request = bytearray(generate_header(message_type_specific_flags=POS_SEQUENCE))
full_client_request.extend(generate_before_payload(sequence=seq))
full_client_request.extend((len(payload_bytes)).to_bytes(4, 'big'))
full_client_request.extend(payload_bytes) # payload
seq = seq + 1
return full_client_request
# 构建音频数据请求
def construct_audioData_request(audio_data, seq):
payload_bytes = gzip.compress(audio_data)
audio_only_request = bytearray(
generate_header(message_type=AUDIO_ONLY_REQUEST, message_type_specific_flags=POS_SEQUENCE))
audio_only_request.extend(generate_before_payload(sequence=seq))
audio_only_request.extend((len(payload_bytes)).to_bytes(4, 'big')) # payload size
audio_only_request.extend(payload_bytes) # payload
seq = seq + 1
return audio_only_request
# 解析响应
def parse_response(res):
protocol_version = res[0] >> 4
header_size = res[0] & 0x0f
message_type = res[1] >> 4
message_type_specific_flags = res[1] & 0x0f
serialization_method = res[2] >> 4
message_compression = res[2] & 0x0f
reserved = res[3]
header_extensions = res[4:header_size * 4]
payload = res[header_size * 4:]
result = {'is_last_package': False}
payload_msg = None
payload_size = 0
if message_type_specific_flags & 0x01:
seq = int.from_bytes(payload[:4], "big", signed=True)
result['payload_sequence'] = seq
payload = payload[4:]
if message_type_specific_flags & 0x02:
result['is_last_package'] = True
if message_type == FULL_SERVER_RESPONSE:
payload_size = int.from_bytes(payload[:4], "big", signed=True)
payload_msg = payload[4:]
elif message_type == SERVER_ACK:
seq = int.from_bytes(payload[:4], "big", signed=True)
result['seq'] = seq
if len(payload) >= 8:
payload_size = int.from_bytes(payload[4:8], "big", signed=False)
payload_msg = payload[8:]
elif message_type == SERVER_ERROR_RESPONSE:
code = int.from_bytes(payload[:4], "big", signed=False)
result['code'] = code
payload_size = int.from_bytes(payload[4:8], "big", signed=False)
payload_msg = payload[8:]
if payload_msg is None:
return result
if message_compression == GZIP:
payload_msg = gzip.decompress(payload_msg)
if serialization_method == JSON:
payload_msg = json.loads(str(payload_msg, "utf-8"))
elif serialization_method != NO_SERIALIZATION:
payload_msg = str(payload_msg, "utf-8")
result['payload_msg'] = payload_msg
result['payload_size'] = payload_size
return result
# 前向转发到目标WebSocket服务器
async def forward_to_server(websocket: WebSocket, client_ws: websockets.WebSocketClientProtocol, seq: int):
global wav_file, total_chunk_size, total_chunks
try:
while True:
audio_data = await websocket.receive_bytes()
if not audio_data:
break
if write_to_file:
write_pcm_data_to_wav(wav_file, audio_data)
total_chunk_size = total_chunk_size + len(audio_data)
total_chunks = total_chunks + 1
print(f"write {len(audio_data)} to file,total={total_chunk_size} {total_chunks}")
seq += 1 # 更新序列号
# print(f"get audio data from client {len(audio_data)}")
await client_ws.send(construct_audioData_request(audio_data, seq))
# seq += 1 # 更新序列号
except WebSocketDisconnect:
logger.info("Client disconnected, closing connection to target server.")
await client_ws.close()
except Exception as e:
logger.error(f"Error in forward_to_server: {e}")
await client_ws.close()
# 前向转发到客户端
async def forward_to_client(websocket: WebSocket, client_ws: websockets.WebSocketClientProtocol):
try:
while True:
response = await client_ws.recv()
if response:
data = parse_response(response)
# print(f"forward to client:{data}")
await websocket.send_text(json.dumps(data))
except websockets.exceptions.ConnectionClosed:
logger.warning(f"Target front client connection closed. {websocket.client_state}")
if websocket.client_state not in [WebSocketState.DISCONNECTED, WebSocketState.CLOSING]:
await websocket.close()
except Exception as e:
logger.error(f"Error in forward_to_client: {e} ")
# 最后关闭文件
if write_to_file:
close_wav_file(wav_file)
if websocket.client_state not in [WebSocketState.DISCONNECTED, WebSocketState.CLOSING]:
await websocket.close()
async def send_full_request():
reqid = str(uuid.uuid4())
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
header = {
"X-Api-Resource-Id": "volc.bigasr.sauc.duration",
"X-Api-Access-Key": "v64Fj-fwLLKIHBgqH2_fWx5dsBEShXd9",
"X-Api-App-Key": "7282190702",
"X-Api-Request-Id": reqid
}
seq = 1
try:
# connect doubao 时带了extra_headers 所以只能使用websockets 12.0 版本, pip install websockets == 12.0
async with websockets.connect(TARGET_WS_URL, extra_headers=header, ssl=ssl_context) as client_ws:
await client_ws.send(construct_full_request(reqid, seq))
print("send full_request to doubao server")
res = await client_ws.recv()
result = parse_response(res)
print("connect response", result)
except WebSocketDisconnect:
logger.info("Frontend client disconnected.")
except Exception as e:
logger.warning(f"Error in WebSocket connection: {e}")
finally:
try:
if websocket.client_state != WebSocketState.DISCONNECTED:
await websocket.close()
except Exception as e:
logger.error(f"Error while closing websocket: {e}")
@asr_router.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
global wav_file, write_to_file
reqid = str(uuid.uuid4())
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
header = {
"X-Api-Resource-Id": "volc.bigasr.sauc.duration",
"X-Api-Access-Key": "v64Fj-fwLLKIHBgqH2_fWx5dsBEShXd9",
"X-Api-App-Key": "7282190702",
"X-Api-Request-Id": reqid
}
seq = 1
await websocket.accept()
try:
async with websockets.connect(TARGET_WS_URL, extra_headers=header, ssl=ssl_context) as client_ws:
await client_ws.send(construct_full_request(reqid, seq))
print("send full_request to doubao server")
res = await client_ws.recv()
result = parse_response(res)
print("connect response", result)
if write_to_file:
wav_file = create_wav_file('test.wav', 2, 16000, 1)
print("start two task")
task1 = asyncio.create_task(forward_to_server(websocket, client_ws, seq))
task2 = asyncio.create_task(forward_to_client(websocket, client_ws))
await asyncio.gather(task1, task2)
except WebSocketDisconnect:
logger.info("Frontend client disconnected.")
except Exception as e:
logger.warning(f"front WebSocket client connection: {e}")
finally:
try:
if write_to_file:
close_wav_file(wav_file)
print("close wav file")
if websocket.client_state != WebSocketState.DISCONNECTED:
await websocket.close()
except Exception as e:
logger.error(f"Error while closing websocket: {e}")
# REST API路由
@asr_router.get("/health")
async def health():
return {"message": "I am ASR service;I am OK"}