from __future__ import annotations
from collections import deque
from typing import Any, Iterable
import ase
import numpy as np
import numpy.typing as npt
from ase.data import atomic_numbers
try:
from vesin import ase_neighbor_list
except ImportError:
from ase.neighborlist import neighbor_list as ase_neighbor_list
[docs]
def graph_edges_by_cutoff(
atoms: ase.Atoms,
cutoff: float = 0.0,
pair_cutoffs: dict[tuple[str, str], float] | None = None,
one_based: bool = True,
):
"""
Build periodic edges with a global cutoff or element-specific cutoffs.
Parameters
----------
atoms
Input structure (ase.Atoms).
cutoff
Global distance cutoff used when `pair_cutoffs` is None.
If `pair_cutoffs` is provided, this acts as the default cutoff
for element pairs not present in `pair_cutoffs`.
Set to 0.0 to effectively forbid unspecified pairs.
pair_cutoffs
Per-element-pair cutoffs, e.g. {("Si", "O"): 2.1, ("Al", "O"): 2.0}.
Pairs are treated as unordered; ("O", "Si") is the same as ("Si", "O").
If None, a single global cutoff is used for all pairs.
one_based
If True, returned atom indices are 1-based. If False, they are 0-based.
Returns
-------
edges
ndarray of shape (n_edges, 5). Columns: [i, j, sx, sy, sz],
where (sx, sy, sz) are integer cell shifts.
"""
Z = atoms.get_atomic_numbers()
# --- Case 1: global cutoff only ---
if pair_cutoffs is None:
i_sel, j_sel, S_sel = ase_neighbor_list("ijS", atoms, cutoff=cutoff)
# --- Case 2: element-specific cutoffs ---
else:
Zmax = Z.max()
global_cutoff = max(cutoff, max(pair_cutoffs.values()))
cutoff_matrix = np.full(
(Zmax + 1, Zmax + 1), global_cutoff, dtype=float
)
# fill matrix from element-symbol pair dict, symmetrising
for (sym1, sym2), c in pair_cutoffs.items():
z1 = atomic_numbers[sym1]
z2 = atomic_numbers[sym2]
cutoff_matrix[z1, z2] = c
cutoff_matrix[z2, z1] = c
# use the maximum cutoff to build the neighbor list once
i, j, d, S = ase_neighbor_list("ijdS", atoms, cutoff=global_cutoff)
# look up per-pair cutoffs
Zi = Z[i]
Zj = Z[j]
pair_cutoffs_ij = cutoff_matrix[Zi, Zj] # shape (n_pairs,)
# keep only neighbours that pass their element-specific cutoff
mask = d <= pair_cutoffs_ij
i_sel = i[mask]
j_sel = j[mask]
S_sel = S[mask]
if one_based:
i_sel = i_sel + 1
j_sel = j_sel + 1
edges = np.column_stack([i_sel, j_sel, S_sel]).astype(np.int_)
return edges
[docs]
def autoreduce_neighborlist(
cart_coords: npt.NDArray[np.float64] | list[None],
frac_coords: npt.NDArray[np.float64] | list[None],
symbols: list[str],
edges: npt.NDArray[np.int_],
remove_types: Iterable[Any] | None = None,
remove_degree2: bool = False,
) -> tuple[
npt.NDArray[np.float64] | list[None],
npt.NDArray[np.float64] | list[None],
list[str],
npt.NDArray[np.int_],
npt.NDArray[np.int_],
]:
"""
Simplify a periodic bonded graph by contracting out selected atoms.
Parameters
----------
frac_coords
Fractional coordinates of all atoms.
symbols
Atomic symbols (length N).
edges
Columns: i+1, j+1, Sx, Sy, Sz (1-based atom indices).
remove_types
If not None, atoms whose symbol is in this set are removed and their
neighbors are connected together (clique over neighbors).
Example: {"O"}.
remove_degree2
If True, atoms that are 2-connected are also removed (in addition to
any atoms in `remove_types`).
Returns
-------
new_frac_coords : (N_keep,3) ndarray
new_symbols : list[str] length N_keep
new_edges : (M_new,5) int ndarray
Same format as input `edges` (1-based indices).
old_to_new : (N,) ndarray of int
Mapping from old atom index (0-based) to new index (0-based). -1 for
removed atoms.
"""
cart_coords = np.asarray(cart_coords)
frac_coords = np.asarray(frac_coords)
edges = np.asarray(edges, dtype=int)
N = len(symbols)
# 0-based indices for manipulation
i0 = edges[:, 0] - 1
j0 = edges[:, 1] - 1
S = edges[:, 2:5].astype(int)
# Build adjacency: adjacency[u] = list of (v, S_uv) where S_uv is shift u->v
adjacency = [[] for _ in range(N)]
for u, v, s in zip(i0, j0, S):
s_vec = np.asarray(s, dtype=int)
adjacency[u].append((v, s_vec))
adjacency[v].append((u, -s_vec))
symbols = list(symbols)
remove_types_set = set(remove_types) if remove_types is not None else set()
removable_by_type = np.array(
[sym in remove_types_set for sym in symbols], dtype=bool
)
degrees = np.array([len(adj) for adj in adjacency], dtype=int)
removed = np.zeros(N, dtype=bool)
scheduled = np.zeros(N, dtype=bool)
# Initial queue: atoms selected by type
q = deque()
for idx in range(N):
if removable_by_type[idx] and degrees[idx] > 0:
q.append(idx)
scheduled[idx] = True
# Also: initial degree-2 atoms, if requested
if remove_degree2:
for idx in range(N):
if not scheduled[idx] and degrees[idx] == 2:
q.append(idx)
scheduled[idx] = True
def add_edge(u, v, S_uv) -> None:
"""Add undirected edge u<->v with shift S_uv from u to v."""
adjacency[u].append((v, S_uv))
adjacency[v].append((u, -S_uv))
degrees[u] += 1
degrees[v] += 1
# Main contraction loop
while q:
r = q.popleft()
if removed[r]:
continue
# Current neighbors that are still alive
neighbors = [(n, s.copy()) for (n, s) in adjacency[r] if not removed[n]]
# Connect neighbors pairwise through r
k = len(neighbors)
if k >= 2:
for ia in range(k):
a, S_ra = neighbors[ia]
for ib in range(ia + 1, k):
b, S_rb = neighbors[ib]
if a == b:
continue
# S_ab (a -> b) from r->a and r->b is: S_ab = -S_ra + S_rb
S_ab = -S_ra + S_rb
add_edge(a, b, S_ab)
# Remove r from neighbors' adjacency lists
for n, _ in neighbors:
old_list = adjacency[n]
if not old_list:
continue
new_list = [(nbr, s) for (nbr, s) in old_list if nbr != r]
removed_count = len(old_list) - len(new_list)
if removed_count:
adjacency[n] = new_list
degrees[n] -= removed_count
# Newly 2-connected atoms can be scheduled if we are doing
# degree-2 reduction and they are not type-protected
if (
remove_degree2
and not removed[n]
and not removable_by_type[n]
and degrees[n] == 2
and not scheduled[n]
):
q.append(n)
scheduled[n] = True
# Finally mark r as removed
adjacency[r] = []
degrees[r] = 0
removed[r] = True
# Build mapping old index -> new index for surviving atoms
old_to_new = -np.ones(N, dtype=int)
keep_indices = [i for i in range(N) if not removed[i]]
for new_idx, old_idx in enumerate(keep_indices):
old_to_new[old_idx] = new_idx
# Rebuild edge list from adjacency of surviving atoms (deduplicated)
edge_keys = set()
for u in keep_indices:
for v, S_uv in adjacency[u]:
if removed[v]:
continue
# Canonical orientation: lower index first, adjust shift sign
if u <= v:
key = (u, v, int(S_uv[0]), int(S_uv[1]), int(S_uv[2]))
else:
key = (v, u, int(-S_uv[0]), int(-S_uv[1]), int(-S_uv[2]))
edge_keys.add(key)
if edge_keys:
edge_array = np.array(sorted(edge_keys), dtype=int)
i_new = old_to_new[edge_array[:, 0]]
j_new = old_to_new[edge_array[:, 1]]
S_new = edge_array[:, 2:5]
new_edges = np.column_stack((i_new + 1, j_new + 1, S_new)).astype(int)
else:
new_edges = np.zeros((0, 5), dtype=int)
new_cart_coords = cart_coords[keep_indices]
new_frac_coords = frac_coords[keep_indices]
new_symbols = [symbols[i] for i in keep_indices]
return new_cart_coords, new_frac_coords, new_symbols, new_edges, old_to_new