Files
ragflow_python/asr-monitor-test/app/chat_service.py

148 lines
5.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import array
import asyncio
import base64
import binascii
import datetime
import gzip
import io
import json
import logging
import os
import queue
import re
import threading
import time
import time
import uuid
from collections import deque
from concurrent.futures import ThreadPoolExecutor
from copy import deepcopy
from datetime import timedelta
from io import BytesIO
from threading import Lock, Thread
from timeit import default_timer as timer
from typing import Optional, Dict, Any
from fastapi import FastAPI, UploadFile, File, Form, Header
from fastapi import WebSocket, APIRouter, WebSocketDisconnect, Request, Body, Query, Depends
from fastapi.responses import StreamingResponse, JSONResponse, Response
from openai import OpenAI
chat_router = APIRouter()
# 从环境变量读取 OpenAI API 密钥
openai_api_key = os.getenv("DASHSCOPE_API_KEY")
if not openai_api_key:
raise RuntimeError("DASHSCOPE_API_KEY environment variable not set")
class MillisecondsFormatter(logging.Formatter):
"""自定义日志格式器,添加毫秒时间戳"""
def formatTime(self, record, datefmt=None):
# 将时间戳转换为本地时间元组
ct = self.converter(record.created)
# 格式化为 "小时:分钟:秒"
t = time.strftime("%H:%M:%S", ct)
# 添加毫秒3位
return f"{t}.{int(record.msecs):03d}"
# 配置全局日志格式
def configure_logging():
# 创建 Formatter
log_format = "%(asctime)s - %(levelname)s - %(message)s"
formatter = MillisecondsFormatter(log_format)
# 获取根 Logger 并清除已有配置
root_logger = logging.getLogger()
root_logger.handlers = []
# 创建并配置 Handler输出到控制台
console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)
# 设置日志级别并添加 Handler
root_logger.setLevel(logging.INFO)
root_logger.addHandler(console_handler)
# 调用配置函数(程序启动时运行一次)
configure_logging()
# 简单的 API 密钥验证(可选)
def verify_api_key(request: Request):
"""简单的 API 密钥验证"""
if request.headers.get("X-API-KEY") != os.getenv("APP_API_KEY", "default_key"):
raise HTTPException(status_code=401, detail="Invalid API key")
return True
@chat_router.post("/completion")
async def chat_completion(request: Request):
"""
与大语言模型进行对话
请求体示例:
{
"messages": [
{"role": "system", "content": "你是一个有帮助的助手"},
{"role": "user", "content": "你好请介绍一下FastAPI"}
],
"model": "gpt-3.5-turbo",
"temperature": 0.7,
"max_tokens": 500
}
"""
try:
# 手动解析请求体
request_data = await request.json()
# 验证必需字段
if "messages" not in request_data or not isinstance(request_data["messages"], list):
raise HTTPException(status_code=400, detail="Missing or invalid 'messages' field")
# 提取参数并提供默认值
messages = request_data["messages"]
model = request_data.get("model", "qwen-plus-latest")
temperature = float(request_data.get("temperature", 0.7))
max_tokens = int(request_data.get("max_tokens", 500))
# 验证消息结构
for msg in messages:
if "role" not in msg or "content" not in msg:
raise HTTPException(status_code=400, detail="Invalid message structure")
if not isinstance(msg["role"], str) or not isinstance(msg["content"], str):
raise HTTPException(status_code=400, detail="Message content must be strings")
logging.info(f"Received chat request: model={model}, messages={len(messages)}")
client = OpenAI(
# 若没有配置环境变量请用百炼API Key将下行替换为api_key="sk-xxx",
api_key=os.getenv("DASHSCOPE_API_KEY"),
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
)
# 调用 OpenAI API
response = client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature
)
# 处理响应
if not response.choices:
raise HTTPException(status_code=500, detail="No response from AI model")
choice = response.choices[0]
message = choice.message
return {
"message": {
"role": message.role,
"content": message.content
},
"finish_reason": choice.finish_reason
}
except json.JSONDecodeError:
raise HTTPException(status_code=400, detail="Invalid JSON format")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logging.exception("Internal server error")
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")