< >
Home » Turbot-DL入门教程 » Turbot-DL入门教程篇-TensorFlow应用-线性回归计算

Turbot-DL入门教程篇-TensorFlow应用-线性回归计算

Turbot-DL入门教程篇-TensorFlow应用-线性回归计算

说明:

  • 介绍如何使用tensorflow解决线性回归的问题

环境:

  • Python 3.5.2

步骤:

  • 创建数据集:
$ vim linear_regression_data.py 


import numpy as np
import matplotlib.pyplot as plt

trX = np.linspace(-1, 1, 101)
trY = 2 * trX + \
        np.ones(*trX.shape) * 4 + \
        np.random.randn(*trX.shape) * 0.03


plt.figure(1)
plt.plot(trX, trY, 'o')

plt.xlabel('trX')
plt.ylabel('trY')

plt.show()
  • 训练数据集线性回归模型
$ vim linear_regression.py


#!/usr/bin/env python

import tensorflow as tf
import numpy as np

trX = np.linspace(-1, 1, 101)
trY = 2 * trX + \
        np.ones(*trX.shape) * 4 + \
        np.random.randn(*trX.shape) * 0.03

X = tf.placeholder(tf.float32)
Y = tf.placeholder(tf.float32)

w = tf.Variable(0.0, name="weights")
b = tf.Variable(0.0, name="biases")
y_model = tf.multiply(X, w) + b

cost = tf.square(Y - y_model)

train_op = tf.train.GradientDescentOptimizer(0.01).minimize(cost)

sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)

for i in range(100):
        for (x, y) in zip(trX, trY):
                sess.run(train_op, feed_dict={X: x, Y: y})
w_ = sess.run(w)
b_ = sess.run(b)

print("Result : trY = " + str(w_) + "*trX + " + str(b_))
  • 运行
python3 linear_regression.py

纠错,疑问,交流: 请进入讨论区点击加入Q群

获取最新文章: 扫一扫右上角的二维码加入“创客智造”公众号


标签: turbot-dl入门教程篇