C++ 中高数的模幂运算

Posted

技术标签:

【中文标题】C++ 中高数的模幂运算【英文标题】:Modular Exponentiation for high numbers in C++ 【发布时间】:2011-01-13 12:02:18 【问题描述】:

因此,我最近一直致力于实施 Miller-Rabin 素性检验。我将它限制在所有 32 位数字的范围内,因为这是一个有趣的项目,我正在做它来熟悉 c++,并且我不想使用任何 64 位一会儿。另一个好处是该算法对所有 32 位数字都是确定性的,因此我可以显着提高效率,因为我确切知道要测试哪些见证人。

因此,对于较小的数字,该算法运行得非常好。但是,该过程的一部分依赖于模幂运算,即 (num ^ pow) % mod。所以,例如,

3 ^ 2 % 5 = 
9 % 5 = 
4

这是我用于模幂运算的代码:

unsigned mod_pow(unsigned num, unsigned pow, unsigned mod)

    unsigned test;
    for(test = 1; pow; pow >>= 1)
    
        if (pow & 1)
            test = (test * num) % mod;
        num = (num * num) % mod;
    

    return test;


正如您可能已经猜到的那样,当参数都是非常大的数字时会出现问题。例如,如果我想测试数字 673109 的素性,我将不得不找到:

(2 ^ 168277) % 673109

现在 2 ^ 168277 是一个非常大的数字,在某个过程中它会溢出测试,从而导致错误的评估。

在反面,参数如

4000111222 ^ 3 % 1608

也出于同样的原因,评估不正确。

是否有人对模幂运算提出建议,以防止这种溢出和/或操纵它以产生正确的结果? (在我看来,溢出只是取模的另一​​种形式,即 num % (UINT_MAX+1))

【问题讨论】:

【参考方案1】:

两件事:

您是否使用了合适的数据类型?换句话说,UINT_MAX 是否允许您将 673109 作为参数?

不,它没有,因为在某一时刻你的代码不起作用,因为在某一时刻你有 num = 2^16num = ... 导致溢出。使用更大的数据类型来保存这个中间值。

如何在每个可能的溢出机会上取模,例如:

test = ((test % mod) * (num % mod)) % mod;

编辑:

unsigned mod_pow(unsigned num, unsigned pow, unsigned mod)

    unsigned long long test;
    unsigned long long n = num;
    for(test = 1; pow; pow >>= 1)
    
        if (pow & 1)
            test = ((test % mod) * (n % mod)) % mod;
        n = ((n % mod) * (n % mod)) % mod;
    

    return test; /* note this is potentially lossy */


int main(int argc, char* argv[])


    /* (2 ^ 168277) % 673109 */
    printf("%u\n", mod_pow(2, 168277, 673109));
    return 0;

【讨论】:

【参考方案2】:

您可以使用以下身份:

(a * b) (mod m) === (a (mod m)) * (b (mod m)) (mod m)

尝试直接使用它并逐步改进。

    if (pow & 1)
        test = ((test % mod) * (num % mod)) % mod;
    num = ((num % mod) * (num % mod)) % mod;

【讨论】:

感谢两位的建议,但根据算法的性质,test 和 num 总是小于 mod,所以: (test % mod) = test 和 (num % mod ) = test 因此身份无法帮助我,因为即使 num 和 test 小于 mod,该函数也会失败。此外,无符号整数允许我将 673109 作为参数。 UINT_MAX = 4 294 967 295 对于我的计算机。【参考方案3】:

Exponentiation by squaring 仍然“有效”用于模幂运算。您的问题不在于2 ^ 168277 是一个非常大的数字,而是您的中间结果之一是一个相当大的数字(大于 2^32),因为 673109 大于 2^16。

所以我认为以下会做。可能我错过了一个细节,但基本思想是有效的,这就是“真实”加密代码可能如何进行大模幂运算(尽管不是使用 32 位和 64 位数字,而是使用永远不必变得大于2 * log(模数)):

