跳转至

拉格朗日插值

例题 Luogu P4781【模板】拉格朗日插值

给出 \(n\) 个点对 \((x_i,y_i)\)\(k\),且 \(\forall i,j\)\(i\neq j \iff x_i\neq x_j\)\(f(x_i)\equiv y_i\pmod{998244353}\)\(\deg(f(x))<n\)(定义 \(\deg(0)=-\infty\)),求 \(f(k)\bmod{998244353}\)

方法 1:差分法

差分法适用于 \(x_i=i\) 的情况。

如,用差分法求某三次多项式 \(f(x)=\sum_{i=0}^{3} a_ix^i\) 的多项式形式,已知 \(f(1)\)\(f(6)\) 的值分别为 \(1, 5, 14, 30, 55, 91\)

\[ \begin{array}{cccccccccccc} 1 & & 5 & & 14 & & 30 & & 55 & & 91 & \\ & 4 & & 9 & & 16 & & 25 & & 36 & \\ & & 5 & & 7 & & 9 & & 11 & \\ & & & 2 & & 2 & & 2 & \\ \end{array} \]

第一行为 \(f(x)\) 的连续的前 \(n\) 项;之后的每一行为之前一行中对应的相邻两项之差。观察到,如果这样操作的次数足够多(前提是 \(f(x)\) 为多项式),最终总会返回一个定值,可以利用这个定值求出 \(f(x)\) 的每一项的系数,然后即可将 \(k\) 代入多项式中求解。上例中可求出 \(f(x)=\frac 1 3 x^3+\frac 1 2 x^2+\frac 1 6 x\)。时间复杂度为 \(O(n^2)\)。这种方法对给出的点的限制性较强。

方法 2:待定系数法

\(f(x)=\sum_{i=0}^{n-1} a_ix^i\) 将每个 \(x_i\) 代入 \(f(x)\),有 \(f(x_i)=y_i\),这样就可以得到一个由 \(n\)\(n\) 元一次方程所组成的方程组,然后使用 高斯消元 解该方程组求出每一项 \(a_i\),即确定了 \(f(x)\) 的表达式。

如果您不知道什么是高斯消元,请看 高斯消元

时间复杂度 \(O(n^3)\),对给出点的坐标无要求。

方法 3:拉格朗日插值法

多项式部分简介 里我们已经定义了多项式除法。

那么我们会有:

\[ f(x)\equiv f(a)\pmod{(x-a)} \]

因为 \(f(x)-f(a)=(a_0-a_0)+a_1(x^1-a^1)+a_1(x^2-a^2)+\cdots +a_n(x^n-a^n)\),显然有 \((x-a)\) 这个因式。

这样我们就可以列一个关于 \(f(x)\) 的多项式线性同余方程组:

\[ \begin{cases} f(x)\equiv y_1\pmod{(x-x_1)}\\ f(x)\equiv y_2\pmod{(x-x_2)}\\ \vdots\\ f(x)\equiv y_n\pmod{(x-x_n)} \end{cases} \]

\[ \begin{aligned} M(x)&=\prod_{i=1}^n{(x-x_i)},\\ m_i(x)&=\dfrac M{x-x_i} \end{aligned} \]

\(m_i(x)\) 在模 \((x-x_i)\) 意义下的乘法逆元为

\[ m_i(x_i)^{-1}=\prod_{j\ne i}{(x_i-x_j)^{-1}} \]

\[ \begin{aligned} f(x)&\equiv\sum_{i=1}^n{y_i\left(m_i(x)\right)\left(m_i(x_i)^{-1}\right)}&\pmod{M(x)}\\ &\equiv\sum_{i=1}^n{y_i\prod_{j\ne i}{\dfrac {x-x_j}{x_i-x_j}}}&\pmod{M(x)} \end{aligned} \]

又因为 \(\deg\left(f(x)\right)<n\) 所以在模 \(M(x)\) 意义下 \(f(x)\) 就是唯一的,即:

