动手打造深度学习框架:元数据结构与算法

Posted 人邮异步社区

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了动手打造深度学习框架:元数据结构与算法相关的知识,希望对你有一定的参考价值。

元程序与运行期代码要解决的问题是有共通性的。这种共通性决定了很多运行期需要的数据结构与算法在元编程中也同样需要。将这些通用的数据结构与算法总结出来并加以实现,可以集中优化,便于后期使用。

事实上,这也正是很多元程序库所做的事情。本书并不打算使用某个元程序库,但讨论一下如何实现这些元程序库中的通用算法是非常有意义的。本质上,本章所讨论的可以视为一个微型元程序库的实现。通过实现这一元程序库,一方面,我们可以进一步熟悉元函数的编写方法;另一方面,我们也会看到一些元程序在编写过程中设计上的权衡,以及随之而来所衍生出的相应技巧。

2.1 基本数据结构与算法

我们要实现的元程序库要包含哪些内容呢?这个元程序库并不需要包含非常复杂的数据结构与算法,但应该具有足够的通用性,能够为我们的深度学习框架实现提供有力的支持。STL就是此类通用函数库中的一个典范:它包含的大部分数据结构与算法都比较简单,但被广泛地应用于各种C++程序的开发过程中。当然,C++标准模板库主要被应用于运行期,而我们要实现的元程序库则会在编译期大显身手。应用场景虽有所区别,但这并不妨碍我们借鉴STL的优秀设计。

2.1.1 数据结构的表示方法

STL中的主要数据结构可以划分为两类:顺序容器与关联容器。前者通过位置来访问数据,后者通过特定类型的键来访问数据。在运行期可以使用的工具相对较多,相应的数据表示形式也多种多样。以顺序容器为例,在STL中常用的顺序容器就包括vector、list等。这些数据结构各有优劣,用户可以根据具体场景进行选择。

相比之下,在编译期我们能使用的工具就不是那么多了:编译期所处理的是常量——无法修改数据的值将对我们的工具选择造成很大限制;编译期对指针等概念的支持相对较弱,我们也无法在编译期进行动态内存分配并以类似指针的形式保存分配的空间,用于后续访问。这些都限制了我们在构造数据结构时可以选择的工具。如第1章所讨论的那样,在编译期表示容器较方便的方法就是使用变长参数模板。我们会将其作为数据结构的载体,以表示在编译期使用的顺序容器与关联容器。

  • 顺序表:一个变长参数模板实例中的元素是天然有序的。按照C++的惯例,我们将变长参数模板中的元素按照从前到后的顺序赋予相应的索引值,索引值从0开始。比如对于tuple<int, double, char>来说,int、double、char所对应的索引值分别为012。
  • 集合:变长参数模板实例也可以表示集合。比如tuple<int, double, char>同样可以视为一个包含了3个元素的集合。集合中的元素没有顺序性,也即tuple<double, char,int> 所表示的集合与tuple<int, double, char>所表示的等价。另外,通常来说集合中的元素具有互异性,即相同的元素在集合中不会出现多次。因此,对于像tuple<int, char,int>这样的变长参数模板实例来说,是否可以将其视为集合呢?显然,这个实例中存在相同的元素。我们可以拒绝将其视为一个集合,也可以采用其他的方式来解释该实例,比如:无论容器中相同的元素出现多少次,都视为仅出现了一次。采用这种解释时,上述实例也可视为一个集合。要怎么解释容器中的元素是一个选择问题。我们将会在本章的后面讨论不同的选择,以及每种选择所带来的性能差异。
  • 映射:STL中的映射容器采用键-值对存储元素,可以通过键来获取相应的值,我们的元程序库中也将引入类似的构造。我们会使用KVBinder模板来存储键-值对。KVBinder的定义如下[2]:
1    template <typename TK, typename TV>
2    struct KVBinder
3    
4        using KeyType = TK;
5        using ValueType = TV;
6        static TV apply(TK*);
7    ;

KVBinder提供了元数据域来获取键与值的类型。在此基础上,我们可以使用变长参数模板容器来表示映射,比如tuple<KVBinder<int, int*>, KVBinder<char, char*>>——这个映射将一些类型与其指针类型关联了起来。

与集合类似,映射中的键有互异性,因此这里也存在是否将具有相同键的容器视为映射的问题。我们将会在讨论映射实现时分析不同选择所带来的性能差异。

  • 多重映射(multimap):STL提供了multimap来表示多重映射,也即键可以重复的映射。在我们的深度学习框架中,某些地方需要在编译期使用多重映射,因此我们的元程序库中也引入了多重映射。我们使用如下的结构来表示多重映射中的键值关系:
1    KVBinder<Key, ValueSequence<Values...>>

ValueSequence是一个变长参数模板,用于存储某个键所对应的值序列。变长参数模板同时还会作为多重映射的容器使用。一个典型的多重映射实例形如:

1    tuple<KVBinder<int, ValueSequence<char>>,
2          KVBinder<double, ValueSequence<int, bool>>> 