按照您的做法,从平方开始求幂。 在 64 位无符号整数中执行实际平方。 在每个步骤中减少模 673109 以返回到 32 位范围内,就像您所做的那样。

如果你的 C++ 实现没有 64 位整数,显然这有点尴尬,尽管你总是可以伪造一个。

幻灯片 22 上有一个示例:http://www.cs.princeton.edu/courses/archive/spr05/cos126/lectures/22.pdf,尽管它使用的数字非常小(小于 2^16),因此它可能无法说明您不知道的任何内容。

如果您在开始之前仅减少 40001112221608 为模,那么您的另一个示例 4000111222 ^ 3 % 1608 将适用于您当前的代码。 1608 足够小,您可以安全地将任意两个 mod-1608 数字乘以 32 位整数。

【讨论】:

谢谢你,成功了。只是出于好奇,您知道任何不需要使用更大内存大小的方法吗?我相信它们会派上用场。 我不知道。您需要将两个数字相乘,最大为 673108,mod 673109。显然,您可以分解并用较小的“数字”进行长乘法,例如 2^10。但是,一旦您在软件中实现乘法和除法,您最好只为将两个 32 位值相乘得到 64 位结果,然后除以提取 32 位余数的特殊情况来实现它。可能有一些核心优化可以做到你需要的最低限度,但我不知道它们并且在 C++ 中伪造 64 位 int 并不是那么难。【参考方案4】:

我最近用 C++ 为 RSA 写了一些东西,不过有点乱。

#include "BigInteger.h"
#include <iostream>
#include <sstream>
#include <stack>

BigInteger::BigInteger() 
    digits.push_back(0);
    negative = false;


BigInteger::~BigInteger() 


void BigInteger::addWithoutSign(BigInteger& c, const BigInteger& a, const BigInteger& b) 
    int sum_n_carry = 0;
    int n = (int)a.digits.size();
    if (n < (int)b.digits.size()) 
        n = b.digits.size();
    
    c.digits.resize(n);
    for (int i = 0; i < n; ++i) 
        unsigned short a_digit = 0;
        unsigned short b_digit = 0;
        if (i < (int)a.digits.size()) 
            a_digit = a.digits[i];
        
        if (i < (int)b.digits.size()) 
            b_digit = b.digits[i];
        
        sum_n_carry += a_digit + b_digit;
        c.digits[i] = (sum_n_carry & 0xFFFF);
        sum_n_carry >>= 16;
    
    if (sum_n_carry != 0) 
        putCarryInfront(c, sum_n_carry);
    
    while (c.digits.size() > 1 && c.digits.back() == 0) 
        c.digits.pop_back();
    
    //std::cout << a.toString() << " + " << b.toString() << " == " << c.toString() << std::endl;


void BigInteger::subWithoutSign(BigInteger& c, const BigInteger& a, const BigInteger& b) 
    int sub_n_borrow = 0;
    int n = a.digits.size();
    if (n < (int)b.digits.size())
        n = (int)b.digits.size();
    c.digits.resize(n);
    for (int i = 0; i < n; ++i) 
        unsigned short a_digit = 0;
        unsigned short b_digit = 0;
        if (i < (int)a.digits.size())
            a_digit = a.digits[i];
        if (i < (int)b.digits.size())
            b_digit = b.digits[i];
        sub_n_borrow += a_digit - b_digit;
        if (sub_n_borrow >= 0) 
            c.digits[i] = sub_n_borrow;
            sub_n_borrow = 0;
         else 
            c.digits[i] = 0x10000 + sub_n_borrow;
            sub_n_borrow = -1;
        
    
    while (c.digits.size() > 1 && c.digits.back() == 0) 
        c.digits.pop_back();
    
    //std::cout << a.toString() << " - " << b.toString() << " == " << c.toString() << std::endl;