\[ f(x)=\sum_{i=1}^n{y_i\prod_{j\ne i}{\dfrac {x-x_j}{x_i-x_j}}} \]

这就是拉格朗日插值的表达式。

通常意义下拉格朗日插值的一种推导

由于要求构造一个函数 \(f(x)\) 过点 \(P_1(x_1, y_1), P_2(x_2,y_2),\cdots,P_n(x_n,y_n)\)。首先设第 \(i\) 个点在 \(x\) 轴上的投影为 \(P_i^{\prime}(x_i,0)\)

考虑构造 \(n\) 个函数 \(f_1(x), f_2(x), \cdots, f_n(x)\),使得对于第 \(i\) 个函数 \(f_i(x)\),其图像过 \(\begin{cases}P_j^{\prime}(x_j,0),(j\neq i)\\P_i(x_i,y_i)\end{cases}\),则可知题目所求的函数 \(f(x)=\sum\limits_{i=1}^nf_i(x)\)

那么可以设 \(f_i(x)=a\cdot\prod_{j\neq i}(x-x_j)\),将点 \(P_i(x_i,y_i)\) 代入可以知道 \(a=\dfrac{y_i}{\prod_{j\neq i} (x_i-x_j)}\),所以

\(f_i(x)=y_i\cdot\dfrac{\prod_{j\neq i} (x-x_j)}{\prod_{j\neq i} (x_i-x_j)}=y_i\cdot\prod_{j\neq i}\dfrac{x-x_j}{x_i-x_j}\)

那么我们就可以从另一个角度推导出通常意义下(而非模意义下)拉格朗日插值的式子为:

\(f(x)=\sum_{i=1}^ny_i\cdot\prod_{j\neq i}\dfrac{x-x_j}{x_i-x_j}\)

代码实现

因为在固定模 \(998244353\) 意义下运算,计算乘法逆元的时间复杂度我们在这里暂且认为是常数时间。

C++
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
#include <exception>
#include <iostream>
#include <optional>
#include <tuple>
#include <utility>
#include <vector>

template <unsigned int Mod>
class Fp {
  static_assert(static_cast<int>(Mod) > 1);

 public:
  Fp() : v_() {}

  Fp(int v) : v_(safe_mod(v)) {}

  static unsigned int safe_mod(int v) {
    v %= static_cast<int>(Mod);
    return v < 0 ? v + static_cast<int>(Mod) : v;
  }

  unsigned int value() const { return v_; }

  Fp operator-() const { return Fp(Mod - v_); }

  Fp pow(int e) const {
    if (e < 0) return inv().pow(-e);
    for (Fp x(*this), res(1);; x *= x) {
      if (e & 1) res *= x;
      if ((e >>= 1) == 0) return res;
    }
  }

  Fp inv() const {
    int x1 = 1, x3 = 0, a = v_, b = Mod;
    while (b != 0) {
      int q = a / b, x1_old = x1, a_old = a;
      x1 = x3, x3 = x1_old - x3 * q, a = b, b = a_old - b * q;
    }
    return Fp(x1);
  }

  Fp &operator+=(const Fp &rhs) {
    if ((v_ += rhs.v_) >= Mod) v_ -= Mod;
    return *this;
  }

  Fp &operator-=(const Fp &rhs) {
    if ((v_ += Mod - rhs.v_) >= Mod) v_ -= Mod;
    return *this;
  }

  Fp &operator*=(const Fp &rhs) {
    v_ = static_cast<unsigned long long>(v_) * rhs.v_ % Mod;
    return *this;
  }

  Fp &operator/=(const Fp &rhs) { return operator*=(rhs.inv()); }

  void swap(Fp &rhs) {
    unsigned int v = v_;
    v_ = rhs.v_, rhs.v_ = v;
  }

  friend Fp operator+(const Fp &lhs, const Fp &rhs) { return Fp(lhs) += rhs; }