它包含了3个键-值对:int-char、double-int与double-bool。

  • 数值容器:细心的读者可能发现了,前面所列出的容器中存储的元素都是类型。这是因为在我们将要实现的深度学习框架中,类型处理占据元程序的主要部分。除此之外,我们也会在某些地方用到与数值相关的元数据结构与算法。但它们与类型容器的处理方式非常相似,因此本章也就不详细讨论了。

可以看出,变长参数模板在我们的元程序库中占据了重要的地位,所有的元数据结构都是以它为载体来实现的。这种设计的缺点在于:给定一个变长参数模板容器,我们很难判断出它所表示的具体含义(序列、集合,还是映射……)。但它也有优点:容器的实例可以自由转换其角色,选择适当的算法。比如,映射可以看成集合(只需要将键-值对看成一个键),因此可以将集合相关的算法应用到映射上;集合又可以看成序列,因此可以将序列相关的算法应用到集合上。我们可以灵活地选择算法达到目的。

还有一点要说明的是:我们使用变长参数模板作为元数据结构的载体,但并不限制变长参数模板的具体类型。在前文中,我们使用了tuple作为示例,但我们也可以采用其他的变长参数模板。比如完全可以自定义一个变长参数模板容器,并使用它来表示序列、集合或映射。

以上就是我们所使用的元数据结构。在此基础上就可以引入一些算法来实现相关的操作了。让我们首先从一些简单的算法开始讨论。

2.1.2 基本算法

很多算法都是非常基础且易于实现的。比如获取顺序表尺寸(其中包含的元素个数)的算法:

1    template <typename TArray>
2    struct Size_;
3
4    template <template <typename...> class Cont, typename...T>
5    struct Size_<Cont<T...>>
6    
7        constexpr static size_t value = sizeof...(T);
8    ;
9
10    template <typename TArray>
11    constexpr static size_t Size = Size_<RemConstRef<TArray>>::value; 

这个算法的核心在第7行,它使用C++11中引入的关键字sizeof... 来获取一个类型序列的长度。我们基于这个关键字构造出了元函数Size_。注意,第1~2行是这个元函数模板的声明,而第4~8行是相应元函数的特化实现。正是这个特化实现限定了该元函数只能作用于变长参数模板容器。

在Size_元函数的基础上,我们引入了Size元函数。像第1章讨论的那样,调用Size元函数时,我们不再需要 ::value这样的依赖名称。同时,Size元函数还调用了RemConstRef对输入参数进行变换,使得元函数可以接收常量或引用类型。

RemConstRef的定义如下:

1    template <typename T>
2    using RemConstRef = std::remove_cv_t<std::remove_reference_t<T>>;

其中调用了type_traits中的元函数,去掉了输入参数中的引用与常量限定符(如果有)。

因此,我们可以这样调用Size元函数:

1    using Cont = std::tuple<char, double, int>;
2    constexpr size_t Res1 = Size<Cont>;
3    constexpr size_t Res2 = Size<Cont&>;

其中Res1与Res2的值均为3。注意,Res2之所以能被求值,是因为RemConstRef去掉了输入参数中的引用限定符。

基本算法的另外两个例子是元函数Head与Tail,它们分别用于获取输入序列的首个元素与去除首个元素的子序列。与Size元函数类似,这两个元函数也分别调用了Head_与Tail_来实现各自的逻辑。Head_与Tail_的定义如下:

1     template <typename TSeqCont>
2     struct Head_;
3
4     template <template <typename...> class Container, typename TH,
5                typename...TCases>
6     struct Head_<Container<TH, TCases...>>
7     
8         using type = TH;
9     ;
10
11    template <typename TSeqCont>
12    struct Tail_;
13
14    template <template <typename...> class Container, typename TH,
15             typename...TCases>
16    struct Tail_<Container<TH, TCases...>> 17 
17    
18        using type = Container<TCases...>;
19    ;

类似算法的实现都非常直观。这里就不一一列举了。

2.1.3 算法的复杂度

理论上,使用第1章讨论的顺序、分支、循环代码的编写方法,我们可以实现大部分与容器相关的算法。但在实现其他算法之前,让我们首先以Size为例,分析其实现的复杂度。

读者可能会问:我们为什么要关心这些算法的复杂度?事实上,这些算法所对应的代码是在编译期被执行的,也就是说,它们的执行效率基本上不会对代码的运行期造成影响。既然如此,我们真的需要关心它们的实现复杂与否吗?

答案是肯定的。这里需要着重指出一点:即使是在编译期执行的代码,也是需要执行的。这些代码的执行者,实际上是编译器!

我们可以从另一个角度来审视代码的编译过程:我们的源程序就好似一段脚本,而编译器正如脚本的执行者,编译结果则类似脚本的执行结果。从这个角度上来说,编译一段C++代码的过程,与执行一段Python代码没有什么区别,都是需要占用系统资源与运行时间的。如果元函数的复杂度比较高,反复调用就会导致编译用时较长、编译所需内存较多。

