"""PostgreSQL database connection handling"""
import asyncio
import logging
from typing import Any, Dict, List, Optional, Union
from contextlib import asynccontextmanager
import asyncpg
from asyncpg import Pool, Connection, Record
from .config import DatabaseConfig
logger = logging.getLogger(__name__)
class DatabaseError(Exception):
"""Custom database error"""
pass
class PostgreSQLManager:
"""PostgreSQL connection and query manager"""
def __init__(self, config: DatabaseConfig):
self.config = config
self.pool: Optional[Pool] = None
self._lock = asyncio.Lock()
async def initialize(self) -> None:
"""Initialize database connection pool"""
async with self._lock:
if self.pool is None:
try:
logger.info(f"Connecting to PostgreSQL at {self.config.host}:{self.config.port}")
self.pool = await asyncpg.create_pool(
self.config.connection_string,
min_size=self.config.min_connections,
max_size=self.config.max_connections,
command_timeout=30,
)
logger.info("Database connection pool initialized")
except Exception as e:
logger.error(f"Failed to connect to database: {e}")
raise DatabaseError(f"Database connection failed: {e}")
async def close(self) -> None:
"""Close database connection pool"""
if self.pool:
await self.pool.close()
self.pool = None
logger.info("Database connection pool closed")
@asynccontextmanager
async def get_connection(self):
"""Get a database connection from the pool"""
if not self.pool:
await self.initialize()
async with self.pool.acquire() as connection:
try:
yield connection
except Exception as e:
logger.error(f"Database operation failed: {e}")
raise DatabaseError(f"Database operation failed: {e}")
async def execute_query(
self,
query: str,
params: Optional[List[Any]] = None,
fetch_mode: str = "all"
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
"""
Execute a SQL query
Args:
query: SQL query string
params: Query parameters
fetch_mode: "all", "one", "none", or "scalar"
Returns:
Query results based on fetch_mode
"""
params = params or []
async with self.get_connection() as conn:
try:
logger.debug(f"Executing query: {query[:100]}...")
if fetch_mode == "none":
await conn.execute(query, *params)
return None
elif fetch_mode == "scalar":
result = await conn.fetchval(query, *params)
return result
elif fetch_mode == "one":
record = await conn.fetchrow(query, *params)
return dict(record) if record else None
else: # fetch_mode == "all"
records = await conn.fetch(query, *params)
return [dict(record) for record in records]
except asyncpg.PostgresError as e:
logger.error(f"PostgreSQL error: {e}")
raise DatabaseError(f"PostgreSQL error: {e}")
except Exception as e:
logger.error(f"Unexpected error: {e}")
raise DatabaseError(f"Unexpected error: {e}")
async def get_table_info(self, schema: str = "public") -> List[Dict[str, Any]]:
"""Get information about tables in a schema"""
query = """
SELECT
table_name,
table_type,
table_comment
FROM information_schema.tables
WHERE table_schema = $1
ORDER BY table_name
"""
return await self.execute_query(query, [schema])
async def get_column_info(self, table_name: str, schema: str = "public") -> List[Dict[str, Any]]:
"""Get column information for a table"""
query = """
SELECT
column_name,
data_type,
is_nullable,
column_default,
character_maximum_length,
numeric_precision,
numeric_scale
FROM information_schema.columns
WHERE table_schema = $1 AND table_name = $2
ORDER BY ordinal_position
"""
return await self.execute_query(query, [schema, table_name])
async def get_schemas(self) -> List[str]:
"""Get list of available schemas"""
query = """
SELECT schema_name
FROM information_schema.schemata
WHERE schema_name NOT IN ('information_schema', 'pg_catalog', 'pg_toast')
ORDER BY schema_name
"""
results = await self.execute_query(query)
return [row['schema_name'] for row in results]
async def test_connection(self) -> Dict[str, Any]:
"""Test database connection and return server info"""
query = "SELECT version() as version, current_database() as database, current_user as user"
result = await self.execute_query(query, fetch_mode="one")
return result