模板函数重载和 SFINAE 实现

Posted

技术标签:

【中文标题】模板函数重载和 SFINAE 实现【英文标题】:Template function overloading and SFINAE implementations 【发布时间】:2018-08-16 17:38:38 【问题描述】:

我正在花一些时间学习如何在 C++ 中使用模板。我从来没有用过它们 之前,我并不总是确定在不同的情况下什么可以实现,什么不能实现。

作为一个练习,我包装了一些我在活动中使用的 Blas 和 Lapack 函数, 我目前正在处理?GELS 的包装(用于评估线性方程组的解)。

 A x + b = 0

?GELS 函数(仅用于实数值)有两个名称:SGELS,用于单精度向量和 DGELS 双精度。

我对接口的想法是这样一个函数solve

 const std::size_t rows = /* number of rows for A */;
 const std::size_t cols = /* number of cols for A */;
 std::array< double, rows * cols > A =  /* values */ ;
 std::array< double, ??? > b =  /* values */ ;  // ??? it can be either
                                                  // rows or cols. It depends on user
                                                  // problem, in general
                                                  // max( dim(x), dim(b) ) =
                                                  // max( cols, rows )     
 solve< double, rows, cols >(A, b);
 // the solution x is stored in b, thus b 
 // must be "large" enough to accomodate x

根据用户要求,问题可能是过度确定的或未确定的,这意味着:

如果超定dim(b) &gt; dim(x)(解是伪逆) 如果未确定dim(b) &lt; dim(x)(解决方案是 LSQ 最小化) 或者dim(b) = dim(x)的正常情况(解是A的倒数)

(不考虑奇异情况)。

由于?GELS 将结果存储在输入向量b 中,因此std::array 应该 有足够的空间来容纳解决方案,如代码 cmets (max(rows, cols)) 中所述。

我想(编译时间)确定采用哪种解决方案(这是一个参数更改 在?GELS 电话中)。我有两个功能(为了这个问题,我正在简化), 处理精度并且已经知道b 的维度和rows/cols 的数量:

namespace wrap 

template <std::size_t rows, std::size_t cols, std::size_t dimb>
void solve(std::array<float, rows * cols> & A, std::array<float, dimb> & b) 
  SGELS(/* Called in the right way */);


template <std::size_t rows, std::size_t cols, std::size_t dimb>
void solve(std::array<double, rows * cols> & A, std::array<double, dimb> & b) 
  DGELS(/* Called in the right way */);


; /* namespace wrap */

它们是内部包装的一部分。用户函数,确定所需尺寸 在b 向量中通过模板:

#include <type_traits>

/** This struct makes the max between rows and cols */
template < std::size_t rows, std::size_t cols >
struct biggest_dim 
  static std::size_t const value = std::conditional< rows >= cols, std::integral_constant< std::size_t, rows >,
                                                     std::integral_constant< std::size_t, cols > >::type::value;
;

/** A type for the b array is selected using "biggest_dim" */
template < typename REAL_T, std::size_t rows, std::size_t cols >
using b_array_t = std::array< REAL_T, biggest_dim< rows, cols >::value >;

/** Here we have the function that allows only the call with b of
 *  the correct size to continue */
template < typename REAL_T, std::size_t rows, std::size_t cols >
void solve(std::array< REAL_T, cols * rows > & A, b_array_t< REAL_T, cols, rows > & b) 
  static_assert(std::is_floating_point< REAL_T >::value, "Only float/double accepted");
  wrap::solve< rows, cols, biggest_dim< rows, cols >::value >(A, b);
 

以这种方式它确实有效。但我想更进一步,我真的不知道如何去做。 如果用户尝试使用大小太小的b 调用solve,编译器会引发极其难以阅读的错误。

我正在尝试插入 static_assert 帮助用户理解他的错误。但任何出现在我脑海中的方向 需要使用具有相同签名的两个函数(这就像模板重载?) 我找不到 SFINAE 策略(实际上它们根本无法编译)。

