UESTC 2404

题目

给一颗\(n\)节点的树,1节点为根节点,求最少减去多少条边可以得到一颗\(p\)节点的树。

题解

参照洛谷P1272

基本上是原题了。

\(dp[i][j]\)表示以\(i\)为根节点时获得节点数为\(p\)的树最少需要的修建次数。状态转移方程 \[ dp[i][j]=\min\{dp[i][j-k]+dp[son][k]-1|k\in[1,j]\} \] \(son\)\(i\)的子节点。

考虑到不一定1为根节点(\(dp[1][p]\))才是最优,所以我们可以遍历每一个节点作为根节点的情况。注意到除了根节点之外,其他节点为根节点的话需要断开自己与父节点的边,所以遍历的时候+1即可。

还有一个坑是输入。每行输入的是\(u\)\(v\)表示\(u\)\(v\)之间有一条边,然而并不知道哪个是父节点,所以得专门做一次处理,我懒得搞骚操作直接暴力了,速度还行吧。

代码

AC代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
// @author: hzy
//#pragma G++ optimize("O3")

#include <algorithm>
#include <cmath>
#include <cstdio>
#include <ctime>
#include <cstring>
#include <iostream>
#include <queue>
#include <map>
#include <set>
#include <string>
#include <list>
#include <forward_list>
#include <stack>
#include <unordered_set>
#include <vector>
#include <limits.h>

#define DISPLAY_A 0

using namespace std;

const int MAX = 200 + 7;

set<int> nodes[201];

int dp[201][201];

int dfs(int root) {
int sum = 1;
for (auto son:nodes[root]) {
sum += dfs(son);
// TODO: optimize
for (int j = sum; j >= 1; j--) {
for (int k = 1; k < j; k++) {
dp[root][j] = min(dp[root][j], dp[root][j - k] + dp[son][k] - 1);
}
}
}
return sum;
}

int dset[201];

struct Edge {
int u, v;
};

int main() {
std::ios::sync_with_stdio(false);
std::cin.tie(0);
int n, p;
cin >> n >> p;
int u, v;
dset[1] = 1;
for (int i = 1; i <= n; i++) {
for (int j = 1; j <= n; j++) {
dp[i][j] = n;
}
}
// 构建单向边
map<int, Edge> edge;
int cnt = 0;
for (int i = 1; i < n; i++) {
cin >> u >> v;
// 干就完事了
if (u == 1) {
nodes[u].insert(v);
dset[v] = u;
} else if (v == 1) {
nodes[v].insert(u);
dset[u] = v;
} else if (dset[u]) {
nodes[u].insert(v);
dset[v] = 1;
} else if (dset[v]) {
nodes[v].insert(u);
dset[u] = 1;
} else {
// 姑且记录一下
edge[cnt].u = u;
edge[cnt].v = v;
cnt += 1;
}
}
set<int> af;
while (!edge.empty()) {
for (auto &it:edge) {
if (dset[it.second.u]) {
nodes[it.second.u].insert(it.second.v);
dset[it.second.v] = 1;
af.insert(it.first);
} else if (dset[it.second.v]) {
nodes[it.second.v].insert(it.second.u);
dset[it.second.u] = 1;
af.insert(it.first);
}
}
for (auto &it:af) {
edge.erase(it);
}
af.clear();
}
// init
for (int i = 1; i <= n; i++) {
dp[i][1] = nodes[i].size();
}
// search
dfs(1);
int ans = dp[1][p];
#if DISPLAY_A
int idx = 1;
#endif
for (int i = 2; i <= n; i++) {
#if DISPLAY_A
if (ans > dp[i][p] + 1) {
ans = dp[i][p] + 1;
idx = i;
}
#else
ans = min(ans, dp[i][p] + 1);
#endif
}
#if DISPLAY_A
cout << ans << " " << idx << endl;
#else
cout << ans << endl;
#endif
return 0;
}

测试例生成(Python)

1
2
3
4
5
6
7
8
9
10
def N(n, p, filename='in'):
n = int(n)
p = int(p)
f = open(filename, 'w')
f.write("{} {}\r\n{}\r\n".format(
n, p,
"\r\n".join([
"{} {}".format(random.randint(1, i - 1), i) for i in range(2, n + 1)
])
))