1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
| #include <bits/stdc++.h> #define ld double #define ull unsigned long long #define ll long long #define pii pair <int, int> #define iiii pair <int, pii > #define mp make_pair #define INF 1000000000 #define rep(i, x) for(int (i) = 0; (i) < (x); (i)++) inline int getint() { int x = 0, p = 1; char c = getchar(); while (c <= 32) c = getchar(); if (c == 45) p = -p, c = getchar(); while (c > 32) x = x * 10 + c - 48, c = getchar(); return x * p; } using namespace std; const int mod = 1e9 + 7; inline void reduce(int &x) { x += x >> 31 & mod; } inline int mul(int x, int y) { return 1ll * x * y % mod; } //ruogu_alter const int N = 5e5 + 5; int n, dep[N], k, a[N], sz[N]; vector<int> g[N]; bool fg[N]; ll res; // void init(int x, int p, int d) { sz[x] = 1; dep[x] = d; for (int to : g[x]) if (to != p) { init(to, x, d + 1); sz[x] += sz[to]; } } bool cmp(int x, int y) { return dep[x] - sz[x] > dep[y] - sz[y]; } void dfs(int x, int p, int c) { if (fg[x]) res += 1ll * c; c += (!fg[x]); for (int to : g[x]) if (to != p) { dfs(to, x, c); } } int main() { n = getint(); k = getint(); rep(i, n - 1) { int x = getint() - 1, y = getint() - 1; g[x].emplace_back(y); g[y].emplace_back(x); } init(0, -1, 0); rep(i, n) a[i] = i; sort(a, a + n, cmp); rep(i, k) fg[a[i]] = true; dfs(0, -1, 0); cout << res << endl; return 0; }
|