scripts/db: update documentation, add pg -> sqlite section (#6177)

* scripts/db: update doc

* scripts/db: update documentation, add pg2sqlite script

* scripts/db/pg2sqlite: formatiing, code quality, modularity

* scripts/db: parallel pgloader
This commit is contained in:
sh
2025-09-15 13:06:39 +00:00
committed by GitHub
parent e9961ebb8e
commit 7c92c7666c
3 changed files with 1197 additions and 63 deletions

899
scripts/db/pg2sqlite.py Executable file
View File

@@ -0,0 +1,899 @@
#!/usr/bin/env python3
"""
PostgreSQL -> per-schema SQLite migration with colored, clean logging.
Usage example:
python db_migrate.py 'postgresql://user:pass@host:5432/db' /path/to/sqlite/dir --dry-run
Note: color output will be disabled automatically if stdout is not a TTY or if --no-color is passed.
"""
from __future__ import annotations
import argparse
import logging
import os
import sqlite3
import sys
import re
import datetime
from pathlib import Path
from typing import Set, List, Tuple, Dict, Optional, NamedTuple
import psycopg
from psycopg import sql
# ----------------------
# Small utilities
# ----------------------
try:
sqlite3.register_adapter(
datetime.datetime,
lambda v: v.isoformat(sep=" ", timespec="microseconds"),
)
sqlite3.register_adapter(datetime.date, lambda v: v.isoformat())
sqlite3.register_adapter(
datetime.time, lambda v: v.isoformat(timespec="microseconds")
)
except ValueError:
pass # already registered
ANSI = {
"reset": "\x1b[0m",
"bold": "\x1b[1m",
"dim": "\x1b[2m",
"red": "\x1b[31m",
"green": "\x1b[32m",
"yellow": "\x1b[33m",
"blue": "\x1b[34m",
"magenta": "\x1b[35m",
"cyan": "\x1b[36m",
"gray": "\x1b[90m",
}
DEFAULT_BATCH_SIZE = 10000
TYPE_COMPATIBILITY = {
"bytea": ["BLOB", "CHAR", "CLOB", "TEXT", "JSON"],
"int": ["INT", "NUMERIC"],
"serial": ["INT", "NUMERIC"],
"numeric": ["NUMERIC", "DECIMAL", "REAL", "FLOAT", "DOUBLE"],
"decimal": ["NUMERIC", "DECIMAL", "REAL", "FLOAT", "DOUBLE"],
"real": ["NUMERIC", "DECIMAL", "REAL", "FLOAT", "DOUBLE"],
"double": ["NUMERIC", "DECIMAL", "REAL", "FLOAT", "DOUBLE"],
"float": ["NUMERIC", "DECIMAL", "REAL", "FLOAT", "DOUBLE"],
"money": ["NUMERIC", "DECIMAL", "REAL", "FLOAT", "DOUBLE"],
"bool": ["BOOL", "INT", "NUMERIC"],
"varchar": ["CHAR", "CLOB", "TEXT"],
"char": ["CHAR", "CLOB", "TEXT"],
"text": ["CHAR", "CLOB", "TEXT"],
"citext": ["CHAR", "CLOB", "TEXT"],
"timestamp": ["DATE", "TIME", "CHAR", "TEXT", "DATETIME"],
"time": ["DATE", "TIME", "CHAR", "TEXT", "DATETIME"],
"date": ["DATE", "TIME", "CHAR", "TEXT", "DATETIME"],
"uuid": ["CHAR", "TEXT", "UUID", "CLOB"],
"json": ["JSON", "TEXT", "CHAR", "CLOB"],
"jsonb": ["JSON", "TEXT", "CHAR", "CLOB"],
}
def _sanitize_cursor_name(s: str) -> str:
return re.sub(r"[^A-Za-z0-9_]+", "_", s)
def supports_color(force_no: bool) -> bool:
"""Return True when we should emit ANSI colors."""
if force_no:
return False
if os.getenv("NO_COLOR"):
return False
term = os.getenv("TERM", "")
if term == "" or term.lower() == "dumb":
return False
try:
isatty = sys.stdout.isatty()
except Exception:
isatty = False
return isatty
class ColoredFormatter(logging.Formatter):
LEVEL_COLORS = {
logging.DEBUG: ANSI["gray"],
logging.INFO: ANSI["green"],
logging.WARNING: ANSI["yellow"],
logging.ERROR: ANSI["red"],
logging.CRITICAL: ANSI["red"] + ANSI["bold"],
}
TAG_COLORS = {
"SKIP": ANSI["yellow"],
"SCHEMA": ANSI["blue"],
"OK": ANSI["magenta"],
}
def __init__(self, use_color: bool = True):
super().__init__(fmt="%(message)s")
self.use_color = use_color
def format(self, record: logging.LogRecord) -> str:
msg = super().format(record)
parts = msg.split(" ", 1)
tag = parts[0]
rest = parts[1] if len(parts) > 1 else ""
plain_label = f"[{tag}]"
if not self.use_color:
return f"{plain_label}{(' ' + rest) if rest else ''}"
color = self.TAG_COLORS.get(
tag.upper(), self.LEVEL_COLORS.get(record.levelno, "")
)
reset = ANSI["reset"]
return f"{color}{plain_label}{reset}{(' ' + rest) if rest else ''}"
def setup_logger(verbose: bool, no_color: bool) -> logging.Logger:
use_color = supports_color(no_color)
logger = logging.getLogger("db_migrate")
logger.setLevel(logging.DEBUG if verbose else logging.INFO)
handler = logging.StreamHandler()
handler.setLevel(logging.DEBUG if verbose else logging.INFO)
handler.setFormatter(ColoredFormatter(use_color=use_color))
logger.handlers.clear()
logger.addHandler(handler)
logging.getLogger("psycopg").setLevel(logging.WARNING)
if verbose:
logger.debug(f"color_support: {use_color}")
try:
isatty = sys.stdout.isatty()
except Exception:
isatty = False
logger.debug(
f"TERM={os.getenv('TERM', '')!r} "
f"NO_COLOR={os.getenv('NO_COLOR')!r} "
f"isatty={isatty}"
)
return logger
def quote_sqlite_identifier(name: str) -> str:
return '"' + name.replace('"', '""') + '"'
def quote_pg_identifier(name: str) -> str:
"""Simple PG identifier quoting for building safe SQL strings."""
return '"' + name.replace('"', '""') + '"'
def sqlite_decl_satisfies(pg_type: str, sqlite_decl: str) -> bool:
# Treat empty/blank SQLite declarations more permissively based on PG type.
decl_raw = sqlite_decl or ""
if not decl_raw.strip():
pg = (pg_type or "").lower()
# arrays -> textual affinity
if pg.endswith("[]"):
return True
# integer-like
if re.search(r"\b(?:int|serial|bigint)\b", pg):
return True
# numeric/float
if re.search(r"\b(?:numeric|decimal|real|double|float|money)\b", pg):
return True
# boolean
if re.search(r"\b(?:bool|boolean)\b", pg):
return True
# binary
if re.search(r"\b(?:bytea)\b", pg):
return True
# textual/json/uuid/timestamps/dates/times
if re.search(
r"\b(?:varchar|char|text|citext|jsonb|json|uuid|timestamp|time|date)\b", pg
):
return True
# conservative fallback: accept empty decl as permissive
return True
decl = decl_raw.upper()
pg = (pg_type or "").lower()
# array type
if pg.endswith("[]"):
return any(tok in decl for tok in ("TEXT", "CHAR", "CLOB", "JSON"))
for key, allowed_types in TYPE_COMPATIBILITY.items():
if re.search(r"\b" + re.escape(key) + r"\b", pg):
return any(tok in decl for tok in allowed_types)
return any(
tok in decl for tok in ("TEXT", "CHAR", "CLOB", "NUMERIC", "BLOB", "INT")
)
# ----------------------
# Postgres helpers
# ----------------------
def list_user_schemas(pg_cursor) -> List[str]:
pg_cursor.execute(
"""
SELECT nspname
FROM pg_namespace
WHERE nspname NOT LIKE 'pg_%'
AND nspname != 'information_schema'
AND nspname != 'public'
ORDER BY nspname;
"""
)
return [r[0] for r in pg_cursor.fetchall()]
def list_tables_in_schema(pg_cursor, schema: str) -> List[str]:
pg_cursor.execute(
sql.SQL(
"""
SELECT table_name
FROM information_schema.tables
WHERE table_schema = %s
ORDER BY table_name;
"""
),
(schema,),
)
return [r[0] for r in pg_cursor.fetchall()]
def get_pg_columns(pg_cursor, schema: str, table: str) -> List[Tuple[str, str]]:
pg_cursor.execute(
sql.SQL(
"""
SELECT column_name, data_type
FROM information_schema.columns
WHERE table_schema = %s AND table_name = %s
ORDER BY ordinal_position;
"""
),
(schema, table),
)
return [(r[0], r[1]) for r in pg_cursor.fetchall()]
# ----------------------
# SQLite helpers
# ----------------------
class SQLiteCol(NamedTuple):
cid: int
name: str
type: str
notnull: int
dflt_value: str
pk: int
def get_sqlite_table_info(sqlite_cursor, table_name: str) -> list[SQLiteCol]:
qname = quote_sqlite_identifier(table_name)
sqlite_cursor.execute(f"PRAGMA table_info({qname});")
return [SQLiteCol(*row) for row in sqlite_cursor.fetchall()]
def sqlite_table_has_autoincrement(sqlite_cursor, table_name: str) -> bool:
sqlite_cursor.execute(
"SELECT sql FROM sqlite_master WHERE type='table' AND name = ?;", (table_name,)
)
row = sqlite_cursor.fetchone()
if not row or not row[0]:
return False
return "AUTOINCREMENT" in row[0].upper()
def sqlite_sequence_table_exists(sqlite_cursor) -> bool:
sqlite_cursor.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='sqlite_sequence';"
)
exists = sqlite_cursor.fetchone() is not None
if not exists:
logging.getLogger("db_migrate").debug("sqlite_sequence table not present")
return exists
# ----------------------
# Core migration logic
# ----------------------
def build_select_for_sqlite_columns(
schema: str,
table: str,
sqlite_cols: list[SQLiteCol],
pg_columns_map: Dict[str, Tuple[str, str]],
) -> str:
"""Return a plain SQL string (no trailing semicolon) selecting PG columns in the order of sqlite_cols.
Uses simple identifier quoting to avoid psycopg.sql objects which may vary by driver version.
"""
select_parts = []
for col in sqlite_cols:
lc = col.name.lower()
if lc in pg_columns_map:
pg_name, pg_type = pg_columns_map[lc]
pg_type_l = (pg_type or "").lower()
if any(tok in pg_type_l for tok in ("timestamp", "time")):
expr = f"{quote_pg_identifier(pg_name)}::text"
else:
expr = f"{quote_pg_identifier(pg_name)}"
select_parts.append(expr)
else:
select_parts.append("NULL")
select_list = ", ".join(select_parts)
# Quote schema and table names simply (this is not comprehensive for every corner case but avoids psycopg.sql use)
return f"SELECT {select_list} FROM {quote_pg_identifier(schema)}.{quote_pg_identifier(table)}"
def validate_column_compatibility(
sqlite_cols: list[SQLiteCol],
pg_columns_map: Dict[str, Tuple[str, str]],
schema: str,
table: str,
) -> None:
for col in sqlite_cols:
lc = col.name.lower()
if lc not in pg_columns_map:
continue
pg_type = pg_columns_map[lc][1]
if not sqlite_decl_satisfies(pg_type, col.type):
raise ValueError(
f"Type mismatch for {schema}.{table}.{col.name}: "
f"PG='{pg_type}' vs SQLite='{col.type}'"
)
def process_row_for_sqlite(row: Tuple, sqlite_cols: list[SQLiteCol]) -> Tuple:
out = []
for i, val in enumerate(row):
col = sqlite_cols[i]
decl_raw = col.type or ""
decl = decl_raw.upper()
decl_is_empty = decl_raw.strip() == ""
if isinstance(val, memoryview):
b = val.tobytes()
if "BLOB" in decl or decl_is_empty:
out.append(sqlite3.Binary(b))
else:
try:
out.append(b.decode("utf-8"))
except UnicodeDecodeError as e:
raise ValueError(
f"UTF-8 decode failed for column '{col.name}': {e}"
)
elif isinstance(val, (bytes, bytearray)):
b = bytes(val)
if "BLOB" in decl or decl_is_empty:
out.append(sqlite3.Binary(b))
else:
try:
out.append(b.decode("utf-8"))
except UnicodeDecodeError as e:
raise ValueError(
f"UTF-8 decode failed for column '{col.name}': {e}"
)
else:
out.append(val)
return tuple(out)
def fetch_and_validate_row_batches(
pg_cursor,
select_sql: str,
sqlite_cols: list[SQLiteCol],
batch_size: int = DEFAULT_BATCH_SIZE,
):
pg_cursor.execute(select_sql)
total_rows_seen = 0
while True:
rows = pg_cursor.fetchmany(batch_size)
if not rows:
break
validated_batch = []
for row in rows:
total_rows_seen += 1
try:
validated_batch.append(process_row_for_sqlite(row, sqlite_cols))
except ValueError as e:
raise ValueError(f"Row {total_rows_seen} validation error: {e}")
yield validated_batch
def _get_postgres_sequence_info(
pg_cursor, schema: str, table: str, pk_col: str
) -> Optional[tuple[str, int]]:
"""Get PostgreSQL sequence name and last_value for a table primary key, if any.
Returns (sequence_name_text, last_value) or None.
"""
pg_cursor.execute(
"""
SELECT data_type
FROM information_schema.columns
WHERE table_schema = %s AND table_name = %s AND column_name = %s;
""",
(schema, table, pk_col),
)
r = pg_cursor.fetchone()
if not r:
return None
pg_type = (r[0] or "").lower()
# Match whole words to avoid matching 'interval' etc.
if not re.search(r"\b(?:int|serial|bigint)\b", pg_type):
return None
# Get sequence name text (NULL if none)
pg_cursor.execute(
"SELECT pg_get_serial_sequence(%s, %s);",
(f"{schema}.{table}", pk_col),
)
row = pg_cursor.fetchone()
seq_name = row[0] if row else None
if not seq_name:
return None
try:
# Read last_value by casting the pg_get_serial_sequence result to regclass.
pg_cursor.execute(
"SELECT last_value FROM pg_get_serial_sequence(%s, %s)::regclass;",
(f"{schema}.{table}", pk_col),
)
rr = pg_cursor.fetchone()
if rr and rr[0] is not None:
return (seq_name, int(rr[0]))
except Exception:
# Swallow errors (best-effort); callers treat missing info as absent
pass
return None
def _get_sqlite_max_pk_value(sqlite_cursor, table: str, pk_col: str) -> Optional[int]:
"""Get the maximum primary key value from SQLite table."""
logger = logging.getLogger("db_migrate")
try:
sqlite_cursor.execute(
f"SELECT MAX({quote_sqlite_identifier(pk_col)}) "
f"FROM {quote_sqlite_identifier(table)};"
)
r3 = sqlite_cursor.fetchone()
if r3 and r3[0] is not None:
return int(r3[0])
except Exception:
logger.debug("Failed to read sqlite max pk", exc_info=True)
return None
def _update_sqlite_sequence(
sqlite_conn: sqlite3.Connection, table: str, sequence_value: int
) -> bool:
"""Update SQLite sqlite_sequence with new value.
IMPORTANT: This function does not commit; caller must commit/rollback.
"""
s_cur = sqlite_conn.cursor()
s_cur.execute(
"UPDATE sqlite_sequence SET seq = ? WHERE name = ?;",
(sequence_value, table),
)
if s_cur.rowcount == 0:
s_cur.execute(
"INSERT INTO sqlite_sequence(name, seq) VALUES (?, ?);",
(table, sequence_value),
)
return True
def try_update_sqlite_sequence(
pg_cursor,
sqlite_conn: sqlite3.Connection,
schema: str,
pg_table: str, # PostgreSQL table name
sqlite_table: str, # SQLite table name
sqlite_cols: list[SQLiteCol],
) -> bool:
"""Update SQLite sequence table based on PostgreSQL sequence values."""
logger = logging.getLogger("db_migrate")
# Check if table has a single primary key
pks = [c for c in sqlite_cols if c.pk]
if len(pks) != 1:
return False
pk_col = pks[0].name
# Check if SQLite has AUTOINCREMENT (use sqlite_table)
s_cur = sqlite_conn.cursor()
if not sqlite_table_has_autoincrement(s_cur, sqlite_table):
return False
if not sqlite_sequence_table_exists(s_cur):
return False
# Get PostgreSQL sequence info (use pg_table)
seq_info = _get_postgres_sequence_info(pg_cursor, schema, pg_table, pk_col)
pg_last_value = seq_info[1] if seq_info else None
# Get SQLite max PK value (use sqlite_table)
sqlite_max = _get_sqlite_max_pk_value(s_cur, sqlite_table, pk_col)
# Determine the value to use
if pg_last_value is not None and sqlite_max is not None:
candidate = max(pg_last_value, sqlite_max)
elif pg_last_value is not None:
candidate = pg_last_value
elif sqlite_max is not None:
candidate = sqlite_max
else:
return False
updated = _update_sqlite_sequence(sqlite_conn, sqlite_table, candidate)
if updated:
try:
sqlite_conn.commit()
except Exception:
logger.debug("Failed to commit sqlite_sequence update", exc_info=True)
# Let caller proceed; treat as best-effort
return updated
# ----------------------
# Flow: per-schema migration
# ----------------------
def migrate_schema(
pg_conn,
sqlite_dir: str,
schema: str,
skipped_tables: Set[str],
logger: logging.Logger,
dry_run: bool = False,
batch_size: int = DEFAULT_BATCH_SIZE,
) -> Tuple[int, int]:
"""Migrate a single schema. Returns (tables_migrated, rows_inserted_total)."""
processed_schema = schema[:-7] if schema.endswith("_schema") else schema
sqlite_db_path = Path(sqlite_dir) / f"{processed_schema}.db"
if not sqlite_db_path.is_file():
logger.error(f"Missing SQLite DB: {sqlite_db_path}")
return 0, 0
tables_migrated = 0
rows_inserted = 0
sqlite_path = sqlite_db_path.resolve()
if dry_run:
uri = sqlite_path.as_uri() + "?mode=ro"
conn_args = {"database": uri, "uri": True}
else:
conn_args = {"database": str(sqlite_path)}
with sqlite3.connect(**conn_args) as sqlite_conn:
sqlite_cur = sqlite_conn.cursor()
sqlite_cur.execute("SELECT name FROM sqlite_master WHERE type='table';")
sqlite_tables = [r[0] for r in sqlite_cur.fetchall()]
# O(1) lookup map for matching by lowercase name
sqlite_table_map = {t.lower(): t for t in sqlite_tables}
with pg_conn.cursor() as pg_cur:
# Ensure we start in a clean state for this connection
try:
pg_conn.rollback()
except Exception:
# Ignore rollback failure; connection is newly opened most likely
logger.debug(
"pg_conn.rollback() at schema start ignored", exc_info=True
)
# Get table list for this schema, but catch/rollback on failure
try:
tables = list_tables_in_schema(pg_cur, schema)
except Exception as e:
logger.error("ERROR %s (list tables): %s", schema, e)
logger.debug("Traceback (list tables):", exc_info=True)
try:
pg_conn.rollback()
except Exception:
logger.debug(
"pg_conn.rollback() failed after list tables error",
exc_info=True,
)
return 0, 0
for table in tables:
# Defensive: ensure connection is in a clean state before any new PG work.
# A prior error can leave the connection in an aborted transaction; calling
# rollback() clears that and allows subsequent SELECTs to run.
try:
pg_conn.rollback()
except Exception:
# ignore: best-effort cleanup
logger.debug(
"pg_conn.rollback() ignored at start of table loop",
exc_info=True,
)
if table.lower() in skipped_tables:
logger.info("SKIP %s.%s (explicit)", schema, table)
continue
if table.lower() not in sqlite_table_map:
logger.info("SKIP %s.%s (no target table)", schema, table)
continue
matched_table = sqlite_table_map[table.lower()]
logger.info("OK %s.%s -> %s", schema, table, matched_table)
# Fetch PG columns, but defend against aborted transaction here
try:
pg_cols = get_pg_columns(pg_cur, schema, table)
except Exception as e:
logger.error("ERROR %s.%s (get columns): %s", schema, table, e)
logger.debug("Traceback (get columns):", exc_info=True)
try:
pg_conn.rollback()
except Exception:
logger.debug(
"pg_conn.rollback() failed after get columns error",
exc_info=True,
)
continue
pg_map: Dict[str, Tuple[str, str]] = {
name.lower(): (name, dtype) for name, dtype in pg_cols
}
sqlite_info = get_sqlite_table_info(sqlite_cur, matched_table)
if not sqlite_info:
logger.warning(f"SKIP {schema}.{table} (no sqlite info)")
continue
# Warn once about schema drift (extra columns)
pg_col_names = {name.lower() for name, _ in pg_cols}
extra_in_sqlite = {
c.name for c in sqlite_info if c.name.lower() not in pg_col_names
}
if extra_in_sqlite:
logger.warning(
f"Schema drift: {schema}.{table} has extra SQLite columns {extra_in_sqlite}"
)
try:
validate_column_compatibility(sqlite_info, pg_map, schema, table)
except ValueError as e:
logger.error(f"ERROR {schema}.{table} (type mismatch): {e}")
continue
select_sql = build_select_for_sqlite_columns(
schema, table, sqlite_info, pg_map
)
csr_name = _sanitize_cursor_name(f"csr_{schema}_{table}")
try:
with pg_conn.cursor(name=csr_name) as data_cur:
batch_gen = fetch_and_validate_row_batches(
data_cur, select_sql, sqlite_info, batch_size=batch_size
)
first_batch = next(batch_gen, None)
if not first_batch:
logger.info(f"SKIP {schema}.{table} (no rows)")
# data_cur will be closed automatically on leaving the 'with'
continue
if dry_run:
count = len(first_batch)
for batch in batch_gen:
count += len(batch)
inserted = 0
logger.info(
f"DRY {schema}.{table} rows_validated={count}"
)
else:
cur = sqlite_conn.cursor()
quoted_table = quote_sqlite_identifier(matched_table)
col_names = [c.name for c in sqlite_info]
quoted_cols = ", ".join(
quote_sqlite_identifier(c) for c in col_names
)
placeholders = ", ".join(["?"] * len(col_names))
cur.execute("BEGIN;")
try:
cur.execute(f"DELETE FROM {quoted_table};")
inserted = 0
if first_batch:
cur.executemany(
f"INSERT INTO {quoted_table} ({quoted_cols}) VALUES ({placeholders});",
first_batch,
)
inserted += len(first_batch)
for batch in batch_gen:
if not batch:
continue
cur.executemany(
f"INSERT INTO {quoted_table} ({quoted_cols}) VALUES ({placeholders});",
batch,
)
inserted += len(batch)
except Exception:
cur.execute("ROLLBACK;")
raise
else:
cur.execute("COMMIT;")
except ValueError as e:
logger.error("ERROR %s.%s (row validation): %s", schema, table, e)
try:
pg_conn.rollback()
except Exception:
logger.debug(
"pg_conn.rollback() failed after row validation error",
exc_info=True,
)
continue
except Exception as e:
logger.error("ERROR %s.%s (select failed): %s", schema, table, e)
logger.debug("Traceback (select failed):", exc_info=True)
try:
pg_conn.rollback()
except Exception:
logger.debug(
"pg_conn.rollback() failed after select failed",
exc_info=True,
)
continue
rows_inserted += inserted
tables_migrated += 1
logger.info(f"DONE {schema}.{table} rows={inserted}")
if not dry_run:
try:
if try_update_sqlite_sequence(
pg_cur,
sqlite_conn,
schema,
table,
matched_table,
sqlite_info,
):
logger.info(
"SEQ %s.%s sqlite_sequence updated", schema, table
)
except Exception as e:
logger.warning(
"SEQ %s.%s update failed (ignored): %s", schema, table, e
)
return tables_migrated, rows_inserted
def migrate_data(
pg_conn_str: str,
sqlite_dir: str,
schema_filter: Optional[str],
dry_run: bool,
skip_tables: str,
logger: logging.Logger,
batch_size: int,
) -> None:
skipped_tables = {t.strip().lower() for t in skip_tables.split(",") if t.strip()}
total_tables = 0
total_rows = 0
total_errors = 0
# List schemas once with a short-lived connection
with psycopg.connect(pg_conn_str) as tmp_conn:
with tmp_conn.cursor() as cur:
schemas = list_user_schemas(cur)
if schema_filter:
schemas = [s for s in schemas if s == schema_filter]
if not schemas:
logger.error("No schemas to process")
return
for schema in schemas:
logger.info("SCHEMA %s", schema)
if dry_run:
logger.info("(dry-run) validating only — no writes will be performed")
# Use a fresh connection per-schema to isolate failures/aborted transactions
try:
with psycopg.connect(pg_conn_str) as pg_conn:
migrated_tables, inserted_rows = migrate_schema(
pg_conn,
sqlite_dir,
schema,
skipped_tables,
logger,
dry_run=dry_run,
batch_size=batch_size,
)
except Exception as e:
logger.error("Schema %s failed: %s", schema, e)
logger.debug("Traceback (schema failure):", exc_info=True)
total_errors += 1
continue
total_tables += migrated_tables
total_rows += inserted_rows
logger.info("---")
logger.info(
f"SUMMARY: schemas={len(schemas)} "
f"tables_migrated={total_tables} "
f"rows_inserted={total_rows} "
f"errors={total_errors}"
)
# ----------------------
# CLI
# ----------------------
def parse_args(argv):
p = argparse.ArgumentParser(
description="Migrate Postgres data into per-schema SQLite DBs (non-invasive)."
)
p.add_argument("pg_conn", help="Postgres connection string")
p.add_argument(
"sqlite_dir", help="Directory containing per-schema sqlite .db files"
)
p.add_argument(
"--schema", help="Only migrate this schema (exact match)", default=None
)
p.add_argument(
"--dry-run",
help="Validate and report only; do not write to sqlite",
action="store_true",
)
p.add_argument("--verbose", help="Verbose logging (debug)", action="store_true")
p.add_argument("--no-color", help="Disable ANSI color output", action="store_true")
p.add_argument(
"--batch-size",
type=int,
default=DEFAULT_BATCH_SIZE,
help="Batch size for data fetching",
)
p.add_argument(
"--skip-tables",
help="Comma-separated list of tables to skip",
default="migrations,servers_stats",
)
return p.parse_args(argv[1:])
def main(argv):
args = parse_args(argv)
logger = setup_logger(args.verbose, args.no_color)
sqlite_dir_path = Path(args.sqlite_dir)
if not sqlite_dir_path.is_dir():
logger.error("SQLite directory does not exist or is not a directory.")
raise SystemExit(1)
try:
migrate_data(
args.pg_conn,
str(sqlite_dir_path),
args.schema,
args.dry_run,
args.skip_tables,
logger,
args.batch_size,
)
except Exception as e:
logger.error(f"Fatal error: {e}")
raise SystemExit(2)
if __name__ == "__main__":
main(sys.argv)