最小二乘法多项式拟合的Java实现

摘要:
背景该项目需要根据一些现有数据学习一元二项式y=ax+b。给定一些x、y、a和b的样本数据可以通过梯度下降或最小二乘法的多项式拟合获得。当解决这个问题时,第一个想法是通过Sparkmlib学习,但结果并不理想:少量文档,参数难以调整。因此,解决问题的方法发生了变化:使用最小二乘法进行多项式拟合。

背景

由项目中需要根据一些已有数据学习出一个y=ax+b的一元二项式,给定了x,y的一些样本数据,通过梯度下降或最小二乘法做多项式拟合得到a、b,解决该问题时,首先想到的是通过spark mllib去学习,可是结果并不理想:少量的文档,参数也很难调整。于是转变了解决问题的方式:采用了最小二乘法做多项式拟合。

最小二乘法多项式拟合描述下: (以下参考:https://blog.csdn.net/funnyrand/article/details/46742561)

假设给定的数据点和其对应的函数值为 (x1, y1), (x2, y2), ... (xm, ym),需要做的就是得到一个多项式函数f(x) = a0 * x + a1 * pow(x, 2) + .. + an * pow(x, n),使其对所有给定x所计算出的f(x)与实际对应的y值的差的平方和最小,也就是计算多项式的各项系数 a0, a1, ... an. 

根据最小二乘法的原理,该问题可转换为求以下线性方程组的解:Ga = B

最小二乘法多项式拟合的Java实现第1张

所以从编程的角度来说需要做两件事情:

1)确定线性方程组的各个系数:

确定系数比较简单,对给定的 (x1, y1), (x2, y2), ... (xm, ym) 做相应的计算即可,相关代码:

private void compute() {
  ...
}

2)解线性方程组:

解线性方程组稍微复杂,这里用到了高斯消元法,基本思想是通过递归做矩阵转换,逐渐减少求解的多项式系数的个数,相关代码:

private double[] calcLinearEquation(double[][] a, double[] b) {
  ...
}

