当前位置: 首页 > news >正文

【机器学习】 Flux.jl 求解 XOR 分类问题的神经网络模型

Flux.jl 搭建神经网络基本流程

Chain(Dense, BatchNorm, Dense)
DataLoader
setup( Adam, RMSProp, Momentum...)
数据准备
搭建多层感知器
建立优化问题
选择算法训练神经网络
输出结果
using Flux, Statistics

# 生成XOR问题的数据
noisy = rand(Float32, 2, 200)  # 2×200的矩阵
truth = [xor(col[1] > 0.5, col[2] > 0.5) for col in eachcol(noisy)]  # 200元素向量
target = Flux.onehotbatch(truth, [true, false])  # 2×200的OneHotMatrix

# 定义模型,一个具有3个隐藏层的多层感知器
model = Chain(
    Dense(2 => 3, tanh),  # 使用tanh激活函数
    BatchNorm(3),         # 批量归一化
    Dense(3 => 2)         # 输出层
)

# 模型输出
out1 = model(noisy)
probs1 = softmax(out1)  # 使用softmax函数获取概率

# 为训练准备目标数据


# 创建数据加载器
loader = Flux.DataLoader((noisy, target), batchsize=64, shuffle=true)

# 设置优化器
optim = Flux.setup(Flux.Adam(0.01), model)  # Adam 策略随机梯度方法

# 训练循环,遍历整个数据集1000次
losses = []
for epoch in 1:1000
    for (x, y) in loader
        loss, grads = Flux.withgradient(model) do m
            y_hat = m(x)
            Flux.logitcrossentropy(y_hat, y)
        end
        Flux.update!(optim, model, grads[1])
        push!(losses, loss)
    end
end

# 训练后的模型输出
out2 = model(noisy)
probs2 = softmax(out2)

# 计算准确率
accuracy = mean((probs2[1,:] .> 0.5) .== truth)
println("Accuracy: $(accuracy * 100)%")

using Plots  # to draw the above figure

p_true = scatter(noisy[1,:], noisy[2,:], zcolor=truth, title="True classification", legend=false)
p_raw =  scatter(noisy[1,:], noisy[2,:], zcolor=probs1[1,:], title="Untrained network", label="", clims=(0,1))
p_done = scatter(noisy[1,:], noisy[2,:], zcolor=probs2[1,:], title="Trained network", legend=false)

plot(p_true, p_raw,layout=(1,3), size=(200,330))

输出分类效果
在这里插入图片描述

http://www.dtcms.com/a/13046.html

相关文章:

  • 修改Opcenter EXFN 页面超时时间(Adjust UI Session Extend Token)
  • C++中move和forword的区别
  • 时尚与科技的融合,戴上更轻更悦耳的QCY C30耳夹耳机,随时享受好音乐
  • 《论软件架构建模技术与应用》写作框架,软考高级系统架构设计师
  • 伊犁云计算22-1 apache 安装rhel8
  • CorePress Pro 网站加载慢 WordPress
  • 研究生三年概括
  • Trapezoidal Decomposition梯形分解算法(TCD)
  • JS设计模式之组合模式:打造灵活高效的对象层次结构
  • 学校快递站点管理|基于springboot学校快递站点管理设计与实现(源码+数据库+文档)
  • 【Unity】对象池 - 未更新完
  • 使用vite+react+ts+Ant Design开发后台管理项目(三)
  • 2024.9.26 Spark学习
  • 钉钉 钉钉打卡 钉钉定位 2024 免费试用 保用
  • 使用 Rust 和 wasm-pack 开发 WebAssembly 应用
  • ubuntu数据硬盘故障导致系统启动失败
  • Kafka集群扩容(新增一台kafka节点)
  • Windows 10 on ARM, version 22H2 (updated Sep 2024) ARM64 AArch64 中文版、英文版下载
  • 缓存穿透 问题(缓存空对象)
  • 513. 找树左下角的值
  • 常见场景题3(面试)
  • Netty简介
  • 时序数据库 TDengine 的入门体验和操作记录
  • java 框架组件
  • 24暑假实习信息、25秋招提前批信息,地信、测绘、遥感、地质相关岗位招聘汇总
  • C++——输入三个整数,按照由小到大的顺序输出。用指针方法处理。
  • ubuntu错误GPG error: http://repo.mysql.com/apt/ubuntu noble InRelease
  • Contact Form 7最新5.9.8版错误修复方案
  • Redisson 总结
  • QT窗口无法激活弹出问题排查记录