/************************************************************************* * Compilation: javac LinearRegression.java * Execution: java LinearRegression * * Compute least squares solution to y = beta * x + alpha. * Simple linear regression. * * // TODO: rename beta and alpha to slope and intercept. * *************************************************************************/ /** * The LinearRegression class performs a simple linear regression * on an set of N data points ( yi , xi ). * That is, it fits a straight line y = α + β x , * (where y is the response variable, x is the predictor variable, * α is the y-intercept , and β is the slope ) * that minimizes the sum of squared residuals of the linear regression model. * It also computes associated statistics, including the coefficient of * determination R 2 and the standard deviation of the * estimates for the slope and y -intercept. * * @author Robert Sedgewick * @author Kevin Wayne */ public class LinearRegression { private final int N; private final double alpha, beta; private final double R2; private final double svar, svar0, svar1; /** * Performs a linear regression on the data points (y[i], x[i]) . * @param x the values of the predictor variable * @param y the corresponding values of the response variable * @throws java.lang.IllegalArgumentException if the lengths of the two arrays are not equal */ public LinearRegression(double[] x, double[] y) { if (x.length != y.length) { throw new IllegalArgumentException("array lengths are not equal"); } N = x.length; // first pass double sumx = 0.0, sumy = 0.0, sumx2 = 0.0; for (int i = 0; i < N; i++) sumx += x[i]; for (int i = 0; i < N; i++) sumx2 += x[i]*x[i]; for (int i = 0; i < N; i++) sumy += y[i]; double xbar = sumx / N; double ybar = sumy / N; // second pass: compute summary statistics double xxbar = 0.0, yybar = 0.0, xybar = 0.0; for (int i = 0; i < N; i++) { xxbar += (x[i] - xbar) * (x[i] - xbar); yybar += (y[i] - ybar) * (y[i] - ybar); xybar += (x[i] - xbar) * (y[i] - ybar); } beta = xybar / xxbar; alpha = ybar - beta * xbar; // more statistical analysis double rss = 0.0; // residual sum of squares double ssr = 0.0; // regression sum of squares for (int i = 0; i < N; i++) { double fit = beta*x[i] + alpha; rss += (fit - y[i]) * (fit - y[i]); ssr += (fit - ybar) * (fit - ybar); } int degreesOfFreedom = N-2; R2 = ssr / yybar; svar = rss / degreesOfFreedom; svar1 = svar / xxbar; svar0 = svar/N + xbar*xbar*svar1; } /** * Returns the y -intercept α of the best of the best-fit line y = α + β x . * @return the y -intercept α of the best-fit line y = α + β x */ public double intercept() { return alpha; } /** * Returns the slope β of the best of the best-fit line y = α + β x . * @return the slope β of the best-fit line y = α + β x */ public double slope() { return beta; } /** * Returns the coefficient of determination R 2. * @return the coefficient of determination R 2, which is a real number between 0 and 1 */ public double R2() { return R2; } /** * Returns the standard error of the estimate for the intercept. * @return the standard error of the estimate for the intercept */ public double interceptStdErr() { return Math.sqrt(svar0); } /** * Returns the standard error of the estimate for the slope. * @return the standard error of the estimate for the slope */ public double slopeStdErr() { return Math.sqrt(svar1); } /** * Returns the expected response y given the value of the predictor * variable x . * @param x the value of the predictor variable * @return the expected response y given the value of the predictor * variable x */ public double predict(double x) { return beta*x + alpha; } /** * Returns a string representation of the simple linear regression model. * @return a string representation of the simple linear regression model, * including the best-fit line and the coefficient of determination R 2 */ public String toString() { String s = ""; s += String.format("%.2f N + %.2f", slope(), intercept()); return s + " (R^2 = " + String.format("%.3f", R2()) + ")"; } }