475 lines
14 KiB
Java
475 lines
14 KiB
Java
package com.jwetherell.algorithms.data_structures;
|
|
|
|
import java.util.ArrayList;
|
|
import java.util.Collection;
|
|
import java.util.Collections;
|
|
import java.util.Comparator;
|
|
import java.util.HashSet;
|
|
import java.util.List;
|
|
import java.util.Set;
|
|
|
|
/**
|
|
* An interval tree is an ordered tree data structure to hold intervals.
|
|
* Specifically, it allows one to efficiently find all intervals that overlap
|
|
* with any given interval or point.
|
|
*
|
|
* http://en.wikipedia.org/wiki/Interval_tree
|
|
*
|
|
* @author Justin Wetherell <phishman3579@gmail.com>
|
|
*/
|
|
public class IntervalTree<O extends Object> {
|
|
|
|
private Interval<O> root = null;
|
|
|
|
private static final Comparator<IntervalData<?>> START_COMPARATOR = new Comparator<IntervalData<?>>() {
|
|
/**
|
|
* {@inheritDoc}
|
|
*/
|
|
@Override
|
|
public int compare(IntervalData<?> arg0, IntervalData<?> arg1) {
|
|
// Compare start first
|
|
if (arg0.start < arg1.start)
|
|
return -1;
|
|
if (arg1.start < arg0.start)
|
|
return 1;
|
|
return 0;
|
|
}
|
|
};
|
|
|
|
private static final Comparator<IntervalData<?>> END_COMPARATOR = new Comparator<IntervalData<?>>() {
|
|
/**
|
|
* {@inheritDoc}
|
|
*/
|
|
@Override
|
|
public int compare(IntervalData<?> arg0, IntervalData<?> arg1) {
|
|
// Compare end first
|
|
if (arg0.end > arg1.end)
|
|
return -1;
|
|
if (arg1.end > arg0.end)
|
|
return 1;
|
|
return 0;
|
|
}
|
|
};
|
|
|
|
/**
|
|
* Create interval tree from list of IntervalData objects;
|
|
*
|
|
* @param intervals
|
|
* is a list of IntervalData objects
|
|
*/
|
|
public IntervalTree(List<IntervalData<O>> intervals) {
|
|
if (intervals.size() <= 0)
|
|
return;
|
|
|
|
root = createFromList(intervals);
|
|
}
|
|
|
|
protected static final <O extends Object> Interval<O> createFromList(List<IntervalData<O>> intervals) {
|
|
Interval<O> newInterval = new Interval<O>();
|
|
if (intervals.size()==1) {
|
|
IntervalData<O> middle = intervals.get(0);
|
|
newInterval.center = ((middle.start + middle.end) / 2);
|
|
newInterval.add(middle);
|
|
return newInterval;
|
|
}
|
|
|
|
int half = intervals.size() / 2;
|
|
IntervalData<O> middle = intervals.get(half);
|
|
newInterval.center = ((middle.start + middle.end) / 2);
|
|
List<IntervalData<O>> leftIntervals = new ArrayList<IntervalData<O>>();
|
|
List<IntervalData<O>> rightIntervals = new ArrayList<IntervalData<O>>();
|
|
for (IntervalData<O> interval : intervals) {
|
|
if (interval.end < newInterval.center) {
|
|
leftIntervals.add(interval);
|
|
} else if (interval.start > newInterval.center) {
|
|
rightIntervals.add(interval);
|
|
} else {
|
|
newInterval.add(interval);
|
|
}
|
|
}
|
|
|
|
if (leftIntervals.size() > 0)
|
|
newInterval.left = createFromList(leftIntervals);
|
|
if (rightIntervals.size() > 0)
|
|
newInterval.right = createFromList(rightIntervals);
|
|
|
|
return newInterval;
|
|
}
|
|
|
|
/**
|
|
* Stabbing query
|
|
*
|
|
* @param index
|
|
* to query for.
|
|
* @return data at index.
|
|
*/
|
|
public IntervalData<O> query(long index) {
|
|
return root.query(index);
|
|
}
|
|
|
|
/**
|
|
* Range query
|
|
*
|
|
* @param start
|
|
* of range to query for.
|
|
* @param end
|
|
* of range to query for.
|
|
* @return data for range.
|
|
*/
|
|
public IntervalData<O> query(long start, long end) {
|
|
return root.query(start, end);
|
|
}
|
|
|
|
/**
|
|
* {@inheritDoc}
|
|
*/
|
|
@Override
|
|
public String toString() {
|
|
StringBuilder builder = new StringBuilder();
|
|
builder.append(IntervalTreePrinter.getString(this));
|
|
return builder.toString();
|
|
}
|
|
|
|
protected static class IntervalTreePrinter {
|
|
|
|
public static <O extends Object> String getString(IntervalTree<O> tree) {
|
|
if (tree.root == null)
|
|
return "Tree has no nodes.";
|
|
return getString(tree.root, "", true);
|
|
}
|
|
|
|
private static <O extends Object> String getString(Interval<O> interval, String prefix, boolean isTail) {
|
|
StringBuilder builder = new StringBuilder();
|
|
|
|
builder.append(prefix + (isTail ? "└── " : "├── ") + interval.toString() + "\n");
|
|
List<Interval<O>> children = new ArrayList<Interval<O>>();
|
|
if (interval.left != null)
|
|
children.add(interval.left);
|
|
if (interval.right != null)
|
|
children.add(interval.right);
|
|
if (children.size() > 0) {
|
|
for (int i = 0; i < children.size() - 1; i++)
|
|
builder.append(getString(children.get(i), prefix + (isTail ? " " : "│ "), false));
|
|
if (children.size() > 0)
|
|
builder.append(getString(children.get(children.size() - 1), prefix + (isTail ? " " : "│ "), true));
|
|
}
|
|
|
|
return builder.toString();
|
|
}
|
|
}
|
|
|
|
public static final class Interval<O> {
|
|
|
|
private long center = Long.MIN_VALUE;
|
|
private Interval<O> left = null;
|
|
private Interval<O> right = null;
|
|
private List<IntervalData<O>> overlap = new ArrayList<IntervalData<O>>();
|
|
|
|
private void add(IntervalData<O> data) {
|
|
overlap.add(data);
|
|
}
|
|
|
|
/**
|
|
* Stabbing query
|
|
*
|
|
* @param index
|
|
* to query for.
|
|
* @return data at index.
|
|
*/
|
|
public IntervalData<O> query(long index) {
|
|
IntervalData<O> results = null;
|
|
if (index < center) {
|
|
// overlap is sorted by start point
|
|
Collections.sort(overlap,START_COMPARATOR);
|
|
for (IntervalData<O> data : overlap) {
|
|
if (data.start > index)
|
|
break;
|
|
|
|
IntervalData<O> temp = data.query(index);
|
|
if (results == null && temp != null)
|
|
results = temp;
|
|
else if (results != null && temp != null)
|
|
results.combined(temp);
|
|
}
|
|
} else if (index >= center) {
|
|
// overlap is reverse sorted by end point
|
|
Collections.sort(overlap,END_COMPARATOR);
|
|
for (IntervalData<O> data : overlap) {
|
|
if (data.end < index)
|
|
break;
|
|
|
|
IntervalData<O> temp = data.query(index);
|
|
if (results == null && temp != null)
|
|
results = temp;
|
|
else if (results != null && temp != null)
|
|
results.combined(temp);
|
|
}
|
|
}
|
|
|
|
if (index < center) {
|
|
if (left != null) {
|
|
IntervalData<O> temp = left.query(index);
|
|
if (results == null && temp != null)
|
|
results = temp;
|
|
else if (results != null && temp != null)
|
|
results.combined(temp);
|
|
}
|
|
} else if (index >= center) {
|
|
if (right != null) {
|
|
IntervalData<O> temp = right.query(index);
|
|
if (results == null && temp != null)
|
|
results = temp;
|
|
else if (results != null && temp != null)
|
|
results.combined(temp);
|
|
}
|
|
}
|
|
return results;
|
|
}
|
|
|
|
/**
|
|
* Range query
|
|
*
|
|
* @param start
|
|
* of range to query for.
|
|
* @param end
|
|
* of range to query for.
|
|
* @return data for range.
|
|
*/
|
|
public IntervalData<O> query(long start, long end) {
|
|
IntervalData<O> results = null;
|
|
for (IntervalData<O> data : overlap) {
|
|
if (data.start > end)
|
|
break;
|
|
IntervalData<O> temp = data.query(start, end);
|
|
if (results == null && temp != null)
|
|
results = temp;
|
|
else if (results != null && temp != null)
|
|
results.combined(temp);
|
|
}
|
|
|
|
if (left != null && start < center) {
|
|
IntervalData<O> temp = left.query(start, end);
|
|
if (temp != null && results == null)
|
|
results = temp;
|
|
else if (results != null && temp != null)
|
|
results.combined(temp);
|
|
}
|
|
|
|
if (right != null && end >= center) {
|
|
IntervalData<O> temp = right.query(start, end);
|
|
if (temp != null && results == null)
|
|
results = temp;
|
|
else if (results != null && temp != null)
|
|
results.combined(temp);
|
|
}
|
|
|
|
return results;
|
|
}
|
|
|
|
/**
|
|
* {@inheritDoc}
|
|
*/
|
|
@Override
|
|
public String toString() {
|
|
StringBuilder builder = new StringBuilder();
|
|
builder.append("Center=").append(center);
|
|
builder.append(" Set=").append(overlap);
|
|
return builder.toString();
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Data structure representing an interval.
|
|
*/
|
|
public static class IntervalData<O> implements Comparable<IntervalData<O>> {
|
|
|
|
private long start = Long.MIN_VALUE;
|
|
private long end = Long.MAX_VALUE;
|
|
private Set<O> set = new HashSet<O>();
|
|
|
|
/**
|
|
* Interval data using O as it's unique identifier
|
|
*
|
|
* @param object
|
|
* Object which defines the interval data
|
|
*/
|
|
public IntervalData(long index, O object) {
|
|
this.start = index;
|
|
this.end = index;
|
|
this.set.add(object);
|
|
}
|
|
|
|
/**
|
|
* Interval data using O as it's unique identifier
|
|
*
|
|
* @param object
|
|
* Object which defines the interval data
|
|
*/
|
|
public IntervalData(long start, long end, O object) {
|
|
this.start = start;
|
|
this.end = end;
|
|
this.set.add(object);
|
|
}
|
|
|
|
/**
|
|
* Interval data list which should all be unique
|
|
*
|
|
* @param list
|
|
* of interval data objects
|
|
*/
|
|
public IntervalData(long start, long end, Set<O> set) {
|
|
this.start = start;
|
|
this.end = end;
|
|
this.set = set;
|
|
}
|
|
|
|
/**
|
|
* Get the start of this interval
|
|
*
|
|
* @return Start of interval
|
|
*/
|
|
public long getStart() {
|
|
return start;
|
|
}
|
|
|
|
/**
|
|
* Get the end of this interval
|
|
*
|
|
* @return End of interval
|
|
*/
|
|
public long getEnd() {
|
|
return end;
|
|
}
|
|
|
|
/**
|
|
* Get the data set in this interval
|
|
*
|
|
* @return Unmodifiable collection of data objects
|
|
*/
|
|
public Collection<O> getData() {
|
|
return Collections.unmodifiableCollection(this.set);
|
|
}
|
|
|
|
/**
|
|
* Clear the indices.
|
|
*/
|
|
public void clear() {
|
|
this.start = Long.MIN_VALUE;
|
|
this.end = Long.MAX_VALUE;
|
|
this.set.clear();
|
|
}
|
|
|
|
/**
|
|
* Combined this IntervalData with data.
|
|
*
|
|
* @param data
|
|
* to combined with.
|
|
* @return Data which represents the combination.
|
|
*/
|
|
public IntervalData<O> combined(IntervalData<O> data) {
|
|
if (data.start < this.start)
|
|
this.start = data.start;
|
|
if (data.end > this.end)
|
|
this.end = data.end;
|
|
this.set.addAll(data.set);
|
|
return this;
|
|
}
|
|
|
|
/**
|
|
* Deep copy of data.
|
|
*
|
|
* @return deep copy.
|
|
*/
|
|
public IntervalData<O> copy() {
|
|
Set<O> copy = new HashSet<O>();
|
|
copy.addAll(set);
|
|
return new IntervalData<O>(start, end, copy);
|
|
}
|
|
|
|
/**
|
|
* Query inside this data object.
|
|
*
|
|
* @param start
|
|
* of range to query for.
|
|
* @param end
|
|
* of range to query for.
|
|
* @return Data queried for or NULL if it doesn't match the query.
|
|
*/
|
|
public IntervalData<O> query(long index) {
|
|
if (index < this.start || index > this.end)
|
|
return null;
|
|
|
|
return copy();
|
|
}
|
|
|
|
/**
|
|
* Query inside this data object.
|
|
*
|
|
* @param startOfQuery
|
|
* of range to query for.
|
|
* @param endOfQuery
|
|
* of range to query for.
|
|
* @return Data queried for or NULL if it doesn't match the query.
|
|
*/
|
|
public IntervalData<O> query(long startOfQuery, long endOfQuery) {
|
|
if (endOfQuery < this.start || startOfQuery > this.end)
|
|
return null;
|
|
|
|
return copy();
|
|
}
|
|
|
|
/**
|
|
* {@inheritDoc}
|
|
*/
|
|
@Override
|
|
public int hashCode() {
|
|
return 31 * ((int)(this.start + this.end)) + this.set.size();
|
|
}
|
|
|
|
/**
|
|
* {@inheritDoc}
|
|
*/
|
|
@SuppressWarnings("unchecked")
|
|
@Override
|
|
public boolean equals(Object obj) {
|
|
if (!(obj instanceof IntervalData))
|
|
return false;
|
|
|
|
IntervalData<O> data = (IntervalData<O>) obj;
|
|
if (this.start == data.start && this.end == data.end) {
|
|
if (this.set.size() != data.set.size())
|
|
return false;
|
|
for (O o : set) {
|
|
if (!data.set.contains(o))
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
/**
|
|
* {@inheritDoc}
|
|
*/
|
|
@Override
|
|
public int compareTo(IntervalData<O> d) {
|
|
if (this.end < d.end)
|
|
return -1;
|
|
if (d.end < this.end)
|
|
return 1;
|
|
return 0;
|
|
}
|
|
|
|
/**
|
|
* {@inheritDoc}
|
|
*/
|
|
@Override
|
|
public String toString() {
|
|
StringBuilder builder = new StringBuilder();
|
|
builder.append(start).append("->").append(end);
|
|
builder.append(" set=").append(set);
|
|
return builder.toString();
|
|
}
|
|
}
|
|
}
|