# -*- coding: utf-8 -*-
"""
Started on Wed Apr 9 13:39:41 2025
@author: Vincent MAHE
Sympy functions useful to the use of MMS functions.
"""
#%% Imports
import copy
from sympy import sqrt, solve, lambdify
import warnings
#%% Functions
[docs]
def sub_deep(expr, sub):
r"""
Performs deep substitutions of an expression.
Parameters
----------
expr : sympy.Expr
Expression on which substitutions are to be performed.
sub : list of tuples
The substitutions to perform.
Returns
-------
expr_sub : sympy.Expr
The expression with substitutions performed.
Notes
-----
Deep substitutions are needed when a substitution involves terms that can still be substituted.
For instance, one wants to substitute :math:`a_1` and :math:`a_2` by expressions, but :math:`a_1` is actually a function of :math:`a_2`, so at least 2 substitutions are required.
"""
expr_sub = copy.copy(expr)
expr_init = 0
while expr_init != expr_sub: # Check if the 2 are the same
expr_init = copy.copy(expr_sub) # Update expr_init -> now the 2 are the same
expr_sub = expr_sub.subs(sub).doit() # Update expr_sub -> now the two are different if substitutions were performed
return expr_sub
[docs]
def solve_poly2(poly, x):
r"""
Finds the roots of a polynomial of degree 2.
Parameters
----------
poly: sympy.Expr
polynomial whose roots are to be computed
x: sympy.Symbol
Variable of the polynomial
Returns
-------
x_sol: list of sympy.Expr
list containing the two roots of the polynomial
Notes
-----
The polynomial of degree 2 takes the form
.. math::
p(x) = a x^2 + bx + c.
Note that :math:`b` can be null but not :math:`a` nor :math:`c`.
It is a workaround to using :func:`~sympy.solvers.solvers.solve` or :func:`~sympy.solvers.solveset.solveset`.
These two work but can be very long when coefficients :math:`a,\; b,\; c` are expressions involving many parameters.
Note that :func:`~sympy.solvers.solvers.solve` is significantly slower than :func:`~sympy.solvers.solveset.solveset`.
"""
# Check the solvability
if not check_solvability(poly, x):
print("The polynomial cannot be solved")
return False
# Polynomial terms
dic_x = polynomial_terms(poly, x)
keys = set(dic_x.keys())
# Solve
if keys == set([x**2, x, 1]):
a = dic_x[x**2].factor()
b = dic_x[x].factor()
c = dic_x[1]
D = (b**2 - 4*a*c).factor()
T1 = (-b/(2*a)).factor()
T2 = sqrt(D)/(2*a)
x1 = T1 - T2
x2 = T1 + T2
x_sol = [x1,x2]
elif keys == set([x**2, 1]):
a = dic_x[x**2].factor()
c = dic_x[1]
T2 = sqrt(- 4*a*c)/(2*a)
x_sol = [-T2, T2]
elif keys == set([x, 1]) or keys == set([x**2, x]):
x_sol = solve(poly.expand(), x)
else:
x_sol = []
print('Trying to use solve_poly2() with a polynomial different from p(x) = a*x**2 + b*x + c')
return x_sol
[docs]
def polynomial_terms(poly, x):
r"""
Identify the terms of a polynomial.
Parameters
----------
poly : sympy.Expr
The polynomial considered.
x : sympy.Symbol
The variable to solve for.
Returns
-------
dic_x: dict
The polynomial terms.
Notes
-----
If the expression given for poly is of the form
.. math::
p(x) = q(x) x^{-n},
where the powers of :math:`x` in :math:`q(x)` are all superior or equal to :math:`0`,
then an auxiliary polynomial
.. math::
P(x) = \dfrac{p(x)}{x^{-n}}
is constructed. It is the terms of that positive powers polynomial that are returned.
"""
# Polynomial terms
dic_x = poly.expand().collect(x, evaluate=False)
keys = set(dic_x.keys())
# Increase the polynomial order if it contains negative powers of x so the lowest possible order is x**0=1
min_power = min(list(keys), key=lambda expr: get_exponent(expr, x))
min_expo = get_exponent(min_power, x)
if min_expo<=0:
poly = (poly/min_power).expand()
# Terms of the increased-order polynomial
dic_x = poly.expand().collect(x, evaluate=False)
return dic_x
[docs]
def check_solvability(poly, x):
r"""
Check the solvability of a polynomial :math:`p(x)`.
Parameters
----------
poly : sympy.Expr
The polynomial considered.
x : sympy.Symbol
The variable to solve for.
Returns
-------
bool : bool,
True is solvable, False otherwise.
"""
dic_x = polynomial_terms(poly, x)
poly_terms = set(dic_x.keys())
min_power = min(poly_terms, key=lambda expr: get_exponent(expr, x))
poly_terms = set([poly_term/min_power for poly_term in poly_terms])
if poly_terms in [set([x**2, x, 1]), set([x**2, 1]), set([x, 1])]:
return True
else:
return False
[docs]
def get_exponent(expr, x):
r"""
Get the exponent of :math:`x` in an expression of the type :math:`\lambda x^n` where :math:`\lambda` is a constant while :math:`n` is an integer or rational.
Parameters
----------
expr: sympy.Expr
The expression in which one wants to identify the exponent of x.
x: sympy.Symbol
The variable whose exponent is to be known.
"""
# This assumes expr is a power of x
if expr.is_Number:
return 0
elif expr == x:
return 1
elif expr.is_Pow and expr.base == x:
return expr.exp
else:
return float('inf') # Handle unexpected expressions
[docs]
def get_block_diagonal_indices(matrix, block_sizes):
r"""
Generate a list of :math:`(i, j)` indices for all elements in the diagonal blocks of a block-diagonal matrix.
Parameters
----------
matrix: sympy.Matrix
The matrix to check for block diagonality.
block_sizes: int or list of int
Size(s) of the diagonal blocks.
Returns
-------
indices : list
A list of tuples `(i, j)` representing the indices of elements in the diagonal blocks.
"""
if isinstance(block_sizes, int):
block_sizes = [block_sizes]*(matrix.rows // block_sizes)
indices = []
start = 0
for size in block_sizes:
end = start + size
# Iterate over the current block
for ii in range(start, end):
for jj in range(start, end):
indices.append((ii, jj))
start = end
return indices
[docs]
def is_block_diagonal(matrix, block_sizes):
"""
Check if a matrix is block-diagonal given block sizes.
Parameters
----------
matrix: sympy.Matrix
The matrix to check for block diagonality.
block_sizes: int or list of int
Size(s) of the diagonal blocks.
Returns
-------
bool: `True` if the matrix is block-diagonal, `False` otherwise.
"""
n = matrix.rows
if isinstance(block_sizes, int):
block_sizes = [block_sizes]*(n // block_sizes)
if sum(block_sizes) != n:
return False
# Get the block-diagonal elements indices
indices_diag_blocks = get_block_diagonal_indices(matrix, block_sizes)
# Iterate over the matrix elements
for i in range(n):
for j in range(n):
# Skip elements inside the diagonal blocks
if (i, j) not in indices_diag_blocks:
if matrix[i, j] != 0:
return False
return True
[docs]
def sympy_to_numpy(expr_sy, param):
"""
Transform a sympy expression into a numpy array.
Parameters
----------
expr_sy : sympy.Expr
A sympy expression.
param : dict
A dictionnary whose values are tuples with 2 elements:
1. The sympy symbol of a parameter
2. The numerical value(s) taken by that parameter
Returns
-------
expr_np : numpy.ndarray
The numerical values taken by the sympy expression evaluated.
"""
args, values = zip(*param.values())
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="invalid value encountered in sqrt")
expr_np = lambdify(args, expr_sy, modules="numpy")(*values)
return expr_np