From a6c91ca90fad27071cb882f1a41d52d6a363402e Mon Sep 17 00:00:00 2001 From: Ankit Wasankar Date: Sun, 12 Apr 2026 10:37:34 +0530 Subject: [PATCH] perf: eliminate N+1 SQL queries in _compute_summaries Replace per-community and per-node query loops with precomputed aggregate queries + executemany batch inserts. community_summaries: 2 queries per community -> 3 total - single GROUP BY query for top symbols across all communities - single flat SELECT for all file paths across all communities - executemany batch insert replaces per-community INSERT risk_index: 3 queries per node -> 3 total - single GROUP BY query for all caller counts - single SELECT for all tested qualified names - executemany batch insert replaces per-node INSERT On a repo with ~17k nodes and ~500 communities this reduces ~54,000 SQL round-trips to ~6, eliminating the build hang. --- code_review_graph/tools/build.py | 136 ++++++++++++++++++++----------- 1 file changed, 87 insertions(+), 49 deletions(-) diff --git a/code_review_graph/tools/build.py b/code_review_graph/tools/build.py index e92f8e3c..37de95b8 100644 --- a/code_review_graph/tools/build.py +++ b/code_review_graph/tools/build.py @@ -138,43 +138,67 @@ def _compute_summaries(store: Any) -> None: # -- community_summaries -- try: + from os.path import commonprefix as _commonprefix conn.execute("BEGIN IMMEDIATE") conn.execute("DELETE FROM community_summaries") - rows = conn.execute( + communities = conn.execute( "SELECT id, name, size, dominant_language FROM communities" ).fetchall() - for r in rows: + logger.info("Computing community summaries for %d communities...", len(communities)) + + # Precompute top-5 symbols per community in a single aggregate query + # (replaces 1 LEFT JOIN query per community). + top_sym_rows = conn.execute( + "SELECT n.community_id, n.name, " + "COUNT(e1.id) + COUNT(e2.id) AS edge_count " + "FROM nodes n " + "LEFT JOIN edges e1 ON e1.source_qualified = n.qualified_name " + "LEFT JOIN edges e2 ON e2.target_qualified = n.qualified_name " + "WHERE n.community_id IS NOT NULL AND n.kind != 'File' " + "GROUP BY n.community_id, n.id " + "ORDER BY n.community_id, edge_count DESC" + ).fetchall() + top_syms_by_comm: dict[int, list[str]] = {} + for sym_row in top_sym_rows: + comm_id = sym_row[0] + if comm_id not in top_syms_by_comm: + top_syms_by_comm[comm_id] = [] + if len(top_syms_by_comm[comm_id]) < 5: + top_syms_by_comm[comm_id].append(sym_row[1]) + + # Precompute file paths per community in a single query + # (replaces 1 SELECT DISTINCT query per community). + file_path_rows = conn.execute( + "SELECT community_id, file_path FROM nodes WHERE community_id IS NOT NULL" + ).fetchall() + paths_by_comm: dict[int, list[str]] = {} + for fp_row in file_path_rows: + comm_id = fp_row[0] + if comm_id not in paths_by_comm: + paths_by_comm[comm_id] = [] + paths_by_comm[comm_id].append(fp_row[1]) + + # Build all rows in Python, then batch-insert in one statement. + rows_to_insert = [] + for r in communities: cid, cname, csize, clang = r[0], r[1], r[2], r[3] - # Top 5 symbols by in+out edge count - top_symbols = conn.execute( - "SELECT n.name FROM nodes n " - "LEFT JOIN edges e1 ON e1.source_qualified = n.qualified_name " - "LEFT JOIN edges e2 ON e2.target_qualified = n.qualified_name " - "WHERE n.community_id = ? AND n.kind != 'File' " - "GROUP BY n.id ORDER BY COUNT(e1.id) + COUNT(e2.id) DESC " - "LIMIT 5", - (cid,), - ).fetchall() - key_syms = _json.dumps([s[0] for s in top_symbols]) - # Auto-generate purpose from common file path prefix - file_rows = conn.execute( - "SELECT DISTINCT file_path FROM nodes WHERE community_id = ? LIMIT 20", - (cid,), - ).fetchall() - paths = [fr[0] for fr in file_rows] + key_syms = _json.dumps(top_syms_by_comm.get(cid, [])) + paths = paths_by_comm.get(cid, []) purpose = "" if paths: - from os.path import commonprefix - prefix = commonprefix(paths) + prefix = _commonprefix(paths) if "/" in prefix: - purpose = prefix.rsplit("/", 1)[0].split("/")[-1] if "/" in prefix else "" - conn.execute( - "INSERT OR REPLACE INTO community_summaries " - "(community_id, name, purpose, key_symbols, size, dominant_language) " - "VALUES (?, ?, ?, ?, ?, ?)", - (cid, cname, purpose, key_syms, csize, clang or ""), - ) + purpose = prefix.rsplit("/", 1)[0].split("/")[-1] + rows_to_insert.append((cid, cname, purpose, key_syms, csize, clang or "")) + + conn.executemany( + "INSERT OR REPLACE INTO community_summaries " + "(community_id, name, purpose, key_symbols, size, dominant_language) " + "VALUES (?, ?, ?, ?, ?, ?)", + rows_to_insert, + ) conn.commit() + logger.info("Community summaries: %d rows written.", len(rows_to_insert)) except sqlite3.OperationalError: conn.rollback() # Table may not exist yet @@ -233,7 +257,6 @@ def _compute_summaries(store: Any) -> None: try: conn.execute("BEGIN IMMEDIATE") conn.execute("DELETE FROM risk_index") - # Per-node risk: caller_count, test coverage, security keywords nodes = conn.execute( "SELECT id, qualified_name, name FROM nodes " "WHERE kind IN ('Function', 'Class', 'Test')" @@ -242,23 +265,35 @@ def _compute_summaries(store: Any) -> None: "auth", "login", "password", "token", "session", "crypt", "secret", "credential", "permission", "sql", "execute", } + logger.info("Computing risk index for %d nodes...", len(nodes)) + + # Precompute caller counts for all nodes in one GROUP BY query + # (replaces 1 COUNT query per node). + caller_counts: dict[str, int] = { + row[0]: row[1] + for row in conn.execute( + "SELECT target_qualified, COUNT(*) FROM edges " + "WHERE kind = 'CALLS' GROUP BY target_qualified" + ).fetchall() + } + + # Precompute all tested qualified names in one query + # (replaces 1 COUNT query per node). + tested_qns: set[str] = { + row[0] + for row in conn.execute( + "SELECT source_qualified FROM edges WHERE kind = 'TESTED_BY'" + ).fetchall() + } + + # Compute risk scores in Python, then batch-insert in one statement. + risk_rows = [] for n in nodes: nid, qn, name = n[0], n[1], n[2] - # Count callers - caller_count = conn.execute( - "SELECT COUNT(*) FROM edges WHERE target_qualified = ? " - "AND kind = 'CALLS'", (qn,), - ).fetchone()[0] - # Test coverage - tested = conn.execute( - "SELECT COUNT(*) FROM edges WHERE source_qualified = ? " - "AND kind = 'TESTED_BY'", (qn,), - ).fetchone()[0] - coverage = "tested" if tested > 0 else "untested" - # Security relevance + caller_count = caller_counts.get(qn, 0) + coverage = "tested" if qn in tested_qns else "untested" name_lower = name.lower() sec_relevant = 1 if any(kw in name_lower for kw in security_kw) else 0 - # Compute risk score risk = 0.0 if caller_count > 10: risk += 0.3 @@ -269,14 +304,17 @@ def _compute_summaries(store: Any) -> None: if sec_relevant: risk += 0.4 risk = min(risk, 1.0) - conn.execute( - "INSERT OR REPLACE INTO risk_index " - "(node_id, qualified_name, risk_score, caller_count, " - "test_coverage, security_relevant, last_computed) " - "VALUES (?, ?, ?, ?, ?, ?, datetime('now'))", - (nid, qn, risk, caller_count, coverage, sec_relevant), - ) + risk_rows.append((nid, qn, risk, caller_count, coverage, sec_relevant)) + + conn.executemany( + "INSERT OR REPLACE INTO risk_index " + "(node_id, qualified_name, risk_score, caller_count, " + "test_coverage, security_relevant, last_computed) " + "VALUES (?, ?, ?, ?, ?, ?, datetime('now'))", + risk_rows, + ) conn.commit() + logger.info("Risk index: %d rows written.", len(risk_rows)) except sqlite3.OperationalError: conn.rollback()