Sweet Snippet 之 大整数乘法

2022-01-10 10:34:14 浏览数 (1)

本文简单介绍了一种大整数乘法的实现方式

当整数范围较大时,直接使用乘法运算符(*)很容易导致数值溢出,如果开发工作中确实需要处理这种大范围的整数,那么我们便需要实现一下大(范围)整数的乘法运算(一般方法便是将大整数表达为字符串,然后基于字符串来进行乘法运算).

在实现大整数乘法之前,我们先来实现一下大整数的加法运算,朴素方法便是从低到高按位进行加法操作,并考虑进位的影响,代码大概如下(Lua):

代码语言:javascript复制
local big_int = {}

local function digit_num(v)
    return #tostring(v)
end

function big_int.add(a, b)
    a = tostring(a)
    b = tostring(b)
    
    local result_buffer = {}
    
    local cur_a_index = -1
    local cur_b_index = -1
    
    local last_carry = 0
    
    while true do
        local sub_a_str = a:sub(cur_a_index, cur_a_index)
        local sub_b_str = b:sub(cur_b_index, cur_b_index)
        
        local sub_a = tonumber(sub_a_str)
        local sub_b = tonumber(sub_b_str)
        
        if last_carry == 0 then
            if not sub_a and not sub_b then
                break
            elseif not sub_a then
                table.insert(result_buffer, 1, b:sub(1, cur_b_index))
                break
            elseif not sub_b then
                table.insert(result_buffer, 1, a:sub(1, cur_a_index))
                break
            end
        end
        
        sub_a = sub_a or 0
        sub_b = sub_b or 0
        
        local sub_result = sub_a   sub_b   last_carry
        
        if sub_result >= 10 then
            last_carry = 1
            table.insert(result_buffer, 1, tostring(sub_result):sub(2))
        else
            last_carry = 0
            table.insert(result_buffer, 1, tostring(sub_result))
        end
        
        cur_a_index = cur_a_index - 1
        cur_b_index = cur_b_index - 1
    end
    
    return table.concat(result_buffer)
end

return big_int

上述代码的基本思路便是按位做加法,但实际上我们可以将"位"的概念扩展一下,延伸为"段",之前我们总是按"位"做加法,现在我们可以按"段"来做加法,这样做的好处便是加法效率会有比较大的提升(注意,这个提升也只是单方面的,并不意味着代码整体效率一定会提升),当然也会带来一些"副作用",譬如前导零问题:

考虑数字 1100 和 2201 相加,我们按两位一段来进行相加,即 11 和 22 相加, 00 和 01 相加,其中 00 和 01 相加的结果为 1(00 和 01 转为数字之后的相加结果),直接转换回字符串(结果为字符串"1",但实际上应为字符串"01")会导致前导零丢失.

处理了相关问题的代码如下(Lua):

代码语言:javascript复制
local big_int = {}

local min_add_digit_num = 9

local function digit_num(v)
    return #tostring(v)
end

function big_int.add(a, b)
    a = tostring(a)
    b = tostring(b)
    
    local result_buffer = {}
    
    local a_start_index = -min_add_digit_num
    local a_end_index = -1
    
    local b_start_index = -min_add_digit_num
    local b_end_index = -1
    
    local last_carry = 0
    
    while true do
        local sub_a_str = a:sub(a_start_index, a_end_index)
        local sub_b_str = b:sub(b_start_index, b_end_index)
        
        local sub_a = tonumber(sub_a_str)
        local sub_b = tonumber(sub_b_str)
        
        if last_carry == 0 then
            if not sub_a and not sub_b then
                break
            elseif not sub_a then
                table.insert(result_buffer, 1, b:sub(1, b_end_index))
                break
            elseif not sub_b then
                table.insert(result_buffer, 1, a:sub(1, a_end_index))
                break
            end
        end
        
        sub_a = sub_a or 0
        sub_b = sub_b or 0
        
        local sub_result = sub_a   sub_b   last_carry
        
        local sub_result_digit_num = digit_num(sub_result)
        local sub_max_digit_num = math.max(digit_num(sub_a_str), digit_num(sub_b_str))
        
        if sub_max_digit_num >= min_add_digit_num then
            if sub_result_digit_num > sub_max_digit_num then
                last_carry = 1
                table.insert(result_buffer, 1, tostring(sub_result):sub(2))
            elseif sub_result_digit_num == sub_max_digit_num then
                last_carry = 0
                table.insert(result_buffer, 1, tostring(sub_result))
            else
                -- handling heading zeros
                last_carry = 0
                table.insert(result_buffer, 1, string.rep("0", sub_max_digit_num - sub_result_digit_num) .. tostring(sub_result))
            end
        else
            last_carry = 0
            table.insert(result_buffer, 1, tostring(sub_result))
        end
        
        a_end_index = a_start_index - 1
        a_start_index = a_start_index - min_add_digit_num
        
        b_end_index = b_start_index - 1
        b_start_index = b_start_index - min_add_digit_num
    end
    
    return table.concat(result_buffer)
end

return big_int

简单测试一下,新版本代码的效率大概是老版本的 3 倍左右.

OK,实现了大整数加法,我们接着来实现大整数乘法,实际上来讲,大整数乘法也是可以按位进行乘法然后直接运用大整数加法来解决的,但是这种实现方式效率较差,更好的方法还是运用二分求解:

考虑大整数乘法 a * b ,我们将 a 分为高位 a_h 和低位 a_l ,将 b 分为高位 b_h 和低位 b_l ,并设 a_l 的位数为 n , b_l 的位数为 m , 则有: a * b = (a_h * 10^{n} a_l) * (b_h * 10^{m} b_l) \ = a_h * b_h * 10^{n m} a_h * b_l * 10^{n} a_l * b_h * 10^{m} a_l * b_l 其中 a_h * b_h, a_h * b_l, a_l * b_h, a_l * b_l 都是相同的大整数乘法子问题,我们直接递归求解即可,代码大概如下(Lua,重复代码已省略):

代码语言:javascript复制
local min_mul_digit_num = 5

function big_int.mul(a, b)
    a = tostring(a)
    b = tostring(b)
    
    local a_digit_num = digit_num(a)
    local b_digit_num = digit_num(b)
    
    if a_digit_num   b_digit_num <= 2 * min_mul_digit_num then
        return tostring((tonumber(a) or 0) * (tonumber(b) or 0))
    else
        local a_digit_num_h = math.ceil(a_digit_num / 2)
        local a_digit_num_l = a_digit_num - a_digit_num_h
        local b_digit_num_h = math.ceil(b_digit_num / 2)
        local b_digit_num_l = b_digit_num - b_digit_num_h
        
        local ah = a:sub(1, a_digit_num_h)
        local al = a:sub(a_digit_num_h   1, -1)
        local bh = b:sub(1, b_digit_num_h)
        local bl = b:sub(b_digit_num_h   1, -1)
        
        local ah_mul_bh = big_int.mul(ah, bh)
        local ah_mul_bl = big_int.mul(ah, bl)
        local al_mul_bh = big_int.mul(al, bh)
        local al_mul_bl = big_int.mul(al, bl)
        
        local result = ah_mul_bh .. string.rep("0", a_digit_num_l   b_digit_num_l)
        result = big_int.add(result, ah_mul_bl .. string.rep("0", a_digit_num_l))
        result = big_int.add(result, al_mul_bh .. string.rep("0", b_digit_num_l))
        result = big_int.add(result, al_mul_bl)
        
        return result
    end
end

完整代码(包括一些测试)在这里.

0 人点赞