作者:手机用户2502891053 | 来源:互联网 | 2024-11-22 13:11
本文探讨了使用普通生成函数和指数生成函数解决组合与排列问题的方法,特别是在处理特定路径计数问题时的应用。文章通过详细分析和代码实现,展示了如何高效地计算在给定条件下不相邻相同元素的排列数量。
HDU 6116 路径计数优化
普通生成函数(OGF)主要用于解决组合问题,而指数生成函数(EGF)则更多应用于排列问题。这两种工具在解决特定类型的数学问题时非常有效,尤其是当涉及到计数问题时。
以一个问题为例,假设需要将 $a$ 个 $A$ 类型的对象分成多堆,每堆可以包含任意数量的对象,但不允许出现相邻两堆完全相同的情况。对于这样的分配,其方案数可以通过组合数 $C_{a-1}^{i-1}$ 来表示,其中 $i$ 是某一特定堆的数量。
进一步地,如果我们将每一堆视为一个独立的单元进行排列,则总的排列方式数量可以通过公式 $\frac{(i+j+k+l)!}{i!j!k!l!}$ 来计算,这里 $i, j, k, l$ 分别代表不同类型的堆的数量。值得注意的是,在这些排列中,即使所有的单个堆被视为不同的单位,仍然可能存在相邻的堆含有相同类型对象的情况,至少会有 $n-i-j-k-l$ 对相邻的相同堆。
为了准确计算不包含任何相邻相同堆的排列数,我们可以采用容斥原理。具体来说,通过枚举所有可能的 $i+j+k+l$ 的值,直接计算每个情况下的排列数,最终结果可以通过求和得到。数学表达式如下:
$ ans = \displaystyle \sum_{x=1}^{n} (-1)^{n-x} x! \sum_{i+j+k+l=x} \frac{C_{a-1}^{i-1} C_{b-1}^{j-1} C_{c-1}^{k-1} C_{d-1}^{l-1}}{i ! j ! k ! l !} $
上述公式实际上是四个指数生成函数的乘积形式,这为我们提供了一个有效的计算框架。
#include
using namespace std;
#define MAXN 200010
#define MOD 998244353
#define clr(a) memset(a, 0, sizeof a)
typedef long long ll;
// 快速幂算法
int quick_pow(int base, int exp) {
int result = 1;
while (exp) {
if (exp & 1) result = result * (ll)base % MOD;
base = base * (ll)base % MOD;
exp >>= 1;
}
return result;
}
// 计算组合数
int comb(int n, int m) {
if (m > n) return 0;
return (ll)J[n] * invJ[m] % MOD * invJ[n - m] % MOD;
}
int a, b, c, d;
int J[MAXN], invJ[MAXN], inv[MAXN];
int A[MAXN], B[MAXN], C[MAXN], D[MAXN];
int main() {
// 预处理阶乘和逆元
J[0] = inv[1] = invJ[0] = J[1] = invJ[1] = 1;
for (int i = 2; i inv[i] = (ll)(MOD - MOD / i) * inv[MOD % i] % MOD;
J[i] = (ll)J[i - 1] * i % MOD;
invJ[i] = (ll)invJ[i - 1] * inv[i] % MOD;
}
while (cin >> a >> b >> c >> d) {
clr(A), clr(B), clr(C), clr(D);
int n = a + b + c + d;
for (int i = 1; i <= a; ++i) A[i] = (ll)comb(a - 1, i - 1) * invJ[i] % MOD;
for (int i = 1; i <= b; ++i) B[i] = (ll)comb(b - 1, i - 1) * invJ[i] % MOD;
for (int i = 1; i <= c; ++i) C[i] = (ll)comb(c - 1, i - 1) * invJ[i] % MOD;
for (int i = 1; i <= d; ++i) D[i] = (ll)comb(d - 1, i - 1) * invJ[i] % MOD;
int len = 1, l = 0;
while (len <= n) len <<= 1, ++l;
// NTT预处理
int wn[2][MAXN];
for (int i = 1; i <(1 < int w0 = quick_pow(3, (MOD - 1) / (i <<1)), w1 = quick_pow(3, MOD - 1 - (MOD - 1) / (i <<1));
wn[0][i] = wn[1][i] = 1;
for (int j = 1; j wn[0][i + j] = (ll)wn[0][i + j - 1] * w0 % MOD,
wn[1][i + j] = (ll)wn[1][i + j - 1] * w1 % MOD;
}
int rev[MAXN];
for (int i = 1; i <(1 <> 1] >> 1) | ((i & 1) <
// 执行NTT变换
auto NTT = [&](int *A, int len, int flag) {
for (int i = 0; i for (int l = 1; l for (int i = 0; i for (int k = 0; k int t1 = A[i + k], t2 = (ll)A[i + l + k] * wn[flag][l + k] % MOD;
A[i + k] = (t1 + t2) % MOD;
A[i + l + k] = (t1 - t2 + MOD) % MOD;
}
if (flag == 1) {
int inv_len = quick_pow(len, MOD - 2);
for (int i = 0; i }
};
NTT(A, len, 0), NTT(B, len, 0), NTT(C, len, 0), NTT(D, len, 0);
for (int i = 0; i NTT(A, len, 1);
ll result = 0;
for (int i = 1; i <= n; ++i)
result += ((n - i & 1) ? -1ll : 1ll) * J[i] * A[i] % MOD,
result += MOD, result %= MOD;
cout < }
}