programming-examples/java/Data_Structures/IntervalTree.java
2019-11-15 12:59:38 +01:00

475 lines
14 KiB
Java

package com.jwetherell.algorithms.data_structures;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
/**
* An interval tree is an ordered tree data structure to hold intervals.
* Specifically, it allows one to efficiently find all intervals that overlap
* with any given interval or point.
*
* http://en.wikipedia.org/wiki/Interval_tree
*
* @author Justin Wetherell <phishman3579@gmail.com>
*/
public class IntervalTree<O extends Object> {
private Interval<O> root = null;
private static final Comparator<IntervalData<?>> START_COMPARATOR = new Comparator<IntervalData<?>>() {
/**
* {@inheritDoc}
*/
@Override
public int compare(IntervalData<?> arg0, IntervalData<?> arg1) {
// Compare start first
if (arg0.start < arg1.start)
return -1;
if (arg1.start < arg0.start)
return 1;
return 0;
}
};
private static final Comparator<IntervalData<?>> END_COMPARATOR = new Comparator<IntervalData<?>>() {
/**
* {@inheritDoc}
*/
@Override
public int compare(IntervalData<?> arg0, IntervalData<?> arg1) {
// Compare end first
if (arg0.end > arg1.end)
return -1;
if (arg1.end > arg0.end)
return 1;
return 0;
}
};
/**
* Create interval tree from list of IntervalData objects;
*
* @param intervals
* is a list of IntervalData objects
*/
public IntervalTree(List<IntervalData<O>> intervals) {
if (intervals.size() <= 0)
return;
root = createFromList(intervals);
}
protected static final <O extends Object> Interval<O> createFromList(List<IntervalData<O>> intervals) {
Interval<O> newInterval = new Interval<O>();
if (intervals.size()==1) {
IntervalData<O> middle = intervals.get(0);
newInterval.center = ((middle.start + middle.end) / 2);
newInterval.add(middle);
return newInterval;
}
int half = intervals.size() / 2;
IntervalData<O> middle = intervals.get(half);
newInterval.center = ((middle.start + middle.end) / 2);
List<IntervalData<O>> leftIntervals = new ArrayList<IntervalData<O>>();
List<IntervalData<O>> rightIntervals = new ArrayList<IntervalData<O>>();
for (IntervalData<O> interval : intervals) {
if (interval.end < newInterval.center) {
leftIntervals.add(interval);
} else if (interval.start > newInterval.center) {
rightIntervals.add(interval);
} else {
newInterval.add(interval);
}
}
if (leftIntervals.size() > 0)
newInterval.left = createFromList(leftIntervals);
if (rightIntervals.size() > 0)
newInterval.right = createFromList(rightIntervals);
return newInterval;
}
/**
* Stabbing query
*
* @param index
* to query for.
* @return data at index.
*/
public IntervalData<O> query(long index) {
return root.query(index);
}
/**
* Range query
*
* @param start
* of range to query for.
* @param end
* of range to query for.
* @return data for range.
*/
public IntervalData<O> query(long start, long end) {
return root.query(start, end);
}
/**
* {@inheritDoc}
*/
@Override
public String toString() {
StringBuilder builder = new StringBuilder();
builder.append(IntervalTreePrinter.getString(this));
return builder.toString();
}
protected static class IntervalTreePrinter {
public static <O extends Object> String getString(IntervalTree<O> tree) {
if (tree.root == null)
return "Tree has no nodes.";
return getString(tree.root, "", true);
}
private static <O extends Object> String getString(Interval<O> interval, String prefix, boolean isTail) {
StringBuilder builder = new StringBuilder();
builder.append(prefix + (isTail ? "└── " : "├── ") + interval.toString() + "\n");
List<Interval<O>> children = new ArrayList<Interval<O>>();
if (interval.left != null)
children.add(interval.left);
if (interval.right != null)
children.add(interval.right);
if (children.size() > 0) {
for (int i = 0; i < children.size() - 1; i++)
builder.append(getString(children.get(i), prefix + (isTail ? " " : ""), false));
if (children.size() > 0)
builder.append(getString(children.get(children.size() - 1), prefix + (isTail ? " " : ""), true));
}
return builder.toString();
}
}
public static final class Interval<O> {
private long center = Long.MIN_VALUE;
private Interval<O> left = null;
private Interval<O> right = null;
private List<IntervalData<O>> overlap = new ArrayList<IntervalData<O>>();
private void add(IntervalData<O> data) {
overlap.add(data);
}
/**
* Stabbing query
*
* @param index
* to query for.
* @return data at index.
*/
public IntervalData<O> query(long index) {
IntervalData<O> results = null;
if (index < center) {
// overlap is sorted by start point
Collections.sort(overlap,START_COMPARATOR);
for (IntervalData<O> data : overlap) {
if (data.start > index)
break;
IntervalData<O> temp = data.query(index);
if (results == null && temp != null)
results = temp;
else if (results != null && temp != null)
results.combined(temp);
}
} else if (index >= center) {
// overlap is reverse sorted by end point
Collections.sort(overlap,END_COMPARATOR);
for (IntervalData<O> data : overlap) {
if (data.end < index)
break;
IntervalData<O> temp = data.query(index);
if (results == null && temp != null)
results = temp;
else if (results != null && temp != null)
results.combined(temp);
}
}
if (index < center) {
if (left != null) {
IntervalData<O> temp = left.query(index);
if (results == null && temp != null)
results = temp;
else if (results != null && temp != null)
results.combined(temp);
}
} else if (index >= center) {
if (right != null) {
IntervalData<O> temp = right.query(index);
if (results == null && temp != null)
results = temp;
else if (results != null && temp != null)
results.combined(temp);
}
}
return results;
}
/**
* Range query
*
* @param start
* of range to query for.
* @param end
* of range to query for.
* @return data for range.
*/
public IntervalData<O> query(long start, long end) {
IntervalData<O> results = null;
for (IntervalData<O> data : overlap) {
if (data.start > end)
break;
IntervalData<O> temp = data.query(start, end);
if (results == null && temp != null)
results = temp;
else if (results != null && temp != null)
results.combined(temp);
}
if (left != null && start < center) {
IntervalData<O> temp = left.query(start, end);
if (temp != null && results == null)
results = temp;
else if (results != null && temp != null)
results.combined(temp);
}
if (right != null && end >= center) {
IntervalData<O> temp = right.query(start, end);
if (temp != null && results == null)
results = temp;
else if (results != null && temp != null)
results.combined(temp);
}
return results;
}
/**
* {@inheritDoc}
*/
@Override
public String toString() {
StringBuilder builder = new StringBuilder();
builder.append("Center=").append(center);
builder.append(" Set=").append(overlap);
return builder.toString();
}
}
/**
* Data structure representing an interval.
*/
public static class IntervalData<O> implements Comparable<IntervalData<O>> {
private long start = Long.MIN_VALUE;
private long end = Long.MAX_VALUE;
private Set<O> set = new HashSet<O>();
/**
* Interval data using O as it's unique identifier
*
* @param object
* Object which defines the interval data
*/
public IntervalData(long index, O object) {
this.start = index;
this.end = index;
this.set.add(object);
}
/**
* Interval data using O as it's unique identifier
*
* @param object
* Object which defines the interval data
*/
public IntervalData(long start, long end, O object) {
this.start = start;
this.end = end;
this.set.add(object);
}
/**
* Interval data list which should all be unique
*
* @param list
* of interval data objects
*/
public IntervalData(long start, long end, Set<O> set) {
this.start = start;
this.end = end;
this.set = set;
}
/**
* Get the start of this interval
*
* @return Start of interval
*/
public long getStart() {
return start;
}
/**
* Get the end of this interval
*
* @return End of interval
*/
public long getEnd() {
return end;
}
/**
* Get the data set in this interval
*
* @return Unmodifiable collection of data objects
*/
public Collection<O> getData() {
return Collections.unmodifiableCollection(this.set);
}
/**
* Clear the indices.
*/
public void clear() {
this.start = Long.MIN_VALUE;
this.end = Long.MAX_VALUE;
this.set.clear();
}
/**
* Combined this IntervalData with data.
*
* @param data
* to combined with.
* @return Data which represents the combination.
*/
public IntervalData<O> combined(IntervalData<O> data) {
if (data.start < this.start)
this.start = data.start;
if (data.end > this.end)
this.end = data.end;
this.set.addAll(data.set);
return this;
}
/**
* Deep copy of data.
*
* @return deep copy.
*/
public IntervalData<O> copy() {
Set<O> copy = new HashSet<O>();
copy.addAll(set);
return new IntervalData<O>(start, end, copy);
}
/**
* Query inside this data object.
*
* @param start
* of range to query for.
* @param end
* of range to query for.
* @return Data queried for or NULL if it doesn't match the query.
*/
public IntervalData<O> query(long index) {
if (index < this.start || index > this.end)
return null;
return copy();
}
/**
* Query inside this data object.
*
* @param startOfQuery
* of range to query for.
* @param endOfQuery
* of range to query for.
* @return Data queried for or NULL if it doesn't match the query.
*/
public IntervalData<O> query(long startOfQuery, long endOfQuery) {
if (endOfQuery < this.start || startOfQuery > this.end)
return null;
return copy();
}
/**
* {@inheritDoc}
*/
@Override
public int hashCode() {
return 31 * ((int)(this.start + this.end)) + this.set.size();
}
/**
* {@inheritDoc}
*/
@SuppressWarnings("unchecked")
@Override
public boolean equals(Object obj) {
if (!(obj instanceof IntervalData))
return false;
IntervalData<O> data = (IntervalData<O>) obj;
if (this.start == data.start && this.end == data.end) {
if (this.set.size() != data.set.size())
return false;
for (O o : set) {
if (!data.set.contains(o))
return false;
}
return true;
}
return false;
}
/**
* {@inheritDoc}
*/
@Override
public int compareTo(IntervalData<O> d) {
if (this.end < d.end)
return -1;
if (d.end < this.end)
return 1;
return 0;
}
/**
* {@inheritDoc}
*/
@Override
public String toString() {
StringBuilder builder = new StringBuilder();
builder.append(start).append("->").append(end);
builder.append(" set=").append(set);
return builder.toString();
}
}
}