Source code for neurosnap.io.pdb

"""Parser and writer for PDB files.

This module provides Neurosnap-native :func:`parse_pdb` and :func:`save_pdb`
helpers for reading and writing
:class:`~neurosnap.structure.structure.Structure`,
:class:`~neurosnap.structure.structure.StructureEnsemble`, and
:class:`~neurosnap.structure.structure.StructureStack` objects.
"""

import io
import pathlib
from collections import Counter
from dataclasses import field
from typing import Dict, Iterable, List, Literal, Optional, Tuple, Union

import numpy as np

from neurosnap._compat import compat_dataclass
from neurosnap.constants.chemistry import ATOMIC_MASSES
from neurosnap.log import logger
from neurosnap.structure.structure import Structure, StructureEnsemble, StructureStack

__all__ = ["parse_pdb", "save_pdb"]

ReturnType = Literal["ensemble", "stack", "auto"]
ConectErrorMode = Literal["strict", "warn", "ignore"]
AtomKey = Tuple[str, int, str, str, bool, str]
ResidueKey = Tuple[str, int, str, str, bool]
AltlocSite = Tuple[int, AtomKey]
ConectRecord = Tuple[int, List[int], int]


@compat_dataclass(slots=True)
class _AtomRecord:
  """Internal representation of a parsed PDB atom record."""

  x: float
  y: float
  z: float
  chain_id: str
  res_id: int
  ins_code: str
  res_name: str
  hetero: bool
  atom_name: str
  element: str
  atom_id: int
  b_factor: float
  occupancy: float
  charge: int
  altloc: str
  atom_key: AtomKey


