CSP-S 模拟赛 17
T1 zzy 的金牌
1.30pts1.30pts1.30pts 做法
直接爆搜,枚举每个点放几个,用“结构体 + setsetset”去重。
代码:
#include <bits/stdc++.h>#define mkpr make_pair
#define fir first
#define sec secondusing namespace std;typedef pair<int, int> pii;
typedef unsigned long long ull;
typedef long long ll;
typedef long double ld;
typedef double db;const int maxn = 3e2 + 7;
const int inf = 0x3f3f3f3f;
const int mod = 998244353;int n, K, a[maxn];
namespace Sub1 {struct node {int x[maxn];bool operator < (const node& y) const {for (int i = 1; i < n; ++i)if (x[i] != y.x[i]) return x[i] < y.x[i];return x[n] < y.x[n];}} tmp;int ans;set<node> uni;void dfs(int now) {if (now == K + 1) {for (int i = 1; i <= n; ++i) tmp.x[i] = a[i];sort(tmp.x + 1, tmp.x + n + 1);if (!uni.count(tmp)) {
// puts("tmp:");
// for (int i = 1; i <= n; ++i)
// printf("%d ", tmp.x[i]);
// puts("");++ans, ans %= mod;uni.insert(tmp);}return ;}for (int i = 1; i <= n; ++i)++a[i], dfs(now + 1), --a[i];}void Main() {dfs(1);printf("%d\n", ans);exit(0);}
}int main() {scanf("%d%d", &n, &K);for (int i = 1; i <= n; ++i) scanf("%d", a + i);if (n <= 7 && K <= 7) Sub1::Main();return 0;
}
2.60pts2.60pts2.60pts 做法
考虑怎样得到的最终序列是合法的:
设最后得到的序列是 sis_isi,我们给 aaa 和 sss 都升序排序(这样就可以避免出现重复的答案),只有 ∀si≥ai\forall s_i \geq a_i∀si≥ai 且 ∑i=1nsi−∑i=1nai=k\sum_{i = 1}^n s_i - \sum_{i = 1}^n a_i = k∑i=1nsi−∑i=1nai=k 才满足条件。
那么我们令 bi=si−aib_i = s_i - a_ibi=si−ai,我们实际上就要求 bbb 的合法个数,那么就有:
- bi+ai≤bi+1+ai+1b_i + a_i \leq b_{i + 1} + a_{i + 1}bi+ai≤bi+1+ai+1;
- ∑i=1n=k.\sum_{i = 1}^n = k.∑i=1n=k.
设 dpi,j,kdp_{i, j, k}dpi,j,k 表示前 iii 个数中,∑x=1ibx=j\sum_{x = 1}^i b_x = j∑x=1ibx=j 且 bi=kb_i = kbi=k 时的方案数,那么枚举上一个数 bi−1=pb_{i - 1} = pbi−1=p 就有转移方程:
dpi,j,k=∑p=0j−kdpi−1,j−k,p, 其中 ai−1+p≤ai+k.
dp_{i, j, k} = \sum_{p = 0}^{j - k} dp_{i - 1, j - k, p}, \space 其中 \space a_{i - 1} + p \leq a_i + k.
dpi,j,k=p=0∑j−kdpi−1,j−k,p, 其中 ai−1+p≤ai+k.
时间复杂度 O(nk3)O(nk^3)O(nk3)。
代码:
#include <bits/stdc++.h>#define mkpr make_pair
#define fir first
#define sec secondusing namespace std;typedef pair<int, int> pii;
typedef unsigned long long ull;
typedef long long ll;
typedef long double ld;
typedef double db;const int maxn = 3e2 + 7;
const int inf = 0x3f3f3f3f;
const int mod = 998244353;int n, m, a[maxn];int dp[maxn][maxn][maxn];
int main() {scanf("%d%d", &n, &m);for (int i = 1; i <= n; ++i) scanf("%d", a + i);sort(a + 1, a + n + 1);for (int i = 0; i <= m; ++i) dp[1][i][i] = 1;for (int i = 2; i <= n; ++i) {for (int j = 0; j <= m; ++j) {for (int k = 0; k <= j; ++k) {for (int p = 0; p <= min(j - k, k + a[i] - a[i - 1]); ++p)(dp[i][j][k] += dp[i - 1][j - k][p]) %= mod;}}}int ans = 0;for (int i = 0; i <= m; ++i)(ans += dp[n][m][i]) %= mod;printf("%d\n", ans);return 0;
}
3.100pts3.100pts3.100pts 做法
发现最内层转移就是求前缀和,所以直接用 gi,j,kg_{i, j, k}gi,j,k 记录 dpi,j,kdp_{i, j, k}dpi,j,k 的前缀和,时间复杂度 O(nk2)O(nk^2)O(nk2)。
代码:
#include <bits/stdc++.h>#define mkpr make_pair
#define fir first
#define sec secondusing namespace std;typedef pair<int, int> pii;
typedef unsigned long long ull;
typedef long long ll;
typedef long double ld;
typedef double db;const int maxn = 3e2 + 7;
const int inf = 0x3f3f3f3f;
const int mod = 998244353;int n, m, a[maxn];int dp[maxn][maxn][maxn];
int g[maxn][maxn][maxn];
int main() {scanf("%d%d", &n, &m);for (int i = 1; i <= n; ++i) scanf("%d", a + i);sort(a + 1, a + n + 1);for (int i = 0; i <= m; ++i)dp[1][i][i] = 1, g[1][i][i] = (g[1][i][i - 1] + dp[1][i][i]) % mod;for (int i = 2; i <= n; ++i) {for (int j = 0; j <= m; ++j) {for (int k = 0; k <= j; ++k) {
// for (int p = 0; p <= min(j - k, k + a[i] - a[i - 1]); ++p)
// (dp[i][j][k] += dp[i - 1][j - k][p]) %= mod;dp[i][j][k] = g[i - 1][j - k][min(j - k, k + a[i] - a[i - 1])];g[i][j][k] = (g[i][j][k - 1] + dp[i][j][k]) % mod;}}}int ans = 0;for (int i = 0; i <= m; ++i)(ans += dp[n][m][i]) %= mod;printf("%d\n", ans);return 0;
}
T3 作弊
1.8pts1.8pts1.8pts 做法
直接爆搜将小朋友划分成若干个不相交的区间,然后统计能满足几个小朋友,取其中的 maxmaxmax。
代码:
namespace Sub1 {int b[maxn]; // 枚举每个区间的结束位置int ans;void dfs(int now, int start) {if (start == n + 1) {int L = 1, R = 0, res = 0;for (int i = 1; i <= now; ++i) {R = b[i]; int mx = 0;for (int j = L; j <= R; ++j) mx = max(mx, a[j]);for (int j = L; j <= R; ++j)if (max(a[j], mx) >= l[j] && max(a[j], mx) <= r[j])++res;L = b[i] + 1;}ans = max(res, ans);return ;}for (int i = start; i <= n; ++i) {b[now] = i;dfs(now + 1, i + 1);}}void Main() {dfs(1, 1);printf("%d\n", ans);exit(0);}
}
2.20pts2.20pts2.20pts 做法
首先可以看出这是一个 dpdpdp,说所以设 dpidp_idpi 表示前 iii 个同学中最多可以满足的个数。
那么就有转移:
dpi=maxj=0i−1{dpj+cnt(j+1,i)}
dp_i = max_{j = 0}^{i - 1} \left \{dp_j + cnt(j + 1, i) \right \}
dpi=maxj=0i−1{dpj+cnt(j+1,i)}
其中 cnt(l,r)cnt(l, r)cnt(l,r) 表示区间 [l,r][l, r][l,r] 内赋成最大值后有多少个同学可以被满足。
时间复杂度 O(n3)O(n^3)O(n3),代码如下:
namespace Sub2 {int dp[107];void Main() {for (int i = 1; i <= n; ++i) {for (int j = 0; j < i; ++j) {int mx = 0, val = 0;for (int k = j + 1; k <= i; ++k) mx = max(mx, a[k]);for (int k = j + 1; k <= i; ++k)if (mx >= l[k] && mx <= r[k]) ++val;dp[i] = max(dp[i], dp[j] + val);}}printf("%d\n", dp[n]);}
}
3.40pts3.40pts3.40pts 做法
20pts20pts20pts 做法的时间瓶颈在求区间最大值和统计区间内被满足的个数。
对于前者,我们可以用 STSTST 表维护;对于后者我们可以开 500050005000 颗树状数组,biti.ask(x)bit_i.ask(x)biti.ask(x) 统计当最大值为 iii 时,前 xxx 个小朋友有多少个可以被满足。
时间复杂度 O(n2log(n))O(n^2log(n))O(n2log(n))。(TMD 维护最大值用的线段树,常数太大没过)
代码:
#include <bits/stdc++.h>#define mkpr make_pair
#define fir first
#define sec secondusing namespace std;typedef pair<int, int> pii;
typedef unsigned long long ull;
typedef long long ll;
typedef long double ld;
typedef double db;const int maxn = 1e5 + 7;
const int inf = 0x3f3f3f3f;int n, a[maxn], L[maxn], R[maxn];int dp[maxn], rt;
int st[maxn][20], lg2[maxn];
void init() {for (int i = 2; i <= n; ++i) lg2[i] = lg2[i >> 1] + 1;for (int i = 1; i <= n; ++i) st[i][0] = a[i];for (int j = 1; j <= lg2[n]; ++j)for (int i = 1; i + (1 << j) - 1 <= n; ++i)st[i][j] = max(st[i][j - 1], st[i + (1 << j - 1)][j - 1]);
}
int ask(int l, int r) {int pw = lg2[r - l + 1];return max(st[l][pw], st[r - (1 << pw) + 1][pw]);
}
struct BIT {int s[5007];void mdf(int p, int x) {for (; p <= n; p += (p & -p)) s[p] += x;}int ask(int p) {int res = 0;for (; p; p -= (p & -p)) res += s[p];return res;}
} bit[5007];int main() {scanf("%d", &n);for (int i = 1; i <= n; ++i) scanf("%d", a + i); for (int i = 1; i <= n; ++i) scanf("%d%d", L + i, R + i);init();for (int i = 1; i <= n; ++i)for (int j = 1; j <= n; ++j)if (i >= L[j] && i <= R[j]) bit[i].mdf(j, 1);for (int i = 1; i <= n; ++i) {for (int j = 0; j < i; ++j) {int mx = ask(j + 1, i);dp[i] = max(dp[i], dp[j] + bit[mx].ask(i) - bit[mx].ask(j));}}printf("%d\n", dp[n]);return 0;
}
4.100pts4.100pts4.100pts 做法
发现现在唯一的时间瓶颈就是求 max{dpj−1+cst(j,i)}max \left \{ dp_{j - 1} + cst(j, i) \right\}max{dpj−1+cst(j,i)},这和基站选址有点像,可以用线段数优化。
对于 iii 而言,我们把条件 li≤max(l,...,i,...,r)≤ril_i \leq max(l, ..., i, ...,r) \leq r_ili≤max(l,...,i,...,r)≤ri 变为【max(l,...,i)max(l, ..., i)max(l,...,i) 和 max(i,...,r)max(i, ..., r)max(i,...,r) 都 ≤ri\leq r_i≤ri 且至少有一个 ≥li\geq l_i≥li】。我们可以用二分求出左右端点的取值范围分别记为 lli,lri,rli,rrill_i, lr_i, rl_i, rr_illi,lri,rli,rri。
现在分类讨论:
- 如果 lrilr_ilri 存在,那么对于所有的右端点 ∈[i,rri]\in [i, rr_i]∈[i,rri] 且左端点 ∈[lli,lri]\in [ll_i, lr_i]∈[lli,lri],iii 都有贡献;
- 如果 rlirl_irli 存在,那么对于所有右端点 ∈[rli,rri]\in [rl_i, rr_i]∈[rli,rri] 且左端点 ∈[lli,i]\in [ll_i,i]∈[lli,i],iii 都有贡献。
代码:
#include <bits/stdc++.h>#define il inline#define mkpr make_pair
#define fir first
#define sec secondusing namespace std;typedef pair<int, int> pii;
typedef unsigned long long ull;
typedef long double ld;
typedef double db;const int maxn = 1e5 + 7;
const int inf = 0x3f3f3f3f;int n, a[maxn], L[maxn], R[maxn];int dp[maxn], rt;
int st[maxn][20], lg2[maxn];
void init() {for (int i = 2; i <= n; ++i) lg2[i] = lg2[i >> 1] + 1;for (int i = 1; i <= n; ++i) st[i][0] = a[i];for (int j = 1; j <= lg2[n]; ++j)for (int i = 1; i + (1 << j) - 1 <= n; ++i)st[i][j] = max(st[i][j - 1], st[i + (1 << j - 1)][j - 1]);
}
int ask(int l, int r) {int pw = lg2[r - l + 1];return max(st[l][pw], st[r - (1 << pw) + 1][pw]);
}
int ll[maxn], lr[maxn], rl[maxn], rr[maxn];
void init2() {for (int i = 1, l, r; i <= n; ++i) {// 求 ll[i], 即当左端点在 [ll[i], i] 时最大值 <= R[i]l = 1, r = i;while (l <= r) {int mid = l + r >> 1;if (ask(mid, i) <= R[i])ll[i] = mid, r = mid - 1;else l = mid + 1;}// 求 lr[i], 即当左端点 <= lr[i] 时最大值 >= L[i]l = 1, r = i;while (l <= r) {int mid = l + r >> 1;if (ask(mid, i) >= L[i])lr[i] = mid, l = mid + 1;else r = mid - 1;}// 求 rl[i], 即当右端点 >= rl[i] 时最大值 >= L[i]l = i, r = n;while (l <= r) {int mid = l + r >> 1;if (ask(i, mid) >= L[i])rl[i] = mid, r = mid - 1;else l = mid + 1;}// 求 rr[i], 即当右端点在 [i, rr[i]] 时最大值 <= R[i]l = i, r = n;while (l <= r) {int mid = l + r >> 1;if (ask(i, mid) <= R[i])rr[i] = mid, l = mid + 1;else r = mid - 1;}}
}
#define ls (p << 1)
#define rs (p << 1 | 1)
#define mid (l + r >> 1)
struct SegmentTree {int mx[maxn << 2], tg[maxn << 2];il void pushup(int p) {mx[p] = max(mx[ls], mx[rs]);}il void pushdown(int p) {mx[ls] += tg[p], tg[ls] += tg[p];mx[rs] += tg[p], tg[rs] += tg[p];tg[p] = 0;}void mdf(int p, int l, int r, int ql, int qr, int v) {if (ql > qr || qr < l || r < ql) return ;if (ql <= l && r <= qr) {mx[p] += v, tg[p] += v;return ;}pushdown(p);mdf(ls, l, mid, ql, qr, v);mdf(rs, mid + 1, r, ql, qr, v);pushup(p);}int ask(int p, int l, int r, int ql, int qr) {if (ql > qr || qr < l || r < ql) return -inf;if (ql <= l && r <= qr) return mx[p];pushdown(p);return max(ask(ls, l, mid, ql, qr), ask(rs, mid + 1, r, ql, qr));}
} sgt;
struct node {int l, r, v;};
vector<node> v[maxn];
int main() {scanf("%d", &n);for (int i = 1; i <= n; ++i) scanf("%d", a + i);for (int i = 1; i <= n; ++i) scanf("%d%d", L + i, R + i);init(), init2();
// for (int i = 1; i <= n; ++i)
// printf("i:%d, ll:%d, lr:%d, rl:%d, rr:%d\n",
// i, ll[i], lr[i], rl[i], rr[i]);for (int i = 1; i <= n; ++i) {if (a[i] > R[i] || (!lr[i] && !rl[i])) continue;if (rl[i]) { // 可能存在 [i, n] 区间内的数的最大值都小于 L[i]v[rl[i]].push_back(node{ll[i], i, 1});v[rr[i] + 1].push_back(node{ll[i], i, -1});}if (lr[i]) { // 可能存在 [1, i] 区间内的数的最大值都小于 L[i]v[i].push_back(node{ll[i], lr[i], 1});v[rl[i]].push_back(node{ll[i], lr[i], -1});}}for (int i = 1; i <= n; ++i) {for (node p : v[i]) sgt.mdf(1, 1, n, p.l, p.r, p.v);dp[i] = sgt.ask(1, 1, n, 1, i);sgt.mdf(1, 1, n, i + 1, i + 1, dp[i]);
// printf("dp[%d]:%d\n", i, dp[i]);}printf("%d\n", dp[n]);return 0;
}