当前位置: 首页 > news >正文

人工智能学习57-TF训练

人工智能学习概述—快手视频
人工智能学习57-TF训练—快手视频
人工智能学习58-TF训练—快手视频
人工智能学习59-TF预测—快手视频

训练示例代码

#导入keras.utils 工具包 
import keras.utils 
#导入mnist数据集 
from keras.datasets import mnist 
#引入tensorflow 类库 
import tensorflow.compat.v1 as tf 
#关闭tensorflow 版本2的功能,仅使用tensorflow版本1的功能 
tf.disable_v2_behavior() 
#引用numpy 处理矩阵操作 
import numpy as np 
#引用图形处理类库 
import matplotlib.pyplot as plt 
import matplotlib 
#引入操作系统类库,方便处理文件与目录 
import os 
#避免多库依赖警告信息 
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' 
#设置tensorflow 训练模型所在目录 
model_path = '../log/model.ckpt' 
#设置神经网络分类数量,0-9个数字需要10个分类 
num_classes = 10 
#从数据集mnist装入训练数据集和测试数据集,mnist提供load_data方法 
(x_train, y_train), (x_test, y_test) = mnist.load_data() 
#灰度图编码范围0-255,将编码归一化,转化为0-1之间数值 
x_train = x_train.astype('float32') / 255.0 
x_test = x_test.astype('float32') / 255.0 
#将训练和测试标注数据转化为张量(batch,num_classes) 
y_train = keras.utils.to_categorical(y_train, num_classes) 
y_test = keras.utils.to_categorical(y_test, num_classes) 
#定义输入张量,分配地址 
x = tf.placeholder(tf.float32, [None, 784]) 
y = tf.placeholder(tf.float32, [None, 10]) 
#初始化第一层权重W矩阵 
w1 = tf.Variable(tf.random_normal([784, 128])) 
#初始化第一层偏置B向量 
b1 = tf.Variable(tf.zeros([128])) 
#定义第一层激活函数输入值 X*W+B 
hc1 = tf.add(tf.matmul(x,w1),b1) 
#定义第一层输出,调用激活函数sigmoid 
h1 = tf.sigmoid(hc1) 
#初始化第二层权重W矩阵 
w2 = tf.Variable(tf.random_normal([128, 10])) 
#初始化第二层偏置B向量 
b2 = tf.Variable(tf.zeros([10])) 
#使用激活函数softmax,预测第二层输出 
pred = tf.nn.softmax(tf.matmul(h1, w2) + b2) 
#使用交叉熵定义代价函数 
cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred), reduction_indices=1)) 
#定义学习率 
learn_rate = 0.01 
#使用梯度下降法优化网络 
optimizer = tf.train.GradientDescentOptimizer(learn_rate).minimize(cost) 
epoch_list = [] 
cost_list = [] train_epoch = 30 
batch_size = 100 
display_step = 1 
#定义神经网络模型保存对象 
saver = tf.train.Saver() 
#触发tensorflow初始化,为定义变量赋值 
init = tf.global_variables_initializer() 
#启动tensorflow会话 
with tf.Session() as sess: if os.path.exists('../log/model.ckpt.meta'): saver.restore(sess, model_path) #如果存在网络模型,在会话中装入网络模型 else: sess.run(init) #如果不存在网络模型,会话执行初始化工作 #循环遍历每次训练 for epoch in range(train_epoch): #定义平均损失 avg_cost = 0. #计算总批次 total_batch = int(x_train.shape[0] / batch_size) #循环每批次样本数据 for i in range(total_batch): #读取每批次训练样本数据 batch_xs = x_train[i*batch_size: (i+1)*batch_size] #读取每批次训练标签样本数据 batch_ys = y_train[i*batch_size: (i+1)*batch_size] #转化样本数据格式,添加第一维度代表样本数量 batch_xs = np.reshape(batch_xs, (100, -1)) #启动tensorflow进行样本数据训练 _, c = sess.run([optimizer, cost], feed_dict={x: batch_xs, y: 
batch_ys}) #累计每批次的平均损失 avg_cost += c / total_batch #记录平均损失 epoch_list.append(epoch+1) cost_list.append(avg_cost) #每隔display_step次训练,输出一次统计信息 if (epoch+1) % display_step == 0: print('Epochs: ','%04d' % (epoch+1), 'Costs ', 
'{:.9f}'.format(avg_cost)) print('Train Finished') #保存tensorflow训练的模型 
save_dir = saver.save(sess, model_path) 
print('TensorFlow model save as file %s' % save_dir) 
#图形显示训练结果 
matplotlib.rcParams['font.family'] = 'SimHei' 
plt.plot(epoch_list, cost_list, '.') 
plt.title('Train Model') 
plt.xlabel('Epoch') 
plt.ylabel('Cost') 
plt.show() 

在这里插入图片描述
在这里插入图片描述

测试示例代码