@compat_dataclass(slots=True)
class _ModelAccumulator:
  """Mutable builder used while parsing a single PDB model."""

  model_id: int
  atoms: List[Tuple[float, float, float]] = field(default_factory=list)
  annotations: Dict[str, List[object]] = field(
    default_factory=lambda: {
      "chain_id": [],
      "res_id": [],
      "ins_code": [],
      "res_name": [],
      "hetero": [],
      "atom_name": [],
      "element": [],
      "atom_id": [],
      "b_factor": [],
      "occupancy": [],
      "charge": [],
      "sym_id": [],
    }
  )
  serial_to_index: Dict[int, int] = field(default_factory=dict)
  _atom_key_to_index: Dict[AtomKey, int] = field(default_factory=dict)
  _selected_altloc: Dict[AtomKey, Tuple[float, str, int]] = field(default_factory=dict)
  _residue_atom_name_counts: Dict[ResidueKey, Counter[str]] = field(default_factory=dict)
  directed_bonds: Counter[Tuple[int, int]] = field(default_factory=Counter)

  def add_atom(self, atom: _AtomRecord):
    """Add or replace a selected atom record in the current model.

    Alternate locations collapse onto the same ``atom_key`` only when an
    actual altloc identifier is present. Ordinary atoms with repeated names in
    the same residue are preserved and renamed to unique atom names.
    """
    if not atom.altloc:
      atom = self._rename_duplicate_atom(atom)
      atom_index = len(self.atoms)
      self.atoms.append((atom.x, atom.y, atom.z))
      self.annotations["chain_id"].append(atom.chain_id)
      self.annotations["res_id"].append(atom.res_id)
      self.annotations["ins_code"].append(atom.ins_code)
      self.annotations["res_name"].append(atom.res_name)
      self.annotations["hetero"].append(atom.hetero)
      self.annotations["atom_name"].append(atom.atom_name)
      self.annotations["element"].append(atom.element)
      self.annotations["atom_id"].append(atom.atom_id)
      self.annotations["b_factor"].append(atom.b_factor)
      self.annotations["occupancy"].append(atom.occupancy)
      self.annotations["charge"].append(atom.charge)
      self.annotations["sym_id"].append("")
      self.serial_to_index[atom.atom_id] = atom_index
      return

    atom_index = self._atom_key_to_index.get(atom.atom_key)
    if atom_index is None:
      atom_index = len(self.atoms)
      self._atom_key_to_index[atom.atom_key] = atom_index
      self.atoms.append((atom.x, atom.y, atom.z))
      self.annotations["chain_id"].append(atom.chain_id)
      self.annotations["res_id"].append(atom.res_id)
      self.annotations["ins_code"].append(atom.ins_code)
      self.annotations["res_name"].append(atom.res_name)
      self.annotations["hetero"].append(atom.hetero)
      self.annotations["atom_name"].append(atom.atom_name)
      self.annotations["element"].append(atom.element)
      self.annotations["atom_id"].append(atom.atom_id)
      self.annotations["b_factor"].append(atom.b_factor)
      self.annotations["occupancy"].append(atom.occupancy)
      self.annotations["charge"].append(atom.charge)
      self.annotations["sym_id"].append("")
      self.serial_to_index[atom.atom_id] = atom_index
      self._selected_altloc[atom.atom_key] = (atom.occupancy, atom.altloc, atom.atom_id)
      return

    prev_occupancy, prev_altloc, prev_serial = self._selected_altloc[atom.atom_key]
    if not self._should_replace_atom(prev_occupancy, prev_altloc, atom.occupancy, atom.altloc):
      return

    self.atoms[atom_index] = (atom.x, atom.y, atom.z)
    self.annotations["chain_id"][atom_index] = atom.chain_id
    self.annotations["res_id"][atom_index] = atom.res_id
    self.annotations["ins_code"][atom_index] = atom.ins_code
    self.annotations["res_name"][atom_index] = atom.res_name
    self.annotations["hetero"][atom_index] = atom.hetero
    self.annotations["atom_name"][atom_index] = atom.atom_name
    self.annotations["element"][atom_index] = atom.element
    self.annotations["atom_id"][atom_index] = atom.atom_id
    self.annotations["b_factor"][atom_index] = atom.b_factor
    self.annotations["occupancy"][atom_index] = atom.occupancy
    self.annotations["charge"][atom_index] = atom.charge
    self.annotations["sym_id"][atom_index] = ""
    if prev_serial in self.serial_to_index and self.serial_to_index[prev_serial] == atom_index:
      del self.serial_to_index[prev_serial]
    self.serial_to_index[atom.atom_id] = atom_index
    self._selected_altloc[atom.atom_key] = (atom.occupancy, atom.altloc, atom.atom_id)

  def _rename_duplicate_atom(self, atom: _AtomRecord) -> _AtomRecord:
    """Return an atom record with a unique atom name within its residue."""
    residue_key = (atom.chain_id, atom.res_id, atom.ins_code, atom.res_name, atom.hetero)
    atom_name_counts = self._residue_atom_name_counts.setdefault(residue_key, Counter())
    duplicate_index = atom_name_counts[atom.atom_name]
    atom_name_counts[atom.atom_name] += 1
    if duplicate_index == 0:
      return atom

    renamed_atom_name = _unique_atom_name(atom.atom_name, duplicate_index)
    logger.warning(
      'Duplicate atom name "%s" found in residue "%s" %s%s%s%s; renaming it to "%s".',
      atom.atom_name,
      atom.res_name,
      atom.chain_id or "<blank chain>",
      atom.res_id,
      atom.ins_code or "",
      " (hetero)" if atom.hetero else "",
      renamed_atom_name,
    )
    return _AtomRecord(
      x=atom.x,
      y=atom.y,
      z=atom.z,
      chain_id=atom.chain_id,
      res_id=atom.res_id,
      ins_code=atom.ins_code,
      res_name=atom.res_name,
      hetero=atom.hetero,
      atom_name=renamed_atom_name,
      element=atom.element,
      atom_id=atom.atom_id,
      b_factor=atom.b_factor,
      occupancy=atom.occupancy,
      charge=atom.charge,
      altloc=atom.altloc,
      atom_key=(atom.chain_id, atom.res_id, atom.ins_code, atom.res_name, atom.hetero, renamed_atom_name),
    )

  def add_directed_bond(self, source_serial: int, target_serial: int) -> bool:
    """Register a directed bond if both atom serials exist in this model."""
    source_index = self.serial_to_index.get(source_serial)
    target_index = self.serial_to_index.get(target_serial)
    if source_index is None or target_index is None or source_index == target_index:
      return False

    self.directed_bonds[(source_index, target_index)] += 1
    return True

  def to_structure(self) -> Structure:
    """Finalize the current model into a :class:`Structure`."""
    structure = Structure(remove_annotations=False)
    structure.metadata = {"model_id": self.model_id}

    if self.atoms:
      structure.atoms = np.array(self.atoms, dtype=structure._dtype_atoms)
      structure.atom_annotations = np.empty(len(self.atoms), dtype=structure._dtype_atom_annotations)
      for field_name, values in self.annotations.items():
        structure.atom_annotations[field_name] = np.array(values, dtype=structure._dtype_atom_annotations.fields[field_name][0])
    else:
      structure.atoms = np.zeros(0, dtype=structure._dtype_atoms)
      structure.atom_annotations = np.zeros(0, dtype=structure._dtype_atom_annotations)

    bond_rows = []
    undirected_bonds: Dict[Tuple[int, int], int] = {}
    for (atom_i, atom_j), count in self.directed_bonds.items():
      # PDB ``CONECT`` records may appear in both directions, so bond order is
      # determined from the strongest directed count for each undirected pair.
      pair = (min(atom_i, atom_j), max(atom_i, atom_j))
      undirected_bonds[pair] = max(undirected_bonds.get(pair, 0), count)

    for (atom_i, atom_j), bond_type in sorted(undirected_bonds.items()):
      bond_rows.append((atom_i, atom_j, bond_type))

    if bond_rows:
      structure.bonds = np.array(bond_rows, dtype=structure._dtype_bond)
    else:
      structure.bonds = np.zeros(0, dtype=structure._dtype_bond)

    structure._remove_empty_annotations()
    return structure

  @staticmethod
  def _should_replace_atom(prev_occupancy: float, prev_altloc: str, new_occupancy: float, new_altloc: str) -> bool:
    """Return ``True`` if a newly seen altloc atom should replace the current one."""
    if new_occupancy > prev_occupancy:
      return True
    if new_occupancy < prev_occupancy:
      return False
    if prev_altloc and not new_altloc:
      return True
    return False


