Source code for neurosnap.algos.lddt

"""Code for LDDT (Local Distance Difference Test) calculation."""

from typing import Dict, List, Literal, Optional, Sequence, Tuple, Union

import numpy as np

from neurosnap.constants.structure import (
  BACKBONE_ATOMS_AA,
  BACKBONE_ATOMS_DNA,
  BACKBONE_ATOMS_RNA,
)
from neurosnap.structure import Atom, Residue, Structure
from neurosnap.structure._common import classify_polymer_residue

_PROTEIN_BACKBONE_FALLBACK = tuple(atom for atom in ("CA", "N", "C") if atom in BACKBONE_ATOMS_AA)

_NUCLEOTIDE_BACKBONE_ATOMS = set(BACKBONE_ATOMS_DNA).union(BACKBONE_ATOMS_RNA)
_NUCLEOTIDE_PREF_BASES = ("C4'", "C3'", "C1'", "C2'", "P", "O4'", "O3'", "O5'")
_NUCLEOTIDE_ATOM_PRIORITY: List[Tuple[str, ...]] = []
_seen_nuc_bases = set()
for base in _NUCLEOTIDE_PREF_BASES:
  if base in _NUCLEOTIDE_BACKBONE_ATOMS and base not in _seen_nuc_bases:
    names = tuple(dict.fromkeys((base, base.replace("'", "*"))))
    _NUCLEOTIDE_ATOM_PRIORITY.append(names)
    _seen_nuc_bases.add(base)
for base in sorted(_NUCLEOTIDE_BACKBONE_ATOMS):
  if base not in _seen_nuc_bases:
    names = tuple(dict.fromkeys((base, base.replace("'", "*"))))
    _NUCLEOTIDE_ATOM_PRIORITY.append(names)
    _seen_nuc_bases.add(base)

_WATER_RESIDUE_NAMES = {"HOH", "WAT"}
_PROTEIN_BACKBONE_NAMES = {atom.upper() for atom in BACKBONE_ATOMS_AA}
SiteKey = Tuple[str, int, str, str, str, str]


# Helpers for metrics calculated using numpy scheme
def _get_flattened(dmap):
  if dmap.ndim == 1:
    return dmap
  elif dmap.ndim == 2:
    return dmap[np.triu_indices_from(dmap, k=1)]
  else:
    assert False, "ERROR: the passes array has dimension not equal to 2 or 1!"


def _get_separations(dmap):
  t_indices = np.triu_indices_from(dmap, k=1)
  separations = np.abs(t_indices[0] - t_indices[1])
  return separations


# return a 1D boolean array indicating where the sequence separation in the
# upper triangle meets the threshold comparison
def _get_sep_thresh_b_indices(dmap, thresh, comparator):
  assert comparator in {"gt", "lt", "ge", "le"}, "ERROR: Unknown comparator for thresholding!"
  separations = _get_separations(dmap)
  if comparator == "gt":
    threshed = separations > thresh
  elif comparator == "lt":
    threshed = separations < thresh
  elif comparator == "ge":
    threshed = separations >= thresh
  elif comparator == "le":
    threshed = separations <= thresh
  return threshed


# return a 1D boolean array indicating where the distance in the
# upper triangle meets the threshold comparison
def _get_dist_thresh_b_indices(dmap, thresh, comparator):
  assert comparator in {"gt", "lt", "ge", "le"}, "ERROR: Unknown comparator for thresholding!"
  dmap_flat = _get_flattened(dmap)
  if comparator == "gt":
    threshed = dmap_flat > thresh
  elif comparator == "lt":
    threshed = dmap_flat < thresh
  elif comparator == "ge":
    threshed = dmap_flat >= thresh
  elif comparator == "le":
    threshed = dmap_flat <= thresh
  return threshed


def _is_hydrogen_atom(atom: Atom) -> bool:
  element = str(atom.element).strip().upper()
  if element == "H":
    return True
  return atom.atom_name.strip().upper().startswith("H")


def _classify_residue_for_lddt(residue: Residue) -> Optional[Literal["protein", "dna", "rna"]]:
  """Classify standard and modified polymer residues for lDDT site selection."""
  polymer_type = classify_polymer_residue(residue)
  if polymer_type is not None:
    return polymer_type

  atom_names = {atom.atom_name.strip().upper() for atom in residue.atoms()}
  if len(atom_names.intersection(_PROTEIN_BACKBONE_NAMES)) >= 2:
    return "protein"
  return None


