当前位置: 首页 > news >正文

语言模型transformers调用部分 (To be continue...

什么?!!!原来自回归模型的model.generate不能用于训练!!??

只能用法forward一次生成,但一次性只能得到一个tensor
就是在这里取最大值导致模型梯度断了,所以不能用model.generate来训练,要训练只能用model.forward

next_tokens = torch.argmax(next_tokens_scores, dim=-1)#返回指定维度最大值的序号
inverted_mask = 1.0 - attention_mask
attention_mask = inverted_mask.masked_fill( #用value填充tensor中与mask中值为1位置相对应的元素
           inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min)

model.generatemain.py的主函数中调用,然后转到transformers/generation/utils.py
这个文件的Class GenerationMixindef generate函数–>然后进入def greedy_search
其中while True: # 在这个while循环里实现自回归
其中还有一个self.prepare_inputs_for_generation(input_ids, **model_kwargs)用于处理model运行的输入,# 有的在main.py中重写这个函数。

Class GenerationMixin:中用来得到每个预测token的得分函数 # 有的会重写或调用这个函数
def compute_transition_scores()# 用法示例

        >>> from transformers import GPT2Tokenizer, AutoModelForCausalLM
        >>> import numpy as np

        >>> tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
        >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
        >>> tokenizer.pad_token_id = tokenizer.eos_token_id
        >>> inputs = tokenizer(["Today is"], return_tensors="pt")

        >>> # Example 1: 打印每个token的得分 with Greedy Search
        >>> outputs = model.generate(**inputs, max_new_tokens=5, return_dict_in_generate=True, output_scores=True)
        >>> transition_scores = model.compute_transition_scores(# 有时候会在外面重写这个函数
        ...     outputs.sequences, outputs.scores, normalize_logits=True
        ... )
        >>> # decoder-only models, like the GPT family, and 1;
        >>> # encoder-decoder models, like BART or T5.
        # 也就是说encoder-decoder就是答案从头开始; decoder-only是答案在输入后边接着
        >>> input_length = 1 if model.config.is_encoder_decoder else inputs.input_ids.shape[1]
        >>> generated_tokens = outputs.sequences[:, input_length:]
        
        >>> for tok, score in zip(generated_tokens[0], transition_scores[0]):
        ...     # | token | token string | log probability | probability
        ...     print(f"| {tok:5d} | {tokenizer.decode(tok):8s} | {score.numpy():.3f} | {np.exp(score.numpy()):.2%}")
        |   262 |  the     | -1.414 | 24.33%
        |  1110 |  day     | -2.609 | 7.36%
        |   618 |  when    | -2.010 | 13.40%
        |   356 |  we      | -1.859 | 15.58%
        |   460 |  can     | -2.508 | 8.14%

       

相关文章:

  • 数据库中冗余字段
  • Java 自定义线程池实现
  • [运维] 可视化爬虫易采集-EasySpider(笔记)
  • CSS案例-2.简单版侧边栏练习
  • 通过dbeaver链接dm8数据库
  • redis优化token校验主动失效
  • 基于UDP的网络聊天室
  • rtt的io设备框架面向对象学习-内部调用流程
  • Linux相关命令(2)
  • jackson:JSON字符串(String)类型的成员序列化和反序列化
  • 【运维笔记】VM 记录一次centos虚拟机和宿主机之间ping不通的问题
  • 基于 HBase Phoenix 构建实时数仓(5)—— 用 Kafka Connect 做实时数据同步
  • 【Redis知识点总结】(七)——缓存雪崩、缓存穿透、缓存击穿、Redis高级用法
  • ReaLTaiizor开源.NET winform控件库学习使用
  • Redis 不再“开源”,对中国的影响及应对方案
  • docker仓库登录及配置insecure-registries的方法
  • python基础——数据容器总结、通用方法和相互转换
  • (一)Linux+Windows下安装ffmpeg
  • 【Golang星辰图】创造美丽图表,洞察数据:解析Go语言中的数据可视化和数据分析库
  • 一次完整的 HTTP 请求所经历的步骤
  • 上海发布首份直播电商行业自律公约,禁止虚假宣传、商业诋毁
  • “浦东时刻”在京展出:沉浸式体验海派风情
  • 代理销售保险存在误导行为,农业银行重庆市分行相关负责人被罚款0.1万元
  • 援藏博士张兴堂已任西藏农牧学院党委书记、副校长
  • 万达电影:股东杭州臻希拟减持不超1.3927%公司股份
  • 上海“随申兑”服务平台有哪些功能?已归集800余个惠企政策