给定n,a,bn,a,b,求有多少个长度为nn的排列满足前缀最大值数量恰好为aa,后缀最大值数量恰好为bb

n,a,b105n, a, b\le 10^5

CF960G

Solution

不难发现,全局最大值把整个排列分成了两个部分

首先很容易想到一个O(n2)O(n^2)暴力,先枚举最大值的位置pp,然后设dp[i][j]dp[i][j]表示一个长度为ii 的排列,满足从左到右有 jj 个在前缀中为最大值的元素的方案数,答案即为p=1ndp[p1][a1]dp[np][b1](n1p1)\sum_{p=1}^{n}dp[p-1][a - 1] * dp[n - p][b - 1] * \binom{n - 1}{p - 1}

在转移dp[i][j]dp[i][j]的时候,可以发现转移式子和第一类斯特林数一模一样

那么考虑它的组合意义实际上就是把ii个数分成jj个环排列

为了满足前缀最大值的限制,我们强制从大到小把每个数放到排列中。如果要使得它成为前缀最大值就放到排列的最前面;否则每个数放进去的时候都可以加到当前排列的任何一个数的后面,并且它不会成为前缀最大值。

那么我们就不需要枚举全局最大值的位置了

我们直接考虑把除最大值外剩下的n1n-1个数分成(a1)+(b1)=a+b2(a-1)+(b-1)=a+b-2个环排列,再乘个组合数枚举每个排列放在最大值前面还是最大值后面即可

答案就是

NTT优化第一类斯特林数的计算


这里稍微写一下吧,但感觉还是会专门写一篇总结。。。

首先对于第一类斯特林数,我们有生成函数 ,或者是带符号的

其中 ,分别称为上升幂和下降幂

可以考虑 这两个多项式每项系数的dp过程,发现和第一类斯特林数是一样的

然后就可以用NTT优化多项式乘法了,可以递归分治,每次暴力用NTT合并,复杂度是O(nlog2n)O(n\log^2 n),足够通过此题


然而第一类斯特林数也存在O(nlogn)O(n\log n)的做法,也比较好理解。这里懒得写了,到时候写总结的时候再写。可以先看yyb巨佬的博客

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
#include <bits/stdc++.h>

#define x first
#define y second
#define y1 Y1
#define y2 Y2
#define mp make_pair
#define pb push_back

using namespace std;

typedef long long LL;
typedef pair<int, int> pii;

template <typename T> inline int Chkmax (T &a, T b) { return a < b ? a = b, 1 : 0; }
template <typename T> inline int Chkmin (T &a, T b) { return a > b ? a = b, 1 : 0; }

inline void proc_status()
{
ifstream t ("/proc/self/status");
cerr << string (istreambuf_iterator <char> (t), istreambuf_iterator <char> ()) <<endl;
}

template <typename T> T read ()
{
T sum = 0, fl = 1; char ch = getchar();
for (; !isdigit(ch); ch = getchar()) if (ch == '-') fl = -1;
for (; isdigit(ch); ch = getchar()) sum = (sum << 3) + (sum << 1) + ch - '0';
return sum * fl;
}

const int Maxn = 2e5 + 100, Mod = 998244353, g = 3;

int N, a, b;
int fac[Maxn], ifac[Maxn];

inline int Pow (int a, int b)
{
int ans = 1;
for (int i = b; i; i >>= 1, a = 1ll * a * a % Mod) if (i & 1) ans = 1ll * ans * a % Mod;
return ans;
}

inline void Init (int maxn)
{
fac[0] = 1;
for (int i = 1; i <= maxn; ++i) fac[i] = 1ll * fac[i - 1] * i % Mod;
ifac[maxn] = Pow(fac[maxn], Mod - 2);
for (int i = maxn - 1; i >= 0; --i) ifac[i] = 1ll * ifac[i + 1] * (i + 1) % Mod;
}

inline int C (int n, int m) { return 1ll * fac[n] * ifac[m] % Mod * ifac[n - m] % Mod; }

namespace Poly
{
int n, rev[Maxn << 2];

inline void DFT (int *A, int flag)
{
for (int i = 0; i < n; ++i) if (i < rev[i]) swap (A[i], A[rev[i]]);

for (int mid = 1; mid < n; mid <<= 1)
{
int Wn = Pow(g, (Mod - 1) / mid / 2);
if (flag == -1) Wn = Pow(Wn, Mod - 2);
for (int i = 0; i < n; i += (mid << 1))
{
int W = 1;
for (int j = i; j < i + mid; ++j, W = 1ll * W * Wn % Mod)
{
int x = A[j], y = 1ll * W * A[j + mid] % Mod;
/**/ A[j] = (x + y) % Mod, A[j + mid] = (x - y + Mod) % Mod;
}
}
}

int inv = Pow(n, Mod - 2);
if (flag == -1) for (int i = 0; i < n; ++i) A[i] = 1ll * A[i] * inv % Mod;
}

inline void NTT (int *A, int *B, int N, int M)
{
n = 1; while (n <= N + M) n <<= 1;
for (int i = N; i < n; ++i) A[i] = 0;
for (int i = M; i < n; ++i) B[i] = 0;
for (int i = 0; i < n; ++i) rev[i] = (rev[i >> 1] >> 1) | (i & 1 ? (n >> 1) : 0);

DFT (A, 1), DFT (B, 1);
for (int i = 0; i < n; ++i) A[i] = 1ll * A[i] * B[i] % Mod;
DFT (A, -1);
}
}

int A[25][Maxn << 2], B[Maxn << 2];

inline int solve (int l, int r, int d)
{
if (l == r) { A[d][0] = l - 1, A[d][1] = 1; return 2; }
int mid = l + r >> 1;
int n = solve (l, mid, d + 1);
for (int i = 0; i < n; ++i) A[d][i] = A[d + 1][i];
int m = solve(mid + 1, r, d + 1);
for (int i = 0; i < m; ++i) B[i] = A[d + 1][i];
Poly :: NTT(A[d], B, n, m);
return n + m;
}

inline int S (int n, int m)
{
// (x + 0) * (x + 1) * (x + 2) * ... * (x + n - 1)
solve(1, n, 0);
return A[0][m];
}

inline void Solve ()
{
if (a + b - 2 > N - 1 || !a || !b) { puts("0"); return ; }
if (N == 1) { puts("1"); return ; }
Init(2e5);
cout<<1ll * S(N - 1, a + b - 2) * C(a + b - 2, a - 1) % Mod<<endl;
}

inline void Input ()
{
N = read<int>(), a = read<int>(), b = read<int>();
}

int main()
{
#ifndef ONLINE_JUDGE
freopen("permutation.in", "r", stdin);
freopen("permutation.out", "w", stdout);
#endif
Input();
Solve();
return 0;
}

Debug

  • 62L:写成mid=0。。。
  • 72L:忘记取模。。。