Solution
2024-08-16 18:57:45
发布于:广东
0阅读
0回复
0点赞
#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;
const int MAX_N = 50000;
struct partialSumTree {
int L, R, half, sum;
partialSumTree* left, * right;
partialSumTree(int l, int r) : L(l), R(r), half((L + R) / 2), sum(0) {
if (half == L) {
left = right = NULL;
}
else {
left = new partialSumTree(L, half);
right = new partialSumTree(half, R);
}
}
void updateValue(int idx, int delta) {
if (idx < L || idx >= R) return;
sum += delta;
if (half != L) {
(idx < half ? left : right)->updateValue(idx, delta);
}
}
int getSum(int A, int B) {
if (A >= R || B <= L) return 0;
if (A <= L && B >= R) return sum;
return left->getSum(A, B) + right->getSum(A, B);
}
};
vector<int> neighbours[MAX_N];
int startTime[MAX_N];
int endTime[MAX_N];
int firstPassDFS(int curNode, int curTime) {
if (startTime[curNode] != -1) return curTime;
startTime[curNode] = curTime++;
for (vector<int>::iterator it = neighbours[curNode].begin(); it != neighbours[curNode].end(); it++) {
curTime = firstPassDFS(*it, curTime);
}
return endTime[curNode] = curTime;
}
vector<int> beginPath[MAX_N], endPath[MAX_N];
int ans = -1;
partialSumTree* bySource, * byDestination;
int secondPassDFS(int curNode, int curTime) {
if (startTime[curNode] != curTime) return curTime;
curTime++;
int passingThrough = byDestination->getSum(startTime[curNode], endTime[curNode]);
for (vector<int>::iterator it = beginPath[curNode].begin(); it != beginPath[curNode].end(); it++) {
bySource->updateValue(startTime[curNode], 1);
byDestination->updateValue(startTime[*it], 1);
passingThrough++;
}
for (vector<int>::iterator it = neighbours[curNode].begin(); it != neighbours[curNode].end(); it++) {
int prevTime = curTime;
curTime = secondPassDFS(*it, curTime);
passingThrough += bySource->getSum(prevTime, curTime);
}
for (vector<int>::iterator it = endPath[curNode].begin(); it != endPath[curNode].end(); it++) {
bySource->updateValue(startTime[*it], -1);
byDestination->updateValue(startTime[curNode], -1);
}
ans = max(ans, passingThrough);
return curTime;
}
int main() {
int N, K;
cin >> N >> K;
for (int i = 1; i < N; i++) {
int x, y;
cin >> x >> y;
x--; y--;
neighbours[x].push_back(y);
neighbours[y].push_back(x);
}
fill(startTime, startTime + N, -1);
fill(endTime, endTime + N, -1);
firstPassDFS(0, 0);
bySource = new partialSumTree(0, N);
byDestination = new partialSumTree(0, N);
for (int i = 0; i < K; i++) {
int s, t;
cin >> s >> t;
s--; t--;
if (startTime[s] > startTime[t]) swap(s, t);
beginPath[s].push_back(t);
endPath[t].push_back(s);
}
secondPassDFS(0, 0);
cout << ans << endl;
return 0;
}
这里空空如也
有帮助,赞一个