你有一个 nn 行、mm 列的、每个格子都填写着 00 的表格。你进行了下面的操作:

  • 对于每一行 ii ,选定自然数 ri(0rim)r_i (0 \le r_i \le m),将这一行最左边的 rir_i 个格子中的数 +1+1
  • 对于每一列 ii ,选定自然数 ci(0cin)c_i (0 \le c_i \le n), 将这一列最上面的 cic_i 个格子中的数 +1+1

问最终表格有多少种本质不同的方案

AGC035F

Solution

只有这一种可能会算重:ri=j,cj=i1r_i = j, c_j = i - 1,和ri=j1,cj=ir_i = j - 1, c_j = i

考虑把包含前一种情况的方案都减掉

枚举强制有几对行列是前一种情况情况,剩下的随便填,然后容斥

f(k)=(nk)×(mk)×k!×(m+1)nk×(n+1)mk f(k) = \binom{n}{k}\times\binom{m}{k}\times k!\times (m+1)^{n-k} \times (n+1)^{m-k} ans=k=0min{n,m}(1)kf(k) ans = \sum_{k=0}^{\min\{n, m\}}(-1)^{k}f(k)

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
#include <bits/stdc++.h>

#define x first
#define y second
#define y1 Y1
#define y2 Y2
#define mp make_pair
#define pb push_back
#define DEBUG(x) cout << #x << " = " << x << endl;

using namespace std;

typedef long long LL;
typedef pair <int, int> pii;

template <typename T> inline int Chkmax (T &a, T b) { return a < b ? a = b, 1 : 0; }
template <typename T> inline int Chkmin (T &a, T b) { return a > b ? a = b, 1 : 0; }
template <typename T> inline T read ()
{
T sum = 0, fl = 1; char ch = getchar();
for (; !isdigit(ch); ch = getchar()) if (ch == '-') fl = -1;
for (; isdigit(ch); ch = getchar()) sum = (sum << 3) + (sum << 1) + ch - '0';
return sum * fl;
}

inline void proc_status ()
{
ifstream t ("/proc/self/status");
cerr << string (istreambuf_iterator <char> (t), istreambuf_iterator <char> ()) << endl;
}

const int Maxn = 5e5 + 100;
const int Mod = 998244353;

namespace MATH
{
int fac[Maxn], ifac[Maxn];

inline void Add (int &a, int b) { if ((a += b) >= Mod) a -= Mod; }

inline int Pow (int a, int b)
{
int ans = 1;
for (int i = b; i; i >>= 1, a = (LL) a * a % Mod) if (i & 1) ans = (LL) ans * a % Mod;
return ans;
}

inline void init (int n = 5e5)
{
fac[0] = 1;
for (int i = 1; i <= n; ++i) fac[i] = (LL) fac[i - 1] * i % Mod;
ifac[n] = Pow (fac[n], Mod - 2);
for (int i = n - 1; i >= 0; --i) ifac[i] = (LL) ifac[i + 1] * (i + 1) % Mod;
}

inline int C (int n, int m) { if (n < m) return 0; return (LL) fac[n] * ifac[m] % Mod * ifac[n - m] % Mod; }
}

using namespace MATH;

int N, M;

inline void Solve ()
{
int ans = 0;
for (int k = 0; k <= min (N, M); ++k)
{
int sum = (LL) C (N, k) * C (M, k) % Mod * fac[k] % Mod * Pow (M + 1, N - k) % Mod * Pow (N + 1, M - k) % Mod;
if (k & 1) Add (ans, Mod - sum);
else Add (ans, sum);
}
cout << ans << endl;
}

inline void Input ()
{
N = read<int>(), M = read<int>();
}

int main()
{

#ifdef hk_cnyali
freopen("F.in", "r", stdin);
freopen("F.out", "w", stdout);
#endif

MATH :: init ();
Input ();
Solve ();

return 0;
}