当前位置: 首页 > news >正文

代码解读——ReferenceNet

ReferenceNet代码详解

文章目录

  • ReferenceNet代码详解
    • init部分
      • 1. 定义部分
      • 2. input&time部分
      • 3. class embedding
      • 4. down
      • 5. mid
      • 6. up
    • forward部分
      • 1. 初始部分
      • 2. time部分
      • 3. pre-process部分
      • 4. down部分
      • 5. mid部分
      • 6. up部分

其实就是类似UNet 2D Condition Model的架构,具体可以参考:

stable diffusion 中使用的 UNet 2D Condition Model 结构解析(diffusers库) - 知乎
img

init部分

1. 定义部分

	@register_to_configdef __init__(self,sample_size: Optional[int] = None,in_channels: int = 4,out_channels: int = 4,center_input_sample: bool = False,flip_sin_to_cos: bool = True,freq_shift: int = 0,down_block_types: Tuple[str] = ("CrossAttnDownBlock2D","CrossAttnDownBlock2D","CrossAttnDownBlock2D","DownBlock2D",),mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),only_cross_attention: Union[bool, Tuple[bool]] = False,block_out_channels: Tuple[int] = (320, 640, 1280, 1280),layers_per_block: Union[int, Tuple[int]] = 2,downsample_padding: int = 1,mid_block_scale_factor: float = 1,act_fn: str = "silu",norm_num_groups: Optional[int] = 32,norm_eps: float = 1e-5,cross_attention_dim: Union[int, Tuple[int]] = 1280,transformer_layers_per_block: Union[int, Tuple[int]] = 1,encoder_hid_dim: Optional[int] = None,encoder_hid_dim_type: Optional[str] = None,attention_head_dim: Union[int, Tuple[int]] = 8,num_attention_heads: Optional[Union[int, Tuple[int]]] = None,dual_cross_attention: bool = False,use_linear_projection: bool = False,class_embed_type: Optional[str] = None,addition_embed_type: Optional[str] = None,addition_time_embed_dim: Optional[int] = None,num_class_embeds: Optional[int] = None,upcast_attention: bool = False,resnet_time_scale_shift: str = "default",resnet_skip_time_act: bool = False,resnet_out_scale_factor: int = 1.0,time_embedding_type: str = "positional",time_embedding_dim: Optional[int] = None,time_embedding_act_fn: Optional[str] = None,timestep_post_act: Optional[str] = None,time_cond_proj_dim: Optional[int] = None,conv_in_kernel: int = 3,conv_out_kernel: int = 3,projection_class_embeddings_input_dim: Optional[int] = None,attention_type: str = "default",class_embeddings_concat: bool = False,mid_block_only_cross_attention: Optional[bool] = None,cross_attention_norm: Optional[str] = None,addition_embed_type_num_heads=64,):super().__init__()self.sample_size = sample_sizeif num_attention_heads is not None:raise ValueError("At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19.")# If `num_attention_heads` is not defined (which is the case for most models)# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.# The reason for this behavior is to correct for incorrectly named variables that were introduced# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking# which is why we correct for the naming here.num_attention_heads = num_attention_heads or attention_head_dim# Check inputsif len(down_block_types) != len(up_block_types):raise ValueError(f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}.")if len(block_out_channels) != len(down_block_types):raise ValueError(f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}.")if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):raise ValueError(f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}.")if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):raise ValueError(f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}.")if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):raise ValueError(f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}.")if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):raise ValueError(f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}.")if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):raise ValueError(f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}.")

主要定义了 ReferenceNet 类的构造函数 init ,用于初始化模型的各种参数和配置。以下是各部分的详细功能:

  1. 类属性 :

    • _supports_gradient_checkpointing :设置为 True ,表示支持梯度检查点功能。
  2. 构造函数参数 :

    • sample_size 、 in_channels 、 out_channels 等:定义输入输出的基本参数。
    • down_block_types 、 up_block_types :指定下采样和上采样块的类型。
    • block_out_channels :定义每个块的输出通道数。
    • layers_per_block :每个块的层数。
    • cross_attention_dim 、 attention_head_dim 等:与注意力机制相关的参数。
    • 其他参数:如 act_fn (激活函数)、 norm_num_groups (归一化组数)等。
  3. 参数检查 :

    • 检查 down_block_types 和 up_block_types 的长度是否一致。
    • 检查 block_out_channels 的长度是否与 down_block_types 一致。
    • 检查 only_cross_attention 、 num_attention_heads 、 attention_head_dim 等参数的长度是否与 down_block_types 一致。
  4. 注意力头数处理 :

    • 如果 num_attention_heads 未定义,则默认使用 attention_head_dim 。
    • 如果 num_attention_heads 已定义,则抛出异常,提示当前版本不支持此参数的定义。
      这些设置和检查确保了模型的配置正确,并为后续的模型构建提供了基础。

