当前位置: 首页 > news >正文

python连接PostgreSQL 数据库操作类优化

PostgreSQL 数据库操作类优化

以下是对你的 GPDB 类的优化建议,包括性能改进、错误处理和代码结构优化:

import pandas as pd
import psycopg2
import psycopg2.extras
from io import StringIO
import contextlib
from typing import Optional, List, Dict, Any, Union

class GPDB:
    def __init__(self, dbname: str, user: str, password: str, host: str, port: str):
        """
        初始化数据库连接参数
        
        参数:
            dbname: 数据库名
            user: 用户名
            password: 密码
            host: 主机地址
            port: 端口号
        """
        self.dbname = dbname
        self.user = user
        self.password = password
        self.host = host
        self.port = port
        self._connection_pool = None  # 可以扩展为连接池

    @contextlib.contextmanager
    def _get_cursor(self, cursor_factory=None):
        """
        上下文管理器,自动处理连接和游标的创建与关闭
        
        参数:
            cursor_factory: 游标工厂,默认为DictCursor
        """
        conn = None
        cursor = None
        try:
            conn = self.gp_connect()
            cursor = conn.cursor(cursor_factory=cursor_factory or psycopg2.extras.DictCursor)
            yield cursor
            conn.commit()
        except Exception as e:
            if conn:
                conn.rollback()
            raise e
        finally:
            if cursor:
                cursor.close()
            if conn:
                conn.close()

    def gp_connect(self):
        """建立数据库连接"""
        try:
            return psycopg2.connect(
                dbname=self.dbname,
                user=self.user,
                password=self.password,
                host=self.host,
                port=self.port,
                connect_timeout=10  # 添加连接超时
            )
        except psycopg2.Error as e:
            raise ConnectionError(f"无法连接到Greenplum服务器: {e}")

    def select_data(self, sql: str, params: Optional[tuple] = None) -> List[Dict[str, Any]]:
        """
        执行查询并返回结果列表
        
        参数:
            sql: SQL查询语句
            params: SQL参数
            
        返回:
            包含查询结果的字典列表
        """
        with self._get_cursor() as cur:
            cur.execute(sql, params or ())
            return cur.fetchall()

    def execute_sql(self, sql: str, params: Optional[tuple] = None) -> int:
        """
        执行SQL语句(INSERT, UPDATE, DELETE等)
        
        参数:
            sql: SQL语句
            params: SQL参数
            
        返回:
            影响的行数
        """
        with self._get_cursor() as cur:
            cur.execute(sql, params or ())
            return cur.rowcount

    def truncate_table(self, table_name: str, cascade: bool = False) -> None:
        """
        清空表数据
        
        参数:
            table_name: 表名
            cascade: 是否级联清空相关表
        """
        sql = f"TRUNCATE TABLE {table_name}"
        if cascade:
            sql += " CASCADE"
        self.execute_sql(sql)

    def insert_df(self, table_name: str, df: pd.DataFrame, batch_size: int = 10000) -> int:
        """
        使用批量插入方式将DataFrame数据写入数据库
        
        参数:
            table_name: 目标表名
            df: 要插入的DataFrame
            batch_size: 每批插入的行数
            
        返回:
            插入的总行数
        """
        if df.empty:
            return 0

        columns = ', '.join(df.columns)
        placeholders = ', '.join(['%s'] * len(df.columns))
        sql = f"INSERT INTO {table_name} ({columns}) VALUES ({placeholders})"

        total_rows = 0
        with self._get_cursor() as cur:
            # 分批插入数据
            for i in range(0, len(df), batch_size):
                batch = df.iloc[i:i + batch_size]
                psycopg2.extras.execute_batch(cur, sql, batch.values.tolist())
                total_rows += len(batch)
        
        return total_rows

    def read_df(self, sql: str, params: Optional[tuple] = None) -> pd.DataFrame:
        """
        执行SQL查询并返回DataFrame
        
        参数:
            sql: SQL查询语句
            params: SQL参数
            
        返回:
            包含查询结果的DataFrame
        """
        with self._get_cursor() as cur:
            cur.execute(sql, params or ())
            columns = [desc[0] for desc in cur.description]
            data = cur.fetchall()
            return pd.DataFrame(data, columns=columns)

    def copy_from_df(self, table_name: str, df: pd.DataFrame, sep: str = '\t', null: str = '\\N') -> None:
        """
        使用COPY命令高效导入数据
        
        参数:
            table_name: 目标表名
            df: 要导入的DataFrame
            sep: 分隔符
            null: NULL值的表示方式
        """
        if df.empty:
            return

        with StringIO() as buffer:
            df.to_csv(buffer, sep=sep, index=False, header=False, na_rep=null)
            buffer.seek(0)
            
            with self._get_cursor() as cur:
                cur.copy_from(buffer, table_name, sep=sep, columns=df.columns.tolist(), null=null)

    def copy_from_file(self, table_name: str, file_path: str, sep: str = '\t', columns: Optional[List[str]] = None) -> None:
        """
        从文件导入数据到数据库表
        
        参数:
            table_name: 目标表名
            file_path: 文件路径
            sep: 分隔符
            columns: 要导入的列名列表
        """
        with open(file_path, 'r') as f:
            with self._get_cursor() as cur:
                cur.copy_from(f, table_name, sep=sep, columns=columns)

    def upsert_df(self, table_name: str, df: pd.DataFrame, conflict_columns: List[str], update_columns: List[str]) -> int:
        """
        执行UPSERT操作(存在则更新,不存在则插入)
        
        参数:
            table_name: 目标表名
            df: 要插入/更新的DataFrame
            conflict_columns: 冲突检测列
            update_columns: 需要更新的列
            
        返回:
            影响的总行数
        """
        if df.empty:
            return 0

        columns = ', '.join(df.columns)
        placeholders = ', '.join(['%s'] * len(df.columns))
        update_set = ', '.join([f"{col} = EXCLUDED.{col}" for col in update_columns])
        
        sql = f"""
        INSERT INTO {table_name} ({columns}) 
        VALUES ({placeholders})
        ON CONFLICT ({', '.join(conflict_columns)}) 
        DO UPDATE SET {update_set}
        """

        total_rows = 0
        with self._get_cursor() as cur:
            # 分批执行UPSERT
            for i in range(0, len(df), 10000):
                batch = df.iloc[i:i + 10000]
                psycopg2.extras.execute_batch(cur, sql, batch.values.tolist())
                total_rows += len(batch)
        
        return total_rows

