programming-examples/java/Data_Structures/SegmentTreeSumLowerBound.java
2019-11-15 12:59:38 +01:00

100 lines
2.3 KiB
Java

public class SegmentTreeSumLowerBound {
int[] s;
int n;
public SegmentTreeSumLowerBound(int n) {
this.n = n;
s = new int[4 * n];
buildTree(1, 0, n - 1);
}
void buildTree(int node, int left, int right) {
if (left != right) {
int mid = (left + right) >> 1;
buildTree(node * 2, left, mid);
buildTree(node * 2 + 1, mid + 1, right);
s[node] = s[node * 2] + s[node * 2 + 1];
}
}
// T[i] += value, assuming value >= 0 (otherwise lower_bound will not work)
public void add(int i, int value) {
add(i, value, 1, 0, n - 1);
}
void add(int i, int value, int node, int left, int right) {
if (left == right) {
s[node] += value;
return;
}
int mid = (left + right) >> 1;
if (i <= mid)
add(i, value, node * 2, left, mid);
else
add(i, value, node * 2 + 1, mid + 1, right);
s[node] = s[node * 2] + s[node * 2 + 1];
}
// Returns min(p | p<=b && sum[a..p]>=sum). If no such p exists, returns ~sum[a..b].
public int lower_bound(int a, int b, int sum) {
return lower_bound(a, b, sum, 1, 0, n - 1);
}
int lower_bound(int a, int b, int sum, int node, int left, int right) {
if (left > b || right < a)
return ~0;
if (left >= a && right <= b && s[node] < sum)
return ~s[node];
if (left == right)
return left;
int mid = (left + right) >> 1;
int res1 = lower_bound(a, b, sum, node * 2, left, mid);
if (res1 >= 0)
return res1;
int res2 = lower_bound(a, b, sum - ~res1, node * 2 + 1, mid + 1, right);
if (res2 >= 0)
return res2;
return ~(~res1 + ~res2);
}
// sum[a..b]
public int sum(int a, int b) {
return sum(a, b, 1, 0, n - 1);
}
int sum(int a, int b, int node, int left, int right) {
if (left >= a && right <= b)
return s[node];
int mid = (left + right) >> 1;
int res = 0;
if (a <= mid)
res += sum(a, b, node * 2, left, mid);
if (b > mid)
res += sum(a, b, node * 2 + 1, mid + 1, right);
return res;
}
// T[i]
public int get(int i) {
return sum(i, i);
}
// T[i] = value
public void set(int i, int value) {
add(i, -get(i) + value);
}
// Usage example
public static void main(String[] args) {
SegmentTreeSumLowerBound t = new SegmentTreeSumLowerBound(4);
t.set(0, 1);
t.set(1, 5);
t.set(2, 2);
t.set(3, 3);
System.out.println(1 == t.lower_bound(1, 3, 5));
t.set(1, 3);
System.out.println(3 == t.get(1));
System.out.println(2 == t.lower_bound(1, 3, 5));
}
}