您认为对于错误b 维度编译时 不更改用户界面 的情况,是否可以提出静态断言? 我希望这个问题足够清楚。

@Caninonos:对我来说,用户界面就是用户调用求解器的方式,即:

 solve< type, number of rows, number of cols > (matrix A, vector b)

这是我对锻炼施加的限制,以提高我的技能。这意味着,我不知道是否真的可以实现解决方案。 b 的类型必须与函数调用匹配,如果我添加另一个模板参数并更改用户界面很容易,违反了我的约束。

最小的完整和工作示例

这是一个最小的完整且有效的示例。根据要求,我删除了对线性代数概念的任何引用。这是个数的问题。案例是:

N1 = 2, N2 =2。由于N3 = max(N1, N2) = 2 一切正常 N1 = 2, N2 =1。由于N3 = max(N1, N2) = N1 = 2 一切正常 N1 = 1, N2 =2。由于N3 = max(N1, N2) = N2 = 2 一切正常 N1 = 1, N2 =2。由于N3 = N1 = 1 &lt; N2 它正确地引发了编译错误。我想用一个静态断言来拦截编译错误,该断言解释了N3 的维度是错误的事实。就目前而言,这个错误很难阅读和理解。

你可以view and test it online here

【问题讨论】:

我不太确定我是否理解您所说的“在编译时更改用户界面”是什么意思,但也许您只是想在您的第一个版本的solve 中添加一个static_assert(dimb == biggest_dim&lt; rows, cols &gt;::value, "msg") ?跨度> 是你的colsrows constexpr 吗? @Caninonos 起初我认为它不起作用,但这是一个好主意。我只需要改变我看待问题的方式...... @W.F.是的,他们是 你能把这个问题减少到最小的模板问题吗?现在它似乎陷入了所有这些线性代数的泥潭 【参考方案1】:

首先进行一些简化设计并提高可读性的改进:

不需要biggest_dimstd::max 从 C++14 开始就是 constexpr。你应该改用它。

不需要b_array_t。你可以写 std::array&lt; REAL_T, std::max(N1, N2)&gt;

现在解决您的问题。 C++17 中的一种好方法是:

template < typename REAL_T, std::size_t N1, std::size_t N2, std::size_t N3>
void solve(std::array< REAL_T, N1 * N2 > & A, std::array< REAL_T, N3> & b) 

    if constexpr (N3 == std::max(N1, N2))
        wrap::internal< N1, N2, N3 >(A, b);
    else
        static_assert(N3 == std::max(N1, N2), "invalid 3rd dimension");

        // don't write static_assert(false)
        // this would make the program ill-formed (*)
 

或者,正如@max66 所指出的那样

template < typename REAL_T, std::size_t N1, std::size_t N2, std::size_t N3>
void solve(std::array< REAL_T, N1 * N2 > & A, std::array< REAL_T, N3> & b) 

    static_assert(N3 == std::max(N1, N2), "invalid 3rd dimension");

    if constexpr (N3 == std::max(N1, N2))
        wrap::internal< N1, N2, N3 >(A, b);

 

Tadaa!! 简单、优雅、漂亮的错误信息。

constexpr if 版本与 static_assert 即版本之间的区别:

void solve(...)

   static_assert(...);
   wrap::internal(...);

仅使用static_assert 编译器将尝试实例化wrap::internal,即使static_assert 失败,也会污染错误输出。如果对 wrap::internal 的调用不是主体的一部分,则使用 constexpr 条件失败,因此错误输出是干净的。