[docs] def parse_pdb( pdb: Union[str, pathlib.Path, io.IOBase], return_type: ReturnType = "auto", malformed_conect: ConectErrorMode = "warn", ) -> Union[StructureEnsemble, StructureStack]: """Parse a PDB file into Neurosnap structure containers. Parsing follows the fixed-width PDB record layout used by BioPython's parser but builds Neurosnap :class:`Structure` entities.Parsed models are first collected into a :class:`StructureEnsemble` and are optionally converted into a :class:`StructureStack` at the end. HETATM records and ``CONECT`` records are parsed directly so ligands and custom covalent bonds are preserved more faithfully than in the old BioPython-backed path. Alternate locations are always ignored. When alternate locations are present, the parser keeps only the highest-occupancy conformer for each atom site and emits a :func:`logger.warning` so the user knows this happened. If a residue contains duplicate atom names without altloc identifiers, the parser preserves all such atoms by automatically renaming later duplicates to unique PDB-style atom names and emits a :func:`logger.warning` describing the rename. This avoids incorrectly collapsing distinct ligand atoms that happen to share generic names such as ``C`` or ``O``. Repeated ``CONECT`` records are interpreted as higher bond order using a directed-count collapse: bond_type = max(count(atom_i -> atom_j), count(atom_j -> atom_i)) This means repeated records in one direction can encode double or triple bonds, while mirrored ``CONECT`` entries from both atoms do not artificially inflate bond order. Parameters: pdb: PDB filepath or open file handle. return_type: Output container type. ``"ensemble"`` always returns a :class:`StructureEnsemble`, ``"stack"`` requires stack-compatible models, and ``"auto"`` returns a :class:`StructureStack` when possible or falls back to a :class:`StructureEnsemble`. malformed_conect: How malformed ``CONECT`` records should be handled. ``"strict"`` raises immediately, ``"warn"`` logs a warning and skips the bad record, and ``"ignore"`` silently skips it. Returns: A :class:`StructureEnsemble` or :class:`StructureStack` depending on ``return_type`` and model compatibility. """ if return_type not in {"ensemble", "stack", "auto"}: raise ValueError('return_type must be one of "ensemble", "stack", or "auto".') if malformed_conect not in {"strict", "warn", "ignore"}: raise ValueError('malformed_conect must be one of "strict", "warn", or "ignore".') lines = _read_lines(pdb) if not lines: raise ValueError("Empty file.") altloc_sites: set[AltlocSite] = set() ensemble = _parse_pdb_models(lines, altloc_sites, malformed_conect=malformed_conect) ensemble.metadata["source_format"] = "pdb" if altloc_sites: logger.warning( "Ignoring alternate locations for %d atom site(s); using the highest-occupancy conformer for each.", len(altloc_sites), ) if return_type == "ensemble": return ensemble if return_type == "stack": return StructureStack.from_ensemble(ensemble) try: return StructureStack.from_ensemble(ensemble) except ValueError: return ensemble
[docs] def save_pdb(structure: Union[Structure, StructureEnsemble, StructureStack], pdb: Union[str, pathlib.Path, io.IOBase]): """Save a Neurosnap structure container as a PDB file. Parameters: structure: Structure container to write. pdb: Output filepath or open file handle. Notes: Multi-model outputs are written using ``MODEL`` / ``ENDMDL`` records. ``CONECT`` records are written for single-model outputs and for multi-model outputs only when all models share identical bonds and atom serials so the topology can be represented unambiguously in PDB format. """ models = _models_for_pdb_output(structure) shared_conect_lines: Optional[List[str]] = None if len(models) > 1: shared_conect_lines = _shared_conect_lines(models) lines: List[str] = [] for model_position, (model_id, model) in enumerate(models): serials = _atom_serials_for_model(model) if len(models) > 1: lines.append(_format_model_record(model_id)) lines.extend(_atom_record_lines(model, serials)) if len(model) > 0: lines.append("TER") if len(models) == 1: lines.extend(_conect_lines_for_model(model, serials)) elif shared_conect_lines is None and len(model.bonds) > 0 and model_position == 0: logger.warning("Omitting CONECT records for multi-model PDB output because the models do not share identical bonds and atom serials.") if len(models) > 1: lines.append("ENDMDL") if shared_conect_lines is not None: lines.extend(shared_conect_lines) lines.append("END") _write_pdb_lines(pdb, lines)
def _parse_pdb_models( lines: Iterable[str], altloc_sites: set[AltlocSite], malformed_conect: ConectErrorMode, ) -> StructureEnsemble: """Parse PDB coordinate records into a :class:`StructureEnsemble`. The parser accumulates each model independently and applies ``CONECT`` records only after all atoms are known so serial-number lookups are complete. """ pending_conect: List[ConectRecord] = [] models: List[_ModelAccumulator] = [] current_model: Optional[_ModelAccumulator] = None implicit_model_id = 1 for line_number, raw_line in enumerate(lines, start=1): if not raw_line.strip(): continue padded_line = raw_line.ljust(80) record_type = padded_line[0:6] if record_type == "MODEL ": model_id = _parse_model_id(padded_line, line_number) current_model = _ModelAccumulator(model_id=model_id) models.append(current_model) implicit_model_id += 1 continue if record_type == "ENDMDL": current_model = None continue if record_type in ("ATOM ", "HETATM"): if current_model is None: current_model = _ModelAccumulator(model_id=implicit_model_id) models.append(current_model) implicit_model_id += 1 atom = _parse_atom_record(padded_line, record_type, line_number) if atom.altloc: # Altlocs are dropped later inside ``add_atom()``, but the parser keeps # track of affected sites so it can emit one summary warning. altloc_sites.add((current_model.model_id, atom.atom_key)) current_model.add_atom(atom) continue if record_type == "CONECT": conect = _parse_conect_record(padded_line, line_number, malformed_conect=malformed_conect) if conect is not None: pending_conect.append(conect) continue if record_type == "TER ": continue if not models: raise ValueError("No models or atoms were found in the PDB file.") _apply_conect_records(models, pending_conect, malformed_conect=malformed_conect) ensemble = StructureEnsemble() for model in models: ensemble.append(model.to_structure(), model_id=model.model_id) return ensemble def _models_for_pdb_output( structure: Union[Structure, StructureEnsemble, StructureStack], ) -> List[Tuple[int, Structure]]: """Return a normalized list of ``(model_id, model)`` pairs for writing.""" if isinstance(structure, Structure): model_id = int(structure.metadata.get("model_id", 1)) return [(model_id, structure)] if isinstance(structure, StructureEnsemble): return list(zip(structure.model_ids, structure.models())) if isinstance(structure, StructureStack): return list(zip(structure.model_ids, structure.models())) raise TypeError(f"Unsupported structure type for PDB output: {type(structure).__name__}.") def _shared_conect_lines(models: List[Tuple[int, Structure]]) -> Optional[List[str]]: """Return shared ``CONECT`` lines for compatible multi-model outputs.""" if not models: return [] serial_maps = [] for _, model in models: serials = _atom_serials_for_model(model) serial_maps.append(serials) reference_model = models[0][1] reference_serials = serial_maps[0] for (_, model), serials in zip(models[1:], serial_maps[1:]): if len(model.bonds) != len(reference_model.bonds): return None if not np.array_equal(model.bonds, reference_model.bonds): return None if not np.array_equal(serials, reference_serials): return None return _conect_lines_for_model(reference_model, reference_serials) def _atom_serials_for_model(model: Structure) -> np.ndarray: """Return atom serial numbers for a model, preserving them when possible.""" if len(model) > 99999: raise ValueError("PDB output supports at most 99999 atoms per model.") if "atom_id" in model.atom_annotations.dtype.names: serials = np.asarray(model.atom_annotations["atom_id"], dtype=np.int32) if serials.size and np.all(serials > 0) and len(np.unique(serials)) == len(serials) and np.max(serials) <= 99999: return serials.copy() return np.arange(1, len(model) + 1, dtype=np.int32) def _format_model_record(model_id: int) -> str: """Return a ``MODEL`` record.""" if model_id < 1 or model_id > 9999: raise ValueError(f'MODEL serial "{model_id}" is outside the supported PDB range 1-9999.') return f"MODEL {model_id:4d}" def _atom_record_lines(model: Structure, serials: np.ndarray) -> List[str]: """Return ``ATOM`` / ``HETATM`` lines for a model.""" chain_ids = model.atom_annotations["chain_id"] lines = [] previous_chain_id = None for atom_index in range(len(model)): chain_id = str(chain_ids[atom_index]) if previous_chain_id is not None and chain_id != previous_chain_id: lines.append("TER") lines.append(_format_atom_record(model, atom_index, int(serials[atom_index]))) previous_chain_id = chain_id return lines def _format_atom_record(model: Structure, atom_index: int, serial: int) -> str: """Format one ``ATOM`` or ``HETATM`` record.""" if serial < 1 or serial > 99999: raise ValueError(f'Atom serial "{serial}" is outside the supported PDB range 1-99999.') atom_name = str(model.atom_annotations["atom_name"][atom_index]).strip().upper() residue_name = str(model.atom_annotations["res_name"][atom_index]).strip().upper() chain_id = str(model.atom_annotations["chain_id"][atom_index]).strip() insertion_code = str(model.atom_annotations["ins_code"][atom_index]).strip() element = str(model.atom_annotations["element"][atom_index]).strip().upper() hetero = bool(model.atom_annotations["hetero"][atom_index]) residue_id = int(model.atom_annotations["res_id"][atom_index]) if not atom_name: raise ValueError(f"Atom {atom_index + 1} is missing an atom_name and cannot be written to PDB.") if len(atom_name) > 4: raise ValueError(f'Atom name "{atom_name}" exceeds the 4-character PDB limit.') if not residue_name: raise ValueError(f"Atom {atom_index + 1} is missing a res_name and cannot be written to PDB.") if len(residue_name) > 3: raise ValueError(f'Residue name "{residue_name}" exceeds the 3-character PDB limit.') if len(chain_id) > 1: raise ValueError(f'Chain ID "{chain_id}" exceeds the 1-character PDB limit.') if len(insertion_code) > 1: raise ValueError(f'Insertion code "{insertion_code}" exceeds the 1-character PDB limit.') if not element: raise ValueError(f"Atom {atom_index + 1} is missing an element and cannot be written to PDB.") if len(element) > 2: raise ValueError(f'Element "{element}" exceeds the 2-character PDB limit.') if len(f"{residue_id:d}") > 4: raise ValueError(f'Residue ID "{residue_id}" exceeds the 4-character PDB limit.') occupancy = _annotation_value_for_pdb(model, "occupancy", atom_index, 1.0) b_factor = _annotation_value_for_pdb(model, "b_factor", atom_index, 0.0) charge = _annotation_value_for_pdb(model, "charge", atom_index, 0) atom_name_field = _format_atom_name(atom_name, element) charge_field = _format_charge_field(int(charge)) record_name = "HETATM" if hetero else "ATOM " return ( f"{record_name}{serial:5d} {atom_name_field} " f"{residue_name:>3} {chain_id[:1]:1}{residue_id:4d}{insertion_code[:1]:1} " f"{float(model.atoms['x'][atom_index]):8.3f}" f"{float(model.atoms['y'][atom_index]):8.3f}" f"{float(model.atoms['z'][atom_index]):8.3f}" f"{float(occupancy):6.2f}" f"{float(b_factor):6.2f}" f" {element:>2}{charge_field:>2}" ) def _annotation_value_for_pdb(model: Structure, name: str, atom_index: int, default): """Return an annotation value with a fallback default for PDB output.""" if name not in model.atom_annotations.dtype.names: return default value = model.atom_annotations[name][atom_index] if isinstance(value, np.generic): return value.item() return value def _format_atom_name(atom_name: str, element: str) -> str: """Return a 4-character atom name field.""" if len(atom_name) == 4: return atom_name if len(element) == 1 and atom_name and atom_name[0].isalpha(): return f" {atom_name:<3}" return f"{atom_name:>4}" def _format_charge_field(charge: int) -> str: """Return a 2-character PDB charge field.""" if charge == 0: return "" if abs(charge) > 9: raise ValueError(f'Atom charge "{charge}" exceeds the 1-digit PDB charge limit.') sign = "+" if charge > 0 else "-" return f"{abs(charge)}{sign}" def _conect_lines_for_model(model: Structure, serials: np.ndarray) -> List[str]: """Return ``CONECT`` lines for a model.""" if len(model.bonds) == 0: return [] atom_index_to_serial = {atom_index: int(serial) for atom_index, serial in enumerate(serials)} directed_counts: Counter[Tuple[int, int]] = Counter() for bond in model.bonds: atom_i = int(bond["atom_i"]) atom_j = int(bond["atom_j"]) bond_type = max(1, int(bond["bond_type"])) if atom_i not in atom_index_to_serial or atom_j not in atom_index_to_serial: raise ValueError("Bond table contains atom indices outside the atom table.") directed_counts[(atom_i, atom_j)] += bond_type lines = [] for (atom_i, atom_j), count in sorted(directed_counts.items()): source_serial = atom_index_to_serial[atom_i] target_serial = atom_index_to_serial[atom_j] lines.extend(_format_conect_records(source_serial, target_serial, count)) return lines def _format_conect_records(source_serial: int, target_serial: int, count: int) -> List[str]: """Return one or more ``CONECT`` records for a repeated bond.""" if source_serial > 99999 or target_serial > 99999: raise ValueError("CONECT atom serial exceeds the 5-character PDB limit.") repeated_targets = [target_serial] * count lines = [] for start in range(0, len(repeated_targets), 4): chunk = repeated_targets[start : start + 4] fields = "".join(f"{serial:5d}" for serial in chunk) lines.append(f"CONECT{source_serial:5d}{fields}") return lines def _apply_conect_records( models: List[_ModelAccumulator], pending_conect: List[ConectRecord], malformed_conect: ConectErrorMode, ): """Apply stored ``CONECT`` records to all models that contain the referenced serials.""" for source_serial, target_serials, line_number in pending_conect: found_match = False for model in models: for target_serial in target_serials: if model.add_directed_bond(source_serial, target_serial): found_match = True if not found_match: _handle_malformed_conect( f"CONECT record references unknown atom serials for source atom {source_serial} at line {line_number}.", malformed_conect=malformed_conect, ) def _parse_model_id(line: str, line_number: int) -> int: """Parse the serial number from a ``MODEL`` record.""" model_field = line[10:14].strip() if not model_field: raise ValueError(f"Missing MODEL serial number at line {line_number}.") try: return int(model_field) except ValueError: raise ValueError(f'Invalid MODEL serial number "{model_field}" at line {line_number}.') def _parse_atom_record(line: str, record_type: str, line_number: int) -> _AtomRecord: """Parse a single ``ATOM`` or ``HETATM`` record. The parser is intentionally strict about required fields such as residue names and element assignments so ambiguous files fail early. """ try: atom_id = int(line[6:11].strip()) atom_name_field = line[12:16] atom_name = line[12:16].strip() altloc = line[16].strip() res_name = line[17:20].strip() chain_id = line[21].strip() res_id = int(line[22:26].strip()) ins_code = line[26].strip() x = float(line[30:38].strip()) y = float(line[38:46].strip()) z = float(line[46:54].strip()) occupancy = _parse_float_field(line[54:60], default=1.0, line_number=line_number, label="occupancy") b_factor = _parse_float_field(line[60:66], default=0.0, line_number=line_number, label="B-factor") element = line[76:78].strip().upper() except ValueError as exc: raise ValueError(f"Invalid atom record ({exc}) at line {line_number}.") if not atom_name: raise ValueError(f"Missing atom name at line {line_number}.") if not res_name: raise ValueError(f"Missing residue name at line {line_number}.") if not element: inferred_element = _infer_element_from_atom_name_field(atom_name_field) if inferred_element is None: raise ValueError(f"Missing element assignment at line {line_number}.") logger.warning('Missing element assignment at line %s; inferred element "%s" from atom name "%s".', line_number, inferred_element, atom_name) element = inferred_element charge = _parse_charge(line[78:80], line_number) hetero = record_type == "HETATM" # Altloc is excluded from the identity key so alternate conformers collapse # onto a single atom site during model accumulation. atom_key = (chain_id, res_id, ins_code, res_name, hetero, atom_name) return _AtomRecord( x=x, y=y, z=z, chain_id=chain_id, res_id=res_id, ins_code=ins_code, res_name=res_name, hetero=hetero, atom_name=atom_name, element=element, atom_id=atom_id, b_factor=b_factor, occupancy=occupancy, charge=charge, altloc=altloc, atom_key=atom_key, ) def _infer_element_from_atom_name_field(atom_name_field: str) -> Optional[str]: """Infer an element symbol from the raw 4-character PDB atom-name field.""" if not atom_name_field.strip(): return None if atom_name_field[0].isdigit(): candidate = atom_name_field[1:2].strip().title() return candidate if candidate in ATOMIC_MASSES else None if atom_name_field[0] == " ": candidate = atom_name_field[1:2].strip().title() return candidate if candidate in ATOMIC_MASSES else None candidate = atom_name_field[:2].strip().title() if candidate in ATOMIC_MASSES: return candidate.upper() candidate = atom_name_field[:1].strip().title() if candidate in ATOMIC_MASSES: return candidate.upper() return None def _parse_conect_record(line: str, line_number: int, malformed_conect: ConectErrorMode) -> Optional[Tuple[int, List[int], int]]: """Parse a ``CONECT`` record into source and target serial numbers.""" serials: List[int] = [] for start in range(6, len(line), 5): chunk = line[start : start + 5].strip() if not chunk: continue try: serials.append(int(chunk)) except ValueError: _handle_malformed_conect(f'Invalid CONECT serial "{chunk}" at line {line_number}.', malformed_conect=malformed_conect) return None if len(serials) < 2: return None return serials[0], serials[1:], line_number def _handle_malformed_conect(message: str, malformed_conect: ConectErrorMode): """Handle a malformed ``CONECT`` record according to parser policy.""" if malformed_conect == "strict": raise ValueError(message) if malformed_conect == "warn": logger.warning(message) def _parse_float_field(field: str, *, default: float, line_number: int, label: str) -> float: """Parse a float field with fixed defaults for blank values.""" field = field.strip() if not field: return default try: return float(field) except ValueError: raise ValueError(f'Invalid {label} value "{field}" at line {line_number}.') def _parse_charge(field: str, line_number: int) -> int: """Parse the PDB atom charge field into a signed integer.""" field = field.strip() if not field: return 0 if len(field) == 2 and field[0].isdigit() and field[1] in "+-": return int(field[0]) * (1 if field[1] == "+" else -1) if len(field) == 2 and field[1].isdigit() and field[0] in "+-": return int(field[1]) * (1 if field[0] == "+" else -1) if field.endswith(("+", "-")) and field[:-1].isdigit(): return int(field[:-1]) * (1 if field[-1] == "+" else -1) if field.lstrip("+-").isdigit(): return int(field) raise ValueError(f'Invalid atom charge "{field}" at line {line_number}.') def _unique_atom_name(atom_name: str, duplicate_index: int) -> str: """Return a best-effort unique atom name within the 4-character PDB limit.""" suffix = str(duplicate_index + 1) base = atom_name.strip() or "X" return f"{base[: max(1, 4 - len(suffix))]}{suffix}"[:4] def _read_lines(file: Union[str, pathlib.Path, io.IOBase]) -> List[str]: """Read all lines from a filepath or file handle. Text and binary file handles are both accepted so the parser works with the same broad range of inputs as the rest of the I/O layer. """ if isinstance(file, (str, pathlib.Path)): with open(file, "rt", encoding="utf-8") as handle: return handle.read().splitlines() content = file.read() if isinstance(content, bytes): content = content.decode("utf-8") return content.splitlines() def _write_pdb_lines(file: Union[str, pathlib.Path, io.IOBase], lines: List[str]): """Write PDB lines to a filepath or file handle.""" content = "\n".join(lines) + "\n" if isinstance(file, (str, pathlib.Path)): with open(file, "wt", encoding="utf-8", newline="\n") as handle: handle.write(content) return if isinstance(file, io.TextIOBase): file.write(content) return try: file.write(content) except TypeError: file.write(content.encode("utf-8"))