  friend Fp operator-(const Fp &lhs, const Fp &rhs) { return Fp(lhs) -= rhs; }

  friend Fp operator*(const Fp &lhs, const Fp &rhs) { return Fp(lhs) *= rhs; }

  friend Fp operator/(const Fp &lhs, const Fp &rhs) { return Fp(lhs) /= rhs; }

  friend bool operator==(const Fp &lhs, const Fp &rhs) {
    return lhs.v_ == rhs.v_;
  }

  friend bool operator!=(const Fp &lhs, const Fp &rhs) {
    return lhs.v_ != rhs.v_;
  }

  friend std::istream &operator>>(std::istream &lhs, Fp &rhs) {
    int v;
    lhs >> v;
    rhs = Fp(v);
    return lhs;
  }

  friend std::ostream &operator<<(std::ostream &lhs, const Fp &rhs) {
    return lhs << rhs.v_;
  }

 private:
  unsigned int v_;
};

template <typename T>
class Poly : public std::vector<T> {
 public:
  using std::vector<T>::vector;  // 使用继承的构造函数

  bool is_zero() const { return deg() == -1; }

  void shrink() { this->resize(std::max(deg() + 1, 1)); }

  int deg()
      const {  // 多项式的次数,当多项式为零时度数为 -1 而不是一般定义的负无穷
    int d = static_cast<int>(this->size()) - 1;
    const T z;
    while (d >= 0 && this->operator[](d) == z) --d;
    return d;
  }

  T leading_coeff() const {
    int d = deg();
    return d == -1 ? T() : this->operator[](d);
  }

  Poly operator-() const {
    Poly res;
    res.reserve(this->size());
    for (auto &&i : *this) res.emplace_back(-i);
    res.shrink();
    return res;
  }

  Poly &operator+=(const Poly &rhs) {
    if (this->size() < rhs.size()) this->resize(rhs.size());
    for (int i = 0, e = static_cast<int>(rhs.size()); i != e; ++i)
      this->operator[](i) += rhs[i];
    shrink();
    return *this;
  }

  Poly &operator-=(const Poly &rhs) {
    if (this->size() < rhs.size()) this->resize(rhs.size());
    for (int i = 0, e = static_cast<int>(rhs.size()); i != e; ++i)
      this->operator[](i) -= rhs[i];
    shrink();
    return *this;
  }

  Poly &operator*=(const Poly &rhs) {
    int n = deg(), m = rhs.deg();
    if (n == -1 || m == -1) return operator=(Poly{0});
    Poly res(n + m + 1);
    for (int i = 0; i <= n; ++i)
      for (int j = 0; j <= m; ++j) res[i + j] += this->operator[](i) * rhs[j];
    return operator=(res);
  }

  Poly &operator/=(const Poly &rhs) {
    int n = deg(), m = rhs.deg(), q = n - m;
    if (m == -1) throw std::runtime_error("Division by zero");
    if (q <= -1) return operator=(Poly{0});
    Poly res(q + 1);
    const T iv = 1 / rhs.leading_coeff();
    for (int i = q; i >= 0; --i)
      if ((res[i] = this->operator[](n--) * iv) != T())
        for (int j = 0; j != m; ++j) this->operator[](i + j) -= res[i] * rhs[j];
    return operator=(res);
  }

  Poly &operator%=(const Poly &rhs) {
    int n = deg(), m = rhs.deg(), q = n - m;
    if (m == -1) throw std::runtime_error("Division by zero");
    const T iv = 1 / rhs.leading_coeff();
    for (int i = q; i >= 0; --i)
      if (T res = this->operator[](n--) * iv; res != T())
        for (int j = 0; j <= m; ++j) this->operator[](i + j) -= res * rhs[j];
    shrink();
    return *this;
  }

