Source code for neurosnap.algos.LDDT

"""
Code for LDDT (Local Distance Difference Test) calculation, adapted from https://github.com/ba-lab/disteval/blob/main/LDDT.ipynb
"""

from typing import Dict, List, Optional, Sequence, Tuple, Union

import numpy as np

from neurosnap.constants import (
  AA_RECORDS,
  BACKBONE_ATOMS_AA,
  BACKBONE_ATOMS_DNA,
  BACKBONE_ATOMS_RNA,
  STANDARD_NUCLEOTIDES,
)
from neurosnap.protein import Protein

_PROTEIN_BACKBONE_FALLBACK = tuple(atom for atom in ("CA", "N", "C") if atom in BACKBONE_ATOMS_AA)
_PROTEIN_ATOM_PRIORITY = ("CB",) + _PROTEIN_BACKBONE_FALLBACK

_NUCLEOTIDE_CODE_MAP = {code: (code[1] if len(code) == 2 and code[0] == "D" else code) for code in STANDARD_NUCLEOTIDES}
_NUCLEOTIDE_BACKBONE_ATOMS = BACKBONE_ATOMS_DNA.union(BACKBONE_ATOMS_RNA)
_NUCLEOTIDE_PREF_BASES = ("C4'", "C3'", "C1'", "C2'", "P", "O4'", "O3'", "O5'")
_NUCLEOTIDE_ATOM_PRIORITY: List[Tuple[str, ...]] = []
_seen_nuc_bases = set()
for base in _NUCLEOTIDE_PREF_BASES:
  if base in _NUCLEOTIDE_BACKBONE_ATOMS and base not in _seen_nuc_bases:
    names = tuple(dict.fromkeys((base, base.replace("'", "*"))))
    _NUCLEOTIDE_ATOM_PRIORITY.append(names)
    _seen_nuc_bases.add(base)
for base in sorted(_NUCLEOTIDE_BACKBONE_ATOMS):
  if base not in _seen_nuc_bases:
    names = tuple(dict.fromkeys((base, base.replace("'", "*"))))
    _NUCLEOTIDE_ATOM_PRIORITY.append(names)
    _seen_nuc_bases.add(base)


# Helpers for metrics calculated using numpy scheme
def _get_flattened(dmap):
  if dmap.ndim == 1:
    return dmap
  elif dmap.ndim == 2:
    return dmap[np.triu_indices_from(dmap, k=1)]
  else:
    assert False, "ERROR: the passes array has dimension not equal to 2 or 1!"


def _get_separations(dmap):
  t_indices = np.triu_indices_from(dmap, k=1)
  separations = np.abs(t_indices[0] - t_indices[1])
  return separations


# return a 1D boolean array indicating where the sequence separation in the
# upper triangle meets the threshold comparison
def _get_sep_thresh_b_indices(dmap, thresh, comparator):
  assert comparator in {"gt", "lt", "ge", "le"}, "ERROR: Unknown comparator for thresholding!"
  separations = _get_separations(dmap)
  if comparator == "gt":
    threshed = separations > thresh
  elif comparator == "lt":
    threshed = separations < thresh
  elif comparator == "ge":
    threshed = separations >= thresh
  elif comparator == "le":
    threshed = separations <= thresh
  return threshed


# return a 1D boolean array indicating where the distance in the
# upper triangle meets the threshold comparison
def _get_dist_thresh_b_indices(dmap, thresh, comparator):
  assert comparator in {"gt", "lt", "ge", "le"}, "ERROR: Unknown comparator for thresholding!"
  dmap_flat = _get_flattened(dmap)
  if comparator == "gt":
    threshed = dmap_flat > thresh
  elif comparator == "lt":
    threshed = dmap_flat < thresh
  elif comparator == "ge":
    threshed = dmap_flat >= thresh
  elif comparator == "le":
    threshed = dmap_flat <= thresh
  return threshed


def _aa3_to_aa1(resname: str) -> Optional[str]:
  """Map a 3-letter amino-acid code to a 1-letter code when possible."""
  record = AA_RECORDS.get(resname)
  if record is None:
    return None
  if record.code is not None:
    return record.code
  if record.standard_equiv_abr:
    equiv = AA_RECORDS.get(record.standard_equiv_abr)
    if equiv and equiv.code is not None:
      return equiv.code
  return None