2. input&time部分

 		# inputconv_in_padding = (conv_in_kernel - 1) // 2self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding)# timeif time_embedding_type == "fourier":time_embed_dim = time_embedding_dim or block_out_channels[0] * 2if time_embed_dim % 2 != 0:raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")self.time_proj = GaussianFourierProjection(time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos)timestep_input_dim = time_embed_dimelif time_embedding_type == "positional":time_embed_dim = time_embedding_dim or block_out_channels[0] * 4self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)timestep_input_dim = block_out_channels[0]else:raise ValueError(f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`.")self.time_embedding = TimestepEmbedding(timestep_input_dim,time_embed_dim,act_fn=act_fn,post_act_fn=timestep_post_act,cond_proj_dim=time_cond_proj_dim,)if encoder_hid_dim_type is None and encoder_hid_dim is not None:encoder_hid_dim_type = "text_proj"self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")if encoder_hid_dim is None and encoder_hid_dim_type is not None:raise ValueError(f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}.")if encoder_hid_dim_type == "text_proj":self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)elif encoder_hid_dim_type == "text_image_proj":# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use# case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`self.encoder_hid_proj = TextImageProjection(text_embed_dim=encoder_hid_dim,image_embed_dim=cross_attention_dim,cross_attention_dim=cross_attention_dim,)elif encoder_hid_dim_type == "image_proj":# Kandinsky 2.2self.encoder_hid_proj = ImageProjection(image_embed_dim=encoder_hid_dim,cross_attention_dim=cross_attention_dim,)elif encoder_hid_dim_type is not None:raise ValueError(f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'.")else:self.encoder_hid_proj = None

主要定义了输入处理、时间嵌入和编码器隐藏状态投影的初始化。以下是各部分的详细功能:

  1. 输入卷积层 :

    • self.conv_in :使用 nn.Conv2d 定义输入卷积层,用于将输入通道数转换为第一个块的输出通道数。卷积核大小由 conv_in_kernel 参数决定,填充方式为 (conv_in_kernel - 1) // 2 。
  2. 时间嵌入 :

    • 根据 time_embedding_type 选择时间嵌入的方式:
      • fourier :使用高斯傅里叶投影( GaussianFourierProjection )进行时间嵌入,要求 time_embed_dim 是2的倍数。
      • positional :使用位置嵌入( Timesteps )进行时间嵌入。
    • self.time_embedding :定义时间步嵌入层( TimestepEmbedding ),用于将时间步信息嵌入到模型中。
  3. 编码器隐藏状态投影 :

    • 根据 encoder_hid_dim_type 选择不同的投影方式:
      • text_proj :使用线性层( nn.Linear )将文本嵌入投影到交叉注意力维度。
      • text_image_proj :使用 TextImageProjection 进行文本和图像嵌入的联合投影。
      • image_proj :使用 ImageProjection 进行图像嵌入投影。
    • 如果 encoder_hid_dim_type 未定义但 encoder_hid_dim 已定义,则默认设置为 text_proj

3. class embedding

 # class embeddingif class_embed_type is None and num_class_embeds is not None:self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)elif class_embed_type == "timestep":self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)elif class_embed_type == "identity":self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)elif class_embed_type == "projection":if projection_class_embeddings_input_dim is None:raise ValueError("`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set")# The projection `class_embed_type` is the same as the timestep `class_embed_type` except# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings# 2. it projects from an arbitrary input dimension.## Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.# As a result, `TimestepEmbedding` can be passed arbitrary vectors.self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)elif class_embed_type == "simple_projection":if projection_class_embeddings_input_dim is None:raise ValueError("`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set")self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)else:self.class_embedding = Noneif addition_embed_type == "text":if encoder_hid_dim is not None:text_time_embedding_from_dim = encoder_hid_dimelse:text_time_embedding_from_dim = cross_attention_dimself.add_embedding = TextTimeEmbedding(text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads)elif addition_embed_type == "text_image":# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use# case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`self.add_embedding = TextImageTimeEmbedding(text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim)elif addition_embed_type == "text_time":self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)elif addition_embed_type == "image":# Kandinsky 2.2self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)elif addition_embed_type == "image_hint":# Kandinsky 2.2 ControlNetself.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)elif addition_embed_type is not None:raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")if time_embedding_act_fn is None:self.time_embed_act = Noneelse:self.time_embed_act = get_activation(time_embedding_act_fn)self.down_blocks = nn.ModuleList([])self.up_blocks = nn.ModuleList([])if isinstance(only_cross_attention, bool):if mid_block_only_cross_attention is None:mid_block_only_cross_attention = only_cross_attentiononly_cross_attention = [only_cross_attention] * len(down_block_types)if mid_block_only_cross_attention is None:mid_block_only_cross_attention = Falseif isinstance(num_attention_heads, int):num_attention_heads = (num_attention_heads,) * len(down_block_types)if isinstance(attention_head_dim, int):attention_head_dim = (attention_head_dim,) * len(down_block_types)if isinstance(cross_attention_dim, int):cross_attention_dim = (cross_attention_dim,) * len(down_block_types)if isinstance(layers_per_block, int):layers_per_block = [layers_per_block] * len(down_block_types)if isinstance(transformer_layers_per_block, int):transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)if class_embeddings_concat:# The time embeddings are concatenated with the class embeddings. The dimension of the# time embeddings passed to the down, middle, and up blocks is twice the dimension of the# regular time embeddingsblocks_time_embed_dim = time_embed_dim * 2else:blocks_time_embed_dim = time_embed_dim