另外,将编译的过程与一般脚本的执行过程进行类比并不完全公平。二者虽然有相似之处,但应用场景不同,它们面临的问题也不同。一般的脚本可能会被反复执行,处理的数据量可能较大(可能要以大量的数据作为输入并产生大量的输出),这就对脚本的执行速度产生了相对较高的要求。源代码文件相对较短,同时编译操作的执行频率相对较低(除了开发场景外,一般编译成功之后就不需要再次编译源代码文件了)。因此我们可以对编译器的执行效率有更大的容忍。

但编译器也有编译器的问题,正如我们在第1章所讨论的那样,编译器可能并没有针对元编程引入足够的优化。元函数在执行过程中所产生的实例可能都会保存在编译器的内存中,在整个编译过程中都不会被释放。因此,如果元函数的复杂度较高,可能导致编译器内存超限而编译失败。

对于老式的计算机或32位编译程序来说,这可能是个大问题(32位编译程序能够使用的最大内存容量为4GB,编译复杂的元程序很可能导致内存不足)。当前,主流的计算机是64位的,同时计算机中的内存容量也得到了很大的提高,这能在一定程度上缓解内存不足的问题。但我们依旧需要关注元函数的复杂度,以防在元函数过于复杂、编译项目较大的情况下,编译用时较长或占用内存较多而导致编译失败。

那么,我们要如何衡量元函数的复杂度呢?作为一个普通的C++ 开发者,我们可能对编译器内部的实现原理并不清楚,因此无法做出很精确的估计。但我们至少可以估计出在一个元函数的执行过程中,编译器可能会构造出的实例数,并以此作为元函数复杂度的一种度量:当然,我们希望元函数执行过程中所构造出的实例数越少越好,实例数越多,说明算法越复杂。

让我们回顾一下之前讨论的Size,对于以下的语句:

1    Size<tuple<double, int, char>>

编译器会在执行过程中接收并产生如下的实例:

1    tuple<double, int, char>
2    RemConsRef<tuple<double, int, char>>
3    Size_<RemConsRef<tuple<double, int, char>>>
4    Size_<RemConsRef<tuple<double, int, char>>>::value
5    Size<tuple<double, int, char>>

这些实例可能会被一一构造出来并保存在编译器的内存中。不同的实例对应的构造与存储成本并不相同[3]。但我们在这里并不会考虑这种成本差异的细节,只是对算法的复杂度进行粗略的估计。

现在让我们来看一个相对复杂的算法:数组索引,即给定一个数组,获取其中的第N个元素。

读者可能会感到诧异:这是复杂的算法吗?事实上,可能出乎读者的意料,这可能是我们将要实现的最复杂的算法之一了。对运行期数组进行索引非常简单,这是因为从硬件到软件层面上都对其提供了很好的支持。但在编译期,语言规范对这种操作并没有提供足够的支持,这就可能导致相应算法(或者说相应操作)的复杂度非常高。

让我们首先实现一个基础版本,再来分析一下这个版本的复杂度高在何处。利用第1章讨论的顺序、分支、循环代码的编写方法,我们可以相对容易地实现数组索引,算法如下:

1     template <typename TCont, size_t ID>
2     struct At_;
3
4     template <template<typename...> class TCont,
5               typename TCurType, typename... TTypes, size_t ID>
6     struct At_<TCont<TCurType, TTypes...>, ID>
7     
8         using type = typename At_<TCont<TTypes...>, ID-1>::type;
9     ;
10
11    template <template<typename...> class TCont,
12              typename TCurType, typename... TTypes>
13    struct At_<TCont<TCurType, TTypes...>, 0>
14    
15        using type = TCurType;
16    ;

At_元函数的实现包含了一个声明与两个模板特化。第1~2行的声明表明该元函数接收两个参数,分别对应输入序列与索引值。后两个特化则形成了一个循环逻辑:第一个特化用于匹配索引值不为0的情况——此时系统会将索引值减1,继续下一步循环;第二个特化匹配索引值为0的情况,此时返回当前类型。这个元函数的使用方式很简单,比如typename At_<tuple<double, int, char>, 2>::type的结果为char。

现在让我们粗略地估计一下该元函数的复杂度。以typename At_<tuple<double, int, char >, 2>::type为例,看一下元函数在执行过程中可能产生的实例个数。不难看出,此时编译器会产生如下的一些实例:

1    At_<tuple<double, int, char>, 2>
2    At_<tuple<int, char>, 1>
3    At_<tuple<char>, 0>

读者可能意识到了:编译器所产生的实例个数与输入的索引值成正比。这并不是一个好现象。显然,当输入的索引值比较大时,编译器就会产生大量的实例,这同时意味着更长的编译时间,以及更多的内存占用。

事实上,这种实现还存在另一个问题。通常来说,如果将信息保存成一个数组,那么我们往往需要访问数组不同位置处的元素。考虑tuple<double, int, char> 这个数组,在刚刚获取了索引值为2的元素之后,如果我们希望再次调用该元函数获取索引值为1的元素,那么编译器会产生如下的实例:

1    At_<tuple<double, int, char>, 1>
2    At_<tuple<int, char>, 0>