Java代码

  1 public class JavaLeastSquare {
  2     private double[] x;
  3     private double[] y;
  4     private double[] weight;
  5     private int n;
  6     private double[] coefficient;
  7 
  8     /**
  9      * Constructor method.
 10      * @param x Array of x
 11      * @param y Array of y
 12      * @param n The order of polynomial
 13      */
 14     public JavaLeastSquare(double[] x, double[] y, int n) {
 15         if (x == null || y == null || x.length < 2 || x.length != y.length
 16                 || n < 2) {
 17             throw new IllegalArgumentException(
 18                     "IllegalArgumentException occurred.");
 19         }
 20         this.x = x;
 21         this.y = y;
 22         this.n = n;
 23         weight = new double[x.length];
 24         for (int i = 0; i < x.length; i++) {
 25             weight[i] = 1;
 26         }
 27         compute();
 28     }
 29 
 30     /**
 31      * Constructor method.
 32      * @param x      Array of x
 33      * @param y      Array of y
 34      * @param weight Array of weight
 35      * @param n      The order of polynomial
 36      */
 37     public JavaLeastSquare(double[] x, double[] y, double[] weight, int n) {
 38         if (x == null || y == null || weight == null || x.length < 2
 39                 || x.length != y.length || x.length != weight.length || n < 2) {
 40             throw new IllegalArgumentException(
 41                     "IllegalArgumentException occurred.");
 42         }
 43         this.x = x;
 44         this.y = y;
 45         this.n = n;
 46         this.weight = weight;
 47         compute();
 48     }
 49 
 50     /**
 51      * Get coefficient of polynomial.
 52      * @return coefficient of polynomial
 53      */
 54     public double[] getCoefficient() {
 55         return coefficient;
 56     }
 57 
 58     /**
 59      * Used to calculate value by given x.
 60      * @param x x
 61      * @return y
 62      */
 63     public double fit(double x) {
 64         if (coefficient == null) {
 65             return 0;
 66         }
 67         double sum = 0;
 68         for (int i = 0; i < coefficient.length; i++) {
 69             sum += Math.pow(x, i) * coefficient[i];
 70         }
 71         return sum;
 72     }
 73 
 74     /**
 75      * Use Newton's method to solve equation.
 76      * @param y y
 77      * @return x
 78      */
 79     public double solve(double y) {
 80         return solve(y, 1.0d);
 81     }
 82 
 83     /**
 84      * Use Newton's method to solve equation.
 85      * @param y      y
 86      * @param startX The start point of x
 87      * @return x
 88      */
 89     public double solve(double y, double startX) {
 90         final double EPS = 0.0000001d;
 91         if (coefficient == null) {
 92             return 0;
 93         }
 94         double x1 = 0.0d;
 95         double x2 = startX;
 96         do {
 97             x1 = x2;
 98             x2 = x1 - (fit(x1) - y) / calcReciprocal(x1);
 99         } while (Math.abs((x1 - x2)) > EPS);
100         return x2;
101     }
102 
103     /*
104      * Calculate the reciprocal of x.
105      * @param x x
106      * @return the reciprocal of x
107      */
108     private double calcReciprocal(double x) {
109         if (coefficient == null) {
110             return 0;
111         }
112         double sum = 0;
113         for (int i = 1; i < coefficient.length; i++) {
114             sum += i * Math.pow(x, i - 1) * coefficient[i];
115         }
116         return sum;
117     }
118 
119     /*
120      * This method is used to calculate each elements of augmented matrix.
121      */
122     private void compute() {
123         if (x == null || y == null || x.length <= 1 || x.length != y.length
124                 || x.length < n || n < 2) {
125             return;
126         }
127         double[] s = new double[(n - 1) * 2 + 1];
128         for (int i = 0; i < s.length; i++) {
129             for (int j = 0; j < x.length; j++) {
130                 s[i] += Math.pow(x[j], i) * weight[j];
131             }
132         }
133         double[] b = new double[n];
134         for (int i = 0; i < b.length; i++) {
135             for (int j = 0; j < x.length; j++) {
136                 b[i] += Math.pow(x[j], i) * y[j] * weight[j];
137             }
138         }
139         double[][] a = new double[n][n];
140         for (int i = 0; i < n; i++) {
141             for (int j = 0; j < n; j++) {
142                 a[i][j] = s[i + j];
143             }
144         }
145 
146         // Now we need to calculate each coefficients of augmented matrix
147         coefficient = calcLinearEquation(a, b);
148     }
149 
150     /*
151      * Calculate linear equation.
152      * The matrix equation is like this: Ax=B
153      * @param a two-dimensional array
154      * @param b one-dimensional array
155      * @return x, one-dimensional array
156      */
157     private double[] calcLinearEquation(double[][] a, double[] b) {
158         if (a == null || b == null || a.length == 0 || a.length != b.length) {
159             return null;
160         }
161 
162         for (double[] x : a) {
163             if (x == null || x.length != a.length)
164                 return null;
165         }
166 
167         int len = a.length - 1;
168         double[] result = new double[a.length];
169 
170         if (len == 0) {
171             result[0] = b[0] / a[0][0];
172             return result;
173         }
174 
175         double[][] aa = new double[len][len];
176         double[] bb = new double[len];
177         int posx = -1, posy = -1;
178         for (int i = 0; i <= len; i++) {
179             for (int j = 0; j <= len; j++)
180                 if (a[i][j] != 0.0d) {
181                     posy = j;
182                     break;
183                 }
184             if (posy != -1) {
185                 posx = i;
186                 break;
187             }
188         }
189         if (posx == -1) {
190             return null;
191         }
192 
193         int count = 0;
194         for (int i = 0; i <= len; i++) {
195             if (i == posx) {
196                 continue;
197             }
198             bb[count] = b[i] * a[posx][posy] - b[posx] * a[i][posy];
199             int count2 = 0;
200             for (int j = 0; j <= len; j++) {
201                 if (j == posy) {
202                     continue;
203                 }
204                 aa[count][count2] = a[i][j] * a[posx][posy] - a[posx][j] * a[i][posy];
205                 count2++;
206             }
207             count++;
208         }
209 
210         // Calculate sub linear equation
211         double[] result2 = calcLinearEquation(aa, bb);
212 
213         // After sub linear calculation, calculate the current coefficient
214         double sum = b[posx];
215         count = 0;
216         for (int i = 0; i <= len; i++) {
217             if (i == posy) {
218                 continue;
219             }
220             sum -= a[posx][i] * result2[count];
221             result[i] = result2[count];
222             count++;
223         }
224         result[posy] = sum / a[posx][posy];
225         return result;
226     }
227 
228     public static void main(String[] args) {
229         JavaLeastSquare eastSquareMethod = new JavaLeastSquare(
230                 new double[]{
231                         2, 14, 20, 25, 26, 34,
232                         47, 87, 165, 265, 365, 465,
233                         565, 665
234                 },
235                 new double[]{
236                         0.7 * 2 + 20 + 0.4,
237                         0.7 * 14 + 20 + 0.5,
238                         0.7 * 20 + 20 + 3.4,
239                         0.7 * 25 + 20 + 5.8,
240                         0.7 * 26 + 20 + 8.27,
241                         0.7 * 34 + 20 + 0.4,
242 
243                         0.7 * 47 + 20 + 0.1,
244                         0.7 * 87 + 20,
245                         0.7 * 165 + 20,
246                         0.7 * 265 + 20,
247                         0.7 * 365 + 20,
248                         0.7 * 465 + 20,
249 
250                         0.7 * 565 + 20,
251                         0.7 * 665 + 20
252                 },
253                 2);
254 
255         double[] coefficients = eastSquareMethod.getCoefficient();
256         for (double c : coefficients) {
257             System.out.println(c);
258         }
259 
260         // 测试
261         System.out.println(eastSquareMethod.fit(4));
262     }
263 }

