当前位置: 首页 > news >正文

用Python实现神经网络(一)

Python实现神经网络有多种方法。这里我们使用keras框架。你必须安装 tensorflow或theano, 和 keras然后才能实现神经网络。

1. 下载数据集并提取训练和测试集(见“NN.ipynb”)

from keras.datasets import mnist

import matplotlib.pyplot as plt

%matplotlib inline

# load (downloaded if needed) the MNIST dataset

(X_train, y_train), (X_test, y_test) = mnist.load_data()

# plot 4 images as gray scale

plt.subplot(221)

plt.imshow(X_train[0], cmap=plt.get_cmap('gray'))

plt.subplot(222)

plt.imshow(X_train[1], cmap=plt.get_cmap('gray'))

plt.subplot(223)

plt.imshow(X_train[2], cmap=plt.get_cmap('gray'))

plt.subplot(224)

plt.imshow(X_train[3], cmap=plt.get_cmap('gray'))

# show the plot

plt.show()

4-18. 输出

2. 导入相关包:

import numpy as np

from keras.datasets import mnist

from keras.models import Sequential

from keras.layers import Dense

from keras.layers import Dropout

from keras.utils import np_utils

3. 预处理数据集:

num_pixels = X_train.shape[1] * X_train.shape[2]

# reshape the inputs so that they can be passed to the

vanilla NN

X_train = X_train.reshape(X_train.shape[0],num_pixels

).astype('float32')

X_test = X_test.reshape(X_test.shape[0],num_pixels).

astype('float32')

# scale inputs

X_train = X_train / 255

X_test = X_test / 255

# one hot encode the output

y_train = np_utils.to_categorical(y_train)

y_test = np_utils.to_categorical(y_test)

num_classes = y_test.shape[1]

4. 构建模型:

# building the model

model = Sequential()

# add 1000 units in the hidden layer

# apply relu activation in hidden layer

model.add(Dense(1000, input_dim=num_pixels,activation='relu'))

# initialize the output layer

model.add(Dense(num_classes, activation='softmax'))

# compile the model

model.compile(loss='categorical_crossentropy',optimizer='adam', metrics=['accuracy'])

# extract the summary of model

model.summary()

5. 运行模型:

model.fit(X_train, y_train, validation_data=(X_test,

y_test), epochs=5, batch_size=1024, verbose=1)

注意随着epochs增加, 测试集的准确率也增加。另外在keras里我们只要在第一层指明输入的维,它会自己动的推出余下各层的维。

http://www.dtcms.com/a/284447.html

相关文章:

  • 基于FPGA的IIC控制EEPROM读写(2)
  • 解决 MyBatis/MyBatis-Plus 中 UUID 类型转换错误的最佳实践
  • OC—初识UIStackView
  • 线程安全集合——CopyOnWriteArrayList
  • FRP配置( CentOS 7 上安装 FRP教程 )
  • MySql查询 值存在但查不到
  • 深度学习G3周:CGAN入门(生成手势图像)
  • 理解欧拉角:定义、转换与应用
  • HTTPS的工作原理及DNS的工作过程
  • 【LeetCode 热题 100】108. 将有序数组转换为二叉搜索树
  • SpringBoot使用ThreadLocal共享数据
  • 2021-07-21 VB窗体求范围质数(Excel复制工作簿)
  • Python 基础语法与数据类型(十三) - 实例方法、类方法、静态方法
  • 【测试100问】没有接口文档的情况下,如何做接口测试?
  • MinIO:开源对象存储解决方案的领先者
  • DiffPy-CMI详细安装教程
  • 【Vue进阶学习笔记】组合式API(Composition API)
  • Go 程序无法使用 /etc/resolv.conf 的 DNS 配置排查记录
  • React hooks——memo
  • 【软件开发】主流 AI 编码插件
  • 关于el-table异步获取数据渲染动态列数据赋值列数据渲染时title高度异常闪过问题
  • 深度解析:基于EasyX的C++黑白棋AI实现 | 算法核心+图形化实战
  • 数据呈现进阶:漏斗图与雷达图的实战指南
  • 基于Echarts的气象数据可视化网站系统的设计与实现(Python版)
  • Idea使用git不提示账号密码登录,而是输入token问题解决
  • 【解决方案】yakit流量转发到mitmproxy
  • 浅谈 awk 中管道的用法
  • zynq mpsoc switch级联ssd高速存储方案
  • 贴吧项目总结二
  • mysql——搭建MGR集群