读者可能已经发现了,这些实例化的结果与之前实例化的结果完全不同!这就意味着虽然编译器可能在内存中保存了之前的实例化结果,但我们无法从之前的实例化结果中获益。进一步,编译器可能会将这些新的实例保存在内存中,进一步增加编译负担。

希望这个示例能让读者体会到一个实现相对较差的元函数可能对编译器产生的不良影响。一个好的元函数实现应该使得实例化的次数尽量少,同时能尽量地复用之前实例化的结果。如果我们仅仅采用第1章所学习的顺序、分支与循环代码的编写方法,显然无法达到这个目的。要想降低元函数的复杂度,就需要求助于一些特别的技巧。我们将在本章的后续部分讨论一些降低复杂度的技巧。同时,我们将在本章的结尾给出一个低复杂度的序列索引算法实现,但本着从易到难的原则,我们将首先讨论一些相对容易掌握的技巧。首先,让我们来看第一类技巧:基于包展开与折叠表达式的优化。

2.2 基于包展开与折叠表达式的优化

我们在第1章讨论循环逻辑的编写方法时,就介绍过包展开与折叠表达式。这两种技巧不仅能简化循环逻辑的编写,同时也能在一些场景中减少编译器所构造的实例个数,也即降低元函数的复杂度。

2.2.1 基于包展开的优化

包展开的一个经典应用就是实现编译期的transform逻辑。transform接收一个序列与某个元函数,对序列中的每个元素调用该元函数进行变换,变换后的结果保存在一个新的序列中返回。

我们当然可以使用基本的循环代码来实现相应的逻辑。但如果采用这种方式,元函数的执行过程中将产生很多实例——实例的个数与序列中元素的个数成正比。比如,假定输入列表为Cont<X1, X2, ...,Xn> , 元函数为F,那么使用基本的循环代码,我们可能构造出以下的实例:

1    Cont<F<X1>>
2    Cont<F<X1>, F<X2>>
3    ...
4    Cont<F<X1>, F<X2>,..., F<Xn>>

如果采用包展开,产生的实例个数就能大大减少。以下给出了基于包展开的transform元函数实现:

1     template <typename TInCont, template <typename> typename F,
2               template<typename...> typename TOutCont>
3     struct Transform_;
4
5     template <template <typename...> typename TInCont,
6               typename... TInputs,
7               template <typename> typename F,
8               template<typename...> typename TOutCont>
9     struct Transform_<TInCont<TInputs...>, F, TOutCont>
10    
11         using type = TOutCont<typename F<TInputs>::type ...>;
12    ;
13
14    template <typename TInCont,
15              template <typename> typename F,
16              template<typename...> typename TOutCont>
17    using Transform = typename Transform_<TInCont, F, TOutCont>::type;

整段代码的核心是第11行,不难发现,这一行通过包展开一次性对序列中的所有元素调用了元函数F。这会减少很多不必要的中间结果。

2.2.2 基于折叠表达式的优化

在一些场景下,使用折叠表达式也会减少实例个数。比如,我们希望实现一个元函数,来判断两个集合是否相等。注意,集合中的元素顺序可以存在差异,所以下面的实现是错误的:

1    template <typename Set1, typename Set2>
2    constexpr bool IsEqual = std::is_same_v<Set1, Set2>;

如果Set1与Set2中元素的顺序不同,那么即使两个集合相等,系统也会返回false。

那么该如何判断两个集合相等呢?我们可以判断每个集合中的任意元素是否属于另一个集合。如果该条件满足,那么两个集合是相等的。假定我们已经实现了一个高效的算法HasKey,用于判断某个元素是否在集合中出现过。这个算法能达到的效果是:多次调用时,只要测试的集合相同,那么所引入的额外的实例化就会非常少[4]:

1    // 首次调用产生一些实例
2    HasKey<tuple<double, char, int>, int>;
3    // 再次调用,测试集相同,只会产生少量的实例
4    HasKey<tuple<double, char, int>, float>;

基于HasKey的实现,我们希望实现一个元函数,来判断两个集合是否相等:

1     template <typename TFirstSet, typename TSecondSet>
2     struct IsEqual_;
3
4     template <template <typename...> class Cont1,
5               template <typename...> class Cont2,
6               typename... Params1, typename... Params2>
7     struct IsEqual_<Cont1<Params1...>, Cont2<Params2...>>
8     
9         constexpr static bool value1
10            = (HasKey<Cont1<Params1...>, Params2> && ...);
11        constexpr static bool value2
12            = (HasKey<Cont2<Params2...>, Params1> && ...);
13        constexpr static bool value = value1 && value2;
14    ;
15
16    template <typename TFirstSet, typename TSecondSet>
17    constexpr bool IsEqual = IsEqual_<TFirstSet, TSecondSet>::value;

上述代码的核心是第9~12行:第9~10行使用了折叠表达式来判断第二个集合中的所有元素都在第一个集合中;第11~12行则判断了第一个集合中的所有元素都在第二个集合中。由于HasKey具有之前所讨论的特性,因此虽然折叠表达式会引入很多的HasKey调用,但根据前文的讨论,由此引入的实例会相对较少。总体来说,这还是一个比较好的实现。

