Source code for topo_metrics.knots

from __future__ import annotations

import math

import numpy as np
import numpy.typing as npt

try:
    import numba as nb  # pyright: ignore[reportMissingImports]
except ImportError:  # pragma: no cover
[docs] nb = None
[docs] Array = npt.NDArray[np.floating]
def _drop_duplicate_endpoint(P: Array) -> Array: """Drop duplicate last point if it matches the first.""" if len(P) >= 2 and np.allclose(P[0], P[-1]): return P[:-1] return P def _as_points(points: npt.ArrayLike) -> Array: """Convert input to (N,3) array of points (drops duplicate endpoint).""" P = np.asarray(points, dtype=float) if P.ndim != 2 or P.shape[1] != 3: raise ValueError("points must have shape (N,3)") return _drop_duplicate_endpoint(P) def _segments(points: Array, closed: bool) -> tuple[Array, Array]: """Returns (A, B) where each segment is A[i] -> B[i].""" if len(points) < 2: return points[:0], points[:0] if closed: A = points B = np.roll(points, -1, axis=0) else: A = points[:-1] B = points[1:] return A, B def _are_adjacent(i: int, j: int, m: int, closed: bool) -> bool: """Are segments i and j adjacent (share a vertex)?""" if i == j: return True if closed: return (j == (i + 1) % m) or (i == (j + 1) % m) return abs(i - j) == 1 def _unit(v: Array, eps: float) -> Array | None: """Return unit vector or None if zero norm.""" n = float(np.linalg.norm(v)) if n < eps: return None return v / n def _cross2(a: Array, b: Array) -> float: """2D cross product (scalar).""" return float(a[0] * b[1] - a[1] * b[0]) def _as_contig_f64(P: Array) -> np.ndarray: """Ensure contiguous float64 array (helps numba).""" return np.ascontiguousarray(np.asarray(P, dtype=np.float64)) def _directional_writhe_Wrz(points: Array, *, eps: float = 1e-12) -> float: """ Directional writhe for projection on xy plane (Eq 32). Counts segment intersections in 2D and assigns sign via triple product. """ A, B = _segments(points, closed=True) S = B - A m = len(S) Wrz = 0.0 for i in range(m): for j in range(i + 1, m): if _are_adjacent(i, j, m, closed=True): continue p = A[i][:2] r = S[i][:2] q = A[j][:2] s = S[j][:2] denom = _cross2(r, s) if abs(denom) < eps: continue qp = q - p ti = _cross2(qp, s) / denom tj = _cross2(qp, r) / denom if not (eps < ti < 1.0 - eps and eps < tj < 1.0 - eps): continue triple = float(np.dot(np.cross(S[j], S[i]), (A[j] - A[i]))) if triple > 0: Wrz += 1.0 elif triple < 0: Wrz -= 1.0 return Wrz def _clean_writhe(wr: float, acn: float, atol: float = 1e-12) -> float: """Clean writhe value based on average crossing number.""" tol = max(atol, 1e-12 * max(1.0, acn)) return 0.0 if abs(wr) < tol else wr def _gauss_pair_method_1a( p1: Array, p2: Array, p3: Array, p4: Array, eps: float = 1e-12 ) -> tuple[float, float]: """ Method 1a: solid-angle / spherical quadrilateral. Klenin & Langowski (2000), Eqs (15)-(16). """ def ucr(a, b): return _unit(np.cross(a, b), eps) r12 = p2 - p1 r34 = p4 - p3 r13 = p3 - p1 r14 = p4 - p1 r23 = p3 - p2 r24 = p4 - p2 n1 = ucr(r13, r14) n2 = ucr(r14, r24) n3 = ucr(r24, r23) n4 = ucr(r23, r13) if n1 is None or n2 is None or n3 is None or n4 is None: return 0.0, 0.0 t1 = float(np.clip(np.dot(n1, n2), -1.0, 1.0)) t2 = float(np.clip(np.dot(n2, n3), -1.0, 1.0)) t3 = float(np.clip(np.dot(n3, n4), -1.0, 1.0)) t4 = float(np.clip(np.dot(n4, n1), -1.0, 1.0)) V_star = math.asin(t1) + math.asin(t2) + math.asin(t3) + math.asin(t4) triple = float(np.dot(np.cross(r34, r12), r13)) sgn = 1.0 if triple > 0 else (-1.0 if triple < 0 else 0.0) return V_star * sgn, abs(V_star) def _gauss_pair_method_1b( p1: Array, p2: Array, p3: Array, p4: Array, eps: float = 1e-12 ) -> tuple[float, float]: """ Method 1b: analytic evaluation of the Gauss integral. Klenin & Langowski (2000), Eqs (17)-(25). """ def F(t1: float, t2: float) -> float: rad = t1 * t1 + t2 * t2 - 2.0 * t1 * t2 * cosb + a0 * a0 * sin2b if rad <= 0.0: # pragma: no cover return 0.0 denom = a0 * math.sqrt(rad) num = t1 * t2 + a0 * a0 * cosb return -(1.0 / (4.0 * math.pi)) * math.atan(num / denom) s1 = p2 - p1 s2 = p4 - p3 l1 = float(np.linalg.norm(s1)) l2 = float(np.linalg.norm(s2)) if l1 < eps or l2 < eps: return 0.0, 0.0 e1 = s1 / l1 e2 = s2 / l2 cosb = float(np.dot(e1, e2)) sin2b = 1.0 - cosb * cosb if sin2b < eps: return 0.0, 0.0 r12 = p3 - p1 a1 = float(np.dot(r12, (e2 * cosb - e1)) / sin2b) a2 = float(np.dot(r12, (e2 - e1 * cosb)) / sin2b) a0 = float(np.dot(r12, np.cross(e1, e2)) / sin2b) if abs(a0) < eps: return 0.0, 0.0 V_over_4pi = ( F(a1 + l1, a2 + l2) - F(a1 + l1, a2) - F(a1, a2 + l2) + F(a1, a2) ) V = 4.0 * math.pi * V_over_4pi return V, abs(V) if nb is not None: # pragma: no cover @nb.njit(cache=True) def _dot3(ax, ay, az, bx, by, bz): return ax * bx + ay * by + az * bz @nb.njit(cache=True) def _cross3(ax, ay, az, bx, by, bz): return (ay * bz - az * by, az * bx - ax * bz, ax * by - ay * bx) @nb.njit(cache=True) def _norm3(x, y, z): return math.sqrt(x * x + y * y + z * z) @nb.njit(cache=True) def _clamp(x, lo, hi): if x < lo: return lo if x > hi: return hi return x @nb.njit(cache=True) def _unit_cross(ax, ay, az, bx, by, bz, eps): cx, cy, cz = _cross3(ax, ay, az, bx, by, bz) n = _norm3(cx, cy, cz) if n < eps: return 0.0, 0.0, 0.0, False inv = 1.0 / n return cx * inv, cy * inv, cz * inv, True @nb.njit(cache=True) def _writhe_sum_1a_numba(A, B, closed_flag, eps): """ Returns (sumV, sumVabs) for method 1a on segments A[i]->B[i]. Mimics _gauss_pair_method_1a logic and adjacency skipping. """ m = A.shape[0] sumV = 0.0 sumVabs = 0.0 for i in range(m): p1x, p1y, p1z = A[i, 0], A[i, 1], A[i, 2] p2x, p2y, p2z = B[i, 0], B[i, 1], B[i, 2] r12x, r12y, r12z = p2x - p1x, p2y - p1y, p2z - p1z for j in range(i + 1, m): if closed_flag == 1: if j == i + 1: continue if i == 0 and j == m - 1: continue else: if j == i + 1: continue p3x, p3y, p3z = A[j, 0], A[j, 1], A[j, 2] p4x, p4y, p4z = B[j, 0], B[j, 1], B[j, 2] r34x, r34y, r34z = p4x - p3x, p4y - p3y, p4z - p3z r13x, r13y, r13z = p3x - p1x, p3y - p1y, p3z - p1z r14x, r14y, r14z = p4x - p1x, p4y - p1y, p4z - p1z r23x, r23y, r23z = p3x - p2x, p3y - p2y, p3z - p2z r24x, r24y, r24z = p4x - p2x, p4y - p2y, p4z - p2z n1x, n1y, n1z, ok1 = _unit_cross( r13x, r13y, r13z, r14x, r14y, r14z, eps ) if not ok1: continue n2x, n2y, n2z, ok2 = _unit_cross( r14x, r14y, r14z, r24x, r24y, r24z, eps ) if not ok2: continue n3x, n3y, n3z, ok3 = _unit_cross( r24x, r24y, r24z, r23x, r23y, r23z, eps ) if not ok3: continue n4x, n4y, n4z, ok4 = _unit_cross( r23x, r23y, r23z, r13x, r13y, r13z, eps ) if not ok4: continue t1 = _clamp(_dot3(n1x, n1y, n1z, n2x, n2y, n2z), -1.0, 1.0) t2 = _clamp(_dot3(n2x, n2y, n2z, n3x, n3y, n3z), -1.0, 1.0) t3 = _clamp(_dot3(n3x, n3y, n3z, n4x, n4y, n4z), -1.0, 1.0) t4 = _clamp(_dot3(n4x, n4y, n4z, n1x, n1y, n1z), -1.0, 1.0) V_star = ( math.asin(t1) + math.asin(t2) + math.asin(t3) + math.asin(t4) ) V_abs = -V_star if V_star < 0.0 else V_star sumVabs += V_abs cx, cy, cz = _cross3(r34x, r34y, r34z, r12x, r12y, r12z) triple = _dot3(cx, cy, cz, r13x, r13y, r13z) if triple > 0.0: sumV += V_star elif triple < 0.0: sumV -= V_star return sumV, sumVabs @nb.njit(cache=True) def _link_sum_1a_numba(A1, B1, A2, B2, sx, sy, sz, eps): """ Returns sumV (signed solid-angle sum) for linking between two rings, with ring2 shifted by (sx,sy,sz) applied as: p -> p - shift. """ sumV = 0.0 for i in range(A1.shape[0]): p1x, p1y, p1z = A1[i, 0], A1[i, 1], A1[i, 2] p2x, p2y, p2z = B1[i, 0], B1[i, 1], B1[i, 2] r12x, r12y, r12z = p2x - p1x, p2y - p1y, p2z - p1z for j in range(A2.shape[0]): p3x, p3y, p3z = A2[j, 0] - sx, A2[j, 1] - sy, A2[j, 2] - sz p4x, p4y, p4z = B2[j, 0] - sx, B2[j, 1] - sy, B2[j, 2] - sz r34x, r34y, r34z = p4x - p3x, p4y - p3y, p4z - p3z r13x, r13y, r13z = p3x - p1x, p3y - p1y, p3z - p1z r14x, r14y, r14z = p4x - p1x, p4y - p1y, p4z - p1z r23x, r23y, r23z = p3x - p2x, p3y - p2y, p3z - p2z r24x, r24y, r24z = p4x - p2x, p4y - p2y, p4z - p2z n1x, n1y, n1z, ok1 = _unit_cross( r13x, r13y, r13z, r14x, r14y, r14z, eps ) if not ok1: continue n2x, n2y, n2z, ok2 = _unit_cross( r14x, r14y, r14z, r24x, r24y, r24z, eps ) if not ok2: continue n3x, n3y, n3z, ok3 = _unit_cross( r24x, r24y, r24z, r23x, r23y, r23z, eps ) if not ok3: continue n4x, n4y, n4z, ok4 = _unit_cross( r23x, r23y, r23z, r13x, r13y, r13z, eps ) if not ok4: continue t1 = _clamp(_dot3(n1x, n1y, n1z, n2x, n2y, n2z), -1.0, 1.0) t2 = _clamp(_dot3(n2x, n2y, n2z, n3x, n3y, n3z), -1.0, 1.0) t3 = _clamp(_dot3(n3x, n3y, n3z, n4x, n4y, n4z), -1.0, 1.0) t4 = _clamp(_dot3(n4x, n4y, n4z, n1x, n1y, n1z), -1.0, 1.0) V_star = ( math.asin(t1) + math.asin(t2) + math.asin(t3) + math.asin(t4) ) cx, cy, cz = _cross3(r34x, r34y, r34z, r12x, r12y, r12z) triple = _dot3(cx, cy, cz, r13x, r13y, r13z) if triple > 0.0: sumV += V_star elif triple < 0.0: sumV -= V_star return sumV # --- segment-segment distance (numba) for disjointness & ranking --- @nb.njit(cache=True) def _segseg_dist2_numba( p1x, p1y, p1z, p2x, p2y, p2z, q1x, q1y, q1z, q2x, q2y, q2z ): """ Ericson-style segment-segment distance squared, numerically stable. """ ux, uy, uz = p2x - p1x, p2y - p1y, p2z - p1z vx, vy, vz = q2x - q1x, q2y - q1y, q2z - q1z wx, wy, wz = p1x - q1x, p1y - q1y, p1z - q1z a = ux * ux + uy * uy + uz * uz b = ux * vx + uy * vy + uz * vz c = vx * vx + vy * vy + vz * vz d = ux * wx + uy * wy + uz * wz e = vx * wx + vy * wy + vz * wz D = a * c - b * b sN, sD = 0.0, D tN, tD = 0.0, D eps = 1e-15 if D < eps: sN, sD = 0.0, 1.0 tN, tD = e, c else: sN = b * e - c * d tN = a * e - b * d if sN < 0.0: sN = 0.0 tN = e tD = c elif sN > sD: sN = sD tN = e + b tD = c if tN < 0.0: tN = 0.0 if -d < 0.0: sN = 0.0 elif -d > a: sN = sD else: sN = -d sD = a elif tN > tD: tN = tD if (-d + b) < 0.0: sN = 0.0 elif (-d + b) > a: sN = sD else: sN = -d + b sD = a sc = 0.0 if abs(sN) < eps else (sN / sD) tc = 0.0 if abs(tN) < eps else (tN / tD) dx = wx + sc * ux - tc * vx dy = wy + sc * uy - tc * vy dz = wz + sc * uz - tc * vz return dx * dx + dy * dy + dz * dz @nb.njit(cache=True) def _min_seg_dist2_shift_numba(A1, B1, A2, B2, sx, sy, sz, early_exit2): """ Min squared distance between segments A1->B1 and shifted A2->B2. Shift applied as: q -> q - (sx,sy,sz). If early_exit2 >= 0 and best <= early_exit2, returns early. """ best = 1.0e300 for i in range(A1.shape[0]): p1x, p1y, p1z = A1[i, 0], A1[i, 1], A1[i, 2] p2x, p2y, p2z = B1[i, 0], B1[i, 1], B1[i, 2] for j in range(A2.shape[0]): q1x, q1y, q1z = A2[j, 0] - sx, A2[j, 1] - sy, A2[j, 2] - sz q2x, q2y, q2z = B2[j, 0] - sx, B2[j, 1] - sy, B2[j, 2] - sz d2 = _segseg_dist2_numba( p1x, p1y, p1z, p2x, p2y, p2z, q1x, q1y, q1z, q2x, q2y, q2z ) if d2 < best: best = d2 if early_exit2 >= 0.0 and best <= early_exit2: return best return best
[docs] def writhe_method_1a( points: npt.ArrayLike, *, closed: bool = True, eps: float = 1e-12 ) -> tuple[float, float]: """Method 1a: pairwise solid angles (Eqs 13, 15-16).""" P = _as_points(points) A, B = _segments(P, closed) m = len(A) if m < 2: return 0.0, 0.0 if nb is not None: Af = _as_contig_f64(A) Bf = _as_contig_f64(B) sumV, sumVabs = _writhe_sum_1a_numba(Af, Bf, 1 if closed else 0, eps) else: sumV = 0.0 sumVabs = 0.0 for i in range(m): for j in range(i + 1, m): if _are_adjacent(i, j, m, closed): continue V, Vabs = _gauss_pair_method_1a(A[i], B[i], A[j], B[j], eps=eps) sumV += V sumVabs += Vabs pref = 1.0 / (2.0 * math.pi) writhe = _clean_writhe(sumV * pref, sumVabs * pref, atol=eps) acn = sumVabs * pref return float(writhe), float(acn)
[docs] def writhe_method_1b( points: npt.ArrayLike, *, closed: bool = True, eps: float = 1e-12 ) -> tuple[float, float]: """Method 1b: analytic Gauss integral (Eqs 13, 24-25).""" P = _as_points(points) A, B = _segments(P, closed) m = len(A) if m < 2: return 0.0, 0.0 sumV = 0.0 sumVabs = 0.0 for i in range(m): for j in range(i + 1, m): if _are_adjacent(i, j, m, closed): continue V, Vabs = _gauss_pair_method_1b(A[i], B[i], A[j], B[j], eps=eps) sumV += V sumVabs += Vabs pref = 1.0 / (2.0 * math.pi) writhe = _clean_writhe(sumV * pref, sumVabs * pref, atol=eps) acn = sumVabs * pref return float(writhe), float(acn)
[docs] def writhe_method_2a(points: npt.ArrayLike, *, eps: float = 1e-12) -> float: """ Method 2a: Wr = Twz + Wrz - Tw (Eqs 30-34), z-axis projection. Requires a closed chain. """ P = _as_points(points) n = len(P) if n < 4: return 0.0 A, B = _segments(P, closed=True) S = B - A p = np.zeros((n, 3), dtype=float) for i in range(n): c = np.cross(S[i - 1], S[i]) u = _unit(c, eps) if u is None: return float("nan") p[i] = u Tw_sum = 0.0 for i in range(n): pi = p[i] pip = p[(i + 1) % n] ang = math.acos(float(np.clip(np.dot(pi, pip), -1.0, 1.0))) sgn = float(np.sign(np.dot(pi, S[(i + 1) % n]))) Tw_sum += ang * sgn Tw = Tw_sum / (2.0 * math.pi) Twz = 0.0 for i in range(n): if p[i, 2] * p[(i + 1) % n, 2] < 0.0: Twz += float(np.sign(np.dot(p[i], S[(i + 1) % n]))) Twz *= 0.5 Wrz = _directional_writhe_Wrz(P, eps=eps) return float(Twz + Wrz - Tw)
[docs] def writhe_method_2b(points: npt.ArrayLike, *, eps: float = 1e-12) -> float: """ Method 2b (le Bret-style): Wr = Wrz - Tw with a_i = k×s_i/|k×s_i| (Eqs 35-38), k = z-axis. Requires closed chain. """ P = _as_points(points) n = len(P) if n < 4: return 0.0 A, B = _segments(P, closed=True) S = B - A k = np.array([0.0, 0.0, 1.0], dtype=float) a = np.zeros((n, 3), dtype=float) for i in range(n): u = _unit(np.cross(k, S[i]), eps) if u is None: return float("nan") a[i] = u p = np.zeros((n, 3), dtype=float) for i in range(n): u = _unit(np.cross(S[i - 1], S[i]), eps) if u is None: return float("nan") p[i] = u Tw_sum = 0.0 for i in range(n): pi = p[i] term1 = math.acos(float(np.clip(np.dot(a[i - 1], pi), -1.0, 1.0))) term2 = math.acos(float(np.clip(np.dot(pi, a[i]), -1.0, 1.0))) Tw_sum += (term1 - term2) * float(np.sign(pi[2])) Tw = Tw_sum / (2.0 * math.pi) Wrz = _directional_writhe_Wrz(P, eps=eps) return float(Wrz - Tw)
def _median_segment_length(points: Array, closed: bool = True) -> float: P = _as_points(points) A, B = _segments(P, closed=closed) if len(A) == 0: return 0.0 seg = np.linalg.norm(B - A, axis=1) return float(np.median(seg)) def _auto_disjoint_tol( ring1: Array, ring2: Array, *, rel: float = 1e-3, abs_: float = 1e-8, ) -> float: """ Heuristic: disjoint_tol ~ rel * (typical segment length), floored by abs_. rel=1e-3: "numerical disjointness" default. rel=1e-2: more conservative if you see near-singular Gauss contributions. """ m1 = _median_segment_length(ring1, closed=True) m2 = _median_segment_length(ring2, closed=True) m = min(m1, m2) if (m1 > 0 and m2 > 0) else max(m1, m2) if m <= 0: return float(abs_) return float(max(abs_, rel * m)) def _segseg_dist2( p1: Array, p2: Array, q1: Array, q2: Array, eps: float = 1e-15 ) -> float: """ Squared minimum distance between 3D segments p1->p2 and q1->q2. Robust clamp-based implementation (Ericson-style). """ u = p2 - p1 v = q2 - q1 w = p1 - q1 a = float(np.dot(u, u)) b = float(np.dot(u, v)) c = float(np.dot(v, v)) d = float(np.dot(u, w)) e = float(np.dot(v, w)) D = a * c - b * b sN, sD = 0.0, D tN, tD = 0.0, D if D < eps: sN, sD = 0.0, 1.0 tN, tD = e, c else: sN = b * e - c * d tN = a * e - b * d if sN < 0.0: sN = 0.0 tN = e tD = c elif sN > sD: sN = sD tN = e + b tD = c if tN < 0.0: tN = 0.0 if -d < 0.0: sN = 0.0 elif -d > a: sN = sD else: sN = -d sD = a elif tN > tD: tN = tD if (-d + b) < 0.0: sN = 0.0 elif (-d + b) > a: sN = sD else: sN = -d + b sD = a sc = 0.0 if abs(sN) < eps else (sN / sD) tc = 0.0 if abs(tN) < eps else (tN / tD) dP = w + sc * u - tc * v return float(np.dot(dP, dP)) def _min_seg_dist2( A1: Array, B1: Array, A2: Array, B2: Array, *, early_exit2: float | None = None, ) -> float: """Min squared distance between any segment in ring1 and ring2.""" best = float("inf") for i in range(A1.shape[0]): p1 = A1[i] p2 = B1[i] for j in range(A2.shape[0]): q1 = A2[j] q2 = B2[j] d2 = _segseg_dist2(p1, p2, q1, q2) if d2 < best: best = d2 if early_exit2 is not None and best <= early_exit2: return float(best) return float(best)
[docs] def linking_number_method_1a( ring1: npt.ArrayLike, ring2: npt.ArrayLike, *, eps: float = 1e-12, disjoint_tol: float | None = None, disjoint_rel: float = 1e-3, return_nan_if_not_disjoint: bool = True, ) -> float: """ Gauss linking number for two CLOSED polygonal rings (method 1a). Disjointness: - if disjoint_tol is None: uses auto disjoint_tol from disjoint_rel - if disjoint_tol <= 0: skips the disjointness check - if not disjoint: returns nan or raises return_nan_if_not_disjoint=False """ P = _as_points(ring1) Q = _as_points(ring2) A1, B1 = _segments(P, closed=True) A2, B2 = _segments(Q, closed=True) if len(A1) < 2 or len(A2) < 2: return 0.0 if disjoint_tol is None: disjoint_tol = _auto_disjoint_tol(P, Q, rel=disjoint_rel) if disjoint_tol is not None and disjoint_tol > 0.0: d2 = _min_seg_dist2( A1, B1, A2, B2, early_exit2=float(disjoint_tol) ** 2 ) if d2 < float(disjoint_tol) ** 2: if return_nan_if_not_disjoint: return float("nan") raise ValueError( "Rings are not disjoint (segment distance below disjoint_tol)." ) if nb is not None: A1f, B1f = _as_contig_f64(A1), _as_contig_f64(B1) A2f, B2f = _as_contig_f64(A2), _as_contig_f64(B2) sumV = _link_sum_1a_numba(A1f, B1f, A2f, B2f, 0.0, 0.0, 0.0, eps) return float(sumV / (4.0 * math.pi)) sumV = 0.0 for i in range(len(A1)): for j in range(len(A2)): V, _ = _gauss_pair_method_1a(A1[i], B1[i], A2[j], B2[j], eps=eps) sumV += V return float(sumV / (4.0 * math.pi))
[docs] def linking_number_method_1b( ring1: npt.ArrayLike, ring2: npt.ArrayLike, *, eps: float = 1e-12, disjoint_tol: float | None = None, disjoint_rel: float = 1e-3, return_nan_if_not_disjoint: bool = True, ) -> float: """ Gauss linking number for two CLOSED polygonal rings (method 1b analytic). Same disjointness behavior as method_1a. """ P = _as_points(ring1) Q = _as_points(ring2) A1, B1 = _segments(P, closed=True) A2, B2 = _segments(Q, closed=True) if len(A1) < 2 or len(A2) < 2: return 0.0 if disjoint_tol is None: disjoint_tol = _auto_disjoint_tol(P, Q, rel=disjoint_rel) if disjoint_tol is not None and disjoint_tol > 0.0: d2 = _min_seg_dist2( A1, B1, A2, B2, early_exit2=float(disjoint_tol) ** 2 ) if d2 < float(disjoint_tol) ** 2: if return_nan_if_not_disjoint: return float("nan") raise ValueError( "Rings are not disjoint (segment distance below disjoint_tol)." ) sumV = 0.0 for i in range(len(A1)): for j in range(len(A2)): V, _ = _gauss_pair_method_1b(A1[i], B1[i], A2[j], B2[j], eps=eps) sumV += V return float(sumV / (4.0 * math.pi))
[docs] def lk_round(lk: float, tol: float = 1e-6) -> tuple[int, bool]: """Return (rounded_int, ok) where ok means |lk-round(lk)| <= tol.""" r = int(round(lk)) return r, (abs(lk - r) <= tol)
[docs] def linking_number_int(lk: float, tol: float = 1e-6) -> int: """Convert near-integer lk to integer.""" return lk_round(lk, tol=tol)[0]
[docs] def is_linked_from_lk(lk: float, *, tol: float = 1e-6) -> bool: """Determine if two rings are linked from linking number value.""" r, ok = lk_round(lk, tol=tol) return bool(ok and abs(r) > 0)
def _centroid(P: Array) -> Array: """Compute centroid of points P.""" return np.mean(P, axis=0) def _pbc_base_integer_shift( cA: Array, cB: Array, cell: Array, pbc: tuple[bool, bool, bool], ) -> Array: """ Returns integer vector n0 such that shifting B by -(n0 @ cell) brings centroid difference into the minimum image (fractional in [-0.5, 0.5)). """ inv_cell = np.linalg.inv(cell) df = (cB - cA) @ inv_cell n0 = np.zeros(3, dtype=int) for k in range(3): n0[k] = int(np.round(df[k])) if pbc[k] else 0 return n0
[docs] def linking_number_pbc( ringA: Array, ringB: Array, *, cell: Array, pbc: tuple[bool, bool, bool] = (True, True, True), n_images: int = 1, method: str = "1a", eps: float = 1e-12, check_top_k: int | None = None, disjoint_tol: float | None = None, disjoint_rel: float = 1e-3, ) -> tuple[float, tuple[int, int, int]]: """ Compute Gauss linking number between ringA and ringB under PBC by scanning periodic images of ringB. Returns (best_lk, best_image_shift) where best_image_shift is integer n such that: ringB_shifted = ringB - (n @ cell) Candidate scoring / selection: - compute min segment-segment dist^2 for each shift - discard candidates with dist < disjoint_tol - among remaining, choose the shift that maximizes |round(lk)|, tie-break by smaller dist^2, then smaller residual |lk-round(lk)| """ cell = np.asarray(cell, dtype=float) if cell.shape != (3, 3): raise ValueError("cell must have shape (3,3)") A = np.asarray(ringA, dtype=float) B = np.asarray(ringB, dtype=float) if A.ndim != 2 or A.shape[1] != 3 or B.ndim != 2 or B.shape[1] != 3: raise ValueError("rings must be arrays of shape (N,3)") if disjoint_tol is None: disjoint_tol = _auto_disjoint_tol(A, B, rel=disjoint_rel) disjoint_tol2 = ( float(disjoint_tol) ** 2 if (disjoint_tol is not None and disjoint_tol > 0.0) else -1.0 ) # Base integer shift from centroid minimum image n0 = _pbc_base_integer_shift(_centroid(A), _centroid(B), cell, pbc) rng = range(-n_images, n_images + 1) # Precompute segments P = _as_points(A) Q = _as_points(B) A1, B1 = _segments(P, closed=True) A2, B2 = _segments(Q, closed=True) if len(A1) < 2 or len(A2) < 2: return 0.0, (0, 0, 0) # --- Numba accelerated path for method 1a (no B copies) --- if nb is not None and method == "1a": A1f, B1f = _as_contig_f64(A1), _as_contig_f64(B1) A2f, B2f = _as_contig_f64(A2), _as_contig_f64(B2) candidates: list[ tuple[float, tuple[int, int, int], tuple[float, float, float]] ] = [] for i in rng: for j in rng: for k in rng: n = np.array([n0[0] + i, n0[1] + j, n0[2] + k], dtype=int) for dim in range(3): if not pbc[dim]: n[dim] = 0 shift = n @ cell sx, sy, sz = ( float(shift[0]), float(shift[1]), float(shift[2]), ) dist2 = float( _min_seg_dist2_shift_numba( A1f, B1f, A2f, B2f, sx, sy, sz, disjoint_tol2 ) ) candidates.append( (dist2, (int(n[0]), int(n[1]), int(n[2])), (sx, sy, sz)) ) candidates.sort(key=lambda x: x[0]) if check_top_k is not None: candidates = candidates[: max(1, int(check_top_k))] best_lk = float("nan") best_n = (0, 0, 0) best_int = 0 best_dist2 = float("inf") best_resid = float("inf") for dist2, n, (sx, sy, sz) in candidates: if disjoint_tol2 >= 0.0 and dist2 < disjoint_tol2: continue sumV = float( _link_sum_1a_numba(A1f, B1f, A2f, B2f, sx, sy, sz, eps) ) lk = sumV / (4.0 * math.pi) lk_int = int(round(lk)) resid = abs(lk - lk_int) if ( abs(lk_int) > abs(best_int) or (abs(lk_int) == abs(best_int) and dist2 < best_dist2) or ( abs(lk_int) == abs(best_int) and dist2 == best_dist2 and resid < best_resid ) ): best_lk = float(lk) best_n = n best_int = lk_int best_dist2 = dist2 best_resid = resid return best_lk, best_n # --- Fallback path (python): build shifted copies --- candidates2: list[tuple[float, tuple[int, int, int], Array]] = [] for i in rng: for j in rng: for k in rng: n = np.array([n0[0] + i, n0[1] + j, n0[2] + k], dtype=int) for dim in range(3): if not pbc[dim]: n[dim] = 0 shift = n @ cell Bimg = B - shift Qimg = _as_points(Bimg) A2i, B2i = _segments(Qimg, closed=True) dist2 = _min_seg_dist2( A1, B1, A2i, B2i, early_exit2=disjoint_tol2 if disjoint_tol2 >= 0.0 else None, ) candidates2.append( (dist2, (int(n[0]), int(n[1]), int(n[2])), Bimg) ) candidates2.sort(key=lambda x: x[0]) if check_top_k is not None: candidates2 = candidates2[: max(1, int(check_top_k))] best_lk = float("nan") best_n = (0, 0, 0) best_int = 0 best_dist2 = float("inf") best_resid = float("inf") for dist2, n, Bimg in candidates2: if disjoint_tol2 >= 0.0 and dist2 < disjoint_tol2: continue if method == "1a": # already filtered by disjointness above; skip redundant check here lk = linking_number_method_1a(A, Bimg, eps=eps, disjoint_tol=0.0) elif method == "1b": lk = linking_number_method_1b(A, Bimg, eps=eps, disjoint_tol=0.0) else: raise ValueError("method must be '1a' or '1b' for linking number") if not np.isfinite(lk): continue lk_int = int(round(lk)) resid = abs(lk - lk_int) if ( abs(lk_int) > abs(best_int) or (abs(lk_int) == abs(best_int) and dist2 < best_dist2) or ( abs(lk_int) == abs(best_int) and dist2 == best_dist2 and resid < best_resid ) ): best_lk = float(lk) best_n = n best_int = lk_int best_dist2 = dist2 best_resid = resid return best_lk, best_n