主要处理的是类嵌入和附加嵌入的初始化。

  1. 类嵌入(Class Embedding) :

    • 根据 class_embed_type 的不同,初始化不同类型的类嵌入。
    • 如果 class_embed_type 为 None 且 num_class_embeds 不为 None ,则使用 nn.Embedding 初始化类嵌入。
    • 如果 class_embed_type 为 “timestep”,则使用 TimestepEmbedding 。
    • 如果 class_embed_type 为 “identity”,则使用 nn.Identity 。
    • 如果 class_embed_type 为 “projection” 或 “simple_projection”,则分别使用 TimestepEmbedding 或 nn.Linear ,并检查 projection_class_embeddings_input_dim 是否设置。
  2. 附加嵌入(Addition Embedding) :

    • 根据 addition_embed_type 的不同,初始化不同类型的附加嵌入。
    • 如果 addition_embed_type 为 “text”,则使用 TextTimeEmbedding ,并根据 encoder_hid_dim 或 cross_attention_dim 设置输入维度。
    • 如果 addition_embed_type 为 “text_image”,则使用 TextImageTimeEmbedding 。
    • 如果 addition_embed_type 为 “text_time”,则初始化 add_time_proj 和 add_embedding 。
    • 如果 addition_embed_type 为 “image” 或 “image_hint”,则分别使用 ImageTimeEmbedding 或 ImageHintTimeEmbedding 。
    • 如果 addition_embed_type 不为 None 且不在上述类型中,则抛出 ValueError 。
  3. 时间嵌入激活函数 :

    • 根据 time_embedding_act_fn 的值,初始化时间嵌入激活函数。
  4. 模块列表初始化 :

    • 初始化 down_blocks 和 up_blocks 为 nn.ModuleList 。
  5. 参数处理 :

    • 根据参数类型,调整 only_cross_attention 、 num_attention_heads 、 attention_head_dim 、 cross_attention_dim 、 layers_per_block 和 transformer_layers_per_block 的格式。
    • 如果 class_embeddings_concat 为 True ,则将 blocks_time_embed_dim 设置为 time_embed_dim 的两倍,否则保持不变。

4. down

		# downoutput_channel = block_out_channels[0]for i, down_block_type in enumerate(down_block_types):input_channel = output_channeloutput_channel = block_out_channels[i]is_final_block = i == len(block_out_channels) - 1down_block = get_down_block(down_block_type,num_layers=layers_per_block[i],transformer_layers_per_block=transformer_layers_per_block[i],in_channels=input_channel,out_channels=output_channel,temb_channels=blocks_time_embed_dim,add_downsample=not is_final_block,resnet_eps=norm_eps,resnet_act_fn=act_fn,resnet_groups=norm_num_groups,cross_attention_dim=cross_attention_dim[i],num_attention_heads=num_attention_heads[i],downsample_padding=downsample_padding,dual_cross_attention=dual_cross_attention,use_linear_projection=use_linear_projection,only_cross_attention=only_cross_attention[i],upcast_attention=upcast_attention,resnet_time_scale_shift=resnet_time_scale_shift,attention_type=attention_type,resnet_skip_time_act=resnet_skip_time_act,resnet_out_scale_factor=resnet_out_scale_factor,cross_attention_norm=cross_attention_norm,attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,)self.down_blocks.append(down_block)

主要定义了下采样块(down blocks)的构建过程。以下是各部分的详细功能:

  1. 初始化输出通道 :

    • output_channel 初始化为第一个块的输出通道数,即 block_out_channels[0] 。
  2. 遍历下采样块类型 :

    • 使用 enumerate(down_block_types) 遍历所有下采样块类型。
    • 对于每个下采样块,设置输入通道为上一个块的输出通道,输出通道为当前块的输出通道。
    • 判断是否为最后一个块,通过 is_final_block 标志来决定是否添加下采样层。
  3. 构建下采样块 :

    • 调用 get_down_block 函数,根据当前块的类型和参数构建下采样块。
    • 参数包括层数、输入输出通道数、时间嵌入通道数、注意力机制相关参数等。
  4. 添加到下采样块列表 :

    • 将构建好的下采样块添加到 self.down_blocks 列表中。

