import pymysql from pymysql import Connection from pymysql.err import OperationalError, InterfaceError from contextlib import contextmanager from config import DATABASE_CONFIG from datetime import datetime,timedelta import logging from typing import Optional,Union from tenacity import retry, stop_after_attempt, wait_fixed, retry_if_exception_type # 配置日志 logging.basicConfig(level=logging.INFO) logger = logging.getLogger("DB") # 重试策略(更保守的配置) RETRY_CONFIG = { 'stop': stop_after_attempt(3), # 最大尝试2次 'wait': wait_fixed(1), # 固定1秒重试间隔 'retry': retry_if_exception_type((OperationalError, InterfaceError)), 'reraise': True, # 重新抛出原始异常 'before_sleep': lambda retry_state: logger.warning( f"Retrying ({retry_state.attempt_number}/2) due to: {retry_state.outcome.exception()}" ) } @contextmanager def get_connection(autocommit: bool = False): """获取单次数据库连接(带重试机制)""" @retry(**RETRY_CONFIG) def connect(): try: return pymysql.connect(**DATABASE_CONFIG) except (OperationalError, InterfaceError) as e: logger.error(f"Connection failed: {str(e)}") raise conn = None try: conn = connect() # 原有重试逻辑不变 conn.autocommit(autocommit) # 新增关键设置 yield conn if not autocommit: conn.commit() # 非自动提交模式下统一提交 except Exception as e: if conn and not autocommit: conn.rollback() # 非自动提交模式下回滚 raise finally: if conn: conn.close() def get_query_type(query: str) -> str: """更精准识别查询类型""" query = query.strip().upper() if query.startswith(("WITH", "SELECT")): return "SELECT" if query.startswith(("INSERT", "REPLACE")): return "INSERT" if query.startswith("UPDATE"): return "UPDATE" if query.startswith("DELETE"): return "DELETE" return "OTHER" def process_result(cursor, query: str) -> Union[int, list[dict[str, any]]]: """统一处理结果""" try: # 自动识别查询类型 query_type =get_query_type(query) if query_type == "SELECT": return cursor.fetchall() elif query_type == "INSERT": return cursor.lastrowid # 返回插入的ID elif query_type in ("UPDATE", "DELETE"): return cursor.rowcount # 返回影响行数 else: return cursor.rowcount # 其他操作返回影响行数 except IndexError: raise ValueError("Invalid SQL query format") @retry(**RETRY_CONFIG) def execute_query( query: str, params: tuple | dict = None, *, connection: Optional[Connection] = None, # 明确类型 连接参数 autocommit: bool = False ): """ 安全执行查询(适配低频操作) :param read_only: 标记是否为只读查询(优化事务) """ # 输入安全验证 if not query.strip(): raise ValueError("Empty query") forbidden_keywords = ['DROP', 'TRUNCATE', 'GRANT'] if any(kw in query.upper() for kw in forbidden_keywords): raise PermissionError("Dangerous operation detected") # 连接管理逻辑变更 if connection: # 使用外部连接 cursor = connection.cursor() try: cursor.execute(query, params) # 根据查询类型处理结果 if autocommit and not query.strip().upper().startswith("SELECT"): connection.commit() return process_result(cursor) finally: cursor.close() else: # 新建连接 with get_connection(autocommit=autocommit) as conn: with conn.cursor() as cursor: # logging.info(f"exec sql {query} {params}") cursor.execute(query, params) return process_result(cursor,query) def create_museum(data: dict): sql = """ INSERT INTO mesum_overview (name, brief, chat_id, photo_url, longitude, latitude, category, create_time, create_date, update_time, update_date, address, free) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) """ now = int(datetime.now().timestamp()) params = ( data['name'], data.get('brief'), data['chat_id'], data.get('photo_url'), data.get('longitude'), data.get('latitude'), data.get('category'), now, datetime.fromtimestamp(now), now, datetime.fromtimestamp(now), data.get('address'), data.get('free', 0) ) return execute_query(sql, params) def get_museums(search: str = None, free: int = None): base_sql = "SELECT * FROM mesum_overview WHERE 1=1" params = [] if search: base_sql += " AND (name LIKE %s OR brief LIKE %s)" params.extend([f"%{search}%", f"%{search}%"]) if free is not None: base_sql += " AND free = %s" params.append(free) base_sql += " ORDER BY create_time DESC" return execute_query(base_sql, tuple(params)) def update_museum(museum_id: int, data: dict): """动态更新博物馆信息(仅更新data中包含的字段)""" allowed_fields = { 'name': 'name = %s', 'brief': 'brief = %s', 'photo_url': 'photo_url = %s', 'longitude': 'longitude = %s', 'latitude': 'latitude = %s', 'category': 'category = %s', 'address': 'address = %s', 'free': 'free = %s' } update_fields = [] params = [] now = int(datetime.now().timestamp()) update_date = datetime.fromtimestamp(now) # 收集动态字段 for field, sql_part in allowed_fields.items(): if field in data: update_fields.append(sql_part) params.append(data[field]) # 必须包含至少一个业务更新字段 if not update_fields: raise ValueError("未提供有效更新字段") # 添加自动更新时间 update_fields.extend([ 'update_time = %s', 'update_date = %s' ]) params.extend([now, update_date]) # 构建动态SQL set_clause = ", ".join(update_fields) sql = f""" UPDATE mesum_overview SET {set_clause} WHERE id = %s """ params.append(museum_id) # 执行更新 rowcount = execute_query(sql, tuple(params)) if rowcount == 0: raise HTTPException(status_code=404, detail="博物馆不存在") return get_museum_by_id(museum_id) def get_museum_by_id(museum_id: int): sql = "SELECT * FROM mesum_overview WHERE id = %s" result = execute_query(sql, (museum_id,)) assert isinstance(result, list), "Unexpected return type" return result[0] if result else None def delete_museum(museum_id: int): sql = "DELETE FROM mesum_overview WHERE id = %s" rowcount = execute_query(sql, (museum_id,)) if rowcount == 0: raise HTTPException(404, "Delete failed") return {"message": f"Deleted {rowcount} museums"} # 创建授权记录 def create_users_museum(data: dict): sql = """ INSERT INTO rag_flow.users_museum (user_id, museum_id, create_time, create_date, update_time, update_date) VALUES (%s, %s, %s, %s, %s, %s) """ now = int(datetime.now().timestamp()) params = ( data['user_id'], data['museum_id'], now, datetime.fromtimestamp(now), now, datetime.fromtimestamp(now) ) return execute_query(sql, params) # 查询授权记录(多条件) def get_users_museums(user_id: str = None, museum_id: int = None, get_all: bool = False): base_sql = "SELECT * FROM rag_flow.users_museum WHERE 1=1" params = [] if not get_all: if user_id: base_sql += " AND user_id = %s" params.append(user_id) if museum_id is not None: base_sql += " AND museum_id = %s" params.append(museum_id) base_sql += " ORDER BY create_time DESC" return execute_query(base_sql, tuple(params)) # 按ID获取单条记录 def get_users_museums_by_user_id(user_id: str): sql = "SELECT * FROM rag_flow.users_museum WHERE user_id = %s" result = execute_query(sql, (user_id,)) return result # 更新授权信息 def update_users_museum(id: int, data: dict): """动态更新用户博物馆授权(仅更新data中包含的字段)""" allowed_fields = { 'user_id': 'user_id = %s', 'museum_id': 'museum_id = %s' } update_fields = [] params = [] now = int(datetime.now().timestamp()) update_date = datetime.fromtimestamp(now) # 收集动态字段 for field, sql_part in allowed_fields.items(): if field in data: update_fields.append(sql_part) params.append(data[field]) # 必须包含至少一个业务字段 if not update_fields: raise ValueError("未提供有效更新字段") # 添加自动更新时间 update_fields.extend([ 'update_time = %s', 'update_date = %s' ]) params.extend([now, update_date]) # 构建动态SQL set_clause = ", ".join(update_fields) sql = f""" UPDATE rag_flow.users_museum SET {set_clause} WHERE id = %s """ params.append(id) # 执行更新 rowcount = execute_query(sql, tuple(params)) if rowcount == 0: raise HTTPException(status_code=404, detail="授权记录不存在") return get_users_museum_by_id(id) # 删除授权记录 def delete_users_museum(id: int): sql = "DELETE FROM rag_flow.users_museum WHERE id = %s" execute_query(sql, (id,)) return {"message": "User museum authorization deleted"} # 批量检查授权状态 def check_auth_batch(user_id: str, museum_ids: list): if not museum_ids: return [] placeholders = ','.join(['%s'] * len(museum_ids)) sql = f""" SELECT museum_id FROM rag_flow.users_museum WHERE user_id = %s AND museum_id IN ({placeholders}) """ params = [user_id] + museum_ids result = execute_query(sql, params) return [item['museum_id'] for item in result] from datetime import datetime # 创建用户 def create_user(data: dict): sql = """ INSERT INTO rag_flow.users_info (user_id, openid, phone, email, token, balance, status, last_login_time, create_time, create_date, update_time, update_date) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) # 12个占位符 """ now = int(datetime.now().timestamp()) params = ( data['user_id'], data.get('openid'), data.get('phone'), data.get('email'), data.get('token'), data.get('balance', 0), # 默认余额0 data.get('status', 1), # 默认状态1正常 data.get('last_login_time'), now, datetime.fromtimestamp(now), now, datetime.fromtimestamp(now) ) #logging.info(f"create user {data} {sql} {params}") return execute_query(sql, params) # 查询用户(多条件) def get_users(status: int = None, email: str = None, phone: str = None): base_sql = "SELECT * FROM rag_flow.users_info WHERE 1=1" params = [] if status is not None: base_sql += " AND status = %s" params.append(status) if email: base_sql += " AND email = %s" params.append(email) if phone: base_sql += " AND phone = %s" params.append(phone) base_sql += " ORDER BY create_time DESC" return execute_query(base_sql, tuple(params)) # 按用户ID获取用户 def get_user_by_id(user_id: str): sql = "SELECT * FROM rag_flow.users_info WHERE user_id = %s" result = execute_query(sql, (user_id,)) return result[0] if result else None # 更新用户信息 def update_user(user_id: str, data: dict): """动态更新用户信息(仅更新data中包含的字段)""" # 允许更新的字段白名单 allowed_fields = { 'phone': 'phone = %s', 'email': 'email = %s', 'token': 'token = %s', 'balance': 'balance = %s', 'status': 'status = %s', 'last_login_time': 'last_login_time = %s' } # 过滤有效更新字段 update_fields = [] params = [] now = int(datetime.now().timestamp()) update_date = datetime.fromtimestamp(now) # 收集动态字段 for field, sql_part in allowed_fields.items(): if field in data: update_fields.append(sql_part) params.append(data[field]) # 必须包含至少一个更新字段(除自动更新的时间字段) if not update_fields: raise ValueError("没有提供有效更新字段") # 添加自动更新的时间字段 update_fields.extend([ 'update_time = %s', 'update_date = %s' ]) params.extend([now, update_date]) # 构建动态SQL set_clause = ", ".join(update_fields) sql = f""" UPDATE rag_flow.users_info SET {set_clause} WHERE user_id = %s """ # 添加用户ID作为最后参数 params.append(user_id) # 执行更新 rowcount = execute_query(sql, tuple(params)) if rowcount == 0: raise HTTPException(status_code=404, detail="用户不存在") return get_user_by_id(user_id) # 删除用户 def delete_user(user_id: str): sql = "DELETE FROM rag_flow.users_info WHERE user_id = %s" execute_query(sql, (user_id,)) return {"message": "User deleted"} # 检查openid是否存在 def check_openid_exists(openid: str): sql = "SELECT 1 FROM rag_flow.users_info WHERE openid = %s" result = execute_query(sql, (openid,)) return bool(result) # 更新用户余额 def update_balance(user_id: str, amount: int): sql = """ UPDATE rag_flow.users_info SET balance = balance + %s, update_time = %s, update_date = %s WHERE user_id = %s """ now = int(datetime.now().timestamp()) params = ( amount, now, datetime.fromtimestamp(now), user_id ) execute_query(sql, params) return get_user_by_id(user_id) # 更新登录信息 def update_login_info(user_id: str, token: str): sql = """ UPDATE rag_flow.users_info SET token = %s, last_login_time = %s, update_time = %s, update_date = %s WHERE user_id = %s """ now = int(datetime.now().timestamp()) params = ( token, now, now, datetime.fromtimestamp(now), user_id ) execute_query(sql, params) return get_user_by_id(user_id)