运输计划 题解
2023-09-09 11:51:46
发布于:广东
8阅读
0回复
0点赞
给定一棵树以及树上的 条通路,我们可以在树上选取一条边,将其权重置为 ,目标是
20pts(m=1)
当 时,我们只需要求出树上的一条链上的权重和与权重最大值即可。
50pts
考虑一种暴力的算法,枚举将哪一条边权重置 ,然后重新在树上求解 条通路的权重。这个过程可以用 LCA 优化。任选一个结点为根,预处理出每一条通路的两个端点的 LCA,则路径长度可以通过树上差分快速计算。时间复杂度为 。
80pts(树退化为链)
当树退化为链时,这就是一个纯粹的数据结构问题。这类最小化最大值的问题可以考虑二分答案,将其转化为判定问题。给定一个权重上界 w ww 后,我们可以 O ( m ) O(m)O(m) 算出哪些通路的权重是超过这个上界的,而这些通路全部位于一条链上,因此我们可以 O ( m ) O(m)O(m) 求出它们的交集。然后在交集中找到权重最大的一条边,如果将这条边的权重置 后,所有通路的权重均不超过 ,那么 就是一个可行的上界。
100pts
上面的二分答案方法给了我们初步的思路。现在只需考虑树上给定权重上界后如何判定:首先任取一个结点作为根结点,并将边权下推为点权。使用差分维护数组 表示从 的父结点到 的这条边被经过了多少次。假设有 个权重超过上届的通路,那么我们只要考虑被经过 次的边即可。
AC代码
#include<bits/stdc++.h>
using namespace std;
const int maxn = 3e5 + 10;
struct edge
{
int v, w;
int nxt;
} e[maxn << 1];
int n, m,
ver[maxn], w[maxn], a[maxn], b[maxn], c[maxn], d[maxn], f[maxn], s[maxn], num,
top[maxn], fa[maxn], size[maxn], son[maxn], dep[maxn], dis[maxn],
l, r, mid, ans, maxw;
inline void adde(int u, int v, int w)
{
static int ed = 1;
e[++ed] = (edge){ v, w, ver[u] };
ver[u] = ed;
}
inline void dfs1(int u, int f)
{
s[++num] = u;
size[u] = 1, fa[u] = f;
for(int i = ver[u]; i; i = e[i].nxt)
{
int v = e[i].v;
if(size[v])
continue;
dep[v] = dep[u] + 1;
w[v] = e[i].w;
dis[v] = dis[u] + w[v];
dfs1(v, u);
size[u] += size[v];
if(size[son[u]] < size[v])
son[u] = v;
}
}
inline void dfs2(int u, int t)
{
top[u] = t;
if(son[u])
dfs2(son[u], t);
for(int i = ver[u]; i; i = e[i].nxt)
{
int v = e[i].v;
if(v == fa[u] || v == son[u])
continue;
dfs2(v, v);
}
}
inline int lca(int u, int v)
{
while(top[u] != top[v])
{
if(dep[top[u]] < dep[top[v]])
swap(u, v);
u = fa[top[u]];
}
return dep[u] < dep[v] ? u : v;
}
inline bool check(int k)
{
memset(f, 0, sizeof f);
int cnt = 0;
for(int i = 1; i <= m; i++)
{
if(d[i] <= k)
continue;
f[a[i]]++, f[b[i]]++, f[c[i]] -= 2;
cnt++;
}
for(int i = n; i >= 1; i--)
{
f[fa[s[i]]] += f[s[i]];
if(w[s[i]] >= maxw - k && f[s[i]] == cnt)
return true;
}
return false;
}
inline int read()
{
static int x;
static char c;
x = 0, c = getchar();
while(!isdigit(c))
c = getchar();
while(isdigit(c))
x = (x << 1) + (x << 3) + (c ^ 48), c = getchar();
return x;
}
int main()
{
n = read(), m = read();
for(int i = 1; i < n; i++)
{
int u = read(), v = read(), w = read();
adde(u, v, w);
adde(v, u, w);
l = max(l, w);
}
dep[1] = 1;
dfs1(1, 0);
dfs2(1, 1);
for(int i = 1; i <= m; i++)
{
a[i] = read(), b[i] = read();
c[i] = lca(a[i], b[i]);
d[i] = dis[a[i]] + dis[b[i]] - (dis[c[i]] << 1);
r = max(r, d[i]);
}
maxw = r, l = maxw - l, r++;
while(l <= r)
{
mid = (l + r) >> 1;
if(check(mid))
ans = mid, r = mid - 1;
else
l = mid + 1;
}
printf("%d", ans);
return 0;
}
这里空空如也
有帮助,赞一个