最近在写一个荧光图像分析软件,需要自己拟合方程。一元回归线公式的算法参考了《Java数值方法》,拟合度R^2(绝对系数)是自己写的,欢迎讨论。计算结果和Excel完全一致。
总共三个文件:
DataPoint.java
/**
* A data point for interpolation and regression.
*/
public class DataPoint
{
/** the x value */ public float x;
/** the y value */ public float y;
/**
* Constructor.
* @param x the x value
* @param y the y value
*/
public DataPoint(float x, float y)
{
this.x = x;
this.y = y;
}
}
/**
* A least-squares regression line function.
*/
import java.util.*;
import java.math.BigDecimal;
public class RegressionLine
//implements Evaluatable
{
/** sum of x */ private double sumX;
/** sum of y */ private double sumY;
/** sum of x*x */ private double sumXX;
/** sum of x*y */ private double sumXY;
/** sum of y*y */ private double sumYY;
/** sum of yi-y */ private double sumDeltaY;
/** sum of sumDeltaY^2 */ private double sumDeltaY2;
/**误差 */
private double sse;
private double sst;
private double E;
private String[] xy ;
private ArrayList listX ;
private ArrayList listY ;
private int XMin,XMax,YMin,YMax;
/** line coefficient a0 */ private float a0;
/** line coefficient a1 */ private float a1;
/** number of data points */ private int pn ;
/** true if coefficients valid */ private boolean coefsValid;
/**
* Constructor.
*/
public RegressionLine() {
XMax = 0;
YMax = 0;
pn = 0;
xy =new String[2];
listX = new ArrayList();
listY = new ArrayList();
}
/**
* Constructor.
* @param data the array of data points
*/
public RegressionLine(DataPoint data[])
{
pn = 0;
xy =new String[2];
listX = new ArrayList();
listY = new ArrayList();
for (int i &#61; 0; i addDataPoint(data[i]); } } /** * Return the current number of data points. * &#64;return the count */ public int getDataPointCount() { return pn; } /** * Return the coefficient a0. * &#64;return the value of a0 */ public float getA0() { validateCoefficients(); return a0; } /** * Return the coefficient a1. * &#64;return the value of a1 */ public float getA1() { validateCoefficients(); return a1; } /** * Return the sum of the x values. * &#64;return the sum */ public double getSumX() { return sumX; } /** * Return the sum of the y values. * &#64;return the sum */ public double getSumY() { return sumY; } /** * Return the sum of the x*x values. * &#64;return the sum */ public double getSumXX() { return sumXX; } /** * Return the sum of the x*y values. * &#64;return the sum */ public double getSumXY() { return sumXY; } public double getSumYY() { return sumYY; } public int getXMin() { return XMin; } public int getXMax() { return XMax; } public int getYMin() { return YMin; } public int getYMax() { return YMax; } /** * Add a new data point: Update the sums. * &#64;param dataPoint the new data point */ public void addDataPoint(DataPoint dataPoint) { sumX &#43;&#61; dataPoint.x; sumY &#43;&#61; dataPoint.y; sumXX &#43;&#61; dataPoint.x*dataPoint.x; sumXY &#43;&#61; dataPoint.x*dataPoint.y; sumYY &#43;&#61; dataPoint.y*dataPoint.y; if(dataPoint.x > XMax){ XMax &#61; (int)dataPoint.x; } if(dataPoint.y > YMax){ YMax &#61; (int)dataPoint.y; } //把每个点的具体坐标存入ArrayList中&#xff0c;备用 xy[0] &#61; (int)dataPoint.x&#43; ""; xy[1] &#61; (int)dataPoint.y&#43; ""; if(dataPoint.x!&#61;0 && dataPoint.y !&#61; 0){ System.out.print(xy[0]&#43;","); System.out.println(xy[1]); try{ //System.out.println("n:"&#43;n); listX.add(pn,xy[0]); listY.add(pn,xy[1]); } catch(Exception e){ e.printStackTrace(); } /* System.out.println("N:" &#43; n); System.out.println("ArrayList listX:"&#43; listX.get(n)); System.out.println("ArrayList listY:"&#43; listY.get(n)); */ } &#43;&#43;pn; coefsValid &#61; false; } /** * Return the value of the regression line function at x. * (Implementation of Evaluatable.) * &#64;param x the value of x * &#64;return the value of the function at x */ public float at(int x) { if (pn <2) return Float.NaN; validateCoefficients(); return a0 &#43; a1*x; } public float at(float x) { if (pn <2) return Float.NaN; validateCoefficients(); return a0 &#43; a1*x; } /** * Reset. */ public void reset() { pn &#61; 0; sumX &#61; sumY &#61; sumXX &#61; sumXY &#61; 0; coefsValid &#61; false; } /** * Validate the coefficients. * 计算方程系数 y&#61;ax&#43;b 中的a */ private void validateCoefficients() { if (coefsValid) return; if (pn >&#61; 2) { float xBar &#61; (float) sumX/pn; float yBar &#61; (float) sumY/pn; a1 &#61; (float) ((pn*sumXY - sumX*sumY) /(pn*sumXX - sumX*sumX)); a0 &#61; (float) (yBar - a1*xBar); } else { a0 &#61; a1 &#61; Float.NaN; } coefsValid &#61; true; } /** * 返回误差 */ public double getR(){ //遍历这个list并计算分母 for(int i &#61; 0; i float Yi&#61; (float)Integer.parseInt(listY.get(i).toString()); float Y &#61; at(Integer.parseInt(listX.get(i).toString())); float deltaY &#61; Yi - Y; float deltaY2 &#61; deltaY*deltaY; /* System.out.println("Yi:" &#43; Yi); System.out.println("Y:" &#43; Y); System.out.println("deltaY:" &#43; deltaY); System.out.println("deltaY2:" &#43; deltaY2); */ sumDeltaY2 &#43;&#61; deltaY2; //System.out.println("sumDeltaY2:" &#43; sumDeltaY2); } sst &#61; sumYY - (sumY*sumY)/pn; //System.out.println("sst:" &#43; sst); E &#61;1- sumDeltaY2/sst; return round(E,4) ; } //用于实现精确的四舍五入 public double round(double v,int scale){ if(scale<0){ throw new IllegalArgumentException( "The scale must be a positive integer or zero"); } BigDecimal b &#61; new BigDecimal(Double.toString(v)); BigDecimal one &#61; new BigDecimal("1"); return b.divide(one,scale,BigDecimal.ROUND_HALF_UP).doubleValue(); } public float round(float v,int scale){ if(scale<0){ throw new IllegalArgumentException( "The scale must be a positive integer or zero"); } BigDecimal b &#61; new BigDecimal(Double.toString(v)); BigDecimal one &#61; new BigDecimal("1"); return b.divide(one,scale,BigDecimal.ROUND_HALF_UP).floatValue(); } } 演示程序&#xff1a; LinearRegression.java /** * Linear Regression * * Demonstrate linear regression by constructing the regression line for a set * of data points. * * require DataPoint.java,RegressionLine.java * * 为了计算对于给定数据点的最小方差回线&#xff0c;需要计算SumX,SumY,SumXX,SumXY; (注&#xff1a;SumXX &#61; Sum (X^2)) * 回归直线方程如下&#xff1a; f(x)&#61;a1x&#43;a0 * 斜率和截距的计算公式如下&#xff1a; * * a1&#61;(n(SumXY)-SumX*SumY)/(n*SumXX-(SumX)^2) * * * * 画线的原理&#xff1a;两点成一直线&#xff0c;只要能确定两个点即可 * 第一点&#xff1a;(0,a0) 再随意取一个x1值代入方程&#xff0c;取得y1&#xff0c;连结(0,a0)和(x1,y1)两点即可。 * 为了让线穿过整个图,x1可以取横坐标的最大值Xmax&#xff0c;即两点为(0,a0),(Xmax,Y)。如果y&#61;a1*Xmax&#43;a0,y大于 * 纵坐标最大值Ymax&#xff0c;则不用这个点。改用y取最大值Ymax&#xff0c;算得此时x的值&#xff0c;使用(X,Ymax)&#xff0c; 即两点为(0,a0),(X,Ymax) * * 拟合度计算&#xff1a;(即Excel中的R^2) * *R2 &#61; 1 - E * 误差E的计算&#xff1a;E &#61; SSE/SST * SSE&#61;sum((Yi-Y)^2) SST&#61;sumYY - (sumY*sumY)/n; * */ public class LinearRegression { private static final int MAX_POINTS &#61; 10; private double E; /** * Main program. * * &#64;param args * the array of runtime arguments */ public static void main(String args[]) { RegressionLine line &#61; new RegressionLine(); line.addDataPoint(new DataPoint(20, 136)); line.addDataPoint(new DataPoint(40, 143)); line.addDataPoint(new DataPoint(60, 152)); line.addDataPoint(new DataPoint(80, 162)); line.addDataPoint(new DataPoint(100, 167)); printSums(line); printLine(line); } /** * Print the computed sums. * * &#64;param line * the regression line */ private static void printSums(RegressionLine line) { System.out.println("\n数据点个数 n &#61; " &#43; line.getDataPointCount()); System.out.println("\nSum x &#61; " &#43; line.getSumX()); System.out.println("Sum y &#61; " &#43; line.getSumY()); System.out.println("Sum xx &#61; " &#43; line.getSumXX()); System.out.println("Sum xy &#61; " &#43; line.getSumXY()); System.out.println("Sum yy &#61; " &#43; line.getSumYY()); } /** * Print the regression line function. * * &#64;param line * the regression line */ private static void printLine(RegressionLine line) { System.out.println("\n回归线公式: y &#61; " &#43; line.getA1() &#43; "x &#43; " &#43; line.getA0()); System.out.println("拟合度&#xff1a; R^2 &#61; " &#43; line.getR()); } }
n: 数据点个数
a0&#61;(SumY - SumY * a1)/n
(也可表达为a0&#61;averageY-a1*averageX)