题解
2023-08-25 14:11:53
发布于:广东
21阅读
0回复
0点赞
#include<bits/stdc++.h>
typedef long long ll;
const ll inf=0x3f3f3f3f3f3f3f3f;
using namespace std;
const int N=100010;
ll tot,head[N];
ll n,m,a[N],fa[20][N],deep[N],Log[N];
ll dp[3][N],f[3][3][20][N];
struct edge{
int ver,to;
}e[N*2];
void add(ll x,ll y){
e[++tot].ver =y;
e[tot].to =head[x];
head[x]=tot;
}
void dfs(ll x){
deep[x]=deep[fa[0][x]]+1;
dp[1][x]=a[x];
f[0][0][0][x]=inf;
for(ll i=1; (1<<(i-1))<deep[x]; i++) fa[i][x]=fa[i-1][fa[i-1][x]];
for(ll i=head[x]; i; i=e[i].to){
ll y=e[i].ver;
if(y!=fa[0][x]){
fa[0][y]=x;
dfs(y);
dp[0][x]+=dp[1][y];
dp[1][x]+=min(dp[0][y],dp[1][y]);
}
}
}
const ll cur[18]={1,2,4,8,16,32,64,128,256,512,1024,2048,4096,8192,16384,32768,65536,131072};
void pre(ll x){
f[1][0][0][x]=dp[0][fa[0][x]]-dp[1][x],
f[0][1][0][x]=f[1][1][0][x]=dp[1][fa[0][x]]-min(dp[0][x],dp[1][x]);
for(ll i=1; cur[i]<deep[x]; i++){
ll y=fa[i-1][x];
f[0][0][i][x]=min(f[0][0][i-1][x]+f[0][0][i-1][y],f[0][1][i-1][x]+f[1][0][i-1][y]),
f[0][1][i][x]=min(f[0][0][i-1][x]+f[0][1][i-1][y],f[0][1][i-1][x]+f[1][1][i-1][y]),
f[1][0][i][x]=min(f[1][0][i-1][x]+f[0][0][i-1][y],f[1][1][i-1][x]+f[1][0][i-1][y]),
f[1][1][i][x]=min(f[1][0][i-1][x]+f[0][1][i-1][y],f[1][1][i-1][x]+f[1][1][i-1][y]);
}
for(ll i=head[x]; i; i=e[i].to){
if(e[i].ver!=fa[0][x]) pre(e[i].ver);
}
}
void lca(ll u,ll x,ll v,ll y){
if(deep[u]<deep[v]){
swap(u,v);
swap(x,y);
}
ll L,u0=inf,u1=inf,v0=inf,v1=inf,l0=inf,l1=inf,ans;
if(x)u1=dp[1][u];
else u0=dp[0][u];
if(y)v1=dp[1][v];
else v0=dp[0][v];
for(ll i=Log[deep[u]-deep[v]]; i>=0; i--){
if(deep[u]-cur[i]>=deep[v]){
ll t0=u0,t1=u1;
u0=min(t0+f[0][0][i][u],t1+f[1][0][i][u]);
u1=min(t0+f[0][1][i][u],t1+f[1][1][i][u]);
u=fa[i][u];
}
}
if(u==v){
L=u;
if(y)l1=u1;
else l0=u0;
}
else{
for(ll i=Log[deep[u]-1]; i>=0; i--){
if(fa[i][u]!=fa[i][v]){
ll t0=u0,t1=u1,p0=v0,p1=v1;
u0=min(t0+f[0][0][i][u],t1+f[1][0][i][u]);
u1=min(t0+f[0][1][i][u],t1+f[1][1][i][u]);
v0=min(p0+f[0][0][i][v],p1+f[1][0][i][v]);
v1=min(p0+f[0][1][i][v],p1+f[1][1][i][v]);
u=fa[i][u];
v=fa[i][v];
}
}
L=fa[0][u];
l0=dp[0][L]-dp[1][u]-dp[1][v]+u1+v1;
l1=dp[1][L]-min(dp[0][u],dp[1][u])-min(dp[0][v],dp[1][v])+min(u0,u1)+min(v0,v1);
}
if(L==1) ans=min(l0,l1);
else{
for(ll i=Log[deep[L]-2]; i>=0; i--){
if(deep[L]-cur[i]>1){
ll t0=l0,t1=l1;
l0=min(t0+f[0][0][i][L],t1+f[1][0][i][L]);
l1=min(t0+f[0][1][i][L],t1+f[1][1][i][L]);
L=fa[i][L];
}
}
ans=min(dp[0][1]-dp[1][L]+l1,dp[1][1]-min(dp[0][L],dp[1][L])+min(l0,l1));
}
cout<<(ans<inf?ans:-1)<<'\n';
}
ll read(){
ll sum=0,f=1;
char ch=getchar();
while(ch>'9'||ch<'0')
{
if(ch=='-')f=-1;
ch=getchar();
}
while(ch<='9'&&ch>='0'){
sum=(sum<<3)+(sum<<1)+ch-'0';
ch=getchar();
}
return sum*f;
}
int main(){
string s;
n=read();
m=read();
cin>>s;
for(ll i=1; i<=n; i++) a[i]=read();
for(ll i=1; i<n; i++){
ll x,y;
cin>>x>>y;
add(x,y);
add(y,x);
Log[i]=Log[i>>1]+1;
}
dfs(1);
pre(1);
while(m--){
ll u,x,v,y;
u=read();
x=read();
v=read();
y=read();
lca(u,x,v,y);
}
return 0;
}
内存最优!
这里空空如也
有帮助,赞一个