「牛客 1103B」路径计数机

是道好题

可惜卡常数

思路:暴力枚举 (a,b)(a,b),快速统计 (c,d)(c,d)

定义 f(u)f(u) 为以 uuLCA\operatorname{LCA}(c,d)(c,d) 个数

tsf(u)=f(v)tsf(u)=\sum f(v)vv 在子树 uu

csf(u)csf(u) 11uu 路径上所有点的 f\sum f

(c,d)(c,d) 总数量记为 sqsq

经过 uuuu 的父节点点对 (c,d)(c,d) 数量为 vq(u)vq(u),可用树上差分求出

考虑对于点对 (a,b)(a,b),记 l=LCA(a,b)l=\operatorname{LCA}(a,b)

LCA(c,d)\operatorname{LCA}(c,d) 在子树 ll 中,数量为

tsf(l)(csf(a)csf(l))(csf(b)csf(l))f(l) tsf(l)-(csf(a)-csf(l))-(csf(b)-csf(l))-f(l)

LCA(c,d)\operatorname{LCA}(c,d) 不在子树 ll 中,数量为

sqtsf(l)vq(l) sq-tsf(l)-vq(l)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
#include <bits/stdc++.h>
#pragma GCC optimize("O3")
#pragma GCC optimize("Ofast,no-stack-protector")
#pragma GCC optimize("inline","fast-math","unroll-loops","no-stack-protector")
#define rep(i, l, r) for (register int i = (l); i <= (r); ++i)
#define per(i, l, r) for (register int i = (l); i >= (r); --i)
using std::make_pair; using std::pair; typedef pair<int, int> pii;
typedef long long ll; typedef unsigned int ui; typedef unsigned long long ull;

struct IO {
#define MAXSIZE (1 << 20)
#define isdigit(x) (x >= '0' && x <= '9')
char buf[MAXSIZE], *p1, *p2;
char pbuf[MAXSIZE], *pp;
#if DEBUG
#else
IO() : p1(buf), p2(buf), pp(pbuf) {}
~IO() { fwrite(pbuf, 1, pp - pbuf, stdout); }
#endif
inline char gc() {
#if DEBUG //调试,可显示字符
return getchar();
#endif
if (p1 == p2) p2 = (p1 = buf) + fread(buf, 1, MAXSIZE, stdin);
return p1 == p2 ? -1 : *p1++;
}
inline bool blank(char ch) { return ch == ' ' || ch == '\n' || ch == '\r' || ch == '\t'; }
template <class T>
inline void read(T &x) {
register double tmp = 1;
register bool sign = 0;
x = 0;
register char ch = gc();
for (; !isdigit(ch); ch = gc())
if (ch == '-') sign = 1;
for (; isdigit(ch); ch = gc()) x = x * 10 + (ch - '0');
if (ch == '.')
for (ch = gc(); isdigit(ch); ch = gc()) tmp /= 10.0, x += tmp * (ch - '0');
if (sign) x = -x;
}
inline void read(char *s) {
register char ch = gc();
for (; blank(ch); ch = gc())
;
for (; !blank(ch); ch = gc()) *s++ = ch;
*s = 0;
}
inline void read(char &c) {
for (c = gc(); blank(c); c = gc())
;
}
inline void push(const char &c) {
#if DEBUG //调试,可显示字符
putchar(c);
#else
if (pp - pbuf == MAXSIZE) fwrite(pbuf, 1, MAXSIZE, stdout), pp = pbuf;
*pp++ = c;
#endif
}
template <class T>
inline void write(T x) {
if (x < 0) x = -x, push('-'); // 负数输出
static T sta[35];
T top = 0;
do {
sta[top++] = x % 10, x /= 10;
} while (x);
while (top) push(sta[--top] + '0');
}
inline void write(const char *s) {
while (*s != '\0') push(*(s++));
}
template <class T>
inline void write(T x, char lastChar) {
write(x), push(lastChar);
}
} io;

const int N = 3010;

int n, p, q, lc[N][N];
std::vector<int> g[N];

int dep[N], fa[13][N];
void dfs_lca(int u){
dep[u] = dep[fa[0][u]] + 1;
for(auto v : g[u]){
if (v == fa[0][u]) continue;
fa[0][v] = u;
dfs_lca(v);
}
}

int lca(int u, int v){
if (dep[u] < dep[v]) std::swap(u, v);
int delta = dep[u] - dep[v];
rep(k, 0, 12) if (delta&(1<<k)) u = fa[k][u];
if (u == v) return u;
per(k, 12, 0)
if (fa[k][u] != fa[k][v])
u = fa[k][u], v = fa[k][v];
return fa[0][u];
}

ll f[N], tsf[N], csf[N], dq[N], vq[N];

void calc(int u){
csf[u] = csf[fa[0][u]] + f[u];
vq[u] = dq[u];
tsf[u] = f[u];
for(auto v : g[u]){
if (v == fa[0][u]) continue;
calc(v);
tsf[u] += tsf[v];
vq[u] += vq[v];
}
}

int main() {
#ifdef LOCAL
freopen("input", "r", stdin);
#endif
io.read(n), io.read(p), io.read(q);
int u, v;
rep(i, 2, n){
io.read(u), io.read(v);
g[u].push_back(v), g[v].push_back(u);
}
dfs_lca(1);
rep(k, 1, 12) rep(i, 1, n) fa[k][i] = fa[k-1][fa[k-1][i]];
ll sq = 0;
rep(i, 1, n){
rep(j, 1, n){
int l = lca(i, j);
lc[i][j] = l;
int d = dep[i] + dep[j] - dep[l] * 2;
if (d == q){
sq++;
f[l]++;
dq[i]++, dq[j]++, dq[l] -= 2;
}
}
}
calc(1);
ll ans = 0;
rep(i, 1, n){
rep(j, 1, n){
int l = lc[i][j];
int d = dep[i] + dep[j] - dep[l] * 2;
if (d == p){
ans += tsf[l]-(csf[i]-csf[l])-(csf[j]-csf[l])-f[l];
ans += sq - tsf[l] - vq[l];
}
}
}
io.write(ans);
return 0;
}

「牛客 1103B」路径计数机

https://gesrua.xyz/archives/题解/nowcoder/1103B

作者

Gesrua

发布于

2019-11-12

更新于

2020-11-21

许可协议