python连接PostgreSQL 数据库操作类优化
PostgreSQL 数据库操作类优化
以下是对你的 GPDB
类的优化建议,包括性能改进、错误处理和代码结构优化:
import pandas as pd
import psycopg2
import psycopg2.extras
from io import StringIO
import contextlib
from typing import Optional, List, Dict, Any, Union
class GPDB:
def __init__(self, dbname: str, user: str, password: str, host: str, port: str):
"""
初始化数据库连接参数
参数:
dbname: 数据库名
user: 用户名
password: 密码
host: 主机地址
port: 端口号
"""
self.dbname = dbname
self.user = user
self.password = password
self.host = host
self.port = port
self._connection_pool = None # 可以扩展为连接池
@contextlib.contextmanager
def _get_cursor(self, cursor_factory=None):
"""
上下文管理器,自动处理连接和游标的创建与关闭
参数:
cursor_factory: 游标工厂,默认为DictCursor
"""
conn = None
cursor = None
try:
conn = self.gp_connect()
cursor = conn.cursor(cursor_factory=cursor_factory or psycopg2.extras.DictCursor)
yield cursor
conn.commit()
except Exception as e:
if conn:
conn.rollback()
raise e
finally:
if cursor:
cursor.close()
if conn:
conn.close()
def gp_connect(self):
"""建立数据库连接"""
try:
return psycopg2.connect(
dbname=self.dbname,
user=self.user,
password=self.password,
host=self.host,
port=self.port,
connect_timeout=10 # 添加连接超时
)
except psycopg2.Error as e:
raise ConnectionError(f"无法连接到Greenplum服务器: {e}")
def select_data(self, sql: str, params: Optional[tuple] = None) -> List[Dict[str, Any]]:
"""
执行查询并返回结果列表
参数:
sql: SQL查询语句
params: SQL参数
返回:
包含查询结果的字典列表
"""
with self._get_cursor() as cur:
cur.execute(sql, params or ())
return cur.fetchall()
def execute_sql(self, sql: str, params: Optional[tuple] = None) -> int:
"""
执行SQL语句(INSERT, UPDATE, DELETE等)
参数:
sql: SQL语句
params: SQL参数
返回:
影响的行数
"""
with self._get_cursor() as cur:
cur.execute(sql, params or ())
return cur.rowcount
def truncate_table(self, table_name: str, cascade: bool = False) -> None:
"""
清空表数据
参数:
table_name: 表名
cascade: 是否级联清空相关表
"""
sql = f"TRUNCATE TABLE {table_name}"
if cascade:
sql += " CASCADE"
self.execute_sql(sql)
def insert_df(self, table_name: str, df: pd.DataFrame, batch_size: int = 10000) -> int:
"""
使用批量插入方式将DataFrame数据写入数据库
参数:
table_name: 目标表名
df: 要插入的DataFrame
batch_size: 每批插入的行数
返回:
插入的总行数
"""
if df.empty:
return 0
columns = ', '.join(df.columns)
placeholders = ', '.join(['%s'] * len(df.columns))
sql = f"INSERT INTO {table_name} ({columns}) VALUES ({placeholders})"
total_rows = 0
with self._get_cursor() as cur:
# 分批插入数据
for i in range(0, len(df), batch_size):
batch = df.iloc[i:i + batch_size]
psycopg2.extras.execute_batch(cur, sql, batch.values.tolist())
total_rows += len(batch)
return total_rows
def read_df(self, sql: str, params: Optional[tuple] = None) -> pd.DataFrame:
"""
执行SQL查询并返回DataFrame
参数:
sql: SQL查询语句
params: SQL参数
返回:
包含查询结果的DataFrame
"""
with self._get_cursor() as cur:
cur.execute(sql, params or ())
columns = [desc[0] for desc in cur.description]
data = cur.fetchall()
return pd.DataFrame(data, columns=columns)
def copy_from_df(self, table_name: str, df: pd.DataFrame, sep: str = '\t', null: str = '\\N') -> None:
"""
使用COPY命令高效导入数据
参数:
table_name: 目标表名
df: 要导入的DataFrame
sep: 分隔符
null: NULL值的表示方式
"""
if df.empty:
return
with StringIO() as buffer:
df.to_csv(buffer, sep=sep, index=False, header=False, na_rep=null)
buffer.seek(0)
with self._get_cursor() as cur:
cur.copy_from(buffer, table_name, sep=sep, columns=df.columns.tolist(), null=null)
def copy_from_file(self, table_name: str, file_path: str, sep: str = '\t', columns: Optional[List[str]] = None) -> None:
"""
从文件导入数据到数据库表
参数:
table_name: 目标表名
file_path: 文件路径
sep: 分隔符
columns: 要导入的列名列表
"""
with open(file_path, 'r') as f:
with self._get_cursor() as cur:
cur.copy_from(f, table_name, sep=sep, columns=columns)
def upsert_df(self, table_name: str, df: pd.DataFrame, conflict_columns: List[str], update_columns: List[str]) -> int:
"""
执行UPSERT操作(存在则更新,不存在则插入)
参数:
table_name: 目标表名
df: 要插入/更新的DataFrame
conflict_columns: 冲突检测列
update_columns: 需要更新的列
返回:
影响的总行数
"""
if df.empty:
return 0
columns = ', '.join(df.columns)
placeholders = ', '.join(['%s'] * len(df.columns))
update_set = ', '.join([f"{col} = EXCLUDED.{col}" for col in update_columns])
sql = f"""
INSERT INTO {table_name} ({columns})
VALUES ({placeholders})
ON CONFLICT ({', '.join(conflict_columns)})
DO UPDATE SET {update_set}
"""
total_rows = 0
with self._get_cursor() as cur:
# 分批执行UPSERT
for i in range(0, len(df), 10000):
batch = df.iloc[i:i + 10000]
psycopg2.extras.execute_batch(cur, sql, batch.values.tolist())
total_rows += len(batch)
return total_rows
优化说明
-
类型提示:添加了类型提示,提高代码可读性和IDE支持
-
上下文管理器:使用
contextlib.contextmanager
创建上下文管理器,自动处理连接和事务 -
批量操作:
- 添加了批量插入和批量更新功能
- 默认分批处理大数据量,避免内存问题
-
错误处理:
- 更完善的错误处理和事务回滚
- 连接超时设置
-
新增功能:
- 添加了
upsert_df
方法实现存在则更新,不存在则插入 - 添加了通用
execute_sql
方法
- 添加了
-
性能优化:
- 使用
execute_batch
替代executemany
提高批量插入性能 - 改进了COPY命令的实现
- 使用
-
代码结构:
- 更清晰的文档字符串
- 更合理的参数命名
- 分离不同功能的方法
使用示例
# 初始化
db = GPDB(dbname="mydb", user="user", password="pass", host="localhost", port="5432")
# 查询数据
results = db.select_data("SELECT * FROM users WHERE age > %s", (30,))
# 读取为DataFrame
df = db.read_df("SELECT * FROM products")
# 插入DataFrame
db.insert_df("products", df)
# 高效导入大数据
db.copy_from_df("large_table", large_df)
# UPSERT操作
db.upsert_df("users", user_df, conflict_columns=["id"], update_columns=["name", "email"])
这个优化版本提供了更好的性能、更强的健壮性和更清晰的接口设计。