TensorFlow入门(3):使用神经网络拟合N元一次方程

背景

前面一篇文章《TensorFlow入门:求N元一次方程》在已知表达式形式的情况下,获得了各个参数的值,但是现实中大部分情况是不能简单使用N元一次方程这样的公式表达的,神经网络的出现,给这类问题提供了一个很好的解决方法。本文继续给出一个简单的例子,使用TensorFlow,利用神经网络对N元一次方程进行拟合。

关于神经网络的简单入门介绍,可以参考这篇文章

如何实现

在使用TensorFlow之前,还是要import相关的包:

首先回顾一下前面的功能,我们有一个函数,它有5个输入值,一个输出值,这里使用param_count表示输入值的个数,当前它的值为5:

我们需要一个已知的函数来生成数据,根据函数y=x*w,w是这个函数的参数,令它为大小为[param_size,1]的矩阵,这里我随便填了5个0到1000的数字:

对于训练的输入和输出值,使用placeholder进行表示:

而之前的w,因为我们使用神经网络表示了,因此不需要了,我们甚至不需要知道这个函数一定是个N元一次方程。接下来就是重点部分,构造神经网络。TensorFlow提供了很多高级API,这个问题是一个回归问题,回归问题,就是通过一定的值,预测值的问题,这个和前篇的分类是不同的问题。我们使用tf.contrib.learn.DNNRegressor来构造神经网络,首先需要告诉它输入有哪些参数,叫做特征列,因为我们只有一个x输入,它是一个大小为[1,param_size]的矩阵,因此定义一个输入x:

使用上面构造的特征列构造一个DNNRegressor的实例,这里先把隐藏层hidden_units设置为[5,5],表示有2个隐藏层,每层有5个神经元,关于这个值怎么设置,学问很大,我暂时还说不清楚,未来了解后再补充。另一个参数model_dir指定学习的结果存放的路径,如果存在则读取,不存在则创建,因为训练神经网络一般比较耗时,因此尽量将结果保存下来,这样即便中途中断,也可以恢复,如果格式不一样,比如特征列或者隐藏层数量不一样,TensorFlow会报错。

然后就可以开始训练过程了,训练的过程可以每次生成一组新的训练数据,然后调用regressor.fit函数训练2000次,其中参数x表示输入值,参数y表示输出值。然后生成一组新的测试数据调用regressor.evaluate函数进行评估,当Loss函数小于一定值的时候停止训练:

不出意外的话,现在就可以开始训练了。最终训练的目的是为了给出指定的输入值,返回一个预测值,我们生成一组预测值,并且看看预测效果:

完整代码如下:

这样训练大约75W次后(使用Z3740(1.33GHz)大约需要1小时的时间),Loss函数会降低到10以内,得到的预测值和实际结果已经相差很小了。

但是可以看到loss函数并不是很稳定,可能突增或者突降,因为每次提供的训练数据太少了,我们可以通过提高x和y的大小来加快训练,同时提高训练效果,可以通过修改x和y矩阵大小来达到目的,修改后的代码如下:

同时跑上面2个训练,可以发现优化后的训练速度大大加快了,loss函数降低很迅速。由于loss函数是20个值的标准差,所以相应要提高一些。神经网络训练出来的结果不是一个[5,1]的矩阵,因此对于验证和预测输入,不能只是大小为[1,5]的矩阵,需要是大小为[20,5]的矩阵,所以在预测的时候,可以填充无效值,结果只取y的第一个值就好了。

在经过150W次训练之后,得到了比较准确的预测效果:

通过函数传入训练数据

TensorFlow还提供通过函数的方式传入输入数据,上面的例子是在while循环中将训练数据生成好传入,如果训练数据比较复杂或者不想将其与训练的代码耦合太大,可以将读取训练数据封装成一个函数传给fit、evaluate和predict。

这个函数需要返回2个值,第一个返回值是输入,它是一个字典,Key是特征列,Value是特征值,第二个返回值是输入对应的输出值,比如上面的例子,可以这样构造训练集:

传入到fit的方式是这样的:

完整代码如下:

在训练了210W次后,loss函数降低到了50以内:

因为预测需要提供20组数据,如果我们只需要预测一组怎么办呢?第一维可以只传入一组数据,其实第一维的数量是可以变化的:

这样预测出来的结果如下:

实际值是30460,预测值是30488.938,可见预测还是挺准确的。

参考资料

 

anyShare分享到:

原文地址:http://godmoon.wicp.net/blog/index.php/post_489.html,转载请注明出处

Moon发表于2017年6月11日
打赏作者

您的支持将鼓励我们继续创作!

[微信] 扫描二维码打赏

[支付宝] 扫描二维码打赏

发布者

sytzz

学会用简单的语言将复杂的问题说清楚。

发表评论

电子邮件地址不会被公开。