int BigInteger::cmpWithoutSign(const BigInteger& a, const BigInteger& b) 
    int n = (int)a.digits.size();
    if (n < (int)b.digits.size())
        n = (int)b.digits.size();
    //std::cout << "cmp(" << a.toString() << ", " << b.toString() << ") == ";
    for (int i = n-1; i >= 0; --i) 
        unsigned short a_digit = 0;
        unsigned short b_digit = 0;
        if (i < (int)a.digits.size())
            a_digit = a.digits[i];
        if (i < (int)b.digits.size())
            b_digit = b.digits[i];
        if (a_digit < b_digit) 
            //std::cout << "-1" << std::endl;
            return -1;
         else if (a_digit > b_digit) 
            //std::cout << "+1" << std::endl;
            return +1;
        
    
    //std::cout << "0" << std::endl;
    return 0;


void BigInteger::multByDigitWithoutSign(BigInteger& c, const BigInteger& a, unsigned short b) 
    unsigned int mult_n_carry = 0;
    c.digits.clear();
    c.digits.resize(a.digits.size());
    for (int i = 0; i < (int)a.digits.size(); ++i) 
        unsigned short a_digit = 0;
        unsigned short b_digit = b;
        if (i < (int)a.digits.size())
            a_digit = a.digits[i];
        mult_n_carry += a_digit * b_digit;
        c.digits[i] = (mult_n_carry & 0xFFFF);
        mult_n_carry >>= 16;
    
    if (mult_n_carry != 0) 
        putCarryInfront(c, mult_n_carry);
    
    //std::cout << a.toString() << " x " << b << " == " << c.toString() << std::endl;


void BigInteger::shiftLeftByBase(BigInteger& b, const BigInteger& a, int times) 
    b.digits.resize(a.digits.size() + times);
    for (int i = 0; i < times; ++i) 
        b.digits[i] = 0;
    
    for (int i = 0; i < (int)a.digits.size(); ++i) 
        b.digits[i + times] = a.digits[i];
    


void BigInteger::shiftRight(BigInteger& a) 
    //std::cout << "shr " << a.toString() << " == ";
    for (int i = 0; i < (int)a.digits.size(); ++i) 
        a.digits[i] >>= 1;
        if (i+1 < (int)a.digits.size()) 
            if ((a.digits[i+1] & 0x1) != 0) 
                a.digits[i] |= 0x8000;
            
        
    
    //std::cout << a.toString() << std::endl;


void BigInteger::shiftLeft(BigInteger& a) 
    bool lastBit = false;
    for (int i = 0; i < (int)a.digits.size(); ++i) 
        bool bit = (a.digits[i] & 0x8000) != 0;
        a.digits[i] <<= 1;
        if (lastBit)
            a.digits[i] |= 1;
        lastBit = bit;
    
    if (lastBit) 
        a.digits.push_back(1);
    


void BigInteger::putCarryInfront(BigInteger& a, unsigned short carry) 
    BigInteger b;
    b.negative = a.negative;
    b.digits.resize(a.digits.size() + 1);
    b.digits[a.digits.size()] = carry;
    for (int i = 0; i < (int)a.digits.size(); ++i) 
        b.digits[i] = a.digits[i];
    
    a.digits.swap(b.digits);


void BigInteger::divideWithoutSign(BigInteger& c, BigInteger& d, const BigInteger& a, const BigInteger& b) 
    c.digits.clear();
    c.digits.push_back(0);
    BigInteger two("2");
    BigInteger e = b;
    BigInteger f("1");
    BigInteger g = a;
    BigInteger one("1");
    while (cmpWithoutSign(g, e) >= 0) 
        shiftLeft(e);
        shiftLeft(f);
    
    shiftRight(e);
    shiftRight(f);
    while (cmpWithoutSign(g, b) >= 0) 
        g -= e;
        c += f;
        while (cmpWithoutSign(g, e) < 0) 
            shiftRight(e);
            shiftRight(f);
        
    
    e = c;
    e *= b;
    f = a;
    f -= e;
    d = f;


BigInteger::BigInteger(const BigInteger& other) 
    digits = other.digits;
    negative = other.negative;


