Source code for neurosnap.structure.structure

"""Data structures for representing molecular coordinates and annotations.

This module provides a single-model :class:`Structure`, immutable hierarchy
views (:class:`Chain`, :class:`Residue`, and :class:`Atom`), an ordered
multi-model container (:class:`StructureEnsemble`), and a shared-annotation
multi-model fast path (:class:`StructureStack`).

The universal length unit is Å.
"""

from dataclasses import dataclass, field
from types import MappingProxyType
from typing import Any, Dict, Iterator, List, Literal, Mapping, Optional, Tuple

import numpy as np
import pandas as pd

from neurosnap.constants.chemistry import ATOMIC_MASSES
from neurosnap.constants.sequence import AA_RECORDS
from neurosnap.constants.structure import BACKBONE_ATOMS_DNA, BACKBONE_ATOMS_RNA, NUC_DNA_CODES, NUC_RNA_CODES, STANDARD_NUCLEOTIDES
from neurosnap.log import logger

### IMPORTANT NOTES
# Universal unit is Å.
# This new Structure object does not care about altlocs and will automatically drop them
# Hetatoms are stored with proper bond information
# In PDB files repeated bonds will correspond to bonds being interpreted at a higher order. For instance if atom i and j have two records for bonds in a PDB file this will be interpreted as them having a double bond.
# Each structure corresponds to a single model ONLY, the StructureEnsemble object should be used instead for an ordered collection of models (OR optional later: StructureStack = shared-annotation multi-model fast path, only when all models have identical atoms/bonds)

_STRUCTURE_DATAFRAME_COLUMNS = (
  "chain",
  "res_id",
  "ins_code",
  "res_name",
  "hetero",
  "res_type",
  "atom",
  "atom_name",
  "element",
  "bfactor",
  "occupancy",
  "charge",
  "sym_id",
  "x",
  "y",
  "z",
  "mass",
)