def _nucleotide_to_code(resname: str) -> Optional[str]:
  """Return simplified single-letter nucleotide code for standard DNA/RNA residues."""
  return _NUCLEOTIDE_CODE_MAP.get(resname)


def _is_nucleotide(resname: str) -> bool:
  return resname in STANDARD_NUCLEOTIDES


def _get_atom(res, name: str):
  """Fetch an atom by name, handling historical prime markers (* vs ')."""
  if name in res:
    return res[name]
  alt = name.replace("'", "*") if "'" in name else name.replace("*", "'")
  if alt != name and alt in res:
    return res[alt]
  return None


def _extract_cb_coords_from_protein(
  prot: Protein,
  *,
  model: Optional[int] = None,
  chains: Optional[List[str]] = None,
  require_standard_aa: bool = True,
) -> Dict[Tuple[str, int], Tuple[float, float, float]]:
  """Collect per-residue representative coordinates for amino acids and nucleotides.

  Returns:
    dict keyed by (chain_id, res_id) -> (x,y,z)
  """
  if model is None:
    model = prot.models()[0]
  assert model in prot.models(), f"Model {model} not found in protein {prot.title}"

  coords: Dict[Tuple[str, int], Tuple[float, float, float]] = {}
  model_obj = prot.structure[model]

  # Decide which chains to traverse
  chain_ids = [c.id for c in model_obj] if not chains else chains
  for cid in chain_ids:
    if cid not in prot.chains(model):
      # Skip silently if a requested chain isn't present
      continue
    chain = model_obj[cid]
    for res in chain:
      if res.id[0] != " ":
        continue
      if getattr(res, "resname", None) is None:
        continue
      resname = res.resname

      is_amino = resname in AA_RECORDS
      is_nucleotide = _is_nucleotide(resname)
      if not (is_amino or is_nucleotide):
        continue

      if require_standard_aa:
        if is_amino:
          if _aa3_to_aa1(resname) is None:
            continue
        elif _nucleotide_to_code(resname) is None:
          continue

      atom = None
      if is_amino:
        if resname != "GLY":
          atom = _get_atom(res, "CB")
        if atom is None:
          for atom_name in _PROTEIN_BACKBONE_FALLBACK:
            atom = _get_atom(res, atom_name)
            if atom is not None:
              break
      else:
        for atom_names in _NUCLEOTIDE_ATOM_PRIORITY:
          for atom_name in atom_names:
            atom = _get_atom(res, atom_name)
            if atom is not None:
              break
          if atom is not None:
            break

      if atom is None:
        continue

      key = (cid, res.id[1])  # (chain, residue sequence number)
      coords[key] = (float(atom.coord[0]), float(atom.coord[1]), float(atom.coord[2]))
  return coords


def _coords_to_distmat(ordered_keys: List[Tuple[str, int]], coord_map: Dict[Tuple[str, int], Tuple[float, float, float]]) -> np.ndarray:
  """Build an NxN Euclidean distance matrix from an ordered list of residue keys and a coord map."""
  if not ordered_keys:
    return np.empty((0, 0))
  pts = np.array([coord_map[k] for k in ordered_keys], dtype=float)  # (N,3)
  # Pairwise distances with broadcasting
  diff = pts[:, None, :] - pts[None, :, :]
  dist = np.sqrt(np.sum(diff * diff, axis=-1))
  return dist


