当前位置: 首页 > news >正文

解析keras.layers.Layer中的权重参数

文章目录

    • 概要
    • __init__()
    • build()
    • add_weight()

概要

keras.layers.Layers是所有层对象的父类,在keras.layers下所有实现类都是其子类,自定义层时需要继承该类。

init()

Layer的构造函数,需要注意两个参数trainablename
trainable:指定该层所有的权重参数是否参与训练更新过程。为False时,该层所有权重参数被冻结,不参与训练,常用于微调。默认为True。
name:指定该层的名字,这在模型参数保存时很有用,可以根据指定名字找到对应层的所有权重参数。

import keras.layers as layers
import tensorflow as tf


class MLP(layers.Layer):
    def __init__(self, units, **kwargs):
        super().__init__(name="my_MLP", **kwargs)
        self.dense_1 = layers.Dense(units, trainable=False, name="MLP_1")
        self.dense_2 = layers.Dense(units, name="MLP_2")

    def call(self, inputs, *args, **kwargs):
        x = self.dense_1(inputs)
        x = self.dense_2(x)
        return x


class Projection(layers.Layer):
    def __init__(self, units, kernel_size, **kwargs):
        super().__init__(name="my_Projection", **kwargs)
        self.Conv1d = layers.Conv1D(units, kernel_size, name="ConV_1d")
        self.mlp = MLP(units)
        self.dense = layers.Dense(units, trainable=False, name="Dense")

    def call(self, inputs, *args, **kwargs):
        x = self.Conv1d(inputs)
        x = self.mlp(x)
        x = self.dense(x)
        return x


if __name__ == '__main__':
    inputs = tf.random.uniform((1, 128, 3))
    projection = Projection(units=16, kernel_size=3)
    projection(inputs)
    trainable_variables = projection.trainable_variables
    print("可训练参数:")
    for i in range(len(trainable_variables)):
        print(trainable_variables[i].name)

    print("不可训练参数:")
    non_trainable_variables = projection.non_trainable_variables
    for i in range(len(non_trainable_variables)):
        print(non_trainable_variables[i].name)




可训练参数:
my_Projection/ConV_1d/kernel:0
my_Projection/ConV_1d/bias:0
my_Projection/my_MLP/MLP_2/kernel:0
my_Projection/my_MLP/MLP_2/bias:0
不可训练参数:
my_Projection/my_MLP/MLP_1/kernel:0
my_Projection/my_MLP/MLP_1/bias:0
my_Projection/Dense/kernel:0
my_Projection/Dense/bias:0

自定义了两个层MLP类和Projection类,其中MLP作为Projection的成员变量。super().init()调用其父类Layer的构造函数,并指定name参数和trainable参数。然后分别打印可训练参数和不可训练参数。
注意 : 只用当层对象(Layer)调用完build()方法后,层中才存在权重参数(否则打印权重参数时,为空集)。当第一次调用call()方法时,build()方法会在call()调用之前先执行,所以代码中在打印权重参数之前,先调用了call方法(projection(inputs))。
另外权重参数kernel和bias的命名方式和文件路径类似,它反应了Projection类与其成员变量的包含关系。

如果想查看Projection类中某一层kernel和bias的具体参数,可通过.numpy()(替换.name)打印出。

build()

生成并初始化所有层中的权重参数。自定义层时,可以重写该方法,定义自己需要的权重参数。

class Stacked_Sum(tf.keras.layers.Layer):
    def __init__(self, dim_output):
        super(Stacked_Sum, self).__init__()
        self.num_outputs = dim_output

    def build(self, input_shape):
        self.kernel = self.add_weight(name="kernel",
                                      shape=[int(input_shape[-2]),
                                             int(input_shape[-1]),
                                             self.num_outputs])

    def call(self, inputs):
        x = tf.einsum('abcd, cde->abe', inputs, self.kernel)
        return tf.keras.activations.swish(x)

在build()方法中通过add_weight()添加了一个新的kernel, 并在call方法中使用,这个kernel的参数会在训练中不断更新。其中input_shape是call方法的输入的shape。build()方法在call()调用之前自动被调用。

add_weight()

可以在build方法中添加与输入shape有关的权重参数,也可以在__init__方法中添加与输入无关的权重参数

class Stacked_Sum(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.kernel = self.add_weight(name="my_kernel",
                                      shape=[4, 3])

    def call(self, inputs):
        return inputs

上述代码是在init方法中添加与输入无关的权重参数。

http://www.dtcms.com/a/113736.html

相关文章:

  • Linux内核——段描述符详解
  • SeaTunnel系列之:Apache SeaTunnel编译和安装
  • 《SQL赋能人工智能:解锁特征工程的隐秘力量》
  • python基础-11-调试程序
  • DrissionPage高级技巧:从爬虫到自动化测试
  • Python FastApi(13):APIRouter
  • 操作系统知识点(二)
  • 超级科学软件实验室(中国) : Super Scientific Software Laboratory (SSSLab)
  • Vue2与Vue3不同
  • Deformable DETR(复习专用)
  • 基于Spark的哔哩哔哩舆情数据分析系统
  • 【RK3588 嵌入式图形编程】-SDL2-扫雷游戏-创建网格
  • liunx输入法
  • 网安小白筑基篇五:web后端基础之Python(补充Python的魔术方法)
  • Scade One - 将MBD技术从少数高安全领域向更广泛的安全嵌入式软件普及
  • 使用MATIO库读取Matlab数据文件中的cell结构数据
  • 【设计模式】命令模式
  • mine craft经典信封
  • 力扣刷题-热题100题-第31题(c++、python)
  • 博途 TIA Portal之1200做主站与200SMART的S7通讯
  • 《减压宝典》Python篇
  • leetcode每日一题:替换子串得到平衡字符串
  • vue3实现markdown预览和编辑
  • Cursor 无限续杯 Windows版
  • 智能体开发实战指南:提示词设计、开发框架与工作流详解
  • ROS多设备交互
  • 用C语言控制键盘上的方向键
  • LightRAG核心原理和数据流
  • Cisco Packet Tracer 8.0(新版)
  • 【神经网络】python实现神经网络(三)——正向学习的模拟演练