5. mid

		# midif mid_block_type == "UNetMidBlock2DCrossAttn":self.mid_block = UNetMidBlock2DCrossAttn(transformer_layers_per_block=transformer_layers_per_block[-1],in_channels=block_out_channels[-1],temb_channels=blocks_time_embed_dim,resnet_eps=norm_eps,resnet_act_fn=act_fn,output_scale_factor=mid_block_scale_factor,resnet_time_scale_shift=resnet_time_scale_shift,cross_attention_dim=cross_attention_dim[-1],num_attention_heads=num_attention_heads[-1],resnet_groups=norm_num_groups,dual_cross_attention=dual_cross_attention,use_linear_projection=use_linear_projection,upcast_attention=upcast_attention,attention_type=attention_type,)elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":self.mid_block = UNetMidBlock2DSimpleCrossAttn(in_channels=block_out_channels[-1],temb_channels=blocks_time_embed_dim,resnet_eps=norm_eps,resnet_act_fn=act_fn,output_scale_factor=mid_block_scale_factor,cross_attention_dim=cross_attention_dim[-1],attention_head_dim=attention_head_dim[-1],resnet_groups=norm_num_groups,resnet_time_scale_shift=resnet_time_scale_shift,skip_time_act=resnet_skip_time_act,only_cross_attention=mid_block_only_cross_attention,cross_attention_norm=cross_attention_norm,)elif mid_block_type is None:self.mid_block = Noneelse:raise ValueError(f"unknown mid_block_type : {mid_block_type}")

主要处理的是 UNet 的中间块(mid-block)的初始化。这个代码段根据 mid_block_type 的不同,选择不同的中间块类型进行初始化。

  1. UNetMidBlock2DCrossAttn :

    • 如果 mid_block_type 是 “UNetMidBlock2DCrossAttn”,则初始化一个 UNetMidBlock2DCrossAttn 对象。
    • 该对象包含多个参数设置,包括 transformer_layers_per_block 、 in_channels 、 temb_channels 、 resnet_eps 、 resnet_act_fn 、 output_scale_factor 、 resnet_time_scale_shift 、 cross_attention_dim 、 num_attention_heads 、 resnet_groups 、 dual_cross_attention 、 use_linear_projection 、 upcast_attention 和 attention_type 。
    • 这些参数用于配置中间块的各种特性,如通道数、时间嵌入、注意力机制等。
  2. UNetMidBlock2DSimpleCrossAttn :

    • 如果 mid_block_type 是 “UNetMidBlock2DSimpleCrossAttn”,则初始化一个 UNetMidBlock2DSimpleCrossAttn 对象。
    • 该对象包含参数设置,包括 in_channels 、 temb_channels 、 resnet_eps 、 resnet_act_fn 、 output_scale_factor 、 cross_attention_dim 、 attention_head_dim 、 resnet_groups 、 resnet_time_scale_shift 、 skip_time_act 、 only_cross_attention 和 cross_attention_norm 。
    • 这些参数用于配置简单交叉注意力机制的中间块。
  3. None :

    • 如果 mid_block_type 是 None ,则不初始化任何中间块,将 self.mid_block 设置为 None 。
  4. 异常处理 :

    • 如果 mid_block_type 是其他未定义的类型,则抛出一个 ValueError 异常,提示 “unknown mid_block_type”。

6. up

		# count how many layers upsample the imagesself.num_upsamplers = 0# upreversed_block_out_channels = list(reversed(block_out_channels))reversed_num_attention_heads = list(reversed(num_attention_heads))reversed_layers_per_block = list(reversed(layers_per_block))reversed_cross_attention_dim = list(reversed(cross_attention_dim))reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))only_cross_attention = list(reversed(only_cross_attention))output_channel = reversed_block_out_channels[0]for i, up_block_type in enumerate(up_block_types):is_final_block = i == len(block_out_channels) - 1prev_output_channel = output_channeloutput_channel = reversed_block_out_channels[i]input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]# add upsample block for all BUT final layerif not is_final_block:add_upsample = Trueself.num_upsamplers += 1else:add_upsample = Falseup_block = get_up_block(up_block_type,num_layers=reversed_layers_per_block[i] + 1,transformer_layers_per_block=reversed_transformer_layers_per_block[i],in_channels=input_channel,out_channels=output_channel,prev_output_channel=prev_output_channel,temb_channels=blocks_time_embed_dim,add_upsample=add_upsample,resnet_eps=norm_eps,resnet_act_fn=act_fn,resnet_groups=norm_num_groups,cross_attention_dim=reversed_cross_attention_dim[i],num_attention_heads=reversed_num_attention_heads[i],dual_cross_attention=dual_cross_attention,use_linear_projection=use_linear_projection,only_cross_attention=only_cross_attention[i],upcast_attention=upcast_attention,resnet_time_scale_shift=resnet_time_scale_shift,attention_type=attention_type,resnet_skip_time_act=resnet_skip_time_act,resnet_out_scale_factor=resnet_out_scale_factor,cross_attention_norm=cross_attention_norm,attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,)self.up_blocks.append(up_block)prev_output_channel = output_channelself.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_q = _LoRACompatibleLinear()self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_k = _LoRACompatibleLinear()self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_v = _LoRACompatibleLinear()self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_out = nn.ModuleList([Identity(), Identity()])self.up_blocks[3].attentions[2].transformer_blocks[0].norm2 = Identity()self.up_blocks[3].attentions[2].transformer_blocks[0].attn2 = Noneself.up_blocks[3].attentions[2].transformer_blocks[0].norm3 = Identity()self.up_blocks[3].attentions[2].transformer_blocks[0].ff = Identity()self.up_blocks[3].attentions[2].proj_out = Identity()if attention_type in ["gated", "gated-text-image"]:positive_len = 768if isinstance(cross_attention_dim, int):positive_len = cross_attention_dimelif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):positive_len = cross_attention_dim[0]feature_type = "text-only" if attention_type == "gated" else "text-image"self.position_net = PositionNet(positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type)

