当前位置: 首页 > news >正文

【深度学习的灵魂】图片布局生成模型LayoutPrompt(2)·布局序列化模块

🌈 个人主页:十二月的猫-CSDN博客
🔥 系列专栏: 🏀《深度学习理论直觉三十讲》_十二月的猫的博客-CSDN博客

💪🏻 十二月的寒冬阻挡不了春天的脚步,十二点的黑夜遮蔽不住黎明的曙光

目录

1. 前言

2. LayoutPrompt介绍

3. 序列化复习 

4. LayoutPrompt · 动态选择模块

4.1 动态选择模块的父类

4.2 基于元素类型相似度的动态选择

4.3 基于类型+尺寸相似度的的动态选择

5. LayoutPrompt ·排名模块

6. 总结


1. 前言

        猫猫不知道大家有没有思考过图片布局生成模型,这算是生成模型的一个非常小的子任务了。前面带大家学习过生成模型,包括GAN、Diffusion等。这些都算是生成模型的研究子领域,利用这些子领域的知识,我们可以来研究具体的任务,例如生成图片、按照语言提示生成图片、按照布局提示生成图片等。

        同样生成布局也是生成模型中的一个具体任务,其实思路也是非常简单。生成图片这个任务中更具体的任务是生成海报、生成照片、生成动漫图片等。因为直接生成一个海报难度太大,我们就先去生成布局,然后在具体布局的约束下去具体生成完整图片。可以简单给大家看一下布局是什么:

        猫猫研究这个呢,主要还是因为创新实训和软件创新大赛两个需要。同时由于猫猫也研究过一点生成模型,因此这个学期就研究一下这个领域啦~~如果大家也对这个领域感兴趣,可以关注我们团队的专栏(布局生成模型是我们海报生成系统中的一个子模块,希望更多猫友参与到我们的开发当中哦):大模型实战训练营_十二月的猫的博客-CSDN博客

2. LayoutPrompt介绍

总的来说,为了完成LayoutPrompt模型,我们需要完成的子模块有:

  • 数据预处理模块:该模块主要就是利用基础数据预处理方法对数据集中的所有数据样本进行预处理。
  • 动态样本选择模块:从训练集中检索最相关的样本,然后作为最直接的上下文(约束信息)送给大语言模型。
  • 布局序列化模块:用于将上面所选的样本布局转化为序列表述(因为大语言模型对序列化输入有更好的效果)。序列化数据就是类似 自然语言、代码等
  • 大语言模型模块:将序列化处理后的所有样本一起送给大语言模型,让大语言模型参考的情况下给出自己的答案。
  • 大语言模型解析模块:用于将大语言模型给出的布局结果解析为标准化的输出。
  • 布局排序模块(布局评价模块):评价大语言模型生成的布局的质量分数,并做一个排序。 

具体的模型图如下:

模型运行具体流程如下:

  1. 用户输入前导信息,例如画布大小,任务类型等,数据预处理模块预处理数据库(用户根据自己的任务可以选择数据库,如海报布局数据库、手机UI设计数据库)中的所有数据。
  2. 动态样本选择模块得到处理后的数据库数据,然后根据用户输入的前导信息选择合适的example样本。
  3. 将样本+前导信息+测试样本送给大语言模型。
  4. 由大语言模型生成最终的layout,送给Rank模块。
  5. Rank模块排序后分数最高的就是最终输出。

Layout生成模式选择(任务类型):

  1. 元素类型限制
  2. 元素类型+尺寸限制
  3. 元素相对位置限制
  4. 布局补全
  5. 布局修正+补全
  6. 内容避让
  7. 文本描述转布局结构

上面这些都是我们layout生成可以选择的模式,这些模式适应不同的任务背景,可以供大家选择。

3. 序列化复习 

        简单复习一下序列化处理的部分,序列化处理的本质:一、固定输入形式;二、为不同任务设定不同的Prompt。更加局限狭义的说,序列化模块最关键的代码如下(构建Prompt):

prompt分为:前导任务信息Prompt、提供example信息Prompt、

