提交 17f2944a 编写于 作者: Z zoooo0820

fix error in tts/st

上级 5f53e902
......@@ -252,7 +252,7 @@ class STExecutor(BaseExecutor):
norm_feat = dict(kaldiio.load_ark(process.stdout))[utt_name]
self._inputs["audio"] = paddle.to_tensor(norm_feat).unsqueeze(0)
self._inputs["audio_len"] = paddle.to_tensor(
self._inputs["audio"].shape[1], dtype="int64")
self._inputs["audio"].shape[1:2], dtype="int64")
else:
raise ValueError("Wrong model type.")
......
......@@ -491,7 +491,7 @@ class TTSExecutor(BaseExecutor):
# multi speaker
if am_dataset in {'aishell3', 'vctk', 'mix', 'canton'}:
mel = self.am_inference(
part_phone_ids, spk_id=paddle.to_tensor(spk_id))
part_phone_ids, spk_id=paddle.to_tensor([spk_id]))
else:
mel = self.am_inference(part_phone_ids)
self.am_time += (time.time() - am_st)
......
......@@ -783,7 +783,7 @@ class FastSpeech2(nn.Layer):
x = paddle.cast(text, 'int64')
d, p, e = durations, pitch, energy
# setup batch axis
ilens = paddle.shape(x)[0]
ilens = paddle.shape(x)[0:1]
xs = x.unsqueeze(0)
......
......@@ -181,7 +181,7 @@ def make_pad_mask(lengths, xs=None, length_dim=-1):
if length_dim == 0:
raise ValueError("length_dim cannot be 0: {}".format(length_dim))
bs = paddle.shape(lengths)[0]
bs = paddle.shape(lengths)
if xs is None:
maxlen = paddle.cast(lengths.max(), dtype=bs.dtype)
else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册