148 lines
5.0 KiB
Python
148 lines
5.0 KiB
Python
|
|
import array
|
|||
|
|
import asyncio
|
|||
|
|
import base64
|
|||
|
|
import binascii
|
|||
|
|
import datetime
|
|||
|
|
import gzip
|
|||
|
|
import io
|
|||
|
|
import json
|
|||
|
|
import logging
|
|||
|
|
import os
|
|||
|
|
import queue
|
|||
|
|
import re
|
|||
|
|
import threading
|
|||
|
|
import time
|
|||
|
|
import time
|
|||
|
|
import uuid
|
|||
|
|
from collections import deque
|
|||
|
|
from concurrent.futures import ThreadPoolExecutor
|
|||
|
|
from copy import deepcopy
|
|||
|
|
from datetime import timedelta
|
|||
|
|
from io import BytesIO
|
|||
|
|
from threading import Lock, Thread
|
|||
|
|
from timeit import default_timer as timer
|
|||
|
|
from typing import Optional, Dict, Any
|
|||
|
|
|
|||
|
|
from fastapi import FastAPI, UploadFile, File, Form, Header
|
|||
|
|
from fastapi import WebSocket, APIRouter, WebSocketDisconnect, Request, Body, Query, Depends
|
|||
|
|
from fastapi.responses import StreamingResponse, JSONResponse, Response
|
|||
|
|
from openai import OpenAI
|
|||
|
|
chat_router = APIRouter()
|
|||
|
|
|
|||
|
|
# 从环境变量读取 OpenAI API 密钥
|
|||
|
|
openai_api_key = os.getenv("DASHSCOPE_API_KEY")
|
|||
|
|
if not openai_api_key:
|
|||
|
|
raise RuntimeError("DASHSCOPE_API_KEY environment variable not set")
|
|||
|
|
|
|||
|
|
class MillisecondsFormatter(logging.Formatter):
|
|||
|
|
"""自定义日志格式器,添加毫秒时间戳"""
|
|||
|
|
def formatTime(self, record, datefmt=None):
|
|||
|
|
# 将时间戳转换为本地时间元组
|
|||
|
|
ct = self.converter(record.created)
|
|||
|
|
# 格式化为 "小时:分钟:秒"
|
|||
|
|
t = time.strftime("%H:%M:%S", ct)
|
|||
|
|
# 添加毫秒(3位)
|
|||
|
|
return f"{t}.{int(record.msecs):03d}"
|
|||
|
|
|
|||
|
|
# 配置全局日志格式
|
|||
|
|
def configure_logging():
|
|||
|
|
# 创建 Formatter
|
|||
|
|
log_format = "%(asctime)s - %(levelname)s - %(message)s"
|
|||
|
|
formatter = MillisecondsFormatter(log_format)
|
|||
|
|
|
|||
|
|
# 获取根 Logger 并清除已有配置
|
|||
|
|
root_logger = logging.getLogger()
|
|||
|
|
root_logger.handlers = []
|
|||
|
|
|
|||
|
|
# 创建并配置 Handler(输出到控制台)
|
|||
|
|
console_handler = logging.StreamHandler()
|
|||
|
|
console_handler.setFormatter(formatter)
|
|||
|
|
|
|||
|
|
# 设置日志级别并添加 Handler
|
|||
|
|
root_logger.setLevel(logging.INFO)
|
|||
|
|
root_logger.addHandler(console_handler)
|
|||
|
|
|
|||
|
|
# 调用配置函数(程序启动时运行一次)
|
|||
|
|
configure_logging()
|
|||
|
|
|
|||
|
|
# 简单的 API 密钥验证(可选)
|
|||
|
|
def verify_api_key(request: Request):
|
|||
|
|
"""简单的 API 密钥验证"""
|
|||
|
|
if request.headers.get("X-API-KEY") != os.getenv("APP_API_KEY", "default_key"):
|
|||
|
|
raise HTTPException(status_code=401, detail="Invalid API key")
|
|||
|
|
return True
|
|||
|
|
|
|||
|
|
|
|||
|
|
@chat_router.post("/completion")
|
|||
|
|
async def chat_completion(request: Request):
|
|||
|
|
"""
|
|||
|
|
与大语言模型进行对话
|
|||
|
|
|
|||
|
|
请求体示例:
|
|||
|
|
{
|
|||
|
|
"messages": [
|
|||
|
|
{"role": "system", "content": "你是一个有帮助的助手"},
|
|||
|
|
{"role": "user", "content": "你好,请介绍一下FastAPI"}
|
|||
|
|
],
|
|||
|
|
"model": "gpt-3.5-turbo",
|
|||
|
|
"temperature": 0.7,
|
|||
|
|
"max_tokens": 500
|
|||
|
|
}
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
# 手动解析请求体
|
|||
|
|
request_data = await request.json()
|
|||
|
|
|
|||
|
|
# 验证必需字段
|
|||
|
|
if "messages" not in request_data or not isinstance(request_data["messages"], list):
|
|||
|
|
raise HTTPException(status_code=400, detail="Missing or invalid 'messages' field")
|
|||
|
|
|
|||
|
|
# 提取参数并提供默认值
|
|||
|
|
messages = request_data["messages"]
|
|||
|
|
model = request_data.get("model", "qwen-plus-latest")
|
|||
|
|
temperature = float(request_data.get("temperature", 0.7))
|
|||
|
|
max_tokens = int(request_data.get("max_tokens", 500))
|
|||
|
|
|
|||
|
|
# 验证消息结构
|
|||
|
|
for msg in messages:
|
|||
|
|
if "role" not in msg or "content" not in msg:
|
|||
|
|
raise HTTPException(status_code=400, detail="Invalid message structure")
|
|||
|
|
if not isinstance(msg["role"], str) or not isinstance(msg["content"], str):
|
|||
|
|
raise HTTPException(status_code=400, detail="Message content must be strings")
|
|||
|
|
|
|||
|
|
logging.info(f"Received chat request: model={model}, messages={len(messages)}")
|
|||
|
|
client = OpenAI(
|
|||
|
|
# 若没有配置环境变量,请用百炼API Key将下行替换为:api_key="sk-xxx",
|
|||
|
|
api_key=os.getenv("DASHSCOPE_API_KEY"),
|
|||
|
|
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 调用 OpenAI API
|
|||
|
|
response = client.chat.completions.create(
|
|||
|
|
model=model,
|
|||
|
|
messages=messages,
|
|||
|
|
temperature=temperature
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 处理响应
|
|||
|
|
if not response.choices:
|
|||
|
|
raise HTTPException(status_code=500, detail="No response from AI model")
|
|||
|
|
|
|||
|
|
choice = response.choices[0]
|
|||
|
|
message = choice.message
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
"message": {
|
|||
|
|
"role": message.role,
|
|||
|
|
"content": message.content
|
|||
|
|
},
|
|||
|
|
"finish_reason": choice.finish_reason
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
except json.JSONDecodeError:
|
|||
|
|
raise HTTPException(status_code=400, detail="Invalid JSON format")
|
|||
|
|
except ValueError as e:
|
|||
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|||
|
|
except Exception as e:
|
|||
|
|
logging.exception("Internal server error")
|
|||
|
|
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|