Files
Orbitin/src/database/base.py
fuzhou 0a576b04cf n1
2026-02-01 20:56:37 +08:00

257 lines
7.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
数据库基类模块
提供统一的数据库连接管理和上下文管理器
"""
import os
import sqlite3
from contextlib import contextmanager
from typing import Generator, Optional, Any
from pathlib import Path
from src.config import config
from src.logging_config import get_logger
logger = get_logger(__name__)
class DatabaseConnectionError(Exception):
"""数据库连接错误"""
pass
class DatabaseBase:
"""数据库基类,提供统一的连接管理"""
def __init__(self, db_path: Optional[str] = None):
"""
初始化数据库基类
参数:
db_path: 数据库文件路径如果为None则使用默认配置
"""
self.db_path = db_path or config.DATABASE_PATH
self._connection: Optional[sqlite3.Connection] = None
self._ensure_directory()
def _ensure_directory(self):
"""确保数据库目录存在"""
data_dir = os.path.dirname(self.db_path)
if data_dir and not os.path.exists(data_dir):
os.makedirs(data_dir)
logger.info(f"创建数据库目录: {data_dir}")
def _connect(self) -> sqlite3.Connection:
"""
创建数据库连接
返回:
sqlite3.Connection 对象
异常:
DatabaseConnectionError: 连接失败时抛出
"""
try:
conn = sqlite3.connect(self.db_path)
conn.row_factory = sqlite3.Row
logger.debug(f"数据库连接已建立: {self.db_path}")
return conn
except sqlite3.Error as e:
error_msg = f"数据库连接失败: {self.db_path}, 错误: {e}"
logger.error(error_msg)
raise DatabaseConnectionError(error_msg) from e
@contextmanager
def get_connection(self) -> Generator[sqlite3.Connection, None, None]:
"""
获取数据库连接的上下文管理器
使用示例:
with self.get_connection() as conn:
cursor = conn.cursor()
cursor.execute(...)
返回:
数据库连接对象
"""
conn = None
try:
conn = self._connect()
yield conn
except sqlite3.Error as e:
logger.error(f"数据库操作失败: {e}")
raise
finally:
if conn:
conn.close()
logger.debug("数据库连接已关闭")
def execute_query(self, query: str, params: tuple = ()) -> list:
"""
执行查询并返回结果
参数:
query: SQL查询语句
params: 查询参数
返回:
查询结果列表
"""
with self.get_connection() as conn:
cursor = conn.cursor()
cursor.execute(query, params)
return [dict(row) for row in cursor.fetchall()]
def execute_update(self, query: str, params: tuple = ()) -> int:
"""
执行更新操作
参数:
query: SQL更新语句
params: 更新参数
返回:
受影响的行数
"""
with self.get_connection() as conn:
cursor = conn.cursor()
cursor.execute(query, params)
conn.commit()
return cursor.rowcount
def execute_many(self, query: str, params_list: list) -> int:
"""
批量执行操作
参数:
query: SQL语句
params_list: 参数列表
返回:
受影响的总行数
"""
with self.get_connection() as conn:
cursor = conn.cursor()
cursor.executemany(query, params_list)
conn.commit()
return cursor.rowcount
def table_exists(self, table_name: str) -> bool:
"""
检查表是否存在
参数:
table_name: 表名
返回:
表是否存在
"""
query = """
SELECT name FROM sqlite_master
WHERE type='table' AND name=?
"""
result = self.execute_query(query, (table_name,))
return len(result) > 0
def get_table_info(self, table_name: str) -> list:
"""
获取表结构信息
参数:
table_name: 表名
返回:
表结构信息列表
"""
with self.get_connection() as conn:
cursor = conn.cursor()
cursor.execute(f"PRAGMA table_info({table_name})")
return [dict(row) for row in cursor.fetchall()]
def vacuum(self):
"""执行数据库整理"""
with self.get_connection() as conn:
conn.execute("VACUUM")
logger.info("数据库整理完成")
def backup(self, backup_path: Optional[str] = None):
"""
备份数据库
参数:
backup_path: 备份文件路径如果为None则使用默认路径
"""
if backup_path is None:
backup_dir = "backups"
os.makedirs(backup_dir, exist_ok=True)
timestamp = os.path.getmtime(self.db_path)
from datetime import datetime
dt = datetime.fromtimestamp(timestamp)
backup_path = os.path.join(
backup_dir,
f"backup_{dt.strftime('%Y%m%d_%H%M%S')}.db"
)
try:
with self.get_connection() as src_conn:
dest_conn = sqlite3.connect(backup_path)
src_conn.backup(dest_conn)
dest_conn.close()
logger.info(f"数据库备份完成: {backup_path}")
except sqlite3.Error as e:
logger.error(f"数据库备份失败: {e}")
raise
# 全局数据库连接池(可选,用于高性能场景)
class ConnectionPool:
"""简单的数据库连接池"""
def __init__(self, db_path: str, max_connections: int = 5):
self.db_path = db_path
self.max_connections = max_connections
self._connections: list[sqlite3.Connection] = []
self._in_use: set[sqlite3.Connection] = set()
@contextmanager
def get_connection(self) -> Generator[sqlite3.Connection, None, None]:
"""从连接池获取连接"""
conn = None
try:
if self._connections:
conn = self._connections.pop()
elif len(self._in_use) < self.max_connections:
conn = sqlite3.connect(self.db_path)
conn.row_factory = sqlite3.Row
else:
raise DatabaseConnectionError("连接池已满")
self._in_use.add(conn)
yield conn
finally:
if conn:
self._in_use.remove(conn)
self._connections.append(conn)
if __name__ == '__main__':
# 测试数据库基类
db = DatabaseBase()
# 测试连接
with db.get_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT sqlite_version()")
version = cursor.fetchone()[0]
print(f"SQLite版本: {version}")
# 测试查询
if db.table_exists("sqlite_master"):
print("sqlite_master表存在")
# 测试备份
try:
db.backup("test_backup.db")
print("备份测试完成")
except Exception as e:
print(f"备份测试失败: {e}")