def _get_atom(residue: Residue, name: str) -> Optional[Atom]:
  """Fetch an atom by name, handling historical prime markers (* vs ')."""
  atom_lookup = {atom.atom_name.strip().upper(): atom for atom in residue.atoms()}
  normalized = name.strip().upper()
  if normalized in atom_lookup:
    return atom_lookup[normalized]
  alt = normalized.replace("'", "*") if "'" in normalized else normalized.replace("*", "'")
  if alt != normalized and alt in atom_lookup:
    return atom_lookup[alt]
  return None


def _select_polymer_representative_atom(residue: Residue, polymer_type: Literal["protein", "dna", "rna"]) -> Optional[Atom]:
  """Return a deterministic representative atom for a polymer residue."""
  if polymer_type == "protein":
    if residue.res_name.strip().upper() != "GLY":
      atom = _get_atom(residue, "CB")
      if atom is not None:
        return atom
    for atom_name in _PROTEIN_BACKBONE_FALLBACK:
      atom = _get_atom(residue, atom_name)
      if atom is not None:
        return atom
    return None

  for atom_names in _NUCLEOTIDE_ATOM_PRIORITY:
    for atom_name in atom_names:
      atom = _get_atom(residue, atom_name)
      if atom is not None:
        return atom
  return None


def _extract_cb_coords_from_structure(
  structure: Structure,
  *,
  chains: Optional[List[str]] = None,
) -> Dict[SiteKey, Tuple[float, float, float]]:
  """Collect aligned analysis-site coordinates from a structure.

  Returns:
    dict keyed by a stable site identifier -> (x,y,z)
  """
  coords: Dict[SiteKey, Tuple[float, float, float]] = {}
  chain_lookup = {chain.chain_id: chain for chain in structure.chains()}
  chain_ids = [chain.chain_id for chain in structure.chains()] if not chains else list(chains)
  for cid in chain_ids:
    chain = chain_lookup.get(cid)
    if chain is None:
      continue
    for residue in chain.residues():
      resname = residue.res_name.strip().upper()
      if resname in _WATER_RESIDUE_NAMES:
        continue

      polymer_type = _classify_residue_for_lddt(residue)
      if polymer_type is not None:
        atom = _select_polymer_representative_atom(residue, polymer_type)
        if atom is None:
          continue

        key = (cid, int(residue.res_id), residue.ins_code, "polymer", "", "")
        coords[key] = (float(atom.coord[0]), float(atom.coord[1]), float(atom.coord[2]))
        continue

      for atom in residue.atoms():
        if _is_hydrogen_atom(atom):
          continue
        key = (cid, int(residue.res_id), residue.ins_code, "nonpolymer", resname, atom.atom_name.strip().upper())
        coords[key] = (float(atom.coord[0]), float(atom.coord[1]), float(atom.coord[2]))
  return coords


def _coords_to_distmat(ordered_keys: List[SiteKey], coord_map: Dict[SiteKey, Tuple[float, float, float]]) -> np.ndarray:
  """Build an NxN Euclidean distance matrix from an ordered list of residue keys and a coord map."""
  if not ordered_keys:
    return np.empty((0, 0))
  pts = np.array([coord_map[k] for k in ordered_keys], dtype=float)  # (N,3)
  # Pairwise distances with broadcasting
  diff = pts[:, None, :] - pts[None, :, :]
  dist = np.sqrt(np.sum(diff * diff, axis=-1))
  return dist


