100 lines
2.3 KiB
Java
100 lines
2.3 KiB
Java
|
public class SegmentTreeSumLowerBound {
|
||
|
int[] s;
|
||
|
int n;
|
||
|
|
||
|
public SegmentTreeSumLowerBound(int n) {
|
||
|
this.n = n;
|
||
|
s = new int[4 * n];
|
||
|
buildTree(1, 0, n - 1);
|
||
|
}
|
||
|
|
||
|
void buildTree(int node, int left, int right) {
|
||
|
if (left != right) {
|
||
|
int mid = (left + right) >> 1;
|
||
|
buildTree(node * 2, left, mid);
|
||
|
buildTree(node * 2 + 1, mid + 1, right);
|
||
|
s[node] = s[node * 2] + s[node * 2 + 1];
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// T[i] += value, assuming value >= 0 (otherwise lower_bound will not work)
|
||
|
public void add(int i, int value) {
|
||
|
add(i, value, 1, 0, n - 1);
|
||
|
}
|
||
|
|
||
|
void add(int i, int value, int node, int left, int right) {
|
||
|
if (left == right) {
|
||
|
s[node] += value;
|
||
|
return;
|
||
|
}
|
||
|
int mid = (left + right) >> 1;
|
||
|
if (i <= mid)
|
||
|
add(i, value, node * 2, left, mid);
|
||
|
else
|
||
|
add(i, value, node * 2 + 1, mid + 1, right);
|
||
|
s[node] = s[node * 2] + s[node * 2 + 1];
|
||
|
}
|
||
|
|
||
|
// Returns min(p | p<=b && sum[a..p]>=sum). If no such p exists, returns ~sum[a..b].
|
||
|
public int lower_bound(int a, int b, int sum) {
|
||
|
return lower_bound(a, b, sum, 1, 0, n - 1);
|
||
|
}
|
||
|
|
||
|
int lower_bound(int a, int b, int sum, int node, int left, int right) {
|
||
|
if (left > b || right < a)
|
||
|
return ~0;
|
||
|
if (left >= a && right <= b && s[node] < sum)
|
||
|
return ~s[node];
|
||
|
if (left == right)
|
||
|
return left;
|
||
|
int mid = (left + right) >> 1;
|
||
|
int res1 = lower_bound(a, b, sum, node * 2, left, mid);
|
||
|
if (res1 >= 0)
|
||
|
return res1;
|
||
|
int res2 = lower_bound(a, b, sum - ~res1, node * 2 + 1, mid + 1, right);
|
||
|
if (res2 >= 0)
|
||
|
return res2;
|
||
|
return ~(~res1 + ~res2);
|
||
|
}
|
||
|
|
||
|
// sum[a..b]
|
||
|
public int sum(int a, int b) {
|
||
|
return sum(a, b, 1, 0, n - 1);
|
||
|
}
|
||
|
|
||
|
int sum(int a, int b, int node, int left, int right) {
|
||
|
if (left >= a && right <= b)
|
||
|
return s[node];
|
||
|
int mid = (left + right) >> 1;
|
||
|
int res = 0;
|
||
|
if (a <= mid)
|
||
|
res += sum(a, b, node * 2, left, mid);
|
||
|
if (b > mid)
|
||
|
res += sum(a, b, node * 2 + 1, mid + 1, right);
|
||
|
return res;
|
||
|
}
|
||
|
|
||
|
// T[i]
|
||
|
public int get(int i) {
|
||
|
return sum(i, i);
|
||
|
}
|
||
|
|
||
|
// T[i] = value
|
||
|
public void set(int i, int value) {
|
||
|
add(i, -get(i) + value);
|
||
|
}
|
||
|
|
||
|
// Usage example
|
||
|
public static void main(String[] args) {
|
||
|
SegmentTreeSumLowerBound t = new SegmentTreeSumLowerBound(4);
|
||
|
t.set(0, 1);
|
||
|
t.set(1, 5);
|
||
|
t.set(2, 2);
|
||
|
t.set(3, 3);
|
||
|
System.out.println(1 == t.lower_bound(1, 3, 5));
|
||
|
t.set(1, 3);
|
||
|
System.out.println(3 == t.get(1));
|
||
|
System.out.println(2 == t.lower_bound(1, 3, 5));
|
||
|
}
|
||
|
}
|