如何使用 rust 宏简化数学公式?
Posted
技术标签:
【中文标题】如何使用 rust 宏简化数学公式?【英文标题】:How to simplify mathematical formulas with rust macros? 【发布时间】:2019-08-27 19:29:47 【问题描述】:我必须承认我对宏有点迷茫。 我想构建一个执行以下任务的宏 我不知道该怎么做。我想执行一个标量产品 两个数组,比如 x 和 y,它们具有相同的长度 N。 我要计算的结果是这样的:
z = sum_i=0^N-1 x[i] * y[i].
x
是 const
哪些元素是 0, 1, or -1
在编译时已知,
而y
的元素是在运行时确定的。由于
x
的结构,很多计算都是无用的(项乘以 0
可以从和中去掉,而1 * y[i], -1 * y[i]
形式的乘法可以分别转化为y[i], -y[i]
)。
例如,如果x = [-1, 1, 0]
,上面的标量积将是
z=-1 * y[0] + 1 * y[1] + 0 * y[2]
为了加快计算速度,我可以手动展开循环并重写
没有x[i]
,我可以将上面的公式硬编码为
z = -y[0] + y[1]
但是这个过程并不优雅,容易出错 当N变大时非常乏味。
我很确定我可以用宏来做到这一点,但我不知道在哪里 开始(我阅读的不同书籍并没有深入探讨宏和 我被卡住了)...
你们中有人知道如何(如果可能的话)使用宏来解决这个问题吗?
提前感谢您的帮助!
编辑:正如许多答案中所指出的,编译器足够聪明,可以在整数的情况下删除优化循环。我不仅使用整数,还使用浮点数(x
数组是 i32s,但通常y
是f64
s),所以编译器不够聪明(而且理所当然地)来优化循环。下面的一段代码给出了下面的 asm。
const X: [i32; 8] = [0, 1, -1, 0, 0, 1, 0, -1];
pub fn dot_x(y: [f64; 8]) -> f64
X.iter().zip(y.iter()).map(|(i, j)| (*i as f64) * j).sum()
playground::dot_x:
xorpd %xmm0, %xmm0
movsd (%rdi), %xmm1
mulsd %xmm0, %xmm1
addsd %xmm0, %xmm1
addsd 8(%rdi), %xmm1
subsd 16(%rdi), %xmm1
movupd 24(%rdi), %xmm2
xorpd %xmm3, %xmm3
mulpd %xmm2, %xmm3
addsd %xmm3, %xmm1
unpckhpd %xmm3, %xmm3
addsd %xmm1, %xmm3
addsd 40(%rdi), %xmm3
mulsd 48(%rdi), %xmm0
addsd %xmm3, %xmm0
subsd 56(%rdi), %xmm0
retq
【问题讨论】:
写一个函数有什么问题?fn scalar_product(x: &[i64], y: &[i64]) -> i64 return x.iter().zip(y.iter()).map(|(l, r)| l * r).sum()
这个想法是让它运行得更快。你可以节省至少一半的计算,因为l
将是 0、1 和 -1。
我会首先假设编译器是智能的,通过优化来编译东西,并检查循环是否以所需的方式展开。可能是,甚至不需要宏。
感谢您的回答。正如我在 edited 帖子中指出的那样,不幸的是,编译器不够聪明,无法拯救我,因为我在计算中也使用了浮点数。
【参考方案1】:
首先,(proc) 宏根本无法查看数组x
的内部。它得到的只是你传递给它的令牌,没有任何上下文。如果你想让它知道值(0、1、-1),你需要将它们直接传递给你的宏:
let result = your_macro!(y, -1, 0, 1, -1);
但是您并不需要为此使用宏。编译器进行了很多优化,其他答案也显示了这一点。但是,正如您在编辑中已经提到的那样,它不会优化 0.0 * x[i]
,因为结果并不总是 0.0
。 (例如,它可能是-0.0
或NaN
。)我们可以在这里做的只是通过使用match
或if
来帮助优化器一点,以确保它对0.0 * y
没有任何作用案例:
const X: [i32; 8] = [0, -1, 0, 0, 0, 0, 1, 0];
fn foobar(y: [f64; 8]) -> f64
let mut sum = 0.0;
for (&x, &y) in X.iter().zip(&y)
if x != 0
sum += x as f64 * y;
sum
在发布模式下,循环展开并内联X
的值,导致大多数迭代被丢弃,因为它们什么都不做。生成的二进制文件(在 x86_64 上)中唯一剩下的是:
foobar:
xorpd xmm0, xmm0
subsd xmm0, qword, ptr, [rdi, +, 8]
addsd xmm0, qword, ptr, [rdi, +, 48]
ret
(正如@lu-zero 所建议的,这也可以使用
filter_map
来完成。看起来像这样:X.iter().zip(&y).filter_map(|(&x, &y)| match x 0 => None, _ => Some(x as f64 * y) ).sum()
,并给出完全相同的生成程序集。甚至没有match
,通过使用filter
和map
分别为:.filter(|(&x, _)| x != 0).map(|(&x, &y)| x as f64 * y).sum()
。)
相当不错!然而,这个函数计算0.0 - y[1] + y[6]
,因为sum
是从0.0
开始的,我们只对它做减法和加法。优化器再次不愿意优化掉0.0
。不是从0.0
开始,而是从None
开始,我们可以提供更多帮助:
fn foobar(y: [f64; 8]) -> f64
let mut sum = None;
for (&x, &y) in X.iter().zip(&y)
if x != 0
let p = x as f64 * y;
sum = Some(sum.map_or(p, |s| s + p));
sum.unwrap_or(0.0)
这会导致:
foobar:
movsd xmm0, qword, ptr, [rdi, +, 48]
subsd xmm0, qword, ptr, [rdi, +, 8]
ret
这只是y[6] - y[1]
。宾果游戏!
【讨论】:
非常感谢。正如@lu-zero 所指出的,它可以通过 filter_map() 变得更短。我目前正在将其放入我的代码中,看看性能是否与“手动分析优化”一样好,但我想会的! 使用filter_map().sum()
方法仍将导致0.0 + ...
计算。把它写得更像我的上一个例子会为你节省更多的指令。 (诚然,这可能并不重要。)
正确。谢谢。
如何在不使用 play.rust-lang.org 的情况下查看汇编代码,但在我直接在我的计算机上编译的 crates 上查看?
@jens1o 我使用objdump -xdC
,但这可能不是最简单的方法。这似乎很有用:github.com/gnzlbg/cargo-asm【参考方案2】:
您可以使用返回函数的宏来实现您的目标。
首先,在没有宏的情况下编写这个函数。这个采用固定数量的参数。
fn main()
println!("Hello, world!");
let func = gen_sum([1,2,3]);
println!("", func([4,5,6])) // 1*4 + 2*5 + 3*6 = 4 + 10 + 18 = 32
fn gen_sum(xs: [i32; 3]) -> impl Fn([i32;3]) -> i32
move |ys| ys[0]*xs[0] + ys[1]*xs[1] + ys[2]*xs[2]
现在,完全重写它,因为之前的设计不能很好地用作宏。我们不得不放弃固定大小的数组,如macros appear unable to allocate fixed-sized arrays。
Rust Playground
fn main()
let func = gen_sum!(1,2,3);
println!("", func(vec![4,5,6])) // 1*4 + 2*5 + 3*6 = 4 + 10 + 18 = 32
#[macro_export]
macro_rules! gen_sum
( $( $x:expr ),* ) =>
let mut xs = Vec::new();
$(
xs.push($x);
)*
move |ys:Vec<i32>|
if xs.len() != ys.len()
panic!("lengths don't match")
let mut total = 0;
for i in 0 as usize .. xs.len()
total += xs[i] * ys[i];
total
;
这是做什么的/应该做什么
在编译时,它会生成一个 lambda。此 lambda 接受数字列表并将其乘以编译时生成的 vec。我不认为这正是您所追求的,因为它不会在编译时优化零。您可以在编译时优化掉零,但在运行时必须检查零在 x 中的位置以确定要乘以 y 中的哪些元素,从而必然会在运行时产生一些成本。您甚至可以使用哈希集在恒定时间内完成此查找过程。一般来说,它仍然可能不值得(我认为 0 并不是那么常见)。计算机更擅长做一件“效率低下”的事情,而不是检测到他们将要做的事情“效率低下”然后跳过那件事。当他们所做的大部分操作“低效”时,这种抽象就会崩溃
跟进
这值得吗?它会改善运行时间吗?我没有测量,但与仅使用函数相比,理解和维护我编写的宏似乎不值得。编写一个执行您所说的零优化的宏可能会更不愉快。
【讨论】:
非常感谢您的回答。由于它没有优化 0/1 的距离,所以它不完全是我所追求的。不过本身就很有趣!【参考方案3】:在许多情况下,编译器的优化阶段会为您解决这个问题。举个例子,这个函数定义
const X: [i32; 8] = [0, 1, -1, 0, 0, 1, 0, -1];
pub fn dot_x(y: [i32; 8]) -> i32
X.iter().zip(y.iter()).map(|(i, j)| i * j).sum()
在 x86_64 上生成此程序集输出:
playground::dot_x:
mov eax, dword ptr [rdi + 4]
sub eax, dword ptr [rdi + 8]
add eax, dword ptr [rdi + 20]
sub eax, dword ptr [rdi + 28]
ret
您将无法获得比这更优化的版本,因此简单地以幼稚的方式编写代码是最好的解决方案。编译器是否会为更长的向量展开循环尚不清楚,它可能会随着编译器版本而改变。
对于浮点数,编译器通常无法执行上述所有优化,因为y
中的数字不能保证是有限的——它们也可能是NaN
、inf
或@987654326 @。由于这个原因,与0.0
相乘并不能保证再次得到0.0
,因此编译器需要在代码中保留乘法指令。不过,您可以使用 fmul_fast()
内在函数明确允许它假设所有数字都是有限的:
#![feature(core_intrinsics)]
use std::intrinsics::fmul_fast;
const X: [i32; 8] = [0, 1, -1, 0, 0, 1, 0, -1];
pub fn dot_x(y: [f64; 8]) -> f64
X.iter().zip(y.iter()).map(|(i, j)| unsafe fmul_fast(*i as f64, *j) ).sum()
这会产生以下汇编代码:
playground::dot_x: # @playground::dot_x
# %bb.0:
xorpd xmm1, xmm1
movsd xmm0, qword ptr [rdi + 8] # xmm0 = mem[0],zero
addsd xmm0, xmm1
subsd xmm0, qword ptr [rdi + 16]
addsd xmm0, xmm1
addsd xmm0, qword ptr [rdi + 40]
addsd xmm0, xmm1
subsd xmm0, qword ptr [rdi + 56]
ret
这仍然会在步骤之间冗余地添加零,但我不认为这会导致实际 CFD 模拟的任何可测量开销,因为此类模拟往往受到内存带宽而不是 CPU 的限制。如果您也想避免这些添加,则需要使用 fadd_fast()
进行添加以允许编译器进一步优化:
#![feature(core_intrinsics)]
use std::intrinsics::fadd_fast, fmul_fast;
const X: [i32; 8] = [0, 1, -1, 0, 0, 1, 0, -1];
pub fn dot_x(y: [f64; 8]) -> f64
let mut result = 0.0;
for (&i, &j) in X.iter().zip(y.iter())
unsafe result = fadd_fast(result, fmul_fast(i as f64, j));
result
这会产生以下汇编代码:
playground::dot_x: # @playground::dot_x
# %bb.0:
movsd xmm0, qword ptr [rdi + 8] # xmm0 = mem[0],zero
subsd xmm0, qword ptr [rdi + 16]
addsd xmm0, qword ptr [rdi + 40]
subsd xmm0, qword ptr [rdi + 56]
ret
与所有优化一样,您应该从代码的可读性和可维护性最高的版本开始。如果性能成为问题,您应该分析您的代码并找到瓶颈。下一步,尝试改进基本方法,例如通过使用具有更好的渐近复杂度的算法。只有这样,您才应该转向像您在问题中建议的那样进行微优化。
【讨论】:
非常感谢您的回答。正如我的 edited 问题中所解释的, y 通常由浮点数组成。因此编译器不够聪明,无法优化循环。 @Jean-PaulDax 您是否有任何迹象表明此循环会导致整个代码的性能瓶颈? 是的。这种函数在流体动力学模拟中可能被调用了数十亿次。目前,我正在手动明确地简化循环,但是在修改模型时它严重缺乏通用性。取决于模型的增益至少是一个数量级(这种计算一直在进行)。 @Jean-PaulDax 我明白了——这是有道理的,而且确实是需要进行微优化的情况之一。编译器无法针对浮点数进行优化的原因已在接受的答案中进行了解释,但我将在此处添加另一个简单的解决方案。 非常感谢您提供的替代解决方案。我们可以有两个可接受的答案吗?【参考方案4】:如果您可以节省 #[inline(always)]
可能使用显式 filter_map() 应该足以让编译器执行您想要的操作。
【讨论】:
哇!看起来真的是working!将在我的实际代码中对其进行测试并发布更新。没想到编译器这么聪明……以上是关于如何使用 rust 宏简化数学公式?的主要内容,如果未能解决你的问题,请参考以下文章