当前位置: 首页 > news >正文

Keras使用1

目录

  • Keras Sequential
    • compile写法
    • 训练模型
    • 完整示例

Keras Sequential

最简单的使用方式是使用Keras Sequential, 中文叫做顺序模型, 可以用来表示多个网络层的线性堆叠.

使用的时候可以通过讲网络层实例的列表传递给Sequential, 比如:

from keras.models import Sequential
from keras.layers import Dense, Activation

model = Sequential([
    Dense(32, input_shape=(784,)),
    Activation('relu'),
    Dense(10),
    Activation('softmax'),
])

也可以简单的使用.add()方法将各层添加到模型中:

model = Sequential()
model.add(Dense(32, input_dim=784))
model.add(Activation('relu'))

对于输入层,需要指定输入数据的尺寸,通过Dense对象中的input_shape属性.注意无需写batch的大小.

input_shape=(784,) 等价于我们在神经网络中定义的shape为(None, 784)的Tensor.

模型创建成功之后,需要进行编译.使用.compile()方法对创建的模型进行编译.compile()方法主要需要指定一下几个参数:

  1. 优化器 optimizer:可以是Keras定义好的优化器的字符串名字,比如’rmsprop’也可以是Optimizer类的实例对象.常见的优化器有: SGD,
    RMSprop, Adagrad, Adadelta等.

  2. 损失函数 loss: 模型视图最小化的目标函数, 它可以是现有损失函数的字符串形式,比如:categorical_crossentropy, 也可以是一个目标函数.

  3. 评估标准 metrics. 评估算法性能的衡量指标.对于分类问题, 建议设置为metrics =[‘accuracy’].评估标准可以是现有的标准的字符串标识符,也可以是自定义的评估标准函数。

compile写法

以下为compile的常见写法:

# 多分类问题
model.compile(optimizer='rmsprop',
              loss='categorical_crossentropy',
              metrics=['accuracy'])

# 二分类问题
model.compile(optimizer='rmsprop',
              loss='binary_crossentropy',
              metrics=['accuracy'])

# 均方误差回归问题
model.compile(optimizer='rmsprop',
              loss='mse')

# 自定义评估标准函数
import keras.backend as K

def mean_pred(y_true, y_pred):
    return K.mean(y_pred)

model.compile(optimizer='rmsprop',
              loss='binary_crossentropy',
              metrics=['accuracy', mean_pred])

训练模型

训练模型: 使用.fit()方法,将训练数据,训练次数(epoch), 批次尺寸(batch_size)传递给fit()方法即可.

# 对于具有 2 个类的单输入模型(二进制分类):

model = Sequential()
model.add(Dense(32, activation='relu', input_dim=100))
model.add(Dense(1, activation='sigmoid'))
model.compile(optimizer='rmsprop',
              loss='binary_crossentropy',
              metrics=['accuracy'])

# 生成虚拟数据
import numpy as np
data = np.random.random((1000, 100))
labels = np.random.randint(2, size=(1000, 1))

# 训练模型,以 32 个样本为一个 batch 进行迭代
model.fit(data, labels, epochs=10, batch_size=32)

完整示例

下面用一个完整的例子来说明Keras的使用. 以下为使用神经网络来进行手写数字识别:

import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout
from keras.optimizers import RMSprop

batch_size = 128
num_classes = 10
epochs = 20

# 导入手写数字数据集
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# 对数据进行初步处理
x_train = x_train.reshape(60000, 784)
x_test = x_test.reshape(10000, 784)
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')

# 将标记结果转化为独热编码
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

# 创建顺序模型
model = Sequential()
# 添加第一层网络, 512个神经元, 激活函数为relu
model.add(Dense(512, activation='relu', input_shape=(784,)))
# 添加Dropout
model.add(Dropout(0.2))
# 第二层网络
model.add(Dense(512, activation='relu'))
model.add(Dropout(0.2))
# 输出层
model.add(Dense(num_classes, activation='softmax'))

# 打印神经网络参数情况
model.summary()

# 编译
model.compile(loss='categorical_crossentropy',
              optimizer=RMSprop(),
              metrics=['accuracy'])

# 训练并打印中间过程
history = model.fit(x_train, y_train,
                    batch_size=batch_size,
                    epochs=epochs,
                    verbose=1,
                    validation_data=(x_test, y_test))
# 计算预测数据的准确率
score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])

相关文章:

  • 【AI学习从零至壹】语⾔模型及词向量相关知识
  • linux多线(进)程编程——(6)共享内存
  • 链表代码实现(C++)
  • C语言--常见的编程示例
  • 医药采购系统平台第5天01:药品目录导入功能的实现Oracle触发器的定义供货商药品目录模块的分析供货商目录表和供货商控制表的分析、数据模型设计和优化
  • Rasa 模拟实现超简易医生助手(适合初学练手)
  • Tkinter表格与列表框应用
  • 制作像素风《饥荒》类游戏的整体蓝图和流程
  • ubuntu 2404 安装 vcs 2018
  • Doris 安装部署、实际应用及优化实践:对比 ClickHouse 的深度解析
  • 从零搭建高可用Kafka集群与EFAK监控平台:全流程实战总结
  • Foxmail邮件客户端跨站脚本攻击漏洞(CNVD-2025-06036)技术分析
  • Go:基本数据
  • leetcode 139. Word Break
  • < 自用文 Project-30.6 Crawl4AI > 为AI模型优化的网络爬虫工具 帮助收集和处理网络数据的工具
  • Java中的数组
  • 苍穹外卖Day-5
  • c# 新建不重名的唯一文件夹
  • STM32 HAL库时钟系统详解
  • AndroidTV 当贝播放器-v1.5.2-官方简洁无广告版
  • 澎湃回声丨23岁小伙“被精神病8年”续:今日将被移出“重精”管理系统
  • 顺利撤离空间站,神十九乘组踏上回家之旅
  • 王毅:坚持金砖团结合作,改革完善全球治理
  • 上海通报5起违反中央八项规定精神问题
  • 国家卫健委:工作相关肌肉骨骼疾病、精神和行为障碍成职业健康新挑战
  • 澎湃思想周报丨数字时代的育儿;凛冬已至好莱坞