You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

126 lines
3.6 KiB
Java

import edu.princeton.cs.introcs.StdOut;
/*************************************************************************
* Compilation: javac SparseVector.java
* Execution: java SparseVector
*
* A sparse vector, implementing using a symbol table.
*
* [Not clear we need the instance variable N except for error checking.]
*
*************************************************************************/
public class SparseVector {
private int N; // length
private ST<Integer, Double> st; // the vector, represented by index-value pairs
// initialize the all 0s vector of length N
public SparseVector(int N) {
this.N = N;
this.st = new ST<Integer, Double>();
}
// put st[i] = value
public void put(int i, double value) {
if (i < 0 || i >= N) throw new IndexOutOfBoundsException("Illegal index");
if (value == 0.0) st.delete(i);
else st.put(i, value);
}
// return st[i]
public double get(int i) {
if (i < 0 || i >= N) throw new IndexOutOfBoundsException("Illegal index");
if (st.contains(i)) return st.get(i);
else return 0.0;
}
// return the number of nonzero entries
public int nnz() {
return st.size();
}
// return the size of the vector
public int size() {
return N;
}
// return the dot product of this vector with that vector
public double dot(SparseVector that) {
if (this.N != that.N) throw new IllegalArgumentException("Vector lengths disagree");
double sum = 0.0;
// iterate over the vector with the fewest nonzeros
if (this.st.size() <= that.st.size()) {
for (int i : this.st.keys())
if (that.st.contains(i)) sum += this.get(i) * that.get(i);
}
else {
for (int i : that.st.keys())
if (this.st.contains(i)) sum += this.get(i) * that.get(i);
}
return sum;
}
// return the dot product of this vector and that array
public double dot(double[] that) {
double sum = 0.0;
for (int i : st.keys())
sum += that[i] * this.get(i);
return sum;
}
// return the 2-norm
public double norm() {
SparseVector a = this;
return Math.sqrt(a.dot(a));
}
// return alpha * this
public SparseVector scale(double alpha) {
SparseVector c = new SparseVector(N);
for (int i : this.st.keys()) c.put(i, alpha * this.get(i));
return c;
}
// return this + that
public SparseVector plus(SparseVector that) {
if (this.N != that.N) throw new IllegalArgumentException("Vector lengths disagree");
SparseVector c = new SparseVector(N);
for (int i : this.st.keys()) c.put(i, this.get(i)); // c = this
for (int i : that.st.keys()) c.put(i, that.get(i) + c.get(i)); // c = c + that
return c;
}
// return a string representation
@Override
public String toString() {
String s = "";
for (int i : st.keys()) {
s += "(" + i + ", " + st.get(i) + ") ";
}
return s;
}
// test client
public static void main(String[] args) {
SparseVector a = new SparseVector(10);
SparseVector b = new SparseVector(10);
a.put(3, 0.50);
a.put(9, 0.75);
a.put(6, 0.11);
a.put(6, 0.00);
b.put(3, 0.60);
b.put(4, 0.90);
StdOut.println("a = " + a);
StdOut.println("b = " + b);
StdOut.println("a dot b = " + a.dot(b));
StdOut.println("a + b = " + a.plus(b));
}
}