提交 9b8fd9f9 编写于 作者: Y Yibing Liu

Upgrade waveflow to 1.8.0

上级 8716a184
......@@ -348,7 +348,7 @@ class WaveFlowModule(dg.Layer):
mel = self.conditioner(mel)
assert mel.shape[2] >= audio.shape[1]
# Prune out the tail of audio/mel so that time/n_group == 0.
pruned_len = audio.shape[1] // self.n_group * self.n_group
pruned_len = int(audio.shape[1] // self.n_group * self.n_group)
if audio.shape[1] > pruned_len:
audio = audio[:, :pruned_len]
......
......@@ -87,7 +87,14 @@ def compute_l2_normalized_weight(v, g, dim):
def compute_weight(v, g, dim, power):
assert len(g.shape) == 1, "magnitude should be a vector"
if power == 2:
return compute_l2_normalized_weight(v, g, dim)
in_dtype = v.dtype
if in_dtype == fluid.core.VarDesc.VarType.FP16:
v = F.cast(v, "float32")
g = F.cast(g, "float32")
weight = compute_l2_normalized_weight(v, g, dim)
if in_dtype == fluid.core.VarDesc.VarType.FP16:
weight = F.cast(weight, "float16")
return weight
else:
v_normalized = F.elementwise_div(
v, (norm_except(v, dim, power) + 1e-12), axis=dim)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册