2025-05-26 21:38:46 +08:00
|
|
|
|
import pymysql
|
|
|
|
|
|
from pymysql import Connection
|
|
|
|
|
|
from pymysql.err import OperationalError, InterfaceError
|
|
|
|
|
|
from contextlib import contextmanager
|
|
|
|
|
|
from config import DATABASE_CONFIG
|
|
|
|
|
|
from datetime import datetime,timedelta
|
|
|
|
|
|
import logging
|
2025-07-10 22:04:44 +08:00
|
|
|
|
from typing import Union, List, Dict, Optional
|
2025-05-26 21:38:46 +08:00
|
|
|
|
from tenacity import retry, stop_after_attempt, wait_fixed, retry_if_exception_type
|
2025-07-10 22:04:44 +08:00
|
|
|
|
from dateutil.relativedelta import relativedelta
|
2025-05-26 21:38:46 +08:00
|
|
|
|
|
|
|
|
|
|
# 配置日志
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
|
|
logger = logging.getLogger("DB")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 重试策略(更保守的配置)
|
|
|
|
|
|
RETRY_CONFIG = {
|
|
|
|
|
|
'stop': stop_after_attempt(3), # 最大尝试2次
|
|
|
|
|
|
'wait': wait_fixed(1), # 固定1秒重试间隔
|
|
|
|
|
|
'retry': retry_if_exception_type((OperationalError, InterfaceError)),
|
|
|
|
|
|
'reraise': True, # 重新抛出原始异常
|
|
|
|
|
|
'before_sleep': lambda retry_state: logger.warning(
|
|
|
|
|
|
f"Retrying ({retry_state.attempt_number}/2) due to: {retry_state.outcome.exception()}"
|
|
|
|
|
|
)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
|
|
def get_connection(autocommit: bool = False):
|
|
|
|
|
|
"""获取单次数据库连接(带重试机制)"""
|
|
|
|
|
|
|
|
|
|
|
|
@retry(**RETRY_CONFIG)
|
|
|
|
|
|
def connect():
|
|
|
|
|
|
try:
|
|
|
|
|
|
return pymysql.connect(**DATABASE_CONFIG)
|
|
|
|
|
|
except (OperationalError, InterfaceError) as e:
|
|
|
|
|
|
logger.error(f"Connection failed: {str(e)}")
|
|
|
|
|
|
raise
|
|
|
|
|
|
conn = None
|
|
|
|
|
|
try:
|
|
|
|
|
|
conn = connect() # 原有重试逻辑不变
|
|
|
|
|
|
conn.autocommit(autocommit) # 新增关键设置
|
|
|
|
|
|
yield conn
|
|
|
|
|
|
if not autocommit:
|
|
|
|
|
|
conn.commit() # 非自动提交模式下统一提交
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
if conn and not autocommit:
|
|
|
|
|
|
conn.rollback() # 非自动提交模式下回滚
|
|
|
|
|
|
raise
|
|
|
|
|
|
finally:
|
|
|
|
|
|
if conn:
|
|
|
|
|
|
conn.close()
|
|
|
|
|
|
def get_query_type(query: str) -> str:
|
|
|
|
|
|
"""更精准识别查询类型"""
|
|
|
|
|
|
query = query.strip().upper()
|
|
|
|
|
|
if query.startswith(("WITH", "SELECT")):
|
|
|
|
|
|
return "SELECT"
|
|
|
|
|
|
if query.startswith(("INSERT", "REPLACE")):
|
|
|
|
|
|
return "INSERT"
|
|
|
|
|
|
if query.startswith("UPDATE"):
|
|
|
|
|
|
return "UPDATE"
|
|
|
|
|
|
if query.startswith("DELETE"):
|
|
|
|
|
|
return "DELETE"
|
|
|
|
|
|
return "OTHER"
|
|
|
|
|
|
|
|
|
|
|
|
def process_result(cursor, query: str) -> Union[int, list[dict[str, any]]]:
|
|
|
|
|
|
"""统一处理结果"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 自动识别查询类型
|
|
|
|
|
|
query_type =get_query_type(query)
|
|
|
|
|
|
if query_type == "SELECT":
|
|
|
|
|
|
return cursor.fetchall()
|
|
|
|
|
|
elif query_type == "INSERT":
|
|
|
|
|
|
return cursor.lastrowid # 返回插入的ID
|
|
|
|
|
|
elif query_type in ("UPDATE", "DELETE"):
|
|
|
|
|
|
return cursor.rowcount # 返回影响行数
|
|
|
|
|
|
else:
|
|
|
|
|
|
return cursor.rowcount # 其他操作返回影响行数
|
|
|
|
|
|
except IndexError:
|
|
|
|
|
|
raise ValueError("Invalid SQL query format")
|
|
|
|
|
|
|
|
|
|
|
|
@retry(**RETRY_CONFIG)
|
|
|
|
|
|
def execute_query(
|
|
|
|
|
|
query: str,
|
|
|
|
|
|
params: tuple | dict = None,
|
|
|
|
|
|
*,
|
|
|
|
|
|
connection: Optional[Connection] = None, # 明确类型 连接参数
|
|
|
|
|
|
autocommit: bool = False
|
|
|
|
|
|
):
|
|
|
|
|
|
"""
|
|
|
|
|
|
安全执行查询(适配低频操作)
|
|
|
|
|
|
:param read_only: 标记是否为只读查询(优化事务)
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 输入安全验证
|
|
|
|
|
|
if not query.strip():
|
|
|
|
|
|
raise ValueError("Empty query")
|
|
|
|
|
|
|
|
|
|
|
|
forbidden_keywords = ['DROP', 'TRUNCATE', 'GRANT']
|
|
|
|
|
|
if any(kw in query.upper() for kw in forbidden_keywords):
|
|
|
|
|
|
raise PermissionError("Dangerous operation detected")
|
|
|
|
|
|
|
|
|
|
|
|
# 连接管理逻辑变更
|
|
|
|
|
|
if connection: # 使用外部连接
|
|
|
|
|
|
cursor = connection.cursor()
|
|
|
|
|
|
try:
|
|
|
|
|
|
cursor.execute(query, params)
|
|
|
|
|
|
# 根据查询类型处理结果
|
|
|
|
|
|
if autocommit and not query.strip().upper().startswith("SELECT"):
|
|
|
|
|
|
connection.commit()
|
|
|
|
|
|
return process_result(cursor)
|
|
|
|
|
|
finally:
|
|
|
|
|
|
cursor.close()
|
|
|
|
|
|
else: # 新建连接
|
|
|
|
|
|
with get_connection(autocommit=autocommit) as conn:
|
|
|
|
|
|
with conn.cursor() as cursor:
|
|
|
|
|
|
# logging.info(f"exec sql {query} {params}")
|
|
|
|
|
|
cursor.execute(query, params)
|
|
|
|
|
|
return process_result(cursor,query)
|
|
|
|
|
|
|
|
|
|
|
|
def create_museum(data: dict):
|
|
|
|
|
|
sql = """
|
|
|
|
|
|
INSERT INTO mesum_overview
|
|
|
|
|
|
(name, brief, chat_id, photo_url, longitude, latitude, category,
|
|
|
|
|
|
create_time, create_date, update_time, update_date, address, free)
|
|
|
|
|
|
VALUES
|
|
|
|
|
|
(%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
now = int(datetime.now().timestamp())
|
|
|
|
|
|
params = (
|
|
|
|
|
|
data['name'],
|
|
|
|
|
|
data.get('brief'),
|
|
|
|
|
|
data['chat_id'],
|
|
|
|
|
|
data.get('photo_url'),
|
|
|
|
|
|
data.get('longitude'),
|
|
|
|
|
|
data.get('latitude'),
|
|
|
|
|
|
data.get('category'),
|
|
|
|
|
|
now,
|
|
|
|
|
|
datetime.fromtimestamp(now),
|
|
|
|
|
|
now,
|
|
|
|
|
|
datetime.fromtimestamp(now),
|
|
|
|
|
|
data.get('address'),
|
|
|
|
|
|
data.get('free', 0)
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
return execute_query(sql, params)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_museums(search: str = None, free: int = None):
|
|
|
|
|
|
base_sql = "SELECT * FROM mesum_overview WHERE 1=1"
|
|
|
|
|
|
params = []
|
|
|
|
|
|
|
|
|
|
|
|
if search:
|
|
|
|
|
|
base_sql += " AND (name LIKE %s OR brief LIKE %s)"
|
|
|
|
|
|
params.extend([f"%{search}%", f"%{search}%"])
|
|
|
|
|
|
|
|
|
|
|
|
if free is not None:
|
|
|
|
|
|
base_sql += " AND free = %s"
|
|
|
|
|
|
params.append(free)
|
|
|
|
|
|
|
|
|
|
|
|
base_sql += " ORDER BY create_time DESC"
|
|
|
|
|
|
|
|
|
|
|
|
return execute_query(base_sql, tuple(params))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def update_museum(museum_id: int, data: dict):
|
|
|
|
|
|
"""动态更新博物馆信息(仅更新data中包含的字段)"""
|
|
|
|
|
|
allowed_fields = {
|
|
|
|
|
|
'name': 'name = %s',
|
|
|
|
|
|
'brief': 'brief = %s',
|
|
|
|
|
|
'photo_url': 'photo_url = %s',
|
|
|
|
|
|
'longitude': 'longitude = %s',
|
|
|
|
|
|
'latitude': 'latitude = %s',
|
|
|
|
|
|
'category': 'category = %s',
|
|
|
|
|
|
'address': 'address = %s',
|
|
|
|
|
|
'free': 'free = %s'
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
update_fields = []
|
|
|
|
|
|
params = []
|
|
|
|
|
|
|
|
|
|
|
|
now = int(datetime.now().timestamp())
|
|
|
|
|
|
update_date = datetime.fromtimestamp(now)
|
|
|
|
|
|
|
|
|
|
|
|
# 收集动态字段
|
|
|
|
|
|
for field, sql_part in allowed_fields.items():
|
|
|
|
|
|
if field in data:
|
|
|
|
|
|
update_fields.append(sql_part)
|
|
|
|
|
|
params.append(data[field])
|
|
|
|
|
|
|
|
|
|
|
|
# 必须包含至少一个业务更新字段
|
|
|
|
|
|
if not update_fields:
|
|
|
|
|
|
raise ValueError("未提供有效更新字段")
|
|
|
|
|
|
|
|
|
|
|
|
# 添加自动更新时间
|
|
|
|
|
|
update_fields.extend([
|
|
|
|
|
|
'update_time = %s',
|
|
|
|
|
|
'update_date = %s'
|
|
|
|
|
|
])
|
|
|
|
|
|
params.extend([now, update_date])
|
|
|
|
|
|
|
|
|
|
|
|
# 构建动态SQL
|
|
|
|
|
|
set_clause = ", ".join(update_fields)
|
|
|
|
|
|
sql = f"""
|
|
|
|
|
|
UPDATE mesum_overview
|
|
|
|
|
|
SET {set_clause}
|
|
|
|
|
|
WHERE id = %s
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
params.append(museum_id)
|
|
|
|
|
|
|
|
|
|
|
|
# 执行更新
|
|
|
|
|
|
rowcount = execute_query(sql, tuple(params))
|
|
|
|
|
|
|
|
|
|
|
|
if rowcount == 0:
|
|
|
|
|
|
raise HTTPException(status_code=404, detail="博物馆不存在")
|
|
|
|
|
|
|
|
|
|
|
|
return get_museum_by_id(museum_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_museum_by_id(museum_id: int):
|
|
|
|
|
|
sql = "SELECT * FROM mesum_overview WHERE id = %s"
|
|
|
|
|
|
result = execute_query(sql, (museum_id,))
|
|
|
|
|
|
assert isinstance(result, list), "Unexpected return type"
|
|
|
|
|
|
return result[0] if result else None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def delete_museum(museum_id: int):
|
|
|
|
|
|
sql = "DELETE FROM mesum_overview WHERE id = %s"
|
|
|
|
|
|
rowcount = execute_query(sql, (museum_id,))
|
|
|
|
|
|
if rowcount == 0:
|
|
|
|
|
|
raise HTTPException(404, "Delete failed")
|
|
|
|
|
|
return {"message": f"Deleted {rowcount} museums"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 创建授权记录
|
|
|
|
|
|
def create_users_museum(data: dict):
|
|
|
|
|
|
sql = """
|
|
|
|
|
|
INSERT INTO rag_flow.users_museum
|
|
|
|
|
|
(user_id, museum_id, create_time, create_date, update_time, update_date)
|
|
|
|
|
|
VALUES
|
|
|
|
|
|
(%s, %s, %s, %s, %s, %s)
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
now = int(datetime.now().timestamp())
|
|
|
|
|
|
params = (
|
|
|
|
|
|
data['user_id'],
|
|
|
|
|
|
data['museum_id'],
|
|
|
|
|
|
now,
|
|
|
|
|
|
datetime.fromtimestamp(now),
|
|
|
|
|
|
now,
|
|
|
|
|
|
datetime.fromtimestamp(now)
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
return execute_query(sql, params)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 查询授权记录(多条件)
|
|
|
|
|
|
def get_users_museums(user_id: str = None, museum_id: int = None, get_all: bool = False):
|
|
|
|
|
|
base_sql = "SELECT * FROM rag_flow.users_museum WHERE 1=1"
|
|
|
|
|
|
params = []
|
|
|
|
|
|
|
|
|
|
|
|
if not get_all:
|
|
|
|
|
|
if user_id:
|
|
|
|
|
|
base_sql += " AND user_id = %s"
|
|
|
|
|
|
params.append(user_id)
|
|
|
|
|
|
if museum_id is not None:
|
|
|
|
|
|
base_sql += " AND museum_id = %s"
|
|
|
|
|
|
params.append(museum_id)
|
|
|
|
|
|
|
|
|
|
|
|
base_sql += " ORDER BY create_time DESC"
|
|
|
|
|
|
|
|
|
|
|
|
return execute_query(base_sql, tuple(params))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 按ID获取单条记录
|
|
|
|
|
|
def get_users_museums_by_user_id(user_id: str):
|
|
|
|
|
|
sql = "SELECT * FROM rag_flow.users_museum WHERE user_id = %s"
|
|
|
|
|
|
result = execute_query(sql, (user_id,))
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 更新授权信息
|
|
|
|
|
|
def update_users_museum(id: int, data: dict):
|
|
|
|
|
|
"""动态更新用户博物馆授权(仅更新data中包含的字段)"""
|
|
|
|
|
|
allowed_fields = {
|
|
|
|
|
|
'user_id': 'user_id = %s',
|
|
|
|
|
|
'museum_id': 'museum_id = %s'
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
update_fields = []
|
|
|
|
|
|
params = []
|
|
|
|
|
|
|
|
|
|
|
|
now = int(datetime.now().timestamp())
|
|
|
|
|
|
update_date = datetime.fromtimestamp(now)
|
|
|
|
|
|
|
|
|
|
|
|
# 收集动态字段
|
|
|
|
|
|
for field, sql_part in allowed_fields.items():
|
|
|
|
|
|
if field in data:
|
|
|
|
|
|
update_fields.append(sql_part)
|
|
|
|
|
|
params.append(data[field])
|
|
|
|
|
|
|
|
|
|
|
|
# 必须包含至少一个业务字段
|
|
|
|
|
|
if not update_fields:
|
|
|
|
|
|
raise ValueError("未提供有效更新字段")
|
|
|
|
|
|
|
|
|
|
|
|
# 添加自动更新时间
|
|
|
|
|
|
update_fields.extend([
|
|
|
|
|
|
'update_time = %s',
|
|
|
|
|
|
'update_date = %s'
|
|
|
|
|
|
])
|
|
|
|
|
|
params.extend([now, update_date])
|
|
|
|
|
|
|
|
|
|
|
|
# 构建动态SQL
|
|
|
|
|
|
set_clause = ", ".join(update_fields)
|
|
|
|
|
|
sql = f"""
|
|
|
|
|
|
UPDATE rag_flow.users_museum
|
|
|
|
|
|
SET {set_clause}
|
|
|
|
|
|
WHERE id = %s
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
params.append(id)
|
|
|
|
|
|
|
|
|
|
|
|
# 执行更新
|
|
|
|
|
|
rowcount = execute_query(sql, tuple(params))
|
|
|
|
|
|
|
|
|
|
|
|
if rowcount == 0:
|
|
|
|
|
|
raise HTTPException(status_code=404, detail="授权记录不存在")
|
|
|
|
|
|
|
|
|
|
|
|
return get_users_museum_by_id(id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 删除授权记录
|
|
|
|
|
|
def delete_users_museum(id: int):
|
|
|
|
|
|
sql = "DELETE FROM rag_flow.users_museum WHERE id = %s"
|
|
|
|
|
|
execute_query(sql, (id,))
|
|
|
|
|
|
return {"message": "User museum authorization deleted"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 批量检查授权状态
|
|
|
|
|
|
def check_auth_batch(user_id: str, museum_ids: list):
|
|
|
|
|
|
if not museum_ids:
|
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
|
|
placeholders = ','.join(['%s'] * len(museum_ids))
|
|
|
|
|
|
sql = f"""
|
|
|
|
|
|
SELECT museum_id
|
|
|
|
|
|
FROM rag_flow.users_museum
|
|
|
|
|
|
WHERE user_id = %s
|
|
|
|
|
|
AND museum_id IN ({placeholders})
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
params = [user_id] + museum_ids
|
|
|
|
|
|
result = execute_query(sql, params)
|
|
|
|
|
|
return [item['museum_id'] for item in result]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from datetime import datetime
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 创建用户
|
|
|
|
|
|
def create_user(data: dict):
|
|
|
|
|
|
sql = """
|
|
|
|
|
|
INSERT INTO rag_flow.users_info
|
|
|
|
|
|
(user_id, openid, phone, email, token, balance, status,
|
|
|
|
|
|
last_login_time, create_time, create_date, update_time, update_date)
|
|
|
|
|
|
VALUES
|
|
|
|
|
|
(%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) # 12个占位符
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
now = int(datetime.now().timestamp())
|
|
|
|
|
|
params = (
|
|
|
|
|
|
data['user_id'],
|
|
|
|
|
|
data.get('openid'),
|
|
|
|
|
|
data.get('phone'),
|
|
|
|
|
|
data.get('email'),
|
|
|
|
|
|
data.get('token'),
|
|
|
|
|
|
data.get('balance', 0), # 默认余额0
|
|
|
|
|
|
data.get('status', 1), # 默认状态1正常
|
|
|
|
|
|
data.get('last_login_time'),
|
|
|
|
|
|
now,
|
|
|
|
|
|
datetime.fromtimestamp(now),
|
|
|
|
|
|
now,
|
|
|
|
|
|
datetime.fromtimestamp(now)
|
|
|
|
|
|
)
|
|
|
|
|
|
#logging.info(f"create user {data} {sql} {params}")
|
|
|
|
|
|
return execute_query(sql, params)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 查询用户(多条件)
|
2025-07-10 22:04:44 +08:00
|
|
|
|
def get_users(status: int = None, email: str = None, phone: str = None,openid: str = None):
|
2025-05-26 21:38:46 +08:00
|
|
|
|
base_sql = "SELECT * FROM rag_flow.users_info WHERE 1=1"
|
|
|
|
|
|
params = []
|
|
|
|
|
|
|
|
|
|
|
|
if status is not None:
|
|
|
|
|
|
base_sql += " AND status = %s"
|
|
|
|
|
|
params.append(status)
|
|
|
|
|
|
|
|
|
|
|
|
if email:
|
|
|
|
|
|
base_sql += " AND email = %s"
|
|
|
|
|
|
params.append(email)
|
|
|
|
|
|
|
|
|
|
|
|
if phone:
|
|
|
|
|
|
base_sql += " AND phone = %s"
|
|
|
|
|
|
params.append(phone)
|
2025-07-10 22:04:44 +08:00
|
|
|
|
|
|
|
|
|
|
if openid:
|
|
|
|
|
|
base_sql += " AND openid = %s"
|
|
|
|
|
|
params.append(openid)
|
2025-05-26 21:38:46 +08:00
|
|
|
|
base_sql += " ORDER BY create_time DESC"
|
|
|
|
|
|
|
|
|
|
|
|
return execute_query(base_sql, tuple(params))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 按用户ID获取用户
|
|
|
|
|
|
def get_user_by_id(user_id: str):
|
|
|
|
|
|
sql = "SELECT * FROM rag_flow.users_info WHERE user_id = %s"
|
|
|
|
|
|
result = execute_query(sql, (user_id,))
|
|
|
|
|
|
return result[0] if result else None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 更新用户信息
|
|
|
|
|
|
def update_user(user_id: str, data: dict):
|
|
|
|
|
|
"""动态更新用户信息(仅更新data中包含的字段)"""
|
|
|
|
|
|
# 允许更新的字段白名单
|
|
|
|
|
|
allowed_fields = {
|
|
|
|
|
|
'phone': 'phone = %s',
|
|
|
|
|
|
'email': 'email = %s',
|
|
|
|
|
|
'token': 'token = %s',
|
|
|
|
|
|
'balance': 'balance = %s',
|
|
|
|
|
|
'status': 'status = %s',
|
|
|
|
|
|
'last_login_time': 'last_login_time = %s'
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
# 过滤有效更新字段
|
|
|
|
|
|
update_fields = []
|
|
|
|
|
|
params = []
|
|
|
|
|
|
|
|
|
|
|
|
now = int(datetime.now().timestamp())
|
|
|
|
|
|
update_date = datetime.fromtimestamp(now)
|
|
|
|
|
|
|
|
|
|
|
|
# 收集动态字段
|
|
|
|
|
|
for field, sql_part in allowed_fields.items():
|
|
|
|
|
|
if field in data:
|
|
|
|
|
|
update_fields.append(sql_part)
|
|
|
|
|
|
params.append(data[field])
|
|
|
|
|
|
|
|
|
|
|
|
# 必须包含至少一个更新字段(除自动更新的时间字段)
|
|
|
|
|
|
if not update_fields:
|
|
|
|
|
|
raise ValueError("没有提供有效更新字段")
|
|
|
|
|
|
|
|
|
|
|
|
# 添加自动更新的时间字段
|
|
|
|
|
|
update_fields.extend([
|
|
|
|
|
|
'update_time = %s',
|
|
|
|
|
|
'update_date = %s'
|
|
|
|
|
|
])
|
|
|
|
|
|
params.extend([now, update_date])
|
|
|
|
|
|
|
|
|
|
|
|
# 构建动态SQL
|
|
|
|
|
|
set_clause = ", ".join(update_fields)
|
|
|
|
|
|
sql = f"""
|
|
|
|
|
|
UPDATE rag_flow.users_info
|
|
|
|
|
|
SET {set_clause}
|
|
|
|
|
|
WHERE user_id = %s
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
# 添加用户ID作为最后参数
|
|
|
|
|
|
params.append(user_id)
|
|
|
|
|
|
|
|
|
|
|
|
# 执行更新
|
|
|
|
|
|
rowcount = execute_query(sql, tuple(params))
|
|
|
|
|
|
|
|
|
|
|
|
if rowcount == 0:
|
|
|
|
|
|
raise HTTPException(status_code=404, detail="用户不存在")
|
|
|
|
|
|
|
|
|
|
|
|
return get_user_by_id(user_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 删除用户
|
|
|
|
|
|
def delete_user(user_id: str):
|
|
|
|
|
|
sql = "DELETE FROM rag_flow.users_info WHERE user_id = %s"
|
|
|
|
|
|
execute_query(sql, (user_id,))
|
|
|
|
|
|
return {"message": "User deleted"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 检查openid是否存在
|
|
|
|
|
|
def check_openid_exists(openid: str):
|
|
|
|
|
|
sql = "SELECT 1 FROM rag_flow.users_info WHERE openid = %s"
|
|
|
|
|
|
result = execute_query(sql, (openid,))
|
|
|
|
|
|
return bool(result)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 更新用户余额
|
|
|
|
|
|
def update_balance(user_id: str, amount: int):
|
|
|
|
|
|
sql = """
|
|
|
|
|
|
UPDATE rag_flow.users_info
|
|
|
|
|
|
SET balance = balance + %s,
|
|
|
|
|
|
update_time = %s,
|
|
|
|
|
|
update_date = %s
|
|
|
|
|
|
WHERE user_id = %s
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
now = int(datetime.now().timestamp())
|
|
|
|
|
|
params = (
|
|
|
|
|
|
amount,
|
|
|
|
|
|
now,
|
|
|
|
|
|
datetime.fromtimestamp(now),
|
|
|
|
|
|
user_id
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
execute_query(sql, params)
|
|
|
|
|
|
return get_user_by_id(user_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 更新登录信息
|
|
|
|
|
|
def update_login_info(user_id: str, token: str):
|
|
|
|
|
|
sql = """
|
|
|
|
|
|
UPDATE rag_flow.users_info
|
|
|
|
|
|
SET token = %s,
|
|
|
|
|
|
last_login_time = %s,
|
|
|
|
|
|
update_time = %s,
|
|
|
|
|
|
update_date = %s
|
|
|
|
|
|
WHERE user_id = %s
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
now = int(datetime.now().timestamp())
|
|
|
|
|
|
params = (
|
|
|
|
|
|
token,
|
|
|
|
|
|
now,
|
|
|
|
|
|
now,
|
|
|
|
|
|
datetime.fromtimestamp(now),
|
|
|
|
|
|
user_id
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
execute_query(sql, params)
|
2025-07-10 22:04:44 +08:00
|
|
|
|
return get_user_by_id(user_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# database.py
|
|
|
|
|
|
#------------------------------------------------------------
|
|
|
|
|
|
def get_museum_subscriptions_by_museum(museum_id: int) -> list:
|
|
|
|
|
|
"""
|
|
|
|
|
|
获取指定博物馆的所有有效订阅套餐
|
|
|
|
|
|
|
|
|
|
|
|
功能说明:
|
|
|
|
|
|
- 查询指定博物馆的所有可用订阅套餐
|
|
|
|
|
|
- 返回结果包含博物馆订阅信息和关联的模板信息
|
|
|
|
|
|
|
|
|
|
|
|
参数说明:
|
|
|
|
|
|
- museum_id: 博物馆ID
|
|
|
|
|
|
|
|
|
|
|
|
返回:
|
|
|
|
|
|
- 包含订阅信息的字典列表
|
|
|
|
|
|
|
|
|
|
|
|
重要逻辑:
|
|
|
|
|
|
- 只返回 is_active=1 的有效订阅
|
|
|
|
|
|
- 通过 JOIN 关联 subscription_templates 表获取模板信息
|
|
|
|
|
|
- 结果按有效期类型排序,便于前端展示
|
|
|
|
|
|
"""
|
|
|
|
|
|
sql = """
|
|
|
|
|
|
SELECT
|
|
|
|
|
|
ms.id,
|
|
|
|
|
|
ms.museum_id,
|
|
|
|
|
|
ms.template_id,
|
|
|
|
|
|
ms.price,
|
|
|
|
|
|
ms.sub_id,
|
|
|
|
|
|
ms.is_active,
|
|
|
|
|
|
ms.created_date,
|
|
|
|
|
|
ms.updated_date,
|
|
|
|
|
|
st.name AS template_name,
|
|
|
|
|
|
st.description AS template_description,
|
|
|
|
|
|
st.validity_type
|
|
|
|
|
|
FROM museum_subscriptions ms
|
|
|
|
|
|
JOIN subscription_templates st ON ms.template_id = st.id
|
|
|
|
|
|
WHERE ms.museum_id = %s AND ms.is_active = 1
|
|
|
|
|
|
ORDER BY st.validity_type
|
|
|
|
|
|
"""
|
|
|
|
|
|
return execute_query(sql, (museum_id,))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_museum_subscription_by_id(subscription_id: str) -> dict:
|
|
|
|
|
|
"""
|
|
|
|
|
|
根据ID获取博物馆订阅套餐的详细信息
|
|
|
|
|
|
|
|
|
|
|
|
功能说明:
|
|
|
|
|
|
- 通过订阅ID获取完整的订阅信息
|
|
|
|
|
|
- 包含关联的模板信息
|
|
|
|
|
|
|
|
|
|
|
|
参数说明:
|
|
|
|
|
|
- subscription_id: 博物馆订阅ID
|
|
|
|
|
|
|
|
|
|
|
|
返回:
|
|
|
|
|
|
- 订阅信息的字典,如果不存在则返回None
|
|
|
|
|
|
|
|
|
|
|
|
重要逻辑:
|
|
|
|
|
|
- 使用内连接获取模板信息
|
|
|
|
|
|
- 确保返回完整的订阅+模板数据
|
|
|
|
|
|
"""
|
|
|
|
|
|
sql = """
|
|
|
|
|
|
SELECT
|
|
|
|
|
|
ms.*,
|
|
|
|
|
|
st.name AS template_name,
|
|
|
|
|
|
st.description AS template_description,
|
|
|
|
|
|
st.validity_type
|
|
|
|
|
|
FROM museum_subscriptions ms
|
|
|
|
|
|
JOIN subscription_templates st ON ms.template_id = st.id
|
|
|
|
|
|
WHERE ms.sub_id = %s
|
|
|
|
|
|
"""
|
|
|
|
|
|
result = execute_query(sql, (subscription_id,))
|
|
|
|
|
|
return result[0] if result else None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_order(order_data: dict) -> int:
|
|
|
|
|
|
"""
|
|
|
|
|
|
创建新的订阅订单
|
|
|
|
|
|
|
|
|
|
|
|
功能说明:
|
|
|
|
|
|
- 在 subscription_orders 表中插入新订单记录
|
|
|
|
|
|
|
|
|
|
|
|
参数说明:
|
|
|
|
|
|
- order_data: 包含订单数据的字典,字段包括:
|
|
|
|
|
|
order_id: 订单号 (必需)
|
|
|
|
|
|
user_id: 用户ID (必需)
|
|
|
|
|
|
museum_subscription_id: 博物馆订阅ID (必需)
|
|
|
|
|
|
amount: 订单金额 (默认0.00)
|
|
|
|
|
|
status: 订单状态 (默认'created')
|
|
|
|
|
|
transaction_id: 支付交易号 (可选)
|
|
|
|
|
|
create_date: 创建时间 (默认当前时间)
|
|
|
|
|
|
pay_time: 支付时间 (可选)
|
|
|
|
|
|
|
|
|
|
|
|
返回:
|
|
|
|
|
|
- 执行结果的行数
|
|
|
|
|
|
|
|
|
|
|
|
重要逻辑:
|
|
|
|
|
|
- 为可选字段提供默认值
|
|
|
|
|
|
- 使用参数化查询防止SQL注入
|
|
|
|
|
|
- 处理所有必需的订单字段
|
|
|
|
|
|
"""
|
|
|
|
|
|
sql = """
|
|
|
|
|
|
INSERT INTO subscription_orders (
|
|
|
|
|
|
order_id,
|
|
|
|
|
|
user_id,
|
|
|
|
|
|
museum_subscription_id,
|
|
|
|
|
|
amount,
|
|
|
|
|
|
status,
|
|
|
|
|
|
transaction_id,
|
|
|
|
|
|
create_date,
|
|
|
|
|
|
pay_time
|
|
|
|
|
|
) VALUES (
|
|
|
|
|
|
%s, %s, %s, %s, %s, %s, %s, %s
|
|
|
|
|
|
)
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 设置默认值
|
|
|
|
|
|
params = (
|
|
|
|
|
|
order_data.get("order_id"),
|
|
|
|
|
|
order_data.get("user_id"),
|
|
|
|
|
|
order_data.get("museum_subscription_id"),
|
|
|
|
|
|
order_data.get("amount", 0.00),
|
|
|
|
|
|
order_data.get("status", "created"),
|
|
|
|
|
|
order_data.get("transaction_id"),
|
|
|
|
|
|
order_data.get("create_date", datetime.now()),
|
|
|
|
|
|
order_data.get("pay_time")
|
|
|
|
|
|
)
|
|
|
|
|
|
return execute_query(sql, params, autocommit=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def update_order(order_id: str, update_data: dict) -> int:
|
|
|
|
|
|
"""
|
|
|
|
|
|
更新订单信息
|
|
|
|
|
|
|
|
|
|
|
|
功能说明:
|
|
|
|
|
|
- 动态更新订单的字段
|
|
|
|
|
|
- 只更新提供的字段
|
|
|
|
|
|
|
|
|
|
|
|
参数说明:
|
|
|
|
|
|
- order_id: 要更新的订单ID
|
|
|
|
|
|
- update_data: 包含更新字段的字典,可选字段包括:
|
|
|
|
|
|
status: 订单状态
|
|
|
|
|
|
transaction_id: 支付交易号
|
|
|
|
|
|
amount: 订单金额
|
|
|
|
|
|
pay_time: 支付时间
|
|
|
|
|
|
|
|
|
|
|
|
返回:
|
|
|
|
|
|
- 受影响的行数
|
|
|
|
|
|
|
|
|
|
|
|
重要逻辑:
|
|
|
|
|
|
- 只允许更新预定义的字段
|
|
|
|
|
|
- 防止更新不允许的字段
|
|
|
|
|
|
- 使用参数化查询确保安全
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 定义允许更新的字段及其SQL部分
|
|
|
|
|
|
allowed_fields = {
|
|
|
|
|
|
'status': 'status = %s',
|
|
|
|
|
|
'transaction_id': 'transaction_id = %s',
|
|
|
|
|
|
'amount': 'amount = %s',
|
|
|
|
|
|
'pay_time': 'pay_time = %s'
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
update_fields = []
|
|
|
|
|
|
params = []
|
|
|
|
|
|
|
|
|
|
|
|
# 收集要更新的字段
|
|
|
|
|
|
for field, sql_part in allowed_fields.items():
|
|
|
|
|
|
if field in update_data:
|
|
|
|
|
|
update_fields.append(sql_part)
|
|
|
|
|
|
params.append(update_data[field])
|
|
|
|
|
|
|
|
|
|
|
|
# 如果没有提供有效更新字段,直接返回
|
|
|
|
|
|
if not update_fields:
|
|
|
|
|
|
return 0
|
|
|
|
|
|
|
|
|
|
|
|
# 构建动态SQL
|
|
|
|
|
|
set_clause = ", ".join(update_fields)
|
|
|
|
|
|
sql = f"UPDATE subscription_orders SET {set_clause} WHERE order_id = %s"
|
|
|
|
|
|
params.append(order_id)
|
|
|
|
|
|
|
|
|
|
|
|
return execute_query(sql, tuple(params), autocommit=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Union, List, Dict, Optional
|
|
|
|
|
|
|
|
|
|
|
|
def get_order_by_id(order_id: str = None, user_id: str = None,combined = None) -> Union[Dict, List[Dict], None]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
根据订单ID或用户ID查询订单信息
|
|
|
|
|
|
|
|
|
|
|
|
功能说明:
|
|
|
|
|
|
- 支持通过 order_id 或 user_id 查询订单信息
|
|
|
|
|
|
- 当传入 order_id 时,返回单个订单(字典)
|
|
|
|
|
|
- 当传入 user_id 时,返回该用户的所有订单(列表)
|
|
|
|
|
|
|
|
|
|
|
|
参数说明:
|
|
|
|
|
|
- order_id: 订单号(字符串)
|
|
|
|
|
|
- user_id: 用户ID(字符串)
|
|
|
|
|
|
|
|
|
|
|
|
返回:
|
|
|
|
|
|
- 如果传入 order_id: 返回单个订单的字典(如果存在)
|
|
|
|
|
|
- 如果传入 user_id: 返回该用户的所有订单列表(可能为空)
|
|
|
|
|
|
- 如果两个参数都未传入,返回 None
|
|
|
|
|
|
|
|
|
|
|
|
重要逻辑:
|
|
|
|
|
|
- 使用参数化查询防止 SQL 注入
|
|
|
|
|
|
- 当同时传入 order_id 和 user_id 时,优先使用 order_id
|
|
|
|
|
|
"""
|
|
|
|
|
|
if not order_id and not user_id:
|
|
|
|
|
|
return None # 两个参数都未传入,直接返回 None
|
|
|
|
|
|
sql_wo_condition= """
|
|
|
|
|
|
SELECT
|
|
|
|
|
|
o.order_id,
|
|
|
|
|
|
o.user_id,
|
|
|
|
|
|
u.phone,
|
|
|
|
|
|
u.openid,
|
|
|
|
|
|
o.museum_subscription_id AS subscription_id,
|
|
|
|
|
|
ms.museum_id,
|
|
|
|
|
|
ms.template_id,
|
|
|
|
|
|
t.name AS template_name,
|
|
|
|
|
|
t.description AS template_desc,
|
|
|
|
|
|
t.validity_type,
|
|
|
|
|
|
ms.price AS subscription_price,
|
|
|
|
|
|
o.amount AS order_amount,
|
|
|
|
|
|
o.status AS order_status,
|
|
|
|
|
|
o.transaction_id,
|
|
|
|
|
|
o.create_date AS order_create_time,
|
|
|
|
|
|
o.pay_time,
|
|
|
|
|
|
us.start_date,
|
|
|
|
|
|
us.end_date,
|
|
|
|
|
|
us.is_active AS subscription_active,
|
|
|
|
|
|
mo.name AS museum_name
|
|
|
|
|
|
FROM
|
|
|
|
|
|
rag_flow.subscription_orders o
|
|
|
|
|
|
LEFT JOIN rag_flow.users_info u ON o.user_id = u.user_id
|
|
|
|
|
|
LEFT JOIN rag_flow.museum_subscriptions ms ON o.museum_subscription_id = ms.sub_id
|
|
|
|
|
|
LEFT JOIN rag_flow.subscription_templates t ON ms.template_id = t.id
|
|
|
|
|
|
LEFT JOIN rag_flow.user_subscriptions us ON o.order_id = us.order_id
|
|
|
|
|
|
LEFT JOIN rag_flow.mesum_overview mo ON ms.museum_id = mo.id
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 优先使用 order_id 查询
|
|
|
|
|
|
if order_id and not combined:
|
|
|
|
|
|
sql = "SELECT * FROM subscription_orders WHERE order_id = %s"
|
|
|
|
|
|
result = execute_query(sql, (order_id,))
|
|
|
|
|
|
return result[0] if result and len(result) > 0 else None # 返回单个订单
|
|
|
|
|
|
|
|
|
|
|
|
# 如果 order_id 不存在,使用 user_id 查询
|
|
|
|
|
|
if user_id and not combined:
|
|
|
|
|
|
sql = "SELECT * FROM subscription_orders WHERE user_id = %s"
|
|
|
|
|
|
result = execute_query(sql, (user_id,))
|
|
|
|
|
|
return result if result else [] # 返回所有订单(列表)
|
|
|
|
|
|
if user_id and combined:
|
|
|
|
|
|
sql = sql_wo_condition + f"\n WHERE o.user_id = %s"
|
|
|
|
|
|
result = execute_query(sql, (user_id,))
|
|
|
|
|
|
return result if result else [] # 返回所有订单(列表)
|
|
|
|
|
|
if order_id and combined:
|
|
|
|
|
|
sql = sql_wo_condition + f"\n WHERE o.order_id = %s"
|
|
|
|
|
|
result = execute_query(sql, (order_id,))
|
|
|
|
|
|
return result[0] if result and len(result) > 0 else None # 返回单个订单
|
|
|
|
|
|
|
|
|
|
|
|
def create_user_subscription(data: dict) -> int:
|
|
|
|
|
|
"""
|
|
|
|
|
|
创建用户订阅记录
|
|
|
|
|
|
|
|
|
|
|
|
功能说明:
|
|
|
|
|
|
- 在 user_subscriptions 表中插入新记录
|
|
|
|
|
|
- 表示用户购买并激活了一个订阅
|
|
|
|
|
|
|
|
|
|
|
|
参数说明:
|
|
|
|
|
|
- data: 包含订阅数据的字典,字段包括:
|
|
|
|
|
|
user_id: 用户ID (必需)
|
|
|
|
|
|
museum_subscription_id: 博物馆订阅ID (必需)
|
|
|
|
|
|
order_id: 关联的订单ID (必需)
|
|
|
|
|
|
start_date: 开始时间 (默认当前时间)
|
|
|
|
|
|
end_date: 结束时间 (必需)
|
|
|
|
|
|
is_active: 是否激活 (默认1)
|
|
|
|
|
|
create_date: 创建时间 (默认当前时间)
|
|
|
|
|
|
|
|
|
|
|
|
返回:
|
|
|
|
|
|
- 执行结果的行数
|
|
|
|
|
|
|
|
|
|
|
|
重要逻辑:
|
|
|
|
|
|
- 为可选字段提供默认值
|
|
|
|
|
|
- 确保所有必需字段都有值
|
|
|
|
|
|
- 处理时间字段的默认值
|
|
|
|
|
|
"""
|
|
|
|
|
|
sql = """
|
|
|
|
|
|
INSERT INTO user_subscriptions (
|
|
|
|
|
|
user_id,
|
|
|
|
|
|
museum_subscription_id,
|
|
|
|
|
|
order_id,
|
|
|
|
|
|
start_date,
|
|
|
|
|
|
end_date,
|
|
|
|
|
|
is_active,
|
|
|
|
|
|
create_date
|
|
|
|
|
|
) VALUES (
|
|
|
|
|
|
%s, %s, %s, %s, %s, %s, %s
|
|
|
|
|
|
)
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 设置默认值
|
|
|
|
|
|
params = (
|
|
|
|
|
|
data.get("user_id"),
|
|
|
|
|
|
data.get("museum_subscription_id"),
|
|
|
|
|
|
data.get("order_id"),
|
|
|
|
|
|
data.get("start_date", datetime.now()),
|
|
|
|
|
|
data.get("end_date"),
|
|
|
|
|
|
data.get("is_active", 1),
|
|
|
|
|
|
data.get("create_date", datetime.now())
|
|
|
|
|
|
)
|
|
|
|
|
|
return execute_query(sql, params, autocommit=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def deactivate_previous_subscriptions(user_id: str, museum_subscription_id: str) -> int:
|
|
|
|
|
|
"""
|
|
|
|
|
|
禁用用户在同一博物馆的旧订阅
|
|
|
|
|
|
|
|
|
|
|
|
功能说明:
|
|
|
|
|
|
- 将用户在同一博物馆的所有激活订阅设为非激活状态
|
|
|
|
|
|
- 确保同一博物馆只有一个激活订阅
|
|
|
|
|
|
|
|
|
|
|
|
参数说明:
|
|
|
|
|
|
- user_id: 用户ID
|
|
|
|
|
|
- museum_id: 博物馆ID
|
|
|
|
|
|
|
|
|
|
|
|
返回:
|
|
|
|
|
|
- 受影响的行数
|
|
|
|
|
|
|
|
|
|
|
|
重要逻辑:
|
|
|
|
|
|
- 通过JOIN关联博物馆订阅表
|
|
|
|
|
|
- 只更新同一博物馆的订阅
|
|
|
|
|
|
- 保持历史订阅记录,只修改激活状态
|
|
|
|
|
|
"""
|
|
|
|
|
|
sql = """
|
|
|
|
|
|
UPDATE user_subscriptions us
|
|
|
|
|
|
JOIN museum_subscriptions ms ON us.museum_subscription_id = ms.sub_id
|
|
|
|
|
|
SET us.is_active = 0
|
|
|
|
|
|
WHERE us.user_id = %s AND us.museum_subscription_id = %s AND us.is_active = 1
|
|
|
|
|
|
"""
|
|
|
|
|
|
return execute_query(sql, (user_id, museum_subscription_id), autocommit=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_user_by_id(user_id: str) -> dict:
|
|
|
|
|
|
"""
|
|
|
|
|
|
根据用户ID获取用户信息
|
|
|
|
|
|
|
|
|
|
|
|
功能说明:
|
|
|
|
|
|
- 通过用户ID查询用户基本信息
|
|
|
|
|
|
|
|
|
|
|
|
参数说明:
|
|
|
|
|
|
- user_id: 用户ID
|
|
|
|
|
|
|
|
|
|
|
|
返回:
|
|
|
|
|
|
- 用户信息的字典,如果不存在则返回None
|
|
|
|
|
|
|
|
|
|
|
|
重要逻辑:
|
|
|
|
|
|
- 直接查询用户表的所有字段
|
|
|
|
|
|
"""
|
|
|
|
|
|
sql = "SELECT * FROM users_info WHERE user_id = %s"
|
|
|
|
|
|
result = execute_query(sql, (user_id,))
|
|
|
|
|
|
return result[0] if result else None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_subscription_template_by_id(template_id: int) -> dict:
|
|
|
|
|
|
"""
|
|
|
|
|
|
根据模板ID获取订阅模板信息
|
|
|
|
|
|
|
|
|
|
|
|
功能说明:
|
|
|
|
|
|
- 通过模板ID查询订阅模板详情
|
|
|
|
|
|
|
|
|
|
|
|
参数说明:
|
|
|
|
|
|
- template_id: 模板ID
|
|
|
|
|
|
|
|
|
|
|
|
返回:
|
|
|
|
|
|
- 模板信息的字典,如果不存在则返回None
|
|
|
|
|
|
"""
|
|
|
|
|
|
sql = "SELECT * FROM subscription_templates WHERE id = %s"
|
|
|
|
|
|
result = execute_query(sql, (template_id,))
|
|
|
|
|
|
return result[0] if result else None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_active_user_subscription(user_id: str, museum_id: int) -> dict:
|
|
|
|
|
|
"""
|
|
|
|
|
|
获取用户在指定博物馆的有效订阅
|
|
|
|
|
|
|
|
|
|
|
|
功能说明:
|
|
|
|
|
|
- 查询用户在特定博物馆的当前有效订阅
|
|
|
|
|
|
- 有效订阅定义为: 已激活且未过期
|
|
|
|
|
|
|
|
|
|
|
|
参数说明:
|
|
|
|
|
|
- user_id: 用户ID
|
|
|
|
|
|
- museum_id: 博物馆ID
|
|
|
|
|
|
|
|
|
|
|
|
返回:
|
|
|
|
|
|
- 订阅信息的字典,如果不存在则返回None
|
|
|
|
|
|
|
|
|
|
|
|
重要逻辑:
|
|
|
|
|
|
- 通过JOIN关联博物馆订阅表
|
|
|
|
|
|
- 检查 is_active=1 和 end_date > NOW()
|
|
|
|
|
|
- 返回最新到期的订阅
|
|
|
|
|
|
"""
|
|
|
|
|
|
sql = """
|
|
|
|
|
|
SELECT us.*
|
|
|
|
|
|
FROM user_subscriptions us
|
|
|
|
|
|
JOIN museum_subscriptions ms ON us.museum_subscription_id = ms.id
|
|
|
|
|
|
WHERE us.user_id = %s
|
|
|
|
|
|
AND ms.museum_id = %s
|
|
|
|
|
|
AND us.is_active = 1
|
|
|
|
|
|
AND us.end_date > NOW()
|
|
|
|
|
|
ORDER BY us.end_date DESC
|
|
|
|
|
|
LIMIT 1
|
|
|
|
|
|
"""
|
|
|
|
|
|
result = execute_query(sql, (user_id, museum_id))
|
|
|
|
|
|
return result[0] if result else None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_user_subscription_history(user_id: str) -> list:
|
|
|
|
|
|
"""
|
|
|
|
|
|
获取用户的订阅历史记录
|
|
|
|
|
|
|
|
|
|
|
|
功能说明:
|
|
|
|
|
|
- 查询用户的所有订阅记录(包括历史和当前)
|
|
|
|
|
|
- 返回完整的订阅详情,包含博物馆和模板信息
|
|
|
|
|
|
|
|
|
|
|
|
参数说明:
|
|
|
|
|
|
- user_id: 用户ID
|
|
|
|
|
|
|
|
|
|
|
|
返回:
|
|
|
|
|
|
- 包含订阅历史记录的字典列表
|
|
|
|
|
|
|
|
|
|
|
|
重要逻辑:
|
|
|
|
|
|
- 通过多层JOIN关联所有相关表
|
|
|
|
|
|
- 包含博物馆名称、模板信息等
|
|
|
|
|
|
- 按开始时间倒序排列,最新订阅在前
|
|
|
|
|
|
"""
|
|
|
|
|
|
sql = """
|
|
|
|
|
|
SELECT
|
|
|
|
|
|
us.*,
|
|
|
|
|
|
ms.price,
|
|
|
|
|
|
ms.sub_id,
|
|
|
|
|
|
st.name AS template_name,
|
|
|
|
|
|
st.description AS template_description,
|
|
|
|
|
|
st.validity_type,
|
|
|
|
|
|
mo.name AS museum_name
|
|
|
|
|
|
FROM user_subscriptions us
|
|
|
|
|
|
JOIN museum_subscriptions ms ON us.museum_subscription_id = ms.id
|
|
|
|
|
|
JOIN subscription_templates st ON ms.template_id = st.id
|
|
|
|
|
|
JOIN mesum_overview mo ON ms.museum_id = mo.id
|
|
|
|
|
|
WHERE us.user_id = %s
|
|
|
|
|
|
ORDER BY us.start_date DESC
|
|
|
|
|
|
"""
|
|
|
|
|
|
return execute_query(sql, (user_id,))
|
|
|
|
|
|
|
|
|
|
|
|
def get_user_subscription_by_order(order_id: str) -> dict:
|
|
|
|
|
|
"""
|
|
|
|
|
|
根据订单ID获取用户订阅信息
|
|
|
|
|
|
|
|
|
|
|
|
功能说明:
|
|
|
|
|
|
- 通过订单ID查询关联的用户订阅
|
|
|
|
|
|
|
|
|
|
|
|
参数说明:
|
|
|
|
|
|
- order_id: 订单ID
|
|
|
|
|
|
|
|
|
|
|
|
返回:
|
|
|
|
|
|
- 订阅信息的字典,如果不存在则返回None
|
|
|
|
|
|
|
|
|
|
|
|
重要逻辑:
|
|
|
|
|
|
- 用于支付回调后验证订阅是否已创建
|
|
|
|
|
|
"""
|
|
|
|
|
|
sql = "SELECT * FROM user_subscriptions WHERE order_id = %s"
|
|
|
|
|
|
result = execute_query(sql, (order_id,))
|
|
|
|
|
|
return result[0] if result else None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def activate_free_subscription(user_id: str, museum_id: int) -> str:
|
|
|
|
|
|
"""
|
|
|
|
|
|
激活免费订阅
|
|
|
|
|
|
|
|
|
|
|
|
功能说明:
|
|
|
|
|
|
- 为用户在指定博物馆激活免费订阅
|
|
|
|
|
|
- 创建订单记录和订阅记录
|
|
|
|
|
|
|
|
|
|
|
|
参数说明:
|
|
|
|
|
|
- user_id: 用户ID
|
|
|
|
|
|
- museum_id: 博物馆ID
|
|
|
|
|
|
|
|
|
|
|
|
返回:
|
|
|
|
|
|
- 创建的订单ID
|
|
|
|
|
|
|
|
|
|
|
|
重要逻辑:
|
|
|
|
|
|
- 查找博物馆的免费订阅
|
|
|
|
|
|
- 创建订单记录 (状态为activated)
|
|
|
|
|
|
- 创建订阅记录 (有效期为7天)
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 1. 获取博物馆的免费订阅
|
|
|
|
|
|
sql = """
|
|
|
|
|
|
SELECT ms.id
|
|
|
|
|
|
FROM museum_subscriptions ms
|
|
|
|
|
|
JOIN subscription_templates st ON ms.template_id = st.id
|
|
|
|
|
|
WHERE ms.museum_id = %s
|
|
|
|
|
|
AND st.validity_type = 'free'
|
|
|
|
|
|
AND ms.is_active = 1
|
|
|
|
|
|
LIMIT 1
|
|
|
|
|
|
"""
|
|
|
|
|
|
result = execute_query(sql, (museum_id,))
|
|
|
|
|
|
if not result:
|
|
|
|
|
|
raise ValueError("该博物馆没有可用的免费订阅")
|
|
|
|
|
|
|
|
|
|
|
|
subscription_id = result[0]["id"]
|
|
|
|
|
|
|
|
|
|
|
|
# 2. 创建免费订单
|
|
|
|
|
|
order_id = f"FREE_{int(time.time())}"
|
|
|
|
|
|
create_order({
|
|
|
|
|
|
"order_id": order_id,
|
|
|
|
|
|
"user_id": user_id,
|
|
|
|
|
|
"museum_subscription_id": subscription_id,
|
|
|
|
|
|
"amount": 0,
|
|
|
|
|
|
"status": "activated",
|
|
|
|
|
|
"create_date": datetime.now()
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
# 3. 创建用户订阅记录 (免费订阅有效期为7天)
|
|
|
|
|
|
start_date = datetime.now()
|
|
|
|
|
|
end_date = start_date + timedelta(days=7)
|
|
|
|
|
|
|
|
|
|
|
|
create_user_subscription({
|
|
|
|
|
|
"user_id": user_id,
|
|
|
|
|
|
"museum_subscription_id": subscription_id,
|
|
|
|
|
|
"order_id": order_id,
|
|
|
|
|
|
"start_date": start_date,
|
|
|
|
|
|
"end_date": end_date,
|
|
|
|
|
|
"is_active": True
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
# 4. 禁用同一博物馆的旧订阅
|
|
|
|
|
|
deactivate_previous_subscriptions(user_id, subscription_id)
|
|
|
|
|
|
|
|
|
|
|
|
return order_id
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def activate_user_subscription(
|
|
|
|
|
|
user_id: str,
|
|
|
|
|
|
museum_subscription_id: str,
|
|
|
|
|
|
order_id: str
|
|
|
|
|
|
) -> bool:
|
|
|
|
|
|
"""
|
|
|
|
|
|
激活用户订阅服务
|
|
|
|
|
|
|
|
|
|
|
|
参数:
|
|
|
|
|
|
- user_id: 用户ID
|
|
|
|
|
|
- museum_subscription_id: 博物馆订阅套餐ID
|
|
|
|
|
|
- order_id: 关联的订单ID
|
|
|
|
|
|
|
|
|
|
|
|
返回:
|
|
|
|
|
|
- 激活成功返回True,失败返回False
|
|
|
|
|
|
|
|
|
|
|
|
主要逻辑:
|
|
|
|
|
|
1. 获取博物馆订阅信息
|
|
|
|
|
|
2. 禁用同一博物馆的旧订阅
|
|
|
|
|
|
3. 计算订阅有效期
|
|
|
|
|
|
4. 创建用户订阅记录
|
|
|
|
|
|
5. 处理重复激活和并发请求
|
|
|
|
|
|
"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 1. 获取博物馆订阅信息
|
|
|
|
|
|
museum_sub = get_museum_subscription_by_id(museum_subscription_id)
|
|
|
|
|
|
if not museum_sub:
|
|
|
|
|
|
logger.error(f"博物馆订阅不存在: {museum_subscription_id}")
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
# 2. 获取关联的订阅模板
|
|
|
|
|
|
template = get_subscription_template_by_id(museum_sub["template_id"])
|
|
|
|
|
|
if not template:
|
|
|
|
|
|
logger.error(f"订阅模板不存在: {museum_sub['template_id']}")
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
# 3. 禁用同一博物馆的旧订阅
|
|
|
|
|
|
deactivated_count = deactivate_previous_subscriptions(
|
|
|
|
|
|
user_id=user_id,
|
|
|
|
|
|
museum_subscription_id=museum_subscription_id
|
|
|
|
|
|
)
|
|
|
|
|
|
logger.info(f"已禁用{deactivated_count}个同一博物馆的旧订阅")
|
|
|
|
|
|
|
|
|
|
|
|
# 4. 计算订阅有效期
|
|
|
|
|
|
start_date = datetime.now()
|
|
|
|
|
|
end_date = calculate_subscription_expiry(start_date,template["validity_type"])
|
|
|
|
|
|
|
|
|
|
|
|
# 5. 检查是否已激活过(防止重复激活)
|
|
|
|
|
|
existing_sub = get_user_subscription_by_order(order_id)
|
|
|
|
|
|
if existing_sub:
|
|
|
|
|
|
logger.warning(f"订阅已激活过,跳过重复激活. Order: {order_id}")
|
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
# 6. 创建用户订阅记录
|
|
|
|
|
|
subscription_data = {
|
|
|
|
|
|
"user_id": user_id,
|
|
|
|
|
|
"museum_subscription_id": museum_subscription_id,
|
|
|
|
|
|
"order_id": order_id,
|
|
|
|
|
|
"start_date": start_date,
|
|
|
|
|
|
"end_date": end_date,
|
|
|
|
|
|
"is_active": 1
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
create_user_subscription(subscription_data)
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"订阅激活成功. 用户: {user_id}, 套餐: {museum_subscription_id}, "
|
|
|
|
|
|
f"有效期: {start_date} 至 {end_date}")
|
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.exception(f"激活订阅失败: {str(e)}")
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
def check_user_subscription(user_id: str, museum_id: int) -> dict:
|
|
|
|
|
|
"""
|
|
|
|
|
|
检查用户是否拥有指定博物馆的有效订阅
|
|
|
|
|
|
|
|
|
|
|
|
功能说明:
|
|
|
|
|
|
- 检查用户是否有指定博物馆的未过期激活订阅
|
|
|
|
|
|
|
|
|
|
|
|
参数说明:
|
|
|
|
|
|
- user_id: 用户ID
|
|
|
|
|
|
- museum_id: 博物馆ID
|
|
|
|
|
|
|
|
|
|
|
|
返回:
|
|
|
|
|
|
- 订阅信息字典,如果没有则返回None
|
|
|
|
|
|
|
|
|
|
|
|
重要逻辑:
|
|
|
|
|
|
- 优先检查当前有效的订阅
|
|
|
|
|
|
- 如果没有,检查免费订阅是否可用
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 1. 检查当前有效订阅
|
|
|
|
|
|
active_sub = get_active_user_subscription(user_id, museum_id)
|
|
|
|
|
|
if active_sub:
|
|
|
|
|
|
return active_sub
|
|
|
|
|
|
|
|
|
|
|
|
# 2. 检查是否有免费订阅可用
|
|
|
|
|
|
# (这里可以扩展更多逻辑,如试用期检查等)
|
|
|
|
|
|
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def calculate_subscription_expiry(start_date: datetime, validity_type: str) -> datetime:
|
|
|
|
|
|
"""
|
|
|
|
|
|
根据有效期类型计算到期日期
|
|
|
|
|
|
|
|
|
|
|
|
参数:
|
|
|
|
|
|
- start_date: 订阅开始日期
|
|
|
|
|
|
- validity_type: 有效期类型 (free, 1month, 1year, permanent)
|
|
|
|
|
|
|
|
|
|
|
|
返回:
|
|
|
|
|
|
- 到期日期
|
|
|
|
|
|
"""
|
|
|
|
|
|
if validity_type == "free":
|
|
|
|
|
|
# 免费套餐通常有较短有效期(例如7天)
|
|
|
|
|
|
return start_date + timedelta(days=7)
|
|
|
|
|
|
elif validity_type == "1month":
|
|
|
|
|
|
# 下个月的同一天(自动处理月末情况)
|
|
|
|
|
|
return start_date + relativedelta(months=1)
|
|
|
|
|
|
elif validity_type == "1year":
|
|
|
|
|
|
# 下一年的同一天
|
|
|
|
|
|
return start_date + relativedelta(years=1)
|
|
|
|
|
|
elif validity_type == "permanent":
|
|
|
|
|
|
# 永久有效设置为100年后
|
|
|
|
|
|
return start_date + relativedelta(years=100)
|
|
|
|
|
|
else:
|
|
|
|
|
|
# 未知类型默认30天
|
|
|
|
|
|
logger.warning(f"未知有效期类型: {validity_type}, 使用默认30天")
|
|
|
|
|
|
return start_date + timedelta(days=30)
|