(*) 我不只是写static_asert(false, "error msg) 的原因是因为那会使程序格式错误,不需要诊断。见constexpr if and static_assert


如果需要,您还可以通过将模板参数移到不可扣除的参数之后,使 float / double 可扣除:

template < std::size_t N1, std::size_t N2, std::size_t N3,  typename REAL_T>
void solve(std::array< REAL_T, N1 * N2 > & A, std::array< REAL_T, N3> & b) 

所以调用变成:

solve< n1_3, n2_3>(A_3, b_3);

【讨论】:

或更简单的:wrap::internal&lt; N1, N2, N3 &gt;(A, b); @max66 是的,确实如此。谢谢。 为了避免static_assert() UB 风险问题(并使解决方案符合C++14),您可以避免if 本身:函数的主体可以简单地为 static_assert(N3 == std::max(N1, N2), "invalid 3rd dimension"); wrap::internal&lt; N1, N2, N3 &gt;(A, b); 是的...是一个很好的理由...但是您可以仅为wrap::internal维护if constexpr并从中释放static_assert() static_assert(N3 == std::max(N1, N2), "invalid 3rd dimension"); if constexpr (N3 == std::max(N1, N3)) wrap::internal&lt; N1, N2, N3 &gt;(A, b); 非常好的答案。但是你改变了求解函数的模板,所以用户界面也改变了,不是吗? (我不是快速推理模板......可能我错了......)【参考方案2】:

您为什么不尝试将tag dispatch 与一些static_asserts 结合起来?我希望,以下是实现您想要解决的问题的一种方法。我的意思是,所有三个正确的案例都正确地传递给正确的blas 调用,处理了不同的类型和尺寸不匹配,并且还处理了关于floatdoubles 的违规行为,所有这些都以用户友好的方式进行,感谢static_assert

编辑。我不确定您的C++ 版本要求,但下面是C++11 友好。

#include <algorithm>
#include <iostream>
#include <type_traits>

template <class value_t, int nrows, int ncols> struct Matrix ;
template <class value_t, int rows> struct Vector ;

template <class value_t> struct blas;

template <> struct blas<float> 
  static void overdet(...)  std::cout << __PRETTY_FUNCTION__ << std::endl; 
  static void underdet(...)  std::cout << __PRETTY_FUNCTION__ << std::endl; 
  static void normal(...)  std::cout << __PRETTY_FUNCTION__ << std::endl; 
;

template <> struct blas<double> 
  static void overdet(...)  std::cout << __PRETTY_FUNCTION__ << std::endl; 
  static void underdet(...)  std::cout << __PRETTY_FUNCTION__ << std::endl; 
  static void normal(...)  std::cout << __PRETTY_FUNCTION__ << std::endl; 
;

class overdet ;
class underdet ;
class normal ;

template <class T1, class T2, int nrows, int ncols, int dim>
void solve(const Matrix<T1, nrows, ncols> &lhs, Vector<T2, dim> &rhs) 
  static_assert(std::is_same<T1, T2>::value,
                "lhs and rhs must have the same value types");
  static_assert(dim >= nrows && dim >= ncols,
                "rhs does not have enough space");
  static_assert(std::is_same<T1, float>::value ||
                std::is_same<T1, double>::value,
                "Only float or double are accepted");
  solve_impl(lhs, rhs,
             typename std::conditional<(nrows < ncols), underdet,
             typename std::conditional<(nrows > ncols), overdet,
                                                        normal>::type>::type);


template <class value_t, int nrows, int ncols, int dim>
void solve_impl(const Matrix<value_t, nrows, ncols> &lhs,
                Vector<value_t, dim> &rhs, underdet) 
  /* get the pointers and dimension information from lhs and rhs */
  blas<value_t>::underdet(
      /* trans, m, n, nrhs, A, lda, B, ldb, work, lwork, info */);


template <class value_t, int nrows, int ncols, int dim>
void solve_impl(const Matrix<value_t, nrows, ncols> &lhs,
                Vector<value_t, dim> &rhs, overdet) 
  /* get the pointers and dimension information from lhs and rhs */
  blas<value_t>::overdet(
      /* trans, m, n, nrhs, A, lda, B, ldb, work, lwork, info */);


template <class value_t, int nrows, int ncols, int dim>
void solve_impl(const Matrix<value_t, nrows, ncols> &lhs,
                Vector<value_t, dim> &rhs, normal) 
  /* get the pointers and dimension information from lhs and rhs */
  blas<value_t>::normal(
      /* trans, m, n, nrhs, A, lda, B, ldb, work, lwork, info */);


int main() 
  /* valid types */
  Matrix<float, 2, 4> A1;
  Matrix<float, 4, 4> A2;
  Matrix<float, 5, 4> A3;
  Vector<float, 4> b1;
  Vector<float, 5> b2;
  solve(A1, b1);
  solve(A2, b1);
  solve(A3, b2);

  Matrix<int, 4, 4> A4;
  Vector<int, 4> b3;
  // solve(A4, b3); // static_assert for float & double

  Matrix<float, 4, 4> A5;
  Vector<int, 4> b4;
  // solve(A5, b4); // static_assert for different types

  // solve(A3, b1); // static_assert for dimension problem

  return 0;

【讨论】:

这是一个很酷的答案,而且非常复杂(至少对我而言)。它超出了我的范围(目前),我需要几个小时才能完全理解它。但不幸的是,@bolov 答案是最容易适应我当前代码库的答案......虽然我赞成你的答案,但很抱歉我不能奖励你更多...... 谢谢@MatteoRagni 的支持 :) 上面的代码摘录,除了MatrixVector 类型,是我用于我的项目的。我需要包装blaslapack 电话。在static_assertions 之后的代码依赖于零大小的类型来分派到相应的函数调用。 solve_impl 末尾的类型只是在编译时确定了正确的函数调用,而零大小的对象被优化掉了。当然,您不应该将lhsrhs 直接通过管道传递给blas&lt;T&gt;::* 函数,因为您不希望为每种类型编译这些函数【参考方案3】:

