函数调用 题解(超详细)
2023-08-26 15:29:55
发布于:广东
50阅读
0回复
0点赞
首先,基本的函数只有两种:整体乘法和单点加法;我们将其称作“原子函数”,而其它函数都是这两种函数的复合,称作“复合函数”。
注意到任意一个复合函数 都可以表示成如下的形式
其中
现在考虑多个复合函数要如何计算,以三个函数的复合为例,设
则
一般地
从上式中可以看出,若干个复合函数的 就等于各项的 的乘积,而复合函数的 的计算比较复杂,如果暴力计算,将会有 的时间复杂度,这也是该题算法设计的瓶颈所在。
现在来考虑如何计算复合函数的加法项 。重新审视
该式中的系数为,这也相当于把的加法部分重复执行了 次。也就是说,如果被执行了 次,那么的加法就会被执行次。而又是其它函数的复合,于是这个系数可以 push_down 下去,就像线段树中加法的 lazy_tag 一样。push_down 的过程直到原子函数终止。
根据题意,所有的函数调用会构成一个 DAG。即:如果函数 调用了函数 ,那么就从 到连一条有向边
for (int j = 1; j <= m; ++j) {
scanf("%d", &T[j]); // T[j] 是函数 f_j 的类型,T[j]=1,2 为原子函数,T[j]=3 为复合函数
k[j] = 1; // k[j] 就是函数 f_j 的 k_j
if (T[j] == 1) {
scanf("%d%lld", &P[j], &V[j]);
} else if (T[j] == 2) {
scanf("%lld", &k[j]);
} else if (T[j] == 3) {
int C; scanf("%d", &C);
for (int l = 1; l <= C; ++l) {
int g; scanf("%d", &g);
G[j].push_back(g); // G 存储了这个有向图
}
}
}
对于每个函数的 ,我们可以用记忆化搜索求解
void dfs(int j) {
// vis[j]=true 表示 j 已经被访问过
if (vis[j]) return;
vis[j] = true;
int size = G[j].size();
for (int ptr = size-1; ptr >= 0; --ptr) {
int g = G[j][ptr];
++deg[g], dfs(g);
k[j] = (k[j]*k[g]) % p; // p=998244353
}
}
现在我们来考虑如何 push_down 每个函数的加法标记,首先,对于直接调用的 个函数,它们的加法标记会反向传播(我们可以在这个过程中顺便计算全局乘法)
long long w = 1;
for (int j = q; j >= 1; --j) {
K = (K*k[idx[j]]) % p; // K 是全局乘法标记
t[idx[j]] = (t[idx[j]]+w) % p; // t[j] 表示函数 f_j 的加法被执行了 t[j] 次
w = (w*k[idx[j]]) % p;
}
然后,对于函数调用构成的 DAG,可以使用拓扑排序来 push_down 加法标记
// Q 是拓扑排序的队列
for (int j = 1; j <= m; ++j)
if (deg[j] == 0) Q.push(j);
while (!Q.empty()) {
int j = Q.front(); Q.pop();
if (vis[j]) continue;
vis[j] = true;
LL w = 1;
int size = G[j].size();
// 要注意这里遍历的顺序,对于每个结点,都要根据其子结点的复合顺序,反向传播 push_down 加法标记
for (int ptr = size-1; ptr >= 0; --ptr) {
int g = G[j][ptr];
t[g] = (t[g] + (w*t[j])%p) % p;
w = (w*k[g]) % p;
deg[g] -= 1; // deg[g] 是 g 的入度
if (deg[g] == 0) Q.push(g);
}
}
完整AC代码
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#include <vector>
#include <queue>
using namespace std;
#define LL long long
const int N = 3e5 + 10;
const LL p = 998244353;
int n, m, q,T[N], P[N], idx[N];
LL a[N], V[N];
vector<int> G[N];
int deg[N];
LL K = 1, k[N], t[N];
bool vis[N]; // dfs 和 bfs 记录
void dfs(int j) {
if (vis[j]) return;
vis[j] = true;
int size = G[j].size();
for (int ptr = size-1; ptr >= 0; --ptr) {
int g = G[j][ptr];
++deg[g], dfs(g);
k[j] = (k[j]*k[g]) % p;
}
}
queue<int> Q;
void topSort() {
for (int j = 1; j <= m; ++j)
if (deg[j] == 0) Q.push(j);
while (!Q.empty()) {
int j = Q.front(); Q.pop();
if (vis[j]) continue;
vis[j] = true;
LL w = 1;
int size = G[j].size();
/* push_down */
for (int ptr = size-1; ptr >= 0; --ptr) {
int g = G[j][ptr];
t[g] = (t[g] + (w*t[j])%p) % p;
w = (w*k[g]) % p;
deg[g] -= 1;
if (deg[g] == 0) Q.push(g);
}
}
}
int main() {
/* 读取数据 */
cin >> n;
for (int i = 1; i <= n; ++i) scanf("%lld", &a[i]);
cin >> m;
for (int j = 1; j <= m; ++j) {
scanf("%d", &T[j]);
k[j] = 1;
if (T[j] == 1) {
scanf("%d%lld", &P[j], &V[j]);
} else if (T[j] == 2) {
scanf("%lld", &k[j]);
} else if (T[j] == 3) {
int C; scanf("%d", &C);
for (int l = 1; l <= C; ++l) {
int g; scanf("%d", &g);
G[j].push_back(g);
}
}
}
cin >> q;
for (int j = 1; j <= q; ++j) scanf("%d", &idx[j]);
/* dfs 计算 k[j] */
memset(vis, false, sizeof(vis));
for (int j = 1; j <= m; ++j) dfs(j);
/* 计算乘法 K,打加法标记 */
LL w = 1;
for (int j = q; j >= 1; --j) {
K = (K*k[idx[j]]) % p;
t[idx[j]] = (t[idx[j]]+w) % p;
w = (w*k[idx[j]]) % p;
}
/* 拓扑排序 push_down 加法标记 */
memset(vis, false, sizeof(vis)), topSort();
/* 计算答案 */
for (int i = 1; i <= n; ++i)
a[i] = (a[i]*K) % p;
for (int j = 1; j <= m; ++j) {
if (T[j] != 1) continue;
a[P[j]] = (a[P[j]] + ((t[j]*V[j])%p)) % p;
}
for (int i = 1; i <= n; ++i) printf("%lld ", a[i]);
putchar('\n');
return 0;
}
欢迎加入团队!!!
这里空空如也
有帮助,赞一个