353 lines
12 KiB
Python
353 lines
12 KiB
Python
# app/asr_service.py
|
||
import os
|
||
import json
|
||
import gzip
|
||
import uuid
|
||
import wave
|
||
import ssl
|
||
import logging
|
||
from fastapi import WebSocket, WebSocketDisconnect
|
||
from fastapi.routing import APIRouter
|
||
import websockets
|
||
import asyncio
|
||
from starlette.websockets import WebSocketState
|
||
|
||
# 初始化路由
|
||
asr_router = APIRouter()
|
||
|
||
# 配置日志
|
||
logging.basicConfig(level=logging.INFO)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# 定义常量
|
||
PROTOCOL_VERSION = 0b0001
|
||
DEFAULT_HEADER_SIZE = 0b0001
|
||
|
||
# 消息类型
|
||
FULL_CLIENT_REQUEST = 0b0001
|
||
AUDIO_ONLY_REQUEST = 0b0010
|
||
FULL_SERVER_RESPONSE = 0b1001
|
||
SERVER_ACK = 0b1011
|
||
SERVER_ERROR_RESPONSE = 0b1111
|
||
|
||
# 消息类型特定标志
|
||
NO_SEQUENCE = 0b0000 # 无检查序列
|
||
POS_SEQUENCE = 0b0001
|
||
NEG_SEQUENCE = 0b0010
|
||
NEG_WITH_SEQUENCE = 0b0011
|
||
|
||
# 消息序列化
|
||
NO_SERIALIZATION = 0b0000
|
||
JSON = 0b0001
|
||
|
||
# 消息压缩
|
||
NO_COMPRESSION = 0b0000
|
||
GZIP = 0b0001
|
||
|
||
# 假设这是你的目标WebSocket服务器的URL
|
||
TARGET_WS_URL = "wss://openspeech.bytedance.com/api/v3/sauc/bigmodel"
|
||
reqId = 1
|
||
|
||
wav_file = None
|
||
write_to_file = False
|
||
total_chunk_size = 0
|
||
total_chunks = 0
|
||
|
||
|
||
def create_wav_file(file_name, sample_width, sample_rate, channels):
|
||
"""
|
||
创建一个WAV文件并设置其参数。
|
||
|
||
参数:
|
||
- file_name: 输出WAV文件的名称。
|
||
- sample_width: 采样宽度(通常为2字节表示16位)。
|
||
- sample_rate: 采样率(例如44100Hz)。
|
||
- channels: 声道数(1表示单声道,2表示立体声)。
|
||
"""
|
||
# 检查文件是否存在,如果存在则打开并清空文件,否则创建新文件
|
||
mode = 'wb' if os.path.exists(file_name) else 'wb'
|
||
wav_file = wave.open(file_name, mode)
|
||
wav_file.setnchannels(channels)
|
||
wav_file.setsampwidth(sample_width)
|
||
wav_file.setframerate(sample_rate)
|
||
return wav_file
|
||
|
||
|
||
def write_pcm_data_to_wav(wav_file, pcm_data):
|
||
"""
|
||
将PCM数据写入到已打开的WAV文件。
|
||
|
||
参数:
|
||
- wav_file: 已经打开的WAV文件对象。
|
||
- pcm_data: 要写入的PCM音频数据的字节序列。
|
||
"""
|
||
wav_file.writeframesraw(pcm_data)
|
||
|
||
|
||
def close_wav_file(wav_file):
|
||
"""
|
||
关闭WAV文件。
|
||
|
||
参数:
|
||
- wav_file: 已经打开的WAV文件对象。
|
||
"""
|
||
wav_file.close()
|
||
|
||
|
||
# 生成请求头
|
||
def generate_header(message_type=FULL_CLIENT_REQUEST, message_type_specific_flags=NO_SEQUENCE, serial_method=JSON,
|
||
compression_type=GZIP, reserved_data=0x00):
|
||
header = bytearray()
|
||
header_size = 1
|
||
header.append((PROTOCOL_VERSION << 4) | header_size)
|
||
header.append((message_type << 4) | message_type_specific_flags)
|
||
header.append((serial_method << 4) | compression_type)
|
||
header.append(reserved_data)
|
||
return header
|
||
|
||
|
||
# 生成负载前置数据
|
||
def generate_before_payload(sequence: int):
|
||
before_payload = bytearray()
|
||
before_payload.extend(sequence.to_bytes(4, 'big', signed=True)) # sequence
|
||
return before_payload
|
||
|
||
|
||
# 构建完整请求
|
||
def construct_request(reqid):
|
||
req = {
|
||
"user": {
|
||
"uid": "test",
|
||
},
|
||
"audio": {
|
||
'format': "pcm",
|
||
"sample_rate": 16000,
|
||
"bits": 16,
|
||
"channel": 1,
|
||
"codec": "raw",
|
||
},
|
||
"request": {
|
||
"model_name": "bigmodel",
|
||
"enable_punc": True,
|
||
"enable_ddc": True,
|
||
"show_utterances": False
|
||
}
|
||
}
|
||
return req
|
||
|
||
|
||
def construct_full_request(reqid, seq):
|
||
request_params = construct_request(reqid)
|
||
payload_bytes = str.encode(json.dumps(request_params))
|
||
payload_bytes = gzip.compress(payload_bytes)
|
||
full_client_request = bytearray(generate_header(message_type_specific_flags=POS_SEQUENCE))
|
||
full_client_request.extend(generate_before_payload(sequence=seq))
|
||
full_client_request.extend((len(payload_bytes)).to_bytes(4, 'big'))
|
||
full_client_request.extend(payload_bytes) # payload
|
||
seq = seq + 1
|
||
return full_client_request
|
||
|
||
|
||
# 构建音频数据请求
|
||
def construct_audioData_request(audio_data, seq):
|
||
payload_bytes = gzip.compress(audio_data)
|
||
audio_only_request = bytearray(
|
||
generate_header(message_type=AUDIO_ONLY_REQUEST, message_type_specific_flags=POS_SEQUENCE))
|
||
audio_only_request.extend(generate_before_payload(sequence=seq))
|
||
audio_only_request.extend((len(payload_bytes)).to_bytes(4, 'big')) # payload size
|
||
audio_only_request.extend(payload_bytes) # payload
|
||
seq = seq + 1
|
||
return audio_only_request
|
||
|
||
|
||
# 解析响应
|
||
def parse_response(res):
|
||
protocol_version = res[0] >> 4
|
||
header_size = res[0] & 0x0f
|
||
message_type = res[1] >> 4
|
||
message_type_specific_flags = res[1] & 0x0f
|
||
serialization_method = res[2] >> 4
|
||
message_compression = res[2] & 0x0f
|
||
reserved = res[3]
|
||
header_extensions = res[4:header_size * 4]
|
||
payload = res[header_size * 4:]
|
||
|
||
result = {'is_last_package': False}
|
||
payload_msg = None
|
||
payload_size = 0
|
||
|
||
if message_type_specific_flags & 0x01:
|
||
seq = int.from_bytes(payload[:4], "big", signed=True)
|
||
result['payload_sequence'] = seq
|
||
payload = payload[4:]
|
||
|
||
if message_type_specific_flags & 0x02:
|
||
result['is_last_package'] = True
|
||
|
||
if message_type == FULL_SERVER_RESPONSE:
|
||
payload_size = int.from_bytes(payload[:4], "big", signed=True)
|
||
payload_msg = payload[4:]
|
||
elif message_type == SERVER_ACK:
|
||
seq = int.from_bytes(payload[:4], "big", signed=True)
|
||
result['seq'] = seq
|
||
if len(payload) >= 8:
|
||
payload_size = int.from_bytes(payload[4:8], "big", signed=False)
|
||
payload_msg = payload[8:]
|
||
elif message_type == SERVER_ERROR_RESPONSE:
|
||
code = int.from_bytes(payload[:4], "big", signed=False)
|
||
result['code'] = code
|
||
payload_size = int.from_bytes(payload[4:8], "big", signed=False)
|
||
payload_msg = payload[8:]
|
||
|
||
if payload_msg is None:
|
||
return result
|
||
|
||
if message_compression == GZIP:
|
||
payload_msg = gzip.decompress(payload_msg)
|
||
|
||
if serialization_method == JSON:
|
||
payload_msg = json.loads(str(payload_msg, "utf-8"))
|
||
elif serialization_method != NO_SERIALIZATION:
|
||
payload_msg = str(payload_msg, "utf-8")
|
||
|
||
result['payload_msg'] = payload_msg
|
||
result['payload_size'] = payload_size
|
||
return result
|
||
|
||
|
||
# 前向转发到目标WebSocket服务器
|
||
async def forward_to_server(websocket: WebSocket, client_ws: websockets.WebSocketClientProtocol, seq: int):
|
||
global wav_file, total_chunk_size, total_chunks
|
||
try:
|
||
while True:
|
||
audio_data = await websocket.receive_bytes()
|
||
if not audio_data:
|
||
break
|
||
if write_to_file:
|
||
write_pcm_data_to_wav(wav_file, audio_data)
|
||
total_chunk_size = total_chunk_size + len(audio_data)
|
||
total_chunks = total_chunks + 1
|
||
print(f"write {len(audio_data)} to file,total={total_chunk_size} {total_chunks}")
|
||
seq += 1 # 更新序列号
|
||
# print(f"get audio data from client {len(audio_data)}")
|
||
await client_ws.send(construct_audioData_request(audio_data, seq))
|
||
# seq += 1 # 更新序列号
|
||
except WebSocketDisconnect:
|
||
logger.info("Client disconnected, closing connection to target server.")
|
||
await client_ws.close()
|
||
except Exception as e:
|
||
logger.error(f"Error in forward_to_server: {e}")
|
||
await client_ws.close()
|
||
|
||
|
||
# 前向转发到客户端
|
||
async def forward_to_client(websocket: WebSocket, client_ws: websockets.WebSocketClientProtocol):
|
||
try:
|
||
while True:
|
||
response = await client_ws.recv()
|
||
if response:
|
||
data = parse_response(response)
|
||
# print(f"forward to client:{data}")
|
||
await websocket.send_text(json.dumps(data))
|
||
except websockets.exceptions.ConnectionClosed:
|
||
logger.warning(f"Target front client connection closed. {websocket.client_state}")
|
||
if websocket.client_state not in [WebSocketState.DISCONNECTED, WebSocketState.CLOSING]:
|
||
await websocket.close()
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error in forward_to_client: {e} ")
|
||
# 最后关闭文件
|
||
if write_to_file:
|
||
close_wav_file(wav_file)
|
||
if websocket.client_state not in [WebSocketState.DISCONNECTED, WebSocketState.CLOSING]:
|
||
await websocket.close()
|
||
|
||
|
||
async def send_full_request():
|
||
reqid = str(uuid.uuid4())
|
||
|
||
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||
ssl_context.check_hostname = False
|
||
ssl_context.verify_mode = ssl.CERT_NONE
|
||
|
||
header = {
|
||
"X-Api-Resource-Id": "volc.bigasr.sauc.duration",
|
||
"X-Api-Access-Key": "v64Fj-fwLLKIHBgqH2_fWx5dsBEShXd9",
|
||
"X-Api-App-Key": "7282190702",
|
||
"X-Api-Request-Id": reqid
|
||
}
|
||
seq = 1
|
||
try:
|
||
# connect doubao 时带了extra_headers ,所以只能使用websockets 12.0 版本, pip install websockets == 12.0
|
||
async with websockets.connect(TARGET_WS_URL, extra_headers=header, ssl=ssl_context) as client_ws:
|
||
await client_ws.send(construct_full_request(reqid, seq))
|
||
|
||
print("send full_request to doubao server")
|
||
res = await client_ws.recv()
|
||
result = parse_response(res)
|
||
print("connect response", result)
|
||
|
||
except WebSocketDisconnect:
|
||
logger.info("Frontend client disconnected.")
|
||
except Exception as e:
|
||
logger.warning(f"Error in WebSocket connection: {e}")
|
||
finally:
|
||
try:
|
||
if websocket.client_state != WebSocketState.DISCONNECTED:
|
||
await websocket.close()
|
||
except Exception as e:
|
||
logger.error(f"Error while closing websocket: {e}")
|
||
|
||
|
||
@asr_router.websocket("/ws")
|
||
async def websocket_endpoint(websocket: WebSocket):
|
||
global wav_file, write_to_file
|
||
reqid = str(uuid.uuid4())
|
||
|
||
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||
ssl_context.check_hostname = False
|
||
ssl_context.verify_mode = ssl.CERT_NONE
|
||
|
||
header = {
|
||
"X-Api-Resource-Id": "volc.bigasr.sauc.duration",
|
||
"X-Api-Access-Key": "v64Fj-fwLLKIHBgqH2_fWx5dsBEShXd9",
|
||
"X-Api-App-Key": "7282190702",
|
||
"X-Api-Request-Id": reqid
|
||
}
|
||
seq = 1
|
||
await websocket.accept()
|
||
|
||
try:
|
||
async with websockets.connect(TARGET_WS_URL, extra_headers=header, ssl=ssl_context) as client_ws:
|
||
await client_ws.send(construct_full_request(reqid, seq))
|
||
print("send full_request to doubao server")
|
||
res = await client_ws.recv()
|
||
result = parse_response(res)
|
||
print("connect response", result)
|
||
if write_to_file:
|
||
wav_file = create_wav_file('test.wav', 2, 16000, 1)
|
||
|
||
print("start two task")
|
||
task1 = asyncio.create_task(forward_to_server(websocket, client_ws, seq))
|
||
task2 = asyncio.create_task(forward_to_client(websocket, client_ws))
|
||
|
||
await asyncio.gather(task1, task2)
|
||
|
||
except WebSocketDisconnect:
|
||
logger.info("Frontend client disconnected.")
|
||
except Exception as e:
|
||
logger.warning(f"front WebSocket client connection: {e}")
|
||
finally:
|
||
try:
|
||
if write_to_file:
|
||
close_wav_file(wav_file)
|
||
print("close wav file")
|
||
if websocket.client_state != WebSocketState.DISCONNECTED:
|
||
await websocket.close()
|
||
except Exception as e:
|
||
logger.error(f"Error while closing websocket: {e}")
|
||
|
||
# REST API路由
|
||
@asr_router.get("/health")
|
||
async def health():
|
||
return {"message": "I am ASR service;I am OK"} |