def build_prompt(serializer,exemplars,test_data,dataset,max_length=8000,separator_in_samples="\n",separator_between_samples="\n\n",
):# 前置任务信息prompt = [PREAMBLE.format(serializer.task_type, LAYOUT_DOMAIN[dataset], *CANVAS_SIZE[dataset])]# 具体prompt限制信息:layout类型、尺寸等(论文例子是用seq限制,更好理解,但是也有seq限制的)# 输入输入限制的是prompt用seq形式还是用html形式。输入限制的是constraint_type后面的东西,输出限制的是example的prompt输入形式。这两个都是完整Prompt的一部分for i in range(len(exemplars)):_prompt = (serializer.build_input(exemplars[i])+ separator_in_samples+ serializer.build_output(exemplars[i]))# 前导信息+限制信息if len(separator_between_samples.join(prompt) + _prompt) <= max_length:prompt.append(_prompt)else:breakprompt.append(serializer.build_input(test_data) + separator_in_samples)return separator_between_samples.join(prompt)
  • 前置任务信息Prompt部分:任务类型+layout应用区域+画布尺寸。如下
    PREAMBLE = ("Please generate a layout based on the given information. ""You need to ensure that the generated layout looks realistic, with elements well aligned and avoiding unnecessary overlap.\n""Task Description: {}\n""Layout Domain: {} layout\n""Canvas Size: canvas width is {}px, canvas height is {}px"
    )
  • 任务类型例如:
1. generation conditioned on given element types
2. generation conditioned on given element types and sizes
3. "generation conditioned on given element relationships\n""'A left B' means that the center coordinate of A is to the left of the center coordinate of B. ""'A right B' means that the center coordinate of A is to the right of the center coordinate of B. ""'A top B' means that the center coordinate of A is above the center coordinate of B. ""'A bottom B' means that the center coordinate of A is below the center coordinate of B. ""'A center B' means that the center coordinate of A and the center coordinate of B are very close. ""'A smaller B' means that the area of A is smaller than the ares of B. ""'A larger B' means that the area of A is larger than the ares of B. ""'A equal B' means that the area of A and the ares of B are very close. ""Here, center coordinate = (left + width / 2, top + height / 2), ""area = width * height"
  • layout应用区域如下:
LAYOUT_DOMAIN = {"rico": "android","publaynet": "thesis poster",   # 适应任务做的修改"posterlayout": "poster","webui": "web",
}
  • 画布尺寸:
CANVAS_SIZE = {"rico": (90, 160),"publaynet": (120, 160),"posterlayout": (102, 150),"webui": (120, 120),
}

4. LayoutPrompt · 动态选择模块

        动态选择模块:从layout数据库中选择最符合要求的一系列layout。

        既然是选择,那么就一定存在选择标准,不同任务的选择标准也肯定是不同的。比如对于Poster任务最重要的是layout中元素类型要相同;对于产品海报最重要的是要把中间放图片的部分给留出来用来展示产品照片,因此需要做到内容避让。基于这样一个前提,我们写了很多任务背景下的layout动态选择。当然这里的任务背景和前面的序列化生成的任务背景是相互对应的,用户选择任务类型后,动态选择模块和序列化模块都会选择相对应的任务类型(Layout生成模式)

        下面我们举代码例子,具体完整代码后续更新在Gitee以及CSDN。

4.1 动态选择模块的父类

class ExemplarSelection:def __init__(self,train_data: list,candidate_size: int,num_prompt: int,shuffle: bool = True,):self.train_data = train_data           # 原始训练数据集self.candidate_size = candidate_size   # 候选池最大容量self.num_prompt = num_prompt           # 最终选取的示例数量self.shuffle = shuffle                 # 是否随机打乱# 构建候选池:若指定大小 > 0,则随机采样if self.candidate_size > 0:random.shuffle(self.train_data)self.train_data = self.train_data[: self.candidate_size] # 截取前N个作为候选池def __call__(self, test_data: dict):pass# 过滤无效样本def _is_filter(self, data):# 检查是否存在 width/height 为 0 的无效元素return (data["discrete_gold_bboxes"][:, 2:] == 0).sum().bool().item()# 按照分数拿到对应样本(分数计算由下面代码实现)def _retrieve_exemplars(self, scores: list):scores = sorted(scores, key=lambda x: x[1], reverse=True)exemplars = []for i in range(len(self.train_data)):if not self._is_filter(self.train_data[scores[i][0]]):exemplars.append(self.train_data[scores[i][0]])# 达到目标数量时提前终止if len(exemplars) == self.num_prompt:breakif self.shuffle:random.shuffle(exemplars)return exemplars
  • 原始的examle就是从Train_data中拿。
  • 候选池就是用来放example的池子,有一个上限。
  • 父类中的核心就是一个根据分数list从train_data拿example的函数。