这部分主要处理的是 UNet 的上采样(upsample)过程。根据不同的上采样块类型,初始化并配置上采样块:

  1. 初始化反转参数 :

    • 代码首先反转了一些参数列表,如 block_out_channels 、 num_attention_heads 、 layers_per_block 、 cross_attention_dim 和 transformer_layers_per_block ,以便在上采样过程中从最后一个块开始处理。
  2. 上采样块的构建 :

    • 遍历 up_block_types ,为每种类型的上采样块调用 get_up_block 函数进行构建。
    • 在构建过程中,设置了输入通道、输出通道、前一个块的输出通道、时间嵌入通道、是否添加上采样、ResNet 参数、交叉注意力参数等。
    • 通过 is_final_block 判断是否为最后一个块,决定是否添加上采样。
  3. 特殊处理 :

    • 对于最后一个上采样块,进行了特殊处理,跳过了最后一层的交叉注意力以加速计算,并为 DDP 训练做了优化。
    • 使用 _LoRACompatibleLinear 和 Identity 替换了一些注意力机制中的组件,以减少计算量。
  4. 位置网络的初始化 :

    • 如果 attention_type 是 “gated” 或 “gated-text-image”,则初始化 PositionNet ,用于处理位置嵌入。
    • 根据 cross_attention_dim 的类型设置 positive_len ,并根据 attention_type 设置 feature_type 。

forward部分

1. 初始部分

def forward(self,sample: torch.FloatTensor,timestep: Union[torch.Tensor, float, int],encoder_hidden_states: torch.Tensor,class_labels: Optional[torch.Tensor] = None,timestep_cond: Optional[torch.Tensor] = None,attention_mask: Optional[torch.Tensor] = None,cross_attention_kwargs: Optional[Dict[str, Any]] = None,added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,mid_block_additional_residual: Optional[torch.Tensor] = None,encoder_attention_mask: Optional[torch.Tensor] = None,return_dict: bool = True,) -> Union[UNet2DConditionOutput, Tuple]:r"""The [`UNet2DConditionModel`] forward method.Args:sample (`torch.FloatTensor`):The noisy input tensor with the following shape `(batch, channel, height, width)`.timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.encoder_hidden_states (`torch.FloatTensor`):The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.encoder_attention_mask (`torch.Tensor`):A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If`True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,which adds large negative values to the attention scores corresponding to "discard" tokens.return_dict (`bool`, *optional*, defaults to `True`):Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plaintuple.cross_attention_kwargs (`dict`, *optional*):A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].added_cond_kwargs: (`dict`, *optional*):A kwargs dictionary containin additional embeddings that if specified are added to the embeddings thatare passed along to the UNet blocks.Returns:[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwisea `tuple` is returned where the first element is the sample tensor."""# By default samples have to be AT least a multiple of the overall upsampling factor.# The overall upsampling factor is equal to 2 ** (# num of upsampling layers).# However, the upsampling interpolation output size can be forced to fit any upsampling size# on the fly if necessary.default_overall_up_factor = 2**self.num_upsamplers# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`forward_upsample_size = Falseupsample_size = Noneif any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):logger.info("Forward upsample size to force interpolation output size.")forward_upsample_size = Trueif attention_mask is not None:attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0attention_mask = attention_mask.unsqueeze(1)# convert encoder_attention_mask to a bias the same way we do for attention_maskif encoder_attention_mask is not None:encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0encoder_attention_mask = encoder_attention_mask.unsqueeze(1)# 0. center input if necessaryif self.config.center_input_sample:sample = 2 * sample - 1.0

