// C++ code for the above approach:
#include <bits/stdc++.h>
using namespace std;
// Persistent Segment Tree
class PST {
public:
struct Node {
int val;
int lc;
int rc;
Node(int v = 0, int l = -1, int r = -1)
: val(v), lc(l), rc(r)
{
}
};
vector<Node> t;
PST() { t.push_back({}); }
int update(int cur, int l, int r, int idx, int diff)
{
int new_node = t.size();
// Creates a new version of node
t.push_back(t[cur]);
t[new_node].val += diff;
if (l == r) {
return new_node;
}
int mid = (l + r) >> 1;
if (idx <= mid) {
t[new_node].lc
= update(t[cur].lc, l, mid, idx, diff);
}
else {
t[new_node].rc
= update(t[cur].rc, mid + 1, r, idx, diff);
}
return new_node;
}
int query(int cur, int l, int r, int ql, int qr)
{
if (ql <= l && r <= qr) {
return t[cur].val;
}
int mid = (l + r) >> 1;
int res = 0;
if (ql <= mid) {
res += query(t[cur].lc, l, mid, ql, qr);
}
if (mid < qr) {
res += query(t[cur].rc, mid + 1, r, ql, qr);
}
return res;
}
};
// Merge Sort Tree with point update
class MST {
public:
vector<int> a;
PST pst;
vector<int> root;
int n;
MST(vector<int> v)
{
a = v;
n = v.size();
root.push_back(pst.update(0, 0, n - 1, v[0], 1));
for (int i = 1; i < n; i++) {
root.push_back(
pst.update(root.back(), 0, n - 1, v[i], 1));
}
}
int build(int l, int r)
{
if (l == r) {
return pst.update(0, 0, n - 1, a[l], 1);
}
int mid = (l + r) >> 1;
int left = build(l, mid);
int right = build(mid + 1, r);
return merge(left, right, l, r);
}
int merge(int left, int right, int l, int r)
{
int cur = root.size();
root.push_back(0);
int mid = (l + r) >> 1;
for (int i = l, j = mid + 1; i <= mid || j <= r;) {
if (j > r || (i <= mid && a[i] < a[j])) {
root[cur] = pst.update(root[cur], 0, n - 1,
a[i], 1);
i++;
}
else {
root[cur] = pst.update(root[cur], 0, n - 1,
a[j], 1);
j++;
}
}
return cur;
}
int query(int cur, int l, int r, int ql, int qr, int x)
{
if (ql <= l && r <= qr) {
return pst.query(root[cur], 0, n - 1, 0, x);
}
int mid = (l + r) >> 1;
int res = 0;
if (ql <= mid) {
res += query(cur - 1, l, mid, ql, qr, x);
}
if (mid < qr) {
res += query(cur - 1, mid + 1, r, ql, qr, x);
}
return res;
}
void update(int idx, int new_val)
{
a[idx] = new_val;
root.push_back(build(0, n - 1));
}
};
// Drivers code
int main()
{
vector<int> v = { 1, 4, 2, 3, 5 };
MST mst(v);
// Query for values less than or
// equal to 3 in the range [1, 3]
cout << mst.query(v.size() - 1, 0, v.size() - 1, 1, 3,
3)
<< endl;
// Update a[2] = 5
mst.update(2, 5);
// Query for values less than or equal
// to 3 in the range [1, 3]
cout << mst.query(v.size() - 1, 0, v.size() - 1, 1, 3,
3)
<< endl;
return 0;
}