处理多种类型和数组时如何编写“好”的 Julia 代码(多重分派)
Posted
技术标签:
【中文标题】处理多种类型和数组时如何编写“好”的 Julia 代码(多重分派)【英文标题】:How to write "good" Julia code when dealing with multiple types and arrays (multiple dispatch) 【发布时间】:2014-09-20 10:57:55 【问题描述】:OP 更新:请注意,在最新版本的 Julia (v0.5) 中,回答此问题的惯用方法是定义 mysquare(x::Number) = x^2
。使用自动广播覆盖矢量化案例,即x = randn(5) ; mysquare.(x)
。另请参阅更详细地解释点语法的新答案。
我是 Julia 新手,鉴于我的 Matlab 出身,我很难确定如何编写利用多重调度和 Julia 类型系统的“好的”Julia 代码。
考虑我有一个提供Float64
平方的函数的情况。我可以这样写:
function mysquare(x::Float64)
return(x^2);
end
有时,我想将所有Float64
s 平方在一维数组中,但不想每次都在mysquare
上写一个循环,所以我使用多重调度并添加以下内容:
function mysquare(x::ArrayFloat64, 1)
y = Array(Float64, length(x));
for k = 1:length(x)
y[k] = x[k]^2;
end
return(y);
end
但现在我有时会使用Int64
,所以我写了另外两个利用多重调度的函数:
function mysquare(x::Int64)
return(x^2);
end
function mysquare(x::ArrayInt64, 1)
y = Array(Float64, length(x));
for k = 1:length(x)
y[k] = x[k]^2;
end
return(y);
end
这是对的吗?还是有更符合意识形态的方法来处理这种情况?我应该使用这样的类型参数吗?
function mysquareT<:Number(x::T)
return(x^2);
end
function mysquareT<:Number(x::ArrayT, 1)
y = Array(Float64, length(x));
for k = 1:length(x)
y[k] = x[k]^2;
end
return(y);
end
这感觉很明智,但我的代码会像我避免使用参数类型的情况一样快速运行吗?
总的来说,我的问题有两个部分:
如果快速代码对我很重要,我应该使用上述参数类型,还是应该为不同的具体类型编写多个版本?还是我应该完全做其他事情?
当我想要一个对数组和标量进行操作的函数时,编写两个版本的函数,一个用于标量,一个用于数组,是否是一种好习惯?还是我应该完全做其他事情?
最后,请指出您在上面的代码中能想到的任何其他问题,因为我在这里的最终目标是编写好的 Julia 代码。
【问题讨论】:
【参考方案1】:从 Julia 0.6(约 2017 年 6 月)开始,“dot syntax”提供了一种将函数应用于标量或数组的简单且惯用的方式。
您只需要提供函数的标量版本,以正常方式编写。
function mysquarex::Number)
return(x^2)
end
将.
附加到函数名称(或将其预先添加到运算符)以在数组的每个元素上调用它:
x = [1 2 3 4]
x2 = mysquare(2) # 4
xs = mysquare.(x) # [1,4,9,16]
xs = mysquare.(x*x') # [1 4 9 16; 4 16 36 64; 9 36 81 144; 16 64 144 256]
y = x .+ 1 # [2 3 4 5]
请注意,点调用将处理广播,如上一个示例所示。
如果您在同一个表达式中有多个点调用,它们将被融合,以便y = sqrt.(sin.(x))
进行一次传递/分配,而不是创建一个包含 sin(x) 的临时表达式并将其转发给 sqrt()功能。 (这与 Matlab/Numpy/Octave/Python/R 不同,它们不做这样的保证)。
宏@.
将一行上的所有内容向量化,因此@. y=sqrt(sin(x))
与y = sqrt.(sin.(x))
相同。这对于多项式特别方便,其中重复的点可能会令人困惑......
【讨论】:
【参考方案2】:Julia 会根据需要为每组输入编译特定版本的函数。因此,回答第 1 部分,没有性能差异。参数化的方式是要走的路。
至于第 2 部分,在某些情况下编写单独的版本可能是个好主意(有时出于性能原因,例如,避免复制)。但是,在您的情况下,您可以使用内置宏 @vectorize_1arg
自动生成数组版本,例如:
function mysquareT<:Number(x::T)
return(x^2)
end
@vectorize_1arg Number mysquare
println(mysquare([1,2,3]))
至于一般的样式,不要使用分号,mysquare(x::Number) = x^2
要短很多。
对于您的矢量化mysquare
,请考虑T
是BigFloat
的情况。但是,您的输出数组是Float64
。处理此问题的一种方法是将其更改为
function mysquareT<:Number(x::ArrayT,1)
n = length(x)
y = Array(T, n)
for k = 1:n
@inbounds y[k] = x[k]^2
end
return y
end
我在其中添加了 @inbounds
宏来提高速度,因为我们不需要每次都检查边界违规——我们知道长度。如果x[k]^2
的类型不是T
,此函数仍可能存在问题。一个更具防御性的版本可能是
function mysquareT<:Number(x::ArrayT,1)
n = length(x)
y = Array(typeof(one(T)^2), n)
for k = 1:n
@inbounds y[k] = x[k]^2
end
return y
end
如果T
是Int
,one(T)
将给出1
,如果T
是Float64
,则1.0
,等等。仅当您想制作超健壮的库代码时,这些考虑因素才重要。如果您真的只会处理Float64
s 或可以提升为Float64
s 的事情,那么这不是问题。看似辛苦,但力量却是惊人的。您总是可以满足于类似于 Python 的性能而忽略所有类型信息。
【讨论】:
为什么是mysquareT<:Number(x::T)
而不是mysquare(x::Number)
?
没有理由,可能只是为了使其在视觉上与矢量化版本一致。一般不会这样写。
您可能希望针对矢量化更改进行更新。这会在侧边栏弹出,因为它的投票率很高,所以搜索的人可能会找到 @vectorize_1arg
,而现在这是 .
符号。以上是关于处理多种类型和数组时如何编写“好”的 Julia 代码(多重分派)的主要内容,如果未能解决你的问题,请参考以下文章
在 Julia 中切片和广播多维数组:meshgrid 示例
指定 Julia 函数只能采用其内容为特定类型的字典/数组的正确方法是啥?
如何在 Julia 中编写和读取包含日期时间列的 DataFrame