# app/asr_service.py import os import json import gzip import uuid import wave import ssl import logging from fastapi import WebSocket, WebSocketDisconnect from fastapi.routing import APIRouter import websockets import asyncio from starlette.websockets import WebSocketState # 初始化路由 asr_router = APIRouter() # 配置日志 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # 定义常量 PROTOCOL_VERSION = 0b0001 DEFAULT_HEADER_SIZE = 0b0001 # 消息类型 FULL_CLIENT_REQUEST = 0b0001 AUDIO_ONLY_REQUEST = 0b0010 FULL_SERVER_RESPONSE = 0b1001 SERVER_ACK = 0b1011 SERVER_ERROR_RESPONSE = 0b1111 # 消息类型特定标志 NO_SEQUENCE = 0b0000 # 无检查序列 POS_SEQUENCE = 0b0001 NEG_SEQUENCE = 0b0010 NEG_WITH_SEQUENCE = 0b0011 # 消息序列化 NO_SERIALIZATION = 0b0000 JSON = 0b0001 # 消息压缩 NO_COMPRESSION = 0b0000 GZIP = 0b0001 # 假设这是你的目标WebSocket服务器的URL TARGET_WS_URL = "wss://openspeech.bytedance.com/api/v3/sauc/bigmodel" reqId = 1 wav_file = None write_to_file = False total_chunk_size = 0 total_chunks = 0 def create_wav_file(file_name, sample_width, sample_rate, channels): """ 创建一个WAV文件并设置其参数。 参数: - file_name: 输出WAV文件的名称。 - sample_width: 采样宽度(通常为2字节表示16位)。 - sample_rate: 采样率(例如44100Hz)。 - channels: 声道数(1表示单声道,2表示立体声)。 """ # 检查文件是否存在,如果存在则打开并清空文件,否则创建新文件 mode = 'wb' if os.path.exists(file_name) else 'wb' wav_file = wave.open(file_name, mode) wav_file.setnchannels(channels) wav_file.setsampwidth(sample_width) wav_file.setframerate(sample_rate) return wav_file def write_pcm_data_to_wav(wav_file, pcm_data): """ 将PCM数据写入到已打开的WAV文件。 参数: - wav_file: 已经打开的WAV文件对象。 - pcm_data: 要写入的PCM音频数据的字节序列。 """ wav_file.writeframesraw(pcm_data) def close_wav_file(wav_file): """ 关闭WAV文件。 参数: - wav_file: 已经打开的WAV文件对象。 """ wav_file.close() # 生成请求头 def generate_header(message_type=FULL_CLIENT_REQUEST, message_type_specific_flags=NO_SEQUENCE, serial_method=JSON, compression_type=GZIP, reserved_data=0x00): header = bytearray() header_size = 1 header.append((PROTOCOL_VERSION << 4) | header_size) header.append((message_type << 4) | message_type_specific_flags) header.append((serial_method << 4) | compression_type) header.append(reserved_data) return header # 生成负载前置数据 def generate_before_payload(sequence: int): before_payload = bytearray() before_payload.extend(sequence.to_bytes(4, 'big', signed=True)) # sequence return before_payload # 构建完整请求 def construct_request(reqid): req = { "user": { "uid": "test", }, "audio": { 'format': "pcm", "sample_rate": 16000, "bits": 16, "channel": 1, "codec": "raw", }, "request": { "model_name": "bigmodel", "enable_punc": True, "enable_ddc": True, "show_utterances": False } } return req def construct_full_request(reqid, seq): request_params = construct_request(reqid) payload_bytes = str.encode(json.dumps(request_params)) payload_bytes = gzip.compress(payload_bytes) full_client_request = bytearray(generate_header(message_type_specific_flags=POS_SEQUENCE)) full_client_request.extend(generate_before_payload(sequence=seq)) full_client_request.extend((len(payload_bytes)).to_bytes(4, 'big')) full_client_request.extend(payload_bytes) # payload seq = seq + 1 return full_client_request # 构建音频数据请求 def construct_audioData_request(audio_data, seq): payload_bytes = gzip.compress(audio_data) audio_only_request = bytearray( generate_header(message_type=AUDIO_ONLY_REQUEST, message_type_specific_flags=POS_SEQUENCE)) audio_only_request.extend(generate_before_payload(sequence=seq)) audio_only_request.extend((len(payload_bytes)).to_bytes(4, 'big')) # payload size audio_only_request.extend(payload_bytes) # payload seq = seq + 1 return audio_only_request # 解析响应 def parse_response(res): protocol_version = res[0] >> 4 header_size = res[0] & 0x0f message_type = res[1] >> 4 message_type_specific_flags = res[1] & 0x0f serialization_method = res[2] >> 4 message_compression = res[2] & 0x0f reserved = res[3] header_extensions = res[4:header_size * 4] payload = res[header_size * 4:] result = {'is_last_package': False} payload_msg = None payload_size = 0 if message_type_specific_flags & 0x01: seq = int.from_bytes(payload[:4], "big", signed=True) result['payload_sequence'] = seq payload = payload[4:] if message_type_specific_flags & 0x02: result['is_last_package'] = True if message_type == FULL_SERVER_RESPONSE: payload_size = int.from_bytes(payload[:4], "big", signed=True) payload_msg = payload[4:] elif message_type == SERVER_ACK: seq = int.from_bytes(payload[:4], "big", signed=True) result['seq'] = seq if len(payload) >= 8: payload_size = int.from_bytes(payload[4:8], "big", signed=False) payload_msg = payload[8:] elif message_type == SERVER_ERROR_RESPONSE: code = int.from_bytes(payload[:4], "big", signed=False) result['code'] = code payload_size = int.from_bytes(payload[4:8], "big", signed=False) payload_msg = payload[8:] if payload_msg is None: return result if message_compression == GZIP: payload_msg = gzip.decompress(payload_msg) if serialization_method == JSON: payload_msg = json.loads(str(payload_msg, "utf-8")) elif serialization_method != NO_SERIALIZATION: payload_msg = str(payload_msg, "utf-8") result['payload_msg'] = payload_msg result['payload_size'] = payload_size return result # 前向转发到目标WebSocket服务器 async def forward_to_server(websocket: WebSocket, client_ws: websockets.WebSocketClientProtocol, seq: int): global wav_file, total_chunk_size, total_chunks try: while True: audio_data = await websocket.receive_bytes() if not audio_data: break if write_to_file: write_pcm_data_to_wav(wav_file, audio_data) total_chunk_size = total_chunk_size + len(audio_data) total_chunks = total_chunks + 1 print(f"write {len(audio_data)} to file,total={total_chunk_size} {total_chunks}") seq += 1 # 更新序列号 # print(f"get audio data from client {len(audio_data)}") await client_ws.send(construct_audioData_request(audio_data, seq)) # seq += 1 # 更新序列号 except WebSocketDisconnect: logger.info("Client disconnected, closing connection to target server.") await client_ws.close() except Exception as e: logger.error(f"Error in forward_to_server: {e}") await client_ws.close() # 前向转发到客户端 async def forward_to_client(websocket: WebSocket, client_ws: websockets.WebSocketClientProtocol): try: while True: response = await client_ws.recv() if response: data = parse_response(response) # print(f"forward to client:{data}") await websocket.send_text(json.dumps(data)) except websockets.exceptions.ConnectionClosed: logger.warning(f"Target front client connection closed. {websocket.client_state}") if websocket.client_state not in [WebSocketState.DISCONNECTED, WebSocketState.CLOSING]: await websocket.close() except Exception as e: logger.error(f"Error in forward_to_client: {e} ") # 最后关闭文件 if write_to_file: close_wav_file(wav_file) if websocket.client_state not in [WebSocketState.DISCONNECTED, WebSocketState.CLOSING]: await websocket.close() async def send_full_request(): reqid = str(uuid.uuid4()) ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ssl_context.check_hostname = False ssl_context.verify_mode = ssl.CERT_NONE header = { "X-Api-Resource-Id": "volc.bigasr.sauc.duration", "X-Api-Access-Key": "v64Fj-fwLLKIHBgqH2_fWx5dsBEShXd9", "X-Api-App-Key": "7282190702", "X-Api-Request-Id": reqid } seq = 1 try: # connect doubao 时带了extra_headers ,所以只能使用websockets 12.0 版本, pip install websockets == 12.0 async with websockets.connect(TARGET_WS_URL, extra_headers=header, ssl=ssl_context) as client_ws: await client_ws.send(construct_full_request(reqid, seq)) print("send full_request to doubao server") res = await client_ws.recv() result = parse_response(res) print("connect response", result) except WebSocketDisconnect: logger.info("Frontend client disconnected.") except Exception as e: logger.warning(f"Error in WebSocket connection: {e}") finally: try: if websocket.client_state != WebSocketState.DISCONNECTED: await websocket.close() except Exception as e: logger.error(f"Error while closing websocket: {e}") @asr_router.websocket("/ws") async def websocket_endpoint(websocket: WebSocket): global wav_file, write_to_file reqid = str(uuid.uuid4()) ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ssl_context.check_hostname = False ssl_context.verify_mode = ssl.CERT_NONE header = { "X-Api-Resource-Id": "volc.bigasr.sauc.duration", "X-Api-Access-Key": "v64Fj-fwLLKIHBgqH2_fWx5dsBEShXd9", "X-Api-App-Key": "7282190702", "X-Api-Request-Id": reqid } seq = 1 await websocket.accept() try: async with websockets.connect(TARGET_WS_URL, extra_headers=header, ssl=ssl_context) as client_ws: await client_ws.send(construct_full_request(reqid, seq)) print("send full_request to doubao server") res = await client_ws.recv() result = parse_response(res) print("connect response", result) if write_to_file: wav_file = create_wav_file('test.wav', 2, 16000, 1) print("start two task") task1 = asyncio.create_task(forward_to_server(websocket, client_ws, seq)) task2 = asyncio.create_task(forward_to_client(websocket, client_ws)) await asyncio.gather(task1, task2) except WebSocketDisconnect: logger.info("Frontend client disconnected.") except Exception as e: logger.warning(f"front WebSocket client connection: {e}") finally: try: if write_to_file: close_wav_file(wav_file) print("close wav file") if websocket.client_state != WebSocketState.DISCONNECTED: await websocket.close() except Exception as e: logger.error(f"Error while closing websocket: {e}") # REST API路由 @asr_router.get("/health") async def health(): return {"message": "I am ASR service;I am OK"}