modelscope-funasr热词激励的模型训练阶段的任务会不会导致语音字数预测受到影响? def calc_predictor(self, encoder_out, encoder_out_lens):
encoder_out_mask = (
~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]
).to(encoder_out.device)
predictor_outs = self.predictor(
encoder_out, None, encoder_out_mask, ignore_id=self.ignore_id
)
return predictor_outs[:4]
def _calc_seaco_loss(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_lengths: torch.Tensor,
hotword_pad: torch.Tensor,
hotword_lengths: torch.Tensor,
seaco_label_pad: torch.Tensor,
):
# predictor forward
encoder_out_mask = (
~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]
).to(encoder_out.device)
pre_acoustic_embeds = self.predictor(
encoder_out, ys_pad, encoder_out_mask, ignore_id=self.ignore_id
)[0]
# decoder forward
decoder_out, _ = self.decoder(
encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_lengths, return_hidden=True
)
selected = self._hotword_representation(hotword_pad, hotword_lengths)
contextual_info = (
selected.squeeze(0).repeat(encoder_out.shape[0], 1, 1).to(encoder_out.device)
)
num_hot_word = contextual_info.shape[1]
_contextual_length = (
torch.Tensor([num_hot_word]).int().repeat(encoder_out.shape[0]).to(encoder_out.device)
)
# dha core
cif_attended, _ = self.seaco_decoder(
contextual_info, _contextual_length, pre_acoustic_embeds, ys_lengths
)
dec_attended, _ = self.seaco_decoder(
contextual_info, _contextual_length, decoder_out, ys_lengths
)
merged = self._merge(cif_attended, dec_attended)
dha_output = self.hotword_output_layer(
merged[:, :-1]
) # remove the last token in loss