当然,这个实现还是有一些可以优化的空间的。我们在第1章讨论了AndValue元函数,用于实现编译期的短路逻辑,可以将其引入该元函数中:如果value1为false,就不需要再计算value2的值了。这个修改就交给读者完成。

包展开与折叠表达式可以说是较直观易用的元函数优化方法了。在一些场景下,使用包展开与折叠表达式确实能够优化元函数。但正如我们在第1章讨论的那样,包展开与折叠表达式的使用场景非常受限,只能在特殊的场景中使用。同时,使用折叠表达式时还需要小心,如果使用不当,反而会造成元函数执行过程中的实例“爆炸”[5]。同样以“判断两个集合是否相等”为例,如果HasKey在每次调用都会产生较多的实例,那么我们的实现会出现很大问题。因此使用时要多加小心。

2.3 基于操作合并的优化

基于包展开与折叠表达式进行的优化,其本质就是使用C++ 的新语法在一条语句中同时执行多条指令。以包展开为例,语句如下:

1    using type = TOutCont<typename F<TInputs>::type ...>;

在一条语句中遍历TInputs中的每个元素,以其作为输入调用F后将结果一次性放到TOutCont容器中。正是这种在一条语句中处理多条指令的方法帮助我们减少了元函数执行过程中的实例个数。

包展开与折叠表达式的使用场景毕竟有限,但这并不妨碍我们将“一条语句中同时执行多条指令”这一思想应用到无法使用包展开与折叠表达式的场景之中。笔者将采用这种思想进行的优化称为“基于操作合并的优化”。

让我们看一个应用该思想的元函数示例:折叠。说到折叠,相信阅读至此的读者可能会想到折叠表达式。折叠表达式是一种折叠,但就像第1章讨论的那样,折叠表达式只能处理数值,无法处理类型。我们在这里要实现的折叠函数则是针对类型的,它的输入是一个类型序列,输出则是一个类型。

我们的折叠函数将接收如下的参数。

  • TInputCont<TInput1, TInput2, ...,TInputN>:包含了输入序列的容器。
  • TInitState:初始状态。
  • F:元函数,接收两个类型输入,返回一个类型结果。

折叠函数会首先调用F<TInitState, TInput1>::type来产生一个中间结果TRes1,之后调用F<TRes1, TInput2>::type来产生中间结果TRes2,以此类推,最终元函数返回TResN。

基于一般的循环语句编写方法,可以按照如下的方式实现折叠函数:

1     template <typename TState,
2               template <typename, typename> typename F,
3               typename... TRemain>
4     struct imp_
5     
6         using type = TState;
7     ;
8
9     template <typename TState,
10              template <typename, typename> typename F,
11              typename T0, typename... TRemain>
12    struct imp_<TState, F, T0, TRemain...>
13    
14        using type = typename imp_<F<TState, T0>, F, TRemain...>::type;
15    ;
16
17    template <typename TInitState, typename TInputCont, 
18              template <typename, typename> typename F>
19    struct Fold_;
20
21    template <typename TInitState, template<typename...> typename TCont,
22              typename... TParams,
23              template <typename, typename> typename F>
24    struct Fold_<TInitState, TCont<TParams...>, F>
25    
26        template <typename S, typename I>
27        using FF = typename F<S, I>::type;
28 
29        using type = typename imp_<TInitState, FF, TParams...>::type;
30    ;
31
32    template <typename TInitState, typename TInputCont,
33              template <typename, typename> typename F>
34    using Fold = typename Fold_<TInitState, TInputCont, F>::type;

这段代码相对较长,但逻辑并不复杂。让我们按照从外到内的顺序来看。

第32~34行定义了Fold元函数,它本质上是将运算逻辑代理给Fold_元函数执行。第17~30行定义了Fold_元函数。其中第17~19行是该元函数的声明,而第21~30行通过特化表明元函数的第二个参数是一个序列。在此基础上,第26~27行将F进行了转换,这样在后面的调用中,我们就不需要再写::type这样的后缀了。同时第29行调用了imp_来实现计算逻辑。与Fold_相比,imp_不再包含容器模板TCont,这使得代码的编写更加容易。

核心的计算位于imp_之中。imp_包含了一个基本模板与一个特化版本,二者放在一起构成了循环逻辑。特化版本会调用F,输入当前状态与待处理的元素,获取相应的返回值,并以该返回值作为输入,再次调用imp_实现循环(第14行)。imp_的基本模板则会匹配终止循环的情形:如果输入已经全部处理完毕,那么直接返回当前状态TState(第6行)。

不难看出,当调用Fold元函数时,如果输入序列较长,那么imp_的特化版本(第9~15行)会被反复调用,相应地产生多个实例。为了减少实例的产生,我们可以引入操作合并,比如,增加一个新的特化版本:

1    template <typename TState,
2              template <typename, typename> typename F,
3              typename T0, typename T1>
4    struct imp_<TState, F, T0, T1>
5    
6        using type = F<F<TState, T0>, T1>;
7    ;