def _calc_lddt_from_maps(
  true_map: np.ndarray,
  pred_map: np.ndarray,
  *,
  R: float = 15.0,
  sep_thresh: int = -1,
  T_set: Sequence[float] = (0.5, 1.0, 2.0, 4.0),
) -> float:
  """
  Mariani V, Biasini M, Barbato A, Schwede T.
  lDDT: a local superposition-free score for comparing protein structures and models using distance difference tests.
  Bioinformatics. 2013 Nov 1;29(21):2722-8.
  doi: 10.1093/bioinformatics/btt473.
  Epub 2013 Aug 27.
  PMID: 23986568; PMCID: PMC3799472.
  """

  # Helper for number preserved in a threshold
  def get_n_preserved(ref_flat, mod_flat, thresh):
    err = np.abs(ref_flat - mod_flat)
    n_preserved = (err < thresh).sum()
    return n_preserved

  # flatten upper triangles
  true_flat_map = _get_flattened(true_map)
  pred_flat_map = _get_flattened(pred_map)

  # Find set L
  S_thresh_indices = _get_sep_thresh_b_indices(true_map, sep_thresh, "gt")
  R_thresh_indices = _get_dist_thresh_b_indices(true_flat_map, R, "lt")

  L_indices = S_thresh_indices & R_thresh_indices

  L_n = L_indices.sum()
  if L_n == 0:
    return float("nan")

  true_flat_in_L = true_flat_map[L_indices]
  pred_flat_in_L = pred_flat_map[L_indices]

  # Calculated LDDT
  preserved_fractions = []
  for _thresh in T_set:
    _n_preserved = get_n_preserved(true_flat_in_L, pred_flat_in_L, _thresh)
    _f_preserved = _n_preserved / L_n
    preserved_fractions.append(_f_preserved)

  lddt = np.mean(preserved_fractions)
  return lddt


[docs] def calc_lddt( reference: Union[np.ndarray, Structure], prediction: Union[np.ndarray, Structure], *, chains_ref: Optional[List[str]] = None, chains_pred: Optional[List[str]] = None, R: float = 15.0, sep_thresh: int = -1, T_set: Sequence[float] = (0.5, 1.0, 2.0, 4.0), ) -> float: """Compute lDDT from distance maps or single-model Neurosnap structures. For structure inputs, lDDT is computed over aligned analysis sites rather than only canonical amino-acid residues. Protein and nucleotide residues, including modified residues recognized from residue chemistry/backbone content, contribute one representative site per residue. Proteins use ``CB`` when available and fall back to backbone atoms; nucleotides use a deterministic sugar/phosphate proxy atom. Non-polymer residues such as ligands contribute one site per heavy atom, while waters and hydrogens are ignored. Args: reference: Distance map or single-model structure used as the ground truth. prediction: Distance map or single-model structure to compare against the reference. chains_ref: Chain identifiers to include when `reference` is a structure. chains_pred: Chain identifiers to include when `prediction` is a structure. R: Maximum reference distance to consider when defining the L set. sep_thresh: Minimum sequence separation between residue pairs. T_set: Error thresholds used to compute preserved distance fractions. Returns: lDDT score between the reference and prediction. Typical range is [0.0, 1.0], where 1.0 indicates perfect local distance agreement and 0.0 indicates no preserved distances under the selected thresholds. Returns NaN when no residue pairs satisfy the L-set criteria (for example, no pairs within `R` and above `sep_thresh`). Raises: ValueError: If distance maps do not share the same shape. TypeError: If the inputs are not both arrays or both structure containers. """ is_ref_array = isinstance(reference, np.ndarray) is_pred_array = isinstance(prediction, np.ndarray) is_ref_structure = isinstance(reference, Structure) is_pred_structure = isinstance(prediction, Structure) if is_ref_array and is_pred_array: if reference.shape != prediction.shape: raise ValueError("Distance maps must share the same shape to compute lDDT.") return _calc_lddt_from_maps(reference, prediction, R=R, sep_thresh=sep_thresh, T_set=T_set) if is_ref_structure and is_pred_structure: ref_cb = _extract_cb_coords_from_structure(reference, chains=chains_ref) pred_cb = _extract_cb_coords_from_structure(prediction, chains=chains_pred) # Intersect keys to ensure 1:1 residue correspondence common_keys = sorted(set(ref_cb.keys()).intersection(pred_cb.keys()), key=lambda x: (x[0], x[1])) if len(common_keys) < 2: raise Exception("Not enough residues/pairs to compute meaningful lDDT") # Build distance maps on the common residue set true_map = _coords_to_distmat(common_keys, ref_cb) pred_map = _coords_to_distmat(common_keys, pred_cb) return _calc_lddt_from_maps(true_map, pred_map, R=R, sep_thresh=sep_thresh, T_set=T_set) raise TypeError("calc_lddt expects both inputs to be either numpy.ndarray distance maps or single-model Neurosnap Structure objects.")