当前位置: 首页 > news >正文

强化学习的数学原理-六、随机近似与随机梯度下降

代码来自up主【强化学习的数学原理-作业】GridWorld示例代码(已更新至DQN、REINFORCE、A2C)_哔哩哔哩_bilibili

SGD、GD、MGD举例:

# 先初始化一个列表,未来要在这100个样本里面再sample出来
np.random.seed(0)
X = np.linspace(-10, 10, 1000)
Y = 2 * X ** 2 + 3*X +5 # 用作真实值

#定义二次函数,找到一组参数a、b、c使得损失函数的值最小
def quadratic_function(X, a, b, c):
    return a * X ** 2 + b * X + c

#定义损失函数
def loss_function(Y_pred, Y):
    return np.mean((Y_pred - Y)**2)

def train(learning_rate, batch_size, note):
    a = np.random.randn() 
    b = np.random.randn()
    c = np.random.randn()
    loss = 1000
    cnt = 0
    results = np.array([0])
    while loss > 0.01:
        cnt += 1
        batch = np.random.randint(0,1000,size=(1,batch_size)) # 大小为1 * batch_size

        x = X[batch]
        y = Y[batch]
        if cnt < 2:
            print(batch)
            print(x)
        y_pred = quadratic_function(x,a,b,c)

        loss = loss_function(y_pred,y)
        results = np.append(results,loss)
        # 这些是计算得到的梯度,是最小化损失函数,通过损失函数对a、b、c分别求导
        grad_a = (2 * (y_pred - y) * x ** 2).mean()
        grad_b = (2 * (y_pred - y) * x).mean()
        grad_c = (2 * (y_pred - y)).mean()

        a -= learning_rate * grad_a
        b -= learning_rate * grad_b
        c -= learning_rate * grad_c

        # 检验误差
        valid_batch = np.random.randint(0,1000,size=(1,5))
        x = X[valid_batch]
        y = Y[valid_batch]
        y_pred = quadratic_function(x,a,b,c)
        loss = loss_function(y_pred,y)
        
        # results = np.append(results,loss)
    
    print("最终系数为:",a,b,c)
    print("最后迭代次数:",cnt)
    y_pred = quadratic_function(X,a,b,c)
    plt.figure(figsize=(8,3))
    # plt.plot(X,y_pred,label="predict")
    plt.plot(X,Y,label="target")
    plt.plot(X,y_pred,label="predict")
    plt.title(note)
    plt.legend()
    plt.show()
    # print(a,b,c)

    plt.figure(figsize=(8,3))
    plt.plot(results[:150],label='x')
    # plt.plot(results[:,1],label='y')
    # plt.yticks(np.arange(-5,5,1))
    plt.legend()
    plt.title(note)
    plt.show()

相关文章:

  • HTML之JavaScript DOM简介
  • Python中的闭包和装饰器
  • 静态时序分析:时钟组间的逻辑独立、物理独立和异步的区别
  • Perplexity AI:通过OpenAI与DeepSeek彻底革新搜索和商业策略
  • 过程监督(Process Supervision)融入到 GRPO (Group Relative Policy Optimization)
  • MT7628基于原厂的SDK包, 修改ra1网卡的MAC方法。
  • 【ORB-SLAM3】鲁棒核函数的阈值设置
  • docker-rss:容器更新的RSS订阅源
  • 卷积与动态特征选择:重塑YOLOv8的多尺度目标检测能力
  • 商汤绝影发布全新端到端自动驾驶技术路线R-UniAD
  • 【Python爬虫(49)】分布式爬虫:在新兴技术浪潮下的蜕变与展望
  • 从0开始:OpenCV入门教程【图像处理基础】
  • 【网络】高级IO
  • sklearn中的决策树
  • Java子类调用父类构造器的应用场景
  • STM32-有关内存堆栈、map文件
  • ROS2 应用:按键控制 MoveIt2 中 Panda 机械臂关节位置
  • golang内存泄漏
  • 下载CentOS 10
  • 探索未知:alpha测试的神秘序章【量化理论】
  • 巴基斯坦称未违反停火协议
  • 高龄老人骨折后,生死可能就在家属一念之间
  • 2025柯桥时尚周启幕:国际纺都越来越时尚
  • 苹果Safari浏览器上的搜索量首次下降
  • 首届上海老年学习课程展将在今年10月举办
  • 陕西澄城打造“中国樱桃第一县”:从黄土高原走向海外,年产值超30亿