你必须考虑为什么接口提供了这个(复杂的)参数。作者想到了几件事。首先,您可以在一个函数中解决A x + b == 0A^T x + b == 0 形式的问题。其次,给定的Ab 实际上可以指向比算法所需的矩阵更大的内存。这可以通过LDALDB 参数看出。

正是子寻址使事情变得复杂。如果您想要一个简单但可能足够有用的 API,您可以选择忽略该部分:

using ::std::size_t;
using ::std::array;

template<typename T, size_t rows, size_t cols>
using matrix = array<T, rows * cols>;

enum class TransposeMode : bool 
  None = false, Transposed = true
;

// See https://***.com/questions/14637356/
template<typename T> struct always_false_t : std::false_type ;
template<typename T> constexpr bool always_false_v = always_false_t<T>::value;

template < typename T, size_t rowsA, size_t colsA, size_t rowsB, size_t colsB
    , TransposeMode mode = TransposeMode::None >
void solve(matrix<T, rowsA, colsA>& A, matrix<T, rowsB, colsB>& B)

  // Since the algorithm works in place, b needs to be able to store
  // both input and output
  static_assert(rowsB >= rowsA && rowsB >= colsA, "b is too small");
  // LDA = rowsA, LDB = rowsB
  if constexpr (::std::is_same_v<T, float>) 
    // SGELS(mode == TransposeMode::None ? 'N' : 'T', ....);
   else if constexpr (::std::is_same_v<T, double>) 
    // DGELS(mode == TransposeMode::None ? 'N' : 'T', ....);
   else 
    static_assert(always_false_v<T>, "Unknown type");
  


现在,使用LDALDB 寻址可能的子地址。我建议您将其作为数据类型的一部分,而不是直接作为模板签名的一部分。您希望拥有自己的矩阵类型,可以引用矩阵中的存储。也许是这样的:

// Since we store elements in a column-major order, we can always 
// pretend that our matrix has less columns than it actually has
// less rows than allocated. We can not equally pretend less rows
// otherwise the addressing into the array is off.
// Thus, we'd only four total parameters:
// offset = columnSkipped * actualRows + rowSkipped), actualRows, rows, cols
// We store the offset implicitly by adjusting our begin pointer
template<typename T, size_t rows, size_t cols, size_t actualRows>
class matrix_view  // Name derived from string_view :)
  static_assert(actualRows >= rows);
  T* start;
  matrix_view(T* start) : start(start) 
  template<typename U, size_t r, size_t c, size_t ac>
  friend class matrix_view;
