Source code for neurosnap.structure.compare

"""Pairwise comparison and alignment functions for Neurosnap structures."""

from typing import Optional, Sequence

import numpy as np

from ._common import backbone_atom_order, coord_matrix
from .structure import Structure


def _matched_backbone_coords(
  reference: Structure,
  mobile: Structure,
  chains1: Optional[Sequence[str]] = None,
  chains2: Optional[Sequence[str]] = None,
) -> tuple[np.ndarray, np.ndarray]:
  """Return matched backbone coordinates for two structures.

  Matching is done by chain mapping plus residue/atom identity, not by raw
  atom-table order. This keeps pairwise comparisons stable even when the two
  files store chains in different orders or one file includes extra
  non-backbone atoms such as hydrogens.
  """
  chains1 = list(chains1 or [])
  chains2 = list(chains2 or [])
  chains1_provided = bool(chains1)
  chains2_provided = bool(chains2)

  available_reference_chains = reference.chain_ids()
  available_mobile_chains = mobile.chain_ids()

  if chains1:
    for chain_id in chains1:
      if chain_id not in available_reference_chains:
        raise ValueError(f"Chain {chain_id} was not found in the reference structure.")
  else:
    chains1 = available_reference_chains

  if chains2:
    for chain_id in chains2:
      if chain_id not in available_mobile_chains:
        raise ValueError(f"Chain {chain_id} was not found in the mobile structure.")
  else:
    chains2 = available_mobile_chains

  chain_mapping_mode = chains1_provided and chains2_provided
  if chain_mapping_mode and len(chains1) != len(chains2):
    raise ValueError("chains1 and chains2 must contain the same number of chains for pairwise mapping.")

  if chain_mapping_mode:
    reference_chain_specs = [(pair_index, chain_id) for pair_index, chain_id in enumerate(chains1)]
    mobile_chain_specs = [(pair_index, chain_id) for pair_index, chain_id in enumerate(chains2)]
  else:
    # In the default mode, chains must line up by their identifiers in both
    # structures so the residue/atom keys can be matched directly.
    reference_chain_specs = [(chain_id, chain_id) for chain_id in chains1]
    mobile_chain_specs = [(chain_id, chain_id) for chain_id in chains2]

  def backbone_atom_map(structure_model, chain_specs):
    """Build a residue-aware backbone lookup for one structure."""
    chain_lookup = {chain.chain_id: chain for chain in structure_model.chains()}
    atom_map = {}
    for map_key, chain_id in chain_specs:
      chain = chain_lookup.get(chain_id)
      if chain is None:
        continue
      for residue in chain.residues():
        atom_order = backbone_atom_order(residue, include_nucleotides=True)
        if not atom_order:
          continue
        residue_atoms = {atom.atom_name.strip().upper(): atom.coord for atom in residue.atoms()}
        for atom_name in atom_order:
          if atom_name in residue_atoms:
            atom_map[(map_key, residue.res_id, residue.ins_code, atom_name)] = residue_atoms[atom_name]
    return atom_map

  reference_atom_map = backbone_atom_map(reference, reference_chain_specs)
  mobile_atom_map = backbone_atom_map(mobile, mobile_chain_specs)
  if not reference_atom_map:
    raise ValueError("Reference structure does not contain any backbone atoms to align.")
  if not mobile_atom_map:
    raise ValueError("Mobile structure does not contain any backbone atoms to align.")

  common_keys = sorted(reference_atom_map.keys() & mobile_atom_map.keys())
  if not common_keys:
    raise ValueError("Structures do not share common backbone atoms to align.")
  if len(common_keys) != len(reference_atom_map) or len(common_keys) != len(mobile_atom_map):
    raise ValueError("Backbone atom mismatch between structures.")

  reference_coords = np.asarray([reference_atom_map[key] for key in common_keys], dtype=np.float32)
  mobile_coords = np.asarray([mobile_atom_map[key] for key in common_keys], dtype=np.float32)
  return reference_coords, mobile_coords


[docs] def align( reference: Structure, mobile: Structure, chains1: Optional[Sequence[str]] = None, chains2: Optional[Sequence[str]] = None, ): """Align a mobile structure onto a reference structure using polymer backbone atoms. When both ``chains1`` and ``chains2`` are provided, they are interpreted as explicit pairwise chain mappings in matching order. Parameters: reference: Reference single-model :class:`Structure`. mobile: Mobile single-model :class:`Structure` to transform in-place. chains1: Optional reference chain IDs to include in the alignment. chains2: Optional mobile chain IDs to include in the alignment. Returns: ``None``. The mobile structure is transformed in-place. """ if not isinstance(reference, Structure): raise TypeError(f"align() expects reference to be a Structure, found {type(reference).__name__}.") if not isinstance(mobile, Structure): raise TypeError(f"align() expects mobile to be a Structure, found {type(mobile).__name__}.") reference_coords, mobile_coords = _matched_backbone_coords(reference, mobile, chains1=chains1, chains2=chains2) # Standard Kabsch alignment on the matched backbone coordinates. reference_center = reference_coords.mean(axis=0) mobile_center = mobile_coords.mean(axis=0) centered_reference = reference_coords - reference_center centered_mobile = mobile_coords - mobile_center covariance = centered_mobile.T @ centered_reference u_matrix, _, vt_matrix = np.linalg.svd(covariance) rotation = u_matrix @ vt_matrix if np.linalg.det(rotation) < 0: u_matrix[:, -1] *= -1 rotation = u_matrix @ vt_matrix translation = reference_center - (mobile_center @ rotation) all_mobile_coords = coord_matrix(mobile) aligned_coords = all_mobile_coords @ rotation.astype(np.float32) + translation.astype(np.float32) mobile.atoms["x"] = aligned_coords[:, 0] mobile.atoms["y"] = aligned_coords[:, 1] mobile.atoms["z"] = aligned_coords[:, 2]
[docs] def calculate_rmsd( reference: Structure, mobile: Structure, chains1: Optional[Sequence[str]] = None, chains2: Optional[Sequence[str]] = None, align_structures: bool = True, ) -> float: """Calculate backbone RMSD between two structures. Parameters: reference: Reference single-model :class:`Structure`. mobile: Mobile single-model :class:`Structure`. chains1: Optional reference chain IDs to include. chains2: Optional mobile chain IDs to include. align_structures: If ``True``, align the mobile structure before computing RMSD. Returns: Backbone RMSD in Å using the same residue/atom correspondence as :func:`align`. """ if not isinstance(reference, Structure): raise TypeError(f"calculate_rmsd() expects reference to be a Structure, found {type(reference).__name__}.") if not isinstance(mobile, Structure): raise TypeError(f"calculate_rmsd() expects mobile to be a Structure, found {type(mobile).__name__}.") if align_structures: align(reference, mobile, chains1=chains1, chains2=chains2) reference_coords, mobile_coords = _matched_backbone_coords(reference, mobile, chains1=chains1, chains2=chains2) if reference_coords.size == 0: return 0.0 diff = reference_coords - mobile_coords return float(np.sqrt(np.sum(diff**2) / reference_coords.shape[0]))