BigInteger::BigInteger(const char* other) 
    digits.push_back(0);
    negative = false;
    BigInteger ten;
    ten.digits[0] = 10;
    const char* c = other;
    bool make_negative = false;
    if (*c == '-') 
        make_negative = true;
        ++c;
    
    while (*c != 0) 
        BigInteger digit;
        digit.digits[0] = *c - '0';
        *this *= ten;
        *this += digit;
        ++c;
    
    negative = make_negative;


bool BigInteger::isOdd() const 
    return (digits[0] & 0x1) != 0;


BigInteger& BigInteger::operator=(const BigInteger& other) 
    if (this == &other) // handle self assignment
        return *this;
    digits = other.digits;
    negative = other.negative;
    return *this;


BigInteger& BigInteger::operator+=(const BigInteger& other) 
    BigInteger result;
    if (negative) 
        if (other.negative) 
            result.negative = true;
            addWithoutSign(result, *this, other);
         else 
            int a = cmpWithoutSign(*this, other);
            if (a < 0) 
                result.negative = false;
                subWithoutSign(result, other, *this);
             else if (a > 0) 
                result.negative = true;
                subWithoutSign(result, *this, other);
             else 
                result.negative = false;
                result.digits.clear();
                result.digits.push_back(0);
            
        
     else 
        if (other.negative) 
            int a = cmpWithoutSign(*this, other);
            if (a < 0) 
                result.negative = true;
                subWithoutSign(result, other, *this);
             else if (a > 0) 
                result.negative = false;
                subWithoutSign(result, *this, other);
             else 
                result.negative = false;
                result.digits.clear();
                result.digits.push_back(0);
            
         else 
            result.negative = false;
            addWithoutSign(result, *this, other);
        
    
    negative = result.negative;
    digits.swap(result.digits);
    return *this;


BigInteger& BigInteger::operator-=(const BigInteger& other) 
    BigInteger neg_other = other;
    neg_other.negative = !neg_other.negative;
    return *this += neg_other;


BigInteger& BigInteger::operator*=(const BigInteger& other) 
    BigInteger result;
    for (int i = 0; i < (int)digits.size(); ++i) 
        BigInteger mult;
        multByDigitWithoutSign(mult, other, digits[i]);
        BigInteger shift;
        shiftLeftByBase(shift, mult, i);
        BigInteger add;
        addWithoutSign(add, result, shift);
        result = add;
    
    if (negative != other.negative) 
        result.negative = true;
     else 
        result.negative = false;
    
    //std::cout << toString() << " x " << other.toString() << " == " << result.toString() << std::endl;
    negative = result.negative;
    digits.swap(result.digits);
    return *this;


BigInteger& BigInteger::operator/=(const BigInteger& other) 
    BigInteger result, tmp;
    divideWithoutSign(result, tmp, *this, other);
    result.negative = (negative != other.negative);
    negative = result.negative;
    digits.swap(result.digits);
    return *this;


BigInteger& BigInteger::operator%=(const BigInteger& other) 
    BigInteger c, d;
    divideWithoutSign(c, d, *this, other);
    *this = d;
    return *this;


bool BigInteger::operator>(const BigInteger& other) const 
    if (negative) 
        if (other.negative) 
            return cmpWithoutSign(*this, other) < 0;
         else 
            return false;
        
     else 
        if (other.negative) 
            return true;
         else 
            return cmpWithoutSign(*this, other) > 0;
        
    


BigInteger& BigInteger::powAssignUnderMod(const BigInteger& exponent, const BigInteger& modulus) 
    BigInteger zero("0");
    BigInteger one("1");
    BigInteger e = exponent;
    BigInteger base = *this;
    *this = one;
    while (cmpWithoutSign(e, zero) != 0) 
        //std::cout << e.toString() << " : " << toString() << " : " << base.toString() << std::endl;
        if (e.isOdd()) 
            *this *= base;
            *this %= modulus;
        
        shiftRight(e);
        base *= BigInteger(base);
        base %= modulus;
    
    return *this;


