"""Parser and writer for SDF files.
This module provides Neurosnap-native :func:`parse_sdf` and :func:`save_sdf`
helpers for reading and writing
:class:`~neurosnap.structure.structure.Structure`,
:class:`~neurosnap.structure.structure.StructureEnsemble`, and
:class:`~neurosnap.structure.structure.StructureStack` objects.
The implementation intentionally follows RDKit's own SDF reading and writing
logic as closely as possible: molecules are parsed through RDKit suppliers,
sanitized using RDKit defaults, and written back using RDKit's SD writer.
"""
import io
import pathlib
from typing import Dict, List, Literal, Tuple, Union
import numpy as np
from rdkit import Chem
from rdkit.Chem import rdchem
from rdkit.Geometry import Point3D
from neurosnap.structure.structure import Structure, StructureEnsemble, StructureStack
__all__ = ["parse_sdf", "save_sdf"]
ReturnType = Literal["ensemble", "stack", "auto"]
_RDKIT_BOND_TO_INT = {
rdchem.BondType.SINGLE: 1,
rdchem.BondType.DOUBLE: 2,
rdchem.BondType.TRIPLE: 3,
rdchem.BondType.AROMATIC: 4,
rdchem.BondType.QUADRUPLE: 5,
}
_INT_TO_RDKIT_BOND = {
1: rdchem.BondType.SINGLE,
2: rdchem.BondType.DOUBLE,
3: rdchem.BondType.TRIPLE,
4: rdchem.BondType.AROMATIC,
5: rdchem.BondType.QUADRUPLE,
}
[docs]
def parse_sdf(
sdf: Union[str, pathlib.Path, io.IOBase],
return_type: ReturnType = "auto",
) -> Union[StructureEnsemble, StructureStack]:
"""Parse an SDF file into Neurosnap structure containers.
Each SDF record is parsed with RDKit and converted into one Neurosnap
:class:`Structure` model. Multi-record SDF files therefore map naturally to
a :class:`StructureEnsemble`, and ``return_type="auto"`` will return a
:class:`StructureStack` when all records share identical atom annotations and
bonds.
Because SDF is a small-molecule format, chain and residue hierarchy are not
natively represented. Parsed structures therefore default to a single
heterogen residue ``LIG`` in chain ``A`` unless RDKit monomer information is
present on the atoms.
Parameters:
sdf: SDF filepath or open file handle.
return_type: Output container type. ``"ensemble"`` always returns a
:class:`StructureEnsemble`, ``"stack"`` requires stack-compatible models,
and ``"auto"`` returns a :class:`StructureStack` when possible or falls
back to a :class:`StructureEnsemble`.
Returns:
A :class:`StructureEnsemble` or :class:`StructureStack` depending on
``return_type`` and model compatibility.
"""
if return_type not in {"ensemble", "stack", "auto"}:
raise ValueError('return_type must be one of "ensemble", "stack", or "auto".')
supplier = _rdkit_supplier_from_sdf(sdf)
ensemble = StructureEnsemble()
for record_index, mol in enumerate(supplier, start=1):
if mol is None:
raise ValueError(f"Failed to parse SDF record {record_index}.")
structure = _structure_from_rdkit_mol(mol, model_id=record_index)
ensemble.append(structure, model_id=record_index)
if len(ensemble) == 0:
raise ValueError("No molecules were found in the SDF file.")
ensemble.metadata["source_format"] = "sdf"
if return_type == "ensemble":
return ensemble
if return_type == "stack":
return StructureStack.from_ensemble(ensemble)
try:
return StructureStack.from_ensemble(ensemble)
except ValueError:
return ensemble
[docs]
def save_sdf(
structure: Union[Structure, StructureEnsemble, StructureStack],
sdf: Union[str, pathlib.Path, io.IOBase],
):
"""Save a Neurosnap structure container as an SDF file.
Parameters:
structure: Structure container to write.
sdf: Output filepath or open file handle.
Notes:
SDF is a small-molecule format, so chain and residue hierarchy are flattened
during output. Each model is written as a separate SDF record using RDKit's
own SD writer. Structure metadata is exported as SDF molecule properties
when the values are scalar.
"""
models = _models_for_sdf_output(structure)
if not models:
raise ValueError("No models are available for SDF output.")
writer = Chem.SDWriter(str(sdf) if not isinstance(sdf, io.IOBase) else sdf)
try:
for model_id, model in models:
mol = _rdkit_mol_from_structure(model, model_id=model_id)
writer.write(mol)
finally:
writer.close()
def _rdkit_supplier_from_sdf(sdf: Union[str, pathlib.Path, io.IOBase]) -> Chem.SDMolSupplier:
"""Return an RDKit SDF supplier for a filepath or file-like object."""
if isinstance(sdf, io.IOBase):
content = sdf.read()
if isinstance(content, bytes):
content = content.decode("utf-8")
supplier = Chem.SDMolSupplier()
supplier.SetData(content, sanitize=True, removeHs=False, strictParsing=True)
return supplier
return Chem.SDMolSupplier(str(sdf), sanitize=True, removeHs=False, strictParsing=True)
def _structure_from_rdkit_mol(mol: Chem.Mol, model_id: int) -> Structure:
"""Convert one RDKit molecule into a Neurosnap structure model."""
if mol.GetNumAtoms() == 0:
raise ValueError("SDF record contains no atoms.")
if mol.GetNumConformers() == 0:
raise ValueError("SDF record does not contain 3D coordinates.")
conformer = mol.GetConformer()
structure = Structure(remove_annotations=False)
structure.metadata = {"model_id": model_id}
if mol.HasProp("_Name"):
structure.metadata["title"] = mol.GetProp("_Name")
for prop_name in mol.GetPropNames(includePrivate=False, includeComputed=False):
if prop_name == "_Name":
continue
structure.metadata[prop_name] = mol.GetProp(prop_name)
atom_defs = []
bond_rows = []
element_counts: Dict[str, int] = {}
for atom_index, atom in enumerate(mol.GetAtoms()):
atom_info = atom.GetMonomerInfo()
if isinstance(atom_info, Chem.AtomPDBResidueInfo):
chain_id = atom_info.GetChainId().strip() or "A"
res_id = atom_info.GetResidueNumber()
ins_code = atom_info.GetInsertionCode().strip()
res_name = atom_info.GetResidueName().strip() or "LIG"
hetero = atom_info.GetIsHeteroAtom()
atom_name = atom_info.GetName().strip() or atom.GetSymbol()
atom_id = atom_info.GetSerialNumber() or (atom_index + 1)
else:
chain_id = "A"
res_id = 1
ins_code = ""
res_name = "LIG"
hetero = True
element = atom.GetSymbol().upper()
element_counts[element] = element_counts.get(element, 0) + 1
atom_name = f"{element}{element_counts[element]}"
atom_id = atom_index + 1
position = conformer.GetAtomPosition(atom_index)
atom_defs.append(
(
float(position.x),
float(position.y),
float(position.z),
chain_id,
int(res_id),
ins_code,
res_name,
bool(hetero),
atom_name,
atom.GetSymbol().upper(),
int(atom_id),
0.0,
1.0,
int(atom.GetFormalCharge()),
"",
)
)
for bond in mol.GetBonds():
bond_type = _RDKIT_BOND_TO_INT.get(bond.GetBondType())
if bond_type is None:
if bond.GetIsAromatic():
bond_type = 4
else:
raise ValueError(f"Unsupported RDKit bond type {bond.GetBondType()} in SDF record.")
bond_rows.append((bond.GetBeginAtomIdx(), bond.GetEndAtomIdx(), bond_type))
structure.atoms = np.array([(x, y, z) for x, y, z, *_rest in atom_defs], dtype=structure._dtype_atoms)
structure.atom_annotations = np.zeros(len(atom_defs), dtype=structure._dtype_atom_annotations)
for atom_index, (_x, _y, _z, chain_id, res_id, ins_code, res_name, hetero, atom_name, element, atom_id, b_factor, occupancy, charge, sym_id) in enumerate(atom_defs):
structure.atom_annotations["chain_id"][atom_index] = chain_id
structure.atom_annotations["res_id"][atom_index] = res_id
structure.atom_annotations["ins_code"][atom_index] = ins_code
structure.atom_annotations["res_name"][atom_index] = res_name
structure.atom_annotations["hetero"][atom_index] = hetero
structure.atom_annotations["atom_name"][atom_index] = atom_name
structure.atom_annotations["element"][atom_index] = element
structure.atom_annotations["atom_id"][atom_index] = atom_id
structure.atom_annotations["b_factor"][atom_index] = b_factor
structure.atom_annotations["occupancy"][atom_index] = occupancy
structure.atom_annotations["charge"][atom_index] = charge
structure.atom_annotations["sym_id"][atom_index] = sym_id
if bond_rows:
structure.bonds = np.array(bond_rows, dtype=structure._dtype_bond)
else:
structure.bonds = np.zeros(0, dtype=structure._dtype_bond)
structure._remove_empty_annotations()
return structure
def _models_for_sdf_output(structure: Union[Structure, StructureEnsemble, StructureStack]) -> List[Tuple[int, Structure]]:
"""Return a normalized list of ``(model_id, model)`` pairs for writing."""
if isinstance(structure, Structure):
model_id = int(structure.metadata.get("model_id", 1))
return [(model_id, structure)]
if isinstance(structure, StructureEnsemble):
return list(zip(structure.model_ids, structure.models()))
if isinstance(structure, StructureStack):
return list(zip(structure.model_ids, structure.models()))
raise TypeError(f"Unsupported structure type for SDF output: {type(structure).__name__}.")
def _rdkit_mol_from_structure(structure: Structure, model_id: int) -> Chem.Mol:
"""Convert one Neurosnap structure model into an RDKit molecule."""
rw_mol = Chem.RWMol()
conformer = Chem.Conformer(len(structure))
aromatic_atoms: set[int] = set()
for atom_index in range(len(structure)):
element = str(structure.atom_annotations["element"][atom_index]).strip().upper()
if not element:
raise ValueError(f"Atom {atom_index + 1} is missing an element and cannot be written to SDF.")
rd_atom = Chem.Atom(element)
charge = _annotation_value_for_sdf(structure, "charge", atom_index, 0)
rd_atom.SetFormalCharge(int(charge))
rw_mol.AddAtom(rd_atom)
conformer.SetAtomPosition(
atom_index,
Point3D(
float(structure.atoms["x"][atom_index]),
float(structure.atoms["y"][atom_index]),
float(structure.atoms["z"][atom_index]),
),
)
for bond in structure.bonds:
atom_i = int(bond["atom_i"])
atom_j = int(bond["atom_j"])
bond_type_value = int(bond["bond_type"])
rd_bond_type = _INT_TO_RDKIT_BOND.get(bond_type_value)
if rd_bond_type is None:
raise ValueError(f"Unsupported bond_type {bond_type_value} for SDF output.")
rw_mol.AddBond(atom_i, atom_j, rd_bond_type)
if rd_bond_type == rdchem.BondType.AROMATIC:
aromatic_atoms.add(atom_i)
aromatic_atoms.add(atom_j)
mol = rw_mol.GetMol()
mol.AddConformer(conformer, assignId=True)
for atom_index in aromatic_atoms:
atom = mol.GetAtomWithIdx(atom_index)
atom.SetIsAromatic(True)
for bond in mol.GetBonds():
if bond.GetBondType() == rdchem.BondType.AROMATIC:
bond.SetIsAromatic(True)
title = structure.metadata.get("title") or structure.metadata.get("name") or f"model_{model_id}"
mol.SetProp("_Name", str(title))
mol.SetIntProp("model_id", int(model_id))
for key, value in structure.metadata.items():
if key in {"title", "name", "model_id"}:
continue
if isinstance(value, (str, int, float, bool)):
mol.SetProp(str(key), str(value))
return mol
def _annotation_value_for_sdf(structure: Structure, name: str, atom_index: int, default):
"""Return an annotation value with a fallback default for SDF output."""
if name not in structure.atom_annotations.dtype.names:
return default
value = structure.atom_annotations[name][atom_index]
if isinstance(value, np.generic):
return value.item()
return value