Source code for capytaine.tools.symbolic_multiplication
import numpy as np
from functools import wraps, total_ordering
[docs]
@total_ordering
class SymbolicMultiplication:
def __init__(self, symbol, value=1.0):
self.symbol = symbol
self.value = value
def __format__(self, format_spec):
return f"{self.symbol}×{self.value.__format__(format_spec)}"
__array_priority__ = 1.0
def __array_function__(self, func, types, *args, **kwargs):
if func in {np.real, np.imag, np.sum}:
return SymbolicMultiplication(self.symbol, func(self.value))
else:
return NotImplemented
def __str__(self):
return f"{self.symbol}×{self.value}"
def __repr__(self):
return f"SymbolicMultiplication(\"{self.symbol}\", {repr(self.value)})"
def __add__(self, x):
return self._concretize() + x
def __radd__(self, x):
return x + self._concretize()
def __mul__(self, x):
return SymbolicMultiplication(self.symbol, self.value * x)
def __rmul__(self, x):
return SymbolicMultiplication(self.symbol, x * self.value)
def __pow__(self, n):
if n == 2:
return self * self
else:
raise NotImplementedError
def __truediv__(self, x):
if hasattr(x, 'symbol') and self.symbol == x.symbol:
return self.value / x.value
else:
return SymbolicMultiplication(self.symbol, self.value / x)
def __rtruediv__(self, x):
if hasattr(x, 'symbol') and self.symbol == x.symbol:
return x.value / self.value
elif self.symbol == "0":
return SymbolicMultiplication("∞", x/self.value)
elif self.symbol == "∞":
return SymbolicMultiplication("0", x/self.value)
else:
raise NotImplementedError
def __matmul__(self, x):
return SymbolicMultiplication(self.symbol, self.value @ x)
def __rmatmul__(self, x):
return SymbolicMultiplication(self.symbol, x @ self.value)
def __getitem__(self, item):
return SymbolicMultiplication(self.symbol, self.value[item])
def __eq__(self, x):
return float(self) == x
def __lt__(self, x):
return float(self) < x
def __hash__(self):
return hash((self.symbol, self.value))
def _concretize(self):
if isinstance(self.value, np.ndarray):
if self.symbol == "0":
return np.zeros_like(self.value)
elif self.symbol == "∞":
return np.full_like(self.value, np.inf)
else:
return float(self)
def __float__(self):
if self.symbol == "0":
return 0.0
elif self.symbol == "∞":
return np.inf
else:
raise NotImplementedError
[docs]
def reshape(self, *args):
return SymbolicMultiplication(self.symbol, self.value.reshape(*args))
[docs]
def supporting_symbolic_multiplication(f):
@wraps(f)
def wrapped_f(a, x):
if hasattr(x, 'symbol'):
return SymbolicMultiplication(x.symbol, f(a, x.value))
else:
return f(a, x)
return wrapped_f