这段代码主要用于处理UNet模型输入输出尺寸、注意力掩码和输入归一化,具体解释如下:

  1. 上采样因子检查 :

    • default_overall_up_factor = 2**self.num_upsamplers 计算整体上采样倍数(如有4层上采样则为16),要求输入图片的高宽必须是该倍数的整数倍,否则会导致上采样后尺寸不匹配。
    • 如果输入尺寸不是上采样倍数的整数倍,则设置 forward_upsample_size = True ,后续会强制插值调整输出尺寸。
  2. 注意力掩码处理 :

    • attention_mask 和 encoder_attention_mask 都会被转换为适合加性注意力机制的 bias mask,掩码为1的位置变为0,掩码为0的位置变为-10000,防止无效区域被关注。
    • 并通过 unsqueeze(1) 扩展维度,适配后续注意力层输入。
  3. 输入归一化 :

    • 如果配置了 center_input_sample ,则将输入sample从[0,1]区间线性映射到[-1,1],有助于模型收敛和数值稳定。
      整体作用是保证UNet输入输出尺寸和注意力机制的兼容性,并对输入数据做归一化预处理。

2. time部分

		# 1. timetimesteps = timestepif not torch.is_tensor(timesteps):# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can# This would be a good case for the `match` statement (Python 3.10+)is_mps = sample.device.type == "mps"if isinstance(timestep, float):dtype = torch.float32 if is_mps else torch.float64else:dtype = torch.int32 if is_mps else torch.int64timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)elif len(timesteps.shape) == 0:timesteps = timesteps[None].to(sample.device)# broadcast to batch dimension in a way that's compatible with ONNX/Core MLtimesteps = timesteps.expand(sample.shape[0])t_emb = self.time_proj(timesteps)# `Timesteps` does not contain any weights and will always return f32 tensors# but time_embedding might actually be running in fp16. so we need to cast here.# there might be better ways to encapsulate this.t_emb = t_emb.to(dtype=sample.dtype)emb = self.time_embedding(t_emb, timestep_cond)aug_emb = Noneif self.class_embedding is not None:if class_labels is None:raise ValueError("class_labels should be provided when num_class_embeds > 0")if self.config.class_embed_type == "timestep":class_labels = self.time_proj(class_labels)# `Timesteps` does not contain any weights and will always return f32 tensors# there might be better ways to encapsulate this.class_labels = class_labels.to(dtype=sample.dtype)class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)if self.config.class_embeddings_concat:emb = torch.cat([emb, class_emb], dim=-1)else:emb = emb + class_embif self.config.addition_embed_type == "text":aug_emb = self.add_embedding(encoder_hidden_states)elif self.config.addition_embed_type == "text_image":# Kandinsky 2.1 - styleif "image_embeds" not in added_cond_kwargs:raise ValueError(f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`")image_embs = added_cond_kwargs.get("image_embeds")text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)aug_emb = self.add_embedding(text_embs, image_embs)elif self.config.addition_embed_type == "text_time":# SDXL - styleif "text_embeds" not in added_cond_kwargs:raise ValueError(f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`")text_embeds = added_cond_kwargs.get("text_embeds")if "time_ids" not in added_cond_kwargs:raise ValueError(f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`")time_ids = added_cond_kwargs.get("time_ids")time_embeds = self.add_time_proj(time_ids.flatten())time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)add_embeds = add_embeds.to(emb.dtype)aug_emb = self.add_embedding(add_embeds)elif self.config.addition_embed_type == "image":# Kandinsky 2.2 - styleif "image_embeds" not in added_cond_kwargs:raise ValueError(f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`")image_embs = added_cond_kwargs.get("image_embeds")aug_emb = self.add_embedding(image_embs)elif self.config.addition_embed_type == "image_hint":# Kandinsky 2.2 - styleif "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:raise ValueError(f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`")image_embs = added_cond_kwargs.get("image_embeds")hint = added_cond_kwargs.get("hint")aug_emb, hint = self.add_embedding(image_embs, hint)sample = torch.cat([sample, hint], dim=1)emb = emb + aug_emb if aug_emb is not None else embif self.time_embed_act is not None:emb = self.time_embed_act(emb)if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":# Kadinsky 2.1 - styleif "image_embeds" not in added_cond_kwargs:raise ValueError(f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in  `added_conditions`")image_embeds = added_cond_kwargs.get("image_embeds")encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":# Kandinsky 2.2 - styleif "image_embeds" not in added_cond_kwargs:raise ValueError(f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in  `added_conditions`")image_embeds = added_cond_kwargs.get("image_embeds")encoder_hidden_states = self.encoder_hid_proj(image_embeds)