#导入keras.utils 工具包 
import keras.utils 
#导入mnist数据集 
from keras.datasets import mnist 
#引入tensorflow 类库 
import tensorflow.compat.v1 as tf 
#关闭tensorflow 版本2的功能,仅使用tensorflow版本1的功能 
tf.disable_v2_behavior() 
#引用图形处理类库 
import matplotlib.pyplot as plt 
#引用numpy处理矩阵操作 
import numpy as np 
#引入操作系统类库,方便处理文件与目录 
import os 
#避免多库依赖警告信息 
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' 
#设置tensorflow 训练模型所在目录 
model_path = '../log/model.ckpt' 
#设置神经网络分类数量,0-9个数字需要10个分类 
num_classes = 10 
#从数据集mnist装入训练数据集和测试数据集,mnist提供load_data方法 
(x_train, y_train), (x_test, y_test) = mnist.load_data() 
#灰度图编码范围0-255,将编码归一化,转化为0-1之间数值 
x_train = x_train.astype('float32') / 255.0 
x_test = x_test.astype('float32') / 255.0 
#将训练和测试标注数据转化为张量(batch,num_classes) 
y_train = keras.utils.to_categorical(y_train, num_classes) 
y_test = keras.utils.to_categorical(y_test, num_classes) 
#定义输入张量,分配地址 
x = tf.placeholder(tf.float32, [None, 784]) 
y = tf.placeholder(tf.float32, [None, 10]) 
#初始化第一层权重W矩阵 
w1 = tf.Variable(tf.random_normal([784, 128])) 
#初始化第一层偏置B向量 
b1 = tf.Variable(tf.zeros([128])) 
#定义第一层激活函数输入值 X*W+B 
hc1 = tf.add(tf.matmul(x,w1),b1) 
#定义第一层输出,调用激活函数sigmoid 
h1 = tf.sigmoid(hc1) 
#初始化第二层权重W矩阵 
w2 = tf.Variable(tf.random_normal([128, 10])) 
#初始化第二层偏置B向量 
b2 = tf.Variable(tf.zeros([10])) 
#使用激活函数softmax,预测第二层输出 
pred = tf.nn.softmax(tf.matmul(h1, w2) + b2) 
#测试样本集数量 
test_num = x_test.shape[0] 
#定义获取随机整数函数,返回0- test_num 
def rand_int(): 
rand = np.random.RandomState(None) 
return rand.randint(low=0, high=test_num) 
#定义神经网络模型保存对象 
saver = tf.train.Saver() 
n = rand_int() 
#启动tensorflow 会话 
with tf.Session() as sess: 
#在会话中装入网络模型 
saver.restore(sess, model_path) 
#读取测试样本数据 
batch_xs = x_test[n: n+2] 
#读取测试标签样本数据 
batch_ys = y_test[n: n+2] 
#转化样本数据格式,添加第一维度代表样本数量 
batch_xs = np.reshape(batch_xs, (2, -1)) 
#定义模型预测输出概率最大品类 
output = tf.argmax(pred, 1) 
#使用模型预算 
outputv, predv = sess.run([output, pred], feed_dict={x: batch_xs}) 
#图形输出 
plt.figure(figsize=(2, 3)) 
for i in range(batch_xs.ndim): 
plt.subplot(1, 2, i+1) 
plt.subplots_adjust(wspace=2) 
t = batch_xs[i].reshape(28, 28) 
plt.imshow(t, cmap='gray') 
if outputv[i] == batch_ys[i].argmax(): 
plt.title('%d,%d' 
% 
color='green') 
else: 
(outputv[i], 
batch_ys[i].argmax()), 
plt.title('%d,%d' % (outputv[i], batch_ys[i].argmax()), color='red') 
plt.xticks([]) 
plt.yticks([]) 
plt.show() 

在这里插入图片描述

相关文章:

  • Shell脚本中和||语法解析
  • tkinter 的 place() 布局管理器学习指南
  • 软件架构的发展历程——从早期的单体架构到如今的云原生与智能架构
  • FPGA基础 -- Verilog 的属性(Attributes)
  • 使用 Isaac Sim 模拟机器人
  • windows清理系统备份文件夹WinSxS文件夹清理
  • tkinter Text 组件学习指南
  • 初学python的我开始Leetcode题10-2
  • 大数据Hadoop集群搭建
  • 加密货币:比特币
  • 结构体的嵌套问题
  • Llama 4 模型卡及提示格式介绍
  • swift-14-可选项的本质、运算符重载、扩展( Extension )
  • 班车服务系统扩展到多场景(穿梭车、周转车)的升级过程中,遗传算法和蚁群算法的实现示例
  • RAG 知识库核心模块全解(产品视角 + 技术细节)
  • day37
  • 项目开发中途遇到困难的解决方案
  • 详解Redis的热点key问题
  • Python 数据分析与可视化 Day 2 - 数据清洗基础
  • 【云创智城】YunCharge充电桩系统-深度剖析OCPP 1.6协议及Java技术实现:构建高效充电桩通信系统
  • 建设工程监理考试网站/百度数据查询
  • 电脑网络题搜网站怎么做/潍坊百度快速排名优化
  • 百度网站提交地址/seo问答
  • qq刷赞网站怎么做的/宁波网站优化公司电话
  • 网站建设学什么的/事件营销成功案例
  • 信用网站建设情况/市场营销主要学什么