理解这种计算方式 (x^e mod n)

Posted

技术标签:

【中文标题】理解这种计算方式 (x^e mod n)【英文标题】:Understanding this way of computing ( x^e mod n) 【发布时间】:2020-10-09 05:03:42 【问题描述】:

在 c++ 中进行 RSA 加密,发现 cmath 中的 pow() 函数给了我不正确的结果。

在网上查看后,我发现了一些可以为我完成上述过程的代码,但我很难理解。

这是代码:

long long int modpow(long long int base, long long int exp, long long int modulus) 
    base %= modulus;
    long long int result = 1;
  while (exp > 0) 
    if (exp & 1) 
    result = (result * base) % modulus; 
    
    base = (base * base) % modulus;
    exp >>= 1; 
  
  return result;

(此代码不是原始代码)

我很难理解这个功能。

我知道exp >>=1; 是左移 1 位,(exp & 1) 根据最低有效位返回 1 或 0,但我不明白这对最终答案有何影响。

例如:

if (exp & 1) 
    result = (result * base) % modulus; 
        

如果 exp 是奇数,(result * base) % modulus 的用途是什么?

希望有人可以向我解释这个功能,因为我不想只是复制它。

【问题讨论】:

当您不理解它时,一个有用的练习是使用调试器或使用老式笔和纸逐步完成它。当您将来遇到这样的代码时,这也将有所帮助。 en.wikipedia.org/wiki/Exponentiation_by_squaring 请注意,result * basebase * base 都受到有符号整数溢出 (UB) 的影响。完整范围alternatives 存在。 【参考方案1】:

代码被编写为“聪明”而不是清晰。这种神秘的风格通常不应该在核心库之外进行(性能至关重要),即使使用它,让 cmets 解释正在发生的事情也会很好。 这是代码的注释版本。

long long int modpow(long long int base, long long int exp, long long int modulus)

  base %= modulus;                          // Eliminate factors; keeps intermediate results smaller
  long long int result = 1;                 // Start with the multiplicative unit (a.k.a. one)
                                            // Throughout, we are calculating:
                                            // `(result*base^exp)%modulus`
  while (exp > 0)                          // While the exponent has not been exhausted.
    if (exp & 1)                           // If the exponent is odd
      result = (result * base) % modulus;   // Consume one application of the base, logically
                                            // (but not actually) reducing the exponent by one.
                                            // That is: result * base^exp == (result*base)*base^(exp-1)
    
    base = (base * base) % modulus;         // The exponent is logically even. Apply B^(2n) == (B^2)^n.
    exp >>= 1;                              // The base is squared and the exponent is divided by 2 
  
  return result;

现在更有意义了吗?


对于那些想知道如何使这个神秘代码更清晰的人,我提供以下版本。主要有三项改进。

首先,按位运算已被等效的算术运算取代。如果要证明该算法有效,则将使用算术,而不是按位运算符。事实上,无论数字如何表示,该算法都有效——不需要“位”的概念,更不用说“位运算符”了。因此,实现该算法的自然方法是使用算术。使用按位运算符会消除清晰度,几乎没有任何好处。编译器足够聪明,可以生成相同的机器代码,但有一个例外。由于exp 被声明为long long int 而不是long long unsigned,因此与exp >>= 1 相比,计算exp /= 2 时有一个额外的步骤。 (我不知道为什么 exp 被签名;对于负指数,该函数在概念上和技术上都是不正确的。)另见 premature-optimization。

其次,我创建了一个帮助函数以提高可读性。虽然改进很小,但它没有性能成本。我希望任何有价值的编译器都能内联该函数。

// Wrapper for detecting if an integer is odd.
bool is_odd(long long int n)

    return n % 2 != 0; 

第三,添加了 cmets 来解释发生了什么。虽然有些人(不是我)可能认为“标准的从右到左模块化二进制求幂算法”是每个 C++ 编码人员都需要的知识,但我更愿意对可能阅读的人做更少的假设我将来的代码。尤其是如果那个人是我,在远离它多年后回到代码。

现在是代码,因为我希望看到编写的当前功能:

// Returns `(base**exp) % modulus`, where `**` denotes exponentiation.
// Assumes `exp` is non-negative.
// Assumes `modulus` is non-zero.
// If `exp` is zero, assumes `modulus` is neither 1 nor -1.
long long int modpow(long long int base, long long int exp, long long int modulus)

    // NOTE: This algorithm is known as the "right-to-left binary method" of
    // "modular exponentiation".

    // Throughout, we'll keep numbers smallish by using `(A*B) % C == ((A%C)*B) % C`.
    // The first application of this principle is to the base.
    base %= modulus;
    // Intermediate results will be stored modulo `modulus`.
    long long int result = 1;

    // Loop invariant:
    //     The value to return is `(result * base**exp) % modulus`.
    // Loop goal:
    //     Reduce `exp` to the point where `base**exp` is 1.
    while (exp > 0) 
        if ( is_odd(exp) ) 
            // Shift one factor of `base` to `result`:
            // `result * base^exp == (result*base) * base^(exp-1)`
            result = (result * base) % modulus;
            //--exp;  // logically happens, but optimized out.
            // We are now in the "`exp` is even" case.
        
        // Reduce the exponent by increasing the base: `B**(2n) == (B**2)**n`.
        base = (base * base) % modulus;
        exp /= 2;
    

    return result;

生成的机器代码几乎相同。如果性能真的很关键,我可以回到 exp >>= 1,但前提是不允许更改 exp 的类型。

【讨论】:

对我来说看起来像是标准的从右到左的模块化二进制求幂算法。有什么不清楚的地方,或者更确切地说,如何才能更清楚? @PresidentJamesK.Polk 经过考虑,我决定采纳您的建议。我补充了我的回答,描述了不清楚的地方以及如何使其更清楚。 n % 2 != 0 是一个很好的奇数测试 - 很容易被好的编译器优化。当n < 0 时,n & 1 是一个不正确的 odd 测试,用于长期丢失的补码机器。 角细节:modpow(base, 0, 1),错误返回1。可以使用long long int result = 1 % modulus; @chux-ReinstateMonica 我故意没有对问题中的代码进行更正,只是澄清了该代码的作用。当我决定展示一种更清晰的编码风格时,我扩展了我的目标,包括记录该代码没有做什么,即不正确的情况。我相信您指出的情况 modpow(base, 0, 1) 包含在该文档中(“如果 exp 为零,则假定 modulus 既不是 1 也不是 -1。”)。

以上是关于理解这种计算方式 (x^e mod n)的主要内容,如果未能解决你的问题,请参考以下文章

使用递归和 mod 计算指数的算法

有限域F_p上的椭圆曲线E:y^2=x^3+ax+b(mod p)的Mordell-Weil群的计算

LayerNorm的理解

LayerNorm的理解

RSA加解密原理以及三种填充模式

深入理解计算机系统笔记