import itertools as it
import logging
from collections import abc, defaultdict
from dataclasses import dataclass
from functools import partial
import stk
from rdkit import Chem
from stko._internal.molecular.molecular_utilities import separate_molecule
logger = logging.getLogger(__name__)
[docs]
class UnreactedTopologyGraph:
"""A class containing topology graphs and performing subset of reactions.
Use this to get partially reacted topology graphs.
.. testcode:: unreacted-topology-graph
import stk
import stko
bb1 = stk.BuildingBlock(
smiles="NCCN", functional_groups=(stk.PrimaryAminoFactory(),)
)
bb2 = stk.BuildingBlock(
smiles="O=CC(C=O)C=O", functional_groups=(stk.AldehydeFactory(),)
)
cage_graphs = stko.topology_functions.UnreactedTopologyGraph(
stk.cage.TwoPlusThree((bb1, bb2))
)
# Get a pool of NamedIntermediates with only 1 reaction, which will
# contain the reacted + the building blocks (there are 2). You can
# iterate through that pool to get the named intermediate, containing
# an stk molecule and other information about the intermediate.
pool = cage_graphs.get_named_intermediates(n=1)
.. testcode:: unreacted-topology-graph
:hide:
assert len(cage_graphs.get_available_reactions()) == 6
assert len(pool) == 3
.. moldoc::
import moldoc.molecule as molecule
import stk
import stko
bb1 = stk.BuildingBlock(
smiles="NCCN", functional_groups=(stk.PrimaryAminoFactory(),)
)
bb2 = stk.BuildingBlock(
smiles="O=CC(C=O)C=O", functional_groups=(stk.AldehydeFactory(),)
)
cage_graphs = stko.topology_functions.UnreactedTopologyGraph(
stk.cage.TwoPlusThree((bb1, bb2))
)
# Get a NamedIntermediate with only 1 reaction, which contains an
# stk molecule and other information about the intermediate.
pool = cage_graphs.get_named_intermediates(n=1)
moldoc_display_molecule = molecule.Molecule(
atoms=(
molecule.Atom(
atomic_number=atom.get_atomic_number(),
position=position,
) for atom, position in zip(
pool.intermediates[0].molecule.get_atoms(),
pool.intermediates[0].molecule.get_position_matrix(),
)
),
bonds=(
molecule.Bond(
atom1_id=bond.get_atom1().get_id(),
atom2_id=bond.get_atom2().get_id(),
order=bond.get_order(),
) for bond in pool.intermediates[0].molecule.get_bonds()
),
)
"""
def __init__(self, topology_graph: stk.TopologyGraph) -> None:
"""Initialize UnreactedTopologyGraph.
Parameters:
topology_graph:
The `stk` topology graph to contain.
"""
self._topology_graph = topology_graph
self._state = self._topology_graph._get_construction_state() # noqa: SLF001
self._state = self._topology_graph._place_building_blocks(self._state) # noqa: SLF001
get_reaction = partial(
self._topology_graph._reaction_factory.get_reaction, # noqa: SLF001
self._state,
)
self._reactions = tuple(
map(get_reaction, self._topology_graph._edge_groups) # noqa: SLF001
)
self._results = tuple(
reaction.get_result() for reaction in self._reactions
)
[docs]
def get_available_reactions(self) -> abc.Sequence[stk.Reaction]:
"""Get all the reaction classes possible."""
return self._reactions
[docs]
def get_reaction_results(self) -> abc.Sequence[stk.ReactionResult]:
"""Get all the reaction results."""
return self._results
[docs]
def yield_constructed_molecules(
self,
n: int | None = None,
) -> abc.Iterator[stk.ConstructedMolecule]:
"""Yield constructed molecules for possible reaction combos.
If `n` is None, this produces all reactions, which could be
a combinatorial nightmare.
"""
for i in range(1, len(self._results)):
if n is not None and i != n:
continue
for inter_reactions in it.combinations(self._reactions, i):
results = tuple(
reaction.get_result() for reaction in inter_reactions
)
yield stk.ConstructedMolecule.init_from_construction_result(
self._topology_graph._get_construction_result( # noqa: SLF001
self._state.with_reaction_results(
inter_reactions, results
)
)
)
[docs]
def get_reacted_smiles(self, n: int | None = None) -> set[str]:
"""Yield constructed molecules with n reactions performed."""
yielded_smiles = set()
for const_mol in self.yield_constructed_molecules(n=n):
distinct_molecules = separate_molecule(const_mol)
for dmol, _ in distinct_molecules:
smiles = Chem.CanonSmiles(
Chem.MolToSmiles(dmol.to_rdkit_mol())
)
if "." in smiles:
msg = "Found `.` in smiles."
raise RuntimeError(msg)
if smiles in yielded_smiles:
continue
yielded_smiles.add(smiles)
return yielded_smiles
[docs]
def get_present_building_blocks(
self,
const_mol: stk.ConstructedMolecule,
subset_ids: list[int],
) -> dict[stk.BuildingBlock, list[int]]:
"""Get the building blocks present in a constructed molecule."""
bbs: dict[stk.BuildingBlock, list[int]] = defaultdict(list)
for atom_info in const_mol.get_atom_infos():
if atom_info.get_atom().get_id() in subset_ids and (
atom_info.get_building_block_id()
not in bbs[atom_info.get_building_block()] # type: ignore[index]
):
bbs[atom_info.get_building_block()].append( # type: ignore[index]
atom_info.get_building_block_id() # type: ignore[arg-type]
)
return bbs