当前位置: 首页 > news >正文

打卡day52

简单cnn 借助调参指南进一步提高精度

基础CNN模型代码

import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.utils import to_categorical# 加载数据
(train_images, train_labels), (test_images, test_labels) = cifar10.load_data()# 数据预处理
train_images = train_images.astype('float32') / 255
test_images = test_images.astype('float32') / 255
train_labels = to_categorical(train_labels)
test_labels = to_categorical(test_labels)# 基础CNN模型
model = models.Sequential([layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)),layers.MaxPooling2D((2, 2)),layers.Conv2D(64, (3, 3), activation='relu'),layers.MaxPooling2D((2, 2)),layers.Conv2D(64, (3, 3), activation='relu'),layers.Flatten(),layers.Dense(64, activation='relu'),layers.Dense(10, activation='softmax')
])model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])history = model.fit(train_images, train_labels, epochs=10, batch_size=64,validation_data=(test_images, test_labels))

改进方法

增加模型复杂度

model = models.Sequential([layers.Conv2D(64, (3, 3), activation='relu', input_shape=(32, 32, 3), padding='same'),layers.BatchNormalization(),layers.Conv2D(64, (3, 3), activation='relu', padding='same'),layers.BatchNormalization(),layers.MaxPooling2D((2, 2)),layers.Dropout(0.25),layers.Conv2D(128, (3, 3), activation='relu', padding='same'),layers.BatchNormalization(),layers.Conv2D(128, (3, 3), activation='relu', padding='same'),layers.BatchNormalization(),layers.MaxPooling2D((2, 2)),layers.Dropout(0.25),layers.Conv2D(256, (3, 3), activation='relu', padding='same'),layers.BatchNormalization(),layers.Conv2D(256, (3, 3), activation='relu', padding='same'),layers.BatchNormalization(),layers.MaxPooling2D((2, 2)),layers.Dropout(0.25),layers.Flatten(),layers.Dense(512, activation='relu'),layers.BatchNormalization(),layers.Dropout(0.5),layers.Dense(10, activation='softmax')
])

优化器调参

from tensorflow.keras.optimizers import Adamoptimizer = Adam(learning_rate=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-07)
model.compile(optimizer=optimizer,loss='categorical_crossentropy',metrics=['accuracy'])

数据增强

from tensorflow.keras.preprocessing.image import ImageDataGeneratordatagen = ImageDataGenerator(rotation_range=15,width_shift_range=0.1,height_shift_range=0.1,horizontal_flip=True,zoom_range=0.1
)
datagen.fit(train_images)history = model.fit(datagen.flow(train_images, train_labels, batch_size=64),epochs=50,validation_data=(test_images, test_labels))

早停和模型检查点

from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpointcallbacks = [EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True),ModelCheckpoint('best_model.h5', monitor='val_accuracy', save_best_only=True)
]history = model.fit(..., callbacks=callbacks, epochs=100)

相关文章:

  • Appium + Python 测试全流程
  • FFmpeg是什么?
  • 106.给AI回答添加点赞收藏功能
  • AI技术专题:电商AI专题
  • PERST#、Hot Reset、Link Disable
  • 什么是序列化?反序列化? 场景使用? 怎么实现???
  • GitHub Desktop Failure when receiving data from the peer
  • Redis的常用配置详解
  • Chapter07-信息披漏
  • 数据管理四部曲:元数据管理、数据整合、数据治理、数据质量管控
  • 修改FFMpeg的日志函数av_log,使其在记录日志时能显示调用该函数的位置(文件名和行号)
  • SGDvsAdamW 优化策略详细解释
  • C++-入门到精通【18】string类和字符串流处理的深入剖析
  • 结构型模式 (7种)
  • 今日行情明日机会——20250612
  • 深度解析Git错误:`fatal: detected dubious ownership in repository` 的根源与解决方案
  • 通过同步压缩小波变换实现信号的分解和重构
  • 线程池启动报null :Caused by: java.lang.IllegalArgumentException: null
  • 成功解决 ValueError: Unable to find resource t64.exe in package pip._vendor.distlib
  • 准确---配置全局代理
  • 新媒体营销实训报告总结/seo综合查询怎么关闭
  • 已经有了域名和服务器怎么做网站/广告免费发布信息平台
  • wordpress和vue/seo首页优化
  • 中山技术支持中山网站建设/南宁seo公司
  • 做网站标题代码/广州各区最新动态
  • 有没有a站可以打开/做企业网站建设公司哪家好