mirror of
https://devops.liangqichi.top/qichi.liang/Orbitin.git
synced 2026-02-10 15:41:31 +08:00
n1
This commit is contained in:
@@ -1,257 +1,257 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
数据库基类模块
|
||||
提供统一的数据库连接管理和上下文管理器
|
||||
"""
|
||||
import os
|
||||
import sqlite3
|
||||
from contextlib import contextmanager
|
||||
from typing import Generator, Optional, Any
|
||||
from pathlib import Path
|
||||
|
||||
from src.config import config
|
||||
from src.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DatabaseConnectionError(Exception):
|
||||
"""数据库连接错误"""
|
||||
pass
|
||||
|
||||
|
||||
class DatabaseBase:
|
||||
"""数据库基类,提供统一的连接管理"""
|
||||
|
||||
def __init__(self, db_path: Optional[str] = None):
|
||||
"""
|
||||
初始化数据库基类
|
||||
|
||||
参数:
|
||||
db_path: 数据库文件路径,如果为None则使用默认配置
|
||||
"""
|
||||
self.db_path = db_path or config.DATABASE_PATH
|
||||
self._connection: Optional[sqlite3.Connection] = None
|
||||
self._ensure_directory()
|
||||
|
||||
def _ensure_directory(self):
|
||||
"""确保数据库目录存在"""
|
||||
data_dir = os.path.dirname(self.db_path)
|
||||
if data_dir and not os.path.exists(data_dir):
|
||||
os.makedirs(data_dir)
|
||||
logger.info(f"创建数据库目录: {data_dir}")
|
||||
|
||||
def _connect(self) -> sqlite3.Connection:
|
||||
"""
|
||||
创建数据库连接
|
||||
|
||||
返回:
|
||||
sqlite3.Connection 对象
|
||||
|
||||
异常:
|
||||
DatabaseConnectionError: 连接失败时抛出
|
||||
"""
|
||||
try:
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
conn.row_factory = sqlite3.Row
|
||||
logger.debug(f"数据库连接已建立: {self.db_path}")
|
||||
return conn
|
||||
except sqlite3.Error as e:
|
||||
error_msg = f"数据库连接失败: {self.db_path}, 错误: {e}"
|
||||
logger.error(error_msg)
|
||||
raise DatabaseConnectionError(error_msg) from e
|
||||
|
||||
@contextmanager
|
||||
def get_connection(self) -> Generator[sqlite3.Connection, None, None]:
|
||||
"""
|
||||
获取数据库连接的上下文管理器
|
||||
|
||||
使用示例:
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(...)
|
||||
|
||||
返回:
|
||||
数据库连接对象
|
||||
"""
|
||||
conn = None
|
||||
try:
|
||||
conn = self._connect()
|
||||
yield conn
|
||||
except sqlite3.Error as e:
|
||||
logger.error(f"数据库操作失败: {e}")
|
||||
raise
|
||||
finally:
|
||||
if conn:
|
||||
conn.close()
|
||||
logger.debug("数据库连接已关闭")
|
||||
|
||||
def execute_query(self, query: str, params: tuple = ()) -> list:
|
||||
"""
|
||||
执行查询并返回结果
|
||||
|
||||
参数:
|
||||
query: SQL查询语句
|
||||
params: 查询参数
|
||||
|
||||
返回:
|
||||
查询结果列表
|
||||
"""
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(query, params)
|
||||
return [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
def execute_update(self, query: str, params: tuple = ()) -> int:
|
||||
"""
|
||||
执行更新操作
|
||||
|
||||
参数:
|
||||
query: SQL更新语句
|
||||
params: 更新参数
|
||||
|
||||
返回:
|
||||
受影响的行数
|
||||
"""
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(query, params)
|
||||
conn.commit()
|
||||
return cursor.rowcount
|
||||
|
||||
def execute_many(self, query: str, params_list: list) -> int:
|
||||
"""
|
||||
批量执行操作
|
||||
|
||||
参数:
|
||||
query: SQL语句
|
||||
params_list: 参数列表
|
||||
|
||||
返回:
|
||||
受影响的总行数
|
||||
"""
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.executemany(query, params_list)
|
||||
conn.commit()
|
||||
return cursor.rowcount
|
||||
|
||||
def table_exists(self, table_name: str) -> bool:
|
||||
"""
|
||||
检查表是否存在
|
||||
|
||||
参数:
|
||||
table_name: 表名
|
||||
|
||||
返回:
|
||||
表是否存在
|
||||
"""
|
||||
query = """
|
||||
SELECT name FROM sqlite_master
|
||||
WHERE type='table' AND name=?
|
||||
"""
|
||||
result = self.execute_query(query, (table_name,))
|
||||
return len(result) > 0
|
||||
|
||||
def get_table_info(self, table_name: str) -> list:
|
||||
"""
|
||||
获取表结构信息
|
||||
|
||||
参数:
|
||||
table_name: 表名
|
||||
|
||||
返回:
|
||||
表结构信息列表
|
||||
"""
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(f"PRAGMA table_info({table_name})")
|
||||
return [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
def vacuum(self):
|
||||
"""执行数据库整理"""
|
||||
with self.get_connection() as conn:
|
||||
conn.execute("VACUUM")
|
||||
logger.info("数据库整理完成")
|
||||
|
||||
def backup(self, backup_path: Optional[str] = None):
|
||||
"""
|
||||
备份数据库
|
||||
|
||||
参数:
|
||||
backup_path: 备份文件路径,如果为None则使用默认路径
|
||||
"""
|
||||
if backup_path is None:
|
||||
backup_dir = "backups"
|
||||
os.makedirs(backup_dir, exist_ok=True)
|
||||
timestamp = os.path.getmtime(self.db_path)
|
||||
from datetime import datetime
|
||||
dt = datetime.fromtimestamp(timestamp)
|
||||
backup_path = os.path.join(
|
||||
backup_dir,
|
||||
f"backup_{dt.strftime('%Y%m%d_%H%M%S')}.db"
|
||||
)
|
||||
|
||||
try:
|
||||
with self.get_connection() as src_conn:
|
||||
dest_conn = sqlite3.connect(backup_path)
|
||||
src_conn.backup(dest_conn)
|
||||
dest_conn.close()
|
||||
logger.info(f"数据库备份完成: {backup_path}")
|
||||
except sqlite3.Error as e:
|
||||
logger.error(f"数据库备份失败: {e}")
|
||||
raise
|
||||
|
||||
|
||||
# 全局数据库连接池(可选,用于高性能场景)
|
||||
class ConnectionPool:
|
||||
"""简单的数据库连接池"""
|
||||
|
||||
def __init__(self, db_path: str, max_connections: int = 5):
|
||||
self.db_path = db_path
|
||||
self.max_connections = max_connections
|
||||
self._connections: list[sqlite3.Connection] = []
|
||||
self._in_use: set[sqlite3.Connection] = set()
|
||||
|
||||
@contextmanager
|
||||
def get_connection(self) -> Generator[sqlite3.Connection, None, None]:
|
||||
"""从连接池获取连接"""
|
||||
conn = None
|
||||
try:
|
||||
if self._connections:
|
||||
conn = self._connections.pop()
|
||||
elif len(self._in_use) < self.max_connections:
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
conn.row_factory = sqlite3.Row
|
||||
else:
|
||||
raise DatabaseConnectionError("连接池已满")
|
||||
|
||||
self._in_use.add(conn)
|
||||
yield conn
|
||||
finally:
|
||||
if conn:
|
||||
self._in_use.remove(conn)
|
||||
self._connections.append(conn)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 测试数据库基类
|
||||
db = DatabaseBase()
|
||||
|
||||
# 测试连接
|
||||
with db.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT sqlite_version()")
|
||||
version = cursor.fetchone()[0]
|
||||
print(f"SQLite版本: {version}")
|
||||
|
||||
# 测试查询
|
||||
if db.table_exists("sqlite_master"):
|
||||
print("sqlite_master表存在")
|
||||
|
||||
# 测试备份
|
||||
try:
|
||||
db.backup("test_backup.db")
|
||||
print("备份测试完成")
|
||||
except Exception as e:
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
数据库基类模块
|
||||
提供统一的数据库连接管理和上下文管理器
|
||||
"""
|
||||
import os
|
||||
import sqlite3
|
||||
from contextlib import contextmanager
|
||||
from typing import Generator, Optional, Any
|
||||
from pathlib import Path
|
||||
|
||||
from src.config import config
|
||||
from src.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DatabaseConnectionError(Exception):
|
||||
"""数据库连接错误"""
|
||||
pass
|
||||
|
||||
|
||||
class DatabaseBase:
|
||||
"""数据库基类,提供统一的连接管理"""
|
||||
|
||||
def __init__(self, db_path: Optional[str] = None):
|
||||
"""
|
||||
初始化数据库基类
|
||||
|
||||
参数:
|
||||
db_path: 数据库文件路径,如果为None则使用默认配置
|
||||
"""
|
||||
self.db_path = db_path or config.DATABASE_PATH
|
||||
self._connection: Optional[sqlite3.Connection] = None
|
||||
self._ensure_directory()
|
||||
|
||||
def _ensure_directory(self):
|
||||
"""确保数据库目录存在"""
|
||||
data_dir = os.path.dirname(self.db_path)
|
||||
if data_dir and not os.path.exists(data_dir):
|
||||
os.makedirs(data_dir)
|
||||
logger.info(f"创建数据库目录: {data_dir}")
|
||||
|
||||
def _connect(self) -> sqlite3.Connection:
|
||||
"""
|
||||
创建数据库连接
|
||||
|
||||
返回:
|
||||
sqlite3.Connection 对象
|
||||
|
||||
异常:
|
||||
DatabaseConnectionError: 连接失败时抛出
|
||||
"""
|
||||
try:
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
conn.row_factory = sqlite3.Row
|
||||
logger.debug(f"数据库连接已建立: {self.db_path}")
|
||||
return conn
|
||||
except sqlite3.Error as e:
|
||||
error_msg = f"数据库连接失败: {self.db_path}, 错误: {e}"
|
||||
logger.error(error_msg)
|
||||
raise DatabaseConnectionError(error_msg) from e
|
||||
|
||||
@contextmanager
|
||||
def get_connection(self) -> Generator[sqlite3.Connection, None, None]:
|
||||
"""
|
||||
获取数据库连接的上下文管理器
|
||||
|
||||
使用示例:
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(...)
|
||||
|
||||
返回:
|
||||
数据库连接对象
|
||||
"""
|
||||
conn = None
|
||||
try:
|
||||
conn = self._connect()
|
||||
yield conn
|
||||
except sqlite3.Error as e:
|
||||
logger.error(f"数据库操作失败: {e}")
|
||||
raise
|
||||
finally:
|
||||
if conn:
|
||||
conn.close()
|
||||
logger.debug("数据库连接已关闭")
|
||||
|
||||
def execute_query(self, query: str, params: tuple = ()) -> list:
|
||||
"""
|
||||
执行查询并返回结果
|
||||
|
||||
参数:
|
||||
query: SQL查询语句
|
||||
params: 查询参数
|
||||
|
||||
返回:
|
||||
查询结果列表
|
||||
"""
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(query, params)
|
||||
return [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
def execute_update(self, query: str, params: tuple = ()) -> int:
|
||||
"""
|
||||
执行更新操作
|
||||
|
||||
参数:
|
||||
query: SQL更新语句
|
||||
params: 更新参数
|
||||
|
||||
返回:
|
||||
受影响的行数
|
||||
"""
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(query, params)
|
||||
conn.commit()
|
||||
return cursor.rowcount
|
||||
|
||||
def execute_many(self, query: str, params_list: list) -> int:
|
||||
"""
|
||||
批量执行操作
|
||||
|
||||
参数:
|
||||
query: SQL语句
|
||||
params_list: 参数列表
|
||||
|
||||
返回:
|
||||
受影响的总行数
|
||||
"""
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.executemany(query, params_list)
|
||||
conn.commit()
|
||||
return cursor.rowcount
|
||||
|
||||
def table_exists(self, table_name: str) -> bool:
|
||||
"""
|
||||
检查表是否存在
|
||||
|
||||
参数:
|
||||
table_name: 表名
|
||||
|
||||
返回:
|
||||
表是否存在
|
||||
"""
|
||||
query = """
|
||||
SELECT name FROM sqlite_master
|
||||
WHERE type='table' AND name=?
|
||||
"""
|
||||
result = self.execute_query(query, (table_name,))
|
||||
return len(result) > 0
|
||||
|
||||
def get_table_info(self, table_name: str) -> list:
|
||||
"""
|
||||
获取表结构信息
|
||||
|
||||
参数:
|
||||
table_name: 表名
|
||||
|
||||
返回:
|
||||
表结构信息列表
|
||||
"""
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(f"PRAGMA table_info({table_name})")
|
||||
return [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
def vacuum(self):
|
||||
"""执行数据库整理"""
|
||||
with self.get_connection() as conn:
|
||||
conn.execute("VACUUM")
|
||||
logger.info("数据库整理完成")
|
||||
|
||||
def backup(self, backup_path: Optional[str] = None):
|
||||
"""
|
||||
备份数据库
|
||||
|
||||
参数:
|
||||
backup_path: 备份文件路径,如果为None则使用默认路径
|
||||
"""
|
||||
if backup_path is None:
|
||||
backup_dir = "backups"
|
||||
os.makedirs(backup_dir, exist_ok=True)
|
||||
timestamp = os.path.getmtime(self.db_path)
|
||||
from datetime import datetime
|
||||
dt = datetime.fromtimestamp(timestamp)
|
||||
backup_path = os.path.join(
|
||||
backup_dir,
|
||||
f"backup_{dt.strftime('%Y%m%d_%H%M%S')}.db"
|
||||
)
|
||||
|
||||
try:
|
||||
with self.get_connection() as src_conn:
|
||||
dest_conn = sqlite3.connect(backup_path)
|
||||
src_conn.backup(dest_conn)
|
||||
dest_conn.close()
|
||||
logger.info(f"数据库备份完成: {backup_path}")
|
||||
except sqlite3.Error as e:
|
||||
logger.error(f"数据库备份失败: {e}")
|
||||
raise
|
||||
|
||||
|
||||
# 全局数据库连接池(可选,用于高性能场景)
|
||||
class ConnectionPool:
|
||||
"""简单的数据库连接池"""
|
||||
|
||||
def __init__(self, db_path: str, max_connections: int = 5):
|
||||
self.db_path = db_path
|
||||
self.max_connections = max_connections
|
||||
self._connections: list[sqlite3.Connection] = []
|
||||
self._in_use: set[sqlite3.Connection] = set()
|
||||
|
||||
@contextmanager
|
||||
def get_connection(self) -> Generator[sqlite3.Connection, None, None]:
|
||||
"""从连接池获取连接"""
|
||||
conn = None
|
||||
try:
|
||||
if self._connections:
|
||||
conn = self._connections.pop()
|
||||
elif len(self._in_use) < self.max_connections:
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
conn.row_factory = sqlite3.Row
|
||||
else:
|
||||
raise DatabaseConnectionError("连接池已满")
|
||||
|
||||
self._in_use.add(conn)
|
||||
yield conn
|
||||
finally:
|
||||
if conn:
|
||||
self._in_use.remove(conn)
|
||||
self._connections.append(conn)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 测试数据库基类
|
||||
db = DatabaseBase()
|
||||
|
||||
# 测试连接
|
||||
with db.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT sqlite_version()")
|
||||
version = cursor.fetchone()[0]
|
||||
print(f"SQLite版本: {version}")
|
||||
|
||||
# 测试查询
|
||||
if db.table_exists("sqlite_master"):
|
||||
print("sqlite_master表存在")
|
||||
|
||||
# 测试备份
|
||||
try:
|
||||
db.backup("test_backup.db")
|
||||
print("备份测试完成")
|
||||
except Exception as e:
|
||||
print(f"备份测试失败: {e}")
|
||||
Reference in New Issue
Block a user