4.2 基于元素类型相似度的动态选择

class GenTypeExemplarSelection(ExemplarSelection):def __call__(self, test_data: dict):scores = []test_labels = test_data["labels"]for i in range(len(self.train_data)):train_labels = self.train_data[i]["labels"]score = labels_similarity(train_labels, test_labels)  # 核心算法scores.append([i, score])return self._retrieve_exemplars(scores)def labels_similarity(labels_1, labels_2):def _intersection(labels_1, labels_2):cnt = 0x = Counter(labels_1)y = Counter(labels_2)for k in x:if k in y:cnt += 2 * min(x[k], y[k])return cntdef _union(labels_1, labels_2):return len(labels_1) + len(labels_2)if isinstance(labels_1, torch.Tensor):labels_1 = labels_1.tolist()if isinstance(labels_2, torch.Tensor):labels_2 = labels_2.tolist()return _intersection(labels_1, labels_2) / _union(labels_1, labels_2)
  • 通过看输入前导信息中的label和example中哪些东西的label类型相近,则选择其作为参考。
  • 核心的分数计算方法为labels_similarity。用来比较Train_data中元素和Test_data中元素的label相似度

4.3 基于类型+尺寸相似度的的动态选择

class GenTypeSizeExemplarSelection(ExemplarSelection):labels_weight = 0.5bboxes_weight = 0.5def __call__(self, test_data: dict):scores = []test_labels = test_data["labels"]test_bboxes = test_data["bboxes"][:, 2:]for i in range(len(self.train_data)):train_labels = self.train_data[i]["labels"]train_bboxes = self.train_data[i]["bboxes"][:, 2:]score = labels_bboxes_similarity(  # 核心算法train_labels,train_bboxes,test_labels,test_bboxes,self.labels_weight,self.bboxes_weight,)scores.append([i, score])return self._retrieve_exemplars(scores)def labels_bboxes_similarity(labels_1, bboxes_1, labels_2, bboxes_2, labels_weight, bboxes_weight
):labels_sim = labels_similarity(labels_1, labels_2)bboxes_sim = bboxes_similarity(labels_1, bboxes_1, labels_2, bboxes_2)return labels_weight * labels_sim + bboxes_weight * bboxes_simdef labels_similarity(labels_1, labels_2):def _intersection(labels_1, labels_2):cnt = 0x = Counter(labels_1)y = Counter(labels_2)for k in x:if k in y:cnt += 2 * min(x[k], y[k])return cntdef _union(labels_1, labels_2):return len(labels_1) + len(labels_2)if isinstance(labels_1, torch.Tensor):labels_1 = labels_1.tolist()if isinstance(labels_2, torch.Tensor):labels_2 = labels_2.tolist()return _intersection(labels_1, labels_2) / _union(labels_1, labels_2)def bboxes_similarity(labels_1, bboxes_1, labels_2, bboxes_2, times=2):"""bboxes_1: M x 4bboxes_2: N x 4distance: M x N"""distance = torch.cdist(bboxes_1, bboxes_2) * timesdistance = torch.pow(0.5, distance)mask = labels_1.unsqueeze(-1) == labels_2.unsqueeze(0)distance = distance * maskrow_ind, col_ind = linear_sum_assignment(-distance)return distance[row_ind, col_ind].sum().item() / len(row_ind)
  • 标签相似度计算使用bboxes_similarity。相似度的比较非常简单:相同/总数
  • 尺寸相似度使用labels_similarity计算。距离计算公式较为复杂,见下面
     
  1. torch.cdist(bboxes_1, bboxes_2)

    • 计算两组边界框之间的欧几里得距离。bboxes_1 和 bboxes_2 通常是形状为 (M, 4) 和 (N, 4) 的张量,其中每个边界框由 (x_min, y_min, x_max, y_max) 来表示。
    • 这将返回一个形状为 (M, N) 的距离矩阵,每个元素表示一对边界框之间的欧几里得距离。
  2. distance = distance * times

    • 将每个距离乘以一个缩放因子 times,可能是为了调整距离的影响力。
  3. distance = torch.pow(0.5, distance)

    • 对距离进行指数衰减,也就是将距离取 0.5 次方,实际上是对距离做了平方根处理,可能是为了减少较大距离对结果的影响。
  4. mask = labels_1.unsqueeze(-1) == labels_2.unsqueeze(0)

    1. 创建一个形状为 (M, N) 的布尔掩码,比较 bboxes_1 和 bboxes_2 中的标签 labels_1 和 labels_2 是否相同。
    2. labels_1 和 labels_2 假设是形状为 (M,) 和 (N,) 的张量,表示每个边界框的标签(例如类别 ID)。
    3. 通过 unsqueeze(-1) 和 unsqueeze(0) 增加维度,使得两个标签张量可以进行逐元素比较,最终得到一个布尔掩码。标签相同的位置会为 True,不同的位置为 False
  5. distance = distance * mask

    1. 将掩码应用到距离矩阵上。对于标签不匹配的边界框对,其对应的距离值会被置为零。这样只有标签相同的边界框对才会参与到后续的匹配中。
  6. row_ind, col_ind = linear_sum_assignment(-distance)

    1. 使用匈牙利算法(通过 linear_sum_assignment 函数)来求解最优匹配问题。由于 linear_sum_assignment 是最小化代价的,而我们想要最小化距离,所以传入的是距离的负值(即 -distance)。
    2. 函数返回两个数组,row_ind 和 col_ind,分别表示最优匹配的行和列索引,即边界框在 bboxes_1 和 bboxes_2 中的匹配关系。
  7. return distance[row_ind, col_ind].sum().item() / len(row_ind)

    1. 根据最优匹配的行和列索引选择相应的距离值,求和后除以匹配的数量(len(row_ind)),得到匹配的平均距离。
    2. .item() 将张量转为 Python 标量值。

