451 lines
15 KiB
Python
451 lines
15 KiB
Python
from fastapi import WebSocket, APIRouter,WebSocketDisconnect,Request,Body,Query
|
||
from fastapi import FastAPI, UploadFile, File, Form, Header, Depends
|
||
from fastapi.responses import StreamingResponse,JSONResponse
|
||
from fastapi.security import OAuth2PasswordBearer
|
||
from jose import JWTError, jwt
|
||
import logging
|
||
from fastapi import HTTPException
|
||
from Crypto.Cipher import AES
|
||
import base64,uuid,asyncio
|
||
import requests
|
||
from datetime import datetime,timedelta
|
||
from app.database import *
|
||
login_router = APIRouter()
|
||
logger = logging.getLogger("login")
|
||
|
||
# 初始化 OAuth2 方案(必须放在使用它的函数之前)
|
||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") # tokenUrl 对应登录接口路径
|
||
|
||
# 需要配置的参数(从环境变量获取)
|
||
WX_APPID = "wx446813bfb3a6985a" #"wxed388cef83f109a3" # 小程序appid
|
||
WX_SECRET = "a7455fca777ad59ce96cc154d62f795f" #"f687afd2c8fae49b4aed2e4a8dd76e6e" # 小程序密钥
|
||
WX_API_URL = "https://api.weixin.qq.com/sns/jscode2session"
|
||
JWT_SECRET_KEY = "3e5b8d7f1a9c2b6d4e0f1a9c2b6d4e0f1a9c2b6d4e0f1a9c2b6d4e0f1a9c2b6d"
|
||
ALGORITHM = "HS256"
|
||
|
||
# 伪数据库
|
||
fake_db = {
|
||
"users": [],
|
||
"museums": []
|
||
}
|
||
|
||
"""
|
||
1. SECRET_KEY 的作用
|
||
作用 说明
|
||
签名验证 用于签发(sign)和验证(verify)JWT 的合法性,防止令牌被篡改
|
||
安全保障 作为加密盐值(salt),确保令牌无法被伪造或逆向破解
|
||
身份验证 确保令牌是由可信的服务器颁发的,而非第三方伪造的
|
||
"""
|
||
# JWT工具函数
|
||
def create_jwt(user_id: str) -> str:
|
||
payload = {
|
||
"sub": user_id,
|
||
"exp": datetime.utcnow() + timedelta(days=7)
|
||
}
|
||
return jwt.encode(payload, JWT_SECRET_KEY, algorithm=ALGORITHM)
|
||
|
||
|
||
def decrypt_data(encrypted_data: str, session_key: str, iv: str) -> dict:
|
||
try:
|
||
|
||
# Base64解码
|
||
session_key_bin = base64.b64decode(session_key + "=") # 补齐可能缺失的padding
|
||
encrypted_data_bin = base64.b64decode(encrypted_data)
|
||
iv_bin = base64.b64decode(iv)
|
||
|
||
# 创建解密器
|
||
cipher = AES.new(session_key_bin, AES.MODE_CBC, iv_bin)
|
||
|
||
# 执行解密
|
||
decrypted = cipher.decrypt(encrypted_data_bin)
|
||
|
||
# 去除PKCS#7填充
|
||
pad = decrypted[-1]
|
||
decrypted = decrypted[:-pad]
|
||
|
||
# 解析JSON
|
||
result = json.loads(decrypted.decode('utf-8'))
|
||
logging.info(f"解密数据: {result}")
|
||
return result
|
||
except json.JSONDecodeError:
|
||
# 特定错误类型识别
|
||
logging.info(f"解密过程失败: {str(e)}")
|
||
raise ValueError("SESSION_KEY_MISMATCH")
|
||
except Exception as e:
|
||
logging.info(f"解密过程失败: {str(e)}")
|
||
if "padding" in str(e).lower():
|
||
raise ValueError("SESSION_KEY_EXPIRED")
|
||
raise
|
||
|
||
async def get_wx_session(code: str):
|
||
"""
|
||
调用微信接口获取session_key和openid
|
||
返回: {"openid": str, "session_key": str}
|
||
"""
|
||
params = {
|
||
"appid": WX_APPID,
|
||
"secret": WX_SECRET,
|
||
"js_code": code,
|
||
"grant_type": "authorization_code"
|
||
}
|
||
|
||
try:
|
||
# 调用微信接口
|
||
response = requests.get(WX_API_URL, params=params, timeout=5)
|
||
response.raise_for_status() # 检查HTTP状态码
|
||
wx_data = response.json()
|
||
except requests.exceptions.RequestException as e:
|
||
raise HTTPException(
|
||
status_code=502,
|
||
detail=f"微信接口请求失败: {str(e)}"
|
||
)
|
||
except ValueError:
|
||
raise HTTPException(
|
||
status_code=502,
|
||
detail="微信接口返回数据解析失败"
|
||
)
|
||
|
||
# 处理微信返回错误
|
||
if "errcode" in wx_data:
|
||
error_map = {
|
||
40029: "无效的code",
|
||
45011: "API调用太频繁",
|
||
-1: "微信系统繁忙"
|
||
}
|
||
detail = error_map.get(
|
||
wx_data["errcode"],
|
||
f"微信接口错误: {wx_data.get('errmsg', '未知错误')}"
|
||
)
|
||
raise HTTPException(
|
||
status_code=401,
|
||
detail=detail
|
||
)
|
||
|
||
if "openid" not in wx_data:
|
||
raise HTTPException(
|
||
status_code=401,
|
||
detail="微信认证失败:缺少openid"
|
||
)
|
||
return wx_data
|
||
|
||
|
||
@login_router.post("/login")
|
||
async def wechat_login(request: Request):
|
||
# 获取原始请求数据
|
||
try:
|
||
data = await request.json()
|
||
except json.JSONDecodeError:
|
||
raise HTTPException(400, "Invalid JSON")
|
||
# 校验必要参数
|
||
required_fields = ["code", "encryptedData", "iv"]
|
||
if not all(k in data for k in required_fields):
|
||
raise HTTPException(400, "Missing required fields")
|
||
code = data.get('code')
|
||
encrypted_data = data['encryptedData']
|
||
iv = data['iv']
|
||
|
||
logging.info(f"wechat login data={data}")
|
||
# 关键修改:增加重试机制
|
||
max_retries = 2
|
||
for attempt in range(max_retries):
|
||
try:
|
||
# 每次尝试都重新获取 session_key
|
||
# 微信小程序登录 code 具有一次性特征(有效期约5分钟)
|
||
# 添加冷启动保护延迟
|
||
await asyncio.sleep(0.5) # 关键延迟
|
||
|
||
wx_data = await get_wx_session(code)
|
||
session_key = wx_data["session_key"]
|
||
openid = wx_data["openid"]
|
||
logging.info(f"get_wx_session return {wx_data}")
|
||
# 解密数据
|
||
try:
|
||
result = decrypt_data(encrypted_data, session_key, iv)
|
||
except ValueError as e:
|
||
# 特定错误处理
|
||
if "SESSION_KEY_" in str(e):
|
||
raise HTTPException(418, "SESSION_KEY_INVALID")
|
||
raise
|
||
|
||
if 'purePhoneNumber' not in result:
|
||
logging.warning("解密数据不包含手机号")
|
||
phone_number = None
|
||
else:
|
||
phone_number = result['purePhoneNumber']
|
||
break # 成功则跳出循环
|
||
|
||
except HTTPException as e:
|
||
if attempt < max_retries - 1:
|
||
# 特定错误时刷新 code
|
||
if "invalid session_key" in str(e).lower():
|
||
logging.warning("Session key 过期,尝试刷新")
|
||
# 触发前端重新获取 code
|
||
raise HTTPException(401, "SESSION_KEY_EXPIRED")
|
||
else:
|
||
logging.error(f"尝试 {attempt + 1} 失败: {str(e)}")
|
||
await asyncio.sleep(1) # 短暂等待后重试
|
||
else:
|
||
logging.error(f"最终解密失败: {str(e)}")
|
||
raise HTTPException(400, "Decrypt failed")
|
||
except Exception as e:
|
||
logging.error(f"解密异常: {str(e)}")
|
||
if attempt == max_retries - 1:
|
||
raise HTTPException(400, "Decrypt failed")
|
||
|
||
logging.info(f"decrypt_data return {phone_number}")
|
||
# ========== 数据库操作开始 ==========
|
||
# 使用数据库查询替代内存查询
|
||
db_users = get_users(openid=openid)
|
||
user = db_users[0] if db_users else None
|
||
|
||
# 用户不存在时创建新用户
|
||
if not user:
|
||
try:
|
||
new_user = {
|
||
"user_id": str(uuid.uuid4()), # 使用UUID生成唯一ID
|
||
"openid": wx_data["openid"],
|
||
"phone": str(phone_number),
|
||
"status": 1, # 默认启用状态
|
||
"balance": 0, # 初始余额设为0
|
||
"is_test_account":0
|
||
# museums字段需要另存关联表,此处暂时保留伪数据
|
||
}
|
||
create_user(new_user) # 调用CRUD创建方法
|
||
user = new_user
|
||
except Exception as e: # 捕获唯一约束等异常
|
||
logging.error(f"User creation failed: {str(e)}")
|
||
raise HTTPException(500, "User registration failed")
|
||
|
||
# 更新最后登录时间
|
||
update_data = {
|
||
"last_login_time": int(datetime.now().timestamp()),
|
||
"token": create_jwt(user["user_id"]) # 生成新token
|
||
}
|
||
updated_user = update_user(user["user_id"], update_data)
|
||
# ========== 数据库操作结束 ==========
|
||
|
||
logging.info(f"login return {user}")
|
||
# 生成token
|
||
return JSONResponse({
|
||
"token": create_jwt(user["user_id"]),
|
||
"user_info": {
|
||
"phone": phone_number,
|
||
"museums": get_museum_avail(user)
|
||
}
|
||
})
|
||
|
||
fake_phone_number = {
|
||
'hxbtest001':'19912345631',
|
||
'hxbtest002':'19912345632',
|
||
'hxbtest003':'19912345633',
|
||
'hxbtest004':'19912345634',
|
||
'hxbtest005':'19912345635',
|
||
}
|
||
@login_router.post("/testAccountLogin")
|
||
async def test_account_login(request: Request):
|
||
# 获取原始请求数据
|
||
try:
|
||
data = await request.json()
|
||
except json.JSONDecodeError:
|
||
raise HTTPException(400, "Invalid JSON")
|
||
# 校验必要参数
|
||
required_fields = ["account"]
|
||
if not all(k in data for k in required_fields):
|
||
raise HTTPException(400, "Missing required fields")
|
||
account = data.get('account')
|
||
phone_number = fake_phone_number.get(account,'19923145671')
|
||
logging.info(f"decrypt_data return {phone_number}")
|
||
# ========== 数据库操作开始 ==========
|
||
# 使用数据库查询替代内存查询
|
||
db_users = get_users(openid=account)
|
||
user = db_users[0] if db_users else None
|
||
|
||
# 用户不存在时创建新用户
|
||
if not user:
|
||
try:
|
||
new_user = {
|
||
"user_id": str(uuid.uuid4()), # 使用UUID生成唯一ID
|
||
"openid": account,
|
||
"phone": phone_number,
|
||
"status": 1, # 默认启用状态
|
||
"balance": 0, # 初始余额设为0
|
||
"is_test_account": 1
|
||
# museums字段需要另存关联表,此处暂时保留伪数据
|
||
}
|
||
create_user(new_user) # 调用CRUD创建方法
|
||
user = new_user
|
||
except Exception as e: # 捕获唯一约束等异常
|
||
logging.error(f"User creation failed: {str(e)}")
|
||
raise HTTPException(500, "User registration failed")
|
||
|
||
# 更新最后登录时间
|
||
update_data = {
|
||
"last_login_time": int(datetime.now().timestamp()),
|
||
"token": create_jwt(user["user_id"]) # 生成新token
|
||
}
|
||
updated_user = update_user(user["user_id"], update_data)
|
||
# ========== 数据库操作结束 ==========
|
||
|
||
logging.info(f"test account login return {user}")
|
||
# 生成token
|
||
return JSONResponse({
|
||
"token": create_jwt(user["user_id"]),
|
||
"user_info": {
|
||
"phone": phone_number,
|
||
"museums": get_museum_avail(user)
|
||
}
|
||
})
|
||
|
||
def get_museum_avail(user):
|
||
museum_list = get_museums(None, None)
|
||
id_list = [museum['id'] for museum in museum_list]
|
||
return id_list
|
||
|
||
from Crypto.Cipher import AES
|
||
import base64
|
||
import json
|
||
import requests
|
||
|
||
|
||
def decrypt_wechat_phone(
|
||
encrypted_data: str,
|
||
code: str,
|
||
appid: str,
|
||
secret: str
|
||
) -> str:
|
||
"""
|
||
微信手机号解密函数(完整本地处理版)
|
||
|
||
参数:
|
||
encrypted_data: 前端传递的加密数据
|
||
code: 前端通过uni.login获取的临时code
|
||
appid: 小程序appid
|
||
secret: 小程序appsecret
|
||
|
||
返回:
|
||
str: 解密后的手机号
|
||
|
||
异常:
|
||
ValueError: 当任何步骤失败时抛出
|
||
"""
|
||
# 步骤1:获取session_key
|
||
params = {
|
||
"appid": WX_APPID,
|
||
"secret": WX_SECRET,
|
||
"js_code": code,
|
||
"grant_type": "authorization_code"
|
||
}
|
||
|
||
try:
|
||
response = requests.get(WX_API_URL, params=params, timeout=5)
|
||
response.raise_for_status()
|
||
wx_data = response.json()
|
||
|
||
if 'session_key' not in wx_data:
|
||
raise ValueError(f"获取session_key失败: {wx_data.get('errmsg', '未知错误')}")
|
||
|
||
session_key = wx_data['session_key']
|
||
iv = encrypted_data.split('=')[1][:24] # 从加密数据中提取iv(根据实际情况调整)
|
||
except Exception as e:
|
||
raise ValueError(f"微信接口请求失败: {str(e)}")
|
||
|
||
# 步骤2:执行AES解密
|
||
try:
|
||
# Base64解码
|
||
session_key_bin = base64.b64decode(session_key + "=") # 补齐可能缺失的padding
|
||
encrypted_data_bin = base64.b64decode(encrypted_data)
|
||
iv_bin = base64.b64decode(iv)
|
||
|
||
# 创建解密器
|
||
cipher = AES.new(session_key_bin, AES.MODE_CBC, iv_bin)
|
||
|
||
# 执行解密
|
||
decrypted = cipher.decrypt(encrypted_data_bin)
|
||
|
||
# 去除PKCS#7填充
|
||
pad = decrypted[-1]
|
||
decrypted = decrypted[:-pad]
|
||
|
||
# 解析JSON
|
||
result = json.loads(decrypted.decode('utf-8'))
|
||
|
||
if 'purePhoneNumber' not in result:
|
||
raise ValueError("解密数据不包含手机号")
|
||
|
||
return result['purePhoneNumber']
|
||
except Exception as e:
|
||
raise ValueError(f"解密过程失败: {str(e)}")
|
||
|
||
|
||
# 在 login_service.py 中添加以下内容
|
||
|
||
async def optional_current_user(token: str = Depends(oauth2_scheme)):
|
||
"""
|
||
可选用户依赖项(不抛出401错误)
|
||
返回: 用户对象 或 None
|
||
"""
|
||
try:
|
||
payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[ALGORITHM])
|
||
user_id = payload.get("sub")
|
||
return get_user_by_id(user_id)
|
||
except (JWTError, StopIteration):
|
||
return None
|
||
|
||
@login_router.get("/verify")
|
||
async def verify_token(user: dict = Depends(optional_current_user)):
|
||
"""
|
||
Token 验证接口
|
||
返回格式:
|
||
{
|
||
"valid": bool,
|
||
"user": {
|
||
"user_id": str,
|
||
"phone": str,
|
||
"museums": List[int]
|
||
} | null
|
||
}
|
||
"""
|
||
logging.info(f"verify_token user={user}")
|
||
if user:
|
||
return JSONResponse({
|
||
"valid": True,
|
||
"user": {
|
||
"user_id": user["user_id"],
|
||
"phone": user["phone"],
|
||
#"museums": user["museums"]
|
||
}
|
||
})
|
||
else:
|
||
return JSONResponse({
|
||
"valid": False,
|
||
"user": None,
|
||
"detail": "无效的认证凭据"
|
||
})
|
||
|
||
@login_router.get("/get_museum_list")
|
||
async def get_museum_list(current_user = Depends(optional_current_user)):
|
||
#orders = Order.query.filter_by(user_id=current_user.id).all()
|
||
logging.info(f"get_museum_list user={current_user}")
|
||
museums_paid = get_users_museums_by_user_id(current_user.get('user_id'))
|
||
paid_id_list = [museum['museum_id'] for museum in museums_paid]
|
||
logging.info(f"get_museum_list paid={paid_id_list}")
|
||
museum_list = get_museums(None, None)
|
||
for museum in museum_list:
|
||
if museum.get('id') in paid_id_list:
|
||
museum['paid'] = True
|
||
else:
|
||
museum['paid'] = False
|
||
try:
|
||
return JSONResponse(museum_list)
|
||
except Exception as e:
|
||
raise HTTPException(500, str(e))
|
||
|
||
@login_router.get("/get_museum_id_auth")
|
||
async def get_museum_id_auth(current_user = Depends(optional_current_user)):
|
||
try:
|
||
museum_list = get_museums(None, None)
|
||
id_list = [museum['id'] for museum in museum_list]
|
||
logging.info(f"get_museum_id_auth={id_list}")
|
||
return JSONResponse(id_list)
|
||
except Exception as e:
|
||
raise HTTPException(500, str(e)) |