这是一道经典的树状数组题目,同时也是莫队、线段树算法很好的练习题。
题目描述
HH 有一串由各种漂亮的贝壳组成的项链。HH 相信不同的贝壳会带来好运,所以每次散步完后,他都会随意取出一段贝壳,思考它们所表达的含义。HH 不断地收集新的贝壳,因此,他的项链变得越来越长。
有一天,他突然提出了一个问题:某一段贝壳中,包含了多少种不同的贝壳?这个问题很难回答…… 因为项链实在是太长了。于是,他只好求助睿智的你,来解决这个问题。
输入格式
一行一个正整数 n,表示项链长度。
第二行 n 个正整数 a_i,表示项链中第 i 个贝壳的种类。
第三行一个整数 m,表示 HH 询问的个数。
接下来 m 行,每行两个整数 l,r,表示询问的区间。
输出格式
输出 m 行,每行一个整数,依次表示询问对应的答案。
样例 #1
样例输入 #1
6
1 2 3 4 3 5
3
1 2
3 5
2 6
样例输出 #1
2
2
4
提示
【数据范围】
1\le n,m,a_i \leq 10^6,1\le l \le r \le n。
本题可能需要较快的读入方式,最大数据点读入数据约 20MB
题解
这道题需要维护一个区间中出现数字的数量,因为数字可能在任何地方出现,所以不可能一次维护好,然后 O(log_n)读取,所以我们要将数据离线下来。
将数据离线下来肯定还不够,既然都离线了,那么进行排序肯定也是可以的,因为排序之后,对树状数组进行添加或者删除都可以在 \mathcal{O}(nlog_n) 的总时间完成。我们按照询问的右值排序,依次处理。
排完序之后,我们要求出这个答案,不妨说是在重叠情况下最靠右的答案合集。什么意思呢,在求以 r 为右端点的区间时,我只需要处理每一类最靠近 r 的数字,这样 l 不管是多大,答案都保证是正确的。换一种说法,就是靠 r 比较远的相同的数字可以直接忽略。
如何忽略这些数字,现在来考虑树状数组记录的状态。我们让树状数组记录 x 的种类个数的前缀和,查询就是 query(r) – query(l – 1)。然后用一个 pre 数组保存上一个 a_i 的下标,最后只需要通过这个 pre 数组就可以完成下标的处理,也就是自动更新到最右的取值。每次往右看的时候,在加上一个种类的时候总要把 pre[i] 减掉。
今天头有点晕,题解如果写的不好请看代码
Raw Code 原生代码(bushi
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
#define maxd(x, y) ((x) > (y) ? (x) : (y))
#define mind(x, y) ((x) < (y) ? (x) : (y))
#define abs_sub(x, y) ((x) > (y) ? (x) - (y) : (y) - (x))
const double eps = 1e-8;
const double pi = acos(-1.0);
const int N = 1e6 + 10;
inline int read() {
char ch = getchar(); int p = 1, x = 0;
while(!isdigit(ch)) {if(ch == '-') {p = -1;} ch = getchar();}
while(isdigit(ch)) {x = x * 10 + ch - '0'; ch = getchar();}
return x;}
int n, a, m, tree[N], pre[N], now[N], ans[N];
inline int lowbit(int x) {return x & -x;}
inline void update(int x, int d) {
if(x == 0) return;
while(x <= n) {
tree[x] += d;
x += lowbit(x);
}
}
inline int query(int x) {
int ans = 0;
while(x > 0) {
ans += tree[x];
x -= lowbit(x);
}
return ans;
}
struct Qu {
int l, r, pos;
}que[N];
inline bool cmp(Qu x, Qu y) {
return x.r < y.r;
}
int main()
{
// freopen("P1972.in", "r", stdin);
// freopen(".out", "w", stdout);
n = read();
for(int i = 1; i <= n; i++) {
a = read();
pre[i] = now[a];
now[a] = i;
}
// for(int i = 1; i <= n; i++) {
// printf("%d ", pre[i]);
// }
m = read();
for(int i = 1; i <= m; i++) {que[i].l = read(); que[i].r =read(); que[i].pos = i;}
sort(que + 1, que + 1 + m, cmp);
//其实区间只要按照右值排序就可以了,当我右值更新的时候,记录上一个右值,遍历到现在的右值把前面的减掉把新的加上
int prer = 0;
for(int i = 1; i <= m; i++) {
while(prer < que[i].r) {
update(++prer, 1);
if(pre[prer]) update(pre[prer], -1);//减去上一个
}
// for(int i = 1; i <= n; i++) {
// printf("%d ", query(i));
// }
// puts("");
ans[que[i].pos] = query(que[i].r) - query(que[i].l - 1);
}
for(int i = 1; i <= m; i++) {
printf("%d\n", ans[i]);
}
return 0;
}