from __future__ import annotations
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Hashable, Iterable, NamedTuple
import ase
import numpy as np
import numpy.typing as npt
from pymatgen.core.lattice import Lattice as PymatgenLattice
from topo_metrics._julia_wrappers import (
get_all_rings,
get_coordination_sequences,
get_topological_genome,
run_rings,
run_strong_rings,
)
from topo_metrics.clusters import (
Cluster,
get_carvs_vector,
get_clusters,
get_vertex_symbol,
)
from topo_metrics.io.cgd import parse_cgd, process_neighbour_list
from topo_metrics.io.conflink import parse_conflink
from topo_metrics.io.lammps_data import load_lammps_data
from topo_metrics.neighbours import (
autoreduce_neighborlist,
graph_edges_by_cutoff,
)
from topo_metrics.ring_geometry import RingGeometry
from topo_metrics.rings import RingSizeCounts
from topo_metrics.symbols import VertexSymbol
from topo_metrics.utils import uniform_repr
[docs]
class RingsResults(NamedTuple):
""" The depth to which the rings were searcher for. """
""" Whether the rings were filtered to strong rings only. """
[docs]
ring_size_count: RingSizeCounts
"""
The number of rings of a given size, of shape ``(**, 2)`` where the first
column indicates the ring size, and the second indicates the number of rings
of that size.
"""
[docs]
clusters: list[Cluster]
"""
Each node can be characterised in terms of the rings in which it
participates. This can be summarised using common metrics such as Vertex
Symbols.
"""
def __repr__(self) -> str:
vs_name = "VertexSymbol"
vertex_symbols_str = ""
vertex_symbols = {
x.to_str()
for x in get_vertex_symbol(self.clusters)
if isinstance(x, VertexSymbol)
}
if len(vertex_symbols) > 1:
vs_name = "VertexSymbols"
vertex_symbols_str += "{\n\t"
if len(vertex_symbols) > 10:
displayed_symbols = (
list(vertex_symbols)[:5] + ["..."] + list(vertex_symbols)[-5:]
)
else:
displayed_symbols = vertex_symbols
vertex_symbols_str += ",\n\t".join(displayed_symbols)
if len(vertex_symbols) > 1:
vertex_symbols_str += "\n}"
info = {
"depth": self.depth,
"strong_rings": self.rings_are_strong,
"ring_size_count": self.ring_size_count,
vs_name: vertex_symbols_str,
"CARVS": get_carvs_vector(self.clusters),
}
return uniform_repr(
"RingsResults", **info, indent_size=4, stringify=False
)
@dataclass
[docs]
class Topology:
"""
A class detailing the topology of a network, based on nodes and edges.
"""
[docs]
edges: npt.NDArray[np.int_] = field(
default_factory=lambda: np.empty((0, 5), dtype=int)
)
[docs]
lattice: PymatgenLattice | None = None
[docs]
properties: dict[Hashable, Any] = field(default_factory=dict)
@classmethod
[docs]
def from_ase(
cls,
ase_atoms: ase.Atoms,
cutoff: float = 0.0,
pair_cutoffs: dict[tuple[str, str], float] | None = None,
remove_types: Iterable[Any] | None = None,
remove_degree2: bool = False,
) -> Topology:
"""
Creates a Topology object from an ASE Atoms object.
Parameters
----------
ase_atoms
The ASE Atoms object representing the structure.
Returns
-------
A Topology object representing the network as nodes and edges.
"""
lattice = None
cart_coords = None
frac_coords = [None] * len(ase_atoms)
if all(ase_atoms.pbc):
lattice = PymatgenLattice(ase_atoms.cell)
frac_coords = ase_atoms.get_scaled_positions()
cart_coords = ase_atoms.get_positions()
symbols = ase_atoms.get_chemical_symbols()
edges = graph_edges_by_cutoff(
ase_atoms, cutoff=cutoff, pair_cutoffs=pair_cutoffs, one_based=True
)
if remove_types is not None or remove_degree2:
_reduced = autoreduce_neighborlist(
cart_coords=cart_coords,
frac_coords=frac_coords,
symbols=symbols,
edges=edges,
remove_types=remove_types,
remove_degree2=remove_degree2,
)
cart_coords, frac_coords, symbols, edges, _ = _reduced
nodes = []
for idx, (xc, xs, symbol) in enumerate(
zip(cart_coords, frac_coords, symbols), start=1
):
nodes.append(
Node(
node_id=idx,
node_type=symbol,
cart_coord=xc,
frac_coord=xs,
)
)
return cls(nodes=nodes, edges=edges, lattice=lattice)
@classmethod
[docs]
def from_cgd(cls, filename: Path | str) -> Topology:
"""
Parses and loads a CGD file with an adjacency matrix.
Parameters
----------
filename
The path to the CGD file.
Returns
-------
A Topology object representing the network as nodes and edges.
"""
if not os.path.exists(filename):
raise FileNotFoundError(f"File '{filename}' not found.")
lattice, atom_labels, atoms, edges = parse_cgd(filename)
neighbour_list = process_neighbour_list(edges, atoms, atom_labels)
# Create Node instances
all_nodes = []
if atoms is not None:
all_nodes = [
Node(
node_id=i + 1,
node_type=label,
frac_coord=frac_coord,
)
for i, (label, frac_coord) in enumerate(zip(atom_labels, atoms))
]
else:
all_nodes = [
Node(
node_id=i + 1,
node_type=label,
frac_coord=None,
)
for i, label in enumerate(atom_labels)
]
return cls(nodes=all_nodes, edges=neighbour_list, lattice=lattice)
@classmethod
[docs]
def from_conflink(
cls,
filename: str,
node_type: str = "Si",
index: int | None = None,
) -> Topology:
"""
Parses and loads a conflink file.
Parameters
----------
filename
The path to the conflink file.
node_type
The type to assign to all nodes in the topology.
Returns
-------
A Topology object representing the network as nodes and edges.
"""
configs = parse_conflink(filename)
if not configs:
raise ValueError(f"No configurations found in '{filename}'.")
index = -1 if index is None else index
components = configs[index].to_nodes_edges_lattice(node_type=node_type)
nodes = [
Node(
node_id=atom_id + 1,
node_type=node_type,
cart_coord=xc,
frac_coord=xs,
)
for atom_id, (xc, xs) in enumerate(
zip(components["cart_coords"], components["frac_coords"])
)
]
return cls(
nodes=nodes,
edges=components["edges"],
lattice=components["lattice"],
)
@classmethod
[docs]
def from_lammps_data(
cls,
filename: Path | str,
*,
atom_style: str = "atomic",
units: str = "metal",
sort_by_id: bool = True,
prefer_bonds: bool = True,
cutoff: float = 0.0,
pair_cutoffs: dict[tuple[str, str], float] | None = None,
contract_neighborlist: bool = False,
remove_types: Iterable[Any] | None = None,
remove_degree2: bool = False,
omit_node_types: bool = False,
) -> Topology:
"""Create a Topology from a LAMMPS data file.
If the file contains a `Bonds` section, bonds are used as the edge list,
and periodic image shifts are inferred assuming MIC. Otherwise edges are
inferred by cutoff.
"""
comps = load_lammps_data(
filename,
atom_style=atom_style,
units=units,
sort_by_id=sort_by_id,
prefer_bonds=prefer_bonds,
cutoff=cutoff,
pair_cutoffs=pair_cutoffs,
contract_neighborlist=contract_neighborlist,
remove_types=remove_types,
remove_degree2=remove_degree2,
omit_node_types=omit_node_types,
)
lattice = None
if all(comps.ase_atoms.pbc):
lattice = PymatgenLattice(comps.ase_atoms.cell)
nodes: list[Node] = []
for idx, (xc, sym) in enumerate(
zip(comps.cart_coords, comps.symbols), start=1
):
xs = (
None
if isinstance(comps.frac_coords, list)
else comps.frac_coords[idx - 1]
)
nodes.append(
Node(
node_id=idx,
node_type=sym,
cart_coord=xc,
frac_coord=xs,
)
)
return cls(nodes=nodes, edges=comps.edges, lattice=lattice)
[docs]
def get_rings(self, depth: int = 12) -> list[RingGeometry]:
"""Computes or retrieves unique rings in the network.
Parameters
----------
depth
The maximum depth to search for rings.
Notes
-----
- In the previous implementation, this method returned the clusters of
rings at each node. This is obtained instead via the `get_clusters`
method. This method now returns all unique rings in the network as
RingGeometry objects.
"""
if self.lattice is None:
raise ValueError(
"Currently, ring geometries require a defined lattice. This "
"will be supported in future releases because it is not a "
"fundamental requirement for ring finding."
)
rings = get_all_rings(self.edges, depth)
final_rings = []
for ring in rings:
this_ring = []
for node_id, image in ring:
node = self.nodes[node_id - 1]
node = node.apply_image_shift(self.lattice, image)
this_ring.append(node)
final_rings.append(RingGeometry(tuple(this_ring)))
return final_rings
[docs]
def get_clusters(
self, depth: int = 12, strong: bool = False
) -> RingsResults:
"""Computes or retrieves ring statistics for the network.
Parameters
----------
depth
The maximum depth to search for rings.
strong
Whether to filter the rings to strong rings only.
Returns
-------
A dictionary containing the ring statistics.
"""
label = ("strong_rings" if strong else "rings", depth)
# check if cached results exist and return them.
if label in self.properties:
result = self.properties[label]
assert isinstance(result, RingsResults)
return result
# compute rings.
compute_rings = run_strong_rings if strong else run_rings
rcount, rnodes = compute_rings(self.edges, depth)
# store and return results.
results = RingsResults(
depth=depth,
rings_are_strong=strong,
ring_size_count=RingSizeCounts(*rcount.T),
clusters=get_clusters(self, rnodes),
)
self.properties[label] = results
return results
[docs]
def get_topological_genome(self) -> str:
"""
Returns a the topology code for the framework.
Notes
-----
- The topological genome is a finite series of numbers that is provably
unique for each net.
- It can be comptued in polynomial time with respect to the size of the
net.
"""
if "topological_genome" in self.properties:
return str(self.properties["topological_genome"])
nodes = get_all_node_frac_coords(self.nodes)
assert self.lattice is not None, (
"Lattice must be defined to compute genome."
)
cell_lengths = self.lattice.lengths
cell_angles = self.lattice.angles
topology_genome = get_topological_genome(
nodes,
self.edges,
cell_lengths,
cell_angles,
)
self.properties["topological_genome"] = topology_genome
return topology_genome
[docs]
def get_coordination_sequences(
self, max_shell: int = 10, node_ids: Iterable[int] | int | None = None
) -> npt.NDArray[np.int_]:
"""
Return coordination sequences for specified nodes.
Parameters
----------
max_shell
The maximum shell to compute coordination sequences to.
node_ids
The node IDs for which to compute coordination sequences. If None,
coordination sequences for all nodes are returned.
"""
cs = get_coordination_sequences(self.edges, max_shell)
if node_ids is None:
return cs
if isinstance(node_ids, int):
node_ids = [node_ids]
return cs[np.array(node_ids, dtype=int) - 1]
@property
[docs]
def cartesian_coordinates(self) -> npt.NDArray[np.floating]:
"""Return the Cartesian positions of all nodes in the network."""
return np.array(
[node.cart_coord for node in self.nodes], dtype=np.float64
)
@property
[docs]
def fractional_coordinates(self) -> npt.NDArray[np.floating]:
"""Return the fractional positions of all nodes in the network."""
return np.array(
[node.frac_coord for node in self.nodes], dtype=np.float64
)
def __getitem__(self, node_id: int) -> Node:
"""Retrieve a Node object by its node number."""
if not (1 <= node_id <= len(self.nodes)):
raise IndexError(
f"node_id {node_id} is out of bounds "
f"(valid range: 1 to {len(self.nodes)})"
)
return self.nodes[node_id - 1]
def __repr__(self) -> str:
info = {
"nodes": len(self.nodes),
"edges": len(self.edges),
"has_lattice": self.lattice is not None,
}
return uniform_repr("Topology", **info, indent_size=4)
@dataclass(order=True)
[docs]
class Node:
"""
A representation of a node in a network.
"""
[docs]
node_type: str | None = "Si"
[docs]
frac_coord: npt.NDArray[np.floating] | None = field(default=None)
[docs]
cart_coord: npt.NDArray[np.floating] | None = field(default=None)
[docs]
is_shifted: bool = field(default=False)
[docs]
def apply_image_shift(
self,
lattice: PymatgenLattice,
image_shift: npt.NDArray[np.int_],
) -> Node:
"""Apply the image shift to this node and return a new Node object.
Parameters
----------
lattice
The lattice object for the network.
image_shift
The shift vector to apply to the node coordinates.
Returns
-------
A new Node object with the shifted coordinates.
"""
if self.frac_coord is None and self.cart_coord is None:
raise ValueError(
"Both `frac_coord` and `cart_coord` are missing; "
"cannot compute shifted coordinates."
)
if self.frac_coord is not None:
frac_coord = self.frac_coord
else:
assert self.cart_coord is not None
frac_coord = lattice.get_fractional_coords(self.cart_coord)
shifted_frac = (frac_coord + image_shift).astype(np.float64)
shifted_cart = lattice.get_cartesian_coords(shifted_frac)
return Node(
node_id=self.node_id,
node_type=self.node_type,
frac_coord=shifted_frac,
cart_coord=shifted_cart,
is_shifted=True,
)
def __post_init__(self) -> None:
"""Ensure coordinates are NumPy arrays if provided."""
self.sort_index = self.node_id
if self.frac_coord is not None:
self.frac_coord = np.array(self.frac_coord, dtype=np.float64)
if self.cart_coord is not None:
self.cart_coord = np.array(self.cart_coord, dtype=np.float64)
def __repr__(self) -> str:
name = "Node" if not self.is_shifted else "ShiftedNode"
info = {"node_id": self.node_id, "node_type": self.node_type}
if self.frac_coord is not None:
formatted_coords = np.array2string(
np.round(self.frac_coord, 2),
precision=2,
separator=", ",
floatmode="fixed",
)
info["frac_coord"] = formatted_coords
if self.cart_coord is not None:
formatted_coords = np.array2string(
np.round(self.cart_coord, 2),
precision=2,
separator=", ",
floatmode="fixed",
)
info["cart_coord"] = formatted_coords
return uniform_repr(name, **info, indent_size=4)
############################### HELPERS ###############################
[docs]
def get_all_node_frac_coords(nodes: list[Node]) -> npt.NDArray[np.floating]:
"""Return the fractional coordinates of all nodes in the network."""
return np.array([node.frac_coord for node in nodes], dtype=np.float64)