优化说明

  1. 类型提示:添加了类型提示,提高代码可读性和IDE支持

  2. 上下文管理器:使用contextlib.contextmanager创建上下文管理器,自动处理连接和事务

  3. 批量操作

    • 添加了批量插入和批量更新功能
    • 默认分批处理大数据量,避免内存问题
  4. 错误处理

    • 更完善的错误处理和事务回滚
    • 连接超时设置
  5. 新增功能

    • 添加了upsert_df方法实现存在则更新,不存在则插入
    • 添加了通用execute_sql方法
  6. 性能优化

    • 使用execute_batch替代executemany提高批量插入性能
    • 改进了COPY命令的实现
  7. 代码结构

    • 更清晰的文档字符串
    • 更合理的参数命名
    • 分离不同功能的方法

使用示例

# 初始化
db = GPDB(dbname="mydb", user="user", password="pass", host="localhost", port="5432")

# 查询数据
results = db.select_data("SELECT * FROM users WHERE age > %s", (30,))

# 读取为DataFrame
df = db.read_df("SELECT * FROM products")

# 插入DataFrame
db.insert_df("products", df)

# 高效导入大数据
db.copy_from_df("large_table", large_df)

# UPSERT操作
db.upsert_df("users", user_df, conflict_columns=["id"], update_columns=["name", "email"])

这个优化版本提供了更好的性能、更强的健壮性和更清晰的接口设计。

http://www.dtcms.com/a/112412.html

相关文章:

  • Pycharm v2024.3.4 Windows Python开发工具
  • MinIO中的纠删码是什么
  • 正则入门到精通
  • 基于 LangChain 搭建简单 RAG 系统
  • Mysql 中的两阶段提交
  • HTML应用指南:利用POST请求获取三大运营商5G基站位置信息(一)
  • 2025-04-04 Unity 网络基础5——TCP分包与黏包
  • Windows 安装和使用 ElasticSearch
  • Git提交本地项目到Github
  • vue+form实现flappybird
  • 迅饶科技X2Modbus网关-GetUser信息泄露漏洞
  • Mysql 中 B 树 vs B+ 树
  • SQL Server 2022 脏读问题排查与思考
  • HTML5 vs HTML 和 CSS3 vs CSS:全面对比
  • Spring Boot 中使用 Redis:从入门到实战
  • Websoft9分享:在数字化转型中选择开源软件可能遇到的难题
  • 神经网络能不能完全拟合y=x² ???
  • WinForm真入门(7)——Button控件详解
  • 京东运维面试题及参考答案
  • k8s进阶之路:本地集群环境搭建
  • 谷歌 Gemini 2.5 Pro 免费开放
  • 24、 Python Socket编程:从协议解析到多线程实战
  • 如何完整迁移 Git 仓库 ?
  • yum list查询时部分包查找不到流程分析
  • 54.大学生心理健康管理系统(基于springboot项目)
  • 有人DTU使用MQTT协议控制Modbus协议的下位机-含数据库
  • Redis分布式锁详解
  • AWS Langfuse AI用Bedrock模型使用完全教程
  • 【万字总结】前端全方位性能优化指南(八)——Webpack 6调优、模块联邦升级、Tree Shaking突破
  • 安卓离线畅玩的多款棋类单机游戏推荐