1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128
| #include <algorithm> #include <array> #include <iostream> #include <vector> using namespace std; using ll = long long; const ll mod = 998244353; const int N = 1e6 + 10; struct IO { static const int N = 1 << 20 | 1; char ibuf[N], *p, *q, c; char obuf[N], *o; bool s; int stk[50], t; IO() { p = q = ibuf, o = obuf; } char gc() { if (p == q) p = ibuf, q = ibuf + fread(ibuf, 1, sizeof ibuf, stdin); return p == q ? EOF : *p++; } template <typename T> IO &operator>>(T &x) { x = 0, s = false; do c = gc(); while (!isdigit(c)); s |= c == '-'; do x = x * 10 + c - '0', c = gc(); while (isdigit(c)); if (s) x = -x; return *this; } void flush() { fwrite(obuf, 1, o - obuf, stdout); o = obuf; } void wc(char x) { *o++ = x; if (o - obuf >= N) flush(); } IO &operator<<(char x) { wc(x); return *this; } template <typename T> IO &operator<<(T x) { if (x < 0) wc('-'); t = 0; do stk[++t] = x % 10, x /= 10; while (x); while (t) wc(stk[t--] + '0'); return *this; } ~IO() { flush(); } } io; int n, k; int a[N], pos[N]; int pre[N], nxt[N]; ll solve() { ll res = 0; for (int i = 1; i <= n; i++) { pre[i] = i - 1; nxt[i] = i + 1; } for (int i = 1; i <= n; i++) { int x = pos[i]; int l = x, r = x; int cnt = 1; while (cnt < k) { if (pre[l] != 0) { l = pre[l]; } else { break; } cnt++; } while (cnt < k) { if (nxt[r] != n + 1) { r = nxt[r]; } else { break; } cnt++; } if (cnt == k) { while (l <= x) { if (r == n + 1) break; res += a[x] * 1ll * (nxt[r] - r) % mod * 1ll * (l - pre[l]) % mod; res %= mod; l = nxt[l]; r = nxt[r]; } } nxt[pre[x]] = nxt[x]; pre[nxt[x]] = pre[x]; } return res; } int main() { io >> n >> k; for (int i = 1; i <= n; i++) { io >> a[i]; } for (int i = 1; i <= n; i++) { pos[i] = i; } sort(pos + 1, pos + 1 + n, [&](const int &A, const int &B) -> bool { return a[A] < a[B]; }); ll sum = 0; sum += solve(); sum %= mod; reverse(pos + 1, pos + 1 + n); sum += solve(); sum %= mod; cout << sum << endl; }
|