【深度学习的灵魂】图片布局生成模型LayoutPrompt(1)
🌈 个人主页:十二月的猫-CSDN博客
🔥 系列专栏: 🏀《深度学习理论直觉三十讲》_十二月的猫的博客-CSDN博客💪🏻 十二月的寒冬阻挡不了春天的脚步,十二点的黑夜遮蔽不住黎明的曙光
目录
1. 前言
2. LayoutPrompt介绍
3. LayoutPrompt · 布局序列化模块
3.1 固定化Prompt
3.2 父类序列化模块
3.3 子类序列化模块
4. 总结
1. 前言
猫猫不知道大家有没有思考过图片布局生成模型,这算是生成模型的一个非常小的子任务了。前面带大家学习过生成模型,包括GAN、Diffusion等。这些都算是生成模型的研究子领域,利用这些子领域的知识,我们可以来研究具体的任务,例如生成图片、按照语言提示生成图片、按照布局提示生成图片等。
同样生成布局也是生成模型中的一个具体任务,其实思路也是非常简单。生成图片这个任务中更具体的任务是生成海报、生成照片、生成动漫图片等。因为直接生成一个海报难度太大,我们就先去生成布局,然后在具体布局的约束下去具体生成完整图片。可以简单给大家看一下布局是什么:
猫猫研究这个呢,主要还是因为创新实训和软件创新大赛两个需要。同时由于猫猫也研究过一点生成模型,因此这个学期就研究一下这个领域啦~~如果大家也对这个领域感兴趣,可以关注我们团队的专栏(布局生成模型是我们海报生成系统中的一个子模块,希望更多猫友参与到我们的开发当中哦):大模型实战训练营_十二月的猫的博客-CSDN博客
2. LayoutPrompt介绍
总的来说,为了完成LayoutPrompt模型,我们需要完成的子模块有:
- 数据预处理模块:该模块主要就是利用基础数据预处理方法对数据集中的所有数据样本进行预处理。
- 动态样本选择模块:从训练集中检索最相关的样本,然后作为最直接的上下文(约束信息)送给大语言模型。
- 布局序列化模块:用于将上面所选的样本布局转化为序列表述(因为大语言模型对序列化输入有更好的效果)。序列化数据就是类似 自然语言、代码等。
- 大语言模型模块:将序列化处理后的所有样本一起送给大语言模型,让大语言模型参考的情况下给出自己的答案。
- 大语言模型解析模块:用于将大语言模型给出的布局结果解析为标准化的输出。
- 布局排序模块(布局评价模块):评价大语言模型生成的布局的质量分数,并做一个排序。
具体的模型图如下:
模型运行具体流程如下:
- 用户输入前导信息,例如画布大小,任务类型等,数据预处理模块预处理数据库(用户根据自己的任务可以选择数据库,如海报布局数据库、手机UI设计数据库)中的所有数据。
- 动态样本选择模块得到处理后的数据库数据,然后根据用户输入的前导信息选择合适的example样本。
- 将样本+前导信息+测试样本送给大语言模型。
- 由大语言模型生成最终的layout,送给Rank模块。
- Rank模块排序后分数最高的就是最终输出。
3. LayoutPrompt · 布局序列化模块
布局序列化模块:本质就是序列化+Prompt。两者核心都是固定。
序列化:将输入输出按照固定序列格式调整。
例如输出固定如下:
- "标题 0 0.1 0.2 0.3 0.4 | 正文 1 0.5 0.6 0.7 0.8"
# < html > # < body > # < div # style = "..." > 标题_0 < / div > # < div # style = "..." > 正文_1 < / div > # < / body > # < / htmlPrompt:根据不同任务,输入需要不同Prompt,同时序列化格式。
例如输入固定如下:
![]()
一句话来说,布局序列化模块准备的就是模型中的这一部分:
- 前导部分(PREAMBLE):固定Prompt,用户输入画布大小高度等。
- 输入限制(INPUT CONSTRAINT) :会有两种表示形式1.如上图的seq形式;2.html形式。
- 输出布局(OUTPUT LAYOUT):得到布局的坐标data后(由其他模块负责),序列化输出上图结果。
从代码角度来说分为三个部分:
- 固定化Prompt。
- 父类序列化模块(固定输出序列结构,输入序列留接口给子类实现)
- 子类序列化模块(一个子类对应一个具体的任务,不同任务输入结构不一样)
从任务角度存在以下七种:
- 元素类型任务(限制layout中的元素)
- 元素类型,元素尺寸限制任务
- 元素类型,元素之间位置关系
- 元素补全(根据部分已知布局元素,生成完整布局结构)
- 元素尺寸修正(给出类型和尺寸,模型自己修正尺寸)
- 防止遮挡的布局生成
- 文本描述下的布局生成
3.1 固定化Prompt
PREAMBLE = ("Please generate a layout based on the given information. ""You need to ensure that the generated layout looks realistic, with elements well aligned and avoiding unnecessary overlap.\n""Task Description: {}\n""Layout Domain: {} layout\n""Canvas Size: canvas width is {}px, canvas height is {}px"
)
# html的头部
HTML_PREFIX = """<html>
<body>
<div class="canvas" style="left: 0px; top: 0px; width: {}px; height: {}px"></div>
"""
# html的结尾
HTML_SUFFIX = """</body>
</html>"""# html的body
HTML_TEMPLATE = """<div class="{}" style="left: {}px; top: {}px; width: {}px; height: {}px"></div>
"""HTML_TEMPLATE_WITH_INDEX = """<div class="{}" style="index: {}; left: {}px; top: {}px; width: {}px; height: {}px"></div>
"""
- 本部分主要是为了固定输入输出的一些序列化格式,同时固定前导部分。
- 用户在这里通过前端输入自己想要的画布大小高度等信息。
3.2 父类序列化模块
# 序列化的父类,后面有具体任务不同的序列化
# 所谓序列化本质就是固定输入的prompt形式,同时固定输出的一个形式。
class Serializer:def __init__(self,input_format: str,output_format: str,index2label: dict,canvas_width: int,canvas_height: int,add_index_token: bool = True,add_sep_token: bool = True,sep_token: str = "|",add_unk_token: bool = False,unk_token: str = "<unk>",):self.input_format = input_formatself.output_format = output_formatself.index2label = index2labelself.canvas_width = canvas_widthself.canvas_height = canvas_heightself.add_index_token = add_index_tokenself.add_sep_token = add_sep_tokenself.sep_token = sep_tokenself.add_unk_token = add_unk_tokenself.unk_token = unk_tokendef build_input(self, data):if self.input_format == "seq":return self._build_seq_input(data)elif self.input_format == "html":return self._build_html_input(data)else:raise ValueError(f"Unsupported input format: {self.input_format}")# check value is not nulldef _build_seq_input(self, data):raise NotImplementedErrordef _build_html_input(self, data):raise NotImplementedErrordef build_output(self, data, label_key="labels", bbox_key="discrete_gold_bboxes"):if self.output_format == "seq":return self._build_seq_output(data, label_key, bbox_key)elif self.output_format == "html":return self._build_html_output(data, label_key, bbox_key)# # 输入数据结构示例# data = {# "labels": [0, 1], # 标签索引# "discrete_gold_bboxes": [ # 坐标列表(假设已离散化)# [0.1, 0.2, 0.3, 0.4],# [0.5, 0.6, 0.7, 0.8]# ]# }# "标题 0 0.1 0.2 0.3 0.4 | 正文 1 0.5 0.6 0.7 0.8"def _build_seq_output(self, data, label_key, bbox_key):# 在字典中存储的标签信息,和边框信息(不是存储具体字,而是存储信息)labels = data[label_key]bboxes = data[bbox_key]tokens = []for idx in range(len(labels)):label = self.index2label[int(labels[idx])]bbox = bboxes[idx].tolist()tokens.append(label)if self.add_index_token:tokens.append(str(idx))tokens.extend(map(str, bbox)) # extend一次性添加很多值,append一次添加一个值。map(function,list):把function作用在list上。str():将其他类型转为str类型。if self.add_sep_token and idx < len(labels) - 1:tokens.append(self.sep_token) # 添加隔离符号return " ".join(tokens)# # 输出结构# < html ># < body ># < div# style = "..." > 标题_0 < / div ># < div# style = "..." > 正文_1 < / div ># < / body ># < / html >def _build_html_output(self, data, label_key, bbox_key):labels = data[label_key]bboxes = data[bbox_key]htmls = [HTML_PREFIX.format(self.canvas_width, self.canvas_height)] # 使用 HTML_PREFIX 作为 HTML 页面的开头,并将画布宽度和高度传递进去。_TEMPLATE = HTML_TEMPLATE_WITH_INDEX if self.add_index_token else HTML_TEMPLATE # 根据 add_index_token 决定使用哪种模板:带索引的模板或普通模板。for idx in range(len(labels)):label = self.index2label[int(labels[idx])]bbox = bboxes[idx].tolist()element = [label]if self.add_index_token:element.append(str(idx))element.extend(map(str, bbox))htmls.append(_TEMPLATE.format(*element))htmls.append(HTML_SUFFIX)return "".join(htmls)
- 主要定义两种输出方式的序列化。第一种是以seq的形式输出;第二种是以html的形式输出。
- 输入形式的序列化(prompt+序列化)仅仅定义了模板交给子类(具体任务)来实现。
- expand和append都是列表后追加,但expand一次追加很多个元素,append一次加一个。
- map(function,list):将function作用在所有的list元素上。
- 作用:1. 得到其他模块给的Layout坐标data后,将其转化为html或seq的固定序列化格式输出。2.作为父类将输入序列化交给子类实现。
3.3 子类序列化模块
前面也说了,子类序列化模块需要根据不同的任务要求给出不同的prompt以及序列化输入。因此有多少个任务就会有多少个子类序列化模块。在这里,猫猫仅仅展示两个模块的代码,完整的代码等专栏更新结束后,会同步放在Gitee以及CSDN账号下。
任务一:限制元素类型的布局生成
class GenTypeSerializer(Serializer):task_type = "generation conditioned on given element types"constraint_type = ["Element Type Constraint: "]HTML_TEMPLATE_WITHOUT_ANK = '<div class="{}"></div>\n'HTML_TEMPLATE_WITHOUT_ANK_WITH_INDEX = '<div class="{}" style="index: {}"></div>\n'def _build_seq_input(self, data):labels = data["labels"]tokens = []for idx in range(len(labels)):label = self.index2label[int(labels[idx])]tokens.append(label)if self.add_index_token:tokens.append(str(idx))if self.add_unk_token:tokens += [self.unk_token] * 4if self.add_sep_token and idx < len(labels) - 1:tokens.append(self.sep_token)return " ".join(tokens)def _build_html_input(self, data):labels = data["labels"]htmls = [HTML_PREFIX.format(self.canvas_width, self.canvas_height)]if self.add_index_token and self.add_unk_token:_TEMPLATE = HTML_TEMPLATE_WITH_INDEXelif self.add_index_token and not self.add_unk_token:_TEMPLATE = self.HTML_TEMPLATE_WITHOUT_ANK_WITH_INDEXelif not self.add_index_token and self.add_unk_token:_TEMPLATE = HTML_TEMPLATEelse:_TEMPLATE = self.HTML_TEMPLATE_WITHOUT_ANKfor idx in range(len(labels)):label = self.index2label[int(labels[idx])]element = [label]if self.add_index_token:element.append(str(idx))if self.add_unk_token:element += [self.unk_token] * 4htmls.append(_TEMPLATE.format(*element))htmls.append(HTML_SUFFIX)return "".join(htmls)def build_input(self, data):return self.constraint_type[0] + super().build_input(data)
任务二:限制元素类型,以及元素关系的布局生成
class GenRelationSerializer(Serializer):task_type = ("generation conditioned on given element relationships\n""'A left B' means that the center coordinate of A is to the left of the center coordinate of B. ""'A right B' means that the center coordinate of A is to the right of the center coordinate of B. ""'A top B' means that the center coordinate of A is above the center coordinate of B. ""'A bottom B' means that the center coordinate of A is below the center coordinate of B. ""'A center B' means that the center coordinate of A and the center coordinate of B are very close. ""'A smaller B' means that the area of A is smaller than the ares of B. ""'A larger B' means that the area of A is larger than the ares of B. ""'A equal B' means that the area of A and the ares of B are very close. ""Here, center coordinate = (left + width / 2, top + height / 2), ""area = width * height")constraint_type = ["Element Type Constraint: ", "Element Relationship Constraint: "]HTML_TEMPLATE_WITHOUT_ANK = '<div class="{}"></div>\n'HTML_TEMPLATE_WITHOUT_ANK_WITH_INDEX = '<div class="{}" style="index: {}"></div>\n'def __init__(self, *args, **kwargs):super().__init__(*args, **kwargs)self.index2type = RelationTypes.index2type()def _build_seq_input(self, data):labels = data["labels"]relations = data["relations"]tokens = []for idx in range(len(labels)):label = self.index2label[int(labels[idx])]tokens.append(label)if self.add_index_token:tokens.append(str(idx))if self.add_unk_token:tokens += [self.unk_token] * 4if self.add_sep_token and idx < len(labels) - 1:tokens.append(self.sep_token)type_cons = " ".join(tokens)if len(relations) == 0:return self.constraint_type[0] + type_constokens = []for idx in range(len(relations)):label_i = relations[idx][2]index_i = relations[idx][3]if label_i != 0:tokens.append("{} {}".format(self.index2label[int(label_i)], index_i))else:tokens.append("canvas")tokens.append(self.index2type[int(relations[idx][4])])label_j = relations[idx][0]index_j = relations[idx][1]if label_j != 0:tokens.append("{} {}".format(self.index2label[int(label_j)], index_j))else:tokens.append("canvas")if self.add_sep_token and idx < len(relations) - 1:tokens.append(self.sep_token)relation_cons = " ".join(tokens)return (self.constraint_type[0]+ type_cons+ "\n"+ self.constraint_type[1]+ relation_cons)def _build_html_input(self, data):labels = data["labels"]relations = data["relations"]htmls = [HTML_PREFIX.format(self.canvas_width, self.canvas_height)]if self.add_index_token and self.add_unk_token:_TEMPLATE = HTML_TEMPLATE_WITH_INDEXelif self.add_index_token and not self.add_unk_token:_TEMPLATE = self.HTML_TEMPLATE_WITHOUT_ANK_WITH_INDEXelif not self.add_index_token and self.add_unk_token:_TEMPLATE = HTML_TEMPLATEelse:_TEMPLATE = self.HTML_TEMPLATE_WITHOUT_ANKfor idx in range(len(labels)):label = self.index2label[int(labels[idx])]element = [label]if self.add_index_token:element.append(str(idx))if self.add_unk_token:element += [self.unk_token] * 4htmls.append(_TEMPLATE.format(*element))htmls.append(HTML_SUFFIX)type_cons = "".join(htmls)if len(relations) == 0:return self.constraint_type[0] + type_constokens = []for idx in range(len(relations)):label_i = relations[idx][2]index_i = relations[idx][3]if label_i != 0:tokens.append("{} {}".format(self.index2label[int(label_i)], index_i))else:tokens.append("canvas")tokens.append(self.index2type[int(relations[idx][4])])label_j = relations[idx][0]index_j = relations[idx][1]if label_j != 0:tokens.append("{} {}".format(self.index2label[int(label_j)], index_j))else:tokens.append("canvas")if self.add_sep_token and idx < len(relations) - 1:tokens.append(self.sep_token)relation_cons = " ".join(tokens)return (self.constraint_type[0]+ type_cons+ "\n"+ self.constraint_type[1]+ relation_cons)
- 针对此任务设计了具体的Prompt形式。
- 输入数据序列化有两个部分:1、seq序列化;2、html序列化
- seq序列化:用户输入限制后,需要结合Prompt序列化到INPUT_CONSTRAINT(
- (注意:这个seq并不是完整的,仅仅是将用户的限制要求填入):
- html序列化:根据用户输入限制生成对应的html格式(注意:这个html并不是完整的,仅仅是将用户的限制要求填入)
- 作用:1.将用户输入的限制转化为某种序列化(seq序列化 或 html序列化)。2.结合固定Prompt包装成完整的INPUT CONSTRAINT
4. 总结
本篇文章带大家深入了解了PosterGenius项目的Layout生成部分的第一篇,后续将更新Layout系列的第二篇。欢迎大家继续支持猫猫呀!!
【如果想学习更多深度学习文章,可以订阅一下热门专栏】
- 《PyTorch科研加速指南:即插即用式模块开发》_十二月的猫的博客-CSDN博客
- 《深度学习理论直觉三十讲》_十二月的猫的博客-CSDN博客
- 《AI认知筑基三十讲》_十二月的猫的博客-CSDN博客
如果想要学习更多pyTorch/python编程的知识,大家可以点个关注并订阅,持续学习、天天进步你的点赞就是我更新的动力,如果觉得对你有帮助,辛苦友友点个赞,收个藏呀~~~