public:
  template<typename U>
  matrix_view(matrix<U, rows, cols>& ref)
  : start(ref.data())  

  template<size_t rowSkipped, size_t colSkipped, size_t newRows, size_t newCols>
  auto submat() 
    static_assert(colSkipped + newCols <= cols, "can only shrink");
    static_assert(rowSkipped + newRows <= rows, "can only shrink");
    auto newStart = start + colSkipped * actualRows + rowSkipped;
    using newType = matrix_view<T, newRows, newCols, actualRows>
    return newType newStart ;
  
  T* data() 
    return start;
  
;

现在,您需要调整您的界面以适应这种新的数据类型,这基本上只是引入了一些新参数。支票基本保持不变。

// Using this instead of just type-defing allows us to use deducation guides
// Replaces: using matrix = std::array from above
template<typename T, size_t rows, size_t cols>
class matrix 
public:
    std::array<T, rows * cols> storage;
    auto data()  return storage.data(); 
    auto data() const  return storage.data(); 
;

extern void dgels(char TRANS
  , integer M, integer N , integer NRHS
  , double* A, integer LDA
  , double* B, integer LDB); // Mock, missing a few parameters at the end
// Replaces the solve method from above
template < typename T, size_t rowsA, size_t colsA, size_t actualRowsA
    , size_t rowsB, size_t colsB, size_t actualRowsB
    , TransposeMode mode = TransposeMode::None >
void solve(matrix_view<T, rowsA, colsA, actualRowsA> A, matrix_view<T, rowsB, colsB, actualRowsB> B)

    static_assert(rowsB >= rowsA && rowsB >= colsA, "b is too small");
    char transMode = mode == TransposeMode::None ? 'N' : 'T';
    // LDA = rowsA, LDB = rowsB
    if constexpr (::std::is_same_v<T, float>) 
      fgels(transMode, rowsA, colsA, colsB, A.data(), actualRowsA, B.data(), actualRowsB);
     else if constexpr (::std::is_same_v<T, double>) 
      dgels(transMode, rowsA, colsA, colsB, A.data(), actualRowsA, B.data(), actualRowsB);
    // DGELS(, ....);
     else 
    static_assert(always_false_v<T>, "Unknown type");
    

示例用法:

int main() 
  matrix<float, 5, 5> A;
  matrix<float, 4, 1> b;

  auto viewA = matrix_viewA.submat<1, 1, 4, 4>();
  auto viewb = matrix_viewb;
  solve(viewA, viewb);
  // solve(viewA, viewb.submat<1, 0, 2, 1>()); // Error: b is too small
  // solve(matrix_viewA, viewb.submat<0, 0, 5, 1>()); // Error: can only shrink (b is 4x1 and can not be viewed as 5x1)

【讨论】:

我要谢谢你。这是一个很好的答案,解释得很好,而且非常完整。不幸的是,@bolov 给出了一个非常适合我的情况的答案,并且几乎不需要零努力即可实施。不过,你有我的赞成票。对不起,我不能奖励你更多...... @MatteoRagni 我很感激。我敢肯定,从长远来看,博洛夫的回答会帮助更多的人,至少如果他们遇到这个问题的话。我的回答更像是我自己想出一种方法来调整相关界面 @MatteoRagni 我添加了一个示例用法,如果您仍然感兴趣,还可以更改一些代码。如果您需要更多幕后提供的功能,您可能需要重新考虑 谢谢!它实际上增加了答案的清晰度!

以上是关于模板函数重载和 SFINAE 实现的主要内容,如果未能解决你的问题,请参考以下文章

SFINAE 与 type_traits

类模板构造函数中的 SFINAE

专门化模板成员函数 [SFINAE]

具有模板化类成员函数的 SFINAE

实验2:函数重载函数模板简单类的定义和实现

C++模板进阶指南:SFINAE