代码解读——ReferenceNet
ReferenceNet代码详解
文章目录
- ReferenceNet代码详解
- init部分
- 1. 定义部分
- 2. input&time部分
- 3. class embedding
- 4. down
- 5. mid
- 6. up
- forward部分
- 1. 初始部分
- 2. time部分
- 3. pre-process部分
- 4. down部分
- 5. mid部分
- 6. up部分
其实就是类似UNet 2D Condition Model的架构,具体可以参考:
stable diffusion 中使用的 UNet 2D Condition Model 结构解析(diffusers库) - 知乎
init部分
1. 定义部分
@register_to_configdef __init__(self,sample_size: Optional[int] = None,in_channels: int = 4,out_channels: int = 4,center_input_sample: bool = False,flip_sin_to_cos: bool = True,freq_shift: int = 0,down_block_types: Tuple[str] = ("CrossAttnDownBlock2D","CrossAttnDownBlock2D","CrossAttnDownBlock2D","DownBlock2D",),mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),only_cross_attention: Union[bool, Tuple[bool]] = False,block_out_channels: Tuple[int] = (320, 640, 1280, 1280),layers_per_block: Union[int, Tuple[int]] = 2,downsample_padding: int = 1,mid_block_scale_factor: float = 1,act_fn: str = "silu",norm_num_groups: Optional[int] = 32,norm_eps: float = 1e-5,cross_attention_dim: Union[int, Tuple[int]] = 1280,transformer_layers_per_block: Union[int, Tuple[int]] = 1,encoder_hid_dim: Optional[int] = None,encoder_hid_dim_type: Optional[str] = None,attention_head_dim: Union[int, Tuple[int]] = 8,num_attention_heads: Optional[Union[int, Tuple[int]]] = None,dual_cross_attention: bool = False,use_linear_projection: bool = False,class_embed_type: Optional[str] = None,addition_embed_type: Optional[str] = None,addition_time_embed_dim: Optional[int] = None,num_class_embeds: Optional[int] = None,upcast_attention: bool = False,resnet_time_scale_shift: str = "default",resnet_skip_time_act: bool = False,resnet_out_scale_factor: int = 1.0,time_embedding_type: str = "positional",time_embedding_dim: Optional[int] = None,time_embedding_act_fn: Optional[str] = None,timestep_post_act: Optional[str] = None,time_cond_proj_dim: Optional[int] = None,conv_in_kernel: int = 3,conv_out_kernel: int = 3,projection_class_embeddings_input_dim: Optional[int] = None,attention_type: str = "default",class_embeddings_concat: bool = False,mid_block_only_cross_attention: Optional[bool] = None,cross_attention_norm: Optional[str] = None,addition_embed_type_num_heads=64,):super().__init__()self.sample_size = sample_sizeif num_attention_heads is not None:raise ValueError("At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19.")# If `num_attention_heads` is not defined (which is the case for most models)# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.# The reason for this behavior is to correct for incorrectly named variables that were introduced# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking# which is why we correct for the naming here.num_attention_heads = num_attention_heads or attention_head_dim# Check inputsif len(down_block_types) != len(up_block_types):raise ValueError(f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}.")if len(block_out_channels) != len(down_block_types):raise ValueError(f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}.")if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):raise ValueError(f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}.")if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):raise ValueError(f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}.")if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):raise ValueError(f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}.")if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):raise ValueError(f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}.")if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):raise ValueError(f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}.")
主要定义了 ReferenceNet 类的构造函数 init ,用于初始化模型的各种参数和配置。以下是各部分的详细功能:
-
类属性 :
- _supports_gradient_checkpointing :设置为 True ,表示支持梯度检查点功能。
-
构造函数参数 :
- sample_size 、 in_channels 、 out_channels 等:定义输入输出的基本参数。
- down_block_types 、 up_block_types :指定下采样和上采样块的类型。
- block_out_channels :定义每个块的输出通道数。
- layers_per_block :每个块的层数。
- cross_attention_dim 、 attention_head_dim 等:与注意力机制相关的参数。
- 其他参数:如 act_fn (激活函数)、 norm_num_groups (归一化组数)等。
-
参数检查 :
- 检查 down_block_types 和 up_block_types 的长度是否一致。
- 检查 block_out_channels 的长度是否与 down_block_types 一致。
- 检查 only_cross_attention 、 num_attention_heads 、 attention_head_dim 等参数的长度是否与 down_block_types 一致。
-
注意力头数处理 :
- 如果 num_attention_heads 未定义,则默认使用 attention_head_dim 。
- 如果 num_attention_heads 已定义,则抛出异常,提示当前版本不支持此参数的定义。
这些设置和检查确保了模型的配置正确,并为后续的模型构建提供了基础。
2. input&time部分
# inputconv_in_padding = (conv_in_kernel - 1) // 2self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding)# timeif time_embedding_type == "fourier":time_embed_dim = time_embedding_dim or block_out_channels[0] * 2if time_embed_dim % 2 != 0:raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")self.time_proj = GaussianFourierProjection(time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos)timestep_input_dim = time_embed_dimelif time_embedding_type == "positional":time_embed_dim = time_embedding_dim or block_out_channels[0] * 4self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)timestep_input_dim = block_out_channels[0]else:raise ValueError(f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`.")self.time_embedding = TimestepEmbedding(timestep_input_dim,time_embed_dim,act_fn=act_fn,post_act_fn=timestep_post_act,cond_proj_dim=time_cond_proj_dim,)if encoder_hid_dim_type is None and encoder_hid_dim is not None:encoder_hid_dim_type = "text_proj"self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")if encoder_hid_dim is None and encoder_hid_dim_type is not None:raise ValueError(f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}.")if encoder_hid_dim_type == "text_proj":self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)elif encoder_hid_dim_type == "text_image_proj":# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use# case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`self.encoder_hid_proj = TextImageProjection(text_embed_dim=encoder_hid_dim,image_embed_dim=cross_attention_dim,cross_attention_dim=cross_attention_dim,)elif encoder_hid_dim_type == "image_proj":# Kandinsky 2.2self.encoder_hid_proj = ImageProjection(image_embed_dim=encoder_hid_dim,cross_attention_dim=cross_attention_dim,)elif encoder_hid_dim_type is not None:raise ValueError(f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'.")else:self.encoder_hid_proj = None
主要定义了输入处理、时间嵌入和编码器隐藏状态投影的初始化。以下是各部分的详细功能:
-
输入卷积层 :
- self.conv_in :使用 nn.Conv2d 定义输入卷积层,用于将输入通道数转换为第一个块的输出通道数。卷积核大小由 conv_in_kernel 参数决定,填充方式为 (conv_in_kernel - 1) // 2 。
-
时间嵌入 :
- 根据 time_embedding_type 选择时间嵌入的方式:
- fourier :使用高斯傅里叶投影( GaussianFourierProjection )进行时间嵌入,要求 time_embed_dim 是2的倍数。
- positional :使用位置嵌入( Timesteps )进行时间嵌入。
- self.time_embedding :定义时间步嵌入层( TimestepEmbedding ),用于将时间步信息嵌入到模型中。
- 根据 time_embedding_type 选择时间嵌入的方式:
-
编码器隐藏状态投影 :
- 根据 encoder_hid_dim_type 选择不同的投影方式:
- text_proj :使用线性层( nn.Linear )将文本嵌入投影到交叉注意力维度。
- text_image_proj :使用 TextImageProjection 进行文本和图像嵌入的联合投影。
- image_proj :使用 ImageProjection 进行图像嵌入投影。
- 如果 encoder_hid_dim_type 未定义但 encoder_hid_dim 已定义,则默认设置为 text_proj
- 根据 encoder_hid_dim_type 选择不同的投影方式:
3. class embedding
# class embeddingif class_embed_type is None and num_class_embeds is not None:self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)elif class_embed_type == "timestep":self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)elif class_embed_type == "identity":self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)elif class_embed_type == "projection":if projection_class_embeddings_input_dim is None:raise ValueError("`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set")# The projection `class_embed_type` is the same as the timestep `class_embed_type` except# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings# 2. it projects from an arbitrary input dimension.## Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.# As a result, `TimestepEmbedding` can be passed arbitrary vectors.self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)elif class_embed_type == "simple_projection":if projection_class_embeddings_input_dim is None:raise ValueError("`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set")self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)else:self.class_embedding = Noneif addition_embed_type == "text":if encoder_hid_dim is not None:text_time_embedding_from_dim = encoder_hid_dimelse:text_time_embedding_from_dim = cross_attention_dimself.add_embedding = TextTimeEmbedding(text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads)elif addition_embed_type == "text_image":# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use# case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`self.add_embedding = TextImageTimeEmbedding(text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim)elif addition_embed_type == "text_time":self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)elif addition_embed_type == "image":# Kandinsky 2.2self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)elif addition_embed_type == "image_hint":# Kandinsky 2.2 ControlNetself.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)elif addition_embed_type is not None:raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")if time_embedding_act_fn is None:self.time_embed_act = Noneelse:self.time_embed_act = get_activation(time_embedding_act_fn)self.down_blocks = nn.ModuleList([])self.up_blocks = nn.ModuleList([])if isinstance(only_cross_attention, bool):if mid_block_only_cross_attention is None:mid_block_only_cross_attention = only_cross_attentiononly_cross_attention = [only_cross_attention] * len(down_block_types)if mid_block_only_cross_attention is None:mid_block_only_cross_attention = Falseif isinstance(num_attention_heads, int):num_attention_heads = (num_attention_heads,) * len(down_block_types)if isinstance(attention_head_dim, int):attention_head_dim = (attention_head_dim,) * len(down_block_types)if isinstance(cross_attention_dim, int):cross_attention_dim = (cross_attention_dim,) * len(down_block_types)if isinstance(layers_per_block, int):layers_per_block = [layers_per_block] * len(down_block_types)if isinstance(transformer_layers_per_block, int):transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)if class_embeddings_concat:# The time embeddings are concatenated with the class embeddings. The dimension of the# time embeddings passed to the down, middle, and up blocks is twice the dimension of the# regular time embeddingsblocks_time_embed_dim = time_embed_dim * 2else:blocks_time_embed_dim = time_embed_dim
主要处理的是类嵌入和附加嵌入的初始化。
-
类嵌入(Class Embedding) :
- 根据 class_embed_type 的不同,初始化不同类型的类嵌入。
- 如果 class_embed_type 为 None 且 num_class_embeds 不为 None ,则使用 nn.Embedding 初始化类嵌入。
- 如果 class_embed_type 为 “timestep”,则使用 TimestepEmbedding 。
- 如果 class_embed_type 为 “identity”,则使用 nn.Identity 。
- 如果 class_embed_type 为 “projection” 或 “simple_projection”,则分别使用 TimestepEmbedding 或 nn.Linear ,并检查 projection_class_embeddings_input_dim 是否设置。
-
附加嵌入(Addition Embedding) :
- 根据 addition_embed_type 的不同,初始化不同类型的附加嵌入。
- 如果 addition_embed_type 为 “text”,则使用 TextTimeEmbedding ,并根据 encoder_hid_dim 或 cross_attention_dim 设置输入维度。
- 如果 addition_embed_type 为 “text_image”,则使用 TextImageTimeEmbedding 。
- 如果 addition_embed_type 为 “text_time”,则初始化 add_time_proj 和 add_embedding 。
- 如果 addition_embed_type 为 “image” 或 “image_hint”,则分别使用 ImageTimeEmbedding 或 ImageHintTimeEmbedding 。
- 如果 addition_embed_type 不为 None 且不在上述类型中,则抛出 ValueError 。
-
时间嵌入激活函数 :
- 根据 time_embedding_act_fn 的值,初始化时间嵌入激活函数。
-
模块列表初始化 :
- 初始化 down_blocks 和 up_blocks 为 nn.ModuleList 。
-
参数处理 :
- 根据参数类型,调整 only_cross_attention 、 num_attention_heads 、 attention_head_dim 、 cross_attention_dim 、 layers_per_block 和 transformer_layers_per_block 的格式。
- 如果 class_embeddings_concat 为 True ,则将 blocks_time_embed_dim 设置为 time_embed_dim 的两倍,否则保持不变。
4. down
# downoutput_channel = block_out_channels[0]for i, down_block_type in enumerate(down_block_types):input_channel = output_channeloutput_channel = block_out_channels[i]is_final_block = i == len(block_out_channels) - 1down_block = get_down_block(down_block_type,num_layers=layers_per_block[i],transformer_layers_per_block=transformer_layers_per_block[i],in_channels=input_channel,out_channels=output_channel,temb_channels=blocks_time_embed_dim,add_downsample=not is_final_block,resnet_eps=norm_eps,resnet_act_fn=act_fn,resnet_groups=norm_num_groups,cross_attention_dim=cross_attention_dim[i],num_attention_heads=num_attention_heads[i],downsample_padding=downsample_padding,dual_cross_attention=dual_cross_attention,use_linear_projection=use_linear_projection,only_cross_attention=only_cross_attention[i],upcast_attention=upcast_attention,resnet_time_scale_shift=resnet_time_scale_shift,attention_type=attention_type,resnet_skip_time_act=resnet_skip_time_act,resnet_out_scale_factor=resnet_out_scale_factor,cross_attention_norm=cross_attention_norm,attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,)self.down_blocks.append(down_block)
主要定义了下采样块(down blocks)的构建过程。以下是各部分的详细功能:
-
初始化输出通道 :
- output_channel 初始化为第一个块的输出通道数,即 block_out_channels[0] 。
-
遍历下采样块类型 :
- 使用 enumerate(down_block_types) 遍历所有下采样块类型。
- 对于每个下采样块,设置输入通道为上一个块的输出通道,输出通道为当前块的输出通道。
- 判断是否为最后一个块,通过 is_final_block 标志来决定是否添加下采样层。
-
构建下采样块 :
- 调用 get_down_block 函数,根据当前块的类型和参数构建下采样块。
- 参数包括层数、输入输出通道数、时间嵌入通道数、注意力机制相关参数等。
-
添加到下采样块列表 :
- 将构建好的下采样块添加到 self.down_blocks 列表中。
5. mid
# midif mid_block_type == "UNetMidBlock2DCrossAttn":self.mid_block = UNetMidBlock2DCrossAttn(transformer_layers_per_block=transformer_layers_per_block[-1],in_channels=block_out_channels[-1],temb_channels=blocks_time_embed_dim,resnet_eps=norm_eps,resnet_act_fn=act_fn,output_scale_factor=mid_block_scale_factor,resnet_time_scale_shift=resnet_time_scale_shift,cross_attention_dim=cross_attention_dim[-1],num_attention_heads=num_attention_heads[-1],resnet_groups=norm_num_groups,dual_cross_attention=dual_cross_attention,use_linear_projection=use_linear_projection,upcast_attention=upcast_attention,attention_type=attention_type,)elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":self.mid_block = UNetMidBlock2DSimpleCrossAttn(in_channels=block_out_channels[-1],temb_channels=blocks_time_embed_dim,resnet_eps=norm_eps,resnet_act_fn=act_fn,output_scale_factor=mid_block_scale_factor,cross_attention_dim=cross_attention_dim[-1],attention_head_dim=attention_head_dim[-1],resnet_groups=norm_num_groups,resnet_time_scale_shift=resnet_time_scale_shift,skip_time_act=resnet_skip_time_act,only_cross_attention=mid_block_only_cross_attention,cross_attention_norm=cross_attention_norm,)elif mid_block_type is None:self.mid_block = Noneelse:raise ValueError(f"unknown mid_block_type : {mid_block_type}")
主要处理的是 UNet 的中间块(mid-block)的初始化。这个代码段根据 mid_block_type 的不同,选择不同的中间块类型进行初始化。
-
UNetMidBlock2DCrossAttn :
- 如果 mid_block_type 是 “UNetMidBlock2DCrossAttn”,则初始化一个 UNetMidBlock2DCrossAttn 对象。
- 该对象包含多个参数设置,包括 transformer_layers_per_block 、 in_channels 、 temb_channels 、 resnet_eps 、 resnet_act_fn 、 output_scale_factor 、 resnet_time_scale_shift 、 cross_attention_dim 、 num_attention_heads 、 resnet_groups 、 dual_cross_attention 、 use_linear_projection 、 upcast_attention 和 attention_type 。
- 这些参数用于配置中间块的各种特性,如通道数、时间嵌入、注意力机制等。
-
UNetMidBlock2DSimpleCrossAttn :
- 如果 mid_block_type 是 “UNetMidBlock2DSimpleCrossAttn”,则初始化一个 UNetMidBlock2DSimpleCrossAttn 对象。
- 该对象包含参数设置,包括 in_channels 、 temb_channels 、 resnet_eps 、 resnet_act_fn 、 output_scale_factor 、 cross_attention_dim 、 attention_head_dim 、 resnet_groups 、 resnet_time_scale_shift 、 skip_time_act 、 only_cross_attention 和 cross_attention_norm 。
- 这些参数用于配置简单交叉注意力机制的中间块。
-
None :
- 如果 mid_block_type 是 None ,则不初始化任何中间块,将 self.mid_block 设置为 None 。
-
异常处理 :
- 如果 mid_block_type 是其他未定义的类型,则抛出一个 ValueError 异常,提示 “unknown mid_block_type”。
6. up
# count how many layers upsample the imagesself.num_upsamplers = 0# upreversed_block_out_channels = list(reversed(block_out_channels))reversed_num_attention_heads = list(reversed(num_attention_heads))reversed_layers_per_block = list(reversed(layers_per_block))reversed_cross_attention_dim = list(reversed(cross_attention_dim))reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))only_cross_attention = list(reversed(only_cross_attention))output_channel = reversed_block_out_channels[0]for i, up_block_type in enumerate(up_block_types):is_final_block = i == len(block_out_channels) - 1prev_output_channel = output_channeloutput_channel = reversed_block_out_channels[i]input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]# add upsample block for all BUT final layerif not is_final_block:add_upsample = Trueself.num_upsamplers += 1else:add_upsample = Falseup_block = get_up_block(up_block_type,num_layers=reversed_layers_per_block[i] + 1,transformer_layers_per_block=reversed_transformer_layers_per_block[i],in_channels=input_channel,out_channels=output_channel,prev_output_channel=prev_output_channel,temb_channels=blocks_time_embed_dim,add_upsample=add_upsample,resnet_eps=norm_eps,resnet_act_fn=act_fn,resnet_groups=norm_num_groups,cross_attention_dim=reversed_cross_attention_dim[i],num_attention_heads=reversed_num_attention_heads[i],dual_cross_attention=dual_cross_attention,use_linear_projection=use_linear_projection,only_cross_attention=only_cross_attention[i],upcast_attention=upcast_attention,resnet_time_scale_shift=resnet_time_scale_shift,attention_type=attention_type,resnet_skip_time_act=resnet_skip_time_act,resnet_out_scale_factor=resnet_out_scale_factor,cross_attention_norm=cross_attention_norm,attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,)self.up_blocks.append(up_block)prev_output_channel = output_channelself.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_q = _LoRACompatibleLinear()self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_k = _LoRACompatibleLinear()self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_v = _LoRACompatibleLinear()self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_out = nn.ModuleList([Identity(), Identity()])self.up_blocks[3].attentions[2].transformer_blocks[0].norm2 = Identity()self.up_blocks[3].attentions[2].transformer_blocks[0].attn2 = Noneself.up_blocks[3].attentions[2].transformer_blocks[0].norm3 = Identity()self.up_blocks[3].attentions[2].transformer_blocks[0].ff = Identity()self.up_blocks[3].attentions[2].proj_out = Identity()if attention_type in ["gated", "gated-text-image"]:positive_len = 768if isinstance(cross_attention_dim, int):positive_len = cross_attention_dimelif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):positive_len = cross_attention_dim[0]feature_type = "text-only" if attention_type == "gated" else "text-image"self.position_net = PositionNet(positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type)
这部分主要处理的是 UNet 的上采样(upsample)过程。根据不同的上采样块类型,初始化并配置上采样块:
-
初始化反转参数 :
- 代码首先反转了一些参数列表,如 block_out_channels 、 num_attention_heads 、 layers_per_block 、 cross_attention_dim 和 transformer_layers_per_block ,以便在上采样过程中从最后一个块开始处理。
-
上采样块的构建 :
- 遍历 up_block_types ,为每种类型的上采样块调用 get_up_block 函数进行构建。
- 在构建过程中,设置了输入通道、输出通道、前一个块的输出通道、时间嵌入通道、是否添加上采样、ResNet 参数、交叉注意力参数等。
- 通过 is_final_block 判断是否为最后一个块,决定是否添加上采样。
-
特殊处理 :
- 对于最后一个上采样块,进行了特殊处理,跳过了最后一层的交叉注意力以加速计算,并为 DDP 训练做了优化。
- 使用 _LoRACompatibleLinear 和 Identity 替换了一些注意力机制中的组件,以减少计算量。
-
位置网络的初始化 :
- 如果 attention_type 是 “gated” 或 “gated-text-image”,则初始化 PositionNet ,用于处理位置嵌入。
- 根据 cross_attention_dim 的类型设置 positive_len ,并根据 attention_type 设置 feature_type 。
forward部分
1. 初始部分
def forward(self,sample: torch.FloatTensor,timestep: Union[torch.Tensor, float, int],encoder_hidden_states: torch.Tensor,class_labels: Optional[torch.Tensor] = None,timestep_cond: Optional[torch.Tensor] = None,attention_mask: Optional[torch.Tensor] = None,cross_attention_kwargs: Optional[Dict[str, Any]] = None,added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,mid_block_additional_residual: Optional[torch.Tensor] = None,encoder_attention_mask: Optional[torch.Tensor] = None,return_dict: bool = True,) -> Union[UNet2DConditionOutput, Tuple]:r"""The [`UNet2DConditionModel`] forward method.Args:sample (`torch.FloatTensor`):The noisy input tensor with the following shape `(batch, channel, height, width)`.timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.encoder_hidden_states (`torch.FloatTensor`):The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.encoder_attention_mask (`torch.Tensor`):A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If`True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,which adds large negative values to the attention scores corresponding to "discard" tokens.return_dict (`bool`, *optional*, defaults to `True`):Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plaintuple.cross_attention_kwargs (`dict`, *optional*):A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].added_cond_kwargs: (`dict`, *optional*):A kwargs dictionary containin additional embeddings that if specified are added to the embeddings thatare passed along to the UNet blocks.Returns:[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwisea `tuple` is returned where the first element is the sample tensor."""# By default samples have to be AT least a multiple of the overall upsampling factor.# The overall upsampling factor is equal to 2 ** (# num of upsampling layers).# However, the upsampling interpolation output size can be forced to fit any upsampling size# on the fly if necessary.default_overall_up_factor = 2**self.num_upsamplers# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`forward_upsample_size = Falseupsample_size = Noneif any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):logger.info("Forward upsample size to force interpolation output size.")forward_upsample_size = Trueif attention_mask is not None:attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0attention_mask = attention_mask.unsqueeze(1)# convert encoder_attention_mask to a bias the same way we do for attention_maskif encoder_attention_mask is not None:encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0encoder_attention_mask = encoder_attention_mask.unsqueeze(1)# 0. center input if necessaryif self.config.center_input_sample:sample = 2 * sample - 1.0
这段代码主要用于处理UNet模型输入输出尺寸、注意力掩码和输入归一化,具体解释如下:
-
上采样因子检查 :
- default_overall_up_factor = 2**self.num_upsamplers 计算整体上采样倍数(如有4层上采样则为16),要求输入图片的高宽必须是该倍数的整数倍,否则会导致上采样后尺寸不匹配。
- 如果输入尺寸不是上采样倍数的整数倍,则设置 forward_upsample_size = True ,后续会强制插值调整输出尺寸。
-
注意力掩码处理 :
- attention_mask 和 encoder_attention_mask 都会被转换为适合加性注意力机制的 bias mask,掩码为1的位置变为0,掩码为0的位置变为-10000,防止无效区域被关注。
- 并通过 unsqueeze(1) 扩展维度,适配后续注意力层输入。
-
输入归一化 :
- 如果配置了 center_input_sample ,则将输入sample从[0,1]区间线性映射到[-1,1],有助于模型收敛和数值稳定。
整体作用是保证UNet输入输出尺寸和注意力机制的兼容性,并对输入数据做归一化预处理。
- 如果配置了 center_input_sample ,则将输入sample从[0,1]区间线性映射到[-1,1],有助于模型收敛和数值稳定。
2. time部分
# 1. timetimesteps = timestepif not torch.is_tensor(timesteps):# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can# This would be a good case for the `match` statement (Python 3.10+)is_mps = sample.device.type == "mps"if isinstance(timestep, float):dtype = torch.float32 if is_mps else torch.float64else:dtype = torch.int32 if is_mps else torch.int64timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)elif len(timesteps.shape) == 0:timesteps = timesteps[None].to(sample.device)# broadcast to batch dimension in a way that's compatible with ONNX/Core MLtimesteps = timesteps.expand(sample.shape[0])t_emb = self.time_proj(timesteps)# `Timesteps` does not contain any weights and will always return f32 tensors# but time_embedding might actually be running in fp16. so we need to cast here.# there might be better ways to encapsulate this.t_emb = t_emb.to(dtype=sample.dtype)emb = self.time_embedding(t_emb, timestep_cond)aug_emb = Noneif self.class_embedding is not None:if class_labels is None:raise ValueError("class_labels should be provided when num_class_embeds > 0")if self.config.class_embed_type == "timestep":class_labels = self.time_proj(class_labels)# `Timesteps` does not contain any weights and will always return f32 tensors# there might be better ways to encapsulate this.class_labels = class_labels.to(dtype=sample.dtype)class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)if self.config.class_embeddings_concat:emb = torch.cat([emb, class_emb], dim=-1)else:emb = emb + class_embif self.config.addition_embed_type == "text":aug_emb = self.add_embedding(encoder_hidden_states)elif self.config.addition_embed_type == "text_image":# Kandinsky 2.1 - styleif "image_embeds" not in added_cond_kwargs:raise ValueError(f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`")image_embs = added_cond_kwargs.get("image_embeds")text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)aug_emb = self.add_embedding(text_embs, image_embs)elif self.config.addition_embed_type == "text_time":# SDXL - styleif "text_embeds" not in added_cond_kwargs:raise ValueError(f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`")text_embeds = added_cond_kwargs.get("text_embeds")if "time_ids" not in added_cond_kwargs:raise ValueError(f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`")time_ids = added_cond_kwargs.get("time_ids")time_embeds = self.add_time_proj(time_ids.flatten())time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)add_embeds = add_embeds.to(emb.dtype)aug_emb = self.add_embedding(add_embeds)elif self.config.addition_embed_type == "image":# Kandinsky 2.2 - styleif "image_embeds" not in added_cond_kwargs:raise ValueError(f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`")image_embs = added_cond_kwargs.get("image_embeds")aug_emb = self.add_embedding(image_embs)elif self.config.addition_embed_type == "image_hint":# Kandinsky 2.2 - styleif "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:raise ValueError(f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`")image_embs = added_cond_kwargs.get("image_embeds")hint = added_cond_kwargs.get("hint")aug_emb, hint = self.add_embedding(image_embs, hint)sample = torch.cat([sample, hint], dim=1)emb = emb + aug_emb if aug_emb is not None else embif self.time_embed_act is not None:emb = self.time_embed_act(emb)if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":# Kadinsky 2.1 - styleif "image_embeds" not in added_cond_kwargs:raise ValueError(f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`")image_embeds = added_cond_kwargs.get("image_embeds")encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":# Kandinsky 2.2 - styleif "image_embeds" not in added_cond_kwargs:raise ValueError(f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`")image_embeds = added_cond_kwargs.get("image_embeds")encoder_hidden_states = self.encoder_hid_proj(image_embeds)
这段代码主要是在做条件 UNet 的条件嵌入(embedding)准备工作,具体包括:
-
时间步嵌入(timesteps & t_emb)
- 处理扩散模型的时间步输入,将其转换为张量并投影到高维空间,作为后续条件特征。
- 保证数据类型和设备一致性,适配不同硬件和精度。
-
主条件嵌入(emb)
- 通过 self.time_embedding 将时间步特征和可选的 timestep_cond 进一步编码。
-
类别嵌入(class_embedding)
- 如果模型有类别条件,先检查输入合法性。
- 支持类别标签直接嵌入或先经过时间投影。
- 类别特征与主 embedding 拼接或相加,增强条件表达。
-
附加条件嵌入(addition embedding)
- 根据 config 的 addition_embed_type,支持多种条件融合:
- text:文本条件
- text_image:文本+图像(如 Kandinsky 2.1)
- text_time:文本+时间(如 SDXL)
- image:图像条件(如 Kandinsky 2.2)
- image_hint:图像+hint(如 Kandinsky 2.2)
- 通过 self.add_embedding 统一处理,必要时拼接到 sample。
- 根据 config 的 addition_embed_type,支持多种条件融合:
-
条件融合(emb = emb + aug_emb)
- 将主 embedding 与附加条件 embedding 融合。
-
激活函数(time_embed_act)
- 如果有激活函数,对融合后的 embedding 进行非线性变换。
-
编码器隐藏状态投影(encoder_hidden_states)
- 根据 config.encoder_hid_dim_type,支持三种投影方式:
- text_proj:对文本条件投影
- text_image_proj:对文本和图像条件联合投影
- image_proj:对图像条件投影
- 统一为后续 cross-attention 提供合适的条件特征。
- 根据 config.encoder_hid_dim_type,支持三种投影方式:
3. pre-process部分
# 2. pre-processsample = self.conv_in(sample)# 2.5 GLIGEN position netif cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:cross_attention_kwargs = cross_attention_kwargs.copy()gligen_args = cross_attention_kwargs.pop("gligen")cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
这段代码主要做了两件事:
-
预处理输入特征 :
- sample = self.conv_in(sample) 这行代码将输入 sample 通过一个卷积层(通常是降通道或特征提取),为后续 UNet 主体做准备。
-
GLIGEN 位置网络处理(可选) :
- 如果 cross_attention_kwargs 字典中包含 “gligen” 键,则说明需要用 GLIGEN 相关的位置网络。
- 代码会复制一份 cross_attention_kwargs,取出 gligen 的参数(gligen_args),然后用 self.position_net(**gligen_args) 计算出位置相关的对象特征,重新赋值回 cross_attention_kwargs[“gligen”]。
- 这样做的目的是为后续 cross-attention 层提供空间/位置条件(如目标检测框、分割等空间提示),常用于多模态或空间控制场景。
4. down部分
# 3. downis_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not Noneis_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not Nonedown_block_res_samples = (sample,)for downsample_block in self.down_blocks:if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:# For t2i-adapter CrossAttnDownBlock2Dadditional_residuals = {}if is_adapter and len(down_block_additional_residuals) > 0:additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)sample, res_samples = downsample_block(hidden_states=sample,temb=emb,encoder_hidden_states=encoder_hidden_states,attention_mask=attention_mask,cross_attention_kwargs=cross_attention_kwargs,encoder_attention_mask=encoder_attention_mask,**additional_residuals,)else:sample, res_samples = downsample_block(hidden_states=sample, temb=emb)if is_adapter and len(down_block_additional_residuals) > 0:sample += down_block_additional_residuals.pop(0)down_block_res_samples += res_samplesif is_controlnet:new_down_block_res_samples = ()for down_block_res_sample, down_block_additional_residual in zip(down_block_res_samples, down_block_additional_residuals):down_block_res_sample = down_block_res_sample + down_block_additional_residualnew_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)down_block_res_samples = new_down_block_res_samples
各部分功能如下:
-
is_controlnet / is_adapter 判断
- is_controlnet 和 is_adapter 用于区分当前网络是否为 ControlNet 或 T2I-Adapter 分支,决定后续残差信息的处理方式。
-
down_block_res_samples 初始化
- down_block_res_samples = (sample,) 初始化下采样阶段的特征收集容器。
-
下采样主循环
- 遍历 self.down_blocks ,对输入特征逐步下采样。
- 如果 downsample_block 支持 cross-attention(如 T2I-Adapter 的 CrossAttnDownBlock2D):
- 若为 adapter 模式且有额外残差,则将残差通过 additional_residuals 传入 block。
- 调用 block,返回下采样后的主特征 sample 和所有残差 res_samples。
- 如果不支持 cross-attention:
- 只传递主特征和条件嵌入。
- 若为 adapter 且有残差,则直接加到 sample 上。
- 每次循环都将 res_samples 累加到 down_block_res_samples,便于后续上采样阶段使用。
-
ControlNet 残差融合
- 如果是 ControlNet 分支,则将 down_block_additional_residuals 依次加到 down_block_res_samples 的每一项,实现对下采样特征的调控。
整体来说,这段代码负责 UNet 下采样路径的主流程,兼容普通、Adapter 和 ControlNet 三种模式,灵活处理多种条件残差信息,为后续中间块和上采样阶段提供丰富的多尺度特征。
5. mid部分
# 4. midif self.mid_block is not None:sample = self.mid_block(sample,emb,encoder_hidden_states=encoder_hidden_states,attention_mask=attention_mask,cross_attention_kwargs=cross_attention_kwargs,encoder_attention_mask=encoder_attention_mask,)# To support T2I-Adapter-XLif (is_adapterand len(down_block_additional_residuals) > 0and sample.shape == down_block_additional_residuals[0].shape):sample += down_block_additional_residuals.pop(0)if is_controlnet:sample = sample + mid_block_additional_residual
各部分功能如下:
-
中间块处理(mid_block)
- if self.mid_block is not None: 判断是否存在中间块(mid_block),这是 UNet 结构中下采样和上采样之间的桥梁。
- sample = self.mid_block(…) 用当前特征、条件嵌入、编码器隐藏状态、注意力掩码等信息,经过 mid_block 进一步处理特征。
-
T2I-Adapter-XL 支持
- if (is_adapter and len(down_block_additional_residuals) > 0 and sample.shape == down_block_additional_residuals[0].shape): 这是为 T2I-Adapter-XL 结构设计的兼容逻辑。
- 如果当前是 adapter 模式,且有额外残差(down_block_additional_residuals),并且形状匹配,则将第一个残差加到 sample 上,并从列表中移除。
- 这样可以灵活地将下采样阶段的额外信息传递到中间块,增强条件控制能力。
-
ControlNet 支持
- if is_controlnet: 判断当前是否为 ControlNet 分支。
- 如果是,则将 mid_block_additional_residual 加到 sample 上,实现对中间特征的进一步调控。
整体来说,这段代码是 UNet 主体的中间处理部分,既支持标准的 mid_block 处理,也兼容 T2I-Adapter-XL 和 ControlNet 等多种条件控制结构,增强了模型的灵活性和可扩展性。
6. up部分
# 5. upfor i, upsample_block in enumerate(self.up_blocks):is_final_block = i == len(self.up_blocks) - 1res_samples = down_block_res_samples[-len(upsample_block.resnets) :]down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]# if we have not reached the final block and need to forward the# upsample size, we do it hereif not is_final_block and forward_upsample_size:upsample_size = down_block_res_samples[-1].shape[2:]if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:sample = upsample_block(hidden_states=sample,temb=emb,res_hidden_states_tuple=res_samples,encoder_hidden_states=encoder_hidden_states,cross_attention_kwargs=cross_attention_kwargs,upsample_size=upsample_size,attention_mask=attention_mask,encoder_attention_mask=encoder_attention_mask,)else:sample = upsample_block(hidden_states=sample,temb=emb,res_hidden_states_tuple=res_samples,upsample_size=upsample_size)if not return_dict:return (sample,)return UNet2DConditionOutput(sample=sample)
上采样的各部分功能如下:
-
循环遍历上采样块 :
- 使用 enumerate(self.up_blocks) 遍历所有的上采样块(upsample blocks)。
- is_final_block 变量用于判断当前是否为最后一个上采样块。
-
处理残差样本 :
- res_samples 从 down_block_res_samples 中提取与当前上采样块对应的残差样本。
- 更新 down_block_res_samples 以移除已使用的残差样本。
-
上采样尺寸的前向传播 :
- 如果不是最后一个上采样块且需要前向传播上采样尺寸,则计算 upsample_size 。
-
处理交叉注意力 :
- 检查上采样块是否具有交叉注意力(cross-attention)。
- 如果有交叉注意力,调用上采样块时传入 encoder_hidden_states 、 cross_attention_kwargs 、 attention_mask 和 encoder_attention_mask 等参数。
- 如果没有交叉注意力,则只传入必要的参数。
-
返回结果 :
- 如果 return_dict 为 False ,则返回 sample 。
- 否则,返回 UNet2DConditionOutput 对象,其中包含处理后的 sample 。