这段代码主要是在做条件 UNet 的条件嵌入(embedding)准备工作,具体包括:

  1. 时间步嵌入(timesteps & t_emb)

    • 处理扩散模型的时间步输入,将其转换为张量并投影到高维空间,作为后续条件特征。
    • 保证数据类型和设备一致性,适配不同硬件和精度。
  2. 主条件嵌入(emb)

    • 通过 self.time_embedding 将时间步特征和可选的 timestep_cond 进一步编码。
  3. 类别嵌入(class_embedding)

    • 如果模型有类别条件,先检查输入合法性。
    • 支持类别标签直接嵌入或先经过时间投影。
    • 类别特征与主 embedding 拼接或相加,增强条件表达。
  4. 附加条件嵌入(addition embedding)

    • 根据 config 的 addition_embed_type,支持多种条件融合:
      • text:文本条件
      • text_image:文本+图像(如 Kandinsky 2.1)
      • text_time:文本+时间(如 SDXL)
      • image:图像条件(如 Kandinsky 2.2)
      • image_hint:图像+hint(如 Kandinsky 2.2)
    • 通过 self.add_embedding 统一处理,必要时拼接到 sample。
  5. 条件融合(emb = emb + aug_emb)

    • 将主 embedding 与附加条件 embedding 融合。
  6. 激活函数(time_embed_act)

    • 如果有激活函数,对融合后的 embedding 进行非线性变换。
  7. 编码器隐藏状态投影(encoder_hidden_states)

    • 根据 config.encoder_hid_dim_type,支持三种投影方式:
      • text_proj:对文本条件投影
      • text_image_proj:对文本和图像条件联合投影
      • image_proj:对图像条件投影
    • 统一为后续 cross-attention 提供合适的条件特征。

3. pre-process部分

		# 2. pre-processsample = self.conv_in(sample)# 2.5 GLIGEN position netif cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:cross_attention_kwargs = cross_attention_kwargs.copy()gligen_args = cross_attention_kwargs.pop("gligen")cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}

这段代码主要做了两件事:

  1. 预处理输入特征 :

    • sample = self.conv_in(sample) 这行代码将输入 sample 通过一个卷积层(通常是降通道或特征提取),为后续 UNet 主体做准备。
  2. GLIGEN 位置网络处理(可选) :

    • 如果 cross_attention_kwargs 字典中包含 “gligen” 键,则说明需要用 GLIGEN 相关的位置网络。
    • 代码会复制一份 cross_attention_kwargs,取出 gligen 的参数(gligen_args),然后用 self.position_net(**gligen_args) 计算出位置相关的对象特征,重新赋值回 cross_attention_kwargs[“gligen”]。
    • 这样做的目的是为后续 cross-attention 层提供空间/位置条件(如目标检测框、分割等空间提示),常用于多模态或空间控制场景。

4. down部分

		# 3. downis_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not Noneis_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not Nonedown_block_res_samples = (sample,)for downsample_block in self.down_blocks:if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:# For t2i-adapter CrossAttnDownBlock2Dadditional_residuals = {}if is_adapter and len(down_block_additional_residuals) > 0:additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)sample, res_samples = downsample_block(hidden_states=sample,temb=emb,encoder_hidden_states=encoder_hidden_states,attention_mask=attention_mask,cross_attention_kwargs=cross_attention_kwargs,encoder_attention_mask=encoder_attention_mask,**additional_residuals,)else:sample, res_samples = downsample_block(hidden_states=sample, temb=emb)if is_adapter and len(down_block_additional_residuals) > 0:sample += down_block_additional_residuals.pop(0)down_block_res_samples += res_samplesif is_controlnet:new_down_block_res_samples = ()for down_block_res_sample, down_block_additional_residual in zip(down_block_res_samples, down_block_additional_residuals):down_block_res_sample = down_block_res_sample + down_block_additional_residualnew_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)down_block_res_samples = new_down_block_res_samples

各部分功能如下:

  1. is_controlnet / is_adapter 判断

    • is_controlnet 和 is_adapter 用于区分当前网络是否为 ControlNet 或 T2I-Adapter 分支,决定后续残差信息的处理方式。
  2. down_block_res_samples 初始化

    • down_block_res_samples = (sample,) 初始化下采样阶段的特征收集容器。
  3. 下采样主循环

    • 遍历 self.down_blocks ,对输入特征逐步下采样。
    • 如果 downsample_block 支持 cross-attention(如 T2I-Adapter 的 CrossAttnDownBlock2D):
      • 若为 adapter 模式且有额外残差,则将残差通过 additional_residuals 传入 block。
      • 调用 block,返回下采样后的主特征 sample 和所有残差 res_samples。
    • 如果不支持 cross-attention:
      • 只传递主特征和条件嵌入。
      • 若为 adapter 且有残差,则直接加到 sample 上。
    • 每次循环都将 res_samples 累加到 down_block_res_samples,便于后续上采样阶段使用。
  4. ControlNet 残差融合

    • 如果是 ControlNet 分支,则将 down_block_additional_residuals 依次加到 down_block_res_samples 的每一项,实现对下采样特征的调控。

整体来说,这段代码负责 UNet 下采样路径的主流程,兼容普通、Adapter 和 ControlNet 三种模式,灵活处理多种条件残差信息,为后续中间块和上采样阶段提供丰富的多尺度特征。

5. mid部分

		# 4. midif self.mid_block is not None:sample = self.mid_block(sample,emb,encoder_hidden_states=encoder_hidden_states,attention_mask=attention_mask,cross_attention_kwargs=cross_attention_kwargs,encoder_attention_mask=encoder_attention_mask,)# To support T2I-Adapter-XLif (is_adapterand len(down_block_additional_residuals) > 0and sample.shape == down_block_additional_residuals[0].shape):sample += down_block_additional_residuals.pop(0)if is_controlnet:sample = sample + mid_block_additional_residual

