#!/usr/bin/env python3 """ 数据库基类模块 提供统一的数据库连接管理和上下文管理器 """ import os import sqlite3 from contextlib import contextmanager from typing import Generator, Optional, Any from pathlib import Path from src.config import config from src.logging_config import get_logger logger = get_logger(__name__) class DatabaseConnectionError(Exception): """数据库连接错误""" pass class DatabaseBase: """数据库基类,提供统一的连接管理""" def __init__(self, db_path: Optional[str] = None): """ 初始化数据库基类 参数: db_path: 数据库文件路径,如果为None则使用默认配置 """ self.db_path = db_path or config.DATABASE_PATH self._connection: Optional[sqlite3.Connection] = None self._ensure_directory() def _ensure_directory(self): """确保数据库目录存在""" data_dir = os.path.dirname(self.db_path) if data_dir and not os.path.exists(data_dir): os.makedirs(data_dir) logger.info(f"创建数据库目录: {data_dir}") def _connect(self) -> sqlite3.Connection: """ 创建数据库连接 返回: sqlite3.Connection 对象 异常: DatabaseConnectionError: 连接失败时抛出 """ try: conn = sqlite3.connect(self.db_path) conn.row_factory = sqlite3.Row logger.debug(f"数据库连接已建立: {self.db_path}") return conn except sqlite3.Error as e: error_msg = f"数据库连接失败: {self.db_path}, 错误: {e}" logger.error(error_msg) raise DatabaseConnectionError(error_msg) from e @contextmanager def get_connection(self) -> Generator[sqlite3.Connection, None, None]: """ 获取数据库连接的上下文管理器 使用示例: with self.get_connection() as conn: cursor = conn.cursor() cursor.execute(...) 返回: 数据库连接对象 """ conn = None try: conn = self._connect() yield conn except sqlite3.Error as e: logger.error(f"数据库操作失败: {e}") raise finally: if conn: conn.close() logger.debug("数据库连接已关闭") def execute_query(self, query: str, params: tuple = ()) -> list: """ 执行查询并返回结果 参数: query: SQL查询语句 params: 查询参数 返回: 查询结果列表 """ with self.get_connection() as conn: cursor = conn.cursor() cursor.execute(query, params) return [dict(row) for row in cursor.fetchall()] def execute_update(self, query: str, params: tuple = ()) -> int: """ 执行更新操作 参数: query: SQL更新语句 params: 更新参数 返回: 受影响的行数 """ with self.get_connection() as conn: cursor = conn.cursor() cursor.execute(query, params) conn.commit() return cursor.rowcount def execute_many(self, query: str, params_list: list) -> int: """ 批量执行操作 参数: query: SQL语句 params_list: 参数列表 返回: 受影响的总行数 """ with self.get_connection() as conn: cursor = conn.cursor() cursor.executemany(query, params_list) conn.commit() return cursor.rowcount def table_exists(self, table_name: str) -> bool: """ 检查表是否存在 参数: table_name: 表名 返回: 表是否存在 """ query = """ SELECT name FROM sqlite_master WHERE type='table' AND name=? """ result = self.execute_query(query, (table_name,)) return len(result) > 0 def get_table_info(self, table_name: str) -> list: """ 获取表结构信息 参数: table_name: 表名 返回: 表结构信息列表 """ with self.get_connection() as conn: cursor = conn.cursor() cursor.execute(f"PRAGMA table_info({table_name})") return [dict(row) for row in cursor.fetchall()] def vacuum(self): """执行数据库整理""" with self.get_connection() as conn: conn.execute("VACUUM") logger.info("数据库整理完成") def backup(self, backup_path: Optional[str] = None): """ 备份数据库 参数: backup_path: 备份文件路径,如果为None则使用默认路径 """ if backup_path is None: backup_dir = "backups" os.makedirs(backup_dir, exist_ok=True) timestamp = os.path.getmtime(self.db_path) from datetime import datetime dt = datetime.fromtimestamp(timestamp) backup_path = os.path.join( backup_dir, f"backup_{dt.strftime('%Y%m%d_%H%M%S')}.db" ) try: with self.get_connection() as src_conn: dest_conn = sqlite3.connect(backup_path) src_conn.backup(dest_conn) dest_conn.close() logger.info(f"数据库备份完成: {backup_path}") except sqlite3.Error as e: logger.error(f"数据库备份失败: {e}") raise # 全局数据库连接池(可选,用于高性能场景) class ConnectionPool: """简单的数据库连接池""" def __init__(self, db_path: str, max_connections: int = 5): self.db_path = db_path self.max_connections = max_connections self._connections: list[sqlite3.Connection] = [] self._in_use: set[sqlite3.Connection] = set() @contextmanager def get_connection(self) -> Generator[sqlite3.Connection, None, None]: """从连接池获取连接""" conn = None try: if self._connections: conn = self._connections.pop() elif len(self._in_use) < self.max_connections: conn = sqlite3.connect(self.db_path) conn.row_factory = sqlite3.Row else: raise DatabaseConnectionError("连接池已满") self._in_use.add(conn) yield conn finally: if conn: self._in_use.remove(conn) self._connections.append(conn) if __name__ == '__main__': # 测试数据库基类 db = DatabaseBase() # 测试连接 with db.get_connection() as conn: cursor = conn.cursor() cursor.execute("SELECT sqlite_version()") version = cursor.fetchone()[0] print(f"SQLite版本: {version}") # 测试查询 if db.table_exists("sqlite_master"): print("sqlite_master表存在") # 测试备份 try: db.backup("test_backup.db") print("备份测试完成") except Exception as e: print(f"备份测试失败: {e}")