tensorflow 2.0 学习(二)线性回归问题

摘要:
主要的数学代码可以理解,但只有参数在梯度的相反方向上更新,这不是很容易理解!这里不使用Tensorflow。下次,更新基础知识!

线性回归问题

 1 # encoding: utf-8
 2 
 3 import numpy as np
 4 import matplotlib.pyplot as plt
 5 
 6 data = []
 7 for i in range(100):
 8     x = np.random.uniform(-10., 10.)    #均匀分布产生x
 9     eps = np.random.normal(0., 0.01)    #高斯分布产生一个误差值
10     y = 1.477*x + 0.089 +eps            #计算得到y值
11     data.append([x, y])                 #保存到data中
12 
13 data = np.array(data)                   #转成数组的方式方便处理
14 plt.plot(data[:,0], data[:,1], 'b')    #自己引入用于观察原始数据
15 #plt.show()
16 plt.savefig('original data.png')
17 
18 
19 def mse(b, w, points):                  #计算所有点的预测值和真实值之间的均方误差
20     totalError = 0
21     for i in range(0, len(points)):
22         x = points[i, 0]
23         y = points[i, 1]
24         totalError += (y -(w*x + b))**2     #真实值减预测值的平方
25     return totalError/float(len(points))    #返回平均误差值
26 
27 
28 def step_gradient(b_current, w_current, points, lr):    #预测模型中梯度下降方式优化b和w
29     b_gradient = 0
30     w_gradient = 0
31     M = float(len(points))
32     for i in range(0, len(points)):
33         x = points[i, 0]
34         y = points[i, 1]
35         b_gradient += (2/M) * ((w_current*x + b_current) - y)   #求偏导数的公式可知
36         w_gradient += (2/M)*x*((w_current*x + b_current) - y)   #求偏导数的公式可知
37     new_b = b_current - (lr*b_gradient)      #更新参数,使用了梯度下降法
38     new_w = w_current - (lr*w_gradient)      #更新参数,使用了梯度下降法
39     return [new_b, new_w]
40 
41 
42 def gradient_descent(points, starting_b, starting_w, lr, num_iterations):   #循环更新w,b多次
43     b = starting_b
44     w = starting_w
45     loss_data = []
46     for step in range(num_iterations):  #计算并更新一次
47         b, w = step_gradient(b, w, np.array(points), lr)        #更新了这一次的b,w
48         loss = mse(b, w, points)
49         loss_data.append([step+1, loss])
50         if step % 50 == 0:              #每50次输出一回
51             print(f"iteration:{step}, loss{loss}, w:{w}, b:{b}")
52     return [b, w, loss_data]
53 
54 
55 def main():
56     lr = 0.01               #学习率,梯度下降算法中的参数
57     initial_b = 0           #初值
58     initial_w = 0
59     num_iterations = 1000   #学习100轮
60     [b, w, loss_data] = gradient_descent(data, initial_b, initial_w, lr, num_iterations)
61     loss = mse(b, w, data)
62     print(f'Final loss:{loss}, w:{w}, b:{b}')
63 
64     plt.figure()            #观察loss每一步情况
65     loss_data = np.array(loss_data)
66     plt.plot(loss_data[:,0], loss_data[:,1], 'g')
67     plt.savefig('loss.png')
68     #plt.show()
69 
70     plt.figure()        #观察最终的拟合效果
71     y_fin = w*data[:,0] + b + eps
72     plt.plot(data[:,0], y_fin, 'r')
73     #plt.show()
74     plt.savefig('final data.png')
75 
76 
77 if __name__ == '__main__':
78     main()

original data (y = w*x + b +eps)

tensorflow 2.0 学习(二)线性回归问题第1张

loss rate

tensorflow 2.0 学习(二)线性回归问题第2张

final data (y' = w' *x + b' + eps )

tensorflow 2.0 学习(二)线性回归问题第3张

最终loss趋近9.17*10^-5, w趋近1.4768, b趋近0.0900

真实的w值1.477, b为0.089

对于线性回归问题,适用性挺好!

主要的数学代码能理解,唯有取梯度的反方向更新参数,不是很能理解!

tensorflow 2.0 学习(二)线性回归问题第4张

 这里还没有用到tensorflow,下一次更新基础知识!

免责声明:文章转载自《tensorflow 2.0 学习(二)线性回归问题》仅用于学习参考。如对内容有疑问,请及时联系本站处理。

上篇Delphi 回调函数Linux内存描述之内存页面page--Linux内存管理(四)下篇

宿迁高防,2C2G15M,22元/月;香港BGP,2C5G5M,25元/月 雨云优惠码:MjYwNzM=

相关文章

dom 绑定数据

一、绑定/修改 .jQuery修改属性值,都是在内存中进行的,并不会修改 DOM 1. 对象绑定$(selector).data(name) $("#form").data("name") 2. dom 绑定 $.data(element,name, val); jQuery.data($("#form")[0], "testing", 123); 3....

关于jQuery中的attr和data问题

今天在使用data获取属性并且赋值时遇到一个小问题,写下来防止以后再跳坑。 在使用jQuery获取自定义属性值时,我们习惯用 $(selector).attr('data-value'); jQuery赋值: $(selector).attr('data-value','123456'); 而data的取值: $(selector).data('value...

Elasticsearch集群角色类型node.master及node.data

在Elasticsearch当中,ES分为三种角色:master、data、client。 三种角色由elasticsearch.yml配置文件中的node.master、node.true来控制。 如果不修改elasticsearch的节点角色信息,那么默认就是node.master: true、node.data: true 默认情况下,es集群中的每...

AJAX全套

概述  对于WEB应用程序:用户浏览器发送请求,服务器接收并处理请求,然后返回结果,往往返回就是字符串(HTML),浏览器将字符串(HTML)渲染并显示浏览器上。 AJAX类似于偷偷像后台发送数据。 1、传统的Web应用 一个简单操作需要重新加载全局数据 2、AJAX AJAX,Asynchronous JavaScript and XML (异步的Ja...

微信小程序——data-*自定义属性

在jQuery的attr与prop提到过在IE9之前版本中如果使用property不当会造成内存泄露问题,而且关于Attribute和Property的区别也让人十分头痛,在HTML5中添加了data-*的方式来自定义属性,所谓data-*实际上上就是data-前缀加上自定义的属性名,使用这样的结构可以进行数据存放。使用data-*可以解决自定义属性混乱无...

TensorFlow 编程基础

1、TensorFlow   安装:https://www.cnblogs.com/pam-sh/p/12239387.html      https://www.cnblogs.com/pam-sh/p/12241942.html • 是一个开放源代码软件库,用于进行高性能数值计算• 借助其灵活的架构,用户可以轻松地将计算工作部署到多种平台(CPU、G...