当前位置: 首页 > news >正文

flow-matching 之学习matcha-tts cosyvoice

文章目录

  • matcha 实现
  • cosyvoice 实现
  • chunk_fm
    • chunk_mask
    • cache_attn
  • stream token2wav

  • 关于flow-matching 很好的原理性解释文章, 值得仔细读,多读几遍,关于文章Flow Straight and Fast:
    Learning to Generate and Transfer Data with Rectified Flow 的讲解梳理。

matcha 实现

def fm_comput_loss()# x1 是target_mel# random timestept = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)# sample noise p(x_0)z = torch.randn_like(x1)y = (1 - (1 - self.sigma_min) * t) * z + t * x1u = x1 - (1 - self.sigma_min) * zpred_y = self.estimator(y, mask, mu, t.squeeze(), spks)loss = F.mse_loss(pred_y, u, reduction="sum") / (torch.sum(mask) * u.shape[1])return loss, y
def estimator_forward():x = pack(y, mu)x = pack(x, spks)q,k,v = x, x, xx = slf_attn(q,k,v)outputs = linear(x)return outputs

cosyvoice 实现

def fm_forward():# mu: encoder_outputs# x1: target_mel# cond: prompt_mel 随机取的部分conds = torch.zeros(feat.shape, device=token.device)for i, j in enumerate(feat_len):if random.random() < 0.5:continueindex = random.randint(0, int(0.3 * j))conds[i, :index] = feat[i, :index]conds = conds.transpose(1, 2)b, _, t = mu.shape# random timestept = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)if self.t_scheduler == 'cosine':t = 1 - torch.cos(t * 0.5 * torch.pi)# sample noise p(x_0)z = torch.randn_like(x1)y = (1 - (1 - self.sigma_min) * t) * z + t * x1u = x1 - (1 - self.sigma_min) * z# during training, we randomly drop condition to trade off mode coverage and sample fidelity# inference 的时候实际不需要condition, 给zero就可以if self.training_cfg_rate > 0:cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_ratemu = mu * cfg_mask.view(-1, 1, 1)spks = spks * cfg_mask.view(-1, 1)cond = cond * cfg_mask.view(-1, 1, 1)pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond, streaming=streaming)loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])return loss, ydef estimator(x, mu, spks, cond):x = pack(x, mu, spks, cond)x = slf_attn(x)outputs = linear(x)return outputs

chunk_fm

  • 训练的时候将特征进行chunk_mask,推理的时候只准备chunk的部分,pre_chunk 存为kv_cache,
  • cache 初始seq_len为0;每次得到的cache,只留下[-chunk_len:] 的长度,作为下一次的输入;特征x 的pos 按照真的来算;

chunk_mask

在这里插入图片描述

  • 训练阶段样本按照seq_len 维度被mask 成不同的可见部分;chunk_mask 和长度mask 都会出现,为了加速收敛;

cache_attn

def slf_attn_cache(x, cache):k_in, v_in, q_in = x, x, xkey_cache = linear1(k_in)value_cache = linear2(v_in)# NOTE here we judge cache.size(0) instead of cache.size(1), because init_cache has size (2, 0, 512, 2)if cache.size(0) != 0:# step into this branchkey = torch.concat([cache[:, :, :, 0], key_cache], dim=1)value = torch.concat([cache[:, :, :, 1], value_cache], dim=1)else:key, value = key_cache, value_cachecache = torch.stack([key_cache, value_cache], dim=3)outputs = scale_dot_production(key, value)return outputs, cache

stream token2wav

在这里插入图片描述

  • 第一个包没有kv_cache, 卷积cache 有,但是值为0;first chunk 推理完就可以存下kv_cache & cnn_cache;
  • 输入token+cache_token,得到token 对应的mel;
  • mel2wav 阶段也是,第一次没有hift cache,直接退出mel 对应的wav,最后8帧存下来作为hift_cache,用于hift_wav 预测以及输出的音频片段间平滑;前n-8 帧的音频输出;
def inference(self, speech_feat: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)):# mel->f0 mel:[1,80,T]#print('hift inference speech_feat', speech_feat.size())f0 = self.f0_predictor(speech_feat)# f0->sources = self.f0_upsamp(f0[:, None]).transpose(1, 2)  # bs,n,ts, _, _ = self.m_source(s)s = s.transpose(1, 2) #sample#print('f0 s', s.size()) # sample level# use cache_source to avoid glitchif cache_source.shape[2] != 0:# s[1,1,t*480]s[:, :, :cache_source.shape[2]] = cache_sourceprint('cache_source s', s.size())else:print('cache_source shape2 is 0')#print('hift inference s', s.size())generated_speech = self.decode(x=speech_feat, s=s)return generated_speech, s

相关文章:

  • 企业级UI测试的“双保险”:TestComplete的智能对象识别与详细报告功能
  • 本地聊天机器人部署方案
  • 安卓基础(静态方法)
  • 网络字节序 - 大端
  • Java的对象头:原理与源码详解
  • 定时任务分布式锁SchedulerLock
  • iptables 访问控制列表使用记录
  • Oracle免费认证来袭
  • 国际数字影像产业园,打造金牛区数字文创新地标
  • 堡塔云WAF免费WEB防火墙,从搭建到应用
  • 【Science Advances】北京邮电大学突破:基于MEMS-超表面的多阶涡旋光束高速切换技术
  • 枚举 · 例8扩展-校门外的树:hard
  • Java:跨越时代的编程语言,持续引领技术变革
  • 每日学习Java之一万个为什么(待完善)
  • ABP vNext + Dapr 实现云原生微服务治理
  • NGINX `ngx_http_gzip_static_module` 零时延送出预压缩文件
  • 没有Mac,我是怎么上传IPA到App Store的?
  • 15.thinkphp的上传功能
  • CAP理论:分布式系统的权衡
  • K8S - 蓝绿发布实战 - Argo Rollouts 零停机方案解析
  • 游客称在网红雪山勒多曼因峰需救援被开价2.8万,康定文旅:封闭整改
  • 赵作海因病离世,妻子李素兰希望过平静生活
  • 美国与胡塞武装达成停火协议,美伊相向而行?
  • 从“重规模”向“重回报”转变,公募基金迎系统性改革
  • 李云泽:支持设立新的金融资产投资公司,今天即将批复一家
  • 云南昆磨高速发生交通事故致3辆车起火,昆明消防:幸无人员伤亡