使用 Sympy 生成 C 代码。将 Pow(x,2) 替换为 x*x

Posted

技术标签:

【中文标题】使用 Sympy 生成 C 代码。将 Pow(x,2) 替换为 x*x【英文标题】:Generate C code with Sympy. Replace Pow(x,2) by x*x 【发布时间】:2021-04-08 13:53:31 【问题描述】:

我正在使用通用子表达式消除 (CSE) 例程和 ccode 打印机生成带有 sympy 的 C 代码。

但是,我希望将幂表达式设为 (x*x) 而不是 pow(x,2)。

无论如何要这样做?

例子:

import sympy as sp
a= sp.MatrixSymbol('a',3,3)
b=sp.Matrix(a)*sp.Matrix(a)

res = sp.cse(b)

lines = []
     
for tmp in res[0]:
    lines.append(sp.ccode(tmp[1], tmp[0]))

for i,result in enumerate(res[1]):
    lines.append(sp.ccode(result,"result_%i"%i))

将输出:

x0[0] = a[0];
x0[1] = a[1];
x0[2] = a[2];
x0[3] = a[3];
x0[4] = a[4];
x0[5] = a[5];
x0[6] = a[6];
x0[7] = a[7];
x0[8] = a[8];
x1 = x0[0];
x2 = x0[1];
x3 = x0[3];
x4 = x2*x3;
x5 = x0[2];
x6 = x0[6];
x7 = x5*x6;
x8 = x0[4];
x9 = x0[7];
x10 = x0[5];
x11 = x0[8];
x12 = x10*x9;
result_0[0] = pow(x1, 2) + x4 + x7;
result_0[1] = x1*x2 + x2*x8 + x5*x9;
result_0[2] = x1*x5 + x10*x2 + x11*x5;
result_0[3] = x1*x3 + x10*x6 + x3*x8;
result_0[4] = x12 + x4 + pow(x8, 2);
result_0[5] = x10*x11 + x10*x8 + x3*x5;
result_0[6] = x1*x6 + x11*x6 + x3*x9;
result_0[7] = x11*x9 + x2*x6 + x8*x9;
result_0[8] = pow(x11, 2) + x12 + x7;

最好的问候

【问题讨论】:

也许这有帮助:***.com/questions/39173019/… 【参考方案1】:

有一个名为create_expand_pow_optimization 的函数可以创建一个包装器来优化您在这方面的表达式。它将被显式乘法取代的最高幂作为参数。

包装器返回一个UnevaluatedExpr,该UnevaluatedExpr 受到保护,防止自动简化恢复此更改。

import sympy as sp
from sympy.codegen.rewriting import create_expand_pow_optimization

expand_opt = create_expand_pow_optimization(3)

a = sp.Matrix(sp.MatrixSymbol('a',3,3))
res = sp.cse(a@a)

for i,result in enumerate(res[1]):
    print(sp.ccode(expand_opt(result),"result_%i"%i))

最后,请注意,对于足够高的优化级别,您的编译器会处理这个问题(并且可能在这方面做得更好)。

【讨论】:

这个功能看起来很有趣!但是,我无法在ccode user_functions 中工作。此外,它似乎只适用于符号 ^ 整数的情况。如果我有:sp.ccode(expand_opt (a+b)**2) 它会给我'pow(a + b, 2) 那应该是sp.ccode(expand_opt((a+b)**2)),但即使这样也行不通。看起来像一个错误。 我在此创建了an issue。【参考方案2】:

您可以将code printer 子类化,并且只更改您想要不同的一个功能。您需要调查 the original sympy code 以找到正确的函数名称和默认实现,因此您可以确保不会出错。稍加注意,所需的括号可以在需要的时间和地点自动生成。

这是一个最小的例子:

import sympy as sp
from sympy.printing.c import C99CodePrinter
from sympy.printing.precedence import precedence
from sympy.abc import x

class CustomCodePrinter(C99CodePrinter):
    def _print_Pow(self, expr):
        PREC = precedence(expr)
        if expr.exp == 2:
            return '(0 * 0)'.format(self.parenthesize(expr.base, PREC))
        else:
            return super()._print_Pow(expr)

default_printer = C99CodePrinter().doprint
custom_printer = CustomCodePrinter().doprint

expressions = [x, (2 + x) ** 2, x ** 3, x ** 15, sp.sqrt(5), sp.sqrt(x)**4, 1 / x, 1 / (x * x)]
print("Default: ".format(default_printer(expressions)))
print("Custom: ".format(custom_printer(expressions)))

输出:

Default: [x, pow(x + 2, 2), pow(x, 3), pow(x, 15), sqrt(5), pow(x, 2), 1.0/x, pow(x, -2)]
Custom: [x, ((x + 2) * (x + 2)), pow(x, 3), pow(x, 15), sqrt(5), (x * x), 1.0/x, pow(x, -2)]

PS:为了支持更广泛的指数,您可以使用例如

class CustomCodePrinter(C99CodePrinter):
    def _print_Pow(self, expr):
        PREC = precedence(expr)
        if expr.exp in range(2, 7):
            return '*'.join([self.parenthesize(expr.base, PREC)] * int(expr.exp))
        elif expr.exp in range(-6, 0):
            return '1.0/(' + ('*'.join([self.parenthesize(expr.base, PREC)] * int(-expr.exp))) + ')'
        else:
            return super()._print_Pow(expr)

【讨论】:

【参考方案3】:

我想我会使用 user_function 方法:

正如上面评论中所建议的,我将使用 sp.ccode 的 user_functions 功能: 假设我们有一个像a^3这样的数字

sp.ccode(a**3, user_functions='Pow': [(lambda x, y: y.is_integer, lambda x, y: '*'.join(['('+x+')']*int(y))),(lambda x, y: not y.is_integer, 'pow')])

应该输出: '(a)*(a)*(a)'

以后我会尝试改进这个功能,只在需要的时候加括号。

欢迎任何改进!

【讨论】:

你仍然需要检查负面力量。例如。 1/aa**(-2)1/(a**2+1) 出错了。对于高功率,您可能想要使用exponentiation via squaring,这可能会导致很多加速。 (另请注意,大多数 C 编译器会自动进行这种优化。)

以上是关于使用 Sympy 生成 C 代码。将 Pow(x,2) 替换为 x*x的主要内容,如果未能解决你的问题,请参考以下文章

使用python的sympy解符号方程组后,如何将结果带入之后的符号表达式

Python sympy做代数运算解Cholesky求逆的L和L逆矩阵示例

Python sympy做代数运算解Cholesky求逆的L和L逆矩阵示例

SymPy可以渲染LaTeX以在GUI中使用吗?

在 sympy 中评估系数而不是指数

将 sympy 中的变量定义为 CONSTANT