题目链接

这是一道经典的树状数组题目,同时也是莫队、线段树算法很好的练习题。

题目描述

HH 有一串由各种漂亮的贝壳组成的项链。HH 相信不同的贝壳会带来好运,所以每次散步完后,他都会随意取出一段贝壳,思考它们所表达的含义。HH 不断地收集新的贝壳,因此,他的项链变得越来越长。

有一天,他突然提出了一个问题:某一段贝壳中,包含了多少种不同的贝壳?这个问题很难回答…… 因为项链实在是太长了。于是,他只好求助睿智的你,来解决这个问题。

输入格式

一行一个正整数 n,表示项链长度。
第二行 n 个正整数 a_i,表示项链中第 i 个贝壳的种类。

第三行一个整数 m,表示 HH 询问的个数。
接下来 m 行,每行两个整数 l,r,表示询问的区间。

输出格式

输出 m 行,每行一个整数,依次表示询问对应的答案。

样例 #1

样例输入 #1
6
1 2 3 4 3 5
3
1 2
3 5
2 6
样例输出 #1
2
2
4

提示

【数据范围】

1\le n,m,a_i \leq 10^61\le l \le r \le n

本题可能需要较快的读入方式,最大数据点读入数据约 20MB

题解

这道题需要维护一个区间中出现数字的数量,因为数字可能在任何地方出现,所以不可能一次维护好,然后 O(log_n)读取,所以我们要将数据离线下来。

将数据离线下来肯定还不够,既然都离线了,那么进行排序肯定也是可以的,因为排序之后,对树状数组进行添加或者删除都可以在 \mathcal{O}(nlog_n) 的总时间完成。我们按照询问的右值排序,依次处理。

排完序之后,我们要求出这个答案,不妨说是在重叠情况下最靠右的答案合集。什么意思呢,在求以 r 为右端点的区间时,我只需要处理每一类最靠近 r 的数字,这样 l 不管是多大,答案都保证是正确的。换一种说法,就是靠 r 比较远的相同的数字可以直接忽略

如何忽略这些数字,现在来考虑树状数组记录的状态。我们让树状数组记录 x 的种类个数的前缀和,查询就是 query(r) – query(l – 1)。然后用一个 pre 数组保存上一个 a_i 的下标,最后只需要通过这个 pre 数组就可以完成下标的处理,也就是自动更新到最右的取值。每次往右看的时候,在加上一个种类的时候总要把 pre[i] 减掉。

今天头有点晕,题解如果写的不好请看代码

Raw Code 原生代码(bushi

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
#define maxd(x, y) ((x) > (y) ? (x) : (y))
#define mind(x, y) ((x) < (y) ? (x) : (y))
#define abs_sub(x, y) ((x) > (y) ? (x) - (y) : (y) - (x))
const double eps = 1e-8;
const double pi = acos(-1.0);
const int N = 1e6 + 10;
inline int read() {
  char ch = getchar(); int p = 1, x = 0;
  while(!isdigit(ch)) {if(ch == '-') {p = -1;} ch = getchar();}
  while(isdigit(ch)) {x = x * 10 + ch - '0'; ch = getchar();}
  return x;}
int n, a, m, tree[N], pre[N], now[N], ans[N];
inline int lowbit(int x) {return x & -x;}
inline void update(int x, int d) {
  if(x == 0) return;
  while(x <= n) {
    tree[x] += d;
    x += lowbit(x);
  }
}
inline int query(int x) {
  int ans = 0;
  while(x > 0) {
    ans += tree[x];
    x -= lowbit(x);
  }
  return ans;
}
struct Qu {
  int l, r, pos;
}que[N];
inline bool cmp(Qu x, Qu y) {
  return x.r < y.r;
}
int main()
{
  // freopen("P1972.in", "r", stdin);
  // freopen(".out", "w", stdout);
  n = read();
  for(int i = 1; i <= n; i++) {
    a = read();
    pre[i] = now[a];
    now[a] = i;
  }
  // for(int i = 1; i <= n; i++) {
  //   printf("%d ", pre[i]);
  // }
  m = read();
  for(int i = 1; i <= m; i++) {que[i].l = read(); que[i].r =read(); que[i].pos = i;}
  sort(que + 1, que + 1 + m, cmp);
  //其实区间只要按照右值排序就可以了,当我右值更新的时候,记录上一个右值,遍历到现在的右值把前面的减掉把新的加上
  int prer = 0;
  for(int i = 1; i <= m; i++) {
    while(prer < que[i].r) {
      update(++prer, 1);
      if(pre[prer]) update(pre[prer], -1);//减去上一个
    }
    // for(int i = 1; i <= n; i++) {
    //   printf("%d ", query(i));
    // }
    // puts("");
    ans[que[i].pos] = query(que[i].r) - query(que[i].l - 1);
  }
  for(int i = 1; i <= m; i++) {
    printf("%d\n", ans[i]);
  }
  return 0;
}