输出结果:

com.datangmobile.biz.leastsquare.JavaLeastSquare
22.27966881467629
0.6952475907448203
25.06065917765557

Process finished with exit code 0

使用开源库

也可使用Apache开源库commons math(http://commons.apache.org/proper/commons-math/userguide/fitting.html),提供的功能更强大:

<dependency>  
    <groupId>org.apache.commons</groupId>  
    <artifactId>commons-math3</artifactId>  
    <version>3.5</version>  
</dependency>  

实现代码:

import org.apache.commons.math3.fitting.PolynomialCurveFitter;
import org.apache.commons.math3.fitting.WeightedObservedPoints;

public class WeightedObservedPointsTest {
    public static void main(String[] args) {
        final WeightedObservedPoints obs = new WeightedObservedPoints();
        obs.add(2,  0.7 * 2 + 20 + 0.4);
        obs.add(12,  0.7 * 12 + 20 + 0.3);
        obs.add(32,  0.7 * 32 + 20 + 3.4);
        obs.add(34 ,  0.7 * 34 + 20 + 5.8);
        obs.add(58 , 0.7 * 58 + 20 + 8.4);
        obs.add(43 , 0.7 * 43 + 20 + 0.28);
        obs.add(27 , 0.7 * 27 + 20 + 0.4);

        // Instantiate a two-degree polynomial fitter.
        final PolynomialCurveFitter fitter = PolynomialCurveFitter.create(2);

        // Retrieve fitted parameters (coefficients of the polynomial function).
        final double[] coeff = fitter.fit(obs.toList());
        for (double c : coeff) {
            System.out.println(c);
        }
    }
}

测试输出结果:

20.47425047847121
0.6749744063035112
0.002523043547711147

Process finished with exit code 0

使用org.ujmp(矩阵)实现最小二乘法:

pom.xml中需要引入org.ujmp

<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
    xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>
    <groupId>com.dtgroup</groupId>
    <artifactId>dtgroup</artifactId>
    <version>0.0.1-SNAPSHOT</version>

    <repositories>
        <repository>
            <id>limaven</id>
            <name>aliyun maven</name>
            <url>http://maven.aliyun.com/nexus/content/groups/public/</url>
            <layout>default</layout>
            <releases>
                <enabled>true</enabled>
            </releases>
            <snapshots>
                <enabled>false</enabled>
            </snapshots>
        </repository>
    </repositories>
    <dependencies>
        <dependency>
            <groupId>org.ujmp</groupId>
            <artifactId>ujmp-core</artifactId>
            <version>0.3.0</version>
        </dependency>
    </dependencies>
</project>

java代码:

    /**
     * 采用最小二乘法多项式拟合方式,获取多项式的系数。
     * @param sampleCount 采样点个数
     * @param fetureCount 多项式的系数
     * @param samples 采样点集合
     * **/
    private static void leastsequare(int sampleCount, int fetureCout, List<Sample> samples) {
        // 构件 2*2矩阵 存储X,元素值都为1.0000的矩阵
        Matrix matrixX = DenseMatrix.Factory.ones(sampleCount, fetureCout);

        for (int i = 0; i < samples.size(); i++) {
            matrixX.setAsDouble(samples.get(i).getX(), i, 1);
        }

        // System.out.println(matrixX);
        System.out.println("--------------------------------------");
        // 构件 2*2矩阵 存储X
        Matrix matrixY = DenseMatrix.Factory.ones(sampleCount, 1);

        for (int i = 0; i < samples.size(); i++) {
            matrixY.setAsDouble(samples.get(i).getY(), i, 0);
        }
        // System.out.println(matrixY);

        // 对X进行转置
        Matrix matrixXTrans = matrixX.transpose();
        // System.out.println(matrixXTrans);

        // 乘积运算:x*转转置后x:matrixXTrans*matrixX
        Matrix matrixMtimes = matrixXTrans.mtimes(matrixX);
        System.out.println(matrixMtimes);

        System.out.println("--------------------------------------");
        // 求逆
        Matrix matrixMtimesInv = matrixMtimes.inv();
        System.out.println(matrixMtimesInv);

        // x转置后结果*求逆结果
        System.out.println("--------------------------------------");
        Matrix matrixMtimesInvMtimes = matrixMtimesInv.mtimes(matrixXTrans);
        System.out.println(matrixMtimesInvMtimes);

        System.out.println("--------------------------------------");
        Matrix theta = matrixMtimesInvMtimes.mtimes(matrixY);
        System.out.println(theta);
    }

测试代码:

    public static void main(String[] args) {
        /**
         * y=ax+b
         * 
         * a(0,1] b[5,20]
         * 
         * x[0,500] y>=5
         */

        // y= 0.8d*x+15
        // 当x不变动时,y对应有多个值;此时把y求均值。
        List<Sample> samples = new ArrayList<Sample>();
        samples.add(new Sample(0.8d * 1 + 15 + 1, 1d));
        samples.add(new Sample(0.8d * 4 + 15 + 0.8, 4d));
        samples.add(new Sample(0.8d * 3 + 15 + 0.7, 3d));
        samples.add(new Sample(0.8d * 24 + 15 + 0.4, 24d));
        samples.add(new Sample(0.8d * 5 + 15 + 0.3, 5d));
        samples.add(new Sample(0.8d * 10 + 15 + 0.4, 10d));
        samples.add(new Sample(0.8d * 14 + 15 + 0.2, 14d));
        samples.add(new Sample(0.8d * 7 + 15 + 0.3, 7d));
        samples.add(new Sample(0.8d * 1000 + 23 + 0.3, 70d));

        int sampleCount = samples.size();
        int fetureCout = 2;

        leastsequare(sampleCount, fetureCout, samples);
    }

过滤样本中的噪点:

    public static void main(String[] args) {
        /**
         * y=ax+b
         * 
         * a(0,1] b[5,20]
         * 
         * x[0,500] y>=5
         */

        // y= 0.8d*x+15
        // 当x不变动时,y对应有多个值;此时把y求均值。
        List<Sample> samples = new ArrayList<Sample>();
        samples.add(new Sample(0.8d * 1 + 15 + 1, 1d));
        samples.add(new Sample(0.8d * 4 + 15 + 0.8, 4d));
        samples.add(new Sample(0.8d * 3 + 15 + 0.7, 3d));
        samples.add(new Sample(0.8d * 24 + 15 + 0.4, 24d));
        samples.add(new Sample(0.8d * 5 + 15 + 0.3, 5d));
        samples.add(new Sample(0.8d * 10 + 15 + 0.4, 10d));
        samples.add(new Sample(0.8d * 14 + 15 + 0.2, 14d));
        samples.add(new Sample(0.8d * 7 + 15 + 0.3, 7d));
        samples.add(new Sample(0.8d * 1000 + 23 + 0.3, 70d));

        // samples = filterSample(samples);
        sortSample(samples);
        FilterSampleByGradientResult result = filterSampleByGradient(0, samples);

        while (result.isComplete() == false) {
            List<Sample> newSamples=result.getSamples(); 
            sortSample(newSamples);
            result = filterSampleByGradient(result.getIndex(), newSamples);
        }
        samples = result.getSamples();

        for (Sample sample : samples) {
            System.out.println(sample);
        }

        int sampleCount = samples.size();
        int fetureCout = 2;

        leastsequare(sampleCount, fetureCout, samples);
    }

    /**
     * 对采样点进行排序,按照x排序,升序排列
     * @param samples 采样点集合
     * **/
    private static void sortSample(List<Sample> samples) {
        samples.sort(new Comparator<Sample>() {
            public int compare(Sample o1, Sample o2) {
                if (o1.getX() > o2.getX()) {
                    return 1;
                } else if (o1.getX() <= o2.getX()) {
                    return -1;
                }
                return 0;
            }
        });
    }

    /**
     * 过滤采样点中的噪点(采样过滤方式:double theta=(y2-y1)/(x2-x1),theta就是一个斜率,根据该值范围来过滤。)
     * @param index 记录上次过滤索引
     * @param samples 采样点集合(将从其中过滤掉噪点)
     * **/
    private static FilterSampleByGradientResult filterSampleByGradient(int index, List<Sample> samples) {
        int sampleSize = samples.size();
        for (int i = index; i < sampleSize - 1; i++) {
            double delta_x = samples.get(i).getX() - samples.get(i + 1).getX();
            double delta_y = samples.get(i).getY() - samples.get(i + 1).getY();
            // 距离小于2米
            if (Math.abs(delta_x) < 1) {
                double newY = (samples.get(i).getY() + samples.get(i + 1).getY()) / 2;
                double newX = samples.get(i).getX();

                samples.remove(i);
                samples.remove(i + 1);
                samples.add(new Sample(newY, newX));

                return new FilterSampleByGradientResult(false, i, samples);
            } else {
                double gradient = delta_y / delta_x;
                if (gradient > 1.5) {
                    if (i == 0) {
                        // double newY = (samples.get(i).getY() + samples.get(i
                        // + 1).getY()) / 2;
                        // double newX = (samples.get(i).getX() + samples.get(i
                        // + 1).getX()) / 2;

                        // samples.remove(i);
                        // samples.add(new Sample(newY, newX));
                    } else {
                        samples.remove(i + 1);
                    }

                    return new FilterSampleByGradientResult(false, i, samples);
                }
            }
        }

        return new FilterSampleByGradientResult(true, 0, samples);
    }

 使用距离来处理过滤:

    private static List<Sample> filterSample(List<Sample> samples) {
        // x={x1,x2,x3...xn}
        // u=E(x) ---x的期望(均值)为 u
        // 6=sqrt(pow((x1-u),2)+pow((x2-u),2)+pow((x3-u),2)+...+pow((xn-u),2))
        // 6为x的标准差,标准差=sqrt(方差)
        // 剔除噪点可以采用:
        // 若xi不属于(u-3*6,u+3*6),则认为它是噪点。

        // 另外一种方案,对x/y都做上边的处理,之后如果两个结果为and 或者 or操作来选取是否剔除。
        // 用点的方式来过滤数据,求出一个中值点,求其他点到该点的距离。
        int sampleCount = samples.size();
        double sumX = 0d;
        double sumY = 0d;

        for (Sample sample : samples) {
            sumX += sample.getX();
            sumY += sample.getY();
        }

        // 求中心点
        double centerX = (sumX / sampleCount);
        double centerY = (sumY / sampleCount);

        List<Double> distanItems = new ArrayList<Double>();
        // 计算出所有点距离该中心点的距离
        for (int i = 0; i < samples.size(); i++) {
            Sample sample = samples.get(i);
            Double xyPow2 = Math.pow(sample.getX() - centerX, 2) + Math.pow(sample.getY() - centerY, 2);
            distanItems.add(Math.sqrt(xyPow2));
        }

        // 以下对根据距离(所有点距离中心点的距离)进行筛选
        double sumDistan = 0d;
        double distanceU = 0d;
        for (Double distance : distanItems) {
            sumDistan += distance;
        }
        distanceU = sumDistan / sampleCount;

        double deltaPowSum = 0d;
        double distanceTheta = 0d;
        // sqrt(pow((x1-u),2)+pow((x2-u),2)+pow((x3-u),2)+...+pow((xn-u),2))
        for (Double distance : distanItems) {
            deltaPowSum += Math.pow((distance - distanceU), 2);
        }
        distanceTheta = Math.sqrt(deltaPowSum);

        // 剔除噪点可以采用:
        // 若xi不属于(u-3*6,u+3*6),则认为它是噪点。
        double minDistance = distanceU - 0.5 * distanceTheta;
        double maxDistance = distanceU + 0.5 * distanceTheta;
        List<Integer> willbeRemoveIdxs = new ArrayList<Integer>();
        for (int i = distanItems.size() - 1; i >= 0; i--) {
            Double distance = distanItems.get(i);
            if (distance <= minDistance || distance >= maxDistance) {
                willbeRemoveIdxs.add(i);
                System.out.println("will be remove " + i);
            }
        }

        for (int willbeRemoveIdx : willbeRemoveIdxs) {
            samples.remove(willbeRemoveIdx);
        }

        return samples;
    }

实际业务测试:

最小二乘法多项式拟合的Java实现第2张最小二乘法多项式拟合的Java实现第3张
package com.zjanalyse.spark.maths;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;

import org.ujmp.core.DenseMatrix;
import org.ujmp.core.Matrix;

public class LastSquare {
    /**
     * y=ax+b a(0,1] b[5,20] x[0,500] y>=5
     */
    public static void main(String[] args) {
        // y= 0.8d*x+15
        // 当x不变动时,y对应有多个值;此时把y求均值。
        List<Sample> samples = new ArrayList<Sample>();
        samples.add(new Sample(0.8d * 11 + 15 + 1, 11d));
        samples.add(new Sample(0.8d * 24 + 15 + 0.8, 24d));
        samples.add(new Sample(0.8d * 33 + 15 + 0.7, 33d));
        samples.add(new Sample(0.8d * 24 + 15 + 0.4, 24d));
        samples.add(new Sample(0.8d * 47 + 15 + 0.3, 47d));
        samples.add(new Sample(0.8d * 60 + 15 + 0.4, 60d));
        samples.add(new Sample(0.8d * 14 + 15 + 0.2, 14d));
        samples.add(new Sample(0.8d * 57 + 15 + 0.3, 57d));
        samples.add(new Sample(0.8d * 70 + 60 + 0.3, 70d));
        samples.add(new Sample(0.8d * 80 + 60 + 0.3, 80d));
        samples.add(new Sample(0.8d * 40 + 30 + 0.3, 40d));

        sortSample(samples);
        System.out.println("原始样本数据");
        for (Sample sample : samples) {
            System.out.println(sample);
        }

        System.out.println("开始“所有点”通过“业务数据取值范围”剔除:");
        // 按照业务过滤。。。
        filterByBusiness(samples);
        System.out.println("结束“所有点”通过“业务数据取值范围”剔除:");

        for (Sample sample : samples) {
            System.out.println(sample);
        }

        int sampleCount = samples.size();
        int fetureCout = 2;
        System.out.println("第一次拟合。。。");
        Matrix theta = leastsequare(sampleCount, fetureCout, samples);

        double wear_loss = theta.getAsDouble(0, 0);
        double path_loss = theta.getAsDouble(1, 0);

        System.out.println("wear loss " + wear_loss);
        System.out.println("path loss " + path_loss);

        System.out.println("开始“所有点”与“第一多项式拟合结果直线方式距离方差”剔除:");
        samples = filterSample(wear_loss, path_loss, samples);
        System.out.println("结束“所有点”与“第一多项式拟合结果直线方式距离方差”剔除:");

        for (Sample sample : samples) {
            System.out.println(sample);
        }

        System.out.println("第二次拟合。。。");
        sampleCount = samples.size();
        fetureCout = 2;

        if (sampleCount >= 2) {
            theta = leastsequare(sampleCount, fetureCout, samples);

            wear_loss = theta.getAsDouble(0, 0);
            path_loss = theta.getAsDouble(1, 0);

            System.out.println("wear loss " + wear_loss);
            System.out.println("path loss " + path_loss);
        }
        System.out.println("complete...");
    }

    /**
     * 按照业务过滤有效值范围
     */
    private static void filterByBusiness(List<Sample> samples) {
        for (int i = 0; i < samples.size(); i++) {
            double x = samples.get(i).getX();
            double y = samples.get(i).getY();
            if (x >= 500) {
                System.out.println(x + " x值超出有效值范围[0,500)");
                samples.remove(i);
                i--;
            }
            // y= 0.8d*x+15
            else if (y < 0 * x + 5 || y > 1 * x + 30) {
                System.out.println(
                        y + " y值超出有效值范围[(0*x+5),(1*x+30)]其中x=" + x + ",也就是:[" + (0 * x + 5) + "," + (1 * x + 30) + ")");
                samples.remove(i);
                i--;
            }
        }
    }

    /**
     * Description 点到直线的距离
     * 
     * @param x1
     *            点横坐标
     * @param y1
     *            点纵坐标
     * @param A
     *            直线方程一般式系数A
     * @param B
     *            直线方程一般式系数B
     * @param C
     *            直线方程一般式系数C
     * @return 点到之间的距离
     * @see 点0,1到之前y=x+0的距离 <br>
     *      double distance = getDistanceOfPerpendicular(0,0, -1, 1, 0);<br>
     *      System.out.println(distance);<br>
     */
    private static double getDistanceOfPerpendicular(double x1, double y1, double A, double B, double C) {
        double distance = Math.abs((A * x1 + B * y1 + C) / Math.sqrt(A * A + B * B));
        return distance;
    }

    private static List<Sample> filterSample(double wear_loss, double path_loss, List<Sample> samples) {
        // x={x1,x2,x3...xn}
        // u=E(x) ---x的期望(均值)为 u
        // 6=sqrt(pow((x1-u),2)+pow((x2-u),2)+pow((x3-u),2)+...+pow((xn-u),2))
        // 6为x的标准差,标准差=sqrt(方差)
        // 剔除噪点可以采用:
        // 若xi不属于(u-3*6,u+3*6),则认为它是噪点。

        // 求出所有点距离第一次拟合结果的直线方程的距离
        int sampleCount = samples.size();
        List<Double> distanItems = new ArrayList<Double>();
        // 计算出所有点距离该中心点的距离
        for (int i = 0; i < samples.size(); i++) {
            Sample sample = samples.get(i);
            double distance = getDistanceOfPerpendicular(sample.getX(), sample.getY(), path_loss, -1, wear_loss);
            distanItems.add(Math.sqrt(distance));
        }

        // 以下对根据距离(所有点距离中心点的距离)进行筛选
        double sumDistan = 0d;
        double distanceU = 0d;
        for (Double distance : distanItems) {
            sumDistan += distance;
        }
        distanceU = sumDistan / sampleCount;

        double deltaPowSum = 0d;
        double distanceTheta = 0d;
        // sqrt(pow((x1-u),2)+pow((x2-u),2)+pow((x3-u),2)+...+pow((xn-u),2))
        for (Double distance : distanItems) {
            deltaPowSum += Math.pow((distance - distanceU), 2);
        }
        distanceTheta = Math.sqrt(deltaPowSum);

        // 剔除噪点可以采用:
        // 若xi不属于(u-3*6,u+3*6),则认为它是噪点。
        double minDistance = distanceU - 0.25 * distanceTheta;
        double maxDistance = distanceU + 0.25 * distanceTheta;
        List<Integer> willbeRemoveIdxs = new ArrayList<Integer>();

        for (int i = distanItems.size() - 1; i >= 0; i--) {
            Double distance = distanItems.get(i);
            if (distance <= minDistance || distance >= maxDistance) {
                System.out.println(distance + " out of range [" + minDistance + "," + maxDistance + "]");
                willbeRemoveIdxs.add(i);
            } else {
                System.out.println(distance);
            }
        }

        for (int willbeRemoveIdx : willbeRemoveIdxs) {
            Sample sample = samples.get(willbeRemoveIdx);
            System.out.println("remove " + sample);
            samples.remove(willbeRemoveIdx);
        }

        return samples;
    }

    /**
     * 对采样点进行排序,按照x排序,升序排列
     * 
     * @param samples
     *            采样点集合
     **/
    private static void sortSample(List<Sample> samples) {
        samples.sort(new Comparator<Sample>() {
            public int compare(Sample o1, Sample o2) {
                if (o1.getX() > o2.getX()) {
                    return 1;
                } else if (o1.getX() <= o2.getX()) {
                    return -1;
                }
                return 0;
            }
        });
    }

    /**
     * Description 采用最小二乘法多项式拟合方式,获取多项式的系数。
     * 
     * @param sampleCount
     *            采样点个数
     * @param fetureCount
     *            多项式的系数
     * @param samples
     *            采样点集合
     **/
    private static Matrix leastsequare(int sampleCount, int fetureCout, List<Sample> samples) {
        // 构件 2*2矩阵 存储X,元素值都为1.0000的矩阵
        Matrix matrixX = DenseMatrix.Factory.ones(sampleCount, fetureCout);

        for (int i = 0; i < samples.size(); i++) {
            matrixX.setAsDouble(samples.get(i).getX(), i, 1);
        }

        // System.out.println(matrixX);
        // System.out.println("--------------------------------------");
        // 构件 2*2矩阵 存储X
        Matrix matrixY = DenseMatrix.Factory.ones(sampleCount, 1);

        for (int i = 0; i < samples.size(); i++) {
            matrixY.setAsDouble(samples.get(i).getY(), i, 0);
        }
        // System.out.println(matrixY);

        // 对X进行转置
        Matrix matrixXTrans = matrixX.transpose();
        // System.out.println(matrixXTrans);

        // 乘积运算:x*转转置后x:matrixXTrans*matrixX
        Matrix matrixMtimes = matrixXTrans.mtimes(matrixX);
        // System.out.println(matrixMtimes);

        // System.out.println("--------------------------------------");
        // 求逆
        Matrix matrixMtimesInv = matrixMtimes.inv();
        // System.out.println(matrixMtimesInv);

        // x转置后结果*求逆结果
        // System.out.println("--------------------------------------");
        Matrix matrixMtimesInvMtimes = matrixMtimesInv.mtimes(matrixXTrans);
        // System.out.println(matrixMtimesInvMtimes);

        // System.out.println("--------------------------------------");
        Matrix theta = matrixMtimesInvMtimes.mtimes(matrixY);
        // System.out.println(theta);

        return theta;
    }
}
View Code

免责声明:文章转载自《最小二乘法多项式拟合的Java实现》仅用于学习参考。如对内容有疑问,请及时联系本站处理。

上篇文档格式转换:mobi、epub、pdf、word互转PHP中级篇 Apache配置httpd-vhosts虚拟主机总结及注意事项[OK]下篇

宿迁高防,2C2G15M,22元/月;香港BGP,2C5G5M,25元/月 雨云优惠码:MjYwNzM=

相关文章

从时序异常检测(Time series anomaly detection algorithm)算法原理讨论到时序异常检测应用的思考

1. 主要观点总结 0x1:什么场景下应用时序算法有效 历史数据可以被用来预测未来数据,对于一些周期性或者趋势性较强的时间序列领域问题,时序分解和时序预测算法可以发挥较好的作用,例如: 四季与天气的关系模式 以交通量计算的交通高峰期的模式 心跳的模式 股票市场和某些产品的销售周期 数据需要有较强的稳定性,例如”预测商店营业额“和"预测打车订单"的稳定...

MATLAB拟合和插值

定义 插值和拟合: 曲线拟合是指您拥有散点数据集并找到最适合数据一般形状的线(或曲线)。 插值是指您有两个数据点并想知道两者之间的值是什么。中间的一半是他们的平均值,但如果你只想知道两者之间的四分之一,你必须插值。   拟合 我们着手写一个线性方程图的拟合: y=3x^3+2x^2+x+2 首先我们生成一组数据来分析: x=-5:0.5:5; e=50*r...

python科学计算_scipy_常数与优化

scipy在numpy的基础上提供了众多的数学、科学以及工程计算中常用的模块;是强大的数值计算库; 1. 常数和特殊函数 scipy的constants模块包含了众多的物理常数: import scipy.constants as CC.c  #真空中的光速C.h  #普朗克常数C.pi #圆周率  在C.physical_constants字典中,通过物...

目标函数、损失函数、代价函数

http://www.cnblogs.com/Belter/p/6653773.html 注:代价函数(有的地方也叫损失函数,Loss Function)在机器学习中的每一种算法中都很重要,因为训练模型的过程就是优化代价函数的过程,代价函数对每个参数的偏导数就是梯度下降中提到的梯度,防止过拟合时添加的正则化项也是加在代价函数后面的。在学习相关算法的过程中,...

车道线曲线拟合方法

车道线曲线拟合方法 一.车道线拟合算法 背景技术:在车辆行驶过程中,为了更好的了解并预测周围的环境信息,对车道线的曲线拟合是必要的。将采集来的车道线数据进行曲线拟合可以很好的估计车道线的参数信息,得知偏移量、倾斜角、曲率半径等信息,从而预测车道线的走向,为驾驶员或车辆自动控制系统提供帮助。现有的主流算法是将采集的车道线数据投影到鸟瞰图后直接用最小二乘法做三...

Matlab 曲线拟合

在matlab中经常需要对数据进行曲线拟合,如最常见的多项式拟合,一般可以通过cftool调用曲线拟合工具(curve fit tool),通过图形界面可以很方便的进行曲线拟合,但是有些时候也会遇到不方便用图形工具。因此这里简单的记下两种常用的拟合方法。 1 多项式拟合(polyfit和polyval) polyfit可以对数据进行拟合(自定义用几次多项式...