148 lines
5.0 KiB
Python
148 lines
5.0 KiB
Python
import array
|
||
import asyncio
|
||
import base64
|
||
import binascii
|
||
import datetime
|
||
import gzip
|
||
import io
|
||
import json
|
||
import logging
|
||
import os
|
||
import queue
|
||
import re
|
||
import threading
|
||
import time
|
||
import time
|
||
import uuid
|
||
from collections import deque
|
||
from concurrent.futures import ThreadPoolExecutor
|
||
from copy import deepcopy
|
||
from datetime import timedelta
|
||
from io import BytesIO
|
||
from threading import Lock, Thread
|
||
from timeit import default_timer as timer
|
||
from typing import Optional, Dict, Any
|
||
|
||
from fastapi import FastAPI, UploadFile, File, Form, Header
|
||
from fastapi import WebSocket, APIRouter, WebSocketDisconnect, Request, Body, Query, Depends
|
||
from fastapi.responses import StreamingResponse, JSONResponse, Response
|
||
from openai import OpenAI
|
||
chat_router = APIRouter()
|
||
|
||
# 从环境变量读取 OpenAI API 密钥
|
||
openai_api_key = os.getenv("DASHSCOPE_API_KEY")
|
||
if not openai_api_key:
|
||
raise RuntimeError("DASHSCOPE_API_KEY environment variable not set")
|
||
|
||
class MillisecondsFormatter(logging.Formatter):
|
||
"""自定义日志格式器,添加毫秒时间戳"""
|
||
def formatTime(self, record, datefmt=None):
|
||
# 将时间戳转换为本地时间元组
|
||
ct = self.converter(record.created)
|
||
# 格式化为 "小时:分钟:秒"
|
||
t = time.strftime("%H:%M:%S", ct)
|
||
# 添加毫秒(3位)
|
||
return f"{t}.{int(record.msecs):03d}"
|
||
|
||
# 配置全局日志格式
|
||
def configure_logging():
|
||
# 创建 Formatter
|
||
log_format = "%(asctime)s - %(levelname)s - %(message)s"
|
||
formatter = MillisecondsFormatter(log_format)
|
||
|
||
# 获取根 Logger 并清除已有配置
|
||
root_logger = logging.getLogger()
|
||
root_logger.handlers = []
|
||
|
||
# 创建并配置 Handler(输出到控制台)
|
||
console_handler = logging.StreamHandler()
|
||
console_handler.setFormatter(formatter)
|
||
|
||
# 设置日志级别并添加 Handler
|
||
root_logger.setLevel(logging.INFO)
|
||
root_logger.addHandler(console_handler)
|
||
|
||
# 调用配置函数(程序启动时运行一次)
|
||
configure_logging()
|
||
|
||
# 简单的 API 密钥验证(可选)
|
||
def verify_api_key(request: Request):
|
||
"""简单的 API 密钥验证"""
|
||
if request.headers.get("X-API-KEY") != os.getenv("APP_API_KEY", "default_key"):
|
||
raise HTTPException(status_code=401, detail="Invalid API key")
|
||
return True
|
||
|
||
|
||
@chat_router.post("/completion")
|
||
async def chat_completion(request: Request):
|
||
"""
|
||
与大语言模型进行对话
|
||
|
||
请求体示例:
|
||
{
|
||
"messages": [
|
||
{"role": "system", "content": "你是一个有帮助的助手"},
|
||
{"role": "user", "content": "你好,请介绍一下FastAPI"}
|
||
],
|
||
"model": "gpt-3.5-turbo",
|
||
"temperature": 0.7,
|
||
"max_tokens": 500
|
||
}
|
||
"""
|
||
try:
|
||
# 手动解析请求体
|
||
request_data = await request.json()
|
||
|
||
# 验证必需字段
|
||
if "messages" not in request_data or not isinstance(request_data["messages"], list):
|
||
raise HTTPException(status_code=400, detail="Missing or invalid 'messages' field")
|
||
|
||
# 提取参数并提供默认值
|
||
messages = request_data["messages"]
|
||
model = request_data.get("model", "qwen-plus-latest")
|
||
temperature = float(request_data.get("temperature", 0.7))
|
||
max_tokens = int(request_data.get("max_tokens", 500))
|
||
|
||
# 验证消息结构
|
||
for msg in messages:
|
||
if "role" not in msg or "content" not in msg:
|
||
raise HTTPException(status_code=400, detail="Invalid message structure")
|
||
if not isinstance(msg["role"], str) or not isinstance(msg["content"], str):
|
||
raise HTTPException(status_code=400, detail="Message content must be strings")
|
||
|
||
logging.info(f"Received chat request: model={model}, messages={len(messages)}")
|
||
client = OpenAI(
|
||
# 若没有配置环境变量,请用百炼API Key将下行替换为:api_key="sk-xxx",
|
||
api_key=os.getenv("DASHSCOPE_API_KEY"),
|
||
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||
)
|
||
|
||
# 调用 OpenAI API
|
||
response = client.chat.completions.create(
|
||
model=model,
|
||
messages=messages,
|
||
temperature=temperature
|
||
)
|
||
|
||
# 处理响应
|
||
if not response.choices:
|
||
raise HTTPException(status_code=500, detail="No response from AI model")
|
||
|
||
choice = response.choices[0]
|
||
message = choice.message
|
||
|
||
return {
|
||
"message": {
|
||
"role": message.role,
|
||
"content": message.content
|
||
},
|
||
"finish_reason": choice.finish_reason
|
||
}
|
||
|
||
except json.JSONDecodeError:
|
||
raise HTTPException(status_code=400, detail="Invalid JSON format")
|
||
except ValueError as e:
|
||
raise HTTPException(status_code=400, detail=str(e))
|
||
except Exception as e:
|
||
logging.exception("Internal server error")
|
||
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") |