  std::pair<Poly, Poly> div_mod(const Poly &rhs) const {
    int n = deg(), m = rhs.deg(), q = n - m;
    if (m == -1) throw std::runtime_error("Division by zero");
    if (q <= -1) return std::make_pair(Poly{0}, Poly(*this));
    const T iv = 1 / rhs.leading_coeff();
    Poly quo(q + 1), rem(*this);
    for (int i = q; i >= 0; --i)
      if ((quo[i] = rem[n--] * iv) != T())
        for (int j = 0; j <= m; ++j) rem[i + j] -= quo[i] * rhs[j];
    rem.shrink();
    return std::make_pair(quo, rem);  // (quotient, remainder)
  }

  T eval(const T &pt) const {
    T res;
    for (int i = deg(); i >= 0; --i) res = res * pt + this->operator[](i);
    return res;
  }

  friend Poly operator+(const Poly &lhs, const Poly &rhs) {
    return Poly(lhs) += rhs;
  }

  friend Poly operator-(const Poly &lhs, const Poly &rhs) {
    return Poly(lhs) -= rhs;
  }

  friend Poly operator*(const Poly &lhs, const Poly &rhs) {
    return Poly(lhs) *= rhs;
  }

  friend Poly operator/(const Poly &lhs, const Poly &rhs) {
    return Poly(lhs) /= rhs;
  }

  friend Poly operator%(const Poly &lhs, const Poly &rhs) {
    return Poly(lhs) %= rhs;
  }

  friend bool operator==(const Poly &lhs, const Poly &rhs) {
    int d = lhs.deg();
    if (d != rhs.deg()) return false;
    for (; d >= 0; --d)
      if (lhs[d] != rhs[d]) return false;
    return true;
  }

  friend bool operator!=(const Poly &lhs, const Poly &rhs) {
    return !(lhs == rhs);
  }

  friend std::ostream &operator<<(std::ostream &lhs, const Poly &rhs) {
    int s = 0, e = static_cast<int>(rhs.size());
    lhs << '[';
    for (auto &&i : rhs) {
      lhs << i;
      if (s >= 1) lhs << 'x';
      if (s > 1) lhs << '^' << s;
      if (++s != e) lhs << " + ";
    }
    return lhs << ']';
  }
};

template <typename T>
Poly<T> lagrange_interpolation(const std::vector<T> &x,
                               const std::vector<T> &y) {
  if (x.size() != y.size()) throw std::runtime_error("x.size() != y.size()");
  const int n = static_cast<int>(x.size());
  Poly<T> M = {T(1)}, f;
  for (int i = 0; i != n; ++i) M *= Poly<T>{-x[i], T(1)};
  for (int i = 0; i != n; ++i) {
    auto m = M / Poly<T>{-x[i], T(1)};
    f += Poly<T>{y[i] / m.eval(x[i])} * m;
  }
  return f;
}

int main() {
  std::ios::sync_with_stdio(false);
  std::cin.tie(nullptr);
  using Z = Fp<998244353>;
  int n;
  Z k;
  std::cin >> n >> k;
  std::vector<Z> x(n), y(n);
  for (int i = 0; i != n; ++i) std::cin >> x[i] >> y[i];
  std::cout << lagrange_interpolation(x, y).eval(k) << std::endl;
  return 0;
}

本题中只用求出 \(f(k)\) 的值,所以在计算上式的过程中直接将 \(k\) 代入即可。

\[ f(k)=\sum_{i=1}^{n}y_i\prod_{j\neq i }\frac{k-x_j}{x_i-x_j} \]

本题中,还需要求解逆元。如果先分别计算出分子和分母,再将分子乘进分母的逆元,累加进最后的答案,时间复杂度的瓶颈就不会在求逆元上,时间复杂度为 \(O(n^2)\)

横坐标是连续整数的拉格朗日插值

如果已知点的横坐标是连续整数,我们可以做到 \(O(n)\) 插值。

设要求 \(n\) 次多项式为 \(f(x)\),我们已知 \(f(1),\cdots,f(n+1)\)\(1\le i\le n+1\)),考虑代入上面的插值公式:

\[ \begin{aligned} f(x)&=\sum\limits_{i=1}^{n+1}y_i\prod\limits_{j\ne i}\frac{x-x_j}{x_i-x_j}\\ &=\sum\limits_{i=1}^{n+1}y_i\prod\limits_{j\ne i}\frac{x-j}{i-j} \end{aligned} \]

后面的累乘可以分子分母分别考虑,不难得到分子为:

\[ \dfrac{\prod\limits_{j=1}^{n+1}(x-j)}{x-i} \]

分母的 \(i-j\) 累乘可以拆成两段阶乘来算:

\[ (-1)^{n+1-i}\cdot(i-1)!\cdot(n+1-i)! \]

于是横坐标为 \(1,\cdots,n+1\) 的插值公式:

\[ f(x)=\sum\limits_{i=1}^{n+1}y_i\cdot\frac{\prod\limits_{j=1}^{n+1}(x-j)}{(x-i)\cdot(-1)^{n+1-i}\cdot(i-1)!\cdot(n+1-i)!} \]

预处理 \((x-i)\) 前后缀积、阶乘阶乘逆,然后代入这个式子,复杂度为 \(O(n)\)

例题 CF622F The Sum of the k-th Powers

给出 \(n,k\),求 \(\sum\limits_{i=1}^ni^k\)\(10^9+7\) 取模的值。

本题中,答案是一个 \(k+1\) 次多项式,因此我们可以线性筛出 \(1^i,\cdots,(k+2)^i\) 的值然后进行 \(O(n)\) 插值。

代码实现
C++
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
// By: Luogu@rui_er(122461)
#include <bits/stdc++.h>
using namespace std;
const int N = 1e6 + 5, mod = 1e9 + 7;

int n, k, tab[N], p[N], pcnt, f[N], pre[N], suf[N], fac[N], inv[N], ans;

int qpow(int x, int y) {
  int ans = 1;
  for (; y; y >>= 1, x = 1LL * x * x % mod)
    if (y & 1) ans = 1LL * ans * x % mod;
  return ans;
}

void sieve(int lim) {
  f[1] = 1;
  for (int i = 2; i <= lim; i++) {
    if (!tab[i]) {
      p[++pcnt] = i;
      f[i] = qpow(i, k);
    }
    for (int j = 1; j <= pcnt && 1LL * i * p[j] <= lim; j++) {
      tab[i * p[j]] = 1;
      f[i * p[j]] = 1LL * f[i] * f[p[j]] % mod;
      if (!(i % p[j])) break;
    }
  }
  for (int i = 2; i <= lim; i++) f[i] = (f[i - 1] + f[i]) % mod;
}

int main() {
  scanf("%d%d", &n, &k);
  sieve(k + 2);
  if (n <= k + 2) return printf("%d\n", f[n]) & 0;
  pre[0] = suf[k + 3] = 1;
  for (int i = 1; i <= k + 2; i++) pre[i] = 1LL * pre[i - 1] * (n - i) % mod;
  for (int i = k + 2; i >= 1; i--) suf[i] = 1LL * suf[i + 1] * (n - i) % mod;
  fac[0] = inv[0] = fac[1] = inv[1] = 1;
  for (int i = 2; i <= k + 2; i++) {
    fac[i] = 1LL * fac[i - 1] * i % mod;
    inv[i] = 1LL * (mod - mod / i) * inv[mod % i] % mod;
  }
  for (int i = 2; i <= k + 2; i++) inv[i] = 1LL * inv[i - 1] * inv[i] % mod;
  for (int i = 1; i <= k + 2; i++) {
    int P = 1LL * pre[i - 1] * suf[i + 1] % mod;
    int Q = 1LL * inv[i - 1] * inv[k + 2 - i] % mod;
    int mul = ((k + 2 - i) & 1) ? -1 : 1;
    ans = (ans + 1LL * (Q * mul + mod) % mod * P % mod * f[i] % mod) % mod;
  }
  printf("%d\n", ans);
  return 0;
}