各部分功能如下:

  1. 中间块处理(mid_block)

    • if self.mid_block is not None: 判断是否存在中间块(mid_block),这是 UNet 结构中下采样和上采样之间的桥梁。
    • sample = self.mid_block(…) 用当前特征、条件嵌入、编码器隐藏状态、注意力掩码等信息,经过 mid_block 进一步处理特征。
  2. T2I-Adapter-XL 支持

    • if (is_adapter and len(down_block_additional_residuals) > 0 and sample.shape == down_block_additional_residuals[0].shape): 这是为 T2I-Adapter-XL 结构设计的兼容逻辑。
    • 如果当前是 adapter 模式,且有额外残差(down_block_additional_residuals),并且形状匹配,则将第一个残差加到 sample 上,并从列表中移除。
    • 这样可以灵活地将下采样阶段的额外信息传递到中间块,增强条件控制能力。
  3. ControlNet 支持

    • if is_controlnet: 判断当前是否为 ControlNet 分支。
    • 如果是,则将 mid_block_additional_residual 加到 sample 上,实现对中间特征的进一步调控。
      整体来说,这段代码是 UNet 主体的中间处理部分,既支持标准的 mid_block 处理,也兼容 T2I-Adapter-XL 和 ControlNet 等多种条件控制结构,增强了模型的灵活性和可扩展性。

6. up部分

		# 5. upfor i, upsample_block in enumerate(self.up_blocks):is_final_block = i == len(self.up_blocks) - 1res_samples = down_block_res_samples[-len(upsample_block.resnets) :]down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]# if we have not reached the final block and need to forward the# upsample size, we do it hereif not is_final_block and forward_upsample_size:upsample_size = down_block_res_samples[-1].shape[2:]if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:sample = upsample_block(hidden_states=sample,temb=emb,res_hidden_states_tuple=res_samples,encoder_hidden_states=encoder_hidden_states,cross_attention_kwargs=cross_attention_kwargs,upsample_size=upsample_size,attention_mask=attention_mask,encoder_attention_mask=encoder_attention_mask,)else:sample = upsample_block(hidden_states=sample,temb=emb,res_hidden_states_tuple=res_samples,upsample_size=upsample_size)if not return_dict:return (sample,)return UNet2DConditionOutput(sample=sample)

上采样的各部分功能如下:

  1. 循环遍历上采样块 :

    • 使用 enumerate(self.up_blocks) 遍历所有的上采样块(upsample blocks)。
    • is_final_block 变量用于判断当前是否为最后一个上采样块。
  2. 处理残差样本 :

    • res_samples 从 down_block_res_samples 中提取与当前上采样块对应的残差样本。
    • 更新 down_block_res_samples 以移除已使用的残差样本。
  3. 上采样尺寸的前向传播 :

    • 如果不是最后一个上采样块且需要前向传播上采样尺寸,则计算 upsample_size 。
  4. 处理交叉注意力 :

    • 检查上采样块是否具有交叉注意力(cross-attention)。
    • 如果有交叉注意力,调用上采样块时传入 encoder_hidden_states 、 cross_attention_kwargs 、 attention_mask 和 encoder_attention_mask 等参数。
    • 如果没有交叉注意力,则只传入必要的参数。
  5. 返回结果 :

    • 如果 return_dict 为 False ,则返回 sample 。
    • 否则,返回 UNet2DConditionOutput 对象,其中包含处理后的 sample 。

相关文章:

  • 【位运算】消失的两个数字(hard)
  • STM32 PID控制
  • Hyperlane 框架详解与使用指南
  • shell打印图案
  • 常用的OceanBase调优配置参数
  • Maven 多仓库配置及缓存清理实战分享
  • 【Redis/1-前置知识】分布式系统概论:架构、数据库与微服务
  • vue的created和mounted区别
  • word嵌入图片显示不全-error记
  • Linux下制作Nginx绿色免安装包
  • 介绍一种直流过压保护电路
  • 中和农信创新引领“三农“金融服务新模式
  • vue实现气泡词云图
  • Elasticsearch增删改查语句
  • freeCAD 学习 step1
  • 金属工具制造企业如何做项目管理?数字化系统全面提升交付效率
  • vue的这两个特性:数据驱动视图 与 双向数据绑定的区别
  • 日志收集工具-logstash
  • Gartner企业技术参考架构学习心得
  • GPU架构对大模型推理部署到底有什么影响?
  • 网站做301的坏处/太原百度网站快速排名
  • html网页设计代码作业网站/如何用手机创建网站
  • 重庆公司大学派斯学院/大连seo按天付费
  • 云网站建设的意义/舆情网站直接打开
  • wordpress网站响应速度插件/游戏推广文案
  • 检测网站是否被墙/seo引擎优化培训