5. LayoutPrompt ·排名模块

class Ranker:lambda_1 = 0.2lambda_2 = 0.2lambda_3 = 0.6def __init__(self, val_path=None):self.val_path = val_pathif self.val_path:self.val_data = read_pt(val_path)self.val_labels = [vd["labels"] for vd in self.val_data]self.val_bboxes = [vd["bboxes"] for vd in self.val_data]def __call__(self, predictions: list):metrics = []for pred_labels, pred_bboxes in predictions:metric = []_pred_labels = pred_labels.unsqueeze(0)_pred_bboxes = convert_ltwh_to_ltrb(pred_bboxes).unsqueeze(0)_pred_padding_mask = torch.ones_like(_pred_labels).bool()metric.append(compute_alignment(_pred_bboxes, _pred_padding_mask))metric.append(compute_overlap(_pred_bboxes, _pred_padding_mask))if self.val_path:metric.append(compute_maximum_iou(pred_labels,pred_bboxes,self.val_labels,self.val_bboxes,))metrics.append(metric)metrics = torch.tensor(metrics)min_vals, _ = torch.min(metrics, 0, keepdim=True)max_vals, _ = torch.max(metrics, 0, keepdim=True)scaled_metrics = (metrics - min_vals) / (max_vals - min_vals)if self.val_path:quality = (scaled_metrics[:, 0] * self.lambda_1+ scaled_metrics[:, 1] * self.lambda_2+ (1 - scaled_metrics[:, 2]) * self.lambda_3)else:quality = (scaled_metrics[:, 0] * self.lambda_1+ scaled_metrics[:, 1] * self.lambda_2)_predictions = sorted(zip(predictions, quality), key=lambda x: x[1])ranked_predictions = [item[0] for item in _predictions]return ranked_predictions
  • lambda_1lambda_2lambda_3 是三个常量,分别代表加权系数,用于在排名时对不同指标的加权。这些系数的值总和为 1,用来平衡不同的评估标准。
  • val_path 是一个可选参数,用来指定验证数据集的路径。如果提供了路径,它会读取验证数据(通过 read_pt 函数),并从中提取出标签 (val_labels) 和边界框 (val_bboxes) 信息。val_labels 和 val_bboxes 是从验证数据集中提取的标签和边界框,用于后续的计算(如 IOU 等)。
  • 对于每个预测 (pred_labels, pred_bboxes),首先进行一些预处理:

    • unsqueeze(0):增加一个维度,使得张量的形状符合后续操作的要求。
    • convert_ltwh_to_ltrb(pred_bboxes):将边界框格式从 ltwh(左上角坐标和宽高)转换为 ltrb(左上角和右下角坐标)。
    • 创建一个全为 True 的填充掩码 _pred_padding_mask
  • 之后,计算两个指标:

    • compute_alignment:计算预测边界框的对齐度(如何与真实边界框对齐)。
    • compute_overlap:计算预测边界框的重叠度(预测与真实框的交集比重)。
  • 如果提供了验证数据(self.val_path),还会计算 compute_maximum_iou,这是计算预测边界框与真实边界框的最大 IOU(交并比)值。

  • 将所有指标值(存储在 metrics 列表中)转换为张量,并对每个指标进行最小-最大标准化。这样可以确保每个指标的值都在 [0, 1] 范围内,便于后续的加权处理。
  • 如果有 IOU 指标,quality 的计算会涉及三项指标:

    • scaled_metrics[:, 0]:对齐度
    • scaled_metrics[:, 1]:重叠度
    • scaled_metrics[:, 2]:最大 IOU
    • 对 IOU 进行反向处理 (1 - scaled_metrics[:, 2]),表示较高的 IOU 值意味着较好的预测。
  • 如果没有 IOU,则只使用对齐度和重叠度。

  • 将预测结果与其质量评分进行打包,并按质量评分进行排序。
  • 返回排序后的预测列表 ranked_predictions,从质量最好的预测到最差的预测。