def _calc_lddt_from_maps(
  true_map: np.ndarray,
  pred_map: np.ndarray,
  *,
  R: float = 15.0,
  sep_thresh: int = -1,
  T_set: Sequence[float] = (0.5, 1.0, 2.0, 4.0),
  precision: int = 4,
) -> float:
  """
  Mariani V, Biasini M, Barbato A, Schwede T.
  lDDT: a local superposition-free score for comparing protein structures and models using distance difference tests.
  Bioinformatics. 2013 Nov 1;29(21):2722-8.
  doi: 10.1093/bioinformatics/btt473.
  Epub 2013 Aug 27.
  PMID: 23986568; PMCID: PMC3799472.
  """

  # Helper for number preserved in a threshold
  def get_n_preserved(ref_flat, mod_flat, thresh):
    err = np.abs(ref_flat - mod_flat)
    n_preserved = (err < thresh).sum()
    return n_preserved

  # flatten upper triangles
  true_flat_map = _get_flattened(true_map)
  pred_flat_map = _get_flattened(pred_map)

  # Find set L
  S_thresh_indices = _get_sep_thresh_b_indices(true_map, sep_thresh, "gt")
  R_thresh_indices = _get_dist_thresh_b_indices(true_flat_map, R, "lt")

  L_indices = S_thresh_indices & R_thresh_indices

  L_n = L_indices.sum()
  if L_n == 0:
    return float("nan")

  true_flat_in_L = true_flat_map[L_indices]
  pred_flat_in_L = pred_flat_map[L_indices]

  # Calculated LDDT
  preserved_fractions = []
  for _thresh in T_set:
    _n_preserved = get_n_preserved(true_flat_in_L, pred_flat_in_L, _thresh)
    _f_preserved = _n_preserved / L_n
    preserved_fractions.append(_f_preserved)

  lddt = np.mean(preserved_fractions)
  if precision > 0:
    lddt = round(lddt, precision)
  return lddt


[docs] def calc_lddt( reference: Union[np.ndarray, Protein], prediction: Union[np.ndarray, Protein], *, model_ref: Optional[int] = None, model_pred: Optional[int] = None, chains_ref: Optional[List[str]] = None, chains_pred: Optional[List[str]] = None, R: float = 15.0, sep_thresh: int = -1, T_set: Sequence[float] = (0.5, 1.0, 2.0, 4.0), precision: int = 4, require_standard_aa: bool = True, ) -> float: """Compute lDDT from distance maps or Protein structures. Args: reference: Distance map or Protein used as the ground truth. prediction: Distance map or Protein to compare against the reference. model_ref: Model index to select when `reference` is a Protein. model_pred: Model index to select when `prediction` is a Protein. chains_ref: Chain identifiers to include when `reference` is a Protein. chains_pred: Chain identifiers to include when `prediction` is a Protein. R: Maximum reference distance to consider when defining the L set. sep_thresh: Minimum sequence separation between residue pairs. T_set: Error thresholds used to compute preserved distance fractions. precision: Decimal precision of the reported score; use 0 or negative to skip rounding. require_standard_aa: Skip residues with unknown amino-acid or nucleotide codes when True. Returns: lDDT score between the reference and prediction. Raises: ValueError: If distance maps do not share the same shape. TypeError: If the inputs are not both arrays or both Protein instances. """ is_ref_array = isinstance(reference, np.ndarray) is_pred_array = isinstance(prediction, np.ndarray) is_ref_protein = isinstance(reference, Protein) is_pred_protein = isinstance(prediction, Protein) if is_ref_array and is_pred_array: if reference.shape != prediction.shape: raise ValueError("Distance maps must share the same shape to compute lDDT.") return _calc_lddt_from_maps(reference, prediction, R=R, sep_thresh=sep_thresh, T_set=T_set, precision=precision) if is_ref_protein and is_pred_protein: # Extract per-residue coordinates for both proteins ref_cb = _extract_cb_coords_from_protein(reference, model=model_ref, chains=chains_ref, require_standard_aa=require_standard_aa) pred_cb = _extract_cb_coords_from_protein(prediction, model=model_pred, chains=chains_pred, require_standard_aa=require_standard_aa) # Intersect keys to ensure 1:1 residue correspondence common_keys = sorted(set(ref_cb.keys()).intersection(pred_cb.keys()), key=lambda x: (x[0], x[1])) if len(common_keys) < 2: raise Exception("Not enough residues/pairs to compute meaningful lDDT") # Build distance maps on the common residue set true_map = _coords_to_distmat(common_keys, ref_cb) pred_map = _coords_to_distmat(common_keys, pred_cb) return _calc_lddt_from_maps(true_map, pred_map, R=R, sep_thresh=sep_thresh, T_set=T_set, precision=precision) raise TypeError("calc_lddt expects both inputs to be either numpy.ndarray distance maps or Protein objects.")