534 lines
15 KiB
Python
534 lines
15 KiB
Python
import pymysql
|
||
from pymysql import Connection
|
||
from pymysql.err import OperationalError, InterfaceError
|
||
from contextlib import contextmanager
|
||
from config import DATABASE_CONFIG
|
||
from datetime import datetime,timedelta
|
||
import logging
|
||
from typing import Optional,Union
|
||
from tenacity import retry, stop_after_attempt, wait_fixed, retry_if_exception_type
|
||
|
||
|
||
# 配置日志
|
||
logging.basicConfig(level=logging.INFO)
|
||
logger = logging.getLogger("DB")
|
||
|
||
|
||
# 重试策略(更保守的配置)
|
||
RETRY_CONFIG = {
|
||
'stop': stop_after_attempt(3), # 最大尝试2次
|
||
'wait': wait_fixed(1), # 固定1秒重试间隔
|
||
'retry': retry_if_exception_type((OperationalError, InterfaceError)),
|
||
'reraise': True, # 重新抛出原始异常
|
||
'before_sleep': lambda retry_state: logger.warning(
|
||
f"Retrying ({retry_state.attempt_number}/2) due to: {retry_state.outcome.exception()}"
|
||
)
|
||
}
|
||
|
||
@contextmanager
|
||
def get_connection(autocommit: bool = False):
|
||
"""获取单次数据库连接(带重试机制)"""
|
||
|
||
@retry(**RETRY_CONFIG)
|
||
def connect():
|
||
try:
|
||
return pymysql.connect(**DATABASE_CONFIG)
|
||
except (OperationalError, InterfaceError) as e:
|
||
logger.error(f"Connection failed: {str(e)}")
|
||
raise
|
||
conn = None
|
||
try:
|
||
conn = connect() # 原有重试逻辑不变
|
||
conn.autocommit(autocommit) # 新增关键设置
|
||
yield conn
|
||
if not autocommit:
|
||
conn.commit() # 非自动提交模式下统一提交
|
||
except Exception as e:
|
||
if conn and not autocommit:
|
||
conn.rollback() # 非自动提交模式下回滚
|
||
raise
|
||
finally:
|
||
if conn:
|
||
conn.close()
|
||
def get_query_type(query: str) -> str:
|
||
"""更精准识别查询类型"""
|
||
query = query.strip().upper()
|
||
if query.startswith(("WITH", "SELECT")):
|
||
return "SELECT"
|
||
if query.startswith(("INSERT", "REPLACE")):
|
||
return "INSERT"
|
||
if query.startswith("UPDATE"):
|
||
return "UPDATE"
|
||
if query.startswith("DELETE"):
|
||
return "DELETE"
|
||
return "OTHER"
|
||
|
||
def process_result(cursor, query: str) -> Union[int, list[dict[str, any]]]:
|
||
"""统一处理结果"""
|
||
try:
|
||
# 自动识别查询类型
|
||
query_type =get_query_type(query)
|
||
if query_type == "SELECT":
|
||
return cursor.fetchall()
|
||
elif query_type == "INSERT":
|
||
return cursor.lastrowid # 返回插入的ID
|
||
elif query_type in ("UPDATE", "DELETE"):
|
||
return cursor.rowcount # 返回影响行数
|
||
else:
|
||
return cursor.rowcount # 其他操作返回影响行数
|
||
except IndexError:
|
||
raise ValueError("Invalid SQL query format")
|
||
|
||
@retry(**RETRY_CONFIG)
|
||
def execute_query(
|
||
query: str,
|
||
params: tuple | dict = None,
|
||
*,
|
||
connection: Optional[Connection] = None, # 明确类型 连接参数
|
||
autocommit: bool = False
|
||
):
|
||
"""
|
||
安全执行查询(适配低频操作)
|
||
:param read_only: 标记是否为只读查询(优化事务)
|
||
"""
|
||
# 输入安全验证
|
||
if not query.strip():
|
||
raise ValueError("Empty query")
|
||
|
||
forbidden_keywords = ['DROP', 'TRUNCATE', 'GRANT']
|
||
if any(kw in query.upper() for kw in forbidden_keywords):
|
||
raise PermissionError("Dangerous operation detected")
|
||
|
||
# 连接管理逻辑变更
|
||
if connection: # 使用外部连接
|
||
cursor = connection.cursor()
|
||
try:
|
||
cursor.execute(query, params)
|
||
# 根据查询类型处理结果
|
||
if autocommit and not query.strip().upper().startswith("SELECT"):
|
||
connection.commit()
|
||
return process_result(cursor)
|
||
finally:
|
||
cursor.close()
|
||
else: # 新建连接
|
||
with get_connection(autocommit=autocommit) as conn:
|
||
with conn.cursor() as cursor:
|
||
# logging.info(f"exec sql {query} {params}")
|
||
cursor.execute(query, params)
|
||
return process_result(cursor,query)
|
||
|
||
def create_museum(data: dict):
|
||
sql = """
|
||
INSERT INTO mesum_overview
|
||
(name, brief, chat_id, photo_url, longitude, latitude, category,
|
||
create_time, create_date, update_time, update_date, address, free)
|
||
VALUES
|
||
(%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||
"""
|
||
|
||
now = int(datetime.now().timestamp())
|
||
params = (
|
||
data['name'],
|
||
data.get('brief'),
|
||
data['chat_id'],
|
||
data.get('photo_url'),
|
||
data.get('longitude'),
|
||
data.get('latitude'),
|
||
data.get('category'),
|
||
now,
|
||
datetime.fromtimestamp(now),
|
||
now,
|
||
datetime.fromtimestamp(now),
|
||
data.get('address'),
|
||
data.get('free', 0)
|
||
)
|
||
|
||
return execute_query(sql, params)
|
||
|
||
|
||
def get_museums(search: str = None, free: int = None):
|
||
base_sql = "SELECT * FROM mesum_overview WHERE 1=1"
|
||
params = []
|
||
|
||
if search:
|
||
base_sql += " AND (name LIKE %s OR brief LIKE %s)"
|
||
params.extend([f"%{search}%", f"%{search}%"])
|
||
|
||
if free is not None:
|
||
base_sql += " AND free = %s"
|
||
params.append(free)
|
||
|
||
base_sql += " ORDER BY create_time DESC"
|
||
|
||
return execute_query(base_sql, tuple(params))
|
||
|
||
|
||
def update_museum(museum_id: int, data: dict):
|
||
"""动态更新博物馆信息(仅更新data中包含的字段)"""
|
||
allowed_fields = {
|
||
'name': 'name = %s',
|
||
'brief': 'brief = %s',
|
||
'photo_url': 'photo_url = %s',
|
||
'longitude': 'longitude = %s',
|
||
'latitude': 'latitude = %s',
|
||
'category': 'category = %s',
|
||
'address': 'address = %s',
|
||
'free': 'free = %s'
|
||
}
|
||
|
||
update_fields = []
|
||
params = []
|
||
|
||
now = int(datetime.now().timestamp())
|
||
update_date = datetime.fromtimestamp(now)
|
||
|
||
# 收集动态字段
|
||
for field, sql_part in allowed_fields.items():
|
||
if field in data:
|
||
update_fields.append(sql_part)
|
||
params.append(data[field])
|
||
|
||
# 必须包含至少一个业务更新字段
|
||
if not update_fields:
|
||
raise ValueError("未提供有效更新字段")
|
||
|
||
# 添加自动更新时间
|
||
update_fields.extend([
|
||
'update_time = %s',
|
||
'update_date = %s'
|
||
])
|
||
params.extend([now, update_date])
|
||
|
||
# 构建动态SQL
|
||
set_clause = ", ".join(update_fields)
|
||
sql = f"""
|
||
UPDATE mesum_overview
|
||
SET {set_clause}
|
||
WHERE id = %s
|
||
"""
|
||
|
||
params.append(museum_id)
|
||
|
||
# 执行更新
|
||
rowcount = execute_query(sql, tuple(params))
|
||
|
||
if rowcount == 0:
|
||
raise HTTPException(status_code=404, detail="博物馆不存在")
|
||
|
||
return get_museum_by_id(museum_id)
|
||
|
||
|
||
def get_museum_by_id(museum_id: int):
|
||
sql = "SELECT * FROM mesum_overview WHERE id = %s"
|
||
result = execute_query(sql, (museum_id,))
|
||
assert isinstance(result, list), "Unexpected return type"
|
||
return result[0] if result else None
|
||
|
||
|
||
def delete_museum(museum_id: int):
|
||
sql = "DELETE FROM mesum_overview WHERE id = %s"
|
||
rowcount = execute_query(sql, (museum_id,))
|
||
if rowcount == 0:
|
||
raise HTTPException(404, "Delete failed")
|
||
return {"message": f"Deleted {rowcount} museums"}
|
||
|
||
|
||
|
||
# 创建授权记录
|
||
def create_users_museum(data: dict):
|
||
sql = """
|
||
INSERT INTO rag_flow.users_museum
|
||
(user_id, museum_id, create_time, create_date, update_time, update_date)
|
||
VALUES
|
||
(%s, %s, %s, %s, %s, %s)
|
||
"""
|
||
|
||
now = int(datetime.now().timestamp())
|
||
params = (
|
||
data['user_id'],
|
||
data['museum_id'],
|
||
now,
|
||
datetime.fromtimestamp(now),
|
||
now,
|
||
datetime.fromtimestamp(now)
|
||
)
|
||
|
||
return execute_query(sql, params)
|
||
|
||
|
||
# 查询授权记录(多条件)
|
||
def get_users_museums(user_id: str = None, museum_id: int = None, get_all: bool = False):
|
||
base_sql = "SELECT * FROM rag_flow.users_museum WHERE 1=1"
|
||
params = []
|
||
|
||
if not get_all:
|
||
if user_id:
|
||
base_sql += " AND user_id = %s"
|
||
params.append(user_id)
|
||
if museum_id is not None:
|
||
base_sql += " AND museum_id = %s"
|
||
params.append(museum_id)
|
||
|
||
base_sql += " ORDER BY create_time DESC"
|
||
|
||
return execute_query(base_sql, tuple(params))
|
||
|
||
|
||
# 按ID获取单条记录
|
||
def get_users_museums_by_user_id(user_id: str):
|
||
sql = "SELECT * FROM rag_flow.users_museum WHERE user_id = %s"
|
||
result = execute_query(sql, (user_id,))
|
||
return result
|
||
|
||
|
||
# 更新授权信息
|
||
def update_users_museum(id: int, data: dict):
|
||
"""动态更新用户博物馆授权(仅更新data中包含的字段)"""
|
||
allowed_fields = {
|
||
'user_id': 'user_id = %s',
|
||
'museum_id': 'museum_id = %s'
|
||
}
|
||
|
||
update_fields = []
|
||
params = []
|
||
|
||
now = int(datetime.now().timestamp())
|
||
update_date = datetime.fromtimestamp(now)
|
||
|
||
# 收集动态字段
|
||
for field, sql_part in allowed_fields.items():
|
||
if field in data:
|
||
update_fields.append(sql_part)
|
||
params.append(data[field])
|
||
|
||
# 必须包含至少一个业务字段
|
||
if not update_fields:
|
||
raise ValueError("未提供有效更新字段")
|
||
|
||
# 添加自动更新时间
|
||
update_fields.extend([
|
||
'update_time = %s',
|
||
'update_date = %s'
|
||
])
|
||
params.extend([now, update_date])
|
||
|
||
# 构建动态SQL
|
||
set_clause = ", ".join(update_fields)
|
||
sql = f"""
|
||
UPDATE rag_flow.users_museum
|
||
SET {set_clause}
|
||
WHERE id = %s
|
||
"""
|
||
|
||
params.append(id)
|
||
|
||
# 执行更新
|
||
rowcount = execute_query(sql, tuple(params))
|
||
|
||
if rowcount == 0:
|
||
raise HTTPException(status_code=404, detail="授权记录不存在")
|
||
|
||
return get_users_museum_by_id(id)
|
||
|
||
|
||
# 删除授权记录
|
||
def delete_users_museum(id: int):
|
||
sql = "DELETE FROM rag_flow.users_museum WHERE id = %s"
|
||
execute_query(sql, (id,))
|
||
return {"message": "User museum authorization deleted"}
|
||
|
||
|
||
# 批量检查授权状态
|
||
def check_auth_batch(user_id: str, museum_ids: list):
|
||
if not museum_ids:
|
||
return []
|
||
|
||
placeholders = ','.join(['%s'] * len(museum_ids))
|
||
sql = f"""
|
||
SELECT museum_id
|
||
FROM rag_flow.users_museum
|
||
WHERE user_id = %s
|
||
AND museum_id IN ({placeholders})
|
||
"""
|
||
|
||
params = [user_id] + museum_ids
|
||
result = execute_query(sql, params)
|
||
return [item['museum_id'] for item in result]
|
||
|
||
|
||
from datetime import datetime
|
||
|
||
|
||
# 创建用户
|
||
def create_user(data: dict):
|
||
sql = """
|
||
INSERT INTO rag_flow.users_info
|
||
(user_id, openid, phone, email, token, balance, status,
|
||
last_login_time, create_time, create_date, update_time, update_date)
|
||
VALUES
|
||
(%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) # 12个占位符
|
||
"""
|
||
|
||
now = int(datetime.now().timestamp())
|
||
params = (
|
||
data['user_id'],
|
||
data.get('openid'),
|
||
data.get('phone'),
|
||
data.get('email'),
|
||
data.get('token'),
|
||
data.get('balance', 0), # 默认余额0
|
||
data.get('status', 1), # 默认状态1正常
|
||
data.get('last_login_time'),
|
||
now,
|
||
datetime.fromtimestamp(now),
|
||
now,
|
||
datetime.fromtimestamp(now)
|
||
)
|
||
#logging.info(f"create user {data} {sql} {params}")
|
||
return execute_query(sql, params)
|
||
|
||
|
||
# 查询用户(多条件)
|
||
def get_users(status: int = None, email: str = None, phone: str = None):
|
||
base_sql = "SELECT * FROM rag_flow.users_info WHERE 1=1"
|
||
params = []
|
||
|
||
if status is not None:
|
||
base_sql += " AND status = %s"
|
||
params.append(status)
|
||
|
||
if email:
|
||
base_sql += " AND email = %s"
|
||
params.append(email)
|
||
|
||
if phone:
|
||
base_sql += " AND phone = %s"
|
||
params.append(phone)
|
||
|
||
base_sql += " ORDER BY create_time DESC"
|
||
|
||
return execute_query(base_sql, tuple(params))
|
||
|
||
|
||
# 按用户ID获取用户
|
||
def get_user_by_id(user_id: str):
|
||
sql = "SELECT * FROM rag_flow.users_info WHERE user_id = %s"
|
||
result = execute_query(sql, (user_id,))
|
||
return result[0] if result else None
|
||
|
||
|
||
# 更新用户信息
|
||
def update_user(user_id: str, data: dict):
|
||
"""动态更新用户信息(仅更新data中包含的字段)"""
|
||
# 允许更新的字段白名单
|
||
allowed_fields = {
|
||
'phone': 'phone = %s',
|
||
'email': 'email = %s',
|
||
'token': 'token = %s',
|
||
'balance': 'balance = %s',
|
||
'status': 'status = %s',
|
||
'last_login_time': 'last_login_time = %s'
|
||
}
|
||
|
||
# 过滤有效更新字段
|
||
update_fields = []
|
||
params = []
|
||
|
||
now = int(datetime.now().timestamp())
|
||
update_date = datetime.fromtimestamp(now)
|
||
|
||
# 收集动态字段
|
||
for field, sql_part in allowed_fields.items():
|
||
if field in data:
|
||
update_fields.append(sql_part)
|
||
params.append(data[field])
|
||
|
||
# 必须包含至少一个更新字段(除自动更新的时间字段)
|
||
if not update_fields:
|
||
raise ValueError("没有提供有效更新字段")
|
||
|
||
# 添加自动更新的时间字段
|
||
update_fields.extend([
|
||
'update_time = %s',
|
||
'update_date = %s'
|
||
])
|
||
params.extend([now, update_date])
|
||
|
||
# 构建动态SQL
|
||
set_clause = ", ".join(update_fields)
|
||
sql = f"""
|
||
UPDATE rag_flow.users_info
|
||
SET {set_clause}
|
||
WHERE user_id = %s
|
||
"""
|
||
|
||
# 添加用户ID作为最后参数
|
||
params.append(user_id)
|
||
|
||
# 执行更新
|
||
rowcount = execute_query(sql, tuple(params))
|
||
|
||
if rowcount == 0:
|
||
raise HTTPException(status_code=404, detail="用户不存在")
|
||
|
||
return get_user_by_id(user_id)
|
||
|
||
|
||
# 删除用户
|
||
def delete_user(user_id: str):
|
||
sql = "DELETE FROM rag_flow.users_info WHERE user_id = %s"
|
||
execute_query(sql, (user_id,))
|
||
return {"message": "User deleted"}
|
||
|
||
|
||
# 检查openid是否存在
|
||
def check_openid_exists(openid: str):
|
||
sql = "SELECT 1 FROM rag_flow.users_info WHERE openid = %s"
|
||
result = execute_query(sql, (openid,))
|
||
return bool(result)
|
||
|
||
|
||
# 更新用户余额
|
||
def update_balance(user_id: str, amount: int):
|
||
sql = """
|
||
UPDATE rag_flow.users_info
|
||
SET balance = balance + %s,
|
||
update_time = %s,
|
||
update_date = %s
|
||
WHERE user_id = %s
|
||
"""
|
||
|
||
now = int(datetime.now().timestamp())
|
||
params = (
|
||
amount,
|
||
now,
|
||
datetime.fromtimestamp(now),
|
||
user_id
|
||
)
|
||
|
||
execute_query(sql, params)
|
||
return get_user_by_id(user_id)
|
||
|
||
|
||
# 更新登录信息
|
||
def update_login_info(user_id: str, token: str):
|
||
sql = """
|
||
UPDATE rag_flow.users_info
|
||
SET token = %s,
|
||
last_login_time = %s,
|
||
update_time = %s,
|
||
update_date = %s
|
||
WHERE user_id = %s
|
||
"""
|
||
|
||
now = int(datetime.now().timestamp())
|
||
params = (
|
||
token,
|
||
now,
|
||
now,
|
||
datetime.fromtimestamp(now),
|
||
user_id
|
||
)
|
||
|
||
execute_query(sql, params)
|
||
return get_user_by_id(user_id) |