2026-02-01 20:56:37 +08:00
|
|
|
|
#!/usr/bin/env python3
|
|
|
|
|
|
"""
|
|
|
|
|
|
数据库基类模块
|
|
|
|
|
|
提供统一的数据库连接管理和上下文管理器
|
|
|
|
|
|
"""
|
|
|
|
|
|
import os
|
|
|
|
|
|
import sqlite3
|
|
|
|
|
|
from contextlib import contextmanager
|
|
|
|
|
|
from typing import Generator, Optional, Any
|
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
|
|
from src.config import config
|
|
|
|
|
|
from src.logging_config import get_logger
|
|
|
|
|
|
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DatabaseConnectionError(Exception):
|
|
|
|
|
|
"""数据库连接错误"""
|
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DatabaseBase:
|
|
|
|
|
|
"""数据库基类,提供统一的连接管理"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, db_path: Optional[str] = None):
|
|
|
|
|
|
"""
|
|
|
|
|
|
初始化数据库基类
|
|
|
|
|
|
|
|
|
|
|
|
参数:
|
|
|
|
|
|
db_path: 数据库文件路径,如果为None则使用默认配置
|
|
|
|
|
|
"""
|
|
|
|
|
|
self.db_path = db_path or config.DATABASE_PATH
|
|
|
|
|
|
self._connection: Optional[sqlite3.Connection] = None
|
|
|
|
|
|
self._ensure_directory()
|
|
|
|
|
|
|
|
|
|
|
|
def _ensure_directory(self):
|
|
|
|
|
|
"""确保数据库目录存在"""
|
|
|
|
|
|
data_dir = os.path.dirname(self.db_path)
|
|
|
|
|
|
if data_dir and not os.path.exists(data_dir):
|
|
|
|
|
|
os.makedirs(data_dir)
|
|
|
|
|
|
logger.info(f"创建数据库目录: {data_dir}")
|
|
|
|
|
|
|
|
|
|
|
|
def _connect(self) -> sqlite3.Connection:
|
|
|
|
|
|
"""
|
|
|
|
|
|
创建数据库连接
|
|
|
|
|
|
|
|
|
|
|
|
返回:
|
|
|
|
|
|
sqlite3.Connection 对象
|
|
|
|
|
|
|
|
|
|
|
|
异常:
|
|
|
|
|
|
DatabaseConnectionError: 连接失败时抛出
|
|
|
|
|
|
"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
conn = sqlite3.connect(self.db_path)
|
|
|
|
|
|
conn.row_factory = sqlite3.Row
|
|
|
|
|
|
logger.debug(f"数据库连接已建立: {self.db_path}")
|
|
|
|
|
|
return conn
|
|
|
|
|
|
except sqlite3.Error as e:
|
|
|
|
|
|
error_msg = f"数据库连接失败: {self.db_path}, 错误: {e}"
|
|
|
|
|
|
logger.error(error_msg)
|
|
|
|
|
|
raise DatabaseConnectionError(error_msg) from e
|
|
|
|
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
|
|
def get_connection(self) -> Generator[sqlite3.Connection, None, None]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
获取数据库连接的上下文管理器
|
|
|
|
|
|
|
|
|
|
|
|
使用示例:
|
|
|
|
|
|
with self.get_connection() as conn:
|
|
|
|
|
|
cursor = conn.cursor()
|
|
|
|
|
|
cursor.execute(...)
|
|
|
|
|
|
|
|
|
|
|
|
返回:
|
|
|
|
|
|
数据库连接对象
|
|
|
|
|
|
"""
|
|
|
|
|
|
conn = None
|
|
|
|
|
|
try:
|
|
|
|
|
|
conn = self._connect()
|
|
|
|
|
|
yield conn
|
|
|
|
|
|
except sqlite3.Error as e:
|
|
|
|
|
|
logger.error(f"数据库操作失败: {e}")
|
|
|
|
|
|
raise
|
|
|
|
|
|
finally:
|
|
|
|
|
|
if conn:
|
|
|
|
|
|
conn.close()
|
|
|
|
|
|
logger.debug("数据库连接已关闭")
|
|
|
|
|
|
|
|
|
|
|
|
def execute_query(self, query: str, params: tuple = ()) -> list:
|
|
|
|
|
|
"""
|
|
|
|
|
|
执行查询并返回结果
|
|
|
|
|
|
|
|
|
|
|
|
参数:
|
|
|
|
|
|
query: SQL查询语句
|
|
|
|
|
|
params: 查询参数
|
|
|
|
|
|
|
|
|
|
|
|
返回:
|
|
|
|
|
|
查询结果列表
|
|
|
|
|
|
"""
|
|
|
|
|
|
with self.get_connection() as conn:
|
|
|
|
|
|
cursor = conn.cursor()
|
|
|
|
|
|
cursor.execute(query, params)
|
|
|
|
|
|
return [dict(row) for row in cursor.fetchall()]
|
|
|
|
|
|
|
|
|
|
|
|
def execute_update(self, query: str, params: tuple = ()) -> int:
|
|
|
|
|
|
"""
|
|
|
|
|
|
执行更新操作
|
|
|
|
|
|
|
|
|
|
|
|
参数:
|
|
|
|
|
|
query: SQL更新语句
|
|
|
|
|
|
params: 更新参数
|
|
|
|
|
|
|
|
|
|
|
|
返回:
|
|
|
|
|
|
受影响的行数
|
|
|
|
|
|
"""
|
|
|
|
|
|
with self.get_connection() as conn:
|
|
|
|
|
|
cursor = conn.cursor()
|
|
|
|
|
|
cursor.execute(query, params)
|
|
|
|
|
|
conn.commit()
|
|
|
|
|
|
return cursor.rowcount
|
|
|
|
|
|
|
|
|
|
|
|
def execute_many(self, query: str, params_list: list) -> int:
|
|
|
|
|
|
"""
|
|
|
|
|
|
批量执行操作
|
|
|
|
|
|
|
|
|
|
|
|
参数:
|
|
|
|
|
|
query: SQL语句
|
|
|
|
|
|
params_list: 参数列表
|
|
|
|
|
|
|
|
|
|
|
|
返回:
|
|
|
|
|
|
受影响的总行数
|
|
|
|
|
|
"""
|
|
|
|
|
|
with self.get_connection() as conn:
|
|
|
|
|
|
cursor = conn.cursor()
|
|
|
|
|
|
cursor.executemany(query, params_list)
|
|
|
|
|
|
conn.commit()
|
|
|
|
|
|
return cursor.rowcount
|
|
|
|
|
|
|
|
|
|
|
|
def table_exists(self, table_name: str) -> bool:
|
|
|
|
|
|
"""
|
|
|
|
|
|
检查表是否存在
|
|
|
|
|
|
|
|
|
|
|
|
参数:
|
|
|
|
|
|
table_name: 表名
|
|
|
|
|
|
|
|
|
|
|
|
返回:
|
|
|
|
|
|
表是否存在
|
|
|
|
|
|
"""
|
|
|
|
|
|
query = """
|
|
|
|
|
|
SELECT name FROM sqlite_master
|
|
|
|
|
|
WHERE type='table' AND name=?
|
|
|
|
|
|
"""
|
|
|
|
|
|
result = self.execute_query(query, (table_name,))
|
|
|
|
|
|
return len(result) > 0
|
|
|
|
|
|
|
|
|
|
|
|
def get_table_info(self, table_name: str) -> list:
|
|
|
|
|
|
"""
|
|
|
|
|
|
获取表结构信息
|
|
|
|
|
|
|
|
|
|
|
|
参数:
|
|
|
|
|
|
table_name: 表名
|
|
|
|
|
|
|
|
|
|
|
|
返回:
|
|
|
|
|
|
表结构信息列表
|
|
|
|
|
|
"""
|
|
|
|
|
|
with self.get_connection() as conn:
|
|
|
|
|
|
cursor = conn.cursor()
|
|
|
|
|
|
cursor.execute(f"PRAGMA table_info({table_name})")
|
|
|
|
|
|
return [dict(row) for row in cursor.fetchall()]
|
|
|
|
|
|
|
|
|
|
|
|
def vacuum(self):
|
|
|
|
|
|
"""执行数据库整理"""
|
|
|
|
|
|
with self.get_connection() as conn:
|
|
|
|
|
|
conn.execute("VACUUM")
|
|
|
|
|
|
logger.info("数据库整理完成")
|
|
|
|
|
|
|
|
|
|
|
|
def backup(self, backup_path: Optional[str] = None):
|
|
|
|
|
|
"""
|
|
|
|
|
|
备份数据库
|
|
|
|
|
|
|
|
|
|
|
|
参数:
|
|
|
|
|
|
backup_path: 备份文件路径,如果为None则使用默认路径
|
|
|
|
|
|
"""
|
|
|
|
|
|
if backup_path is None:
|
|
|
|
|
|
backup_dir = "backups"
|
|
|
|
|
|
os.makedirs(backup_dir, exist_ok=True)
|
|
|
|
|
|
timestamp = os.path.getmtime(self.db_path)
|
|
|
|
|
|
from datetime import datetime
|
|
|
|
|
|
dt = datetime.fromtimestamp(timestamp)
|
|
|
|
|
|
backup_path = os.path.join(
|
|
|
|
|
|
backup_dir,
|
|
|
|
|
|
f"backup_{dt.strftime('%Y%m%d_%H%M%S')}.db"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
with self.get_connection() as src_conn:
|
|
|
|
|
|
dest_conn = sqlite3.connect(backup_path)
|
|
|
|
|
|
src_conn.backup(dest_conn)
|
|
|
|
|
|
dest_conn.close()
|
|
|
|
|
|
logger.info(f"数据库备份完成: {backup_path}")
|
|
|
|
|
|
except sqlite3.Error as e:
|
|
|
|
|
|
logger.error(f"数据库备份失败: {e}")
|
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 全局数据库连接池(可选,用于高性能场景)
|
|
|
|
|
|
class ConnectionPool:
|
|
|
|
|
|
"""简单的数据库连接池"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, db_path: str, max_connections: int = 5):
|
|
|
|
|
|
self.db_path = db_path
|
|
|
|
|
|
self.max_connections = max_connections
|
|
|
|
|
|
self._connections: list[sqlite3.Connection] = []
|
|
|
|
|
|
self._in_use: set[sqlite3.Connection] = set()
|
|
|
|
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
|
|
def get_connection(self) -> Generator[sqlite3.Connection, None, None]:
|
|
|
|
|
|
"""从连接池获取连接"""
|
|
|
|
|
|
conn = None
|
|
|
|
|
|
try:
|
|
|
|
|
|
if self._connections:
|
|
|
|
|
|
conn = self._connections.pop()
|
|
|
|
|
|
elif len(self._in_use) < self.max_connections:
|
|
|
|
|
|
conn = sqlite3.connect(self.db_path)
|
|
|
|
|
|
conn.row_factory = sqlite3.Row
|
|
|
|
|
|
else:
|
|
|
|
|
|
raise DatabaseConnectionError("连接池已满")
|
|
|
|
|
|
|
|
|
|
|
|
self._in_use.add(conn)
|
|
|
|
|
|
yield conn
|
|
|
|
|
|
finally:
|
|
|
|
|
|
if conn:
|
|
|
|
|
|
self._in_use.remove(conn)
|
|
|
|
|
|
self._connections.append(conn)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
|
# 测试数据库基类
|
|
|
|
|
|
db = DatabaseBase()
|
|
|
|
|
|
|
|
|
|
|
|
# 测试连接
|
|
|
|
|
|
with db.get_connection() as conn:
|
|
|
|
|
|
cursor = conn.cursor()
|
|
|
|
|
|
cursor.execute("SELECT sqlite_version()")
|
|
|
|
|
|
version = cursor.fetchone()[0]
|
|
|
|
|
|
print(f"SQLite版本: {version}")
|
|
|
|
|
|
|
|
|
|
|
|
# 测试查询
|
|
|
|
|
|
if db.table_exists("sqlite_master"):
|
|
|
|
|
|
print("sqlite_master表存在")
|
|
|
|
|
|
|
|
|
|
|
|
# 测试备份
|
|
|
|
|
|
try:
|
|
|
|
|
|
db.backup("test_backup.db")
|
|
|
|
|
|
print("备份测试完成")
|
|
|
|
|
|
except Exception as e:
|
2025-12-31 02:04:16 +08:00
|
|
|
|
print(f"备份测试失败: {e}")
|