Files
ragflow_python/asr-monitor-test/app/database.py

534 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import pymysql
from pymysql import Connection
from pymysql.err import OperationalError, InterfaceError
from contextlib import contextmanager
from config import DATABASE_CONFIG
from datetime import datetime,timedelta
import logging
from typing import Optional,Union
from tenacity import retry, stop_after_attempt, wait_fixed, retry_if_exception_type
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("DB")
# 重试策略(更保守的配置)
RETRY_CONFIG = {
'stop': stop_after_attempt(3), # 最大尝试2次
'wait': wait_fixed(1), # 固定1秒重试间隔
'retry': retry_if_exception_type((OperationalError, InterfaceError)),
'reraise': True, # 重新抛出原始异常
'before_sleep': lambda retry_state: logger.warning(
f"Retrying ({retry_state.attempt_number}/2) due to: {retry_state.outcome.exception()}"
)
}
@contextmanager
def get_connection(autocommit: bool = False):
"""获取单次数据库连接(带重试机制)"""
@retry(**RETRY_CONFIG)
def connect():
try:
return pymysql.connect(**DATABASE_CONFIG)
except (OperationalError, InterfaceError) as e:
logger.error(f"Connection failed: {str(e)}")
raise
conn = None
try:
conn = connect() # 原有重试逻辑不变
conn.autocommit(autocommit) # 新增关键设置
yield conn
if not autocommit:
conn.commit() # 非自动提交模式下统一提交
except Exception as e:
if conn and not autocommit:
conn.rollback() # 非自动提交模式下回滚
raise
finally:
if conn:
conn.close()
def get_query_type(query: str) -> str:
"""更精准识别查询类型"""
query = query.strip().upper()
if query.startswith(("WITH", "SELECT")):
return "SELECT"
if query.startswith(("INSERT", "REPLACE")):
return "INSERT"
if query.startswith("UPDATE"):
return "UPDATE"
if query.startswith("DELETE"):
return "DELETE"
return "OTHER"
def process_result(cursor, query: str) -> Union[int, list[dict[str, any]]]:
"""统一处理结果"""
try:
# 自动识别查询类型
query_type =get_query_type(query)
if query_type == "SELECT":
return cursor.fetchall()
elif query_type == "INSERT":
return cursor.lastrowid # 返回插入的ID
elif query_type in ("UPDATE", "DELETE"):
return cursor.rowcount # 返回影响行数
else:
return cursor.rowcount # 其他操作返回影响行数
except IndexError:
raise ValueError("Invalid SQL query format")
@retry(**RETRY_CONFIG)
def execute_query(
query: str,
params: tuple | dict = None,
*,
connection: Optional[Connection] = None, # 明确类型 连接参数
autocommit: bool = False
):
"""
安全执行查询(适配低频操作)
:param read_only: 标记是否为只读查询(优化事务)
"""
# 输入安全验证
if not query.strip():
raise ValueError("Empty query")
forbidden_keywords = ['DROP', 'TRUNCATE', 'GRANT']
if any(kw in query.upper() for kw in forbidden_keywords):
raise PermissionError("Dangerous operation detected")
# 连接管理逻辑变更
if connection: # 使用外部连接
cursor = connection.cursor()
try:
cursor.execute(query, params)
# 根据查询类型处理结果
if autocommit and not query.strip().upper().startswith("SELECT"):
connection.commit()
return process_result(cursor)
finally:
cursor.close()
else: # 新建连接
with get_connection(autocommit=autocommit) as conn:
with conn.cursor() as cursor:
# logging.info(f"exec sql {query} {params}")
cursor.execute(query, params)
return process_result(cursor,query)
def create_museum(data: dict):
sql = """
INSERT INTO mesum_overview
(name, brief, chat_id, photo_url, longitude, latitude, category,
create_time, create_date, update_time, update_date, address, free)
VALUES
(%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
"""
now = int(datetime.now().timestamp())
params = (
data['name'],
data.get('brief'),
data['chat_id'],
data.get('photo_url'),
data.get('longitude'),
data.get('latitude'),
data.get('category'),
now,
datetime.fromtimestamp(now),
now,
datetime.fromtimestamp(now),
data.get('address'),
data.get('free', 0)
)
return execute_query(sql, params)
def get_museums(search: str = None, free: int = None):
base_sql = "SELECT * FROM mesum_overview WHERE 1=1"
params = []
if search:
base_sql += " AND (name LIKE %s OR brief LIKE %s)"
params.extend([f"%{search}%", f"%{search}%"])
if free is not None:
base_sql += " AND free = %s"
params.append(free)
base_sql += " ORDER BY create_time DESC"
return execute_query(base_sql, tuple(params))
def update_museum(museum_id: int, data: dict):
"""动态更新博物馆信息仅更新data中包含的字段"""
allowed_fields = {
'name': 'name = %s',
'brief': 'brief = %s',
'photo_url': 'photo_url = %s',
'longitude': 'longitude = %s',
'latitude': 'latitude = %s',
'category': 'category = %s',
'address': 'address = %s',
'free': 'free = %s'
}
update_fields = []
params = []
now = int(datetime.now().timestamp())
update_date = datetime.fromtimestamp(now)
# 收集动态字段
for field, sql_part in allowed_fields.items():
if field in data:
update_fields.append(sql_part)
params.append(data[field])
# 必须包含至少一个业务更新字段
if not update_fields:
raise ValueError("未提供有效更新字段")
# 添加自动更新时间
update_fields.extend([
'update_time = %s',
'update_date = %s'
])
params.extend([now, update_date])
# 构建动态SQL
set_clause = ", ".join(update_fields)
sql = f"""
UPDATE mesum_overview
SET {set_clause}
WHERE id = %s
"""
params.append(museum_id)
# 执行更新
rowcount = execute_query(sql, tuple(params))
if rowcount == 0:
raise HTTPException(status_code=404, detail="博物馆不存在")
return get_museum_by_id(museum_id)
def get_museum_by_id(museum_id: int):
sql = "SELECT * FROM mesum_overview WHERE id = %s"
result = execute_query(sql, (museum_id,))
assert isinstance(result, list), "Unexpected return type"
return result[0] if result else None
def delete_museum(museum_id: int):
sql = "DELETE FROM mesum_overview WHERE id = %s"
rowcount = execute_query(sql, (museum_id,))
if rowcount == 0:
raise HTTPException(404, "Delete failed")
return {"message": f"Deleted {rowcount} museums"}
# 创建授权记录
def create_users_museum(data: dict):
sql = """
INSERT INTO rag_flow.users_museum
(user_id, museum_id, create_time, create_date, update_time, update_date)
VALUES
(%s, %s, %s, %s, %s, %s)
"""
now = int(datetime.now().timestamp())
params = (
data['user_id'],
data['museum_id'],
now,
datetime.fromtimestamp(now),
now,
datetime.fromtimestamp(now)
)
return execute_query(sql, params)
# 查询授权记录(多条件)
def get_users_museums(user_id: str = None, museum_id: int = None, get_all: bool = False):
base_sql = "SELECT * FROM rag_flow.users_museum WHERE 1=1"
params = []
if not get_all:
if user_id:
base_sql += " AND user_id = %s"
params.append(user_id)
if museum_id is not None:
base_sql += " AND museum_id = %s"
params.append(museum_id)
base_sql += " ORDER BY create_time DESC"
return execute_query(base_sql, tuple(params))
# 按ID获取单条记录
def get_users_museums_by_user_id(user_id: str):
sql = "SELECT * FROM rag_flow.users_museum WHERE user_id = %s"
result = execute_query(sql, (user_id,))
return result
# 更新授权信息
def update_users_museum(id: int, data: dict):
"""动态更新用户博物馆授权仅更新data中包含的字段"""
allowed_fields = {
'user_id': 'user_id = %s',
'museum_id': 'museum_id = %s'
}
update_fields = []
params = []
now = int(datetime.now().timestamp())
update_date = datetime.fromtimestamp(now)
# 收集动态字段
for field, sql_part in allowed_fields.items():
if field in data:
update_fields.append(sql_part)
params.append(data[field])
# 必须包含至少一个业务字段
if not update_fields:
raise ValueError("未提供有效更新字段")
# 添加自动更新时间
update_fields.extend([
'update_time = %s',
'update_date = %s'
])
params.extend([now, update_date])
# 构建动态SQL
set_clause = ", ".join(update_fields)
sql = f"""
UPDATE rag_flow.users_museum
SET {set_clause}
WHERE id = %s
"""
params.append(id)
# 执行更新
rowcount = execute_query(sql, tuple(params))
if rowcount == 0:
raise HTTPException(status_code=404, detail="授权记录不存在")
return get_users_museum_by_id(id)
# 删除授权记录
def delete_users_museum(id: int):
sql = "DELETE FROM rag_flow.users_museum WHERE id = %s"
execute_query(sql, (id,))
return {"message": "User museum authorization deleted"}
# 批量检查授权状态
def check_auth_batch(user_id: str, museum_ids: list):
if not museum_ids:
return []
placeholders = ','.join(['%s'] * len(museum_ids))
sql = f"""
SELECT museum_id
FROM rag_flow.users_museum
WHERE user_id = %s
AND museum_id IN ({placeholders})
"""
params = [user_id] + museum_ids
result = execute_query(sql, params)
return [item['museum_id'] for item in result]
from datetime import datetime
# 创建用户
def create_user(data: dict):
sql = """
INSERT INTO rag_flow.users_info
(user_id, openid, phone, email, token, balance, status,
last_login_time, create_time, create_date, update_time, update_date)
VALUES
(%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) # 12个占位符
"""
now = int(datetime.now().timestamp())
params = (
data['user_id'],
data.get('openid'),
data.get('phone'),
data.get('email'),
data.get('token'),
data.get('balance', 0), # 默认余额0
data.get('status', 1), # 默认状态1正常
data.get('last_login_time'),
now,
datetime.fromtimestamp(now),
now,
datetime.fromtimestamp(now)
)
#logging.info(f"create user {data} {sql} {params}")
return execute_query(sql, params)
# 查询用户(多条件)
def get_users(status: int = None, email: str = None, phone: str = None):
base_sql = "SELECT * FROM rag_flow.users_info WHERE 1=1"
params = []
if status is not None:
base_sql += " AND status = %s"
params.append(status)
if email:
base_sql += " AND email = %s"
params.append(email)
if phone:
base_sql += " AND phone = %s"
params.append(phone)
base_sql += " ORDER BY create_time DESC"
return execute_query(base_sql, tuple(params))
# 按用户ID获取用户
def get_user_by_id(user_id: str):
sql = "SELECT * FROM rag_flow.users_info WHERE user_id = %s"
result = execute_query(sql, (user_id,))
return result[0] if result else None
# 更新用户信息
def update_user(user_id: str, data: dict):
"""动态更新用户信息仅更新data中包含的字段"""
# 允许更新的字段白名单
allowed_fields = {
'phone': 'phone = %s',
'email': 'email = %s',
'token': 'token = %s',
'balance': 'balance = %s',
'status': 'status = %s',
'last_login_time': 'last_login_time = %s'
}
# 过滤有效更新字段
update_fields = []
params = []
now = int(datetime.now().timestamp())
update_date = datetime.fromtimestamp(now)
# 收集动态字段
for field, sql_part in allowed_fields.items():
if field in data:
update_fields.append(sql_part)
params.append(data[field])
# 必须包含至少一个更新字段(除自动更新的时间字段)
if not update_fields:
raise ValueError("没有提供有效更新字段")
# 添加自动更新的时间字段
update_fields.extend([
'update_time = %s',
'update_date = %s'
])
params.extend([now, update_date])
# 构建动态SQL
set_clause = ", ".join(update_fields)
sql = f"""
UPDATE rag_flow.users_info
SET {set_clause}
WHERE user_id = %s
"""
# 添加用户ID作为最后参数
params.append(user_id)
# 执行更新
rowcount = execute_query(sql, tuple(params))
if rowcount == 0:
raise HTTPException(status_code=404, detail="用户不存在")
return get_user_by_id(user_id)
# 删除用户
def delete_user(user_id: str):
sql = "DELETE FROM rag_flow.users_info WHERE user_id = %s"
execute_query(sql, (user_id,))
return {"message": "User deleted"}
# 检查openid是否存在
def check_openid_exists(openid: str):
sql = "SELECT 1 FROM rag_flow.users_info WHERE openid = %s"
result = execute_query(sql, (openid,))
return bool(result)
# 更新用户余额
def update_balance(user_id: str, amount: int):
sql = """
UPDATE rag_flow.users_info
SET balance = balance + %s,
update_time = %s,
update_date = %s
WHERE user_id = %s
"""
now = int(datetime.now().timestamp())
params = (
amount,
now,
datetime.fromtimestamp(now),
user_id
)
execute_query(sql, params)
return get_user_by_id(user_id)
# 更新登录信息
def update_login_info(user_id: str, token: str):
sql = """
UPDATE rag_flow.users_info
SET token = %s,
last_login_time = %s,
update_time = %s,
update_date = %s
WHERE user_id = %s
"""
now = int(datetime.now().timestamp())
params = (
token,
now,
now,
datetime.fromtimestamp(now),
user_id
)
execute_query(sql, params)
return get_user_by_id(user_id)