6. 总结

本篇文章带大家深入了解了PosterGenius项目的Layout生成部分的第一篇,后续将更新Layout系列的第二篇。欢迎大家继续支持猫猫呀!!

 【如果想学习更多深度学习文章,可以订阅一下热门专栏】

  • 《PyTorch科研加速指南:即插即用式模块开发》_十二月的猫的博客-CSDN博客
  • 《深度学习理论直觉三十讲》_十二月的猫的博客-CSDN博客
  • 《AI认知筑基三十讲》_十二月的猫的博客-CSDN博客

如果想要学习更多pyTorch/python编程的知识,大家可以点个关注并订阅,持续学习、天天进步你的点赞就是我更新的动力,如果觉得对你有帮助,辛苦友友点个赞,收个藏呀~~~

相关文章:

  • Linux电源管理(5)_Hibernate和Sleep功能介绍
  • Centos9 安装 RocketMQ5
  • Windows 中使用dockers创建指定java web 为镜像和运行容器
  • 深度学习系统学习系列【2】之人工神经网络(ANN)
  • 长江学者答辩ppt美化_特聘教授_校企联聘学者_青年长江学者PPT案例模板
  • 设计模式简述(十七)备忘录模式
  • 使用线性表实现通讯录管理
  • AtCoder Beginner Contest 404(ABCDE)
  • C++八股--5--设计模式--适配器模式,代理模式,观察者模式
  • Maven安装配置以及Idea中的配置教程
  • ElasticSearch深入解析(十):字段膨胀(Mapping 爆炸)问题的解决思路
  • Servlet(二)
  • 安卓基础(悬浮窗和摄像)
  • 大数据引领行业革命:深度解析与未来趋势
  • 【网络原理】深入理解HTTPS协议
  • 智能家居的OneNet云平台
  • 接口测试的核心思维(基础篇)
  • C语言蓝桥杯真题代码
  • java学习之数据结构:二、链表
  • 第38课 常用快捷操作——双击“鼠标左键”进入Properties Panel
  • 李云泽:对受关税影响较大、经营暂时困难的市场主体,一企一策提供精准服务
  • 大一女生头孢过敏输液室呼救无医护响应,自行拔针仍不幸身亡
  • 日本政府强烈反对美关税政策并要求其取消
  • 德国斯图加特发生车辆冲撞人群事件,至少三人受伤
  • 海港负国安主场两连败,五强争冠卫冕冠军开始掉队
  • 旭辉控股集团:去年收入477.89亿元,长远计划逐步向轻资产业务模式转型