题解
2023-08-09 10:20:39
发布于:浙江
27阅读
0回复
0点赞
基本树的应用
#include<bits/stdc++.h>
using namespace std;
#define Mem(a, x) memset(a, x, sizeof(a))
#define forE(i, x, y) for(int i = head[x], y; ~i && (y = e[i].to); i = e[i].nxt)
typedef long long ll;
const int maxn = 5e4 + 5;
const int N = 16;
int n, m, a[maxn];
struct Edge
{
int nxt, to, w;
}e[maxn << 1];
int head[maxn], ecnt = -1;
void addEdge(int x, int y, int w)
{
e[++ecnt] = (Edge){head[x], y, w};
head[x] = ecnt;
}
int fa[N + 2][maxn], dep[maxn], p[maxn];
ll dis1[maxn], dis[N + 2][maxn];
void dfs1(int u, int _f)
{
for(int i = 1; (1 << i) <= dep[u]; ++i)
{
fa[i][u] = fa[i - 1][fa[i - 1][u]];
dis[i][u] = dis[i - 1][u] + dis[i - 1][fa[i - 1][u]];
}
forE(i, u, v)
{
if(v == _f) continue;
dep[v] = dep[u] + 1;
fa[0][v] = u, dis[0][v] = e[i].w;
dis1[v] = dis1[u] + e[i].w;
p[v] = p[u] ? p[u] : v;
dfs1(v, u);
}
}
bool vis[maxn];
vector<int> cap;
vector<ll> res, res_son;
void jump(int x, ll mid)
{
if(dis1[x] <= mid) cap.push_back(x);
else
{
for(int i = N; i >= 0; --i)
if(fa[i][x] && dis[i][x] <= mid) mid -= dis[i][x], x = fa[i][x];
vis[x] = 1;
}
}
bool ned[maxn];
bool dfs2(int u, int _f)
{
if(vis[u]) return 0;
bool flg = 0, flg_leaf = 1;
forE(i, u, v) if(v != _f) flg |= dfs2(v, u), flg_leaf = 0;
return flg_leaf | flg;
}
bool check(ll mid)
{
Mem(vis, 0), cap.clear(); res.clear(); res_son.clear(); Mem(ned, 0);
for(int i = 1; i <= m; ++i) jump(a[i], mid);
forE(i, 1, v) if(dfs2(v, 1)) ned[v] = 1;
sort(cap.begin(), cap.end(), [&](int a, int b) {return mid - dis1[a] < mid - dis1[b];});
for(int i = 0; i < (int)cap.size(); ++i)
{
int x = cap[i], son = p[x];
if(ned[son] && mid - dis1[x] < dis[0][son]) ned[son] = 0;
else res.push_back(mid - dis1[x]);
}
sort(res.begin(), res.end());
forE(i, 1, v) if(ned[v]) res_son.push_back(dis[0][v]);
sort(res_son.begin(), res_son.end());
for(int i = 0, j = 0; i < (int)res_son.size(); ++i)
{
while(j < (int)res.size() && res[j] < res_son[i]) ++j;
if(j == (int)res.size()) return 0;
++j;
}
return 1;
}
ll solve(ll M)
{
ll L = 0, R = M + 1;
while(L < R)
{
ll mid = (L + R) >> 1;
if(check(mid)) R = mid;
else L = mid + 1;
}
return L == M + 1 ? -1 : L;
}
int main()
{
Mem(head, -1), ecnt = -1;
scanf("%d", &n);
ll Sum = 0;
for(int i = 1, u, v, w; i < n; ++i)
{
scanf("%d%d%d", &u, &v, &w);
addEdge(u, v, w), addEdge(v, u, w);
Sum += w;
}
scanf("%d", &m);
for(int i = 1; i <= m; ++i) scanf("%d", &a[i]);
dep[1] = 1, dfs1(1, 0);
printf("%lld\n", solve(Sum));
return 0;
}
这里空空如也
有帮助,赞一个