import pymysql from pymysql import Connection from pymysql.err import OperationalError, InterfaceError from contextlib import contextmanager from config import DATABASE_CONFIG from datetime import datetime,timedelta import logging from typing import Union, List, Dict, Optional from tenacity import retry, stop_after_attempt, wait_fixed, retry_if_exception_type from dateutil.relativedelta import relativedelta # 配置日志 logging.basicConfig(level=logging.INFO) logger = logging.getLogger("DB") # 重试策略(更保守的配置) RETRY_CONFIG = { 'stop': stop_after_attempt(3), # 最大尝试2次 'wait': wait_fixed(1), # 固定1秒重试间隔 'retry': retry_if_exception_type((OperationalError, InterfaceError)), 'reraise': True, # 重新抛出原始异常 'before_sleep': lambda retry_state: logger.warning( f"Retrying ({retry_state.attempt_number}/2) due to: {retry_state.outcome.exception()}" ) } @contextmanager def get_connection(autocommit: bool = False): """获取单次数据库连接(带重试机制)""" @retry(**RETRY_CONFIG) def connect(): try: return pymysql.connect(**DATABASE_CONFIG) except (OperationalError, InterfaceError) as e: logger.error(f"Connection failed: {str(e)}") raise conn = None try: conn = connect() # 原有重试逻辑不变 conn.autocommit(autocommit) # 新增关键设置 yield conn if not autocommit: conn.commit() # 非自动提交模式下统一提交 except Exception as e: if conn and not autocommit: conn.rollback() # 非自动提交模式下回滚 raise finally: if conn: conn.close() def get_query_type(query: str) -> str: """更精准识别查询类型""" query = query.strip().upper() if query.startswith(("WITH", "SELECT")): return "SELECT" if query.startswith(("INSERT", "REPLACE")): return "INSERT" if query.startswith("UPDATE"): return "UPDATE" if query.startswith("DELETE"): return "DELETE" return "OTHER" def process_result(cursor, query: str) -> Union[int, list[dict[str, any]]]: """统一处理结果""" try: # 自动识别查询类型 query_type =get_query_type(query) if query_type == "SELECT": return cursor.fetchall() elif query_type == "INSERT": return cursor.lastrowid # 返回插入的ID elif query_type in ("UPDATE", "DELETE"): return cursor.rowcount # 返回影响行数 else: return cursor.rowcount # 其他操作返回影响行数 except IndexError: raise ValueError("Invalid SQL query format") @retry(**RETRY_CONFIG) def execute_query( query: str, params: tuple | dict = None, *, connection: Optional[Connection] = None, # 明确类型 连接参数 autocommit: bool = False ): """ 安全执行查询(适配低频操作) :param read_only: 标记是否为只读查询(优化事务) """ # 输入安全验证 if not query.strip(): raise ValueError("Empty query") forbidden_keywords = ['DROP', 'TRUNCATE', 'GRANT'] if any(kw in query.upper() for kw in forbidden_keywords): raise PermissionError("Dangerous operation detected") # 连接管理逻辑变更 if connection: # 使用外部连接 cursor = connection.cursor() try: cursor.execute(query, params) # 根据查询类型处理结果 if autocommit and not query.strip().upper().startswith("SELECT"): connection.commit() return process_result(cursor) finally: cursor.close() else: # 新建连接 with get_connection(autocommit=autocommit) as conn: with conn.cursor() as cursor: # logging.info(f"exec sql {query} {params}") cursor.execute(query, params) return process_result(cursor,query) def create_museum(data: dict): sql = """ INSERT INTO mesum_overview (name, brief, chat_id, photo_url, longitude, latitude, category, create_time, create_date, update_time, update_date, address, free) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) """ now = int(datetime.now().timestamp()) params = ( data['name'], data.get('brief'), data['chat_id'], data.get('photo_url'), data.get('longitude'), data.get('latitude'), data.get('category'), now, datetime.fromtimestamp(now), now, datetime.fromtimestamp(now), data.get('address'), data.get('free', 0) ) return execute_query(sql, params) def get_museums(search: str = None, free: int = None): base_sql = "SELECT * FROM mesum_overview WHERE 1=1" params = [] if search: base_sql += " AND (name LIKE %s OR brief LIKE %s)" params.extend([f"%{search}%", f"%{search}%"]) if free is not None: base_sql += " AND free = %s" params.append(free) base_sql += " ORDER BY create_time DESC" return execute_query(base_sql, tuple(params)) def update_museum(museum_id: int, data: dict): """动态更新博物馆信息(仅更新data中包含的字段)""" allowed_fields = { 'name': 'name = %s', 'brief': 'brief = %s', 'photo_url': 'photo_url = %s', 'longitude': 'longitude = %s', 'latitude': 'latitude = %s', 'category': 'category = %s', 'address': 'address = %s', 'free': 'free = %s' } update_fields = [] params = [] now = int(datetime.now().timestamp()) update_date = datetime.fromtimestamp(now) # 收集动态字段 for field, sql_part in allowed_fields.items(): if field in data: update_fields.append(sql_part) params.append(data[field]) # 必须包含至少一个业务更新字段 if not update_fields: raise ValueError("未提供有效更新字段") # 添加自动更新时间 update_fields.extend([ 'update_time = %s', 'update_date = %s' ]) params.extend([now, update_date]) # 构建动态SQL set_clause = ", ".join(update_fields) sql = f""" UPDATE mesum_overview SET {set_clause} WHERE id = %s """ params.append(museum_id) # 执行更新 rowcount = execute_query(sql, tuple(params)) if rowcount == 0: raise HTTPException(status_code=404, detail="博物馆不存在") return get_museum_by_id(museum_id) def get_museum_by_id(museum_id: int): sql = "SELECT * FROM mesum_overview WHERE id = %s" result = execute_query(sql, (museum_id,)) assert isinstance(result, list), "Unexpected return type" return result[0] if result else None def delete_museum(museum_id: int): sql = "DELETE FROM mesum_overview WHERE id = %s" rowcount = execute_query(sql, (museum_id,)) if rowcount == 0: raise HTTPException(404, "Delete failed") return {"message": f"Deleted {rowcount} museums"} # 创建授权记录 def create_users_museum(data: dict): sql = """ INSERT INTO rag_flow.users_museum (user_id, museum_id, create_time, create_date, update_time, update_date) VALUES (%s, %s, %s, %s, %s, %s) """ now = int(datetime.now().timestamp()) params = ( data['user_id'], data['museum_id'], now, datetime.fromtimestamp(now), now, datetime.fromtimestamp(now) ) return execute_query(sql, params) # 查询授权记录(多条件) def get_users_museums(user_id: str = None, museum_id: int = None, get_all: bool = False): base_sql = "SELECT * FROM rag_flow.users_museum WHERE 1=1" params = [] if not get_all: if user_id: base_sql += " AND user_id = %s" params.append(user_id) if museum_id is not None: base_sql += " AND museum_id = %s" params.append(museum_id) base_sql += " ORDER BY create_time DESC" return execute_query(base_sql, tuple(params)) # 按ID获取单条记录 def get_users_museums_by_user_id(user_id: str): sql = "SELECT * FROM rag_flow.users_museum WHERE user_id = %s" result = execute_query(sql, (user_id,)) return result # 更新授权信息 def update_users_museum(id: int, data: dict): """动态更新用户博物馆授权(仅更新data中包含的字段)""" allowed_fields = { 'user_id': 'user_id = %s', 'museum_id': 'museum_id = %s' } update_fields = [] params = [] now = int(datetime.now().timestamp()) update_date = datetime.fromtimestamp(now) # 收集动态字段 for field, sql_part in allowed_fields.items(): if field in data: update_fields.append(sql_part) params.append(data[field]) # 必须包含至少一个业务字段 if not update_fields: raise ValueError("未提供有效更新字段") # 添加自动更新时间 update_fields.extend([ 'update_time = %s', 'update_date = %s' ]) params.extend([now, update_date]) # 构建动态SQL set_clause = ", ".join(update_fields) sql = f""" UPDATE rag_flow.users_museum SET {set_clause} WHERE id = %s """ params.append(id) # 执行更新 rowcount = execute_query(sql, tuple(params)) if rowcount == 0: raise HTTPException(status_code=404, detail="授权记录不存在") return get_users_museum_by_id(id) # 删除授权记录 def delete_users_museum(id: int): sql = "DELETE FROM rag_flow.users_museum WHERE id = %s" execute_query(sql, (id,)) return {"message": "User museum authorization deleted"} # 批量检查授权状态 def check_auth_batch(user_id: str, museum_ids: list): if not museum_ids: return [] placeholders = ','.join(['%s'] * len(museum_ids)) sql = f""" SELECT museum_id FROM rag_flow.users_museum WHERE user_id = %s AND museum_id IN ({placeholders}) """ params = [user_id] + museum_ids result = execute_query(sql, params) return [item['museum_id'] for item in result] from datetime import datetime # 创建用户 def create_user(data: dict): sql = """ INSERT INTO rag_flow.users_info (user_id, openid, phone, email, token, balance, status, last_login_time, create_time, create_date, update_time, update_date) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) # 12个占位符 """ now = int(datetime.now().timestamp()) params = ( data['user_id'], data.get('openid'), data.get('phone'), data.get('email'), data.get('token'), data.get('balance', 0), # 默认余额0 data.get('status', 1), # 默认状态1正常 data.get('last_login_time'), now, datetime.fromtimestamp(now), now, datetime.fromtimestamp(now) ) #logging.info(f"create user {data} {sql} {params}") return execute_query(sql, params) # 查询用户(多条件) def get_users(status: int = None, email: str = None, phone: str = None,openid: str = None): base_sql = "SELECT * FROM rag_flow.users_info WHERE 1=1" params = [] if status is not None: base_sql += " AND status = %s" params.append(status) if email: base_sql += " AND email = %s" params.append(email) if phone: base_sql += " AND phone = %s" params.append(phone) if openid: base_sql += " AND openid = %s" params.append(openid) base_sql += " ORDER BY create_time DESC" return execute_query(base_sql, tuple(params)) # 按用户ID获取用户 def get_user_by_id(user_id: str): sql = "SELECT * FROM rag_flow.users_info WHERE user_id = %s" result = execute_query(sql, (user_id,)) return result[0] if result else None # 更新用户信息 def update_user(user_id: str, data: dict): """动态更新用户信息(仅更新data中包含的字段)""" # 允许更新的字段白名单 allowed_fields = { 'phone': 'phone = %s', 'email': 'email = %s', 'token': 'token = %s', 'balance': 'balance = %s', 'status': 'status = %s', 'last_login_time': 'last_login_time = %s' } # 过滤有效更新字段 update_fields = [] params = [] now = int(datetime.now().timestamp()) update_date = datetime.fromtimestamp(now) # 收集动态字段 for field, sql_part in allowed_fields.items(): if field in data: update_fields.append(sql_part) params.append(data[field]) # 必须包含至少一个更新字段(除自动更新的时间字段) if not update_fields: raise ValueError("没有提供有效更新字段") # 添加自动更新的时间字段 update_fields.extend([ 'update_time = %s', 'update_date = %s' ]) params.extend([now, update_date]) # 构建动态SQL set_clause = ", ".join(update_fields) sql = f""" UPDATE rag_flow.users_info SET {set_clause} WHERE user_id = %s """ # 添加用户ID作为最后参数 params.append(user_id) # 执行更新 rowcount = execute_query(sql, tuple(params)) if rowcount == 0: raise HTTPException(status_code=404, detail="用户不存在") return get_user_by_id(user_id) # 删除用户 def delete_user(user_id: str): sql = "DELETE FROM rag_flow.users_info WHERE user_id = %s" execute_query(sql, (user_id,)) return {"message": "User deleted"} # 检查openid是否存在 def check_openid_exists(openid: str): sql = "SELECT 1 FROM rag_flow.users_info WHERE openid = %s" result = execute_query(sql, (openid,)) return bool(result) # 更新用户余额 def update_balance(user_id: str, amount: int): sql = """ UPDATE rag_flow.users_info SET balance = balance + %s, update_time = %s, update_date = %s WHERE user_id = %s """ now = int(datetime.now().timestamp()) params = ( amount, now, datetime.fromtimestamp(now), user_id ) execute_query(sql, params) return get_user_by_id(user_id) # 更新登录信息 def update_login_info(user_id: str, token: str): sql = """ UPDATE rag_flow.users_info SET token = %s, last_login_time = %s, update_time = %s, update_date = %s WHERE user_id = %s """ now = int(datetime.now().timestamp()) params = ( token, now, now, datetime.fromtimestamp(now), user_id ) execute_query(sql, params) return get_user_by_id(user_id) # database.py #------------------------------------------------------------ def get_museum_subscriptions_by_museum(museum_id: int) -> list: """ 获取指定博物馆的所有有效订阅套餐 功能说明: - 查询指定博物馆的所有可用订阅套餐 - 返回结果包含博物馆订阅信息和关联的模板信息 参数说明: - museum_id: 博物馆ID 返回: - 包含订阅信息的字典列表 重要逻辑: - 只返回 is_active=1 的有效订阅 - 通过 JOIN 关联 subscription_templates 表获取模板信息 - 结果按有效期类型排序,便于前端展示 """ sql = """ SELECT ms.id, ms.museum_id, ms.template_id, ms.price, ms.sub_id, ms.is_active, ms.created_date, ms.updated_date, st.name AS template_name, st.description AS template_description, st.validity_type FROM museum_subscriptions ms JOIN subscription_templates st ON ms.template_id = st.id WHERE ms.museum_id = %s AND ms.is_active = 1 ORDER BY st.validity_type """ return execute_query(sql, (museum_id,)) def get_museum_subscription_by_id(subscription_id: str) -> dict: """ 根据ID获取博物馆订阅套餐的详细信息 功能说明: - 通过订阅ID获取完整的订阅信息 - 包含关联的模板信息 参数说明: - subscription_id: 博物馆订阅ID 返回: - 订阅信息的字典,如果不存在则返回None 重要逻辑: - 使用内连接获取模板信息 - 确保返回完整的订阅+模板数据 """ sql = """ SELECT ms.*, st.name AS template_name, st.description AS template_description, st.validity_type FROM museum_subscriptions ms JOIN subscription_templates st ON ms.template_id = st.id WHERE ms.sub_id = %s """ result = execute_query(sql, (subscription_id,)) return result[0] if result else None def create_order(order_data: dict) -> int: """ 创建新的订阅订单 功能说明: - 在 subscription_orders 表中插入新订单记录 参数说明: - order_data: 包含订单数据的字典,字段包括: order_id: 订单号 (必需) user_id: 用户ID (必需) museum_subscription_id: 博物馆订阅ID (必需) amount: 订单金额 (默认0.00) status: 订单状态 (默认'created') transaction_id: 支付交易号 (可选) create_date: 创建时间 (默认当前时间) pay_time: 支付时间 (可选) 返回: - 执行结果的行数 重要逻辑: - 为可选字段提供默认值 - 使用参数化查询防止SQL注入 - 处理所有必需的订单字段 """ sql = """ INSERT INTO subscription_orders ( order_id, user_id, museum_subscription_id, amount, status, transaction_id, create_date, pay_time ) VALUES ( %s, %s, %s, %s, %s, %s, %s, %s ) """ # 设置默认值 params = ( order_data.get("order_id"), order_data.get("user_id"), order_data.get("museum_subscription_id"), order_data.get("amount", 0.00), order_data.get("status", "created"), order_data.get("transaction_id"), order_data.get("create_date", datetime.now()), order_data.get("pay_time") ) return execute_query(sql, params, autocommit=True) def update_order(order_id: str, update_data: dict) -> int: """ 更新订单信息 功能说明: - 动态更新订单的字段 - 只更新提供的字段 参数说明: - order_id: 要更新的订单ID - update_data: 包含更新字段的字典,可选字段包括: status: 订单状态 transaction_id: 支付交易号 amount: 订单金额 pay_time: 支付时间 返回: - 受影响的行数 重要逻辑: - 只允许更新预定义的字段 - 防止更新不允许的字段 - 使用参数化查询确保安全 """ # 定义允许更新的字段及其SQL部分 allowed_fields = { 'status': 'status = %s', 'transaction_id': 'transaction_id = %s', 'amount': 'amount = %s', 'pay_time': 'pay_time = %s' } update_fields = [] params = [] # 收集要更新的字段 for field, sql_part in allowed_fields.items(): if field in update_data: update_fields.append(sql_part) params.append(update_data[field]) # 如果没有提供有效更新字段,直接返回 if not update_fields: return 0 # 构建动态SQL set_clause = ", ".join(update_fields) sql = f"UPDATE subscription_orders SET {set_clause} WHERE order_id = %s" params.append(order_id) return execute_query(sql, tuple(params), autocommit=True) from typing import Union, List, Dict, Optional def get_order_by_id(order_id: str = None, user_id: str = None,combined = None) -> Union[Dict, List[Dict], None]: """ 根据订单ID或用户ID查询订单信息 功能说明: - 支持通过 order_id 或 user_id 查询订单信息 - 当传入 order_id 时,返回单个订单(字典) - 当传入 user_id 时,返回该用户的所有订单(列表) 参数说明: - order_id: 订单号(字符串) - user_id: 用户ID(字符串) 返回: - 如果传入 order_id: 返回单个订单的字典(如果存在) - 如果传入 user_id: 返回该用户的所有订单列表(可能为空) - 如果两个参数都未传入,返回 None 重要逻辑: - 使用参数化查询防止 SQL 注入 - 当同时传入 order_id 和 user_id 时,优先使用 order_id """ if not order_id and not user_id: return None # 两个参数都未传入,直接返回 None sql_wo_condition= """ SELECT o.order_id, o.user_id, u.phone, u.openid, o.museum_subscription_id AS subscription_id, ms.museum_id, ms.template_id, t.name AS template_name, t.description AS template_desc, t.validity_type, ms.price AS subscription_price, o.amount AS order_amount, o.status AS order_status, o.transaction_id, o.create_date AS order_create_time, o.pay_time, us.start_date, us.end_date, us.is_active AS subscription_active, mo.name AS museum_name FROM rag_flow.subscription_orders o LEFT JOIN rag_flow.users_info u ON o.user_id = u.user_id LEFT JOIN rag_flow.museum_subscriptions ms ON o.museum_subscription_id = ms.sub_id LEFT JOIN rag_flow.subscription_templates t ON ms.template_id = t.id LEFT JOIN rag_flow.user_subscriptions us ON o.order_id = us.order_id LEFT JOIN rag_flow.mesum_overview mo ON ms.museum_id = mo.id """ # 优先使用 order_id 查询 if order_id and not combined: sql = "SELECT * FROM subscription_orders WHERE order_id = %s" result = execute_query(sql, (order_id,)) return result[0] if result and len(result) > 0 else None # 返回单个订单 # 如果 order_id 不存在,使用 user_id 查询 if user_id and not combined: sql = "SELECT * FROM subscription_orders WHERE user_id = %s" result = execute_query(sql, (user_id,)) return result if result else [] # 返回所有订单(列表) if user_id and combined: sql = sql_wo_condition + f"\n WHERE o.user_id = %s" result = execute_query(sql, (user_id,)) return result if result else [] # 返回所有订单(列表) if order_id and combined: sql = sql_wo_condition + f"\n WHERE o.order_id = %s" result = execute_query(sql, (order_id,)) return result[0] if result and len(result) > 0 else None # 返回单个订单 def create_user_subscription(data: dict) -> int: """ 创建用户订阅记录 功能说明: - 在 user_subscriptions 表中插入新记录 - 表示用户购买并激活了一个订阅 参数说明: - data: 包含订阅数据的字典,字段包括: user_id: 用户ID (必需) museum_subscription_id: 博物馆订阅ID (必需) order_id: 关联的订单ID (必需) start_date: 开始时间 (默认当前时间) end_date: 结束时间 (必需) is_active: 是否激活 (默认1) create_date: 创建时间 (默认当前时间) 返回: - 执行结果的行数 重要逻辑: - 为可选字段提供默认值 - 确保所有必需字段都有值 - 处理时间字段的默认值 """ sql = """ INSERT INTO user_subscriptions ( user_id, museum_subscription_id, order_id, start_date, end_date, is_active, create_date ) VALUES ( %s, %s, %s, %s, %s, %s, %s ) """ # 设置默认值 params = ( data.get("user_id"), data.get("museum_subscription_id"), data.get("order_id"), data.get("start_date", datetime.now()), data.get("end_date"), data.get("is_active", 1), data.get("create_date", datetime.now()) ) return execute_query(sql, params, autocommit=True) def deactivate_previous_subscriptions(user_id: str, museum_subscription_id: str) -> int: """ 禁用用户在同一博物馆的旧订阅 功能说明: - 将用户在同一博物馆的所有激活订阅设为非激活状态 - 确保同一博物馆只有一个激活订阅 参数说明: - user_id: 用户ID - museum_id: 博物馆ID 返回: - 受影响的行数 重要逻辑: - 通过JOIN关联博物馆订阅表 - 只更新同一博物馆的订阅 - 保持历史订阅记录,只修改激活状态 """ sql = """ UPDATE user_subscriptions us JOIN museum_subscriptions ms ON us.museum_subscription_id = ms.sub_id SET us.is_active = 0 WHERE us.user_id = %s AND us.museum_subscription_id = %s AND us.is_active = 1 """ return execute_query(sql, (user_id, museum_subscription_id), autocommit=True) def get_user_by_id(user_id: str) -> dict: """ 根据用户ID获取用户信息 功能说明: - 通过用户ID查询用户基本信息 参数说明: - user_id: 用户ID 返回: - 用户信息的字典,如果不存在则返回None 重要逻辑: - 直接查询用户表的所有字段 """ sql = "SELECT * FROM users_info WHERE user_id = %s" result = execute_query(sql, (user_id,)) return result[0] if result else None def get_subscription_template_by_id(template_id: int) -> dict: """ 根据模板ID获取订阅模板信息 功能说明: - 通过模板ID查询订阅模板详情 参数说明: - template_id: 模板ID 返回: - 模板信息的字典,如果不存在则返回None """ sql = "SELECT * FROM subscription_templates WHERE id = %s" result = execute_query(sql, (template_id,)) return result[0] if result else None def get_active_user_subscription(user_id: str, museum_id: int) -> dict: """ 获取用户在指定博物馆的有效订阅 功能说明: - 查询用户在特定博物馆的当前有效订阅 - 有效订阅定义为: 已激活且未过期 参数说明: - user_id: 用户ID - museum_id: 博物馆ID 返回: - 订阅信息的字典,如果不存在则返回None 重要逻辑: - 通过JOIN关联博物馆订阅表 - 检查 is_active=1 和 end_date > NOW() - 返回最新到期的订阅 """ sql = """ SELECT us.* FROM user_subscriptions us JOIN museum_subscriptions ms ON us.museum_subscription_id = ms.id WHERE us.user_id = %s AND ms.museum_id = %s AND us.is_active = 1 AND us.end_date > NOW() ORDER BY us.end_date DESC LIMIT 1 """ result = execute_query(sql, (user_id, museum_id)) return result[0] if result else None def get_user_subscription_history(user_id: str) -> list: """ 获取用户的订阅历史记录 功能说明: - 查询用户的所有订阅记录(包括历史和当前) - 返回完整的订阅详情,包含博物馆和模板信息 参数说明: - user_id: 用户ID 返回: - 包含订阅历史记录的字典列表 重要逻辑: - 通过多层JOIN关联所有相关表 - 包含博物馆名称、模板信息等 - 按开始时间倒序排列,最新订阅在前 """ sql = """ SELECT us.*, ms.price, ms.sub_id, st.name AS template_name, st.description AS template_description, st.validity_type, mo.name AS museum_name FROM user_subscriptions us JOIN museum_subscriptions ms ON us.museum_subscription_id = ms.id JOIN subscription_templates st ON ms.template_id = st.id JOIN mesum_overview mo ON ms.museum_id = mo.id WHERE us.user_id = %s ORDER BY us.start_date DESC """ return execute_query(sql, (user_id,)) def get_user_subscription_by_order(order_id: str) -> dict: """ 根据订单ID获取用户订阅信息 功能说明: - 通过订单ID查询关联的用户订阅 参数说明: - order_id: 订单ID 返回: - 订阅信息的字典,如果不存在则返回None 重要逻辑: - 用于支付回调后验证订阅是否已创建 """ sql = "SELECT * FROM user_subscriptions WHERE order_id = %s" result = execute_query(sql, (order_id,)) return result[0] if result else None def activate_free_subscription(user_id: str, museum_id: int) -> str: """ 激活免费订阅 功能说明: - 为用户在指定博物馆激活免费订阅 - 创建订单记录和订阅记录 参数说明: - user_id: 用户ID - museum_id: 博物馆ID 返回: - 创建的订单ID 重要逻辑: - 查找博物馆的免费订阅 - 创建订单记录 (状态为activated) - 创建订阅记录 (有效期为7天) """ # 1. 获取博物馆的免费订阅 sql = """ SELECT ms.id FROM museum_subscriptions ms JOIN subscription_templates st ON ms.template_id = st.id WHERE ms.museum_id = %s AND st.validity_type = 'free' AND ms.is_active = 1 LIMIT 1 """ result = execute_query(sql, (museum_id,)) if not result: raise ValueError("该博物馆没有可用的免费订阅") subscription_id = result[0]["id"] # 2. 创建免费订单 order_id = f"FREE_{int(time.time())}" create_order({ "order_id": order_id, "user_id": user_id, "museum_subscription_id": subscription_id, "amount": 0, "status": "activated", "create_date": datetime.now() }) # 3. 创建用户订阅记录 (免费订阅有效期为7天) start_date = datetime.now() end_date = start_date + timedelta(days=7) create_user_subscription({ "user_id": user_id, "museum_subscription_id": subscription_id, "order_id": order_id, "start_date": start_date, "end_date": end_date, "is_active": True }) # 4. 禁用同一博物馆的旧订阅 deactivate_previous_subscriptions(user_id, subscription_id) return order_id def activate_user_subscription( user_id: str, museum_subscription_id: str, order_id: str ) -> bool: """ 激活用户订阅服务 参数: - user_id: 用户ID - museum_subscription_id: 博物馆订阅套餐ID - order_id: 关联的订单ID 返回: - 激活成功返回True,失败返回False 主要逻辑: 1. 获取博物馆订阅信息 2. 禁用同一博物馆的旧订阅 3. 计算订阅有效期 4. 创建用户订阅记录 5. 处理重复激活和并发请求 """ try: # 1. 获取博物馆订阅信息 museum_sub = get_museum_subscription_by_id(museum_subscription_id) if not museum_sub: logger.error(f"博物馆订阅不存在: {museum_subscription_id}") return False # 2. 获取关联的订阅模板 template = get_subscription_template_by_id(museum_sub["template_id"]) if not template: logger.error(f"订阅模板不存在: {museum_sub['template_id']}") return False # 3. 禁用同一博物馆的旧订阅 deactivated_count = deactivate_previous_subscriptions( user_id=user_id, museum_subscription_id=museum_subscription_id ) logger.info(f"已禁用{deactivated_count}个同一博物馆的旧订阅") # 4. 计算订阅有效期 start_date = datetime.now() end_date = calculate_subscription_expiry(start_date,template["validity_type"]) # 5. 检查是否已激活过(防止重复激活) existing_sub = get_user_subscription_by_order(order_id) if existing_sub: logger.warning(f"订阅已激活过,跳过重复激活. Order: {order_id}") return True # 6. 创建用户订阅记录 subscription_data = { "user_id": user_id, "museum_subscription_id": museum_subscription_id, "order_id": order_id, "start_date": start_date, "end_date": end_date, "is_active": 1 } create_user_subscription(subscription_data) logger.info(f"订阅激活成功. 用户: {user_id}, 套餐: {museum_subscription_id}, " f"有效期: {start_date} 至 {end_date}") return True except Exception as e: logger.exception(f"激活订阅失败: {str(e)}") return False def check_user_subscription(user_id: str, museum_id: int) -> dict: """ 检查用户是否拥有指定博物馆的有效订阅 功能说明: - 检查用户是否有指定博物馆的未过期激活订阅 参数说明: - user_id: 用户ID - museum_id: 博物馆ID 返回: - 订阅信息字典,如果没有则返回None 重要逻辑: - 优先检查当前有效的订阅 - 如果没有,检查免费订阅是否可用 """ # 1. 检查当前有效订阅 active_sub = get_active_user_subscription(user_id, museum_id) if active_sub: return active_sub # 2. 检查是否有免费订阅可用 # (这里可以扩展更多逻辑,如试用期检查等) return None def calculate_subscription_expiry(start_date: datetime, validity_type: str) -> datetime: """ 根据有效期类型计算到期日期 参数: - start_date: 订阅开始日期 - validity_type: 有效期类型 (free, 1month, 1year, permanent) 返回: - 到期日期 """ if validity_type == "free": # 免费套餐通常有较短有效期(例如7天) return start_date + timedelta(days=7) elif validity_type == "1month": # 下个月的同一天(自动处理月末情况) return start_date + relativedelta(months=1) elif validity_type == "1year": # 下一年的同一天 return start_date + relativedelta(years=1) elif validity_type == "permanent": # 永久有效设置为100年后 return start_date + relativedelta(years=100) else: # 未知类型默认30天 logger.warning(f"未知有效期类型: {validity_type}, 使用默认30天") return start_date + timedelta(days=30)