其中当imp_中待处理的元素个数为2时,系统会选择这个分支——可以看出,这个元函数相当于将两步操作进行了合并,可以减少一些实例的产生。

我们可以进一步使用这种技巧,引入更多的特化版本:

1    template <typename TState,
2              template <typename, typename> typename F,
3              typename T0, typename T1, typename T2>
4    struct imp_<TState, F, T0, T1, T2> ...;

这样,当待处理的元素个数为3时,系统会选择这个特化版本,一次性完成处理。

我们还可以引入更多类似的特化版本,笔者引入了可以同时处理6个元素的特化版本。这样,如果序列中的元素个数小于或等于6,那么就可以一次性完成处理了。

对于序列中的元素个数大于6的情形,我们引入了如下的特化版本:

1     template <typename TState,
2               template <typename, typename> typename F,
3               typename T0, typename T1, typename T2,
4               typename T3, typename T4, typename T5,
5               typename T6, typename... TRemain>
6     struct imp_<TState, F, T0, T1, T2, T3, T4, T5, T6, TRemain...>
7     
8         using type = typename imp_<F<F<F<F<F<F<F<TState, T0>,
9                                                  T1>, T2>, T3>,
10                                                 T4>, T5>, T6>,
11                                   F, TRemain...>::type;
12    ;

也即当序列中的元素个数大于或等于7时,系统会匹配这个特化版本,一次处理7个元素。这相当于将每7步处理合并到一步进行。

与运行期序列相比,元函数所接收的编译期序列都相对较短。如果序列的长度小于或等于7,那么对imp_的一次实例化就可以满足需求。即使序列长度大于7,我们也很有可能仅需要几次对imp_的实例化就可以完成整个序列的处理。

2.4 基于函数重载的索引算法

无论是序列还是集合、映射,它们都包含获取其中的元素的操作(索引操作)。不同之处在于:序列使用整数值作为索引,而映射使用键作为索引。事实上,集合的查找也可以被视为一种索引,输入是键,输出是一个bool值,表示是否找到。无论建立哪种容器,我们都需要通过索引算法获取其中的元素。

我们在小节通过一个示例说明了:索引算法性能较差时会产生大量的编译期实例。因此,一个好的索引算法对元程序库的性能优化至关重要。本节将讨论一种基于函数重载的索引算法,它可以用于为序列、集合与映射建立低复杂度的索引函数。

2.4.1 分摊复杂度

对于一个容器,如果我们只进行一次索引操作,那么通常来说没有什么好的优化方法。对于一个一般的容器来说,我们能做的往往只有依次处理数组中的每个元素,此时算法的复杂度与数组中元素的个数成正比。

幸运的是,通常来说,当建立了一个容器后,我们往往需要多次获取其中的元素。虽然每次访问的索引可能不同,但由于访问的容器是相同的,因此我们可以在首次索引时就构造一些数据结构,以降低后续索引该容器的其他元素时的复杂度。

因此,在为容器设计索引算法时,我们往往考虑的不是某次访问的复杂度,而是要将多次访问同一容器的总体消耗除以访问的次数,估计一个复杂度的平均值。这个平均值也被称为分摊复杂度,而我们的目标是使分摊复杂度尽量小。

在小节中,我们给出了一个朴素的索引算法。该算法的分摊复杂度非常高,因为我们在首次访问容器时,并没有针对容器的特性建立任何可以帮助后续访问的数据结构。相应地,每次访问都需要很高的成本。那么,该如何在索引操作中有效地利用容器本身的信息呢?

2.4.2 容器的重载结构映射

给定一个容器,我们需要将其转换成一种全新的数据结构,建立索引与相应元素的关系,便于根据索引快速获取相应的元素。我们假定元函数要处理的索引与元素都是类型,而要建立类型之间的关系,一种方法就是使用函数:将函数的参数与返回值分别设置为索引类型与元素类型,这样从索引到元素的查找过程就可以映射为一个函数重载的过程。

让我们以映射为例进行讨论。考虑容器tuple<KVBinder<int, unsigned int>, KVBinder< char, unsigned char>>,如果我们能构造出如下的函数声明[6]:

1    unsigned int apply(int*);
2    unsigned char apply(char*);

那么以下的调用:

1    decltype(apply((int*)nullptr)))

将返回unsigned int类型。

decltype关键字用于返回表达式的类型。这里的表达式是apply((int*)nullptr)),它是一个函数调用。编译器在解析到这个语句时,就会触发重载解析机制,寻找与该调用相匹配的函数声明——unsigned int apply(int*),从而确定出该函数会返回一个unsigned int类型的结果。

我们将容器的键-值对转换成了函数的参数-返回值对,并利用重载解析获取了索引所对应的值。这种处理方法本质上是将一个编译器不擅长的问题(获取容器中的元素)转换成一个编译器擅长的问题(重载解析)。重载解析在C++ 的首个标准出现之时就存在了,是C++ 经常会被用到的功能之一,因此几乎每个C++ 编译器都能对其提供高效的支持。