std::string BigInteger::toString() const 
    std::ostringstream os;
    if (negative)
        os << "-";
    BigInteger tmp = *this;
    BigInteger zero("0");
    BigInteger ten("10");
    tmp.negative = false;
    std::stack<char> s;
    while (cmpWithoutSign(tmp, zero) != 0) 
        BigInteger tmp2, tmp3;
        divideWithoutSign(tmp2, tmp3, tmp, ten);
        s.push((char)(tmp3.digits[0] + '0'));
        tmp = tmp2;
    
    while (!s.empty()) 
        os << s.top();
        s.pop();
    
    /*
    for (int i = digits.size()-1; i >= 0; --i) 
        os << digits[i];
        if (i != 0) 
            os << ",";
        
    
    */
    return os.str();

还有一个用法示例。

BigInteger a("87682374682734687"), b("435983748957348957349857345"), c("2348927349872344")

// Will Calculate pow(87682374682734687, 435983748957348957349857345) % 2348927349872344
a.powAssignUnderMod(b, c);

它也很快,并且有无限位数。

【讨论】:

感谢分享!一个问题, digit 是 std::vector 吗? 是的,但是在引擎盖下使用 base 65536,而不是 base 10。【参考方案5】:
    package playTime;

    public class play 

        public static long count = 0; 
        public static long binSlots = 10; 
        public static long y = 645; 
        public static long finalValue = 1; 
        public static long x = 11; 

        public static void main(String[] args)

            int[] binArray = new int[]0,0,1,0,0,0,0,1,0,1;  

            x = BME(x, count, binArray); 

            System.out.print("\nfinal value:"+finalValue);

        

        public static long BME(long x, long count, int[] binArray)

            if(count == binSlots)
                return finalValue; 
            

            if(binArray[(int) count] == 1)
                finalValue = finalValue*x%y; 
            

            x = (x*x)%y; 
            System.out.print("Array("+binArray[(int) count]+") "
                            +"x("+x+")" +" finalVal("+              finalValue + ")\n");

            count++; 


            return BME(x, count,binArray); 
        

    

【讨论】:

那是我很快用java写的代码。我使用的示例是 11^644mod 645。= 1。我们知道 645 的二进制是 1010000100。我有点作弊并对变量进行硬编码,但它工作正常。 输出为 Array(0) x(121) finalVal(1) Array(0) x(451) finalVal(1) Array(1) x(226) finalVal(451) Array(0) x(121) finalVal(451) 数组(0) x(451) finalVal(451) 数组(0) x(226) finalVal(451) 数组(0) x(121) finalVal(451) 数组(1) x( 451) finalVal(391) Array(0) x(226) finalVal(391) Array(1) x(121) finalVal(1) 最终值:1【参考方案6】:

LL 代表long long int

LL power_mod(LL a, LL k) 
    if (k == 0)
        return 1;
    LL temp = power(a, k/2);
    LL res;

    res = ( ( temp % P ) * (temp % P) ) % P;
    if (k % 2 == 1)
        res = ((a % P) * (res % P)) % P;
    return res;

使用上面的递归函数来找到数字的 mod exp。这不会导致溢出,因为它以自下而上的方式计算。

示例测试运行: a = 2k = 168277 显示输出为 518358,这是正确的,函数在 O(log(k)) 时间运行;

【讨论】:

以上是关于C++ 中高数的模幂运算的主要内容,如果未能解决你的问题,请参考以下文章

c_cpp 模幂运算

模幂运算

模幂运算问题,使用朴素算法和重复-平方算法(快速幂+C#计算程序运行时间)

模幂运算问题,使用朴素算法和重复-平方算法(快速幂+C#计算程序运行时间)

欧几里得算法解决 RR' - NN' = 1. 使用蒙哥马利算法进行模幂运算以在 python 或 Petite Chez 方案中实现费马检验

二进制计数器的模值是啥意思?呀