中国剩余定理学习笔记

2022-09-23 11:56:12 浏览数 (1)

中国剩余定理

begin{cases} x equiv a_1 pmod {p_1} x equiv a_2 pmod {p_2} ldots x equiv a_i pmod {p_i} end{cases}

中国剩余定理用于求解模数两两互质的线性同余方程组。

对于这种 x equiv a_i pmod{p_i} 的线性同余方程组,在 bmod prod limits _{i=1}^n p_i 的意义下有唯一解。


M = prod limits _{i=1}^n p_i ,而 M_i = dfrac{M}{p_i} .

则解 x = sum limits _{i=1}^{n} M_i times K_i ,其中 K_i 满足:

M_i times K_i equiv a_i pmod{p_i}

对于任意的 M_i times K_i ,由于 M_i 是所有剩余模数的积,则 M_i times K_i equiv 0 pmod{p_j} (j neq i) ,因此在 x 中加上 M_i times K_i 对满足 x equiv a_j pmod{p_j} (j neq i) 无影响。

而每项 M_i times K_i 都能保证第 i 个线性同余方程得到满足。

只需要解出 K_i 即可。

M_i times K_i p_i times t = a_i

利用扩展欧几里得定理求出 K_i .

由于模数两两互素, gcd(M_i,p_i) = 1 ,即解:

M_i times k_i p_i times t = 1k_i ,然后 K_i = a_i times k_i .

可以发现, k_i 就是 M_ibmod p_i 意义下的逆元。

那么就可以得到一个解: x = sum limits _{i=1}^n M_i times k_i times a_i


而通解为: x = k times M sum limits _{i=1}^n M_i times k_i times a_i

证明:设两个解为 x_1,x_2 ,有:

forall i in {1,2,ldots,n},x_1 - x_2 equiv 0 pmod{p_i}

p_i 两两互素,则 M mid x_1 - x_2 ,那么 x_1 - x_2 = k times M ,则在 bmod M 意义下有唯一解。


代码语言:javascript复制
inline ll CRT()
{
    ll prod = 1,ans = 0;
    for (int i = 1; i <= n;   i) prod *= a[i];
    for (int i = 1; i <= n;   i)
    {
        ll x,y;
        exgcd(prod / a[i], a[i], x, y);
        ans = (ans   prod / a[i] * b[i] * x) % prod;
    }
    return (ans   prod) % prod;
}


例题:[TJOI2009]猜数字

给出 a_i,b_i ,要求最小的 n 使得 b_i mid (n - a_i) .

改写为 n - a_i equiv 0 pmod {b_i} .

n equiv a_i pmod{b_i}

n 的最小整数解。

这题会爆 long long,需要用到龟速乘的技巧。

代码语言:javascript复制
#include <cstdio>
#define ll long long
const int maxk = 20;
ll k, a[maxk], b[maxk];
void exgcd(ll a, ll b, ll& x, ll& y)
{
    if (b == 0) return (void)(x = 1, y = 0);
    exgcd(b, a % b, y, x), y -= (a / b) * x;
}
ll mul(ll a,ll b,ll mod)
{
    ll res = 0;
    while(a)
    {
        if(a & 1) res = (res   b) % mod;
        b = (b   b) % mod;
        a >>= 1;
    }
    return res % mod;
}
ll crt()
{
    ll prod = 1, ans = 0;
    for (int i = 1; i <= k;   i) prod *= b[i];
    for (int i = 1; i <= k;   i)
    {
        ll x, y, mi = prod / b[i];
        exgcd(mi, b[i], x, y);
        x = ((x % b[i])   b[i]) % b[i];
        ans = (ans   mul(mul(mi, x, prod), a[i], prod)) % prod;
    }
    return (ans   prod) % prod;
}
int main()
{
    scanf("%lld",&k);
    for (int i = 1; i <= k;   i) scanf("%lld", a   i);
    for (int i = 1; i <= k;   i) scanf("%lld", b   i);
    printf("%lldn", crt());
    return 0;
}

拓展中国剩余定理

然而如果不保证模数互质,就需要使用拓展中国剩余定理。

模数不互质的话,之前的巧妙构造行不通了。尝试更平凡的列式子考虑:

begin{cases}x equiv r_1 pmod{p_1} x equiv r_2 pmod{p_2}end{cases}

可以写成:

begin{cases} x = r_1 p_1 times k_1 x = r_2 p_2 times k_2 end{cases}

即:

r_1 p_1 times k_1 = r_2 p_2 times k_2

移项得到:

p_1 times k_1 - p_2 times k_2 = r_2 - r_1

等价于解 a = p_1b = -p_2c = r_2 - r_1ax by = c 二元一次不定方程。

可以用拓展欧几里得定理解出可行的 k_1,k_2 .

解出一组特解 k_1',k_2' 后,利用前文知识,求出通解:

d=gcd(p_1,-p_2)

begin{cases} k_1 = k_1' t times dfrac{-p_2}{d} k_2 = k_2' - t times dfrac{p_1}{d}end{cases}

x = r_1 p_1 times k_1 可以写为:

x = r_1 p_1 times k_1' p_1 times t times dfrac{-p_2}{d}

其中只有 t 是自变量,该式子等价于:

x equiv r_1 p_1 times k_1' pmod{|p_1 times dfrac{-p_2}{d}|}

其中 |p_1 times dfrac{-p_2}{d}| = operatorname{lcm}(p_1,p_2)

通过一系列的变换,我们成功将两个线性同余方程合并为一个线性同余方程。

依次类推,即可求出最后的一个线性同余方程,随后求出原线性同余方程组的解。

在计算过程中注意取模,将可取模的都取模以避免溢出。


代码语言:javascript复制
#include <cstdio>
#define ll long long
const int maxn = 1e6   100;
int n;
ll p[maxn], r[maxn];
ll exgcd(ll a, ll b, ll& x, ll& y)
{
    if (b == 0)
    {
        x = 1, y = 0;
        return a;
    }
    ll g = exgcd(b, a % b, y, x);
    y -= (a / b) * x;
    return g;
}
ll mul(ll a, ll b, ll mod)
{
    ll res = 0;
    while (b > 0)
    {
        if (b & 1) res = (res   a) % mod;
        a = (a   a) % mod;
        b >>= 1;
    }
    return ((res % mod)   mod) % mod;
}
inline ll exCRT()
{
    ll lcm = p[1], lastr = r[1], k1, k2;
    for (int i = 2; i <= n;   i)
    {
        ll a = lcm, b = p[i], c = (((r[i] - lastr) % p[i])   p[i]) % p[i];
        ll g = exgcd(a, b, k1, k2);
        ll mod = p[i] / g;
        k1 = ((k1 % mod)   mod) % mod;
        k1 = mul(k1, c / g, mod);

        lastr  = k1 * lcm;
        lcm = lcm / g * p[i];
        lastr = ((lastr % lcm)   lcm) % lcm;
    }
    return lastr;
}
int main()
{
    scanf("%d", &n);
    for (int i = 1; i <= n;   i) scanf("%lld %lld", p   i, r   i);
    printf("%lldn", exCRT());
    return 0;
}
sum

0 人点赞