如果我们在容器的首次索引操作时构造出上述重载函数的声明,那么在下一次索引操作时,就可以直接使用该声明获取对应的元素:这几乎不需要引入任何额外的成本。比如,后续想获取char类型对应的元素,那么只需要调用decltype(apply((char*)nullptr)))。

2.4.3 构造重载结构

这里以映射为例讨论如何基于容器构造重载结构。

一个映射可以表示为KVBinder元素序列的形式,而一个KVBinder中又包含了两个类型,分别表示键与值。在本小节中,我们假定容器中每个KVBinder的键都是唯一的,即不会出现类似tuple<KVBinder<int, int>, KVBinder<int, char>>的情况。

让我们回顾一下KVBinder的定义:

1    template <typename TK, typename TV>
2    struct KVBinder
3    
4        using KeyType = TK;
5        using ValueType = TV;
6        static TV apply(TK*);
7    ;

可以看到,其中已经包含了一个函数声明apply,而这个函数声明正是构造重载结构的关键。在此基础上,可以通过如下的元函数将一个映射容器转换为相应的重载结构:

1    template <typename TCon, typename TDefault>
2    struct map_;
3
4     template <template <typename... > typename TCon, typename...TItem,
5               typename TDefault>
6     struct map_<TCon<TItem...>, TDefault> : TItem...
7     
8         using TItem::apply ...;
9         static TDefault apply(...);
10    ; 

这里使用了C++17所提供的语法:第6行表明map_结构体模板派生自容器中的每个元素(每个KVBinder),而第8行则表示将每个KVBinder中的apply声明添加到map_的接口中。这样,map_中就相当于包含了一组名为apply的函数,每个函数的参数都对应KVBinder中的键,而每个函数的返回值都对应KVBinder中的值类型。由于我们假设KVBinder中的键没有重复,因此这组函数声明是合法的。

代码第9行则声明了一个额外的apply函数,用于匹配搜索键为空的情况:此时对应的值类型为TDefault。

以容器tuple<KVBinder<int, unsigned int>, KVBinder<char, unsigned char>>为例:在以该容器为map_结构体模板的输入参数时,系统将构造出如下的实例:

1    struct map_ : KVBinder<int, unsigned int>,
2                  KVBinder<char, unsigned char>
3    
4        static unsigned int apply(int*);
5        static unsigned char apply(char*);
6        static TDefault      apply(...);
7    ;

2.4.4 索引元函数

在引入了map_模板的基础上,我们就可以构造元函数实现索引的功能了:

1    template <typename TCon, typename TKey, typename TDefault>
2    struct Find_
3    
4         using type = decltype(map_<TCon, TDefault>::apply((TKey*)nullptr));
5    ;
6
7    template <typename TCon, typename TKey, typename TDefault = void>
8    using Find = typename Find_<TCon, TKey, TDefault>::type; 

其中的第4行在首次被调用时会构造map_<TCon, TDefault>,并通过其apply成员获取相应的值类型。再次调用时,由于map_<TCon, TDefault>已经被构造过了,因此不会再次构造:相应的查询只需要一次重载解析即可。可以说,只要完成了对某个映射的首次查询,再次查询的成本是非常低的。对同一映射多次查询,其分摊复杂度就会很低了。

2.4.5 允许重复键

前文讨论的方法有一个假设前提:KVBinder中的键没有重复。如果这个前提不成立,那么系统的运行会出问题。如果对容器tuple<KVBinder<char, unsigned int>, KVBinder <char, unsigned char>>采用前文讨论的方法,那么构造出的map_会出现具有相同签名的函数:

1    struct map_ : KVBinder<int, unsigned int>,
2                  KVBinder<char, unsigned char>
3    
4        static unsigned int apply(char*);
5        static unsigned char apply(char*);
6        static void          apply(...);
7    ;

显然,此时的重载解析会出现错误。

要求容器中的键没有重复这一条件实际上是比较苛刻的。可以想象,为了满足这一条件,我们需要在映射的插入操作中引入额外的检测逻辑。那么,能否放宽相应的限制呢?答案是肯定的。从概念的角度上来说,映射的键应当是没有重复的。但在数据结构的表示上,我们可以允许容器中的键存在重复,如果出现重复的键,那么第一个出现的键是有效的。比如tuple<KVBinder<char, unsigned int>, KVBinder<char, unsigned char>> 实际上等价于tuple<KVBinder<char, unsigned int>>,也即当容器中存在两个KVBinder相同的键时,只有第一个键会起作用。

采用这种设计时,如果我们希望向映射中插入新的键-值对,只需要在映射的开头添加一个KVBinder的实例,不需要关注相同的键是否出现过,因此插入操作的效率会提升不少。

但有利必有弊,由于容器中可能存在重复的键,因此我们需要引入更复杂的索引函数在键出现重复时能够进行选择:

1     template <typename TCont, typename TDefault>
2     struct map_;
3
4     template <template<typename...> class TCont, typename TDefault,
5               typename TCurItem, typename... TRemainItems>
6     struct map_<TCont<TCurItem, TRemainItems...>, TDefault>
7         : TCurItem, map_<TCont<TRemainItems...>, TDefault>
8     
9         using TCurItem::apply;
10
11        template <typename T>
12        static auto apply(T ptr)
13        
14            return map_<TCont<TRemainItems...>>::apply(ptr);
15        
16    ;
17
18    template <template<typename...> class TCont, typename TDefault>
19    struct map_<TCont<>, TDefault>
20    
21        static TDefault apply(...);
22    ; 

可以看出,这个map_实现相较之前的版本复杂了不少。

这个map_的实现本质上引入了一个继承体系(第6~7行)。举例如下。

  • map_<tuple<KVBinder<K, V1>, KVBinder<K, V2>>, TDefault>继承自KVBinder<K, V1>与map_<tuple<KVBinder<K, V2>>, TDefault>。
  • map_<tuple<KVBinder<K, V2>>, TDefault> 继承自KVBinder<K, V2> 与map_<tuple<>, TDefault>。

在几乎每个map_的实现内部都包含了两个apply函数。其中一个函数的声明与KVBinder实例相关;另一个函数则是函数模板。根据匹配规则,如果同时存在模板与非模板的匹配函数,那么编译器会首先选择非模板的匹配函数版本,我们正是靠这一规则实现了相同键时值的选择。

我们还是通过一些示例来理解上述代码。考虑如下的调用:

1    using CheckMap = map_<tuple<KVBinder<int, unsigned int>,
2                                KVBinder<int, char>>, void>;
3    using Res = decltype(CheckMap::apply((int*)nullptr));

系统首先会在map_<tuple<KVBinder<int, unsigned int>, KVBinder<int, char>>, void>所提供的两个apply函数,也即unsigned int apply(int*)与apply模板之间进行选择。由于模板的优先级较低,因此编译器会选择普通函数的版本。相应地,Res所对应的值为unsigned int。

现在换一个调用:

1    using CheckMap = map_<tuple<KVBinder<int, unsigned int>,
2                                KVBinder<int, char>>, void>;
3    using Res = decltype(CheckMap::apply((double*)nullptr));

系统首先在map_<tuple<KVBinder<int, unsigned int>, KVBinder<int, char>>, void>所提供的两个apply函数,也即unsigned int apply(int*)与apply模板之间进行选择。由于非函数模板的参数类型不匹配,因此系统会选择函数模板。而函数模板的返回类型为auto,因此系统会根据其内部语句来确定其返回值。

这个内部语句会调用map_<tuple<KVBinder<int, char>>, void>的apply函数,这又引入了两个选择:char apply(int*)与apply模板二选一。由于非函数模板的类型参数不匹配,因此系统会选择函数模板。而函数模板的返回类型为auto,因此系统会根据其内部语句来确定其返回值。

此时,内部语句相当于选择了map_<tuple<>, void> 的apply函数。这个函数返回void,因此整个求值过程将返回void。

可以看出,整个选择过程是相对复杂的。同时,在执行map_的过程中所构造的实例数也要多于上一个map_版本所构造的实例数(至少map_<TCont<TRemainItems...>, TDefault> 在上一个版本中是不会被产生的)。因此,从本质上来说,我们相当于牺牲了查询操作的复杂度,但降低了插入操作的复杂度。

事实上,如果采用这种方式,删除操作的复杂度也可以被降低。根据我们的实现,如果一个键不存在,那么应当返回默认值TDefault。因此,我们可以将删除键“Key”的操作简单实现为一个插入KVBinder<Key, TDefault>的操作即可。

以上,我们讨论了两种映射的实现方式。这两种映射一种应用于查询较多,而插入、删除较少的情形;一种应用于插入、删除较多,而查询相对较少的情形。具体选择哪种方式,要根据实际应用而定。在我们的深度学习框架中采用了前一种实现方式。

本文截选自《动手打造深度学习框架》

本书基于C++编写,旨在带领读者动手打造出一个深度学习框架。本书首先介绍C++模板元编程的基础技术,然后在此基础上剖析深度学习框架的内部结构,逐一实现深度学习框架中的各个组件和功能,包括基本数据结构、运算与表达模板、基本层、复合层、循环层、求值与优化等,最终打造出一个深度学习框架。本书将深度学习框架与C++模板元编程有机结合,更利于读者学习和掌握使用C++开发大型项目的方法。

本书适合对C++有一定了解,希望深入了解深度学习框架内部实现细节,以及提升C++程序设计水平的读者阅读。

 

以上是关于动手打造深度学习框架:元数据结构与算法的主要内容,如果未能解决你的问题,请参考以下文章

自己动手实现深度学习框架-3 自动分批训练, 缓解过拟合

动手学深度学习v2学习笔记02:线性代数矩阵计算自动求导

对比《动手学深度学习》 PDF代码+《神经网络与深度学习 》PDF

《动手学深度学习》线性回归从零开始(linear-regression-scratch)

《动手学深度学习》线性回归从零开始(linear-regression-scratch)

元宇宙企业大比拼:云宇宙数据中台:iwemeta.com