[docs] class Structure: """Single-model molecular structure container. Coordinates are stored separately from per-atom annotations so geometry-heavy operations can work on compact numeric arrays while annotation schemas remain flexible. Parameters: remove_annotations: If ``True``, optional annotation columns that only contain default values are removed after initialization. """ # These fields define atom identity and basic PDB semantics, so they must # always be present even if every value is currently the default. _MANDATORY_ANNOTATIONS = ("chain_id", "res_id", "ins_code", "res_name", "hetero", "atom_name", "element") _ANNOTATION_DEFAULTS = { "chain_id": "", "res_id": 0, "ins_code": "", "res_name": "", "hetero": False, "atom_name": "", "element": "", "atom_id": 0, "b_factor": 0.0, "occupancy": 1.0, "charge": 0, "sym_id": "", }
[docs] def __init__(self, remove_annotations: bool = True): """Initialize an empty single-model structure.""" # keys are metadata field names / titles and values are the corresponding values # TODO: Read metadata from file and add it as needed self.metadata: Dict[str, Any] = {} # create dtype for atoms array self._dtype_atoms = np.dtype( [ ("x", "f4"), # x coordinate ("y", "f4"), # y coordinate ("z", "f4"), # z coordinate ] ) # create dtype for annotations array self._dtype_atom_annotations = np.dtype( [ ("chain_id", "U4"), # chain ID ("res_id", "i4"), # residue number ("ins_code", "U1"), # insertion code ("res_name", "U5"), # residue name ("hetero", "?"), # ATOM vs HETATM ("atom_name", "U6"), # atom name ("element", "U2"), # chemical element ("atom_id", "i4"), # atom serial number ("b_factor", "f4"), # temperature factor ("occupancy", "f4"), # occupancy ("charge", "i1"), # small int (-128 to 127 is enough) ("sym_id", "U4"), # symmetry ID (string, often small) ] ) # create dtype for bonds array self._dtype_bond = np.dtype( [ ("atom_i", np.int32), ("atom_j", np.int32), ("bond_type", np.int8), ] ) # Coordinates and annotations are stored separately so annotation schema # changes only require rebuilding the annotation table. self.atoms = np.zeros(0, dtype=self._dtype_atoms) self.atom_annotations = np.zeros(0, dtype=self._dtype_atom_annotations) self.bonds = np.zeros(0, dtype=self._dtype_bond) if remove_annotations: self._remove_empty_annotations()
[docs] def __len__(self) -> int: """Return the number of atoms in the structure.""" return len(self.atoms)
[docs] def __repr__(self) -> str: """Return a compact string summary of the structure.""" chains = self.chains() chain_ids = [chain.chain_id if chain.chain_id else "<blank>" for chain in chains] residue_count = sum(len(chain.residues()) for chain in chains) return f"<Structure: Chains=[{','.join(chain_ids)}] Residues={residue_count} Atoms={len(self)}>"
[docs] def __iter__(self) -> Iterator["Chain"]: """Iterate over chains in atom-table order.""" return iter(self.chains())
[docs] def __getitem__(self, chain_id: str) -> "Chain": """Return a chain view by chain ID. Parameters: chain_id: Chain identifier to retrieve. Returns: The matching :class:`Chain` view. Raises: TypeError: If ``chain_id`` is not a string. KeyError: If the requested chain is not present in the structure. """ if not isinstance(chain_id, str): raise TypeError("Structure indices must be chain IDs as strings.") for chain in self.chains(): if chain.chain_id == chain_id: return chain raise KeyError(f'Chain "{chain_id}" was not found.')
[docs] def to_dataframe(self) -> pd.DataFrame: """Export the structure as a pandas dataframe. This dataframe is derived on demand from the current atom table and is never cached on the structure. """ atom_count = len(self) data = { "chain": self._annotation_export("chain_id"), "res_id": self._annotation_export("res_id"), "ins_code": self._annotation_export("ins_code"), "res_name": self._annotation_export("res_name"), "hetero": self._annotation_export("hetero"), "res_type": self._residue_types(), "atom": self._annotation_export("atom_id"), "atom_name": self._annotation_export("atom_name"), "element": self._annotation_export("element"), "bfactor": self._annotation_export("b_factor"), "occupancy": self._annotation_export("occupancy"), "charge": self._annotation_export("charge"), "sym_id": self._annotation_export("sym_id"), "x": self.atoms["x"].copy() if atom_count else np.zeros(0, dtype=np.float32), "y": self.atoms["y"].copy() if atom_count else np.zeros(0, dtype=np.float32), "z": self.atoms["z"].copy() if atom_count else np.zeros(0, dtype=np.float32), "mass": self._atom_masses(), } return pd.DataFrame(data, columns=_STRUCTURE_DATAFRAME_COLUMNS)
[docs] def chains(self) -> List["Chain"]: """Return all chains in the structure as immutable hierarchy views. Returns: List of :class:`Chain` objects in atom-table order. """ chain_map: Dict[str, Dict[Tuple[int, str, str, bool], List[int]]] = {} for atom_index in range(len(self)): chain_id = self._annotation_value("chain_id", atom_index) residue_key = ( self._annotation_value("res_id", atom_index), self._annotation_value("ins_code", atom_index), self._annotation_value("res_name", atom_index), self._annotation_value("hetero", atom_index), ) # Chain and residue ordering follow the original atom table order so # hierarchy views remain stable with respect to the parsed file. if chain_id not in chain_map: chain_map[chain_id] = {} if residue_key not in chain_map[chain_id]: chain_map[chain_id][residue_key] = [] chain_map[chain_id][residue_key].append(atom_index) chains: List[Chain] = [] for chain_id, residue_map in chain_map.items(): residues: List[Residue] = [] for residue_key, atom_indices in residue_map.items(): res_id, ins_code, res_name, hetero = residue_key residues.append( Residue( chain_id=chain_id, res_id=res_id, ins_code=ins_code, res_name=res_name, hetero=hetero, _atoms=tuple(self._atom_view(atom_index) for atom_index in atom_indices), _atom_indices=tuple(atom_indices), ) ) chains.append(Chain(chain_id=chain_id, _residues=tuple(residues))) return chains
[docs] def chain_ids(self) -> List[str]: """Return all chains IDs found in the structure. Returns: List of strings for each chain. """ return [str(x) for x in np.unique(self.atom_annotations["chain_id"])]
[docs] def renumber(self, chain: Optional[str] = None, start: int = 1): """Renumber residues in-place. Parameters: chain: Chain ID to renumber. If ``None``, all chains are renumbered in chain order using one continuous counter. start: Starting residue number. Notes: Renumbering treats inserted residues as ordinary sequential residues and clears their insertion codes. For example, residues ``10``, ``10A``, and ``10B`` become ``1``, ``2``, and ``3`` (with empty insertion codes) when renumbered with ``start=1``. """ if chain is not None and chain not in self.chain_ids(): raise ValueError(f'Chain "{chain}" was not found in the structure.') residue_number = int(start) residue_map: Dict[Tuple[str, int, str, str, bool], int] = {} for chain_view in self.chains(): if chain is not None and chain_view.chain_id != chain: continue for residue in chain_view.residues(): residue_key = (residue.chain_id, residue.res_id, residue.ins_code, residue.res_name, residue.hetero) residue_map[residue_key] = residue_number residue_number += 1 if not residue_map: return for atom_index in range(len(self)): residue_key = ( self._annotation_value("chain_id", atom_index), self._annotation_value("res_id", atom_index), self._annotation_value("ins_code", atom_index), self._annotation_value("res_name", atom_index), self._annotation_value("hetero", atom_index), ) if residue_key in residue_map: self.atom_annotations["res_id"][atom_index] = residue_map[residue_key] self.atom_annotations["ins_code"][atom_index] = ""
[docs] def translate( self, x: float = 0.0, y: float = 0.0, z: float = 0.0, chains: Optional[List[str]] = None, ): """Translate selected atoms in-place by a fixed vector. Parameters: x: Translation along the x-axis. y: Translation along the y-axis. z: Translation along the z-axis. chains: Optional chain IDs to translate. If ``None``, all atoms are translated. """ atom_mask = self._atom_mask(chains=chains) self.atoms["x"][atom_mask] += float(x) self.atoms["y"][atom_mask] += float(y) self.atoms["z"][atom_mask] += float(z)
[docs] def center_at( self, x: float = 0.0, y: float = 0.0, z: float = 0.0, chains: Optional[List[str]] = None, ): """Translate selected atoms so their center of mass matches a target point. Parameters: x: Target x-coordinate for the center of mass. y: Target y-coordinate for the center of mass. z: Target z-coordinate for the center of mass. chains: Optional chain IDs to center. If ``None``, all atoms are used. """ target = np.array([x, y, z], dtype=np.float32) center_of_mass = self.calculate_center_of_mass(chains=chains) translation = target - center_of_mass self.translate(x=float(translation[0]), y=float(translation[1]), z=float(translation[2]), chains=chains)
[docs] def calculate_center_of_mass(self, chains: Optional[List[str]] = None) -> np.ndarray: """Calculate the center of mass for the selected atoms. Parameters: chains: Optional chain IDs to include. If ``None``, all atoms are used. Returns: A length-3 NumPy array containing the center of mass in Å. Raises: ValueError: If no atoms are found in the selected structure or if any selected atom has an unknown element mass. """ atom_mask = self._atom_mask(chains=chains) if not np.any(atom_mask): raise ValueError("No atoms were found in the selected structure.") coord = self._coord_matrix(atom_mask=atom_mask) masses = self._atom_masses(atom_mask=atom_mask) return np.average(coord, axis=0, weights=masses)
[docs] def calculate_geometric_center(self, chains: Optional[List[str]] = None) -> np.ndarray: """Calculate the geometric center for the selected atoms. Parameters: chains: Optional chain IDs to include. If ``None``, all atoms are used. Returns: A length-3 NumPy array containing the arithmetic mean of the selected atom coordinates in Å. Raises: ValueError: If no atoms are found in the selected structure. """ atom_mask = self._atom_mask(chains=chains) if not np.any(atom_mask): raise ValueError("No atoms were found in the selected structure.") coord = self._coord_matrix(atom_mask=atom_mask) return coord.mean(axis=0)
[docs] def distances_from(self, point: np.ndarray, chains: Optional[List[str]] = None) -> np.ndarray: """Calculate distances from a point for the selected atoms. Parameters: point: Reference point as an array-like object with shape ``(3,)``. chains: Optional chain IDs to include. If ``None``, all atoms are used. Returns: A 1D NumPy array containing Euclidean distances in atom-table order. """ point = np.asarray(point, dtype=np.float32) if point.shape != (3,): raise ValueError("Point must be an array-like object with shape (3,).") atom_mask = self._atom_mask(chains=chains) coord = self._coord_matrix(atom_mask=atom_mask) if coord.size == 0: return np.zeros(0, dtype=np.float32) return np.linalg.norm(coord - point, axis=1)
[docs] def calculate_rog(self, chains: Optional[List[str]] = None, center: Optional[np.ndarray] = None) -> float: """Calculate the radius of gyration for the selected atoms. Parameters: chains: Optional chain IDs to include. If ``None``, all atoms are used. center: Optional reference point. If ``None``, the center of mass is used. Returns: Radius of gyration in Å. """ atom_mask = self._atom_mask(chains=chains) if not np.any(atom_mask): return 0.0 if center is None: center = self.calculate_center_of_mass(chains=chains) distances = self.distances_from(center, chains=chains) if distances.size == 0: return 0.0 return float(np.sqrt(np.mean(distances**2)))
[docs] def add_annotation( self, name: str, dtype: Any, values: Any = None, *, fill_value: Any = None, overwrite: bool = False, ): """Add a new per-atom annotation column. Parameters: name: Annotation name to add. dtype: NumPy-compatible scalar dtype for the annotation values. values: Optional per-atom values for the annotation. fill_value: Optional default value used when ``values`` is not supplied. overwrite: Whether to replace an existing optional annotation of the same name. Raises: ValueError: If the name is invalid, reserved, already present, or the supplied values do not match the atom count. TypeError: If the supplied dtype is not a scalar per-atom dtype. """ if not isinstance(name, str) or not name: raise ValueError("Annotation name must be a non-empty string.") if name in self._dtype_atoms.names: raise ValueError(f'"{name}" is reserved for coordinate storage.') annotation_dtype = np.dtype(dtype) if annotation_dtype.names is not None: raise TypeError("Annotation dtype must describe a single scalar value per atom.") if name in self._dtype_atom_annotations.names: if not overwrite: raise ValueError(f'Annotation "{name}" already exists.') if name in self._MANDATORY_ANNOTATIONS: raise ValueError(f'Cannot overwrite mandatory annotation "{name}".') self.remove_annotation(name) atom_count = len(self) if values is not None: # Annotation columns are always per-atom, so the incoming data must line # up exactly with the current atom table. values_array = np.asarray(values, dtype=annotation_dtype) if values_array.ndim != 1: raise ValueError("Annotation values must be one-dimensional.") if len(values_array) != atom_count: raise ValueError(f'Annotation "{name}" must contain exactly {atom_count} values; found {len(values_array)}.') else: if fill_value is None: fill_value = self._default_fill_value(name, annotation_dtype) values_array = np.full(atom_count, fill_value, dtype=annotation_dtype) # NumPy structured dtypes are immutable, so adding a field means rebuilding # the annotation table with the old columns copied over. new_dtype = np.dtype(list(self._dtype_atom_annotations.descr) + [(name, annotation_dtype)]) new_annotations = np.empty(atom_count, dtype=new_dtype) for field_name in self._dtype_atom_annotations.names: new_annotations[field_name] = self.atom_annotations[field_name] new_annotations[name] = values_array self._dtype_atom_annotations = new_dtype self.atom_annotations = new_annotations
[docs] def remove_annotation(self, name: str): """Remove a non-mandatory annotation column and return its values. Parameters: name: Annotation name to remove. Returns: Copy of the removed annotation values. Raises: KeyError: If the annotation does not exist. ValueError: If the name is invalid or refers to a mandatory annotation. """ if not isinstance(name, str) or not name: raise ValueError("Annotation name must be a non-empty string.") if name not in self._dtype_atom_annotations.names: raise KeyError(f'Annotation "{name}" was not found.') if name in self._MANDATORY_ANNOTATIONS: raise ValueError(f'Cannot remove mandatory annotation "{name}".') removed_values = self.atom_annotations[name].copy() # Removing a field uses the same rebuild path as add_annotation(), but with # the target column omitted from the new dtype. remaining_fields = [ (field_name, self._dtype_atom_annotations.fields[field_name][0]) for field_name in self._dtype_atom_annotations.names if field_name != name ] new_dtype = np.dtype(remaining_fields) new_annotations = np.empty(len(self), dtype=new_dtype) for field_name in new_dtype.names: new_annotations[field_name] = self.atom_annotations[field_name] self._dtype_atom_annotations = new_dtype self.atom_annotations = new_annotations return removed_values
def _remove_empty_annotations(self): """Drop optional annotations that do not currently carry information.""" # This is most useful after parsing/loading when optional columns were kept # available during construction but ended up containing only defaults. for name in list(self._dtype_atom_annotations.names): if name in self._MANDATORY_ANNOTATIONS: continue if self._annotation_is_empty(name): self.remove_annotation(name) def _annotation_is_empty(self, name: str) -> bool: """Return ``True`` if an annotation contains only default values.""" values = self.atom_annotations[name] if values.size == 0: return True # A column that is still entirely filled with its default sentinel value # does not contain any structure-specific information yet. default = self._default_fill_value(name, values.dtype) return bool(np.all(values == default)) def _default_fill_value(self, name: str, dtype: np.dtype): """Return the default fill value for an annotation dtype.""" if name in self._ANNOTATION_DEFAULTS: return self._ANNOTATION_DEFAULTS[name] if dtype.kind in {"U", "S"}: return "" if dtype.kind == "b": return False if dtype.kind in {"i", "u"}: return 0 if dtype.kind == "f": return 0.0 raise TypeError(f'No default fill value is defined for dtype "{dtype}" of annotation "{name}".') def _annotation_value(self, name: str, atom_index: int): """Return a Python scalar for a single annotation value.""" value = self.atom_annotations[name][atom_index] if isinstance(value, np.generic): return value.item() return value def _annotation_export(self, name: str) -> np.ndarray: """Return an annotation column or a default-filled export column.""" if name in self._dtype_atom_annotations.names: return self.atom_annotations[name].copy() default_value = self._ANNOTATION_DEFAULTS[name] return np.full(len(self), default_value) def _residue_types(self) -> np.ndarray: """Return per-atom residue-type labels for dataframe export.""" residue_types = np.full(len(self), "HETEROGEN", dtype="U12") if len(self) == 0: return residue_types hetero = self._annotation_export("hetero") res_name = self._annotation_export("res_name") for atom_index in range(len(self)): if bool(hetero[atom_index]): continue residue_name = str(res_name[atom_index]) if residue_name in STANDARD_NUCLEOTIDES: residue_types[atom_index] = "NUCLEOTIDE" elif residue_name in AA_RECORDS: residue_types[atom_index] = "AMINO_ACID" return residue_types def _atom_mask(self, chains: Optional[List[str]] = None) -> np.ndarray: """Return a boolean mask for atoms in the selected chains.""" if chains is None: return np.ones(len(self), dtype=bool) selected_chains = [str(chain_id) for chain_id in chains] missing_chains = sorted(set(selected_chains) - set(self.chain_ids())) if missing_chains: raise ValueError(f'Chain(s) {", ".join(missing_chains)} were not found in the structure.') return np.isin(self.atom_annotations["chain_id"], selected_chains) def _coord_matrix(self, atom_mask: Optional[np.ndarray] = None) -> np.ndarray: """Return atom coordinates as an ``(n_atoms, 3)`` matrix.""" if atom_mask is None: atom_mask = np.ones(len(self), dtype=bool) return np.column_stack((self.atoms["x"][atom_mask], self.atoms["y"][atom_mask], self.atoms["z"][atom_mask])).astype(np.float32, copy=False) def _atom_masses(self, atom_mask: Optional[np.ndarray] = None) -> np.ndarray: """Return per-atom masses derived from element symbols.""" if atom_mask is None: atom_mask = np.ones(len(self), dtype=bool) selected_indices = np.flatnonzero(atom_mask) masses = np.zeros(len(selected_indices), dtype=np.float32) unknown_elements = set() for index, atom_index in enumerate(selected_indices): element = str(self._annotation_value("element", atom_index)).strip() if element not in ATOMIC_MASSES: unknown_elements.add(element or "<blank>") continue masses[index] = float(ATOMIC_MASSES[element]) if unknown_elements: message = ( "Unknown element mass for atom selection: " + ", ".join(sorted(unknown_elements)) + ". This is likely an error in the input structure." ) logger.warning(message) raise ValueError(message) return masses def _atom_view(self, atom_index: int) -> "Atom": """Create an immutable :class:`Atom` view for one atom index.""" coord = self.atoms[atom_index] extra_annotations = { name: self._annotation_value(name, atom_index) for name in self._dtype_atom_annotations.names if name not in self._MANDATORY_ANNOTATIONS } return Atom( x=float(coord["x"]), y=float(coord["y"]), z=float(coord["z"]), chain_id=self._annotation_value("chain_id", atom_index), res_id=self._annotation_value("res_id", atom_index), ins_code=self._annotation_value("ins_code", atom_index), res_name=self._annotation_value("res_name", atom_index), hetero=self._annotation_value("hetero", atom_index), atom_name=self._annotation_value("atom_name", atom_index), element=self._annotation_value("element", atom_index), annotations=MappingProxyType(extra_annotations), )
[docs] @dataclass(frozen=True, slots=True) class Atom: """Immutable atom-level hierarchy view.""" x: float y: float z: float chain_id: str res_id: int ins_code: str res_name: str hetero: bool atom_name: str element: str annotations: Mapping[str, Any] = field(default_factory=lambda: MappingProxyType({})) @property def coord(self) -> np.ndarray: """Return the atom coordinates as a length-3 NumPy array.""" return np.array([self.x, self.y, self.z], dtype=np.float32)
[docs] @dataclass(frozen=True, slots=True) class Residue: """Immutable residue-level hierarchy view. A :class:`Residue` groups atoms that share the same chain identifier, residue number, insertion code, residue name, and hetero flag. The object is a lightweight read-only view over the parsed atom table, intended for traversal and analysis rather than in-place editing. Attributes: chain_id: Chain identifier containing the residue. res_id: Residue sequence number. ins_code: PDB insertion code for the residue. res_name: Residue name / CCD code. hetero: ``True`` for heterogens and ``False`` for polymer ``ATOM`` records. """ chain_id: str res_id: int ins_code: str res_name: str hetero: bool _atoms: Tuple[Atom, ...] = field(repr=False) _atom_indices: Tuple[int, ...] = field(repr=False)
[docs] def atoms(self) -> List[Atom]: """Return the atoms that belong to this residue. Returns: List of immutable :class:`Atom` views in atom-table order. """ return list(self._atoms)
[docs] def atom_indices(self) -> List[int]: """Return atom-table indices for the atoms in this residue. Returns: List of integer atom indices in atom-table order. """ return list(self._atom_indices)
[docs] def key(self) -> Tuple[str, int, str, str, bool]: """Return a stable residue identity tuple. The returned key is suitable for dictionary/set membership when residue identity needs to be tracked outside the object itself. Returns: ``(chain_id, res_id, ins_code, res_name, hetero)`` """ return (self.chain_id, self.res_id, self.ins_code, self.res_name, self.hetero)
[docs] def __hash__(self): """Return a hash derived from :meth:`key`.""" return hash(self.key())
[docs] def __eq__(self, other): """Compare two residue views by stable identity.""" if not isinstance(other, Residue): return NotImplemented return self.key() == other.key()
[docs] @dataclass(frozen=True, slots=True) class Chain: """Immutable chain-level hierarchy view. A :class:`Chain` is a read-only hierarchy view over the residues associated with one chain identifier in a single :class:`Structure`. It provides chain- level traversal plus convenience helpers for sequence extraction and simple residue-number gap detection. Attributes: chain_id: Chain identifier represented by this view. """ chain_id: str _residues: Tuple[Residue, ...] = field(repr=False)
[docs] def __iter__(self) -> Iterator[Residue]: """Iterate over residues in residue order.""" return iter(self._residues)
[docs] def residues(self) -> List[Residue]: """Return the residues that belong to this chain. Returns: List of immutable :class:`Residue` views in residue order. """ return list(self._residues)
[docs] def __getitem__(self, res_id: int) -> Residue: """Return a residue view by residue ID, not by positional index. Parameters: res_id: Residue sequence number to retrieve. Returns: The first :class:`Residue` in this chain with the requested residue ID. Raises: TypeError: If ``res_id`` is not an integer residue ID. KeyError: If no residue with the requested ID is present in the chain. Notes: This method looks up residues by their residue ID rather than by list position. If multiple residues share the same residue ID, such as inserted residues distinguished by insertion codes, the first matching residue is returned and a warning is emitted. """ if not isinstance(res_id, (int, np.integer)): raise TypeError("Chain indices must be residue IDs as integers.") matches = [residue for residue in self._residues if residue.res_id == int(res_id)] if not matches: raise KeyError(f'Residue ID {res_id} was not found in chain "{self.chain_id}".') if len(matches) > 1: logger.warning( 'Chain "%s" contains multiple residues with residue ID %d; returning the first match.', self.chain_id, int(res_id), ) return matches[0]
[docs] def sequence( self, polymer_type: Literal["auto", "protein", "dna", "rna", "nucleotide"] = "auto", include_modifications: bool = False, modification_mode: Literal["inline", "parent"] = "inline", on_unknown_modified: Literal["raise", "unknown"] = "raise", ) -> str: """Return the polymer sequence for this chain. Protein, DNA, and RNA sequences are supported. Small molecules and other non-polymer residues in the chain are ignored. Modified residues can either be skipped, emitted inline as ``(CCD)``, or mapped to their parent sequence code when available. Parameters: polymer_type: Polymer family to extract. ``"auto"`` infers the family from the chain contents. ``"nucleotide"`` accepts either DNA or RNA, but raises if both are present. include_modifications: Whether modified residues should contribute to the sequence. If ``False``, modified residues are skipped entirely. modification_mode: How included modifications are emitted. ``"inline"`` inserts ``(CCD)`` tokens, while ``"parent"`` uses the inferred parent residue code. on_unknown_modified: Behavior when ``modification_mode="parent"`` is requested but no parent code can be inferred. ``"raise"`` raises a :class:`ValueError`; ``"unknown"`` inserts ``"X"``. Returns: Sequence string for the selected polymer family. Returns an empty string if the chain contains no residues from the requested polymer family. Raises: ValueError: If the chain mixes polymer families in a way that conflicts with ``polymer_type`` or if an unknown modified residue cannot be mapped in ``"parent"`` mode. """ if polymer_type not in {"auto", "protein", "dna", "rna", "nucleotide"}: raise ValueError('polymer_type must be one of "auto", "protein", "dna", "rna", or "nucleotide".') if modification_mode not in {"inline", "parent"}: raise ValueError('modification_mode must be either "inline" or "parent".') if on_unknown_modified not in {"raise", "unknown"}: raise ValueError('on_unknown_modified must be either "raise" or "unknown".') detected_polymer_type = None if polymer_type == "auto" else polymer_type sequence_parts = [] for residue in self._residues: residue_polymer_type = _classify_polymer_residue(residue) if residue_polymer_type is None: continue # Auto-detection locks onto the first polymer family encountered and then # rejects incompatible mixtures later in the chain. if detected_polymer_type is None: detected_polymer_type = residue_polymer_type elif not _polymer_types_compatible(detected_polymer_type, residue_polymer_type): raise ValueError(f'Chain "{self.chain_id}" mixes polymer residue types; found both "{detected_polymer_type}" and "{residue_polymer_type}".') if polymer_type != "auto" and not _polymer_types_compatible(polymer_type, residue_polymer_type): raise ValueError( f'Chain "{self.chain_id}" contains "{residue_polymer_type}" residues, which are incompatible with polymer_type="{polymer_type}".' ) residue_name = residue.res_name.strip().upper() if residue_polymer_type == "protein": # Standard amino acids map directly to one-letter codes. Modified amino # acids either get skipped, emitted inline as CCD tokens, or mapped to # their declared parent residue when available. residue_record = AA_RECORDS[residue_name] if residue_record.code is not None: residue_token = residue_record.code elif not include_modifications: residue_token = None elif modification_mode == "inline": residue_token = f"({residue_record.abr})" else: residue_token = None if residue_record.standard_equiv_abr is not None: parent_record = AA_RECORDS.get(residue_record.standard_equiv_abr) if parent_record is not None: residue_token = parent_record.code if residue_token is None: if on_unknown_modified == "unknown": residue_token = "X" else: raise ValueError(f'Could not infer a parent sequence code for modified residue "{residue_name}".') else: # Nucleotide handling mirrors the protein path but uses canonical # one-letter base codes and a simple parent-code inference fallback for # modified residues. if residue_name in NUC_RNA_CODES: residue_token = residue_name elif residue_name in NUC_DNA_CODES: residue_token = residue_name[1] elif not include_modifications: residue_token = None elif modification_mode == "inline": residue_token = f"({residue_name})" else: allowed_codes = {"A", "C", "G", "T"} if residue_polymer_type == "dna" else {"A", "C", "G", "U"} residue_token = next((char for char in reversed(residue_name) if char in allowed_codes), None) if residue_token is None: if on_unknown_modified == "unknown": residue_token = "X" else: raise ValueError(f'Could not infer a parent sequence code for modified residue "{residue_name}".') if residue_token is not None: sequence_parts.append(residue_token) return "".join(sequence_parts)
[docs] def missing_residue_ids(self) -> List[int]: """Return missing residue numbers inferred from gaps in the chain. Hetero residues are ignored so ligand or solvent numbering does not create artificial gaps in the polymer residue sequence. Returns: Sorted list of integer residue IDs that are absent between observed non-hetero residue numbers. """ residue_ids = sorted({residue.res_id for residue in self._residues if not residue.hetero}) missing_ids = [] for index in range(len(residue_ids) - 1): current_residue_id = residue_ids[index] next_residue_id = residue_ids[index + 1] if next_residue_id > current_residue_id + 1: missing_ids.extend(range(current_residue_id + 1, next_residue_id)) return missing_ids
def _validate_structure_model(model: Structure): """Validate that an object is a structurally consistent ``Structure``. Parameters: model: Candidate structure to validate. Raises: TypeError: If the input is not a :class:`Structure`. ValueError: If the coordinate and annotation tables are inconsistent. """ if not isinstance(model, Structure): raise TypeError(f"Expected a Structure instance, found {type(model).__name__}.") if len(model.atoms) != len(model.atom_annotations): raise ValueError("Structure atoms and atom_annotations must have the same length.") if model.atoms.dtype.names != ("x", "y", "z"): raise ValueError('Structure atoms dtype must contain the coordinate fields "x", "y", and "z".') def _structure_chain_ids(structure: Structure) -> List[str]: """Return chain identifiers from a structure in hierarchy order.""" return [chain.chain_id if chain.chain_id else "<blank>" for chain in structure.chains()] def _model_position_from_id(model_ids: List[int], model_id: int) -> int: """Return the positional index for a model identifier. Raises: KeyError: If the requested model identifier is not present. """ try: return model_ids.index(int(model_id)) except ValueError: raise KeyError(f"Model ID {model_id} was not found.") def _polymer_types_compatible(requested_polymer_type: str, residue_polymer_type: str) -> bool: """Return ``True`` if a residue polymer family matches a requested family.""" if requested_polymer_type == "nucleotide": return residue_polymer_type in {"dna", "rna"} return requested_polymer_type == residue_polymer_type def _classify_polymer_residue(residue: Residue) -> Optional[str]: """Classify a residue as protein, DNA, RNA, or non-polymer.""" residue_name = residue.res_name.strip().upper() if residue_name in AA_RECORDS: return "protein" if residue_name in NUC_DNA_CODES: return "dna" if residue_name in NUC_RNA_CODES: return "rna" atom_names = {atom.atom_name.strip().upper() for atom in residue._atoms} if "O2'" in atom_names: backbone_matches = len(atom_names.intersection({atom_name.upper() for atom_name in BACKBONE_ATOMS_RNA})) if backbone_matches >= 3: return "rna" backbone_matches = len(atom_names.intersection({atom_name.upper() for atom_name in BACKBONE_ATOMS_DNA})) if backbone_matches >= 3: return "dna" return None
[docs] class StructureEnsemble: """Ordered collection of independent ``Structure`` models. Unlike :class:`StructureStack`, models in an ensemble do not need to have the same atoms, annotations, or bonds. Parameters: models: Optional initial list of models. model_ids: Optional identifiers corresponding to ``models``. metadata: Optional ensemble-level metadata dictionary. """
[docs] def __init__( self, models: Optional[List[Structure]] = None, *, model_ids: Optional[List[int]] = None, metadata: Optional[Mapping[str, Any]] = None, ): """Initialize an ordered collection of independent models.""" self.metadata: Dict[str, Any] = dict(metadata or {}) self._models: List[Structure] = [] self.model_ids: List[int] = [] models = models or [] if model_ids is not None and len(model_ids) != len(models): raise ValueError("model_ids must have the same length as models.") for index, model in enumerate(models): model_id = None if model_ids is None else model_ids[index] self.append(model, model_id=model_id)
[docs] def __len__(self) -> int: """Return the number of models in the ensemble.""" return len(self._models)
[docs] def __repr__(self) -> str: """Return a compact string summary of the ensemble.""" seen = set() chain_ids = [] for model in self._models: for chain_id in _structure_chain_ids(model): if chain_id in seen: continue seen.add(chain_id) chain_ids.append(chain_id) atom_count = sum(len(model) for model in self._models) return f"<Structure Ensemble: Models={len(self)} Chains=[{', '.join(chain_ids)}] Atoms={atom_count}>"
[docs] def to_dataframe(self) -> pd.DataFrame: """Export the ensemble as a pandas dataframe with a ``model`` column. This dataframe is derived on demand from the current models and is never cached on the ensemble. """ frames = [] for model_id, model in zip(self.model_ids, self._models): frame = model.to_dataframe() frame.insert(0, "model", model_id) frames.append(frame) if not frames: return pd.DataFrame(columns=("model",) + _STRUCTURE_DATAFRAME_COLUMNS) return pd.concat(frames, ignore_index=True)
[docs] def __iter__(self) -> Iterator[Structure]: """Iterate over the stored models in order.""" return iter(self._models)
[docs] def __getitem__(self, index): """Return a model by model ID or a sliced sub-ensemble by position. Integer access uses ``model_id`` lookup rather than positional indexing, so ``ensemble[5]`` returns the model whose ID is ``5``. Slice access keeps normal positional semantics to preserve standard Python iteration and slicing behavior. Raises: KeyError: If an integer model ID is requested but not present. """ if isinstance(index, slice): return StructureEnsemble(self._models[index], model_ids=self.model_ids[index], metadata=self.metadata) model_position = _model_position_from_id(self.model_ids, index) return self._models[model_position]
[docs] def append(self, model: Structure, *, model_id: Optional[int] = None): """Append a validated model to the ensemble. Parameters: model: Model to append. model_id: Optional model identifier. Defaults to the next sequential model ID starting at ``1``. """ _validate_structure_model(model) assigned_model_id = len(self.model_ids) + 1 if model_id is None else int(model_id) self._models.append(model) self.model_ids.append(assigned_model_id) model.metadata["model_id"] = assigned_model_id
[docs] def remove_model(self, model_id: int) -> Structure: """Remove and return a model by model ID. Parameters: model_id: Model identifier to remove. Returns: The removed :class:`Structure`. Raises: KeyError: If the requested model ID is not present. """ model_position = _model_position_from_id(self.model_ids, model_id) self.model_ids.pop(model_position) return self._models.pop(model_position)
[docs] def renumber(self, start: int = 1): """Renumber model identifiers in-place. Parameters: start: Starting model ID. Defaults to ``1``. """ self.model_ids = list(range(int(start), int(start) + len(self))) for model_id, model in zip(self.model_ids, self._models): model.metadata["model_id"] = model_id
[docs] def models(self) -> List[Structure]: """Return the models as a shallow copied list.""" return list(self._models)
[docs] def first(self) -> Structure: """Return the first model in the ensemble. Returns: The first :class:`Structure` in stored order. Raises: IndexError: If the ensemble is empty. """ if not self._models: raise IndexError("Cannot fetch the first model from an empty StructureEnsemble.") return self._models[0]
[docs] def to_stack(self) -> "StructureStack": """Convert the ensemble into a ``StructureStack``. Raises: ValueError: If the models are not stack-compatible. """ return StructureStack(self._models, model_ids=self.model_ids, metadata=self.metadata)
[docs] class StructureStack: """Shared-annotation, shared-bond multi-model fast path. All models in a stack must share the same atom ordering, per-atom annotations, and bonds. Only the coordinates vary between models. Parameters: models: Optional initial list of stack-compatible models. model_ids: Optional identifiers corresponding to ``models``. metadata: Optional stack-level metadata dictionary. """
[docs] def __init__( self, models: Optional[List[Structure]] = None, *, model_ids: Optional[List[int]] = None, metadata: Optional[Mapping[str, Any]] = None, ): """Initialize an empty or pre-populated stack of compatible models.""" self.metadata: Dict[str, Any] = dict(metadata or {}) self.model_ids: List[int] = [] # Use an empty Structure instance to seed the default dtypes for an empty # stack before the first model is added. template = Structure(remove_annotations=False) self._dtype_atoms = template._dtype_atoms self._dtype_atom_annotations = template._dtype_atom_annotations self._dtype_bond = template._dtype_bond self.coord = np.zeros((0, 0, 3), dtype=np.float32) self.atom_annotations = np.zeros(0, dtype=self._dtype_atom_annotations) self.bonds = np.zeros(0, dtype=self._dtype_bond) models = models or [] if model_ids is not None and len(model_ids) != len(models): raise ValueError("model_ids must have the same length as models.") for index, model in enumerate(models): model_id = None if model_ids is None else model_ids[index] self.append(model, model_id=model_id)
[docs] def __len__(self) -> int: """Return the number of models in the stack.""" return self.coord.shape[0]
[docs] def __repr__(self) -> str: """Return a compact string summary of the stack.""" chain_ids = [] if len(self) == 0 else _structure_chain_ids(self._model_to_structure(0)) atom_count = len(self) * self.atom_count return f"<Structure Stack: Models={len(self)} Chains=[{', '.join(chain_ids)}] Atoms={atom_count}>"
[docs] def to_dataframe(self) -> pd.DataFrame: """Export the stack as a pandas dataframe with a ``model`` column. This dataframe is derived on demand from the current stack contents and is never cached on the stack. """ frames = [] for model_index, model_id in enumerate(self.model_ids): frame = self._model_to_structure(model_index).to_dataframe() frame.insert(0, "model", model_id) frames.append(frame) if not frames: return pd.DataFrame(columns=("model",) + _STRUCTURE_DATAFRAME_COLUMNS) return pd.concat(frames, ignore_index=True)
[docs] def __iter__(self) -> Iterator[Structure]: """Iterate over the stack as materialized ``Structure`` models.""" for model_index in range(len(self)): yield self._model_to_structure(model_index)
[docs] def first(self) -> Structure: """Return the first model in the stack. Returns: The first :class:`Structure` in stored order. Raises: IndexError: If the stack is empty. """ if len(self) == 0: raise IndexError("Cannot fetch the first model from an empty StructureStack.") return self._model_to_structure(0)
[docs] def __getitem__(self, index): """Return a materialized model by model ID or a sliced sub-stack by position. Integer access uses ``model_id`` lookup rather than positional indexing, so ``stack[5]`` returns the model whose ID is ``5``. Slice access keeps normal positional semantics to preserve standard Python slicing behavior. Raises: KeyError: If an integer model ID is requested but not present. """ if isinstance(index, slice): return StructureStack._from_parts( self.coord[index].copy(), self.atom_annotations.copy(), self.bonds.copy(), model_ids=self.model_ids[index], metadata=self.metadata, ) model_position = _model_position_from_id(self.model_ids, index) return self._model_to_structure(model_position)
@property def atom_count(self) -> int: """Return the number of shared atoms per model.""" return self.coord.shape[1]
[docs] def append(self, model: Structure, *, model_id: Optional[int] = None): """Append a stack-compatible model. Parameters: model: Model to append. model_id: Optional model identifier. Defaults to the next sequential model ID starting at ``1``. Raises: ValueError: If the candidate model is not compatible with the existing stack. """ _validate_structure_model(model) coord = self._coord_matrix_from_structure(model) if len(self) == 0: # The first model defines the shared annotation and bond schema for the # entire stack. self._dtype_atoms = model._dtype_atoms self._dtype_atom_annotations = model._dtype_atom_annotations self._dtype_bond = model._dtype_bond self.coord = coord[np.newaxis, ...] self.atom_annotations = np.array(model.atom_annotations, dtype=model.atom_annotations.dtype, copy=True) self.bonds = np.array(model.bonds, dtype=model.bonds.dtype, copy=True) else: reference = self._model_to_structure(0) self._ensure_stack_compatible(reference, model) self.coord = np.concatenate((self.coord, coord[np.newaxis, ...]), axis=0) self.model_ids.append(len(self.model_ids) + 1 if model_id is None else int(model_id))
[docs] def remove_model(self, model_id: int) -> Structure: """Remove and return a model by model ID. Parameters: model_id: Model identifier to remove. Returns: The removed :class:`Structure`. Raises: KeyError: If the requested model ID is not present. """ model_position = _model_position_from_id(self.model_ids, model_id) removed_model = self._model_to_structure(model_position) self.coord = np.delete(self.coord, model_position, axis=0) self.model_ids.pop(model_position) return removed_model
[docs] def renumber(self, start: int = 1): """Renumber model identifiers in-place. Parameters: start: Starting model ID. Defaults to ``1``. """ self.model_ids = list(range(int(start), int(start) + len(self)))
[docs] def models(self) -> List[Structure]: """Materialize and return all models in the stack.""" return [self._model_to_structure(index) for index in range(len(self))]
[docs] def to_ensemble(self) -> StructureEnsemble: """Convert the stack into an independent ``StructureEnsemble``.""" return StructureEnsemble(self.models(), model_ids=self.model_ids, metadata=self.metadata)
[docs] @classmethod def from_ensemble(cls, ensemble: StructureEnsemble) -> "StructureStack": """Build a stack from an ensemble of compatible models.""" return cls(ensemble.models(), model_ids=ensemble.model_ids, metadata=ensemble.metadata)
@classmethod def _from_parts( cls, coord: np.ndarray, atom_annotations: np.ndarray, bonds: np.ndarray, *, model_ids: Optional[List[int]] = None, metadata: Optional[Mapping[str, Any]] = None, ) -> "StructureStack": """Construct a stack directly from shared coordinates and annotations.""" coord = np.asarray(coord, dtype=np.float32) if coord.ndim != 3 or coord.shape[2] != 3: raise ValueError("StructureStack coordinates must have shape (n_models, n_atoms, 3).") if len(atom_annotations) != coord.shape[1]: raise ValueError("Shared atom annotations must match the atom dimension of coord.") stack = cls(metadata=metadata) stack.coord = coord.copy() stack._dtype_atom_annotations = atom_annotations.dtype stack.atom_annotations = np.array(atom_annotations, dtype=atom_annotations.dtype, copy=True) stack._dtype_bond = bonds.dtype stack.bonds = np.array(bonds, dtype=bonds.dtype, copy=True) stack.model_ids = list(range(1, coord.shape[0] + 1)) if model_ids is None else [int(x) for x in model_ids] if len(stack.model_ids) != coord.shape[0]: raise ValueError("model_ids must match the number of models in coord.") return stack def _model_to_structure(self, model_index: int) -> Structure: """Materialize a single model from the shared stack representation.""" atoms = self._atoms_from_coord_matrix(self.coord[model_index], self._dtype_atoms) metadata = dict(self.metadata) metadata["model_id"] = self.model_ids[model_index] return self._structure_from_parts(atoms, self.atom_annotations, self.bonds, metadata=metadata) @staticmethod def _coord_matrix_from_structure(model: Structure) -> np.ndarray: """Extract an ``(n_atoms, 3)`` coordinate matrix from a structure.""" if len(model.atoms) == 0: return np.zeros((0, 3), dtype=np.float32) return np.column_stack((model.atoms["x"], model.atoms["y"], model.atoms["z"])).astype(np.float32, copy=False) @staticmethod def _atoms_from_coord_matrix(coord: np.ndarray, dtype: np.dtype) -> np.ndarray: """Create a coordinate structured array from an ``(n_atoms, 3)`` matrix.""" coord = np.asarray(coord, dtype=np.float32) if coord.ndim != 2 or coord.shape[1] != 3: raise ValueError("Coordinate matrix must have shape (n_atoms, 3).") atoms = np.empty(coord.shape[0], dtype=dtype) atoms["x"] = coord[:, 0] atoms["y"] = coord[:, 1] atoms["z"] = coord[:, 2] return atoms @staticmethod def _structure_from_parts( atoms: np.ndarray, atom_annotations: np.ndarray, bonds: np.ndarray, metadata: Optional[Mapping[str, Any]] = None, ) -> Structure: """Build an independent ``Structure`` from array components.""" model = Structure(remove_annotations=False) model._dtype_atoms = atoms.dtype model._dtype_atom_annotations = atom_annotations.dtype model._dtype_bond = bonds.dtype model.atoms = np.array(atoms, dtype=atoms.dtype, copy=True) model.atom_annotations = np.array(atom_annotations, dtype=atom_annotations.dtype, copy=True) model.bonds = np.array(bonds, dtype=bonds.dtype, copy=True) model.metadata = dict(metadata or {}) return model @staticmethod def _ensure_stack_compatible(reference: Structure, candidate: Structure): """Validate that two structures can coexist in the same stack.""" if len(reference) != len(candidate): raise ValueError("StructureStack requires each model to have the same number of atoms.") if reference._dtype_atom_annotations != candidate._dtype_atom_annotations: raise ValueError("StructureStack requires identical annotation schemas across all models.") if not np.array_equal(reference.atom_annotations, candidate.atom_annotations): raise ValueError("StructureStack requires identical atom annotations across all models.") if reference.bonds.dtype != candidate.bonds.dtype or not np.array_equal(reference.bonds, candidate.bonds): raise ValueError("StructureStack requires identical bonds across all models.")