diff --git a/.gitignore b/.gitignore index 8b3614cb..6b58f64c 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,4 @@ test/test_data/predictions/af* build/ *.egg-info/ tmp/ +test_logs/ diff --git a/.gitmodules b/.gitmodules index 954cdc1e..08c43b92 100644 --- a/.gitmodules +++ b/.gitmodules @@ -16,5 +16,5 @@ branch = main [submodule "alphafold3"] path = alphafold3 - url = https://github.com/google-deepmind/alphafold3.git + url = https://github.com/KosinskiLab/alphafold3.git branch = main diff --git a/README.md b/README.md index 8109865c..77bdd2b0 100644 --- a/README.md +++ b/README.md @@ -66,9 +66,10 @@ The same copy/range syntax also works with AlphaFold 3 JSON features. Examples: When a workflow or wrapper maps a logical token such as `Q8I2G6:1-100:150-200` to `Q8I2G6_af3_input.json:1-100:150-200`, AlphaPulldown preserves the region -selection and expands the AF3 JSON feature input into separate cropped chain(s). -For the AlphaFold 3 backend, discontinuous regions are modeled as separate -chains, so they are not connected by a peptide bond. +selection and keeps the AF3 JSON feature input as one discontinuous polymer +chain with preserved residue-number gaps. For the AlphaFold 3 backend this +means chopped regions stay intra-chain, so template contacts between retained +fragments are not masked as inter-chain interactions. For workflow deployments, make sure the execution environment also carries `alphapulldown-input-parser>=0.4.0`. diff --git a/alphafold3 b/alphafold3 index 2e3703e8..6ad1a659 160000 --- a/alphafold3 +++ b/alphafold3 @@ -1 +1 @@ -Subproject commit 2e3703e82a9592efbb3fa76ca9e0714aedabacdb +Subproject commit 6ad1a65994c2111d291a386cdc048d8c9bfae4af diff --git a/alphapulldown/folding_backend/alphafold3_backend.py b/alphapulldown/folding_backend/alphafold3_backend.py index 0c518883..ffe6c84d 100644 --- a/alphapulldown/folding_backend/alphafold3_backend.py +++ b/alphapulldown/folding_backend/alphafold3_backend.py @@ -9,15 +9,17 @@ import csv import dataclasses import functools +import hashlib import inspect import json import logging import os import pathlib +import re import time import typing from collections.abc import Sequence -from typing import List, Dict, Union, overload +from typing import Any, List, Dict, Union, overload import alphafold3.cpp import haiku as hk @@ -32,6 +34,7 @@ from alphafold3.model import features, params, post_processing from alphafold3.model import model from alphafold3.model.components import utils +from alphafold3.structure import mmcif as af3_mmcif from jax import numpy as jnp from alphafold.common import residue_constants @@ -186,9 +189,14 @@ def write_outputs( for sample_idx, result in enumerate(results_for_seed.inference_results): sample_dir = os.path.join(output_dir, f'seed-{seed}_sample-{sample_idx}') os.makedirs(sample_dir, exist_ok=True) + result = _make_viewer_compatible_inference_result(result) post_processing.write_output( inference_result=result, output_dir=sample_dir ) + _augment_confidence_json_with_author_numbering( + os.path.join(sample_dir, 'confidences.json'), + result, + ) ranking_score = float(result.metadata['ranking_score']) ranking_scores.append((seed, sample_idx, ranking_score)) if max_ranking_score is None or ranking_score > max_ranking_score: @@ -218,12 +226,332 @@ def write_outputs( terms_of_use=output_terms, name=job_name, ) + _augment_confidence_json_with_author_numbering( + os.path.join(output_dir, f'{job_name}_confidences.json'), + max_ranking_result, + ) with open(os.path.join(output_dir, 'ranking_scores.csv'), 'wt') as f: writer = csv.writer(f) writer.writerow(['seed', 'sample', 'ranking_score']) writer.writerows(ranking_scores) +def _duplicate_occurrence_to_insertion_code( + occurrence_index: int, + *, + strict: bool = True, +) -> str: + """Maps the Nth occurrence of a residue ID to a mmCIF insertion code.""" + if occurrence_index <= 1: + return '.' + offset = occurrence_index - 2 + if offset >= 26: + if strict: + raise ValueError( + 'More than 27 repeated residue occurrences in one chain are not ' + 'supported for mmCIF insertion-code output.' + ) + return '.' + return chr(ord('A') + offset) + + +def _normalise_output_name_fragment(raw_name: str) -> str: + """Normalises one output-name fragment while preserving readable IDs.""" + cleaned = re.sub(r"\s+", "_", raw_name.strip()) + cleaned = re.sub(r"[^A-Za-z0-9_.-]+", "_", cleaned) + cleaned = cleaned.strip("_.-") + return cleaned or "ranked_0" + + +def _collapse_repeated_name_fragments( + fragments: Sequence[str], +) -> list[str]: + """Collapses consecutive identical fragments into a readable count suffix.""" + if not fragments: + return [] + + collapsed: list[str] = [] + current_fragment = fragments[0] + current_count = 1 + + for fragment in fragments[1:]: + if fragment == current_fragment: + current_count += 1 + continue + + collapsed.append( + current_fragment + if current_count == 1 + else f"{current_fragment}__x{current_count}" + ) + current_fragment = fragment + current_count = 1 + + collapsed.append( + current_fragment + if current_count == 1 + else f"{current_fragment}__x{current_count}" + ) + return collapsed + + +def _compact_output_job_name(job_name: str, *, max_chars: int = 200) -> str: + """Keeps job names readable while staying below common filename limits.""" + if len(job_name) <= max_chars: + return job_name + + digest = hashlib.sha1(job_name.encode("utf-8")).hexdigest()[:12] + suffix = f"__{digest}" + prefix = job_name[: max_chars - len(suffix)].rstrip("_.-") + if not prefix: + return f"job{suffix}" + return f"{prefix}{suffix}" + + +def _compact_existing_compound_name(raw_name: str) -> str: + """Compacts already-joined `_and_` names such as multimer descriptions.""" + parts = [ + _normalise_output_name_fragment(part) + for part in raw_name.split("_and_") + if part.strip("_.-") + ] + if not parts: + return "ranked_0" + return _compact_output_job_name( + "_and_".join(_collapse_repeated_name_fragments(parts)) + ) + + +def _regions_to_name_fragment(regions: Sequence[tuple[int, int]]) -> str: + """Returns a readable name fragment for a set of closed residue intervals.""" + return "_".join(f"{start}-{end}" for start, end in regions) + + +def _json_input_basename(json_path: str) -> str: + """Returns a readable basename for an AF3 JSON input path.""" + stem = pathlib.Path(json_path).stem + for suffix in ("_af3_input", "_input"): + if stem.endswith(suffix): + stem = stem[: -len(suffix)] + break + return stem or pathlib.Path(json_path).stem + + +def _object_name_fragment(obj: typing.Any) -> str: + """Builds a deterministic output-name fragment for one modelling object.""" + if isinstance(obj, dict) and "json_input" in obj: + fragment = _json_input_basename(obj["json_input"]) + regions = obj.get("regions") + if isinstance(regions, Sequence) and regions: + fragment = f"{fragment}__{_regions_to_name_fragment(regions)}" + return _normalise_output_name_fragment(fragment) + + if isinstance(obj, MultimericObject): + return _compact_existing_compound_name(obj.description or "multimer") + + if isinstance(obj, (MonomericObject, ChoppedObject)): + return _normalise_output_name_fragment(obj.description or "monomer") + + if isinstance(obj, folding_input.Input): + return _compact_existing_compound_name(obj.name) + + return _normalise_output_name_fragment(type(obj).__name__) + + +def _build_output_job_name(objects_to_model: Sequence[dict]) -> str: + """Builds a readable AF3 job name from the requested modelling objects.""" + fragments: list[str] = [] + for entry in objects_to_model: + object_to_model = entry["object"] + if isinstance(object_to_model, list): + fragments.extend(_object_name_fragment(obj) for obj in object_to_model) + else: + fragments.append(_object_name_fragment(object_to_model)) + fragments = [fragment for fragment in fragments if fragment] + if not fragments: + return "ranked_0" + readable_name = "_and_".join(_collapse_repeated_name_fragments(fragments)) + return _compact_output_job_name(readable_name) + + +def _residue_author_ids(struc) -> list[str]: + """Returns author-facing residue IDs, falling back to residue IDs if unset.""" + author_residue_ids = [str(residue_id) for residue_id in struc.residues_table.auth_seq_id] + if all(residue_id in {".", "?"} for residue_id in author_residue_ids): + return [str(int(residue_id)) for residue_id in struc.residues_table.id] + return author_residue_ids + + +def _existing_insertion_codes(struc) -> list[str]: + """Returns normalised residue insertion codes from a structure.""" + return [ + "." + if insertion_code in {".", "?", ""} + else str(insertion_code) + for insertion_code in struc.residues_table.insertion_code + ] + + +def _author_ids_with_insertion_codes( + chain_ids: Sequence[str], + author_residue_ids: Sequence[str], + existing_insertion_codes: Sequence[str] | None = None, + *, + strict: bool = True, +) -> tuple[list[str], list[str], list[str]]: + """Returns author IDs, insertion codes, and combined author labels.""" + occurrence_count_by_residue: dict[tuple[str, str], int] = {} + insertion_codes: list[str] = [] + combined_labels: list[str] = [] + + for index, (chain_id, residue_id) in enumerate( + zip(chain_ids, author_residue_ids, strict=True) + ): + explicit_insertion_code = "." + overflow_occurrence = 0 + if existing_insertion_codes is not None: + explicit_insertion_code = existing_insertion_codes[index] + + if explicit_insertion_code not in {".", "?", ""}: + insertion_code = explicit_insertion_code + else: + key = (chain_id, residue_id) + occurrence = occurrence_count_by_residue.get(key, 0) + 1 + occurrence_count_by_residue[key] = occurrence + insertion_code = _duplicate_occurrence_to_insertion_code( + occurrence, + strict=strict, + ) + if insertion_code == "." and occurrence > 1: + overflow_occurrence = occurrence + + insertion_codes.append(insertion_code) + if insertion_code == ".": + if overflow_occurrence and not strict: + combined_labels.append(f"{residue_id}[{overflow_occurrence}]") + else: + combined_labels.append(residue_id) + else: + combined_labels.append(f"{residue_id}{insertion_code}") + + return list(author_residue_ids), insertion_codes, combined_labels + + +def _coerce_json_scalar(value: str) -> int | str: + """Converts a stringified integer back to int where possible.""" + try: + return int(value) + except (TypeError, ValueError): + return value + + +def _augment_confidence_json_with_author_numbering( + confidences_path: os.PathLike[str] | str, + inference_result: model.InferenceResult, +) -> None: + """Adds preserved author numbering to the confidence sidecar JSON.""" + token_auth_res_ids = inference_result.metadata.get("token_auth_res_ids") + token_pdb_ins_codes = inference_result.metadata.get("token_pdb_ins_codes") + token_auth_res_labels = inference_result.metadata.get("token_auth_res_labels") + if ( + token_auth_res_ids is None + or token_pdb_ins_codes is None + or token_auth_res_labels is None + ): + return + + with open(confidences_path, "rt", encoding="utf-8") as handle: + confidence_data = json.load(handle) + + confidence_data["token_label_seq_ids"] = [ + int(token_id) for token_id in confidence_data.get("token_res_ids", []) + ] + confidence_data["token_auth_res_ids"] = [ + _coerce_json_scalar(str(token_id)) for token_id in token_auth_res_ids + ] + confidence_data["token_pdb_ins_codes"] = [str(code) for code in token_pdb_ins_codes] + confidence_data["token_auth_res_labels"] = [ + str(label) for label in token_auth_res_labels + ] + + with open(confidences_path, "wt", encoding="utf-8") as handle: + json.dump(confidence_data, handle, indent=1) + handle.write("\n") + + +def _make_viewer_compatible_inference_result( + inference_result: model.InferenceResult, +) -> model.InferenceResult: + """Creates a viewer-safe copy with sequential label IDs and preserved auth IDs.""" + struc = inference_result.predicted_structure + residue_chain_ids = [ + str(chain_id) + for chain_id in struc.chains_table.apply_array_to_column( + column_name='id', + arr=struc.residues_table.chain_key, + ) + ] + author_residue_ids = _residue_author_ids(struc) + existing_insertion_codes = _existing_insertion_codes(struc) + + sequential_label_ids = np.asarray( + _sequential_residue_ids_per_chain(residue_chain_ids), + dtype=np.int32, + ) + ( + author_residue_ids, + insertion_codes, + _, + ) = _author_ids_with_insertion_codes( + residue_chain_ids, + author_residue_ids, + existing_insertion_codes, + ) + + viewer_structure = struc.copy_and_update_residues( + res_id=sequential_label_ids, + res_auth_seq_id=np.asarray(author_residue_ids, dtype=object), + res_insertion_code=np.asarray(insertion_codes, dtype=object), + ) + + metadata = dict(inference_result.metadata) + token_chain_ids = [ + str(chain_id) + for chain_id in metadata.get("token_chain_ids", []) + ] + if token_chain_ids and "token_res_ids" in metadata: + token_author_ids = [str(token_id) for token_id in metadata["token_res_ids"]] + ( + token_author_ids, + token_insertion_codes, + token_author_labels, + ) = _author_ids_with_insertion_codes( + token_chain_ids, + token_author_ids, + strict=False, + ) + metadata["token_res_ids"] = _sequential_residue_ids_per_chain(token_chain_ids) + metadata["token_auth_res_ids"] = token_author_ids + metadata["token_pdb_ins_codes"] = token_insertion_codes + metadata["token_auth_res_labels"] = token_author_labels + return dataclasses.replace( + inference_result, + predicted_structure=viewer_structure, + metadata=metadata, + ) + + +def _sequential_residue_ids_per_chain(chain_ids: Sequence[str]) -> list[int]: + """Returns sequential residue IDs that are unique within each chain.""" + next_residue_id_by_chain: dict[str, int] = {} + residue_ids = [] + for chain_id in chain_ids: + next_residue_id = next_residue_id_by_chain.get(chain_id, 0) + 1 + next_residue_id_by_chain[chain_id] = next_residue_id + residue_ids.append(next_residue_id) + return residue_ids + + def predict_structure( fold_input: folding_input.Input, model_runner: ModelRunner, @@ -510,6 +838,286 @@ def _construct_chain(chain_cls: type, **kwargs): } return chain_cls(**filtered_kwargs) + def _chain_input_sequence( + chain: ( + folding_input.ProteinChain + | folding_input.RnaChain + | folding_input.DnaChain + ), + ) -> str: + canonical_sequence = getattr(chain, "_sequence", None) + if isinstance(canonical_sequence, str): + return canonical_sequence + return chain.sequence + + def _clone_chain_with_id(chain, new_id: str): + description = getattr(chain, "description", None) + if isinstance(chain, folding_input.ProteinChain): + return _construct_chain( + folding_input.ProteinChain, + id=new_id, + sequence=_chain_input_sequence(chain), + description=description, + residue_ids=getattr(chain, "residue_ids", None), + ptms=chain.ptms, + paired_msa=chain.paired_msa, + unpaired_msa=chain.unpaired_msa, + templates=chain.templates, + ) + if isinstance(chain, folding_input.RnaChain): + return _construct_chain( + folding_input.RnaChain, + id=new_id, + sequence=_chain_input_sequence(chain), + description=description, + residue_ids=getattr(chain, "residue_ids", None), + modifications=chain.modifications, + unpaired_msa=chain.unpaired_msa, + ) + if isinstance(chain, folding_input.DnaChain): + return _construct_chain( + folding_input.DnaChain, + id=new_id, + sequence=_chain_input_sequence(chain), + description=description, + residue_ids=getattr(chain, "residue_ids", None), + modifications=chain.modifications(), + ) + if isinstance(chain, folding_input.Ligand): + return _construct_chain( + folding_input.Ligand, + id=new_id, + ccd_ids=chain.ccd_ids, + smiles=chain.smiles, + description=description, + ) + raise TypeError(f"Unsupported chain type: {type(chain)}") + + def _adjacent_duplicate_keep_indices( + residue_ids: Sequence[int] | None, + ) -> np.ndarray | None: + if residue_ids is None: + return None + residue_ids_array = np.asarray(residue_ids, dtype=np.int32).reshape(-1) + if residue_ids_array.size <= 1: + return None + keep_mask = np.ones(residue_ids_array.shape[0], dtype=bool) + keep_mask[1:] = residue_ids_array[1:] != residue_ids_array[:-1] + if np.all(keep_mask): + return None + return np.flatnonzero(keep_mask) + + def _slice_sequence_like( + value: str | bytes | bytearray, + keep_indices: np.ndarray, + *, + tolerate_shorter_input: bool = False, + context: str = "sequence", + ) -> str | bytes: + as_text = ( + value.decode("utf-8") + if isinstance(value, (bytes, bytearray)) + else str(value) + ) + if keep_indices.size: + max_index = int(np.max(keep_indices)) + if max_index >= len(as_text): + if tolerate_shorter_input: + logging.warning( + "Leaving %s unsliced because length %d is shorter " + "than requested keep index %d", + context, + len(as_text), + max_index, + ) + return value + raise IndexError( + f"{context} length {len(as_text)} is shorter than keep index {max_index}" + ) + sliced = "".join(as_text[int(index)] for index in keep_indices.tolist()) + if isinstance(value, (bytes, bytearray)): + return sliced.encode("utf-8") + return sliced + + def _normalise_monomeric_inputs_for_af3( + mono_obj: Union[MonomericObject, ChoppedObject], + ) -> tuple[str, Dict[str, Any], list[int] | None, bool]: + sequence = mono_obj.sequence + feature_dict = mono_obj.feature_dict + residue_index = feature_dict.get("residue_index") + if residue_index is None: + return sequence, feature_dict, None, False + + residue_ids = ( + np.asarray(residue_index, dtype=np.int32).reshape(-1) + 1 + ).astype(int).tolist() + keep_indices = _adjacent_duplicate_keep_indices(residue_ids) + if keep_indices is None: + return sequence, feature_dict, residue_ids, False + + if len(sequence) != len(residue_ids): + logging.warning( + "Skipping adjacent-duplicate AF3 normalization for %s because " + "sequence length %d != residue ID count %d", + mono_obj.description, + len(sequence), + len(residue_ids), + ) + return sequence, feature_dict, residue_ids, False + + normalized_feature_dict = dict(feature_dict) + normalized_sequence = _slice_sequence_like(sequence, keep_indices) + if not isinstance(normalized_sequence, str): + normalized_sequence = normalized_sequence.decode("utf-8") + normalized_residue_ids = [residue_ids[int(index)] for index in keep_indices.tolist()] + + for key in ("aatype", "between_segment_residues", "residue_index"): + if key in normalized_feature_dict: + normalized_feature_dict[key] = np.asarray( + normalized_feature_dict[key] + )[keep_indices] + + for key in ( + "msa", + "deletion_matrix_int", + "deletion_matrix", + "msa_all_seq", + "deletion_matrix_int_all_seq", + "deletion_matrix_all_seq", + ): + if key in normalized_feature_dict: + normalized_feature_dict[key] = np.asarray( + normalized_feature_dict[key] + )[:, keep_indices] + + for key in ( + "template_aatype", + "template_all_atom_masks", + "template_confidence_scores", + ): + if key in normalized_feature_dict: + normalized_feature_dict[key] = np.asarray( + normalized_feature_dict[key] + )[:, keep_indices] + + if "template_all_atom_positions" in normalized_feature_dict: + normalized_feature_dict["template_all_atom_positions"] = np.asarray( + normalized_feature_dict["template_all_atom_positions"] + )[:, keep_indices, :, :] + + if "template_sequence" in normalized_feature_dict: + normalized_feature_dict["template_sequence"] = np.array( + [ + _slice_sequence_like( + template_sequence, + keep_indices, + tolerate_shorter_input=True, + context="template_sequence", + ) + for template_sequence in normalized_feature_dict["template_sequence"] + ], + dtype=object, + ) + + if "sequence" in normalized_feature_dict: + normalized_feature_dict["sequence"] = np.array( + [normalized_sequence.encode("utf-8")] + ) + if "seq_length" in normalized_feature_dict: + normalized_feature_dict["seq_length"] = np.full( + len(normalized_sequence), + len(normalized_sequence), + dtype=np.int32, + ) + if "num_alignments" in normalized_feature_dict: + num_alignments = int( + np.asarray(normalized_feature_dict["num_alignments"]).reshape(-1)[0] + ) + normalized_feature_dict["num_alignments"] = np.full( + len(normalized_sequence), + num_alignments, + dtype=np.int32, + ) + + logging.info( + "Collapsed %d adjacent duplicate residue position(s) for %s to " + "satisfy AF3 input constraints", + len(residue_ids) - len(normalized_residue_ids), + mono_obj.description, + ) + return ( + normalized_sequence, + normalized_feature_dict, + normalized_residue_ids, + True, + ) + + @functools.lru_cache(maxsize=None) + def _validate_json_template_mmcif( + mmcif_string: str, + ) -> tuple[bool, str | None]: + try: + af3_mmcif.from_string(mmcif_string) + except Exception as exc: + return False, str(exc) + return True, None + + def _sanitize_json_input_chain_templates( + chain: ( + folding_input.ProteinChain + | folding_input.RnaChain + | folding_input.DnaChain + | folding_input.Ligand + ), + *, + json_path: str, + ): + if not isinstance(chain, folding_input.ProteinChain): + return chain + if chain.templates is None: + return chain + + valid_templates = [] + dropped_template_count = 0 + for template_index, template in enumerate(chain.templates): + is_valid, error_message = _validate_json_template_mmcif( + template.mmcif + ) + if is_valid: + valid_templates.append(template) + continue + + dropped_template_count += 1 + logging.warning( + "Skipping invalid template %d for JSON input %s chain %s: %s", + template_index, + json_path, + chain.id, + error_message, + ) + + if dropped_template_count == 0: + return chain + + logging.info( + "Kept %d/%d template(s) for JSON input %s chain %s after validation", + len(valid_templates), + len(chain.templates), + json_path, + chain.id, + ) + return _construct_chain( + folding_input.ProteinChain, + id=chain.id, + sequence=_chain_input_sequence(chain), + description=getattr(chain, "description", None), + residue_ids=getattr(chain, "residue_ids", None), + ptms=chain.ptms, + paired_msa=chain.paired_msa, + unpaired_msa=chain.unpaired_msa, + templates=valid_templates, + ) + def _slice_a3m_row_to_region(a3m_row: str, start: int, end: int) -> str: query_position = 0 sliced_chars: list[str] = [] @@ -555,6 +1163,33 @@ def _slice_a3m_to_region( + "\n" ) + def _slice_a3m_to_regions( + a3m_text: str | None, + regions: Sequence[tuple[int, int]], + ) -> str | None: + if a3m_text in (None, "") or not regions: + return a3m_text + + sequences, descriptions = af3_parsers.parse_fasta(a3m_text) + sliced_sequences = [ + "".join( + _slice_a3m_row_to_region(sequence, start, end) + for start, end in regions + ) + for sequence in sequences + ] + return ( + "\n".join( + f">{description}\n{sequence}" + for description, sequence in zip( + descriptions, + sliced_sequences, + strict=True, + ) + ) + + "\n" + ) + def _slice_templates_to_region( templates: Sequence[folding_input.Template] | None, start: int, @@ -591,7 +1226,75 @@ def _slice_positioned_modifications_to_region( if start <= modification_position <= end ] - def _slice_af3_chain_to_region( + def _slice_positioned_modifications_to_regions( + modifications: Sequence[tuple[str, int]], + regions: Sequence[tuple[int, int]], + ) -> list[tuple[str, int]]: + sliced_modifications: list[tuple[str, int]] = [] + offset = 0 + for start, end in regions: + sliced_modifications.extend( + ( + modification_name, + offset + modification_position - start + 1, + ) + for modification_name, modification_position in modifications + if start <= modification_position <= end + ) + offset += end - start + 1 + return sliced_modifications + + def _slice_templates_to_regions( + templates: Sequence[folding_input.Template] | None, + regions: Sequence[tuple[int, int]], + ) -> Sequence[folding_input.Template] | None: + if templates is None: + return None + + sliced_templates = [] + for template in templates: + remapped_indices = {} + offset = 0 + for start, end in regions: + start_index = start - 1 + remapped_indices.update({ + offset + (query_index - start_index): template_index + for query_index, template_index in template.query_to_template_map.items() + if start_index <= query_index < end + }) + offset += end - start + 1 + if remapped_indices: + sliced_templates.append( + folding_input.Template( + mmcif=template.mmcif, + query_to_template_map=remapped_indices, + ) + ) + return sliced_templates + + def _chain_residue_ids( + chain: ( + folding_input.ProteinChain + | folding_input.RnaChain + | folding_input.DnaChain + | folding_input.Ligand + ), + ) -> list[int] | None: + residue_ids = getattr(chain, "residue_ids", None) + if residue_ids is not None: + return [int(residue_id) for residue_id in residue_ids] + if isinstance( + chain, + ( + folding_input.ProteinChain, + folding_input.RnaChain, + folding_input.DnaChain, + ), + ): + return list(range(1, len(_chain_input_sequence(chain)) + 1)) + return None + + def _slice_af3_chain_to_regions( chain: ( folding_input.ProteinChain | folding_input.RnaChain @@ -599,8 +1302,7 @@ def _slice_af3_chain_to_region( | folding_input.Ligand ), *, - start: int, - end: int, + regions: Sequence[tuple[int, int]], json_path: str, ): if isinstance(chain, folding_input.Ligand): @@ -608,38 +1310,45 @@ def _slice_af3_chain_to_region( f"Region ranges are not supported for ligand AF3 JSON inputs: {json_path}" ) - sequence_length = len(chain.sequence) - if not 1 <= start <= end <= sequence_length: - raise ValueError( - f"Requested region {start}-{end} is outside the sequence " - f"length {sequence_length} for AF3 JSON input {json_path}." - ) + chain_sequence = _chain_input_sequence(chain) + sequence_length = len(chain_sequence) + for start, end in regions: + if not 1 <= start <= end <= sequence_length: + raise ValueError( + f"Requested region {start}-{end} is outside the sequence " + f"length {sequence_length} for AF3 JSON input {json_path}." + ) - sliced_sequence = chain.sequence[start - 1:end] + sliced_sequence = "".join( + chain_sequence[start - 1:end] + for start, end in regions + ) + sliced_residue_ids = [ + residue_id + for start, end in regions + for residue_id in (_chain_residue_ids(chain) or [])[start - 1:end] + ] if isinstance(chain, folding_input.ProteinChain): return _construct_chain( folding_input.ProteinChain, id=chain.id, sequence=sliced_sequence, - ptms=_slice_positioned_modifications_to_region( + ptms=_slice_positioned_modifications_to_regions( chain.ptms, - start, - end, + regions, ), - unpaired_msa=_slice_a3m_to_region( + residue_ids=sliced_residue_ids, + unpaired_msa=_slice_a3m_to_regions( chain.unpaired_msa, - start, - end, + regions, ), - paired_msa=_slice_a3m_to_region( + paired_msa=_slice_a3m_to_regions( chain.paired_msa, - start, - end, + regions, ), - templates=_slice_templates_to_region( + templates=_slice_templates_to_regions( chain.templates, - start, - end, + regions, ), ) @@ -648,15 +1357,14 @@ def _slice_af3_chain_to_region( folding_input.RnaChain, id=chain.id, sequence=sliced_sequence, - modifications=_slice_positioned_modifications_to_region( + modifications=_slice_positioned_modifications_to_regions( chain.modifications, - start, - end, + regions, ), - unpaired_msa=_slice_a3m_to_region( + residue_ids=sliced_residue_ids, + unpaired_msa=_slice_a3m_to_regions( chain.unpaired_msa, - start, - end, + regions, ), ) @@ -665,11 +1373,11 @@ def _slice_af3_chain_to_region( folding_input.DnaChain, id=chain.id, sequence=sliced_sequence, - modifications=_slice_positioned_modifications_to_region( + modifications=_slice_positioned_modifications_to_regions( chain.modifications(), - start, - end, + regions, ), + residue_ids=sliced_residue_ids, ) raise TypeError(f"Unsupported chain type for AF3 JSON slicing: {type(chain)}") @@ -689,33 +1397,31 @@ def _expand_json_input_chains( | folding_input.RnaChain | folding_input.DnaChain | folding_input.Ligand - ]: + ]: + sanitized_chains = [ + _sanitize_json_input_chain_templates(chain, json_path=json_path) + for chain in chains + ] if not regions: - return list(chains) + return list(sanitized_chains) - if len(chains) != 1: + if len(sanitized_chains) != 1: raise ValueError( "Region ranges for AF3 JSON feature inputs require exactly " - f"one chain per file, but {json_path} contains {len(chains)} chains." + f"one chain per file, but {json_path} contains {len(sanitized_chains)} chains." ) - base_chain = chains[0] - expanded_chains = [ - _slice_af3_chain_to_region( - base_chain, - start=start, - end=end, - json_path=json_path, - ) - for start, end in regions - ] + merged_chain = _slice_af3_chain_to_regions( + sanitized_chains[0], + regions=regions, + json_path=json_path, + ) logging.info( - "Expanded AF3 JSON input %s into %d chain(s) for regions %s", + "Collapsed AF3 JSON input %s into one gapped chain for regions %s", json_path, - len(expanded_chains), regions, ) - return expanded_chains + return [merged_chain] def get_chain_id(index: int) -> str: if index < 26: @@ -885,8 +1591,9 @@ def _monomeric_to_chain( mono_obj: Union[MonomericObject, ChoppedObject], chain_id: str ) -> folding_input.ProteinChain: - sequence = mono_obj.sequence - feature_dict = mono_obj.feature_dict + sequence, feature_dict, residue_ids, _ = _normalise_monomeric_inputs_for_af3( + mono_obj + ) msa_array = feature_dict.get('msa') deletion_matrix = feature_dict.get('deletion_matrix_int') if deletion_matrix is None: @@ -997,6 +1704,7 @@ def _monomeric_to_chain( id=chain_id, sequence=sequence, ptms=[], + residue_ids=residue_ids, description=mono_obj.description, unpaired_msa=unpaired_msa, paired_msa=paired_msa, @@ -1007,12 +1715,10 @@ def _monomeric_to_chain( def _expand_monomeric_object( mono_obj: Union[MonomericObject, ChoppedObject], ) -> list[Union[MonomericObject, ChoppedObject]]: - if isinstance(mono_obj, ChoppedObject): - return mono_obj.split_into_individual_region_objects() return [mono_obj] def _process_single_object(obj, chain_id_counter_ref, used_chain_ids: set): - nonlocal all_chains, job_name + nonlocal all_chains if isinstance(obj, dict) and 'json_input' in obj: json_path = obj['json_input'] json_regions = obj.get('regions') @@ -1049,53 +1755,9 @@ def _process_single_object(obj, chain_id_counter_ref, used_chain_ids: set): used_chain_ids.add(new_id) logging.info(f"Added chain ID '{new_id}' from JSON file {json_path}") - # Create a new chain with the modified ID - if isinstance(chain, folding_input.ProteinChain): - modified_chain = _construct_chain( - folding_input.ProteinChain, - id=new_id, - sequence=chain.sequence, - ptms=chain.ptms, - description=getattr(chain, "description", None), - paired_msa=chain.paired_msa, - unpaired_msa=chain.unpaired_msa, - templates=chain.templates, - ) - elif isinstance(chain, folding_input.RnaChain): - modified_chain = _construct_chain( - folding_input.RnaChain, - id=new_id, - sequence=chain.sequence, - modifications=chain.modifications, - description=getattr(chain, "description", None), - unpaired_msa=chain.unpaired_msa, - ) - elif isinstance(chain, folding_input.DnaChain): - modified_chain = _construct_chain( - folding_input.DnaChain, - id=new_id, - sequence=chain.sequence, - description=getattr(chain, "description", None), - modifications=chain.modifications(), - ) - elif isinstance(chain, folding_input.Ligand): - modified_chain = _construct_chain( - folding_input.Ligand, - id=new_id, - ccd_ids=chain.ccd_ids, - smiles=chain.smiles, - description=getattr(chain, "description", None), - ) - else: - raise TypeError(f"Unsupported chain type: {type(chain)}") - - modified_chains.append(modified_chain) + modified_chains.append(_clone_chain_with_id(chain, new_id)) all_chains.extend(modified_chains) - if len(all_chains) == len(modified_chains): - job_name = input_obj.name - else: - job_name = f"{job_name}_and_{input_obj.name}" except Exception as e: logging.error(f"Failed to parse JSON file {json_path}: {e}") raise @@ -1112,55 +1774,60 @@ def _process_single_object(obj, chain_id_counter_ref, used_chain_ids: set): chains = [] translated_result = None combined_msa = obj.feature_dict.get('msa') - expanded_interactors = [ - expanded_interactor - for interactor in obj.interactors - for expanded_interactor in _expand_monomeric_object(interactor) + expanded_interactors = list(obj.interactors) + normalized_interactor_inputs = [ + _normalise_monomeric_inputs_for_af3(interactor) + for interactor in expanded_interactors ] - expanded_discontinuous_regions = len(expanded_interactors) != len(obj.interactors) + has_adjacent_duplicate_residue_ids = any( + normalized_input[3] for normalized_input in normalized_interactor_inputs + ) translated_result = ( translate_af2_individual_chain_features_to_af3_msas_with_stats( chain_feature_dicts=[ - interactor.feature_dict for interactor in expanded_interactors + normalized_input[1] + for normalized_input in normalized_interactor_inputs ], chain_sequences=[ - interactor.sequence for interactor in expanded_interactors + normalized_input[0] + for normalized_input in normalized_interactor_inputs ], ) ) - if not expanded_discontinuous_regions: - num_pairable_chains = sum( - chain_stats.paired_species_identifier_count > 0 - for chain_stats in translated_result.chain_stats - ) - if num_pairable_chains < 2 and combined_msa is not None: - # Fall back to the merged AF2 multimer MSA transport path when the - # individual `_all_seq` features do not carry usable species IDs. - translated_result = ( - translate_af2_complex_msa_to_af3_unpaired_chain_msas_with_stats( - merged_msa=np.asarray(combined_msa), - chain_sequences=[ - interactor.sequence for interactor in obj.interactors - ], - num_alignments=obj.feature_dict.get('num_alignments'), - deletion_matrix=( - obj.feature_dict.get('deletion_matrix_int') - if obj.feature_dict.get('deletion_matrix_int') is not None - else obj.feature_dict.get('deletion_matrix') - ), - asym_id=obj.feature_dict.get('asym_id'), - ) - ) - expanded_interactors = list(obj.interactors) - else: + num_pairable_chains = sum( + chain_stats.paired_species_identifier_count > 0 + for chain_stats in translated_result.chain_stats + ) + if ( + has_adjacent_duplicate_residue_ids + and num_pairable_chains < 2 + and combined_msa is not None + ): logging.info( - "Expanded discontinuous chopped interactors into %d AF3 chains " - "for job %s; keeping per-chain AF2-derived MSAs because AF3 " - "cannot encode polymer chain breaks within one protein chain.", - len(expanded_interactors), + "Skipping AF2 merged-MSA fallback for multimeric object %s " + "because adjacent duplicate residue IDs were collapsed for AF3 input compatibility.", obj.description, ) + elif num_pairable_chains < 2 and combined_msa is not None: + # Fall back to the merged AF2 multimer MSA transport path when the + # individual `_all_seq` features do not carry usable species IDs. + translated_result = ( + translate_af2_complex_msa_to_af3_unpaired_chain_msas_with_stats( + merged_msa=np.asarray(combined_msa), + chain_sequences=[ + interactor.sequence for interactor in obj.interactors + ], + num_alignments=obj.feature_dict.get('num_alignments'), + deletion_matrix=( + obj.feature_dict.get('deletion_matrix_int') + if obj.feature_dict.get('deletion_matrix_int') is not None + else obj.feature_dict.get('deletion_matrix') + ), + asym_id=obj.feature_dict.get('asym_id'), + ) + ) + expanded_interactors = list(obj.interactors) for chain_index, interactor in enumerate(expanded_interactors): chain_id = get_next_available_chain_id(used_chain_ids, chain_id_counter_ref) @@ -1173,8 +1840,9 @@ def _process_single_object(obj, chain_id_counter_ref, used_chain_ids: set): base_chain = _construct_chain( folding_input.ProteinChain, id=base_chain.id, - sequence=base_chain.sequence, + sequence=_chain_input_sequence(base_chain), ptms=base_chain.ptms, + residue_ids=getattr(base_chain, "residue_ids", None), description=interactor.description, unpaired_msa=chain_msas.unpaired_msa, paired_msa=chain_msas.paired_msa, @@ -1213,7 +1881,6 @@ def _process_single_object(obj, chain_id_counter_ref, used_chain_ids: set): chain_id_counter = [0] # Use a list to allow pass-by-reference used_chain_ids = set() # Track used chain IDs all_chains = [] - job_name = "ranked_0" # Track chains whose MSAs were translated from AF2 features; they must # not be rewritten by the promotion heuristic below. af2_translated_msa_chain_ids: set[str] = set() @@ -1253,8 +1920,9 @@ def _process_single_object(obj, chain_id_counter_ref, used_chain_ids: set): new_chain = _construct_chain( folding_input.ProteinChain, id=ch.id, - sequence=ch.sequence, + sequence=_chain_input_sequence(ch), ptms=ch.ptms, + residue_ids=getattr(ch, "residue_ids", None), description=getattr(ch, "description", None), paired_msa=ch.unpaired_msa, unpaired_msa='', @@ -1267,6 +1935,7 @@ def _process_single_object(obj, chain_id_counter_ref, used_chain_ids: set): promoted_chains.append(ch) all_chains = promoted_chains + job_name = _build_output_job_name(objects_to_model) combined_input = folding_input.Input( name=job_name, rng_seeds=[random_seed], diff --git a/alphapulldown/objects.py b/alphapulldown/objects.py index 76d6bf34..9da07e86 100644 --- a/alphapulldown/objects.py +++ b/alphapulldown/objects.py @@ -477,9 +477,9 @@ def prepare_final_sliced_feature_dict(self) -> None: def split_into_individual_region_objects(self) -> List["ChoppedObject"]: """Return one chopped object per requested region. - This is used by backends such as AlphaFold 3 that cannot encode - discontinuous regions as a single polymer chain without introducing a - peptide bond between adjacent modeled residues. + This helper is retained for backends or experiments that need each + requested region as an independent object instead of one discontinuous + polymer chain. """ if len(self.regions) <= 1: return [self] diff --git a/conftest.py b/conftest.py index eec62fb4..01c1e302 100755 --- a/conftest.py +++ b/conftest.py @@ -1,5 +1,15 @@ import pytest + +def pytest_addoption(parser): + parser.addoption( + "--use-temp-dir", + action="store_true", + default=False, + help="Run functional test suites with isolated temporary output directories.", + ) + + @pytest.hookimpl(tryfirst=True) def pytest_itemcollected(item): try: diff --git a/test/check_alphafold2_predictions.py b/test/check_alphafold2_predictions.py index 171ac52b..90c0b696 100755 --- a/test/check_alphafold2_predictions.py +++ b/test/check_alphafold2_predictions.py @@ -13,6 +13,7 @@ import sys import tempfile import logging +import unittest from pathlib import Path from absl.testing import absltest, parameterized @@ -38,6 +39,37 @@ # from alphafold.model import config # config.CONFIG_MULTIMER.model.embeddings_and_evoformer.evoformer_num_block = 1 + +def _has_nvidia_gpu() -> bool: + nvidia_smi = shutil.which("nvidia-smi") + if not nvidia_smi: + return False + try: + result = subprocess.run( + [nvidia_smi, "-L"], + capture_output=True, + text=True, + check=False, + ) + except OSError: + return False + return result.returncode == 0 and bool(result.stdout.strip()) + + +def _gpu_functional_test_skip_reason() -> str | None: + if os.getenv("RUN_GPU_FUNCTIONAL_TESTS", "").lower() in ("1", "true", "yes"): + return None + if os.getenv("CI", "").lower() in ("1", "true", "yes") or os.getenv( + "GITHUB_ACTIONS", "" + ).lower() == "true": + return ( + "GPU functional tests are disabled on CI/CD. " + "Set RUN_GPU_FUNCTIONAL_TESTS=1 to override." + ) + if not _has_nvidia_gpu(): + return "GPU functional tests require an NVIDIA GPU and nvidia-smi." + return None + # --------------------------------------------------------------------------- # # common helper mix-in / assertions # # --------------------------------------------------------------------------- # @@ -47,6 +79,9 @@ class _TestBase(parameterized.TestCase): @classmethod def setUpClass(cls): super().setUpClass() + skip_reason = _gpu_functional_test_skip_reason() + if skip_reason: + raise unittest.SkipTest(skip_reason) # do the skip here so import-time doesn't abort discovery #if not DATA_DIR.is_dir(): # cls.skipTest(f"set $ALPHAFOLD_DATA_DIR to run Alphafold functional tests (tried {DATA_DIR!r})") @@ -80,7 +115,8 @@ def setUp(self): self.test_features_dir = self.test_data_dir / "features" self.test_protein_lists_dir = self.test_data_dir / "protein_lists" self.test_modelling_dir = self.test_data_dir / "predictions" - self.af2_backend_dir = self.test_modelling_dir / "af2_backend" + # setUpClass already resolved this to either a temp root or the legacy shared root + self.af2_backend_dir = self.base_output_dir test_name = self._testMethodName self.output_dir = self.af2_backend_dir / test_name @@ -99,12 +135,12 @@ def _runCommonTests(self, res: subprocess.CompletedProcess, multimer: bool, dirn f"STDERR:\n{res.stderr}" ) - dirs = [dirname] if dirname else [ - d for d in self.output_dir.iterdir() if d.is_dir() - ] + if dirname is not None: + folders = [self.output_dir / dirname] + else: + folders = [d for d in self.output_dir.iterdir() if d.is_dir()] - for d in dirs: - folder = self.output_dir / d + for folder in folders: files = list(folder.iterdir()) self.assertEqual( @@ -210,10 +246,9 @@ class TestResume(_TestBase): def setUp(self): super().setUp() self.protein_lists = self.test_protein_lists_dir / "test_dimer.txt" - self.af2_backend_dir.mkdir(parents=True, exist_ok=True) - - source = self.test_modelling_dir / "TEST_and_TEST" - target = self.af2_backend_dir / "TEST_and_TEST" + # Resume tests need a pre-populated per-test output tree to continue from. + source = self.test_modelling_dir / "TEST_homo_2er" + target = self.output_dir / "TEST_homo_2er" shutil.copytree(source, target, dirs_exist_ok=True) self.base_args = [ @@ -405,4 +440,4 @@ def test_dropout_increases_diversity(self): # The test passes if calculations succeed - the diversity check is informational if __name__ == "__main__": - absltest.main() \ No newline at end of file + absltest.main() diff --git a/test/check_alphafold3_predictions.py b/test/check_alphafold3_predictions.py index 9eaf8120..b38daeea 100755 --- a/test/check_alphafold3_predictions.py +++ b/test/check_alphafold3_predictions.py @@ -11,17 +11,20 @@ import time import sys import tempfile +import hashlib from pathlib import Path import shutil import pickle import json import numpy as np import re +import unittest from typing import Dict, List, Tuple, Any from absl.testing import absltest, parameterized import alphapulldown +from alphafold3.constants import residue_names as af3_residue_names from alphapulldown.objects import MultimericObject from alphapulldown.utils.modelling_setup import ( create_custom_info, @@ -43,6 +46,37 @@ absltest.skip("set $ALPHAFOLD_DATA_DIR to run Alphafold functional tests") +def _has_nvidia_gpu() -> bool: + nvidia_smi = shutil.which("nvidia-smi") + if not nvidia_smi: + return False + try: + result = subprocess.run( + [nvidia_smi, "-L"], + capture_output=True, + text=True, + check=False, + ) + except OSError: + return False + return result.returncode == 0 and bool(result.stdout.strip()) + + +def _gpu_functional_test_skip_reason() -> str | None: + if os.getenv("RUN_GPU_FUNCTIONAL_TESTS", "").lower() in ("1", "true", "yes"): + return None + if os.getenv("CI", "").lower() in ("1", "true", "yes") or os.getenv( + "GITHUB_ACTIONS", "" + ).lower() == "true": + return ( + "GPU functional tests are disabled on CI/CD. " + "Set RUN_GPU_FUNCTIONAL_TESTS=1 to override." + ) + if not _has_nvidia_gpu(): + return "GPU functional tests require an NVIDIA GPU and nvidia-smi." + return None + + # --------------------------------------------------------------------------- # # common helper mix-in / assertions # # --------------------------------------------------------------------------- # @@ -52,6 +86,9 @@ class _TestBase(parameterized.TestCase): @classmethod def setUpClass(cls): super().setUpClass() + skip_reason = _gpu_functional_test_skip_reason() + if skip_reason: + raise unittest.SkipTest(skip_reason) # Create a base directory for all test outputs if cls.use_temp_dir: cls.base_output_dir = Path(tempfile.mkdtemp(prefix="af3_test_")) @@ -186,7 +223,7 @@ def _get_sequence_from_json(self, json_file: str) -> List[Tuple[str, str]]: def _apply_ptms_to_sequence(self, sequence: str, modifications: List[Dict]) -> str: """ - Apply post-translational modifications to a protein sequence. + Apply PTMs to the expected structure-side sequence representation. Args: sequence: Original protein sequence @@ -204,17 +241,14 @@ def _apply_ptms_to_sequence(self, sequence: str, modifications: List[Dict]) -> s if ptm_position < len(seq_list): if ptm_type == "HYS": - # N-terminal histidine modification - replace N-terminal methionine with HYS - if ptm_position == 0 and seq_list[0] == 'M': - # Replace M with H (histidine) - HYS is the CCD code, but we use H for sequence - seq_list[0] = 'H' + seq_list[ptm_position] = "H" elif ptm_type == "2MG": - # 2-methylguanosine modification - replace G with modified G - # For simplicity, we'll keep it as G since the exact representation may vary - pass - # Add more PTM types as needed + seq_list[ptm_position] = "G" else: - print(f"Warning: Unknown PTM type '{ptm_type}' at position {ptm_position + 1}") + seq_list[ptm_position] = af3_residue_names.letters_three_to_one( + ptm_type, + default='X', + ) return ''.join(seq_list) @@ -345,18 +379,26 @@ def _process_homo_oligomer_chopped_line(self, line: str) -> List[Tuple[str, str] if "-" in region_str: s, e = region_str.split("-") regions.append((int(s), int(e))) - - region_sequences = self._get_region_sequences(protein_name, regions) + + # AF3 cannot represent immediately repeated author residue IDs at a + # region boundary (e.g. 6-7 followed by 7-8). Collapse only that shared + # boundary residue while keeping the explicit region naming unchanged. + normalized_regions = [] + for start, end in regions: + if normalized_regions and start == normalized_regions[-1][1]: + start += 1 + if start <= end: + normalized_regions.append((start, end)) + + region_sequences = self._get_region_sequences(protein_name, normalized_regions) if not region_sequences: return [] - # AF3 cannot encode polymer chain breaks inside one protein chain, so - # discontinuous regions are modeled as separate protein chains. + concatenated_sequence = "".join(region_sequences) sequences = [] - for _ in range(num_copies): - for region_sequence in region_sequences: - chain_id = self._chain_id_from_index(len(sequences)) - sequences.append((chain_id, region_sequence)) + for copy_index in range(num_copies): + chain_id = self._chain_id_from_index(copy_index) + sequences.append((chain_id, concatenated_sequence)) return sequences @@ -445,9 +487,10 @@ def parse_protein_and_regions(part: str): if "," in part: protein_name, regions = parse_protein_and_regions(part) region_sequences = self._get_region_sequences(protein_name, regions) - for region_sequence in region_sequences: - chain_id = self._chain_id_from_index(len(sequences)) - sequences.append((chain_id, region_sequence)) + if not region_sequences: + continue + chain_id = self._chain_id_from_index(len(sequences)) + sequences.append((chain_id, "".join(region_sequences))) else: protein_name = part sequence = self._get_sequence_for_protein(protein_name) @@ -461,12 +504,9 @@ def parse_protein_and_regions(part: str): part = line.strip() if "," in part: protein_name, regions = parse_protein_and_regions(part) - return [ - (self._chain_id_from_index(i), region_sequence) - for i, region_sequence in enumerate( - self._get_region_sequences(protein_name, regions) - ) - ] + region_sequences = self._get_region_sequences(protein_name, regions) + if region_sequences: + return [('A', "".join(region_sequences))] else: protein_name = part sequence = self._get_sequence_for_protein(protein_name) @@ -476,7 +516,7 @@ def parse_protein_and_regions(part: str): def _extract_cif_chains_and_sequences(self, cif_path: Path) -> List[Tuple[str, str]]: """ - Extract chain IDs and sequences from a CIF file using Biopython. + Extract chain IDs and sequences from a CIF file. Args: cif_path: Path to the CIF file @@ -486,6 +526,69 @@ def _extract_cif_chains_and_sequences(self, cif_path: Path) -> List[Tuple[str, s """ chains_and_sequences = [] + try: + from alphafold3.cpp import cif_dict + + with open(cif_path, "rt") as handle: + cif = cif_dict.from_string(handle.read()) + + sequences_by_chain = {} + + if "_pdbx_poly_seq_scheme.asym_id" in cif: + asym_ids = cif.get_array("_pdbx_poly_seq_scheme.asym_id", dtype=object) + mon_ids = cif.get_array("_pdbx_poly_seq_scheme.mon_id", dtype=object) + + for chain_id, mon_id in zip(asym_ids, mon_ids, strict=True): + sequence = sequences_by_chain.setdefault(chain_id, "") + if mon_id in self._protein_letters_3to1: + sequence += self._protein_letters_3to1[mon_id] + elif mon_id in self._dna_letters_3to1: + sequence += self._dna_letters_3to1[mon_id] + elif mon_id in self._rna_letters_3to1: + sequence += self._rna_letters_3to1[mon_id] + elif mon_id + " " in self._rna_letters_3to1: + sequence += self._rna_letters_3to1[mon_id + " "] + elif mon_id + " " in self._dna_letters_3to1: + sequence += self._dna_letters_3to1[mon_id + " "] + elif mon_id == "HYS": + sequence += "H" + elif mon_id == "2MG": + sequence += "G" + else: + sequence += "X" + sequences_by_chain[chain_id] = sequence + + for scheme_prefix in ("_pdbx_nonpoly_scheme", "_pdbx_branch_scheme"): + asym_key = f"{scheme_prefix}.asym_id" + mon_key = f"{scheme_prefix}.mon_id" + if asym_key not in cif or mon_key not in cif: + continue + asym_ids = cif.get_array(asym_key, dtype=object) + mon_ids = cif.get_array(mon_key, dtype=object) + for chain_id, mon_id in zip(asym_ids, mon_ids, strict=True): + if mon_id in {"HOH", "DOD"}: + continue + sequence = sequences_by_chain.setdefault(chain_id, "") + ligand_codes = [] if not sequence else sequence.split("+") + ligand_codes.append(mon_id if mon_id in self._ligand_ccd_codes else "UNKNOWN") + sequences_by_chain[chain_id] = "+".join(ligand_codes) + + chain_order = ( + list(cif.get_array("_struct_asym.id", dtype=object)) + if "_struct_asym.id" in cif + else list(sequences_by_chain.keys()) + ) + for chain_id in chain_order: + sequence = sequences_by_chain.get(chain_id) + if sequence: + chains_and_sequences.append((chain_id, sequence)) + if chains_and_sequences: + return chains_and_sequences + except ImportError: + pass + except Exception as e: + print(f"Error parsing CIF with AF3 cif_dict: {e}") + try: from Bio.PDB import MMCIFParser @@ -500,9 +603,9 @@ def _extract_cif_chains_and_sequences(self, cif_path: Path) -> List[Tuple[str, s for chain in model: chain_id = chain.id - # Get residues in order + # Keep the residue order from the file instead of sorting by + # residue number so discontinuous numbering remains testable. residues = list(chain.get_residues()) - residues.sort(key=lambda r: r.id[1]) # Sort by residue number # Separate standard residues from HETATM records standard_residues = [] @@ -583,6 +686,45 @@ def _extract_cif_chains_and_sequences(self, cif_path: Path) -> List[Tuple[str, s return chains_and_sequences + def _extract_cif_chain_residue_numbers(self, cif_path: Path) -> List[Tuple[str, List[Union[int, str]]]]: + """Extract author-facing residue numbers for each polymer chain from a CIF file.""" + try: + from alphafold3.cpp import cif_dict + + with open(cif_path, "rt") as handle: + cif = cif_dict.from_string(handle.read()) + + asym_ids = cif.get_array("_pdbx_poly_seq_scheme.asym_id", dtype=object) + auth_seq_nums = cif.get_array( + "_pdbx_poly_seq_scheme.auth_seq_num", dtype=object + ) + ins_codes = cif.get_array( + "_pdbx_poly_seq_scheme.pdb_ins_code", dtype=object + ) + + chain_residue_numbers = [] + chain_to_numbers = {} + for chain_id, auth_seq_num, ins_code in zip( + asym_ids, + auth_seq_nums, + ins_codes, + strict=True, + ): + residue_numbers = chain_to_numbers.setdefault(chain_id, []) + ins_code = str(ins_code) + auth_seq_num = int(auth_seq_num) + if ins_code in {".", "?"}: + residue_numbers.append(auth_seq_num) + else: + residue_numbers.append(f"{auth_seq_num}{ins_code}") + + for chain_id, residue_numbers in chain_to_numbers.items(): + if residue_numbers: + chain_residue_numbers.append((chain_id, residue_numbers)) + return chain_residue_numbers + except Exception as exc: + self.fail(f"Failed to extract CIF residue numbers from {cif_path}: {exc}") + def _apply_ptms_from_hetatm(self, sequence: str, hetatm_residues: List[Tuple[int, str]]) -> str: """ Apply PTMs from HETATM records to the protein sequence. @@ -753,41 +895,45 @@ def _extract_cif_chains_and_sequences_regex(self, cif_path: Path) -> List[Tuple[ return chains_and_sequences - def _get_ptm_positions(self, protein_list: str) -> List[int]: - """ - Extract PTM positions from JSON files for a given protein list. - - Args: - protein_list: Name of the protein list file - - Returns: - List of PTM positions (1-based) - """ - ptm_positions = [] - - # Read the protein list file - protein_list_path = self.test_protein_lists_dir / protein_list - with open(protein_list_path, 'r') as f: - lines = [line.strip() for line in f.readlines() if line.strip()] - - for line in lines: - if line.endswith('.json'): - json_path = self.test_features_dir / line - if json_path.exists(): - with open(json_path, 'r') as f: - json_data = json.load(f) - - json_sequences = json_data.get('sequences', []) - for seq_data in json_sequences: - if 'protein' in seq_data: - protein_seq = seq_data['protein'] - modifications = protein_seq.get('modifications', []) - for ptm in modifications: - ptm_position = ptm.get('ptmPosition') - if ptm_position: - ptm_positions.append(ptm_position) - - return ptm_positions + def _assert_exact_chain_mapping( + self, + expected_sequences: List[Tuple[str, str]], + actual_chains_and_sequences: List[Tuple[str, str]], + *, + context: str, + ) -> None: + """Assert an exact chain-id to sequence mapping, independent of file order.""" + expected_dict = dict(expected_sequences) + actual_dict = dict(actual_chains_and_sequences) + + self.assertLen( + expected_dict, + len(expected_sequences), + f"{context}: expected chain IDs must be unique", + ) + self.assertLen( + actual_dict, + len(actual_chains_and_sequences), + f"{context}: actual chain IDs must be unique", + ) + + print(f"Expected exact chain mapping for {context}: {expected_dict}") + print(f"Actual exact chain mapping for {context}: {actual_dict}") + + self.assertEqual( + actual_dict, + expected_dict, + f"{context}: exact chain mapping mismatch", + ) + + def _requires_exact_chain_mapping(self, protein_list: str) -> bool: + """Cases where inference must preserve the explicit input chain IDs.""" + return protein_list in { + "test_monomer_with_rna.txt", + "test_monomer_with_dna.txt", + "test_monomer_with_ligand.txt", + "test_protein_with_ptms.txt", + } def _check_chain_counts_and_sequences(self, protein_list: str): """ @@ -823,83 +969,27 @@ def _check_chain_counts_and_sequences(self, protein_list: str): len(expected_sequences), f"Expected {len(expected_sequences)} chains, but found {len(actual_chains_and_sequences)}" ) - - # Check if this is a PTM case - ptm_positions = self._get_ptm_positions(protein_list) - is_ptm_case = len(ptm_positions) > 0 - - if is_ptm_case: - # For PTM cases, check that sequences are reasonable for PTM cases - print(f"PTM case detected. PTM positions: {ptm_positions}") - self._check_sequences_with_ptms(expected_sequences, actual_chains_and_sequences, ptm_positions) - else: - # For non-PTM cases, check exact sequence matches - actual_sequences = [seq for _, seq in actual_chains_and_sequences] - expected_sequences_only = [seq for _, seq in expected_sequences] - - # Sort sequences for comparison (since chain order might vary) - actual_sequences.sort() - expected_sequences_only.sort() - - self.assertEqual( - actual_sequences, - expected_sequences_only, - f"Sequences don't match. Expected: {expected_sequences_only}, Actual: {actual_sequences}" + + if self._requires_exact_chain_mapping(protein_list): + self._assert_exact_chain_mapping( + expected_sequences, + actual_chains_and_sequences, + context=protein_list, ) + return + + actual_sequences = [seq for _, seq in actual_chains_and_sequences] + expected_sequences_only = [seq for _, seq in expected_sequences] + + # Sort sequences for comparison (since chain order might vary) + actual_sequences.sort() + expected_sequences_only.sort() - def _check_sequences_with_ptms(self, expected_sequences: List[Tuple[str, str]], - actual_chains_and_sequences: List[Tuple[str, str]], - ptm_positions: List[int]): - """ - Check that sequences are reasonable for PTM cases. - - Args: - expected_sequences: List of (chain_id, sequence) tuples for expected chains - actual_chains_and_sequences: List of (chain_id, sequence) tuples for actual chains - ptm_positions: List of PTM positions (1-based) - """ - # Create dictionaries for easier lookup - expected_dict = dict(expected_sequences) - actual_dict = dict(actual_chains_and_sequences) - - # Check that all chain IDs match self.assertEqual( - set(expected_dict.keys()), - set(actual_dict.keys()), - f"Chain IDs don't match. Expected: {set(expected_dict.keys())}, Actual: {set(actual_dict.keys())}" + actual_sequences, + expected_sequences_only, + f"Sequences don't match. Expected: {expected_sequences_only}, Actual: {actual_sequences}" ) - - # Check each chain - for chain_id in expected_dict.keys(): - expected_seq = expected_dict[chain_id] - actual_seq = actual_dict[chain_id] - - print(f"Chain {chain_id}:") - print(f" Expected: {expected_seq}") - print(f" Actual: {actual_seq}") - print(f" PTM positions: {ptm_positions}") - - # For PTM cases, we'll be very lenient and just check that: - # 1. Chain IDs match (already checked above) - # 2. Sequences are not empty - # 3. Sequences contain only valid amino acid characters - - self.assertGreater( - len(actual_seq), - 0, - f"Sequence for chain {chain_id} is empty" - ) - - # Check that sequence contains only valid amino acid characters - valid_aa = set('ACDEFGHIKLMNPQRSTVWY') - invalid_chars = set(actual_seq) - valid_aa - self.assertEqual( - len(invalid_chars), - 0, - f"Sequence for chain {chain_id} contains invalid amino acid characters: {invalid_chars}" - ) - - print(f" ✓ Chain {chain_id}: Valid sequence with correct chain ID") def _make_af3_test_env(self) -> Dict[str, str]: flash_impl = self._af3_flash_attention_impl() @@ -1149,15 +1239,220 @@ def _paired_empty(d): print("✓ Combined AF3 input JSON created; per-chain MSAs present for backend pairing") - def test_af3_splits_discontinuous_chopped_regions_into_separate_chains(self): - """AF3 must encode multi-region chopped inputs as separate protein chains.""" + def test_af3_custom_residue_ids_round_trip_through_json_and_structure(self): + """Custom AF3 residue IDs must survive JSON and structure conversion.""" + from alphafold3.common import folding_input + from alphafold3.constants import chemical_components + + expected_residue_ids = [2, 3, 4, 5, 8, 9, 10] + chain = folding_input.ProteinChain( + id="A", + sequence="SSHEKKK", + ptms=[], + residue_ids=expected_residue_ids, + unpaired_msa="", + paired_msa="", + templates=[], + ) + fold_input = folding_input.Input( + name="gap_test", + chains=[chain], + rng_seeds=[1], + ) + + round_tripped = folding_input.Input.from_json(fold_input.to_json()) + self.assertEqual( + list(round_tripped.protein_chains[0].residue_ids), + expected_residue_ids, + ) + + struc = round_tripped.to_structure(ccd=chemical_components.Ccd()) + self.assertEqual(struc.present_residues.id.tolist(), expected_residue_ids) + + def test_af3_custom_residue_ids_propagate_to_token_features(self): + """AF3 token features must retain custom gapped residue numbering.""" + from alphafold3.common import folding_input + from alphafold3.constants import chemical_components + from alphafold3.model import features as af3_features + from alphafold3.model.atom_layout import atom_layout + from alphafold3.model.network import featurization as af3_featurization + + expected_residue_ids = [1, 2, 3, 4, 8, 9, 10] + chain = folding_input.ProteinChain( + id="A", + sequence="ACDEFGH", + ptms=[], + residue_ids=expected_residue_ids, + unpaired_msa="", + paired_msa="", + templates=[], + ) + fold_input = folding_input.Input( + name="gap_token_test", + chains=[chain], + rng_seeds=[1], + ) + ccd = chemical_components.Ccd() + struc = fold_input.to_structure(ccd=ccd) + flat_layout = atom_layout.atom_layout_from_structure(struc) + all_tokens, _, _ = af3_features.tokenizer( + flat_layout, + ccd=ccd, + max_atoms_per_token=24, + flatten_non_standard_residues=False, + logging_name="gap_token_test", + ) + padding_shapes = af3_features.PaddingShapes( + num_tokens=len(all_tokens.atom_name), + msa_size=1, + num_chains=1, + num_templates=0, + num_atoms=24 * len(all_tokens.atom_name), + ) + token_features = af3_features.TokenFeatures.compute_features( + all_tokens=all_tokens, + padding_shapes=padding_shapes, + ) + + self.assertEqual( + token_features.residue_index[:len(expected_residue_ids)].tolist(), + expected_residue_ids, + ) + self.assertEqual( + sorted(set(token_features.asym_id[:len(expected_residue_ids)].tolist())), + [1], + ) + + relative_encoding = np.asarray( + af3_featurization.create_relative_encoding( + token_features, + max_relative_idx=4, + max_relative_chain=2, + ) + ) + inter_chain_bin = 2 * 4 + 1 + self.assertEqual(relative_encoding[3, 4, inter_chain_bin], 0) + self.assertEqual(np.argmax(relative_encoding[3, 4, : 2 * 4 + 2]), 0) + + def test_af3_duplicate_residue_ids_survive_empty_structure_round_trip(self): + """AF3 must preserve duplicate residue IDs when rebuilding empty structures.""" + from alphafold3.common import folding_input + from alphafold3.constants import chemical_components + from alphafold3.model.atom_layout import atom_layout + + expected_residue_ids = list(range(1, 11)) + list(range(2, 6)) + list(range(12, 16)) + chain = folding_input.ProteinChain( + id="A", + sequence="ACDEFGHIKLCDEFMNPQ", + ptms=[], + residue_ids=expected_residue_ids, + unpaired_msa="", + paired_msa="", + templates=[], + ) + fold_input = folding_input.Input( + name="duplicate_residue_ids_test", + chains=[chain], + rng_seeds=[1], + ) + ccd = chemical_components.Ccd() + struc = fold_input.to_structure(ccd=ccd) + flat_layout = atom_layout.atom_layout_from_structure(struc) + all_physical_residues = atom_layout.residues_from_structure(struc) + rebuilt = atom_layout.make_structure( + flat_layout, + atom_coords=np.zeros((flat_layout.atom_name.shape[0], 3), dtype=np.float32), + name="duplicate_residue_ids_test", + all_physical_residues=all_physical_residues, + ) + + self.assertEqual(rebuilt.present_residues.id.tolist(), expected_residue_ids) + + def test_af3_output_job_name_compacts_long_homomer_names(self): + """AF3 job names should stay readable and below common filename limits.""" + from alphapulldown.folding_backend.alphafold3_backend import AlphaFold3Backend + + parsed = parse_fold( + ["A0A075B6L2:10:1-3:4-5:6-7:7-8"], + [str(self.test_features_dir)], + "+", + ) + data = create_custom_info(parsed) + all_interactors = create_interactors(data, [str(self.test_features_dir)]) + self.assertLen(all_interactors, 1) + self.assertLen(all_interactors[0], 10) + + object_to_model = MultimericObject(interactors=all_interactors[0], pair_msa=True) + mappings = AlphaFold3Backend.prepare_input( + objects_to_model=[{"object": object_to_model, "output_dir": str(self.output_dir)}], + random_seed=42, + ) + self.assertLen(mappings, 1) + fold_input_obj, _ = next(iter(mappings[0].items())) + + self.assertEqual( + fold_input_obj.sanitised_name(), + "A0A075B6L2_1-3_4-5_6-7_7-8__x10", + ) + self.assertLessEqual(len(fold_input_obj.sanitised_name()), 200) + expected_sequence = "".join( + self._get_region_sequences( + "A0A075B6L2", + [(1, 3), (4, 5), (6, 7), (8, 8)], + ) + ) + self.assertTrue( + all(chain.sequence == expected_sequence for chain in fold_input_obj.chains) + ) + self.assertTrue( + all(list(chain.residue_ids) == [1, 2, 3, 4, 5, 6, 7, 8] for chain in fold_input_obj.chains) + ) + + def test_af3_output_job_name_hashes_overlong_unique_compound_names(self): + """AF3 job names should fall back to a deterministic hash suffix when needed.""" + from alphapulldown.folding_backend.alphafold3_backend import ( + _build_output_job_name, + ) + + fragments = [ + f"protein_{index:02d}_{'verylongsegment' * 4}" + for index in range(12) + ] + objects_to_model = [ + { + "object": { + "json_input": str( + Path("/tmp") / f"{fragment}_af3_input.json" + ) + }, + "output_dir": str(self.output_dir), + } + for fragment in fragments + ] + + readable_name = "_and_".join(fragments) + self.assertGreater(len(readable_name), 200) + + job_name = _build_output_job_name(objects_to_model) + expected_digest = hashlib.sha1( + readable_name.encode("utf-8") + ).hexdigest()[:12] + + self.assertLessEqual(len(job_name), 200) + self.assertTrue(job_name.endswith(f"__{expected_digest}")) + self.assertRegex(job_name, r"__[0-9a-f]{12}$") + self.assertEqual(job_name, _build_output_job_name(objects_to_model)) + + def test_af3_prepare_input_accepts_monomer_plus_ligand_json(self): + """AF3 mixed protein+ligand JSON inputs must survive prepare_input cloning.""" + from alphafold3.common import folding_input from alphapulldown.folding_backend.alphafold3_backend import ( AlphaFold3Backend, process_fold_input, ) parsed = parse_fold( - ["TEST+A0A075B6L2:1-10:2-5:12-15"], + ["A0A024R1R8+ligand.json"], [str(self.test_features_dir)], "+", ) @@ -1166,9 +1461,10 @@ def test_af3_splits_discontinuous_chopped_regions_into_separate_chains(self): self.assertLen(all_interactors, 1) self.assertLen(all_interactors[0], 2) - object_to_model = MultimericObject(interactors=all_interactors[0], pair_msa=True) - objects_to_model = [{"object": object_to_model, "output_dir": str(self.output_dir)}] - + objects_to_model = [ + {"object": obj, "output_dir": str(self.output_dir)} + for obj in all_interactors[0] + ] mappings = AlphaFold3Backend.prepare_input( objects_to_model=objects_to_model, random_seed=42, @@ -1176,18 +1472,10 @@ def test_af3_splits_discontinuous_chopped_regions_into_separate_chains(self): self.assertLen(mappings, 1) fold_input_obj, _ = next(iter(mappings[0].items())) - expected_sequences = [ - self._get_sequence_for_protein("TEST"), - *self._get_region_sequences( - "A0A075B6L2", - [(1, 10), (2, 5), (12, 15)], - ), - ] - concatenated_chopped_sequence = "".join(expected_sequences[1:]) - actual_sequences = [chain.sequence for chain in fold_input_obj.chains] - self.assertCountEqual(actual_sequences, expected_sequences) - self.assertLen(actual_sequences, 4) - self.assertNotIn(concatenated_chopped_sequence, actual_sequences) + self.assertEqual([chain.id for chain in fold_input_obj.chains], ["A", "L"]) + self.assertIsInstance(fold_input_obj.chains[0], folding_input.ProteinChain) + self.assertIsInstance(fold_input_obj.chains[1], folding_input.Ligand) + self.assertEqual(list(fold_input_obj.chains[1].ccd_ids), ["ATP"]) process_fold_input( fold_input=fold_input_obj, @@ -1197,72 +1485,52 @@ def test_af3_splits_discontinuous_chopped_regions_into_separate_chains(self): ) input_json = self.output_dir / f"{fold_input_obj.sanitised_name()}_data.json" with open(input_json, "rt") as handle: - data = json.load(handle) + written = json.load(handle) protein_entries = [ sequence_entry["protein"] - for sequence_entry in data.get("sequences", []) + for sequence_entry in written.get("sequences", []) if "protein" in sequence_entry ] - self.assertLen(protein_entries, 4) - self.assertCountEqual( - [entry["sequence"] for entry in protein_entries], - expected_sequences, - ) - - print("✓ AF3 input expands discontinuous chopped regions into separate chains") - - def test_af3_json_feature_ranges_expand_into_separate_chains(self): - """AF3 JSON feature files with ranges must expand into separate protein chains.""" + ligand_entries = [ + sequence_entry["ligand"] + for sequence_entry in written.get("sequences", []) + if "ligand" in sequence_entry + ] + self.assertLen(protein_entries, 1) + self.assertLen(ligand_entries, 1) + self.assertEqual(ligand_entries[0]["id"], "L") + self.assertEqual(ligand_entries[0]["ccdCodes"], ["ATP"]) + + def test_af3_prepare_input_skips_invalid_json_templates_for_ptm_input(self): + """Malformed inline JSON templates should be dropped instead of crashing AF3.""" + from alphafold3.common import folding_input from alphapulldown.folding_backend.alphafold3_backend import ( AlphaFold3Backend, process_fold_input, ) - feature_dir = self.test_features_dir / "af3_features" / "protein" - json_filename = "A0A024R1R8_af3_input.json" - parsed = parse_fold( - [f"{json_filename}:2-5:8-10"], - [str(feature_dir)], - "+", - ) - self.assertEqual( - parsed, - [[ - { - "json_input": str(feature_dir / json_filename), - "regions": [(2, 5), (8, 10)], - } - ]], - ) - - data = create_custom_info(parsed) - all_interactors = create_interactors(data, [str(feature_dir)]) - self.assertLen(all_interactors, 1) - self.assertLen(all_interactors[0], 1) - self.assertIsInstance(all_interactors[0][0], dict) + json_input = self.test_features_dir / "protein_with_ptms.json" + raw_payload = json.loads(json_input.read_text()) + expected_protein = raw_payload["sequences"][0]["protein"] - objects_to_model = [{"object": all_interactors[0][0], "output_dir": str(self.output_dir)}] mappings = AlphaFold3Backend.prepare_input( - objects_to_model=objects_to_model, + objects_to_model=[ + { + "object": {"json_input": str(json_input)}, + "output_dir": str(self.output_dir), + } + ], random_seed=42, ) self.assertLen(mappings, 1) fold_input_obj, _ = next(iter(mappings[0].items())) - json_sequences = self._get_sequence_from_json( - "af3_features/protein/A0A024R1R8_af3_input.json" - ) - self.assertLen(json_sequences, 1) - full_sequence = json_sequences[0][1] - expected_sequences = [ - full_sequence[1:5], - full_sequence[7:10], - ] - self.assertCountEqual( - [chain.sequence for chain in fold_input_obj.chains], - expected_sequences, - ) + self.assertEqual([chain.id for chain in fold_input_obj.chains], ["P"]) + self.assertLen(fold_input_obj.chains, 1) + self.assertIsInstance(fold_input_obj.chains[0], folding_input.ProteinChain) + self.assertEqual(list(fold_input_obj.chains[0].ptms), [("HYS", 1), ("2MG", 15)]) + self.assertEqual(list(fold_input_obj.chains[0].templates), []) process_fold_input( fold_input=fold_input_obj, @@ -1279,49 +1547,541 @@ def test_af3_json_feature_ranges_expand_into_separate_chains(self): for sequence_entry in written.get("sequences", []) if "protein" in sequence_entry ] - self.assertLen(protein_entries, 2) - self.assertCountEqual( - [entry["sequence"] for entry in protein_entries], - expected_sequences, + self.assertLen(protein_entries, 1) + self.assertEqual(protein_entries[0]["id"], "P") + self.assertEqual(protein_entries[0]["sequence"], expected_protein["sequence"]) + self.assertEqual( + protein_entries[0]["modifications"], + expected_protein["modifications"], ) + self.assertEqual(protein_entries[0]["templates"], []) - print("✓ AF3 JSON feature ranges expand into separate chains") + def test_af3_prepare_input_keeps_valid_json_templates(self): + """Valid inline JSON templates should survive prepare_input and JSON write-out.""" + from alphafold3.common import folding_input + from alphapulldown.folding_backend.alphafold3_backend import ( + AlphaFold3Backend, + process_fold_input, + ) - def test_af3_predicts_json_feature_ranges_as_separate_chains(self): - """Run AF3 on a Snakefile-style AF3 JSON feature input with explicit ranges.""" - self._require_af3_functional_environment() - env = self._make_af3_test_env() - flash_impl = self._af3_flash_attention_impl() - feature_dir = self.test_features_dir / "af3_features" / "protein" + json_input = ( + self.test_features_dir + / "af3_features" + / "protein" + / "P61626_af3_input.json" + ) + raw_payload = json.loads(json_input.read_text()) + expected_protein = raw_payload["sequences"][0]["protein"] + expected_template_count = len(expected_protein["templates"]) + self.assertGreater(expected_template_count, 0) - res = subprocess.run( - [ - sys.executable, - str(self.script_single), - "--input=A0A024R1R8_af3_input.json:2-5:8-10", - f"--output_directory={self.output_dir}", - f"--data_directory={DATA_DIR}", - f"--features_directory={feature_dir}", - "--fold_backend=alphafold3", - f"--flash_attention_implementation={flash_impl}", - "--num_diffusion_samples=1", + mappings = AlphaFold3Backend.prepare_input( + objects_to_model=[ + { + "object": {"json_input": str(json_input)}, + "output_dir": str(self.output_dir), + } ], - capture_output=True, - text=True, - env=env, + random_seed=42, ) - self._runCommonTests(res) + self.assertLen(mappings, 1) + fold_input_obj, _ = next(iter(mappings[0].items())) + + self.assertEqual([chain.id for chain in fold_input_obj.chains], ["A"]) + self.assertLen(fold_input_obj.chains, 1) + self.assertIsInstance(fold_input_obj.chains[0], folding_input.ProteinChain) + self.assertLen(fold_input_obj.chains[0].templates, expected_template_count) + + process_fold_input( + fold_input=fold_input_obj, + model_runner=None, + output_dir=str(self.output_dir), + buckets=(512,), + ) + input_json = self.output_dir / f"{fold_input_obj.sanitised_name()}_data.json" + with open(input_json, "rt") as handle: + written = json.load(handle) + + protein_entries = [ + sequence_entry["protein"] + for sequence_entry in written.get("sequences", []) + if "protein" in sequence_entry + ] + self.assertLen(protein_entries, 1) + self.assertEqual(protein_entries[0]["id"], "A") + self.assertEqual( + len(protein_entries[0]["templates"]), + expected_template_count, + ) + self.assertTrue( + all(template["mmcif"] for template in protein_entries[0]["templates"]) + ) + self.assertTrue( + all(template["queryIndices"] for template in protein_entries[0]["templates"]) + ) + self.assertTrue( + all(template["templateIndices"] for template in protein_entries[0]["templates"]) + ) + + def test_af3_viewer_output_renumbers_gapped_residue_ids_for_viewers(self): + """Viewer-safe AF3 output must use sequential label IDs for gapped chains.""" + from alphafold3.common import folding_input + from alphafold3.constants import chemical_components + from alphafold3.model import model as af3_model + from alphapulldown.folding_backend.alphafold3_backend import ( + _make_viewer_compatible_inference_result, + ) + + original_residue_ids = [2, 3, 4, 5, 8, 9, 10] + chain = folding_input.ProteinChain( + id="A", + sequence="ACDEFGH", + ptms=[], + residue_ids=original_residue_ids, + unpaired_msa="", + paired_msa="", + templates=[], + ) + fold_input = folding_input.Input( + name="gapped_residue_ids_for_viewers", + chains=[chain], + rng_seeds=[1], + ) + struc = fold_input.to_structure(ccd=chemical_components.Ccd()) + inference_result = af3_model.InferenceResult( + predicted_structure=struc, + metadata={ + "token_chain_ids": ["A"] * len(original_residue_ids), + "token_res_ids": original_residue_ids, + }, + ) + + viewer_result = _make_viewer_compatible_inference_result(inference_result) + + self.assertEqual( + viewer_result.predicted_structure.present_residues.id.tolist(), + list(range(1, len(original_residue_ids) + 1)), + ) + self.assertEqual( + viewer_result.metadata["token_res_ids"], + list(range(1, len(original_residue_ids) + 1)), + ) + self.assertEqual( + viewer_result.predicted_structure.residues_table.auth_seq_id.tolist(), + [str(residue_id) for residue_id in original_residue_ids], + ) + self.assertEqual( + viewer_result.predicted_structure.residues_table.insertion_code.tolist(), + ["."] * len(original_residue_ids), + ) + self.assertEqual( + viewer_result.metadata["token_auth_res_ids"], + [str(residue_id) for residue_id in original_residue_ids], + ) + self.assertEqual( + viewer_result.metadata["token_auth_res_labels"], + [str(residue_id) for residue_id in original_residue_ids], + ) + + def test_af3_viewer_output_uses_insertion_codes_for_duplicate_residue_ids(self): + """Viewer-safe AF3 output must preserve IDs and disambiguate with insertions.""" + from alphafold3.common import folding_input + from alphafold3.constants import chemical_components + from alphafold3.model import model as af3_model + from alphapulldown.folding_backend.alphafold3_backend import ( + _make_viewer_compatible_inference_result, + ) + + original_residue_ids = ( + list(range(1, 11)) + list(range(2, 6)) + list(range(12, 16)) + ) + chain = folding_input.ProteinChain( + id="A", + sequence="ACDEFGHIKLCDEFMNPQ", + ptms=[], + residue_ids=original_residue_ids, + unpaired_msa="", + paired_msa="", + templates=[], + ) + fold_input = folding_input.Input( + name="duplicate_residue_ids_for_chimerax", + chains=[chain], + rng_seeds=[1], + ) + struc = fold_input.to_structure(ccd=chemical_components.Ccd()) + inference_result = af3_model.InferenceResult( + predicted_structure=struc, + metadata={ + "token_chain_ids": ["A"] * len(original_residue_ids), + "token_res_ids": original_residue_ids, + }, + ) + + viewer_result = _make_viewer_compatible_inference_result( + inference_result + ) + + self.assertEqual( + viewer_result.predicted_structure.present_residues.id.tolist(), + list(range(1, len(original_residue_ids) + 1)), + ) + self.assertEqual( + viewer_result.metadata["token_res_ids"], + list(range(1, len(original_residue_ids) + 1)), + ) + self.assertEqual( + viewer_result.predicted_structure.residues_table.auth_seq_id.tolist(), + [str(residue_id) for residue_id in original_residue_ids], + ) + self.assertEqual( + viewer_result.predicted_structure.residues_table.insertion_code.tolist(), + ['.'] * 10 + ['A'] * 4 + ['.'] * 4, + ) + self.assertEqual( + viewer_result.metadata["token_auth_res_ids"], + [str(residue_id) for residue_id in original_residue_ids], + ) + self.assertEqual( + viewer_result.metadata["token_pdb_ins_codes"], + ['.'] * 10 + ['A'] * 4 + ['.'] * 4, + ) + self.assertEqual( + viewer_result.metadata["token_auth_res_labels"], + [str(i) for i in range(1, 11)] + + [f"{i}A" for i in range(2, 6)] + + [str(i) for i in range(12, 16)], + ) + + def test_af3_viewer_output_handles_many_tokens_for_one_residue(self): + """Viewer metadata must not crash when many tokens map to one residue.""" + from alphafold3.common import folding_input + from alphafold3.constants import chemical_components + from alphafold3.model import model as af3_model + from alphapulldown.folding_backend.alphafold3_backend import ( + _make_viewer_compatible_inference_result, + ) + + chain = folding_input.ProteinChain( + id="L", + sequence="A", + ptms=[], + residue_ids=[1], + unpaired_msa="", + paired_msa="", + templates=[], + ) + fold_input = folding_input.Input( + name="many_tokens_one_residue", + chains=[chain], + rng_seeds=[1], + ) + struc = fold_input.to_structure(ccd=chemical_components.Ccd()) + token_count = 40 + inference_result = af3_model.InferenceResult( + predicted_structure=struc, + metadata={ + "token_chain_ids": ["L"] * token_count, + "token_res_ids": [1] * token_count, + }, + ) + + viewer_result = _make_viewer_compatible_inference_result(inference_result) + + self.assertEqual( + viewer_result.metadata["token_res_ids"], + list(range(1, token_count + 1)), + ) + self.assertEqual( + viewer_result.metadata["token_auth_res_ids"], + ["1"] * token_count, + ) + self.assertEqual( + viewer_result.metadata["token_pdb_ins_codes"][:27], + ["."] + [chr(ord("A") + index) for index in range(26)], + ) + self.assertEqual( + viewer_result.metadata["token_pdb_ins_codes"][27:], + ["."] * (token_count - 27), + ) + self.assertEqual( + viewer_result.metadata["token_auth_res_labels"][:27], + ["1"] + [f"1{chr(ord('A') + index)}" for index in range(26)], + ) + self.assertEqual( + viewer_result.metadata["token_auth_res_labels"][27], + "1[28]", + ) + self.assertEqual( + viewer_result.metadata["token_auth_res_labels"][-1], + "1[40]", + ) + + def test_af3_keeps_discontinuous_chopped_regions_in_one_gapped_chain(self): + """AF3 must keep multi-region chopped inputs as one gapped protein chain.""" + from alphapulldown.folding_backend.alphafold3_backend import ( + AlphaFold3Backend, + process_fold_input, + ) + + parsed = parse_fold( + ["TEST+A0A075B6L2:1-10:2-5:12-15"], + [str(self.test_features_dir)], + "+", + ) + data = create_custom_info(parsed) + all_interactors = create_interactors(data, [str(self.test_features_dir)]) + self.assertLen(all_interactors, 1) + self.assertLen(all_interactors[0], 2) + + object_to_model = MultimericObject(interactors=all_interactors[0], pair_msa=True) + objects_to_model = [{"object": object_to_model, "output_dir": str(self.output_dir)}] + + mappings = AlphaFold3Backend.prepare_input( + objects_to_model=objects_to_model, + random_seed=42, + ) + self.assertLen(mappings, 1) + fold_input_obj, _ = next(iter(mappings[0].items())) + + chopped_region_sequences = self._get_region_sequences( + "A0A075B6L2", + [(1, 10), (2, 5), (12, 15)], + ) + concatenated_chopped_sequence = "".join(chopped_region_sequences) + expected_sequences = [ + self._get_sequence_for_protein("TEST"), + concatenated_chopped_sequence, + ] + expected_chopped_residue_ids = ( + list(range(1, 11)) + + [2, 3, 4, 5] + + list(range(12, 16)) + ) + actual_sequences = [chain.sequence for chain in fold_input_obj.chains] + self.assertCountEqual(actual_sequences, expected_sequences) + self.assertLen(actual_sequences, 2) + + chopped_chains = [ + chain for chain in fold_input_obj.chains + if chain.sequence == concatenated_chopped_sequence + ] + self.assertLen(chopped_chains, 1) + self.assertEqual( + list(chopped_chains[0].residue_ids), + expected_chopped_residue_ids, + ) + + process_fold_input( + fold_input=fold_input_obj, + model_runner=None, + output_dir=str(self.output_dir), + buckets=(512,), + ) + input_json = self.output_dir / f"{fold_input_obj.sanitised_name()}_data.json" + with open(input_json, "rt") as handle: + data = json.load(handle) + + protein_entries = [ + sequence_entry["protein"] + for sequence_entry in data.get("sequences", []) + if "protein" in sequence_entry + ] + self.assertLen(protein_entries, 2) + self.assertCountEqual( + [entry["sequence"] for entry in protein_entries], + expected_sequences, + ) + chopped_entries = [ + entry for entry in protein_entries + if entry["sequence"] == concatenated_chopped_sequence + ] + self.assertLen(chopped_entries, 1) + self.assertEqual( + chopped_entries[0]["residueIds"], + expected_chopped_residue_ids, + ) + + print("✓ AF3 input keeps discontinuous chopped regions as one gapped chain") + + def test_af3_keeps_two_out_of_order_gapped_copies_as_two_chains(self): + """AF3 must keep two copied out-of-order gapped regions as two chains.""" + from alphapulldown.folding_backend.alphafold3_backend import ( + AlphaFold3Backend, + process_fold_input, + ) + + parsed = parse_fold( + ["A0A075B6L2:2:8-10:2-5"], + [str(self.test_features_dir)], + "+", + ) + + data = create_custom_info(parsed) + all_interactors = create_interactors(data, [str(self.test_features_dir)]) + self.assertLen(all_interactors, 1) + self.assertLen(all_interactors[0], 2) + + objects_to_model = [{"object": all_interactors[0], "output_dir": str(self.output_dir)}] + mappings = AlphaFold3Backend.prepare_input( + objects_to_model=objects_to_model, + random_seed=42, + ) + self.assertLen(mappings, 1) + fold_input_obj, _ = next(iter(mappings[0].items())) + + expected_regions = [(8, 10), (2, 5)] + expected_sequence = "".join( + self._get_region_sequences("A0A075B6L2", expected_regions) + ) + expected_residue_ids = [8, 9, 10, 2, 3, 4, 5] + + self.assertEqual( + [chain.id for chain in fold_input_obj.chains], + ["A", "B"], + ) + self.assertEqual( + [chain.sequence for chain in fold_input_obj.chains], + [expected_sequence, expected_sequence], + ) + self.assertEqual( + [list(chain.residue_ids) for chain in fold_input_obj.chains], + [expected_residue_ids, expected_residue_ids], + ) + + process_fold_input( + fold_input=fold_input_obj, + model_runner=None, + output_dir=str(self.output_dir), + buckets=(512,), + ) + input_json = self.output_dir / f"{fold_input_obj.sanitised_name()}_data.json" + with open(input_json, "rt") as handle: + written = json.load(handle) + + protein_entries = [ + sequence_entry["protein"] + for sequence_entry in written.get("sequences", []) + if "protein" in sequence_entry + ] + self.assertLen(protein_entries, 1) + self.assertEqual(protein_entries[0]["id"], ["A", "B"]) + self.assertEqual(protein_entries[0]["sequence"], expected_sequence) + self.assertEqual(protein_entries[0]["residueIds"], expected_residue_ids) + + print("✓ AF3 input keeps two copied out-of-order gapped regions as two chains") + + def test_af3_json_feature_ranges_collapse_into_one_gapped_chain(self): + """AF3 JSON feature files with ranges must collapse into one gapped chain.""" + from alphapulldown.folding_backend.alphafold3_backend import ( + AlphaFold3Backend, + process_fold_input, + ) + + feature_dir = self.test_features_dir / "af3_features" / "protein" + json_filename = "A0A024R1R8_af3_input.json" + parsed = parse_fold( + [f"{json_filename}:2-5:8-10"], + [str(feature_dir)], + "+", + ) + self.assertEqual( + parsed, + [[ + { + "json_input": str(feature_dir / json_filename), + "regions": [(2, 5), (8, 10)], + } + ]], + ) + + data = create_custom_info(parsed) + all_interactors = create_interactors(data, [str(feature_dir)]) + self.assertLen(all_interactors, 1) + self.assertLen(all_interactors[0], 1) + self.assertIsInstance(all_interactors[0][0], dict) + + objects_to_model = [{"object": all_interactors[0][0], "output_dir": str(self.output_dir)}] + mappings = AlphaFold3Backend.prepare_input( + objects_to_model=objects_to_model, + random_seed=42, + ) + self.assertLen(mappings, 1) + fold_input_obj, _ = next(iter(mappings[0].items())) json_sequences = self._get_sequence_from_json( "af3_features/protein/A0A024R1R8_af3_input.json" ) self.assertLen(json_sequences, 1) full_sequence = json_sequences[0][1] - expected_sequences = [ - full_sequence[1:5], - full_sequence[7:10], + expected_sequence = full_sequence[1:5] + full_sequence[7:10] + expected_residue_ids = [2, 3, 4, 5, 8, 9, 10] + self.assertEqual( + [chain.sequence for chain in fold_input_obj.chains], + [expected_sequence], + ) + self.assertEqual( + fold_input_obj.sanitised_name(), + "A0A024R1R8__2-5_8-10", + ) + self.assertEqual( + [list(chain.residue_ids) for chain in fold_input_obj.chains], + [expected_residue_ids], + ) + + process_fold_input( + fold_input=fold_input_obj, + model_runner=None, + output_dir=str(self.output_dir), + buckets=(512,), + ) + input_json = self.output_dir / f"{fold_input_obj.sanitised_name()}_data.json" + with open(input_json, "rt") as handle: + written = json.load(handle) + + protein_entries = [ + sequence_entry["protein"] + for sequence_entry in written.get("sequences", []) + if "protein" in sequence_entry ] - concatenated_sequence = "".join(expected_sequences) + self.assertLen(protein_entries, 1) + self.assertEqual(protein_entries[0]["sequence"], expected_sequence) + self.assertEqual(protein_entries[0]["residueIds"], expected_residue_ids) + + print("✓ AF3 JSON feature ranges collapse into one gapped chain") + + def test_af3_predicts_json_feature_ranges_as_one_gapped_chain(self): + """Run AF3 on a Snakefile-style AF3 JSON feature input with explicit ranges.""" + self._require_af3_functional_environment() + env = self._make_af3_test_env() + flash_impl = self._af3_flash_attention_impl() + feature_dir = self.test_features_dir / "af3_features" / "protein" + + res = subprocess.run( + [ + sys.executable, + str(self.script_single), + "--input=A0A024R1R8_af3_input.json:2-5:8-10", + f"--output_directory={self.output_dir}", + f"--data_directory={DATA_DIR}", + f"--features_directory={feature_dir}", + "--fold_backend=alphafold3", + f"--flash_attention_implementation={flash_impl}", + "--num_diffusion_samples=1", + ], + capture_output=True, + text=True, + env=env, + ) + self._runCommonTests(res) + + json_sequences = self._get_sequence_from_json( + "af3_features/protein/A0A024R1R8_af3_input.json" + ) + self.assertLen(json_sequences, 1) + full_sequence = json_sequences[0][1] + expected_sequence = full_sequence[1:5] + full_sequence[7:10] + expected_residue_ids = [2, 3, 4, 5, 8, 9, 10] result_dir = self._resolve_single_af3_result_dir() cif_files = list(result_dir.glob("*_model.cif")) @@ -1329,15 +2089,15 @@ def test_af3_predicts_json_feature_ranges_as_separate_chains(self): actual_chains_and_sequences = self._extract_cif_chains_and_sequences(cif_files[0]) actual_sequences = [sequence for _, sequence in actual_chains_and_sequences] + actual_residue_numbers = self._extract_cif_chain_residue_numbers(cif_files[0]) - self.assertLen(actual_sequences, 2) - self.assertCountEqual(actual_sequences, expected_sequences) - self.assertNotIn(concatenated_sequence, actual_sequences) + self.assertEqual(actual_sequences, [expected_sequence]) + self.assertEqual(actual_residue_numbers, [("A", expected_residue_ids)]) - print("✓ AF3 prediction keeps AF3 JSON feature ranges as separate chains") + print("✓ AF3 prediction keeps AF3 JSON feature ranges as one gapped chain") - def test_af3_predicts_discontinuous_chopped_regions_as_separate_chains(self): - """Run AF3 inference and ensure discontinuous chopped regions remain separate chains.""" + def test_af3_predicts_discontinuous_chopped_regions_as_one_gapped_chain(self): + """Run AF3 inference and ensure discontinuous chopped regions remain one chain.""" self._require_af3_functional_environment() env = self._make_af3_test_env() flash_impl = self._af3_flash_attention_impl() @@ -1360,14 +2120,20 @@ def test_af3_predicts_discontinuous_chopped_regions_as_separate_chains(self): ) self._runCommonTests(res) + chopped_region_sequences = self._get_region_sequences( + "A0A075B6L2", + [(1, 10), (2, 5), (12, 15)], + ) + concatenated_chopped_sequence = "".join(chopped_region_sequences) expected_sequences = [ self._get_sequence_for_protein("TEST"), - *self._get_region_sequences( - "A0A075B6L2", - [(1, 10), (2, 5), (12, 15)], - ), + concatenated_chopped_sequence, ] - concatenated_chopped_sequence = "".join(expected_sequences[1:]) + expected_chopped_residue_ids = ( + list(range(1, 11)) + + ["2A", "3A", "4A", "5A"] + + list(range(12, 16)) + ) result_dir = self._resolve_single_af3_result_dir() cif_files = list(result_dir.glob("*_model.cif")) @@ -1375,27 +2141,91 @@ def test_af3_predicts_discontinuous_chopped_regions_as_separate_chains(self): actual_chains_and_sequences = self._extract_cif_chains_and_sequences(cif_files[0]) actual_sequences = [sequence for _, sequence in actual_chains_and_sequences] + residue_numbers_by_chain = dict(self._extract_cif_chain_residue_numbers(cif_files[0])) + sequences_by_chain = dict(actual_chains_and_sequences) - self.assertLen(actual_sequences, 4) + self.assertLen(actual_sequences, 2) self.assertCountEqual(actual_sequences, expected_sequences) - self.assertNotIn(concatenated_chopped_sequence, actual_sequences) + chopped_chain_ids = [ + chain_id + for chain_id, sequence in sequences_by_chain.items() + if sequence == concatenated_chopped_sequence + ] + self.assertLen(chopped_chain_ids, 1) + self.assertEqual( + residue_numbers_by_chain[chopped_chain_ids[0]], + expected_chopped_residue_ids, + ) - print("✓ AF3 prediction keeps discontinuous chopped regions as separate chains") + print("✓ AF3 prediction keeps discontinuous chopped regions as one gapped chain") - def test_dimer_chopped_expected_sequences_are_split_by_region(self): - """Sequence expectations for AF3 chopped inputs must reflect chain splitting.""" + def test_af3_predicts_two_out_of_order_gapped_copies_as_two_chains(self): + """Run AF3 inference and ensure copied out-of-order gapped regions remain two chains.""" + self._require_af3_functional_environment() + env = self._make_af3_test_env() + flash_impl = self._af3_flash_attention_impl() + + res = subprocess.run( + [ + sys.executable, + str(self.script_single), + "--input=A0A075B6L2:2:8-10:2-5", + f"--output_directory={self.output_dir}", + f"--data_directory={DATA_DIR}", + f"--features_directory={self.test_features_dir}", + "--fold_backend=alphafold3", + f"--flash_attention_implementation={flash_impl}", + "--num_diffusion_samples=1", + ], + capture_output=True, + text=True, + env=env, + ) + self._runCommonTests(res) + + expected_regions = [(8, 10), (2, 5)] + expected_sequence = "".join( + self._get_region_sequences("A0A075B6L2", expected_regions) + ) + expected_residue_ids = [8, 9, 10, 2, 3, 4, 5] + + result_dir = self._resolve_single_af3_result_dir() + cif_files = list(result_dir.glob("*_model.cif")) + self.assertTrue(cif_files, f"No predicted CIF files found in {result_dir}") + + actual_chains_and_sequences = self._extract_cif_chains_and_sequences(cif_files[0]) + residue_numbers_by_chain = dict(self._extract_cif_chain_residue_numbers(cif_files[0])) + + self.assertEqual( + [sequence for _, sequence in actual_chains_and_sequences], + [expected_sequence, expected_sequence], + ) + self.assertEqual( + [chain_id for chain_id, _ in actual_chains_and_sequences], + ["A", "B"], + ) + self.assertEqual(residue_numbers_by_chain["A"], expected_residue_ids) + self.assertEqual(residue_numbers_by_chain["B"], expected_residue_ids) + + print("✓ AF3 prediction keeps two copied out-of-order gapped regions as two chains") + + def test_dimer_chopped_expected_sequences_are_concatenated_per_chain(self): + """Sequence expectations for AF3 chopped inputs must reflect one gapped chain.""" expected_sequences = self._extract_expected_sequences("test_dimer_chopped.txt") + chopped_sequence = "".join( + self._get_region_sequences( + "A0A075B6L2", + [(1, 10), (2, 5), (12, 15)], + ) + ) self.assertCountEqual( [sequence for _, sequence in expected_sequences], [ self._get_sequence_for_protein("TEST"), - *self._get_region_sequences( - "A0A075B6L2", - [(1, 10), (2, 5), (12, 15)], - ), + chopped_sequence, ], ) - self.assertLen(expected_sequences, 4) + self.assertLen(expected_sequences, 2) def test_multi_seeds_samples_sequence_extraction(self): """Test that sequence extraction works correctly for multi_seeds_samples.""" @@ -1408,77 +2238,79 @@ def test_multi_seeds_samples_sequence_extraction(self): def test_multi_seeds_samples_output_validation(self): """Test that the multi_seeds_samples output files are correct.""" - # Set up the test to use the existing output directory - test_name = "test__multi_seeds_samples" - output_dir = Path("test/test_data/predictions/af3_backend") / test_name - - if not output_dir.exists(): - self.skipTest(f"Output directory {output_dir} does not exist. Run the full test first.") - - # Temporarily set the output directory - original_output_dir = self.output_dir - self.output_dir = output_dir - - try: - # Check that all expected files exist - files = list(self.output_dir.iterdir()) - - # Check for main output files - self.assertIn("TERMS_OF_USE.md", {f.name for f in files}) - self.assertIn("ranking_scores.csv", {f.name for f in files}) - - # Check for AlphaFold3 output files - conf_files = [f for f in files if f.name.endswith("_confidences.json")] - summary_conf_files = [f for f in files if f.name.endswith("_summary_confidences.json")] - model_files = [f for f in files if f.name.endswith("_model.cif")] - - self.assertTrue(len(conf_files) > 0, "No confidences.json files found") - self.assertTrue(len(summary_conf_files) > 0, "No summary_confidences.json files found") - self.assertTrue(len(model_files) > 0, "No model.cif files found") - - # Check sample directories (should be 12: 3 seeds × 4 samples) - sample_dirs = [f for f in files if f.is_dir() and f.name.startswith("seed-")] - self.assertEqual(len(sample_dirs), 12, - f"Expected 12 sample directories, found {len(sample_dirs)}") - - # Check each sample directory has the required files - for sample_dir in sample_dirs: - sample_files = list(sample_dir.iterdir()) - self.assertIn("confidences.json", {f.name for f in sample_files}) - self.assertIn("model.cif", {f.name for f in sample_files}) - self.assertIn("summary_confidences.json", {f.name for f in sample_files}) - - # Verify ranking scores - with open(self.output_dir / "ranking_scores.csv") as f: - lines = f.readlines() - self.assertTrue(len(lines) > 1, "ranking_scores.csv should have header and data") - self.assertEqual(len(lines[0].strip().split(",")), 3, "ranking_scores.csv should have 3 columns") - - # Should have 12 entries + 1 header = 13 lines - expected_lines = 13 - self.assertEqual(len(lines), expected_lines, - f"Expected {expected_lines} lines in ranking_scores.csv, found {len(lines)}") - - # Verify CSV format for all data lines - for i, line in enumerate(lines[1:], 1): # Skip header - parts = line.strip().split(",") - self.assertEqual(len(parts), 3, f"Line {i+1} should have 3 columns: seed,sample,ranking_score") - # Verify that seed, sample are integers and ranking_score is a float - try: - int(parts[0]) # seed - int(parts[1]) # sample - float(parts[2]) # ranking_score - except ValueError: - self.fail(f"Line {i+1} has invalid format: {line.strip()}") - - # Check chain counts and sequences - self._check_chain_counts_and_sequences("test_multi_seeds_samples.txt") - - print(f"✓ Verified multi_seeds_samples output with {len(sample_dirs)} sample directories and {len(lines)-1} ranking score entries") - - finally: - # Restore original output directory - self.output_dir = original_output_dir + if not (self.output_dir / "ranking_scores.csv").exists(): + # Keep this validation test independently runnable under isolated temp dirs. + env = self._make_af3_test_env() + res = subprocess.run( + self._args( + plist="test_multi_seeds_samples.txt", + script="run_structure_prediction.py", + ), + capture_output=True, + text=True, + env=env, + ) + self._runCommonTests(res) + + result_dir = self._resolve_single_af3_result_dir() + files = list(result_dir.iterdir()) + + self.assertIn("TERMS_OF_USE.md", {f.name for f in files}) + self.assertIn("ranking_scores.csv", {f.name for f in files}) + + conf_files = [f for f in files if f.name.endswith("_confidences.json")] + summary_conf_files = [f for f in files if f.name.endswith("_summary_confidences.json")] + model_files = [f for f in files if f.name.endswith("_model.cif")] + + self.assertTrue(len(conf_files) > 0, "No confidences.json files found") + self.assertTrue(len(summary_conf_files) > 0, "No summary_confidences.json files found") + self.assertTrue(len(model_files) > 0, "No model.cif files found") + + sample_dirs = [f for f in files if f.is_dir() and f.name.startswith("seed-")] + self.assertEqual( + len(sample_dirs), + 12, + f"Expected 12 sample directories, found {len(sample_dirs)}", + ) + + for sample_dir in sample_dirs: + sample_files = list(sample_dir.iterdir()) + self.assertIn("confidences.json", {f.name for f in sample_files}) + self.assertIn("model.cif", {f.name for f in sample_files}) + self.assertIn("summary_confidences.json", {f.name for f in sample_files}) + + with open(result_dir / "ranking_scores.csv") as f: + lines = f.readlines() + self.assertTrue(len(lines) > 1, "ranking_scores.csv should have header and data") + self.assertEqual(len(lines[0].strip().split(",")), 3, "ranking_scores.csv should have 3 columns") + + expected_lines = 13 + self.assertEqual( + len(lines), + expected_lines, + f"Expected {expected_lines} lines in ranking_scores.csv, found {len(lines)}", + ) + + for i, line in enumerate(lines[1:], 1): + parts = line.strip().split(",") + self.assertEqual( + len(parts), + 3, + f"Line {i+1} should have 3 columns: seed,sample,ranking_score", + ) + try: + int(parts[0]) + int(parts[1]) + float(parts[2]) + except ValueError: + self.fail(f"Line {i+1} has invalid format: {line.strip()}") + + self._check_chain_counts_and_sequences("test_multi_seeds_samples.txt") + + print( + f"✓ Verified multi_seeds_samples output with {len(sample_dirs)} sample " + f"directories and {len(lines)-1} ranking score entries" + ) def test_af3_run_structure_prediction_keeps_single_explicit_output_dir_flat_for_json(self): """A single explicit output dir must remain flat even with --use_ap_style.""" @@ -1690,6 +2522,12 @@ def test_af3_writes_embeddings_and_distogram(self): def test_af3_num_recycles_affects_runtime(self): """num_recycles=1 should be faster than default (keeping other knobs same).""" + if os.getenv("AF3_RUN_PERF_TESTS", "").lower() not in ("1", "true", "yes"): + self.skipTest( + "Set AF3_RUN_PERF_TESTS=1 to run AF3 runtime benchmarks." + ) + + self._require_af3_functional_environment() env = self._make_af3_test_env() flash_impl = self._af3_flash_attention_impl() diff --git a/test/test_alphafold2_predictions.py b/test/test_alphafold2_predictions.py new file mode 100644 index 00000000..d20d0e55 --- /dev/null +++ b/test/test_alphafold2_predictions.py @@ -0,0 +1,685 @@ +#!/usr/bin/env python3 +"""Submit AlphaFold2 functional tests to Slurm and summarize results. + +This is a standalone wrapper for `test/check_alphafold2_predictions.py`. +It is intentionally not a pytest test module, despite the filename. + +Typical usage from a login node: + + python test/test_alphafold2_predictions.py + +Run only selected tests: + + python test/test_alphafold2_predictions.py -k dimer +""" + +from __future__ import annotations + +__test__ = False + +import argparse +import dataclasses +import datetime as dt +import importlib.util +import inspect +import json +import re +import shlex +import subprocess +import sys +import time +import unittest +from pathlib import Path +from typing import Iterable + +from _pytest.mark.expression import Expression + + +REPO_ROOT = Path(__file__).resolve().parents[1] +DEFAULT_TEST_FILE = REPO_ROOT / "test" / "check_alphafold2_predictions.py" +DEFAULT_LOG_ROOT = REPO_ROOT / "test_logs" + +PASS_STATES = {"COMPLETED"} +FAIL_STATES = { + "BOOT_FAIL", + "CANCELLED", + "DEADLINE", + "FAILED", + "NODE_FAIL", + "OUT_OF_MEMORY", + "PREEMPTED", + "REVOKED", + "TIMEOUT", +} + + +@dataclasses.dataclass(slots=True) +class JobSpec: + index: int + nodeid: str + slug: str + stdout_path: Path + stderr_path: Path + script_path: Path + rerun_command: str + job_id: str | None = None + slurm_state: str | None = None + exit_code: str | None = None + outcome: str | None = None + reason: str | None = None + + +def _has_cmd(cmd: str) -> bool: + try: + subprocess.run( + [cmd, "--help"], + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + check=False, + ) + return True + except FileNotFoundError: + return False + + +def _run( + cmd: list[str], + *, + cwd: Path = REPO_ROOT, + check: bool = True, +) -> subprocess.CompletedProcess[str]: + return subprocess.run( + cmd, + cwd=cwd, + text=True, + capture_output=True, + check=check, + ) + + +def _normalize_state(state: str | None) -> str | None: + if not state: + return None + return state.split()[0].rstrip("+") + + +def _slugify(value: str, *, max_len: int = 120) -> str: + slug = re.sub(r"[^A-Za-z0-9._-]+", "_", value).strip("._") + if not slug: + slug = "test" + if len(slug) > max_len: + slug = slug[:max_len].rstrip("._") + return slug + + +def _quote(value: str) -> str: + return shlex.quote(value) + + +def _timestamp() -> str: + return dt.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + + +def _relative_nodeid_prefix(test_file: Path) -> str: + return str(test_file.resolve().relative_to(REPO_ROOT)) + + +def _matches_k_expression(nodeid: str, k_expr: str | None) -> bool: + if not k_expr: + return True + expression = Expression.compile(k_expr) + lowered = nodeid.lower() + return expression.evaluate(lambda token: token.lower() in lowered) + + +def _collect_nodeids_from_module_import(test_file: Path, k_expr: str | None) -> list[str]: + module_name = f"_codex_collect_{test_file.stem}" + spec = importlib.util.spec_from_file_location(module_name, test_file) + if spec is None or spec.loader is None: + raise RuntimeError(f"Failed to create import spec for {test_file}") + + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + try: + spec.loader.exec_module(module) + finally: + sys.modules.pop(module_name, None) + + prefix = _relative_nodeid_prefix(test_file) + nodeids: list[str] = [] + for _, cls in inspect.getmembers(module, inspect.isclass): + if cls.__module__ != module.__name__: + continue + if not issubclass(cls, unittest.TestCase): + continue + if not cls.__name__.startswith("Test"): + continue + + for method_name in sorted(name for name in dir(cls) if name.startswith("test")): + nodeid = f"{prefix}::{cls.__name__}::{method_name}" + if _matches_k_expression(nodeid, k_expr): + nodeids.append(nodeid) + return nodeids + + +def collect_nodeids( + *, + python_executable: str, + test_file: Path, + k_expr: str | None, +) -> list[str]: + cmd = [ + python_executable, + "-m", + "pytest", + "--collect-only", + "-q", + str(test_file), + ] + if k_expr: + cmd.extend(["-k", k_expr]) + result = _run(cmd, check=False) + if result.returncode != 0: + raise RuntimeError( + "pytest collection failed.\n" + f"Command: {' '.join(_quote(part) for part in cmd)}\n\n" + f"STDOUT:\n{result.stdout}\n\nSTDERR:\n{result.stderr}" + ) + + nodeids: list[str] = [] + for raw_line in result.stdout.splitlines(): + line = raw_line.strip() + if not line: + continue + if ".py::" not in line: + continue + if line.startswith("ERROR ") or line.startswith("SKIPPED "): + continue + nodeids.append(line) + if nodeids: + return nodeids + + return _collect_nodeids_from_module_import(test_file, k_expr) + + +def write_job_script( + *, + job: JobSpec, + python_executable: str, + use_temp_dir: bool, +) -> None: + pytest_cmd = [ + python_executable, + "-m", + "pytest", + "-vv", + "-s", + job.nodeid, + ] + if use_temp_dir: + pytest_cmd.append("--use-temp-dir") + + script = "\n".join( + [ + "#!/bin/bash", + "set -euo pipefail", + f"cd {_quote(str(REPO_ROOT))}", + "export PYTHONUNBUFFERED=1", + "echo \"[$(date)] Running test node:\"", + f"echo {_quote(job.nodeid)}", + "echo \"[$(date)] Host: $(hostname)\"", + "echo \"[$(date)] Python: $(which python || true)\"", + " ".join(_quote(part) for part in pytest_cmd), + "", + ] + ) + job.script_path.write_text(script, encoding="utf-8") + job.script_path.chmod(0o755) + + +def submit_job(job: JobSpec, args: argparse.Namespace) -> str: + cmd = [ + "sbatch", + "--parsable", + "--export=ALL", + f"--job-name={args.job_name_prefix}_{job.index:03d}", + f"--chdir={REPO_ROOT}", + f"--output={job.stdout_path}", + f"--error={job.stderr_path}", + f"--time={args.time}", + "--ntasks=1", + f"--cpus-per-task={args.cpus_per_task}", + f"--mem={args.mem}", + ] + if args.partition: + cmd.append(f"--partition={args.partition}") + if args.qos: + cmd.append(f"--qos={args.qos}") + if args.constraint: + cmd.append(f"--constraint={args.constraint}") + if args.account: + cmd.append(f"--account={args.account}") + if args.gres: + cmd.append(f"--gres={args.gres}") + for extra_arg in args.extra_sbatch_arg: + cmd.append(extra_arg) + cmd.append(str(job.script_path)) + + result = _run(cmd) + raw_job_id = result.stdout.strip().splitlines()[-1] + return raw_job_id.split(";", 1)[0] + + +def active_job_ids(job_ids: Iterable[str]) -> set[str]: + job_ids = [job_id for job_id in job_ids if job_id] + if not job_ids: + return set() + + result = _run( + [ + "squeue", + "-h", + "-j", + ",".join(job_ids), + "-o", + "%A", + ], + check=False, + ) + if result.returncode != 0: + return set() + return {line.strip() for line in result.stdout.splitlines() if line.strip()} + + +def query_sacct(job_id: str) -> tuple[str | None, str | None]: + if not _has_cmd("sacct"): + return None, None + + result = _run( + [ + "sacct", + "-X", + "-n", + "-P", + "-j", + job_id, + "-o", + "JobIDRaw,State,ExitCode", + ], + check=False, + ) + if result.returncode != 0: + return None, None + + for line in result.stdout.splitlines(): + parts = line.strip().split("|") + if len(parts) < 3: + continue + job_id_raw, state, exit_code = parts[:3] + if job_id_raw == job_id: + return _normalize_state(state), exit_code + return None, None + + +def wait_for_jobs(jobs: list[JobSpec], *, poll_interval: int, timeout_seconds: int | None) -> None: + outstanding = {job.job_id for job in jobs if job.job_id} + start = time.monotonic() + previous_remaining = len(outstanding) + + while outstanding: + if timeout_seconds is not None and (time.monotonic() - start) > timeout_seconds: + raise TimeoutError( + f"Timed out waiting for {len(outstanding)} Slurm job(s): " + + ", ".join(sorted(outstanding)) + ) + + active = active_job_ids(outstanding) + finished = outstanding - active + if finished: + outstanding = active + + remaining = len(outstanding) + if remaining != previous_remaining or finished: + done = len(jobs) - remaining + print(f"[wait] {done}/{len(jobs)} jobs finished, {remaining} remaining", flush=True) + previous_remaining = remaining + + if outstanding: + time.sleep(poll_interval) + + +def _combined_log_text(job: JobSpec) -> str: + parts: list[str] = [] + if job.stdout_path.exists(): + parts.append(job.stdout_path.read_text(encoding="utf-8", errors="replace")) + if job.stderr_path.exists(): + stderr_text = job.stderr_path.read_text(encoding="utf-8", errors="replace") + if stderr_text: + parts.append(stderr_text) + return "\n".join(parts) + + +def _extract_reason_from_log(text: str) -> str: + patterns = [ + r"short test summary info[\s\S]*$", + r"=+ FAILURES =+[\s\S]*$", + r"Traceback[\s\S]*$", + r"(?m)^E\s+.*$", + r"(?m)^FAILED .*$", + r"(?m)^ERROR .*$", + r"(?m)^.*Killed.*$", + r"(?m)^.*PASSED.*$", + r"(?m)^.*SKIPPED.*$", + ] + for pattern in patterns: + match = re.search(pattern, text) + if match: + snippet = match.group(0).strip() + if len(snippet) > 1200: + snippet = snippet[-1200:] + return snippet + + non_empty_lines = [line.rstrip() for line in text.splitlines() if line.strip()] + if not non_empty_lines: + return "No log output captured." + return "\n".join(non_empty_lines[-20:]) + + +def classify_job(job: JobSpec) -> None: + job.slurm_state, job.exit_code = query_sacct(job.job_id or "") + text = _combined_log_text(job) + state = job.slurm_state + + if state in FAIL_STATES: + job.outcome = "FAILED" + job.reason = f"Slurm state: {state}\n{_extract_reason_from_log(text)}" + return + + if re.search(r"(?im)\bkilled\b", text): + job.outcome = "FAILED" + job.reason = _extract_reason_from_log(text) + return + + if re.search(r"(?m)^FAILED ", text) or re.search(r"(?m)^ERROR ", text): + job.outcome = "FAILED" + job.reason = _extract_reason_from_log(text) + return + + if re.search(r"=+ FAILURES =+", text) or "Traceback" in text: + job.outcome = "FAILED" + job.reason = _extract_reason_from_log(text) + return + + if re.search(r"(?i)\b\d+\s+skipped\b", text) or " SKIPPED" in text: + job.outcome = "SKIPPED" + job.reason = _extract_reason_from_log(text) + return + + if re.search(r"(?i)\b\d+\s+passed\b", text) or " PASSED" in text: + job.outcome = "PASSED" + job.reason = _extract_reason_from_log(text) + return + + if state in PASS_STATES: + job.outcome = "PASSED" + job.reason = _extract_reason_from_log(text) + return + + job.outcome = "UNKNOWN" + job.reason = _extract_reason_from_log(text) + + +def write_summary(log_dir: Path, jobs: list[JobSpec]) -> Path: + payload = { + "generated_at": dt.datetime.now().isoformat(), + "repo_root": str(REPO_ROOT), + "jobs": [ + { + "index": job.index, + "nodeid": job.nodeid, + "job_id": job.job_id, + "slurm_state": job.slurm_state, + "exit_code": job.exit_code, + "outcome": job.outcome, + "stdout_log": str(job.stdout_path), + "stderr_log": str(job.stderr_path), + "rerun_command": job.rerun_command, + "reason": job.reason, + } + for job in jobs + ], + } + summary_path = log_dir / "summary.json" + summary_path.write_text(json.dumps(payload, indent=2), encoding="utf-8") + return summary_path + + +def print_summary(jobs: list[JobSpec], summary_path: Path) -> int: + counts: dict[str, int] = {} + for job in jobs: + counts[job.outcome or "UNKNOWN"] = counts.get(job.outcome or "UNKNOWN", 0) + 1 + + print("\nSummary") + for outcome in sorted(counts): + print(f" {outcome}: {counts[outcome]}") + print(f" summary_json: {summary_path}") + + problem_jobs = [job for job in jobs if job.outcome not in {"PASSED", "SKIPPED"}] + if problem_jobs: + print("\nProblems") + for job in problem_jobs: + print(f" {job.nodeid}") + print(f" slurm_job: {job.job_id}") + print(f" state: {job.slurm_state or 'unknown'}") + print(f" stdout: {job.stdout_path}") + print(f" stderr: {job.stderr_path}") + print(f" rerun: {job.rerun_command}") + if job.reason: + for line in job.reason.splitlines()[:20]: + print(f" {line}") + return 1 if problem_jobs else 0 + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description=( + "Submit AlphaFold2 functional tests to Slurm in parallel, wait for completion, " + "and summarize the logs." + ) + ) + parser.add_argument( + "nodeid", + nargs="*", + help=( + "Optional exact pytest node IDs to submit. If omitted, tests are collected " + f"from {DEFAULT_TEST_FILE.relative_to(REPO_ROOT)}." + ), + ) + parser.add_argument( + "--test-file", + default=str(DEFAULT_TEST_FILE), + help="Pytest file to collect from. Defaults to test/check_alphafold2_predictions.py", + ) + parser.add_argument( + "-k", + dest="k_expr", + default=None, + help="Optional pytest -k expression applied during collection.", + ) + parser.add_argument( + "--max-tests", + type=int, + default=None, + help="Submit at most this many collected tests.", + ) + parser.add_argument( + "--list", + action="store_true", + help="List collected node IDs and exit without submitting jobs.", + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="Collect tests and write job scripts, but do not call sbatch.", + ) + parser.add_argument( + "--use-temp-dir", + action=argparse.BooleanOptionalAction, + default=True, + help=( + "Run target tests with isolated temporary output directories. " + "Use --no-use-temp-dir to keep the shared repo output tree." + ), + ) + parser.add_argument("--partition", default="gpu-el8", help="Slurm partition/queue.") + parser.add_argument("--qos", default="normal", help="Slurm QoS.") + parser.add_argument("--constraint", default="gaming", help="Optional Slurm constraint.") + parser.add_argument("--account", default=None, help="Optional Slurm account.") + parser.add_argument("--gres", default="gpu:1", help="Slurm gres request, for example gpu:1.") + parser.add_argument("--time", default="12:00:00", help="Per-job walltime.") + parser.add_argument("--cpus-per-task", type=int, default=8, help="CPUs per Slurm task.") + parser.add_argument("--mem", default="16G", help="Per-job memory request.") + parser.add_argument( + "--extra-sbatch-arg", + action="append", + default=[], + help="Additional raw sbatch argument. Can be passed multiple times.", + ) + parser.add_argument( + "--job-name-prefix", + default="af2test", + help="Prefix for Slurm job names.", + ) + parser.add_argument( + "--poll-interval", + type=int, + default=30, + help="Seconds between Slurm polling cycles.", + ) + parser.add_argument( + "--wait-timeout-hours", + type=float, + default=24.0, + help="Maximum hours to wait for all submitted jobs. Use 0 to disable.", + ) + parser.add_argument( + "--log-dir", + default=None, + help="Directory to write job scripts and logs into. Defaults to test_logs/alphafold2_.", + ) + parser.add_argument( + "--python", + default=sys.executable, + help="Python executable used both for collection and inside Slurm jobs.", + ) + return parser.parse_args() + + +def main() -> int: + args = parse_args() + if not _has_cmd("sbatch") and not args.list and not args.dry_run: + raise SystemExit("sbatch is not available in PATH.") + if not _has_cmd("squeue") and not args.list and not args.dry_run: + raise SystemExit("squeue is not available in PATH.") + + test_file = Path(args.test_file).resolve() + if not test_file.exists(): + raise SystemExit(f"Test file does not exist: {test_file}") + + if args.log_dir: + log_dir = Path(args.log_dir).resolve() + else: + log_dir = (DEFAULT_LOG_ROOT / f"alphafold2_{_timestamp()}").resolve() + log_dir.mkdir(parents=True, exist_ok=True) + + if args.nodeid: + nodeids = list(args.nodeid) + else: + nodeids = collect_nodeids( + python_executable=args.python, + test_file=test_file, + k_expr=args.k_expr, + ) + + if args.max_tests is not None: + nodeids = nodeids[: args.max_tests] + + if not nodeids: + print("No tests matched the requested selection.") + return 0 + + if not args.use_temp_dir and len(nodeids) > 1: + raise SystemExit( + "--no-use-temp-dir is not safe for parallel AF2 wrapper runs because " + "the tests share and clean common output roots. Re-run with the default " + "--use-temp-dir, or submit a single nodeid at a time." + ) + + if args.list: + for nodeid in nodeids: + print(nodeid) + return 0 + + print(f"Collected {len(nodeids)} test node(s).") + print(f"Log directory: {log_dir}") + + jobs: list[JobSpec] = [] + for index, nodeid in enumerate(nodeids, start=1): + slug = _slugify(nodeid) + stdout_path = log_dir / f"{index:03d}_{slug}.out" + stderr_path = log_dir / f"{index:03d}_{slug}.err" + script_path = log_dir / f"{index:03d}_{slug}.sbatch.sh" + rerun_command = ( + f"{_quote(args.python)} -m pytest -vv -s {_quote(nodeid)}" + + (" --use-temp-dir" if args.use_temp_dir else "") + ) + job = JobSpec( + index=index, + nodeid=nodeid, + slug=slug, + stdout_path=stdout_path, + stderr_path=stderr_path, + script_path=script_path, + rerun_command=rerun_command, + ) + write_job_script( + job=job, + python_executable=args.python, + use_temp_dir=args.use_temp_dir, + ) + jobs.append(job) + + if args.dry_run: + print("Dry run only. Prepared job scripts:") + for job in jobs: + print(f" {job.nodeid}") + print(f" script: {job.script_path}") + print(f" stdout: {job.stdout_path}") + print(f" stderr: {job.stderr_path}") + return 0 + + for job in jobs: + job.job_id = submit_job(job, args) + print(f"[submit] {job.job_id} {job.nodeid}") + + timeout_seconds: int | None + if args.wait_timeout_hours <= 0: + timeout_seconds = None + else: + timeout_seconds = int(args.wait_timeout_hours * 3600) + + wait_for_jobs( + jobs, + poll_interval=args.poll_interval, + timeout_seconds=timeout_seconds, + ) + + for job in jobs: + classify_job(job) + + summary_path = write_summary(log_dir, jobs) + return print_summary(jobs, summary_path) + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/test/test_alphafold3_predictions.py b/test/test_alphafold3_predictions.py new file mode 100644 index 00000000..b524fa2e --- /dev/null +++ b/test/test_alphafold3_predictions.py @@ -0,0 +1,703 @@ +#!/usr/bin/env python3 +"""Submit AlphaFold3 functional tests to Slurm and summarize results. + +This is a standalone wrapper for `test/check_alphafold3_predictions.py`. +It is intentionally not a pytest test module, despite the filename. + +Typical usage from a login node: + + python test/test_alphafold3_predictions.py + +Run only selected tests: + + python test/test_alphafold3_predictions.py -k chopped + +Enable the runtime benchmark test as well: + + python test/test_alphafold3_predictions.py --include-perf +""" + +from __future__ import annotations + +__test__ = False + +import argparse +import dataclasses +import datetime as dt +import importlib.util +import inspect +import json +import os +import re +import shlex +import subprocess +import sys +import time +import unittest +from pathlib import Path +from typing import Iterable + +from _pytest.mark.expression import Expression + + +REPO_ROOT = Path(__file__).resolve().parents[1] +DEFAULT_TEST_FILE = REPO_ROOT / "test" / "check_alphafold3_predictions.py" +DEFAULT_LOG_ROOT = REPO_ROOT / "test_logs" + +PASS_STATES = {"COMPLETED"} +FAIL_STATES = { + "BOOT_FAIL", + "CANCELLED", + "DEADLINE", + "FAILED", + "NODE_FAIL", + "OUT_OF_MEMORY", + "PREEMPTED", + "REVOKED", + "TIMEOUT", +} + + +@dataclasses.dataclass(slots=True) +class JobSpec: + index: int + nodeid: str + slug: str + stdout_path: Path + stderr_path: Path + script_path: Path + rerun_command: str + job_id: str | None = None + slurm_state: str | None = None + exit_code: str | None = None + outcome: str | None = None + reason: str | None = None + + +def _has_cmd(cmd: str) -> bool: + try: + subprocess.run( + [cmd, "--help"], + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + check=False, + ) + return True + except FileNotFoundError: + return False + + +def _run( + cmd: list[str], + *, + cwd: Path = REPO_ROOT, + check: bool = True, +) -> subprocess.CompletedProcess[str]: + return subprocess.run( + cmd, + cwd=cwd, + text=True, + capture_output=True, + check=check, + ) + + +def _normalize_state(state: str | None) -> str | None: + if not state: + return None + return state.split()[0].rstrip("+") + + +def _slugify(value: str, *, max_len: int = 120) -> str: + slug = re.sub(r"[^A-Za-z0-9._-]+", "_", value).strip("._") + if not slug: + slug = "test" + if len(slug) > max_len: + slug = slug[:max_len].rstrip("._") + return slug + + +def _quote(value: str) -> str: + return shlex.quote(value) + + +def _timestamp() -> str: + return dt.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + + +def _relative_nodeid_prefix(test_file: Path) -> str: + return str(test_file.resolve().relative_to(REPO_ROOT)) + + +def _matches_k_expression(nodeid: str, k_expr: str | None) -> bool: + if not k_expr: + return True + expression = Expression.compile(k_expr) + lowered = nodeid.lower() + return expression.evaluate(lambda token: token.lower() in lowered) + + +def _collect_nodeids_from_module_import(test_file: Path, k_expr: str | None) -> list[str]: + module_name = f"_codex_collect_{test_file.stem}" + spec = importlib.util.spec_from_file_location(module_name, test_file) + if spec is None or spec.loader is None: + raise RuntimeError(f"Failed to create import spec for {test_file}") + + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + try: + spec.loader.exec_module(module) + finally: + sys.modules.pop(module_name, None) + + prefix = _relative_nodeid_prefix(test_file) + nodeids: list[str] = [] + for _, cls in inspect.getmembers(module, inspect.isclass): + if cls.__module__ != module.__name__: + continue + if not issubclass(cls, unittest.TestCase): + continue + if not cls.__name__.startswith("Test"): + continue + + for method_name in sorted(name for name in dir(cls) if name.startswith("test")): + nodeid = f"{prefix}::{cls.__name__}::{method_name}" + if _matches_k_expression(nodeid, k_expr): + nodeids.append(nodeid) + return nodeids + + +def collect_nodeids( + *, + python_executable: str, + test_file: Path, + k_expr: str | None, +) -> list[str]: + cmd = [ + python_executable, + "-m", + "pytest", + "--collect-only", + "-q", + str(test_file), + ] + if k_expr: + cmd.extend(["-k", k_expr]) + result = _run(cmd, check=False) + if result.returncode != 0: + raise RuntimeError( + "pytest collection failed.\n" + f"Command: {' '.join(_quote(part) for part in cmd)}\n\n" + f"STDOUT:\n{result.stdout}\n\nSTDERR:\n{result.stderr}" + ) + + nodeids: list[str] = [] + for raw_line in result.stdout.splitlines(): + line = raw_line.strip() + if not line: + continue + if ".py::" not in line: + continue + if line.startswith("ERROR ") or line.startswith("SKIPPED "): + continue + nodeids.append(line) + if nodeids: + return nodeids + + return _collect_nodeids_from_module_import(test_file, k_expr) + + +def write_job_script( + *, + job: JobSpec, + python_executable: str, + use_temp_dir: bool, + include_perf: bool, +) -> None: + pytest_cmd = [ + python_executable, + "-m", + "pytest", + "-vv", + "-s", + job.nodeid, + ] + if use_temp_dir: + pytest_cmd.append("--use-temp-dir") + + env_lines = [ + "export PYTHONUNBUFFERED=1", + ] + if include_perf: + env_lines.append("export AF3_RUN_PERF_TESTS=1") + + script = "\n".join( + [ + "#!/bin/bash", + "set -euo pipefail", + f"cd {_quote(str(REPO_ROOT))}", + *env_lines, + "echo \"[$(date)] Running test node:\"", + f"echo {_quote(job.nodeid)}", + "echo \"[$(date)] Host: $(hostname)\"", + "echo \"[$(date)] Python: $(which python || true)\"", + " ".join(_quote(part) for part in pytest_cmd), + "", + ] + ) + job.script_path.write_text(script, encoding="utf-8") + job.script_path.chmod(0o755) + + +def submit_job(job: JobSpec, args: argparse.Namespace) -> str: + cmd = [ + "sbatch", + "--parsable", + "--export=ALL", + f"--job-name={args.job_name_prefix}_{job.index:03d}", + f"--chdir={REPO_ROOT}", + f"--output={job.stdout_path}", + f"--error={job.stderr_path}", + f"--time={args.time}", + "--ntasks=1", + f"--cpus-per-task={args.cpus_per_task}", + f"--mem={args.mem}", + ] + if args.partition: + cmd.append(f"--partition={args.partition}") + if args.qos: + cmd.append(f"--qos={args.qos}") + if args.constraint: + cmd.append(f"--constraint={args.constraint}") + if args.account: + cmd.append(f"--account={args.account}") + if args.gres: + cmd.append(f"--gres={args.gres}") + for extra_arg in args.extra_sbatch_arg: + cmd.append(extra_arg) + cmd.append(str(job.script_path)) + + result = _run(cmd) + raw_job_id = result.stdout.strip().splitlines()[-1] + return raw_job_id.split(";", 1)[0] + + +def active_job_ids(job_ids: Iterable[str]) -> set[str]: + job_ids = [job_id for job_id in job_ids if job_id] + if not job_ids: + return set() + + result = _run( + [ + "squeue", + "-h", + "-j", + ",".join(job_ids), + "-o", + "%A", + ], + check=False, + ) + if result.returncode != 0: + return set() + return {line.strip() for line in result.stdout.splitlines() if line.strip()} + + +def query_sacct(job_id: str) -> tuple[str | None, str | None]: + if not _has_cmd("sacct"): + return None, None + + result = _run( + [ + "sacct", + "-X", + "-n", + "-P", + "-j", + job_id, + "-o", + "JobIDRaw,State,ExitCode", + ], + check=False, + ) + if result.returncode != 0: + return None, None + + for line in result.stdout.splitlines(): + parts = line.strip().split("|") + if len(parts) < 3: + continue + job_id_raw, state, exit_code = parts[:3] + if job_id_raw == job_id: + return _normalize_state(state), exit_code + return None, None + + +def wait_for_jobs(jobs: list[JobSpec], *, poll_interval: int, timeout_seconds: int | None) -> None: + outstanding = {job.job_id for job in jobs if job.job_id} + start = time.monotonic() + previous_remaining = len(outstanding) + + while outstanding: + if timeout_seconds is not None and (time.monotonic() - start) > timeout_seconds: + raise TimeoutError( + f"Timed out waiting for {len(outstanding)} Slurm job(s): " + + ", ".join(sorted(outstanding)) + ) + + active = active_job_ids(outstanding) + finished = outstanding - active + if finished: + outstanding = active + + remaining = len(outstanding) + if remaining != previous_remaining or finished: + done = len(jobs) - remaining + print(f"[wait] {done}/{len(jobs)} jobs finished, {remaining} remaining", flush=True) + previous_remaining = remaining + + if outstanding: + time.sleep(poll_interval) + + +def _combined_log_text(job: JobSpec) -> str: + parts: list[str] = [] + if job.stdout_path.exists(): + parts.append(job.stdout_path.read_text(encoding="utf-8", errors="replace")) + if job.stderr_path.exists(): + stderr_text = job.stderr_path.read_text(encoding="utf-8", errors="replace") + if stderr_text: + parts.append(stderr_text) + return "\n".join(parts) + + +def _extract_reason_from_log(text: str) -> str: + patterns = [ + r"short test summary info[\s\S]*$", + r"=+ FAILURES =+[\s\S]*$", + r"Traceback[\s\S]*$", + r"(?m)^E\s+.*$", + r"(?m)^FAILED .*$", + r"(?m)^ERROR .*$", + r"(?m)^.*Killed.*$", + r"(?m)^.*PASSED.*$", + r"(?m)^.*SKIPPED.*$", + ] + for pattern in patterns: + match = re.search(pattern, text) + if match: + snippet = match.group(0).strip() + if len(snippet) > 1200: + snippet = snippet[-1200:] + return snippet + + non_empty_lines = [line.rstrip() for line in text.splitlines() if line.strip()] + if not non_empty_lines: + return "No log output captured." + return "\n".join(non_empty_lines[-20:]) + + +def classify_job(job: JobSpec) -> None: + job.slurm_state, job.exit_code = query_sacct(job.job_id or "") + text = _combined_log_text(job) + state = job.slurm_state + + if state in FAIL_STATES: + job.outcome = "FAILED" + job.reason = f"Slurm state: {state}\n{_extract_reason_from_log(text)}" + return + + if re.search(r"(?im)\bkilled\b", text): + job.outcome = "FAILED" + job.reason = _extract_reason_from_log(text) + return + + if re.search(r"(?m)^FAILED ", text) or re.search(r"(?m)^ERROR ", text): + job.outcome = "FAILED" + job.reason = _extract_reason_from_log(text) + return + + if re.search(r"=+ FAILURES =+", text) or "Traceback" in text: + job.outcome = "FAILED" + job.reason = _extract_reason_from_log(text) + return + + if re.search(r"(?i)\b\d+\s+skipped\b", text) or " SKIPPED" in text: + job.outcome = "SKIPPED" + job.reason = _extract_reason_from_log(text) + return + + if re.search(r"(?i)\b\d+\s+passed\b", text) or " PASSED" in text: + job.outcome = "PASSED" + job.reason = _extract_reason_from_log(text) + return + + if state in PASS_STATES: + job.outcome = "PASSED" + job.reason = _extract_reason_from_log(text) + return + + job.outcome = "UNKNOWN" + job.reason = _extract_reason_from_log(text) + + +def write_summary(log_dir: Path, jobs: list[JobSpec]) -> Path: + payload = { + "generated_at": dt.datetime.now().isoformat(), + "repo_root": str(REPO_ROOT), + "jobs": [ + { + "index": job.index, + "nodeid": job.nodeid, + "job_id": job.job_id, + "slurm_state": job.slurm_state, + "exit_code": job.exit_code, + "outcome": job.outcome, + "stdout_log": str(job.stdout_path), + "stderr_log": str(job.stderr_path), + "rerun_command": job.rerun_command, + "reason": job.reason, + } + for job in jobs + ], + } + summary_path = log_dir / "summary.json" + summary_path.write_text(json.dumps(payload, indent=2), encoding="utf-8") + return summary_path + + +def print_summary(jobs: list[JobSpec], summary_path: Path) -> int: + counts: dict[str, int] = {} + for job in jobs: + counts[job.outcome or "UNKNOWN"] = counts.get(job.outcome or "UNKNOWN", 0) + 1 + + print("\nSummary") + for outcome in sorted(counts): + print(f" {outcome}: {counts[outcome]}") + print(f" summary_json: {summary_path}") + + problem_jobs = [job for job in jobs if job.outcome not in {"PASSED", "SKIPPED"}] + if problem_jobs: + print("\nProblems") + for job in problem_jobs: + print(f" {job.nodeid}") + print(f" slurm_job: {job.job_id}") + print(f" state: {job.slurm_state or 'unknown'}") + print(f" stdout: {job.stdout_path}") + print(f" stderr: {job.stderr_path}") + print(f" rerun: {job.rerun_command}") + if job.reason: + for line in job.reason.splitlines()[:20]: + print(f" {line}") + return 1 if problem_jobs else 0 + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description=( + "Submit AlphaFold3 functional tests to Slurm in parallel, wait for completion, " + "and summarize the logs." + ) + ) + parser.add_argument( + "nodeid", + nargs="*", + help=( + "Optional exact pytest node IDs to submit. If omitted, tests are collected " + f"from {DEFAULT_TEST_FILE.relative_to(REPO_ROOT)}." + ), + ) + parser.add_argument( + "--test-file", + default=str(DEFAULT_TEST_FILE), + help="Pytest file to collect from. Defaults to test/check_alphafold3_predictions.py", + ) + parser.add_argument( + "-k", + dest="k_expr", + default=None, + help="Optional pytest -k expression applied during collection.", + ) + parser.add_argument( + "--max-tests", + type=int, + default=None, + help="Submit at most this many collected tests.", + ) + parser.add_argument( + "--list", + action="store_true", + help="List collected node IDs and exit without submitting jobs.", + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="Collect tests and write job scripts, but do not call sbatch.", + ) + parser.add_argument( + "--include-perf", + action="store_true", + help="Set AF3_RUN_PERF_TESTS=1 inside jobs so the runtime benchmark is included.", + ) + parser.add_argument( + "--use-temp-dir", + action=argparse.BooleanOptionalAction, + default=True, + help=( + "Run target tests with isolated temporary output directories. " + "Use --no-use-temp-dir to keep the shared repo output tree." + ), + ) + parser.add_argument("--partition", default="gpu-el8", help="Slurm partition/queue.") + parser.add_argument("--qos", default="normal", help="Slurm QoS.") + parser.add_argument("--constraint", default="gaming", help="Optional Slurm constraint.") + parser.add_argument("--account", default=None, help="Optional Slurm account.") + parser.add_argument("--gres", default="gpu:1", help="Slurm gres request, for example gpu:1.") + parser.add_argument("--time", default="12:00:00", help="Per-job walltime.") + parser.add_argument("--cpus-per-task", type=int, default=8, help="CPUs per Slurm task.") + parser.add_argument("--mem", default="16G", help="Per-job memory request.") + parser.add_argument( + "--extra-sbatch-arg", + action="append", + default=[], + help="Additional raw sbatch argument. Can be passed multiple times.", + ) + parser.add_argument( + "--job-name-prefix", + default="af3test", + help="Prefix for Slurm job names.", + ) + parser.add_argument( + "--poll-interval", + type=int, + default=30, + help="Seconds between Slurm polling cycles.", + ) + parser.add_argument( + "--wait-timeout-hours", + type=float, + default=24.0, + help="Maximum hours to wait for all submitted jobs. Use 0 to disable.", + ) + parser.add_argument( + "--log-dir", + default=None, + help="Directory to write job scripts and logs into. Defaults to test_logs/alphafold3_.", + ) + parser.add_argument( + "--python", + default=sys.executable, + help="Python executable used both for collection and inside Slurm jobs.", + ) + return parser.parse_args() + + +def main() -> int: + args = parse_args() + if not _has_cmd("sbatch") and not args.list and not args.dry_run: + raise SystemExit("sbatch is not available in PATH.") + if not _has_cmd("squeue") and not args.list and not args.dry_run: + raise SystemExit("squeue is not available in PATH.") + + test_file = Path(args.test_file).resolve() + if not test_file.exists(): + raise SystemExit(f"Test file does not exist: {test_file}") + + if args.log_dir: + log_dir = Path(args.log_dir).resolve() + else: + log_dir = (DEFAULT_LOG_ROOT / f"alphafold3_{_timestamp()}").resolve() + log_dir.mkdir(parents=True, exist_ok=True) + + if args.nodeid: + nodeids = list(args.nodeid) + else: + nodeids = collect_nodeids( + python_executable=args.python, + test_file=test_file, + k_expr=args.k_expr, + ) + + if args.max_tests is not None: + nodeids = nodeids[: args.max_tests] + + if not nodeids: + print("No tests matched the requested selection.") + return 0 + + if not args.use_temp_dir and len(nodeids) > 1: + raise SystemExit( + "--no-use-temp-dir is not safe for parallel AF3 wrapper runs because " + "the tests share and clean common output roots. Re-run with the default " + "--use-temp-dir, or submit a single nodeid at a time." + ) + + if args.list: + for nodeid in nodeids: + print(nodeid) + return 0 + + print(f"Collected {len(nodeids)} test node(s).") + print(f"Log directory: {log_dir}") + + jobs: list[JobSpec] = [] + for index, nodeid in enumerate(nodeids, start=1): + slug = _slugify(nodeid) + stdout_path = log_dir / f"{index:03d}_{slug}.out" + stderr_path = log_dir / f"{index:03d}_{slug}.err" + script_path = log_dir / f"{index:03d}_{slug}.sbatch.sh" + rerun_command = ( + f"{_quote(args.python)} -m pytest -vv -s {_quote(nodeid)}" + + (" --use-temp-dir" if args.use_temp_dir else "") + ) + job = JobSpec( + index=index, + nodeid=nodeid, + slug=slug, + stdout_path=stdout_path, + stderr_path=stderr_path, + script_path=script_path, + rerun_command=rerun_command, + ) + write_job_script( + job=job, + python_executable=args.python, + use_temp_dir=args.use_temp_dir, + include_perf=args.include_perf, + ) + jobs.append(job) + + if args.dry_run: + print("Dry run only. Prepared job scripts:") + for job in jobs: + print(f" {job.nodeid}") + print(f" script: {job.script_path}") + print(f" stdout: {job.stdout_path}") + print(f" stderr: {job.stderr_path}") + return 0 + + for job in jobs: + job.job_id = submit_job(job, args) + print(f"[submit] {job.job_id} {job.nodeid}") + + timeout_seconds: int | None + if args.wait_timeout_hours <= 0: + timeout_seconds = None + else: + timeout_seconds = int(args.wait_timeout_hours * 3600) + + wait_for_jobs( + jobs, + poll_interval=args.poll_interval, + timeout_seconds=timeout_seconds, + ) + + for job in jobs: + classify_job(job) + + summary_path = write_summary(log_dir, jobs) + return print_summary(jobs, summary_path) + + +if __name__ == "__main__": + raise SystemExit(main())