进过改善的代码如下:
2024-01-21 11:33:22
发布于:北京
27阅读
0回复
0点赞
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
using pi = pair<int, int>;
#define mp make_pair
#define f first
#define s second
using vi = vector<int>;
#define sz(x) int((x).size())
#define all(x) begin(x), end(x)
#define sor(x) sort(all(x))
#define pb push_back
#define bk back()
const int MOD = 1e9 + 7;
template <class T> void remDup(vector<T> &v) {
sort(all(v));
v.erase(unique(all(v)), end(v));
}
struct mi {
int v;
explicit operator int() const { return v; }
mi() : v(0) {}
mi(ll _v) : v(int(_v % MOD)) { v += (v < 0) * MOD; }
};
mi &operator+=(mi &a, mi b) {
if ((a.v += b.v) >= MOD)
a.v -= MOD;
return a;
}
mi operator+(mi a, mi b) { return a += b; }
const int HASHMAX = 16 * 1000;
struct Table {
mi vals[HASHMAX];
bitset<HASHMAX> visited;
vi keys;
void addTo(int key, mi v) {
if (visited[key] == 0) {
visited[key] = 1;
keys.pb(key);
}
vals[key] += v;
}
void reset() {
for (const auto &u : keys) {
vals[u] = 0;
visited[u] = 0;
}
keys.clear();
}
};
vector<vi> good_subs;
bool is_good_new_set[100005];
void genGoodSubs() {
for (int i = 1; i <= 9; i++) {
good_subs.pb({i});
}
for (int i = 1; i + 3 <= 9; i++) {
good_subs.pb({i, i + 3});
}
for (int i = 1; i <= 2; i++) {
for (int j = 0; j <= 6; j += 3) {
int first_val = i + j;
good_subs.pb({first_val, first_val + 1});
}
}
for (int i = 1; i <= 2; i++) {
for (int j = 0; j <= 3; j += 3) {
vi v;
for (int k = 0; k <= 1; k++) {
for (int l = 0; l <= 3; l += 3) {
v.pb(i + j + k + l);
}
}
sor(v);
good_subs.pb(v);
}
}
for (auto u : good_subs) {
int mask = 0;
for (auto x : u) {
mask += 1 << x;
}
is_good_new_set[mask] = 1;
}
}
void solve() {
string S_inp;
cin >> S_inp;
vi S{-100};
for (auto u : S_inp) {
S.pb(u - '0');
}
int N = sz(S) - 1;
vector<vi> S_masks = vector<vi>(N + 1, vi(5));
for (int i = 1; i <= N; i++) {
for (int j = 0; j <= 4; j++) {
for (int k = 0; k <= j; k++) {
S_masks[i][j] |= (1 << S[i + k]);
}
}
}
Table *dp = new Table();
Table *ndp = new Table();
dp->addTo(1 + 111 * 16, mi(1));
for (int i = 1; i <= N; i++) {
vi cand_new_digs;
for (int j = -3; j <= 3; j++) {
if (i + j >= 1 && i + j <= N) {
cand_new_digs.pb(S[i + j]);
}
}
remDup(cand_new_digs);
ndp->reset();
for (auto u : dp->keys) {
int bars = u % 16;
int nums = u / 16;
int max_bars = 0;
for (int j = 0; j < 4; j++) {
if ((bars >> j) & 1) {
max_bars = j;
}
}
mi ways = dp->vals[u];
for (int new_dig : cand_new_digs) {
array<int, 4> all_nums_arr{new_dig, nums % 10, (nums / 10) % 10,
(nums / 100) % 10};
int new_bars = 0;
int bar_2_set = 0;
for (int old_bar = 0; old_bar <= max_bars; old_bar++) {
int bar_2_set_dig = all_nums_arr[old_bar];
if ((bar_2_set >> bar_2_set_dig) & 1) {
break;
} else {
bar_2_set |= 1 << bar_2_set_dig;
if ((bars >> old_bar) & 1) {
if ((bar_2_set & S_masks[i - old_bar][3]) == bar_2_set) {
if (is_good_new_set[bar_2_set] &&
bar_2_set == S_masks[i - old_bar][old_bar]) {
new_bars |= 1;
}
if (old_bar < 3) {
new_bars |= 1 << (old_bar + 1);
}
}
}
}
}
if (new_bars == 0)
continue;
for (int j = 3; j; --j) {
if (new_bars & (1 << j)) {
break;
}
all_nums_arr[j - 1] = 0;
}
int new_nums =
all_nums_arr[0] + 10 * all_nums_arr[1] + 100 * all_nums_arr[2];
ndp->addTo(new_bars + 16 * new_nums, ways);
}
}
swap(dp, ndp);
}
mi ans = 0;
for (int u : dp->keys) {
if (((u % 16) >> 0) & 1) {
ans += dp->vals[u];
}
}
cout << ans.v << "\n";
}
int main() {
cin.tie(0)->sync_with_stdio(0);
genGoodSubs();
int T;
cin >> T;
for (int t = 1; t <= T; t++) {
solve();
}
}
这里空空如也
有帮助,赞一个