Files
ragflow_python/asr-monitor-test/app/system_admin.py
qcloud 074747f902
Some checks failed
tests / ragflow_tests (push) Has been cancelled
106.51.72.204 上的gitea重新初始化,提交到远程
2025-10-09 16:55:45 +08:00

232 lines
7.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from fastapi import APIRouter, Depends, HTTPException, status, Request, Response,Query,Header
from fastapi.responses import StreamingResponse, JSONResponse
from fastapi.security import OAuth2PasswordBearer
import hashlib
import random
import string
import xml.etree.ElementTree as ET
import requests
import time
from datetime import datetime, timedelta, date
from decimal import Decimal
from uuid import UUID
import json
import logging
from app.database import *
from jose import JWTError, jwt
import base64
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.hazmat.primitives.hashes import SHA256
from cryptography.hazmat.backends import default_backend
import httpx,threading,asyncio
class CustomJSONResponse(JSONResponse):
"""
自定义 JSON 响应类,处理特殊类型:
- datetime: 转换为 ISO 8601 字符串
- date: 转换为 ISO 8601 字符串
- Decimal: 转换为 float
"""
def render(self, content: any) -> bytes:
"""
重写渲染方法,使用自定义编码器
"""
class EnhancedJSONEncoder(json.JSONEncoder):
def default(self, obj):
"""
增强型 JSON 编码器,处理多种特殊类型:
- datetime: 转换为 ISO 8601 字符串
- date: 转换为 ISO 8601 字符串
- time: 转换为 ISO 8601 字符串
- Decimal: 转换为 float
- UUID: 转换为字符串
- numpy 类型: 转换为 Python 原生类型
"""
# 处理日期时间类型
if isinstance(obj, datetime):
return obj.isoformat()
if isinstance(obj, date):
return obj.isoformat()
# 处理 Decimal 类型
if isinstance(obj, Decimal):
return float(obj)
# 处理 UUID 类型
if isinstance(obj, UUID):
return str(obj)
"""
# 处理 numpy 类型
if isinstance(obj, np.integer):
return int(obj)
if isinstance(obj, np.floating):
return float(obj)
if isinstance(obj, np.ndarray):
return obj.tolist()
"""
# 处理其他自定义类型
if hasattr(obj, '__json__'):
return obj.__json__()
# 默认处理
return super().default(obj)
return json.dumps(
content,
ensure_ascii=False,
allow_nan=False,
indent=None,
separators=(",", ":"),
cls=EnhancedJSONEncoder
).encode("utf-8")
system_admin_router = APIRouter()
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
logger = logging.getLogger("system_admin")
JWT_SECRET_KEY = "3e5b8d7f1a9c2b6d4e0f1a9c2b6d4e0f1a9c2b6d4e0f1a9c2b6d4e0f1a9c2b6d"
ALGORITHM = "HS256"
# JWT工具函数
def create_jwt(user_id: str) -> str:
payload = {
"sub": user_id,
"exp": datetime.utcnow() + timedelta(days=7)
}
return jwt.encode(payload, JWT_SECRET_KEY, algorithm=ALGORITHM)
async def get_current_user(token: str = Depends(oauth2_scheme)):
"""
可选用户依赖项不抛出401错误
返回: 用户对象 或 None
"""
try:
payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[ALGORITHM])
user_id = payload.get("sub")
return get_user_by_id(user_id)
except (JWTError, StopIteration):
return None
@system_admin_router.get("/get_subscriptions")
async def get_museum_subscriptions(
museum_id: Optional[int] = Query(None, description="博物馆ID不提供则返回所有博物馆"),
page: int = Query(1, ge=1, description="页码从1开始"),
page_size: int = Query(50, ge=1, le=100, description="每页记录数最大100"),
#current_user: dict = Depends(get_current_user)
):
# 直接使用参数FastAPI 会自动处理类型转换和验证
result = get_all_users_subscriptions_paginated(museum_id, page, page_size)
return CustomJSONResponse({
"code": 0,
"status": "success",
"data": result.get('data'),
"pagination":result.get("pagination")
})
def generate_json_response(code: str, message: str) -> JSONResponse:
return JSONResponse(
content={"code": code, "message": message},
status_code=200
)
@system_admin_router.post("/login")
async def test_account_login(request: Request):
# 获取原始请求数据
try:
data = await request.json()
except json.JSONDecodeError:
raise HTTPException(400, "Invalid JSON")
# 校验必要参数
required_fields = ["user","password"]
if not all(k in data for k in required_fields):
raise HTTPException(400, "Missing required fields")
user = data.get('user')
password = data.get('password')
account_info=get_admin_account_info(phone=user)
if account_info and len(account_info)>0:
account_info = account_info[0]
else:
return JSONResponse({
"code": 0,
"status": "error",
"data": {
"status":"error",
"msg": "账户不存在",
"token": create_jwt(user),
"user_info": {
},
"menu_authed": [],
"museum_authed": []
}
})
if account_info['password'] != password:
return JSONResponse({
"code": 0,
"status": "error",
"data": {
"status":"error",
"msg": "密码不正确",
"token": create_jwt(user),
"user_info": {
},
"menu_authed": [],
"museum_authed": []
}
})
museum_authed_str = account_info.get("museum_authed",None)
menu_authed_str = account_info.get("menu_authed","")
logging.info(f"account_info={account_info} {type(account_info)} {museum_authed_str}")
if not museum_authed_str: # 使用not 可以同步判断为空字符串
all_museums = get_museums()
logger.info(f"all {all_museums}")
museum_authed_list = [
{"id": int(item.id), "name": item.name}
for item in all_museums
if hasattr(item, 'id') and hasattr(item, 'name') and item.id is not None and item.name is not None
]
else:
museum_records = get_museums(id_list = museum_authed_str)
museum_authed_list = [
{"id": int(item['id']), "name": item['name']}
for item in museum_records
if isinstance(item, dict) and 'id' in item and 'name' in item and item['id'] is not None and item[
'name'] is not None
]
#museum_authed_list = [int(x) for x in museum_authed_str.split(',') if x.strip().isdigit()]
menu_authed_list = [x for x in menu_authed_str.split(',') ]
# ========== 数据库操作开始 ==========
# 使用数据库查询替代内存查询
#db_users = get_users(openid=user)
#user = db_users[0] if db_users else None
# ========== 数据库操作结束 ==========
logging.info(f"system account login return {user} {password}")
# 生成token
return JSONResponse({
"code": 0,
"status": "success",
"data": {
"status":"success",
"token": create_jwt(user),
"user_info": {
},
"menu_authed":menu_authed_list,
"museum_authed":museum_authed_list
}
})