T6
2024-12-11 11:04:48
发布于:浙江
11阅读
0回复
0点赞
本题首先应该不难想到要用 来维护答案,我们令 表示 以第 个数为结尾能得到的最大价值。那么
我们可以不断的枚举 前面的下标 ,令,则有三种情况:
- ,
- ,
- ,
答案就是我们的
#include <bits/stdc++.h>
#define int long long
using namespace std;
const int N = 500010;
int n;
int dp[N], a[N], sum[N];
signed main(){
cin >> n;
for(int i = 1; i <= n; i ++ ){
cin >> a[i];
sum[i] = sum[i - 1] + a[i];
dp[i] = -9e18;
}
for(int i = 1; i <= n; i ++ ){
for(int j = 1; j <= i; j ++ ){
int len = i - j + 1;
if(sum[i] - sum[j - 1] > 0) dp[i] = max(dp[i], dp[j - 1] - j + i + 1);
else if(sum[i] - sum[j - 1] < 0) dp[i] = max(dp[i], dp[j - 1] + j - i - 1);
else dp[i] = dp[j - 1];
}
}
cout << dp[n];
return 0;
}
但是很明显,这样会超时,我们可以仔细观察那 个式子,把 和 所在式子分离出来:
- ,
- ,
- ,
所以首先我们只需要将所有的前缀和当作下标,维护出 的最大值即可,维护这个最大值可以用树状数组或者线段树等数据结构,然后每次求 的时候,只需要在对应的区间上求对应的最大值即可,假设当前前缀和离散后的下标为 :
- 考虑 的情况,只需要在 中间找最大的 ,
- 考虑 的情况,只需要在 中间找 的最大值
- 考虑 的情况,只需要在 中间找 的最大值
其次前缀会比较大,所以需要离散化。
#include <bits/stdc++.h>
#define int long long
#define ls u << 1
#define rs u << 1 | 1
using namespace std;
const int N = 500010;
int a[N], dp[N], n, sum[N];
/*
维护 dp[j] - j, dp[j] + j, dp[j]的区间最大值
*/
struct Node{
int maxn;
}seg[4][N * 4];
vector<int>q;
int find(int x){
return lower_bound(q.begin(), q.end(), x) - q.begin() + 1;
}
void pushup(int id, int u){
seg[id][u].maxn = max(seg[id][ls].maxn, seg[id][rs].maxn);
}
void build(int id, int u, int l, int r){
if(l == r){
seg[id][u].maxn = -9e18;
return;
}
int mid = l + r >> 1;
build(id, ls, l, mid); build(id, rs, mid + 1, r);
pushup(id, u);
}
void update(int id, int u, int l, int r, int pos, int val){
if(l == r && l == pos){
seg[id][u].maxn = max(seg[id][u].maxn, val);
return;
}
int mid = l + r >> 1;
if(pos <= mid) update(id, ls, l, mid, pos, val);
else update(id, rs, mid + 1, r, pos, val);
pushup(id, u);
}
int query(int id, int u, int l, int r, int ql, int qr){
if(l == ql && r == qr) return seg[id][u].maxn;
int mid = l + r >> 1;
if(qr <= mid) return query(id, ls, l, mid, ql, qr);
else if(ql > mid) return query(id, rs, mid + 1, r, ql, qr);
else return max(query(id, ls, l, mid, ql, mid), query(id, rs, mid + 1, r, mid + 1, qr));
}
signed main(){
cin >> n;
for(int i = 1; i <= n; i ++ ){
cin >> a[i];
sum[i] = sum[i - 1] + a[i];
dp[i] = -9e18;
}
for(int i = 1; i <= n; i ++ ){
q.push_back(sum[i]);
}
sort(q.begin(), q.end());
q.erase(unique(q.begin(), q.end()), q.end());
for(int i = 1; i <= 3; i ++ ){//初始化
build(i, 1, 1, n + 1);
update(i, 1, 1, n + 1, find(0), 0);
}
for(int i = 1; i <= n; i ++ ){
int sum_idx = find(sum[i]);
int x = query(1, 1, 1, n + 1, sum_idx, sum_idx);//区间和等于0的情况
int y = -9e18;
if(sum_idx - 1 >= 1) y = query(2, 1, 1, n + 1, 1, sum_idx - 1);//区间和大于0
int z = -9e18;
if(sum_idx + 1 <= n) z = query(3, 1, 1, n + 1, sum_idx + 1, n + 1);//区间和小于0
dp[i] = max({dp[i], x, y + i, z - i});
update(1, 1, 1, n + 1, sum_idx, dp[i]);
update(2, 1, 1, n + 1, sum_idx, dp[i] - i);
update(3, 1, 1, n + 1, sum_idx, dp[i] + i);
}
cout << dp[n];
return 0;
}
这里空空如也
有帮助,赞一个