Files
ragflow_python/asr-monitor-test/app/database.py

1360 lines
40 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import pymysql
from pymysql import Connection
from pymysql.err import OperationalError, InterfaceError
from contextlib import contextmanager
from app.config import DATABASE_CONFIG
from datetime import datetime,timedelta
import logging
from zoneinfo import ZoneInfo # Python 3.9+ 内置
from typing import Union, List, Dict, Optional
from tenacity import retry, stop_after_attempt, wait_fixed, retry_if_exception_type
from dateutil.relativedelta import relativedelta
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("DB")
# 重试策略(更保守的配置)
RETRY_CONFIG = {
'stop': stop_after_attempt(3), # 最大尝试2次
'wait': wait_fixed(1), # 固定1秒重试间隔
'retry': retry_if_exception_type((OperationalError, InterfaceError)),
'reraise': True, # 重新抛出原始异常
'before_sleep': lambda retry_state: logger.warning(
f"Retrying ({retry_state.attempt_number}/2) due to: {retry_state.outcome.exception()}"
)
}
@contextmanager
def get_connection(autocommit: bool = False):
"""获取单次数据库连接(带重试机制)"""
@retry(**RETRY_CONFIG)
def connect():
try:
return pymysql.connect(**DATABASE_CONFIG)
except (OperationalError, InterfaceError) as e:
logger.error(f"Connection failed: {str(e)}")
raise
conn = None
try:
conn = connect() # 原有重试逻辑不变
conn.autocommit(autocommit) # 新增关键设置
yield conn
if not autocommit:
conn.commit() # 非自动提交模式下统一提交
except Exception as e:
if conn and not autocommit:
conn.rollback() # 非自动提交模式下回滚
raise
finally:
if conn:
conn.close()
def get_query_type(query: str) -> str:
"""更精准识别查询类型"""
query = query.strip().upper()
if query.startswith(("WITH", "SELECT")):
return "SELECT"
if query.startswith(("INSERT", "REPLACE")):
return "INSERT"
if query.startswith("UPDATE"):
return "UPDATE"
if query.startswith("DELETE"):
return "DELETE"
return "OTHER"
def process_result(cursor, query: str) -> Union[int, list[dict[str, any]]]:
"""统一处理结果"""
try:
# 自动识别查询类型
query_type =get_query_type(query)
if query_type == "SELECT":
return cursor.fetchall()
elif query_type == "INSERT":
return cursor.lastrowid # 返回插入的ID
elif query_type in ("UPDATE", "DELETE"):
return cursor.rowcount # 返回影响行数
else:
return cursor.rowcount # 其他操作返回影响行数
except IndexError:
raise ValueError("Invalid SQL query format")
@retry(**RETRY_CONFIG)
def execute_query(
query: str,
params: tuple | dict = None,
*,
connection: Optional[Connection] = None, # 明确类型 连接参数
autocommit: bool = False
):
"""
安全执行查询(适配低频操作)
:param read_only: 标记是否为只读查询(优化事务)
"""
# 输入安全验证
if not query.strip():
raise ValueError("Empty query")
forbidden_keywords = ['DROP', 'TRUNCATE', 'GRANT']
if any(kw in query.upper() for kw in forbidden_keywords):
raise PermissionError("Dangerous operation detected")
# 连接管理逻辑变更
if connection: # 使用外部连接
cursor = connection.cursor()
try:
cursor.execute(query, params)
# 根据查询类型处理结果
if autocommit and not query.strip().upper().startswith("SELECT"):
connection.commit()
return process_result(cursor)
finally:
cursor.close()
else: # 新建连接
with get_connection(autocommit=autocommit) as conn:
with conn.cursor() as cursor:
# logging.info(f"exec sql {query} {params}")
cursor.execute(query, params)
return process_result(cursor,query)
def create_museum(data: dict):
sql = """
INSERT INTO mesum_overview
(name, brief, chat_id, photo_url, longitude, latitude, category,
create_time, create_date, update_time, update_date, address, free)
VALUES
(%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
"""
now = int(datetime.now().timestamp())
params = (
data['name'],
data.get('brief'),
data['chat_id'],
data.get('photo_url'),
data.get('longitude'),
data.get('latitude'),
data.get('category'),
now,
datetime.fromtimestamp(now),
now,
datetime.fromtimestamp(now),
data.get('address'),
data.get('free', 0)
)
return execute_query(sql, params)
def get_museums(search: str = None, free: int = None):
base_sql = "SELECT * FROM mesum_overview WHERE 1=1"
params = []
if search:
base_sql += " AND (name LIKE %s OR brief LIKE %s)"
params.extend([f"%{search}%", f"%{search}%"])
if free is not None:
base_sql += " AND free = %s"
params.append(free)
base_sql += " ORDER BY create_time DESC"
return execute_query(base_sql, tuple(params))
def update_museum(museum_id: int, data: dict):
"""动态更新博物馆信息仅更新data中包含的字段"""
allowed_fields = {
'name': 'name = %s',
'brief': 'brief = %s',
'photo_url': 'photo_url = %s',
'longitude': 'longitude = %s',
'latitude': 'latitude = %s',
'category': 'category = %s',
'address': 'address = %s',
'free': 'free = %s'
}
update_fields = []
params = []
now = int(datetime.now().timestamp())
update_date = datetime.fromtimestamp(now)
# 收集动态字段
for field, sql_part in allowed_fields.items():
if field in data:
update_fields.append(sql_part)
params.append(data[field])
# 必须包含至少一个业务更新字段
if not update_fields:
raise ValueError("未提供有效更新字段")
# 添加自动更新时间
update_fields.extend([
'update_time = %s',
'update_date = %s'
])
params.extend([now, update_date])
# 构建动态SQL
set_clause = ", ".join(update_fields)
sql = f"""
UPDATE mesum_overview
SET {set_clause}
WHERE id = %s
"""
params.append(museum_id)
# 执行更新
rowcount = execute_query(sql, tuple(params))
if rowcount == 0:
raise HTTPException(status_code=404, detail="博物馆不存在")
return get_museum_by_id(museum_id)
def get_museum_by_id(museum_id: int):
sql = "SELECT * FROM mesum_overview WHERE id = %s"
result = execute_query(sql, (museum_id,))
assert isinstance(result, list), "Unexpected return type"
return result[0] if result else None
def delete_museum(museum_id: int):
sql = "DELETE FROM mesum_overview WHERE id = %s"
rowcount = execute_query(sql, (museum_id,))
if rowcount == 0:
raise HTTPException(404, "Delete failed")
return {"message": f"Deleted {rowcount} museums"}
# 创建授权记录
def create_users_museum(data: dict):
sql = """
INSERT INTO rag_flow.users_museum
(user_id, museum_id, create_time, create_date, update_time, update_date)
VALUES
(%s, %s, %s, %s, %s, %s)
"""
now = int(datetime.now().timestamp())
params = (
data['user_id'],
data['museum_id'],
now,
datetime.fromtimestamp(now),
now,
datetime.fromtimestamp(now)
)
return execute_query(sql, params)
# 查询授权记录(多条件)
def get_users_museums(user_id: str = None, museum_id: int = None, get_all: bool = False):
base_sql = "SELECT * FROM rag_flow.users_museum WHERE 1=1"
params = []
if not get_all:
if user_id:
base_sql += " AND user_id = %s"
params.append(user_id)
if museum_id is not None:
base_sql += " AND museum_id = %s"
params.append(museum_id)
base_sql += " ORDER BY create_time DESC"
return execute_query(base_sql, tuple(params))
# 按ID获取单条记录
def get_users_museums_by_user_id(user_id: str):
sql = "SELECT * FROM rag_flow.users_museum WHERE user_id = %s"
result = execute_query(sql, (user_id,))
return result
# 更新授权信息
def update_users_museum(id: int, data: dict):
"""动态更新用户博物馆授权仅更新data中包含的字段"""
allowed_fields = {
'user_id': 'user_id = %s',
'museum_id': 'museum_id = %s'
}
update_fields = []
params = []
now = int(datetime.now().timestamp())
update_date = datetime.fromtimestamp(now)
# 收集动态字段
for field, sql_part in allowed_fields.items():
if field in data:
update_fields.append(sql_part)
params.append(data[field])
# 必须包含至少一个业务字段
if not update_fields:
raise ValueError("未提供有效更新字段")
# 添加自动更新时间
update_fields.extend([
'update_time = %s',
'update_date = %s'
])
params.extend([now, update_date])
# 构建动态SQL
set_clause = ", ".join(update_fields)
sql = f"""
UPDATE rag_flow.users_museum
SET {set_clause}
WHERE id = %s
"""
params.append(id)
# 执行更新
rowcount = execute_query(sql, tuple(params))
if rowcount == 0:
raise HTTPException(status_code=404, detail="授权记录不存在")
return get_users_museum_by_id(id)
# 删除授权记录
def delete_users_museum(id: int):
sql = "DELETE FROM rag_flow.users_museum WHERE id = %s"
execute_query(sql, (id,))
return {"message": "User museum authorization deleted"}
# 批量检查授权状态
def check_auth_batch(user_id: str, museum_ids: list):
if not museum_ids:
return []
placeholders = ','.join(['%s'] * len(museum_ids))
sql = f"""
SELECT museum_id
FROM rag_flow.users_museum
WHERE user_id = %s
AND museum_id IN ({placeholders})
"""
params = [user_id] + museum_ids
result = execute_query(sql, params)
return [item['museum_id'] for item in result]
from datetime import datetime
# 创建用户
def create_user(data: dict):
sql = """
INSERT INTO rag_flow.users_info
(user_id, openid, phone, email, token, balance, status,
last_login_time, create_time, create_date, update_time, update_date)
VALUES
(%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) # 12个占位符
"""
now = int(datetime.now().timestamp())
params = (
data['user_id'],
data.get('openid'),
data.get('phone'),
data.get('email'),
data.get('token'),
data.get('balance', 0), # 默认余额0
data.get('status', 1), # 默认状态1正常
data.get('last_login_time'),
now,
datetime.fromtimestamp(now),
now,
datetime.fromtimestamp(now)
)
#logging.info(f"create user {data} {sql} {params}")
return execute_query(sql, params)
# 查询用户(多条件)
def get_users(status: int = None, email: str = None, phone: str = None,openid: str = None):
base_sql = "SELECT * FROM rag_flow.users_info WHERE 1=1"
params = []
if status is not None:
base_sql += " AND status = %s"
params.append(status)
if email:
base_sql += " AND email = %s"
params.append(email)
if phone:
base_sql += " AND phone = %s"
params.append(phone)
if openid:
base_sql += " AND openid = %s"
params.append(openid)
base_sql += " ORDER BY create_time DESC"
return execute_query(base_sql, tuple(params))
# 按用户ID获取用户
def get_user_by_id(user_id: str):
"""
根据用户ID获取用户信息
功能说明:
- 通过用户ID查询用户基本信息
参数说明:
- user_id: 用户ID
返回:
- 用户信息的字典如果不存在则返回None
重要逻辑:
- 直接查询用户表的所有字段
"""
sql = "SELECT * FROM rag_flow.users_info WHERE user_id = %s"
result = execute_query(sql, (user_id,))
return result[0] if result else None
# 更新用户信息
def update_user(user_id: str, data: dict):
"""动态更新用户信息仅更新data中包含的字段"""
# 允许更新的字段白名单
allowed_fields = {
'phone': 'phone = %s',
'email': 'email = %s',
'token': 'token = %s',
'balance': 'balance = %s',
'status': 'status = %s',
'last_login_time': 'last_login_time = %s'
}
# 过滤有效更新字段
update_fields = []
params = []
now = int(datetime.now().timestamp())
update_date = datetime.fromtimestamp(now)
# 收集动态字段
for field, sql_part in allowed_fields.items():
if field in data:
update_fields.append(sql_part)
params.append(data[field])
# 必须包含至少一个更新字段(除自动更新的时间字段)
if not update_fields:
raise ValueError("没有提供有效更新字段")
# 添加自动更新的时间字段
update_fields.extend([
'update_time = %s',
'update_date = %s'
])
params.extend([now, update_date])
# 构建动态SQL
set_clause = ", ".join(update_fields)
sql = f"""
UPDATE rag_flow.users_info
SET {set_clause}
WHERE user_id = %s
"""
# 添加用户ID作为最后参数
params.append(user_id)
# 执行更新
rowcount = execute_query(sql, tuple(params))
if rowcount == 0:
raise HTTPException(status_code=404, detail="用户不存在")
return get_user_by_id(user_id)
# 删除用户
def delete_user(user_id: str):
sql = "DELETE FROM rag_flow.users_info WHERE user_id = %s"
execute_query(sql, (user_id,))
return {"message": "User deleted"}
# 检查openid是否存在
def check_openid_exists(openid: str):
sql = "SELECT 1 FROM rag_flow.users_info WHERE openid = %s"
result = execute_query(sql, (openid,))
return bool(result)
# 更新用户余额
def update_balance(user_id: str, amount: int):
sql = """
UPDATE rag_flow.users_info
SET balance = balance + %s,
update_time = %s,
update_date = %s
WHERE user_id = %s
"""
now = int(datetime.now().timestamp())
params = (
amount,
now,
datetime.fromtimestamp(now),
user_id
)
execute_query(sql, params)
return get_user_by_id(user_id)
# 更新登录信息
def update_login_info(user_id: str, token: str):
sql = """
UPDATE rag_flow.users_info
SET token = %s,
last_login_time = %s,
update_time = %s,
update_date = %s
WHERE user_id = %s
"""
now = int(datetime.now().timestamp())
params = (
token,
now,
now,
datetime.fromtimestamp(now),
user_id
)
execute_query(sql, params)
return get_user_by_id(user_id)
# database.py
#------------------------------------------------------------
def get_museum_subscriptions_by_museum(museum_id: int) -> list:
"""
获取指定博物馆的所有有效订阅套餐
功能说明:
- 查询指定博物馆的所有可用订阅套餐
- 返回结果包含博物馆订阅信息和关联的模板信息
参数说明:
- museum_id: 博物馆ID
返回:
- 包含订阅信息的字典列表
重要逻辑:
- 只返回 is_active=1 的有效订阅
- 通过 JOIN 关联 subscription_templates 表获取模板信息
- 结果按有效期类型排序,便于前端展示
"""
sql = """
SELECT
ms.id,
ms.museum_id,
ms.template_id,
ms.price,
ms.sub_id,
ms.is_active,
ms.created_date,
ms.updated_date,
st.name AS template_name,
st.description AS template_description,
st.validity_type
FROM museum_subscriptions ms
JOIN subscription_templates st ON ms.template_id = st.id
WHERE ms.museum_id = %s AND ms.is_active = 1
ORDER BY st.validity_type
"""
return execute_query(sql, (museum_id,))
def get_museum_subscription_by_id(subscription_id: str) -> dict:
"""
根据ID获取博物馆订阅套餐的详细信息
功能说明:
- 通过订阅ID获取完整的订阅信息
- 包含关联的模板信息
参数说明:
- subscription_id: 博物馆订阅ID
返回:
- 订阅信息的字典如果不存在则返回None
重要逻辑:
- 使用内连接获取模板信息
- 确保返回完整的订阅+模板数据
"""
sql = """
SELECT
ms.*,
st.name AS template_name,
st.description AS template_description,
st.validity_type
FROM museum_subscriptions ms
JOIN subscription_templates st ON ms.template_id = st.id
WHERE ms.sub_id = %s
"""
result = execute_query(sql, (subscription_id,))
return result[0] if result else None
def create_order(order_data: dict) -> int:
"""
创建新的订阅订单
功能说明:
- 在 subscription_orders 表中插入新订单记录
参数说明:
- order_data: 包含订单数据的字典,字段包括:
order_id: 订单号 (必需)
user_id: 用户ID (必需)
museum_subscription_id: 博物馆订阅ID (必需)
amount: 订单金额 (默认0.00)
status: 订单状态 (默认'created')
transaction_id: 支付交易号 (可选)
create_date: 创建时间 (默认当前时间)
pay_time: 支付时间 (可选)
返回:
- 执行结果的行数
重要逻辑:
- 为可选字段提供默认值
- 使用参数化查询防止SQL注入
- 处理所有必需的订单字段
"""
sql = """
INSERT INTO subscription_orders (
order_id,
user_id,
museum_subscription_id,
amount,
status,
transaction_id,
create_date,
pay_time
) VALUES (
%s, %s, %s, %s, %s, %s, %s, %s
)
"""
# 设置默认值
params = (
order_data.get("order_id"),
order_data.get("user_id"),
order_data.get("museum_subscription_id"),
order_data.get("amount", 0.00),
order_data.get("status", "created"),
order_data.get("transaction_id"),
order_data.get("create_date", datetime.now()),
order_data.get("pay_time")
)
return execute_query(sql, params, autocommit=True)
def update_order(order_id: str, update_data: dict) -> int:
"""
更新订单信息
功能说明:
- 动态更新订单的字段
- 只更新提供的字段
参数说明:
- order_id: 要更新的订单ID
- update_data: 包含更新字段的字典,可选字段包括:
status: 订单状态
transaction_id: 支付交易号
amount: 订单金额
pay_time: 支付时间
返回:
- 受影响的行数
重要逻辑:
- 只允许更新预定义的字段
- 防止更新不允许的字段
- 使用参数化查询确保安全
"""
# 定义允许更新的字段及其SQL部分
allowed_fields = {
'status': 'status = %s',
'transaction_id': 'transaction_id = %s',
'amount': 'amount = %s',
'pay_time': 'pay_time = %s'
}
update_fields = []
params = []
# 收集要更新的字段
for field, sql_part in allowed_fields.items():
if field in update_data:
update_fields.append(sql_part)
params.append(update_data[field])
# 如果没有提供有效更新字段,直接返回
if not update_fields:
return 0
# 构建动态SQL
set_clause = ", ".join(update_fields)
sql = f"UPDATE subscription_orders SET {set_clause} WHERE order_id = %s"
params.append(order_id)
return execute_query(sql, tuple(params), autocommit=True)
from typing import Union, List, Dict, Optional
def get_order_by_id(order_id: str = None, user_id: str = None,combined:bool = False,museum_id:int = None) -> Union[Dict, List[Dict], None]:
"""
根据订单ID或用户ID查询订单信息
功能说明:
- 支持通过 order_id 或 user_id 查询订单信息
- 当传入 order_id 时,返回单个订单(字典)
- 当传入 user_id 时,返回该用户的所有订单(列表)
参数说明:
- order_id: 订单号(字符串)
- user_id: 用户ID字符串
返回:
- 如果传入 order_id: 返回单个订单的字典(如果存在)
- 如果传入 user_id: 返回该用户的所有订单列表(可能为空)
- 如果两个参数都未传入,返回 None
重要逻辑:
- 使用参数化查询防止 SQL 注入
- 当同时传入 order_id 和 user_id 时,优先使用 order_id
"""
# 无任何参数时返回 None
if not any([order_id, user_id, museum_id]):
return None
# ========== 简单查询模式 ==========
if not combined:
# 优先使用 order_id 查询
if order_id:
sql = "SELECT * FROM subscription_orders WHERE order_id = %s"
result = execute_query(sql, (order_id,))
return result[0] if result and len(result) > 0 else None
# 使用 user_id 查询
if user_id:
sql = "SELECT * FROM subscription_orders WHERE user_id = %s"
result = execute_query(sql, (user_id,))
return result if result else []
# ========== 复杂查询模式 ==========
base_sql= """
SELECT
o.order_id,
o.user_id,
u.phone,
u.openid,
o.museum_subscription_id AS subscription_id,
ms.museum_id,
ms.template_id,
t.name AS template_name,
t.description AS template_desc,
t.validity_type,
ms.price AS subscription_price,
o.amount AS order_amount,
o.status AS order_status,
o.transaction_id,
o.create_date AS order_create_time,
o.pay_time,
us.start_date,
us.end_date,
us.is_active AS subscription_active,
mo.name AS museum_name
FROM
rag_flow.subscription_orders o
LEFT JOIN rag_flow.users_info u ON o.user_id = u.user_id
LEFT JOIN rag_flow.museum_subscriptions ms ON o.museum_subscription_id = ms.sub_id
LEFT JOIN rag_flow.subscription_templates t ON ms.template_id = t.id
LEFT JOIN rag_flow.user_subscriptions us ON o.order_id = us.order_id
LEFT JOIN rag_flow.mesum_overview mo ON ms.museum_id = mo.id
"""
# 构建查询条件和参数
conditions = []
params = []
# 添加条件(按优先级)
if order_id:
conditions.append("o.order_id = %s")
params.append(order_id)
elif user_id:
conditions.append("o.user_id = %s")
params.append(user_id)
# 新增博物馆ID条件可与其他条件组合
if museum_id:
conditions.append("ms.museum_id = %s")
params.append(museum_id)
# 构建完整SQL
if conditions:
where_clause = " WHERE " + " AND ".join(conditions)
sql = base_sql + where_clause
else:
sql = base_sql
# 执行查询
result = execute_query(sql, tuple(params))
# 处理返回结果
if not result:
return [] if user_id or museum_id else None
# 当有order_id时返回单个对象否则返回列表
return result[0] if order_id and not museum_id else result
def create_user_subscription(data: dict) -> int:
"""
创建用户订阅记录
功能说明:
- 在 user_subscriptions 表中插入新记录
- 表示用户购买并激活了一个订阅
参数说明:
- data: 包含订阅数据的字典,字段包括:
user_id: 用户ID (必需)
museum_subscription_id: 博物馆订阅ID (必需)
order_id: 关联的订单ID (必需)
start_date: 开始时间 (默认当前时间)
end_date: 结束时间 (必需)
is_active: 是否激活 (默认1)
create_date: 创建时间 (默认当前时间)
返回:
- 执行结果的行数
重要逻辑:
- 为可选字段提供默认值
- 确保所有必需字段都有值
- 处理时间字段的默认值
"""
sql = """
INSERT INTO user_subscriptions (
user_id,
museum_subscription_id,
order_id,
start_date,
end_date,
is_active,
create_date
) VALUES (
%s, %s, %s, %s, %s, %s, %s
)
"""
# 设置默认值
params = (
data.get("user_id"),
data.get("museum_subscription_id"),
data.get("order_id"),
data.get("start_date", datetime.now()),
data.get("end_date"),
data.get("is_active", 1),
data.get("create_date", datetime.now())
)
return execute_query(sql, params, autocommit=True)
def deactivate_previous_subscriptions(user_id: str, museum_subscription_id: str) -> int:
"""
禁用用户在同一博物馆的旧订阅
功能说明:
- 将用户在同一博物馆的所有激活订阅设为非激活状态
- 确保同一博物馆只有一个激活订阅
参数说明:
- user_id: 用户ID
- museum_id: 博物馆ID
返回:
- 受影响的行数
重要逻辑:
- 通过JOIN关联博物馆订阅表
- 只更新同一博物馆的订阅
- 保持历史订阅记录,只修改激活状态
"""
sql = """
UPDATE user_subscriptions us
JOIN museum_subscriptions ms ON us.museum_subscription_id = ms.sub_id
SET us.is_active = 0
WHERE us.user_id = %s AND us.museum_subscription_id = %s AND us.is_active = 1
"""
return execute_query(sql, (user_id, museum_subscription_id), autocommit=True)
def get_subscription_template_by_id(template_id: int) -> dict:
"""
根据模板ID获取订阅模板信息
功能说明:
- 通过模板ID查询订阅模板详情
参数说明:
- template_id: 模板ID
返回:
- 模板信息的字典如果不存在则返回None
"""
sql = "SELECT * FROM subscription_templates WHERE id = %s"
result = execute_query(sql, (template_id,))
return result[0] if result else None
def get_active_user_subscription(user_id: str, museum_id: int) -> dict:
"""
获取用户在指定博物馆的有效订阅
功能说明:
- 查询用户在特定博物馆的当前有效订阅
- 有效订阅定义为: 已激活且未过期
参数说明:
- user_id: 用户ID
- museum_id: 博物馆ID
返回:
- 订阅信息的字典如果不存在则返回None
重要逻辑:
- 通过JOIN关联博物馆订阅表
- 检查 is_active=1 和 end_date > NOW()
- 返回最新到期的订阅
"""
sql = """
SELECT us.*
FROM user_subscriptions us
JOIN museum_subscriptions ms ON us.museum_subscription_id = ms.id
WHERE us.user_id = %s
AND ms.museum_id = %s
AND us.is_active = 1
AND us.end_date > NOW()
ORDER BY us.end_date DESC
LIMIT 1
"""
result = execute_query(sql, (user_id, museum_id))
return result[0] if result else None
def get_user_subscription_history(user_id: str) -> list:
"""
获取用户的订阅历史记录
功能说明:
- 查询用户的所有订阅记录(包括历史和当前)
- 返回完整的订阅详情,包含博物馆和模板信息
参数说明:
- user_id: 用户ID
返回:
- 包含订阅历史记录的字典列表
重要逻辑:
- 通过多层JOIN关联所有相关表
- 包含博物馆名称、模板信息等
- 按开始时间倒序排列,最新订阅在前
"""
sql = """
SELECT
us.*,
ms.price,
ms.sub_id,
st.name AS template_name,
st.description AS template_description,
st.validity_type,
mo.name AS museum_name
FROM user_subscriptions us
JOIN museum_subscriptions ms ON us.museum_subscription_id = ms.id
JOIN subscription_templates st ON ms.template_id = st.id
JOIN mesum_overview mo ON ms.museum_id = mo.id
WHERE us.user_id = %s
ORDER BY us.start_date DESC
"""
return execute_query(sql, (user_id,))
def get_user_subscription_by_order(order_id: str) -> dict:
"""
根据订单ID获取用户订阅信息
功能说明:
- 通过订单ID查询关联的用户订阅
参数说明:
- order_id: 订单ID
返回:
- 订阅信息的字典如果不存在则返回None
重要逻辑:
- 用于支付回调后验证订阅是否已创建
"""
sql = "SELECT * FROM user_subscriptions WHERE order_id = %s"
result = execute_query(sql, (order_id,))
return result[0] if result else None
def activate_free_subscription(user_id: str, museum_id: int) -> str:
"""
激活免费订阅
功能说明:
- 为用户在指定博物馆激活免费订阅
- 创建订单记录和订阅记录
参数说明:
- user_id: 用户ID
- museum_id: 博物馆ID
返回:
- 创建的订单ID
重要逻辑:
- 查找博物馆的免费订阅
- 创建订单记录 (状态为activated)
- 创建订阅记录 (有效期为7天)
"""
# 1. 获取博物馆的免费订阅
sql = """
SELECT ms.id
FROM museum_subscriptions ms
JOIN subscription_templates st ON ms.template_id = st.id
WHERE ms.museum_id = %s
AND st.validity_type = 'free'
AND ms.is_active = 1
LIMIT 1
"""
result = execute_query(sql, (museum_id,))
if not result:
raise ValueError("该博物馆没有可用的免费订阅")
subscription_id = result[0]["id"]
# 2. 创建免费订单
order_id = f"FREE_{int(time.time())}"
create_order({
"order_id": order_id,
"user_id": user_id,
"museum_subscription_id": subscription_id,
"amount": 0,
"status": "activated",
"create_date": datetime.now()
})
# 3. 创建用户订阅记录 (免费订阅有效期为7天)
start_date = datetime.now()
end_date = start_date + timedelta(days=7)
create_user_subscription({
"user_id": user_id,
"museum_subscription_id": subscription_id,
"order_id": order_id,
"start_date": start_date,
"end_date": end_date,
"is_active": True
})
# 4. 禁用同一博物馆的旧订阅
deactivate_previous_subscriptions(user_id, subscription_id)
return order_id
def activate_user_subscription(
user_id: str,
museum_subscription_id: str,
order_id: str
) -> bool:
"""
激活用户订阅服务
参数:
- user_id: 用户ID
- museum_subscription_id: 博物馆订阅套餐ID
- order_id: 关联的订单ID
返回:
- 激活成功返回True失败返回False
主要逻辑:
1. 获取博物馆订阅信息
2. 禁用同一博物馆的旧订阅
3. 计算订阅有效期
4. 创建用户订阅记录
5. 处理重复激活和并发请求
"""
try:
# 1. 获取博物馆订阅信息
museum_sub = get_museum_subscription_by_id(museum_subscription_id)
if not museum_sub:
logger.error(f"博物馆订阅不存在: {museum_subscription_id}")
return False
# 2. 获取关联的订阅模板
template = get_subscription_template_by_id(museum_sub["template_id"])
if not template:
logger.error(f"订阅模板不存在: {museum_sub['template_id']}")
return False
# 3. 禁用同一博物馆的旧订阅
deactivated_count = deactivate_previous_subscriptions(
user_id=user_id,
museum_subscription_id=museum_subscription_id
)
logger.info(f"已禁用{deactivated_count}个同一博物馆的旧订阅")
# 4. 计算订阅有效期
start_date = datetime.now()
end_date = calculate_subscription_expiry(start_date,template["validity_type"])
# 5. 检查是否已激活过(防止重复激活)
existing_sub = get_user_subscription_by_order(order_id)
if existing_sub:
logger.warning(f"订阅已激活过,跳过重复激活. Order: {order_id}")
return True
# 6. 创建用户订阅记录
subscription_data = {
"user_id": user_id,
"museum_subscription_id": museum_subscription_id,
"order_id": order_id,
"start_date": start_date,
"end_date": end_date,
"is_active": 1
}
create_user_subscription(subscription_data)
logger.info(f"订阅激活成功. 用户: {user_id}, 套餐: {museum_subscription_id}, "
f"有效期: {start_date}{end_date}")
return True
except Exception as e:
logger.exception(f"激活订阅失败: {str(e)}")
return False
def check_user_subscription(user_id: str, museum_id: int) -> dict:
"""
检查用户是否拥有指定博物馆的有效订阅
功能说明:
- 检查用户是否有指定博物馆的未过期激活订阅
参数说明:
- user_id: 用户ID
- museum_id: 博物馆ID
返回:
- 订阅信息字典如果没有则返回None
重要逻辑:
- 优先检查当前有效的订阅
- 如果没有,检查免费订阅是否可用
"""
# 1. 检查当前有效订阅
active_sub = get_active_user_subscription(user_id, museum_id)
if active_sub:
return active_sub
# 2. 检查是否有免费订阅可用
# (这里可以扩展更多逻辑,如试用期检查等)
return None
def is_museum_free_period(museum_id: int) -> bool:
"""
检查博物馆当前是否处于免费时段
参数:
- museum_id: 博物馆ID
返回:
- True: 当前是免费时段
- False: 当前不是免费时段
"""
# 查询博物馆的免费时段配置
sql = """
SELECT t.validity_type, t.valid_time_range, t.valid_week_days
FROM museum_subscriptions ms
JOIN subscription_templates t ON ms.template_id = t.id
WHERE ms.museum_id = %s
AND t.validity_type = 'free_interval'
AND t.is_active = 1
AND ms.is_active = 1
LIMIT 1
"""
result = execute_query(sql, (museum_id,))
if not result:
return False
subscription = result[0]
return is_subscription_valid(subscription)
def get_user_valid_subscription(user_id: str, museum_id: int) -> bool:
"""
检查用户是否有有效的博物馆订阅
参数:
- user_id: 用户ID
- museum_id: 博物馆ID
返回:
- True: 用户有有效订阅
- False: 用户无有效订阅
"""
# 查询用户的有效订阅
sql = """
SELECT
t.validity_type,
t.valid_time_range,
t.valid_week_days,
us.start_date,
us.end_date
FROM user_subscriptions us
JOIN museum_subscriptions ms ON us.museum_subscription_id = ms.sub_id
JOIN subscription_templates t ON ms.template_id = t.id
WHERE us.user_id = %s
AND ms.museum_id = %s
AND us.is_active = 1
AND ms.is_active = 1
AND t.is_active = 1
AND us.start_date <= NOW()
AND us.end_date >= NOW()
"""
subscriptions = execute_query(sql, (user_id, museum_id))
# 检查每个订阅是否在当前时间有效
for sub in subscriptions:
if is_subscription_valid(sub):
return True
return False
def calculate_subscription_expiry(start_date: datetime, validity_type: str) -> datetime:
"""
根据有效期类型计算到期日期
参数:
- start_date: 订阅开始日期
- validity_type: 有效期类型 (free, 1month, 1year, permanent)
返回:
- 到期日期
"""
if validity_type == "free":
# 免费套餐通常有较短有效期例如7天
return start_date + timedelta(days=7)
elif validity_type == "1month":
# 下个月的同一天(自动处理月末情况)
return start_date + relativedelta(months=1)
elif validity_type == "1year":
# 下一年的同一天
return start_date + relativedelta(years=1)
elif validity_type == "permanent":
# 永久有效设置为100年后
return start_date + relativedelta(years=100)
else:
# 未知类型默认30天
logger.warning(f"未知有效期类型: {validity_type}, 使用默认30天")
return start_date + timedelta(days=30)
def is_subscription_valid(subscription: dict) -> bool:
"""
检查订阅在当前时间是否有效
参数:
subscription: 包含订阅信息的字典,包含以下字段:
- validity_type: 有效期类型
- valid_time_range: 有效时间段 (格式: "08:00-20:00")
- valid_week_days: 有效星期 (格式: "1,3,5")
- start_date: 订阅开始日期 (datetime 对象)
- end_date: 订阅结束日期 (datetime 对象)
"""
# 设置时区(根据服务器实际时区调整)
tz = ZoneInfo('Asia/Shanghai')
now = datetime.now(tz)
# 1. 检查永久免费订阅
if subscription['validity_type'] == 'free':
return True
# 2. 检查时间间隔类型订阅
if subscription['validity_type'] == 'free_interval':
# 时间间隔类型不需要检查有效期范围
pass
else:
# 3. 检查有效期是否在范围内
if subscription['validity_type'] in ['1month', '1year', 'permanent']:
start_date = subscription['start_date'].astimezone(tz)
end_date = subscription['end_date'].astimezone(tz)
if not (start_date <= now <= end_date):
return False
# 4. 检查星期限制
if subscription.get('valid_week_days'):
week_day = now.isoweekday() # 1=周一, 7=周日
valid_days = [int(d) for d in str(subscription['valid_week_days']).split(',')]
if week_day not in valid_days:
return False
# 5. 检查时间范围限制
if subscription.get('valid_time_range'):
try:
start_str, end_str = subscription['valid_time_range'].split('-')
start_time = datetime.strptime(start_str, '%H:%M').time()
end_time = datetime.strptime(end_str, '%H:%M').time()
current_time = now.time()
# 处理跨夜时段
if end_time < start_time:
if not (current_time >= start_time or current_time <= end_time):
return False
else:
if not (start_time <= current_time <= end_time):
return False
except (ValueError, AttributeError):
# 时间格式无效,跳过时间检查
pass
return True