mirror of
https://github.com/simplex-chat/simplex-chat.git
synced 2026-03-30 20:45:49 +00:00
scripts/db: update documentation, add pg -> sqlite section (#6177)
* scripts/db: update doc * scripts/db: update documentation, add pg2sqlite script * scripts/db/pg2sqlite: formatiing, code quality, modularity * scripts/db: parallel pgloader
This commit is contained in:
899
scripts/db/pg2sqlite.py
Executable file
899
scripts/db/pg2sqlite.py
Executable file
@@ -0,0 +1,899 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
PostgreSQL -> per-schema SQLite migration with colored, clean logging.
|
||||
|
||||
Usage example:
|
||||
python db_migrate.py 'postgresql://user:pass@host:5432/db' /path/to/sqlite/dir --dry-run
|
||||
|
||||
Note: color output will be disabled automatically if stdout is not a TTY or if --no-color is passed.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import sqlite3
|
||||
import sys
|
||||
import re
|
||||
import datetime
|
||||
from pathlib import Path
|
||||
from typing import Set, List, Tuple, Dict, Optional, NamedTuple
|
||||
|
||||
import psycopg
|
||||
from psycopg import sql
|
||||
|
||||
# ----------------------
|
||||
# Small utilities
|
||||
# ----------------------
|
||||
|
||||
try:
|
||||
sqlite3.register_adapter(
|
||||
datetime.datetime,
|
||||
lambda v: v.isoformat(sep=" ", timespec="microseconds"),
|
||||
)
|
||||
sqlite3.register_adapter(datetime.date, lambda v: v.isoformat())
|
||||
sqlite3.register_adapter(
|
||||
datetime.time, lambda v: v.isoformat(timespec="microseconds")
|
||||
)
|
||||
except ValueError:
|
||||
pass # already registered
|
||||
|
||||
ANSI = {
|
||||
"reset": "\x1b[0m",
|
||||
"bold": "\x1b[1m",
|
||||
"dim": "\x1b[2m",
|
||||
"red": "\x1b[31m",
|
||||
"green": "\x1b[32m",
|
||||
"yellow": "\x1b[33m",
|
||||
"blue": "\x1b[34m",
|
||||
"magenta": "\x1b[35m",
|
||||
"cyan": "\x1b[36m",
|
||||
"gray": "\x1b[90m",
|
||||
}
|
||||
|
||||
DEFAULT_BATCH_SIZE = 10000
|
||||
|
||||
TYPE_COMPATIBILITY = {
|
||||
"bytea": ["BLOB", "CHAR", "CLOB", "TEXT", "JSON"],
|
||||
"int": ["INT", "NUMERIC"],
|
||||
"serial": ["INT", "NUMERIC"],
|
||||
"numeric": ["NUMERIC", "DECIMAL", "REAL", "FLOAT", "DOUBLE"],
|
||||
"decimal": ["NUMERIC", "DECIMAL", "REAL", "FLOAT", "DOUBLE"],
|
||||
"real": ["NUMERIC", "DECIMAL", "REAL", "FLOAT", "DOUBLE"],
|
||||
"double": ["NUMERIC", "DECIMAL", "REAL", "FLOAT", "DOUBLE"],
|
||||
"float": ["NUMERIC", "DECIMAL", "REAL", "FLOAT", "DOUBLE"],
|
||||
"money": ["NUMERIC", "DECIMAL", "REAL", "FLOAT", "DOUBLE"],
|
||||
"bool": ["BOOL", "INT", "NUMERIC"],
|
||||
"varchar": ["CHAR", "CLOB", "TEXT"],
|
||||
"char": ["CHAR", "CLOB", "TEXT"],
|
||||
"text": ["CHAR", "CLOB", "TEXT"],
|
||||
"citext": ["CHAR", "CLOB", "TEXT"],
|
||||
"timestamp": ["DATE", "TIME", "CHAR", "TEXT", "DATETIME"],
|
||||
"time": ["DATE", "TIME", "CHAR", "TEXT", "DATETIME"],
|
||||
"date": ["DATE", "TIME", "CHAR", "TEXT", "DATETIME"],
|
||||
"uuid": ["CHAR", "TEXT", "UUID", "CLOB"],
|
||||
"json": ["JSON", "TEXT", "CHAR", "CLOB"],
|
||||
"jsonb": ["JSON", "TEXT", "CHAR", "CLOB"],
|
||||
}
|
||||
|
||||
|
||||
def _sanitize_cursor_name(s: str) -> str:
|
||||
return re.sub(r"[^A-Za-z0-9_]+", "_", s)
|
||||
|
||||
|
||||
def supports_color(force_no: bool) -> bool:
|
||||
"""Return True when we should emit ANSI colors."""
|
||||
if force_no:
|
||||
return False
|
||||
if os.getenv("NO_COLOR"):
|
||||
return False
|
||||
term = os.getenv("TERM", "")
|
||||
if term == "" or term.lower() == "dumb":
|
||||
return False
|
||||
try:
|
||||
isatty = sys.stdout.isatty()
|
||||
except Exception:
|
||||
isatty = False
|
||||
return isatty
|
||||
|
||||
|
||||
class ColoredFormatter(logging.Formatter):
|
||||
LEVEL_COLORS = {
|
||||
logging.DEBUG: ANSI["gray"],
|
||||
logging.INFO: ANSI["green"],
|
||||
logging.WARNING: ANSI["yellow"],
|
||||
logging.ERROR: ANSI["red"],
|
||||
logging.CRITICAL: ANSI["red"] + ANSI["bold"],
|
||||
}
|
||||
|
||||
TAG_COLORS = {
|
||||
"SKIP": ANSI["yellow"],
|
||||
"SCHEMA": ANSI["blue"],
|
||||
"OK": ANSI["magenta"],
|
||||
}
|
||||
|
||||
def __init__(self, use_color: bool = True):
|
||||
super().__init__(fmt="%(message)s")
|
||||
self.use_color = use_color
|
||||
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
msg = super().format(record)
|
||||
parts = msg.split(" ", 1)
|
||||
tag = parts[0]
|
||||
rest = parts[1] if len(parts) > 1 else ""
|
||||
plain_label = f"[{tag}]"
|
||||
|
||||
if not self.use_color:
|
||||
return f"{plain_label}{(' ' + rest) if rest else ''}"
|
||||
|
||||
color = self.TAG_COLORS.get(
|
||||
tag.upper(), self.LEVEL_COLORS.get(record.levelno, "")
|
||||
)
|
||||
reset = ANSI["reset"]
|
||||
return f"{color}{plain_label}{reset}{(' ' + rest) if rest else ''}"
|
||||
|
||||
|
||||
def setup_logger(verbose: bool, no_color: bool) -> logging.Logger:
|
||||
use_color = supports_color(no_color)
|
||||
logger = logging.getLogger("db_migrate")
|
||||
logger.setLevel(logging.DEBUG if verbose else logging.INFO)
|
||||
handler = logging.StreamHandler()
|
||||
handler.setLevel(logging.DEBUG if verbose else logging.INFO)
|
||||
handler.setFormatter(ColoredFormatter(use_color=use_color))
|
||||
logger.handlers.clear()
|
||||
logger.addHandler(handler)
|
||||
logging.getLogger("psycopg").setLevel(logging.WARNING)
|
||||
|
||||
if verbose:
|
||||
logger.debug(f"color_support: {use_color}")
|
||||
try:
|
||||
isatty = sys.stdout.isatty()
|
||||
except Exception:
|
||||
isatty = False
|
||||
logger.debug(
|
||||
f"TERM={os.getenv('TERM', '')!r} "
|
||||
f"NO_COLOR={os.getenv('NO_COLOR')!r} "
|
||||
f"isatty={isatty}"
|
||||
)
|
||||
return logger
|
||||
|
||||
|
||||
def quote_sqlite_identifier(name: str) -> str:
|
||||
return '"' + name.replace('"', '""') + '"'
|
||||
|
||||
|
||||
def quote_pg_identifier(name: str) -> str:
|
||||
"""Simple PG identifier quoting for building safe SQL strings."""
|
||||
return '"' + name.replace('"', '""') + '"'
|
||||
|
||||
|
||||
def sqlite_decl_satisfies(pg_type: str, sqlite_decl: str) -> bool:
|
||||
# Treat empty/blank SQLite declarations more permissively based on PG type.
|
||||
decl_raw = sqlite_decl or ""
|
||||
if not decl_raw.strip():
|
||||
pg = (pg_type or "").lower()
|
||||
# arrays -> textual affinity
|
||||
if pg.endswith("[]"):
|
||||
return True
|
||||
# integer-like
|
||||
if re.search(r"\b(?:int|serial|bigint)\b", pg):
|
||||
return True
|
||||
# numeric/float
|
||||
if re.search(r"\b(?:numeric|decimal|real|double|float|money)\b", pg):
|
||||
return True
|
||||
# boolean
|
||||
if re.search(r"\b(?:bool|boolean)\b", pg):
|
||||
return True
|
||||
# binary
|
||||
if re.search(r"\b(?:bytea)\b", pg):
|
||||
return True
|
||||
# textual/json/uuid/timestamps/dates/times
|
||||
if re.search(
|
||||
r"\b(?:varchar|char|text|citext|jsonb|json|uuid|timestamp|time|date)\b", pg
|
||||
):
|
||||
return True
|
||||
# conservative fallback: accept empty decl as permissive
|
||||
return True
|
||||
|
||||
decl = decl_raw.upper()
|
||||
pg = (pg_type or "").lower()
|
||||
|
||||
# array type
|
||||
if pg.endswith("[]"):
|
||||
return any(tok in decl for tok in ("TEXT", "CHAR", "CLOB", "JSON"))
|
||||
|
||||
for key, allowed_types in TYPE_COMPATIBILITY.items():
|
||||
if re.search(r"\b" + re.escape(key) + r"\b", pg):
|
||||
return any(tok in decl for tok in allowed_types)
|
||||
|
||||
return any(
|
||||
tok in decl for tok in ("TEXT", "CHAR", "CLOB", "NUMERIC", "BLOB", "INT")
|
||||
)
|
||||
|
||||
|
||||
# ----------------------
|
||||
# Postgres helpers
|
||||
# ----------------------
|
||||
|
||||
|
||||
def list_user_schemas(pg_cursor) -> List[str]:
|
||||
pg_cursor.execute(
|
||||
"""
|
||||
SELECT nspname
|
||||
FROM pg_namespace
|
||||
WHERE nspname NOT LIKE 'pg_%'
|
||||
AND nspname != 'information_schema'
|
||||
AND nspname != 'public'
|
||||
ORDER BY nspname;
|
||||
"""
|
||||
)
|
||||
return [r[0] for r in pg_cursor.fetchall()]
|
||||
|
||||
|
||||
def list_tables_in_schema(pg_cursor, schema: str) -> List[str]:
|
||||
pg_cursor.execute(
|
||||
sql.SQL(
|
||||
"""
|
||||
SELECT table_name
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = %s
|
||||
ORDER BY table_name;
|
||||
"""
|
||||
),
|
||||
(schema,),
|
||||
)
|
||||
return [r[0] for r in pg_cursor.fetchall()]
|
||||
|
||||
|
||||
def get_pg_columns(pg_cursor, schema: str, table: str) -> List[Tuple[str, str]]:
|
||||
pg_cursor.execute(
|
||||
sql.SQL(
|
||||
"""
|
||||
SELECT column_name, data_type
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = %s AND table_name = %s
|
||||
ORDER BY ordinal_position;
|
||||
"""
|
||||
),
|
||||
(schema, table),
|
||||
)
|
||||
return [(r[0], r[1]) for r in pg_cursor.fetchall()]
|
||||
|
||||
|
||||
# ----------------------
|
||||
# SQLite helpers
|
||||
# ----------------------
|
||||
|
||||
|
||||
class SQLiteCol(NamedTuple):
|
||||
cid: int
|
||||
name: str
|
||||
type: str
|
||||
notnull: int
|
||||
dflt_value: str
|
||||
pk: int
|
||||
|
||||
|
||||
def get_sqlite_table_info(sqlite_cursor, table_name: str) -> list[SQLiteCol]:
|
||||
qname = quote_sqlite_identifier(table_name)
|
||||
sqlite_cursor.execute(f"PRAGMA table_info({qname});")
|
||||
return [SQLiteCol(*row) for row in sqlite_cursor.fetchall()]
|
||||
|
||||
|
||||
def sqlite_table_has_autoincrement(sqlite_cursor, table_name: str) -> bool:
|
||||
sqlite_cursor.execute(
|
||||
"SELECT sql FROM sqlite_master WHERE type='table' AND name = ?;", (table_name,)
|
||||
)
|
||||
row = sqlite_cursor.fetchone()
|
||||
if not row or not row[0]:
|
||||
return False
|
||||
return "AUTOINCREMENT" in row[0].upper()
|
||||
|
||||
|
||||
def sqlite_sequence_table_exists(sqlite_cursor) -> bool:
|
||||
sqlite_cursor.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='sqlite_sequence';"
|
||||
)
|
||||
exists = sqlite_cursor.fetchone() is not None
|
||||
if not exists:
|
||||
logging.getLogger("db_migrate").debug("sqlite_sequence table not present")
|
||||
return exists
|
||||
|
||||
|
||||
# ----------------------
|
||||
# Core migration logic
|
||||
# ----------------------
|
||||
|
||||
|
||||
def build_select_for_sqlite_columns(
|
||||
schema: str,
|
||||
table: str,
|
||||
sqlite_cols: list[SQLiteCol],
|
||||
pg_columns_map: Dict[str, Tuple[str, str]],
|
||||
) -> str:
|
||||
"""Return a plain SQL string (no trailing semicolon) selecting PG columns in the order of sqlite_cols.
|
||||
Uses simple identifier quoting to avoid psycopg.sql objects which may vary by driver version.
|
||||
"""
|
||||
select_parts = []
|
||||
for col in sqlite_cols:
|
||||
lc = col.name.lower()
|
||||
if lc in pg_columns_map:
|
||||
pg_name, pg_type = pg_columns_map[lc]
|
||||
pg_type_l = (pg_type or "").lower()
|
||||
if any(tok in pg_type_l for tok in ("timestamp", "time")):
|
||||
expr = f"{quote_pg_identifier(pg_name)}::text"
|
||||
else:
|
||||
expr = f"{quote_pg_identifier(pg_name)}"
|
||||
select_parts.append(expr)
|
||||
else:
|
||||
select_parts.append("NULL")
|
||||
select_list = ", ".join(select_parts)
|
||||
# Quote schema and table names simply (this is not comprehensive for every corner case but avoids psycopg.sql use)
|
||||
return f"SELECT {select_list} FROM {quote_pg_identifier(schema)}.{quote_pg_identifier(table)}"
|
||||
|
||||
|
||||
def validate_column_compatibility(
|
||||
sqlite_cols: list[SQLiteCol],
|
||||
pg_columns_map: Dict[str, Tuple[str, str]],
|
||||
schema: str,
|
||||
table: str,
|
||||
) -> None:
|
||||
for col in sqlite_cols:
|
||||
lc = col.name.lower()
|
||||
if lc not in pg_columns_map:
|
||||
continue
|
||||
pg_type = pg_columns_map[lc][1]
|
||||
if not sqlite_decl_satisfies(pg_type, col.type):
|
||||
raise ValueError(
|
||||
f"Type mismatch for {schema}.{table}.{col.name}: "
|
||||
f"PG='{pg_type}' vs SQLite='{col.type}'"
|
||||
)
|
||||
|
||||
|
||||
def process_row_for_sqlite(row: Tuple, sqlite_cols: list[SQLiteCol]) -> Tuple:
|
||||
out = []
|
||||
for i, val in enumerate(row):
|
||||
col = sqlite_cols[i]
|
||||
decl_raw = col.type or ""
|
||||
decl = decl_raw.upper()
|
||||
decl_is_empty = decl_raw.strip() == ""
|
||||
|
||||
if isinstance(val, memoryview):
|
||||
b = val.tobytes()
|
||||
if "BLOB" in decl or decl_is_empty:
|
||||
out.append(sqlite3.Binary(b))
|
||||
else:
|
||||
try:
|
||||
out.append(b.decode("utf-8"))
|
||||
except UnicodeDecodeError as e:
|
||||
raise ValueError(
|
||||
f"UTF-8 decode failed for column '{col.name}': {e}"
|
||||
)
|
||||
elif isinstance(val, (bytes, bytearray)):
|
||||
b = bytes(val)
|
||||
if "BLOB" in decl or decl_is_empty:
|
||||
out.append(sqlite3.Binary(b))
|
||||
else:
|
||||
try:
|
||||
out.append(b.decode("utf-8"))
|
||||
except UnicodeDecodeError as e:
|
||||
raise ValueError(
|
||||
f"UTF-8 decode failed for column '{col.name}': {e}"
|
||||
)
|
||||
else:
|
||||
out.append(val)
|
||||
return tuple(out)
|
||||
|
||||
|
||||
def fetch_and_validate_row_batches(
|
||||
pg_cursor,
|
||||
select_sql: str,
|
||||
sqlite_cols: list[SQLiteCol],
|
||||
batch_size: int = DEFAULT_BATCH_SIZE,
|
||||
):
|
||||
pg_cursor.execute(select_sql)
|
||||
total_rows_seen = 0
|
||||
while True:
|
||||
rows = pg_cursor.fetchmany(batch_size)
|
||||
if not rows:
|
||||
break
|
||||
validated_batch = []
|
||||
for row in rows:
|
||||
total_rows_seen += 1
|
||||
try:
|
||||
validated_batch.append(process_row_for_sqlite(row, sqlite_cols))
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Row {total_rows_seen} validation error: {e}")
|
||||
yield validated_batch
|
||||
|
||||
|
||||
def _get_postgres_sequence_info(
|
||||
pg_cursor, schema: str, table: str, pk_col: str
|
||||
) -> Optional[tuple[str, int]]:
|
||||
"""Get PostgreSQL sequence name and last_value for a table primary key, if any.
|
||||
|
||||
Returns (sequence_name_text, last_value) or None.
|
||||
"""
|
||||
pg_cursor.execute(
|
||||
"""
|
||||
SELECT data_type
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = %s AND table_name = %s AND column_name = %s;
|
||||
""",
|
||||
(schema, table, pk_col),
|
||||
)
|
||||
r = pg_cursor.fetchone()
|
||||
if not r:
|
||||
return None
|
||||
|
||||
pg_type = (r[0] or "").lower()
|
||||
# Match whole words to avoid matching 'interval' etc.
|
||||
if not re.search(r"\b(?:int|serial|bigint)\b", pg_type):
|
||||
return None
|
||||
|
||||
# Get sequence name text (NULL if none)
|
||||
pg_cursor.execute(
|
||||
"SELECT pg_get_serial_sequence(%s, %s);",
|
||||
(f"{schema}.{table}", pk_col),
|
||||
)
|
||||
row = pg_cursor.fetchone()
|
||||
seq_name = row[0] if row else None
|
||||
if not seq_name:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Read last_value by casting the pg_get_serial_sequence result to regclass.
|
||||
pg_cursor.execute(
|
||||
"SELECT last_value FROM pg_get_serial_sequence(%s, %s)::regclass;",
|
||||
(f"{schema}.{table}", pk_col),
|
||||
)
|
||||
rr = pg_cursor.fetchone()
|
||||
if rr and rr[0] is not None:
|
||||
return (seq_name, int(rr[0]))
|
||||
except Exception:
|
||||
# Swallow errors (best-effort); callers treat missing info as absent
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _get_sqlite_max_pk_value(sqlite_cursor, table: str, pk_col: str) -> Optional[int]:
|
||||
"""Get the maximum primary key value from SQLite table."""
|
||||
logger = logging.getLogger("db_migrate")
|
||||
try:
|
||||
sqlite_cursor.execute(
|
||||
f"SELECT MAX({quote_sqlite_identifier(pk_col)}) "
|
||||
f"FROM {quote_sqlite_identifier(table)};"
|
||||
)
|
||||
r3 = sqlite_cursor.fetchone()
|
||||
if r3 and r3[0] is not None:
|
||||
return int(r3[0])
|
||||
except Exception:
|
||||
logger.debug("Failed to read sqlite max pk", exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
def _update_sqlite_sequence(
|
||||
sqlite_conn: sqlite3.Connection, table: str, sequence_value: int
|
||||
) -> bool:
|
||||
"""Update SQLite sqlite_sequence with new value.
|
||||
|
||||
IMPORTANT: This function does not commit; caller must commit/rollback.
|
||||
"""
|
||||
s_cur = sqlite_conn.cursor()
|
||||
s_cur.execute(
|
||||
"UPDATE sqlite_sequence SET seq = ? WHERE name = ?;",
|
||||
(sequence_value, table),
|
||||
)
|
||||
if s_cur.rowcount == 0:
|
||||
s_cur.execute(
|
||||
"INSERT INTO sqlite_sequence(name, seq) VALUES (?, ?);",
|
||||
(table, sequence_value),
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
def try_update_sqlite_sequence(
|
||||
pg_cursor,
|
||||
sqlite_conn: sqlite3.Connection,
|
||||
schema: str,
|
||||
pg_table: str, # PostgreSQL table name
|
||||
sqlite_table: str, # SQLite table name
|
||||
sqlite_cols: list[SQLiteCol],
|
||||
) -> bool:
|
||||
"""Update SQLite sequence table based on PostgreSQL sequence values."""
|
||||
logger = logging.getLogger("db_migrate")
|
||||
# Check if table has a single primary key
|
||||
pks = [c for c in sqlite_cols if c.pk]
|
||||
if len(pks) != 1:
|
||||
return False
|
||||
pk_col = pks[0].name
|
||||
|
||||
# Check if SQLite has AUTOINCREMENT (use sqlite_table)
|
||||
s_cur = sqlite_conn.cursor()
|
||||
if not sqlite_table_has_autoincrement(s_cur, sqlite_table):
|
||||
return False
|
||||
if not sqlite_sequence_table_exists(s_cur):
|
||||
return False
|
||||
|
||||
# Get PostgreSQL sequence info (use pg_table)
|
||||
seq_info = _get_postgres_sequence_info(pg_cursor, schema, pg_table, pk_col)
|
||||
pg_last_value = seq_info[1] if seq_info else None
|
||||
|
||||
# Get SQLite max PK value (use sqlite_table)
|
||||
sqlite_max = _get_sqlite_max_pk_value(s_cur, sqlite_table, pk_col)
|
||||
|
||||
# Determine the value to use
|
||||
if pg_last_value is not None and sqlite_max is not None:
|
||||
candidate = max(pg_last_value, sqlite_max)
|
||||
elif pg_last_value is not None:
|
||||
candidate = pg_last_value
|
||||
elif sqlite_max is not None:
|
||||
candidate = sqlite_max
|
||||
else:
|
||||
return False
|
||||
|
||||
updated = _update_sqlite_sequence(sqlite_conn, sqlite_table, candidate)
|
||||
if updated:
|
||||
try:
|
||||
sqlite_conn.commit()
|
||||
except Exception:
|
||||
logger.debug("Failed to commit sqlite_sequence update", exc_info=True)
|
||||
# Let caller proceed; treat as best-effort
|
||||
return updated
|
||||
|
||||
|
||||
# ----------------------
|
||||
# Flow: per-schema migration
|
||||
# ----------------------
|
||||
|
||||
|
||||
def migrate_schema(
|
||||
pg_conn,
|
||||
sqlite_dir: str,
|
||||
schema: str,
|
||||
skipped_tables: Set[str],
|
||||
logger: logging.Logger,
|
||||
dry_run: bool = False,
|
||||
batch_size: int = DEFAULT_BATCH_SIZE,
|
||||
) -> Tuple[int, int]:
|
||||
"""Migrate a single schema. Returns (tables_migrated, rows_inserted_total)."""
|
||||
processed_schema = schema[:-7] if schema.endswith("_schema") else schema
|
||||
sqlite_db_path = Path(sqlite_dir) / f"{processed_schema}.db"
|
||||
if not sqlite_db_path.is_file():
|
||||
logger.error(f"Missing SQLite DB: {sqlite_db_path}")
|
||||
return 0, 0
|
||||
|
||||
tables_migrated = 0
|
||||
rows_inserted = 0
|
||||
|
||||
sqlite_path = sqlite_db_path.resolve()
|
||||
if dry_run:
|
||||
uri = sqlite_path.as_uri() + "?mode=ro"
|
||||
conn_args = {"database": uri, "uri": True}
|
||||
else:
|
||||
conn_args = {"database": str(sqlite_path)}
|
||||
|
||||
with sqlite3.connect(**conn_args) as sqlite_conn:
|
||||
sqlite_cur = sqlite_conn.cursor()
|
||||
sqlite_cur.execute("SELECT name FROM sqlite_master WHERE type='table';")
|
||||
sqlite_tables = [r[0] for r in sqlite_cur.fetchall()]
|
||||
# O(1) lookup map for matching by lowercase name
|
||||
sqlite_table_map = {t.lower(): t for t in sqlite_tables}
|
||||
|
||||
with pg_conn.cursor() as pg_cur:
|
||||
# Ensure we start in a clean state for this connection
|
||||
try:
|
||||
pg_conn.rollback()
|
||||
except Exception:
|
||||
# Ignore rollback failure; connection is newly opened most likely
|
||||
logger.debug(
|
||||
"pg_conn.rollback() at schema start ignored", exc_info=True
|
||||
)
|
||||
|
||||
# Get table list for this schema, but catch/rollback on failure
|
||||
try:
|
||||
tables = list_tables_in_schema(pg_cur, schema)
|
||||
except Exception as e:
|
||||
logger.error("ERROR %s (list tables): %s", schema, e)
|
||||
logger.debug("Traceback (list tables):", exc_info=True)
|
||||
try:
|
||||
pg_conn.rollback()
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"pg_conn.rollback() failed after list tables error",
|
||||
exc_info=True,
|
||||
)
|
||||
return 0, 0
|
||||
|
||||
for table in tables:
|
||||
# Defensive: ensure connection is in a clean state before any new PG work.
|
||||
# A prior error can leave the connection in an aborted transaction; calling
|
||||
# rollback() clears that and allows subsequent SELECTs to run.
|
||||
try:
|
||||
pg_conn.rollback()
|
||||
except Exception:
|
||||
# ignore: best-effort cleanup
|
||||
logger.debug(
|
||||
"pg_conn.rollback() ignored at start of table loop",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
if table.lower() in skipped_tables:
|
||||
logger.info("SKIP %s.%s (explicit)", schema, table)
|
||||
continue
|
||||
if table.lower() not in sqlite_table_map:
|
||||
logger.info("SKIP %s.%s (no target table)", schema, table)
|
||||
continue
|
||||
|
||||
matched_table = sqlite_table_map[table.lower()]
|
||||
logger.info("OK %s.%s -> %s", schema, table, matched_table)
|
||||
|
||||
# Fetch PG columns, but defend against aborted transaction here
|
||||
try:
|
||||
pg_cols = get_pg_columns(pg_cur, schema, table)
|
||||
except Exception as e:
|
||||
logger.error("ERROR %s.%s (get columns): %s", schema, table, e)
|
||||
logger.debug("Traceback (get columns):", exc_info=True)
|
||||
try:
|
||||
pg_conn.rollback()
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"pg_conn.rollback() failed after get columns error",
|
||||
exc_info=True,
|
||||
)
|
||||
continue
|
||||
|
||||
pg_map: Dict[str, Tuple[str, str]] = {
|
||||
name.lower(): (name, dtype) for name, dtype in pg_cols
|
||||
}
|
||||
|
||||
sqlite_info = get_sqlite_table_info(sqlite_cur, matched_table)
|
||||
if not sqlite_info:
|
||||
logger.warning(f"SKIP {schema}.{table} (no sqlite info)")
|
||||
continue
|
||||
|
||||
# Warn once about schema drift (extra columns)
|
||||
pg_col_names = {name.lower() for name, _ in pg_cols}
|
||||
extra_in_sqlite = {
|
||||
c.name for c in sqlite_info if c.name.lower() not in pg_col_names
|
||||
}
|
||||
if extra_in_sqlite:
|
||||
logger.warning(
|
||||
f"Schema drift: {schema}.{table} has extra SQLite columns {extra_in_sqlite}"
|
||||
)
|
||||
|
||||
try:
|
||||
validate_column_compatibility(sqlite_info, pg_map, schema, table)
|
||||
except ValueError as e:
|
||||
logger.error(f"ERROR {schema}.{table} (type mismatch): {e}")
|
||||
continue
|
||||
|
||||
select_sql = build_select_for_sqlite_columns(
|
||||
schema, table, sqlite_info, pg_map
|
||||
)
|
||||
|
||||
csr_name = _sanitize_cursor_name(f"csr_{schema}_{table}")
|
||||
try:
|
||||
with pg_conn.cursor(name=csr_name) as data_cur:
|
||||
batch_gen = fetch_and_validate_row_batches(
|
||||
data_cur, select_sql, sqlite_info, batch_size=batch_size
|
||||
)
|
||||
first_batch = next(batch_gen, None)
|
||||
|
||||
if not first_batch:
|
||||
logger.info(f"SKIP {schema}.{table} (no rows)")
|
||||
# data_cur will be closed automatically on leaving the 'with'
|
||||
continue
|
||||
|
||||
if dry_run:
|
||||
count = len(first_batch)
|
||||
for batch in batch_gen:
|
||||
count += len(batch)
|
||||
inserted = 0
|
||||
logger.info(
|
||||
f"DRY {schema}.{table} rows_validated={count}"
|
||||
)
|
||||
else:
|
||||
cur = sqlite_conn.cursor()
|
||||
quoted_table = quote_sqlite_identifier(matched_table)
|
||||
col_names = [c.name for c in sqlite_info]
|
||||
quoted_cols = ", ".join(
|
||||
quote_sqlite_identifier(c) for c in col_names
|
||||
)
|
||||
placeholders = ", ".join(["?"] * len(col_names))
|
||||
|
||||
cur.execute("BEGIN;")
|
||||
try:
|
||||
cur.execute(f"DELETE FROM {quoted_table};")
|
||||
inserted = 0
|
||||
if first_batch:
|
||||
cur.executemany(
|
||||
f"INSERT INTO {quoted_table} ({quoted_cols}) VALUES ({placeholders});",
|
||||
first_batch,
|
||||
)
|
||||
inserted += len(first_batch)
|
||||
for batch in batch_gen:
|
||||
if not batch:
|
||||
continue
|
||||
cur.executemany(
|
||||
f"INSERT INTO {quoted_table} ({quoted_cols}) VALUES ({placeholders});",
|
||||
batch,
|
||||
)
|
||||
inserted += len(batch)
|
||||
except Exception:
|
||||
cur.execute("ROLLBACK;")
|
||||
raise
|
||||
else:
|
||||
cur.execute("COMMIT;")
|
||||
except ValueError as e:
|
||||
logger.error("ERROR %s.%s (row validation): %s", schema, table, e)
|
||||
try:
|
||||
pg_conn.rollback()
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"pg_conn.rollback() failed after row validation error",
|
||||
exc_info=True,
|
||||
)
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error("ERROR %s.%s (select failed): %s", schema, table, e)
|
||||
logger.debug("Traceback (select failed):", exc_info=True)
|
||||
try:
|
||||
pg_conn.rollback()
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"pg_conn.rollback() failed after select failed",
|
||||
exc_info=True,
|
||||
)
|
||||
continue
|
||||
|
||||
rows_inserted += inserted
|
||||
tables_migrated += 1
|
||||
logger.info(f"DONE {schema}.{table} rows={inserted}")
|
||||
|
||||
if not dry_run:
|
||||
try:
|
||||
if try_update_sqlite_sequence(
|
||||
pg_cur,
|
||||
sqlite_conn,
|
||||
schema,
|
||||
table,
|
||||
matched_table,
|
||||
sqlite_info,
|
||||
):
|
||||
logger.info(
|
||||
"SEQ %s.%s sqlite_sequence updated", schema, table
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"SEQ %s.%s update failed (ignored): %s", schema, table, e
|
||||
)
|
||||
|
||||
return tables_migrated, rows_inserted
|
||||
|
||||
|
||||
def migrate_data(
|
||||
pg_conn_str: str,
|
||||
sqlite_dir: str,
|
||||
schema_filter: Optional[str],
|
||||
dry_run: bool,
|
||||
skip_tables: str,
|
||||
logger: logging.Logger,
|
||||
batch_size: int,
|
||||
) -> None:
|
||||
skipped_tables = {t.strip().lower() for t in skip_tables.split(",") if t.strip()}
|
||||
total_tables = 0
|
||||
total_rows = 0
|
||||
total_errors = 0
|
||||
|
||||
# List schemas once with a short-lived connection
|
||||
with psycopg.connect(pg_conn_str) as tmp_conn:
|
||||
with tmp_conn.cursor() as cur:
|
||||
schemas = list_user_schemas(cur)
|
||||
|
||||
if schema_filter:
|
||||
schemas = [s for s in schemas if s == schema_filter]
|
||||
if not schemas:
|
||||
logger.error("No schemas to process")
|
||||
return
|
||||
|
||||
for schema in schemas:
|
||||
logger.info("SCHEMA %s", schema)
|
||||
if dry_run:
|
||||
logger.info("(dry-run) validating only — no writes will be performed")
|
||||
|
||||
# Use a fresh connection per-schema to isolate failures/aborted transactions
|
||||
try:
|
||||
with psycopg.connect(pg_conn_str) as pg_conn:
|
||||
migrated_tables, inserted_rows = migrate_schema(
|
||||
pg_conn,
|
||||
sqlite_dir,
|
||||
schema,
|
||||
skipped_tables,
|
||||
logger,
|
||||
dry_run=dry_run,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Schema %s failed: %s", schema, e)
|
||||
logger.debug("Traceback (schema failure):", exc_info=True)
|
||||
total_errors += 1
|
||||
continue
|
||||
|
||||
total_tables += migrated_tables
|
||||
total_rows += inserted_rows
|
||||
|
||||
logger.info("---")
|
||||
logger.info(
|
||||
f"SUMMARY: schemas={len(schemas)} "
|
||||
f"tables_migrated={total_tables} "
|
||||
f"rows_inserted={total_rows} "
|
||||
f"errors={total_errors}"
|
||||
)
|
||||
|
||||
|
||||
# ----------------------
|
||||
# CLI
|
||||
# ----------------------
|
||||
|
||||
|
||||
def parse_args(argv):
|
||||
p = argparse.ArgumentParser(
|
||||
description="Migrate Postgres data into per-schema SQLite DBs (non-invasive)."
|
||||
)
|
||||
p.add_argument("pg_conn", help="Postgres connection string")
|
||||
p.add_argument(
|
||||
"sqlite_dir", help="Directory containing per-schema sqlite .db files"
|
||||
)
|
||||
p.add_argument(
|
||||
"--schema", help="Only migrate this schema (exact match)", default=None
|
||||
)
|
||||
p.add_argument(
|
||||
"--dry-run",
|
||||
help="Validate and report only; do not write to sqlite",
|
||||
action="store_true",
|
||||
)
|
||||
p.add_argument("--verbose", help="Verbose logging (debug)", action="store_true")
|
||||
p.add_argument("--no-color", help="Disable ANSI color output", action="store_true")
|
||||
p.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=DEFAULT_BATCH_SIZE,
|
||||
help="Batch size for data fetching",
|
||||
)
|
||||
p.add_argument(
|
||||
"--skip-tables",
|
||||
help="Comma-separated list of tables to skip",
|
||||
default="migrations,servers_stats",
|
||||
)
|
||||
|
||||
return p.parse_args(argv[1:])
|
||||
|
||||
|
||||
def main(argv):
|
||||
args = parse_args(argv)
|
||||
logger = setup_logger(args.verbose, args.no_color)
|
||||
|
||||
sqlite_dir_path = Path(args.sqlite_dir)
|
||||
if not sqlite_dir_path.is_dir():
|
||||
logger.error("SQLite directory does not exist or is not a directory.")
|
||||
raise SystemExit(1)
|
||||
|
||||
try:
|
||||
migrate_data(
|
||||
args.pg_conn,
|
||||
str(sqlite_dir_path),
|
||||
args.schema,
|
||||
args.dry_run,
|
||||
args.skip_tables,
|
||||
logger,
|
||||
args.batch_size,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Fatal error: {e}")
|
||||
raise SystemExit(2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(sys.argv)
|
||||
Reference in New Issue
Block a user