提交 45af3a43 编写于 作者: C chenfeiyu

fix WeightNormWrapper, stop using CacheDataset for deep voice 3, pin numba version to 0.47.0

上级 b7c584e2
...@@ -230,7 +230,7 @@ def make_data_loader(data_root, config): ...@@ -230,7 +230,7 @@ def make_data_loader(data_root, config):
ref_level_db=c["ref_level_db"], ref_level_db=c["ref_level_db"],
max_norm=c["max_norm"], max_norm=c["max_norm"],
clip_norm=c["clip_norm"]) clip_norm=c["clip_norm"])
ljspeech = CacheDataset(TransformDataset(meta, transform)) ljspeech = TransformDataset(meta, transform)
# use meta data's text length as a sort key for the sampler # use meta data's text length as a sort key for the sampler
batch_size = config["train"]["batch_size"] batch_size = config["train"]["batch_size"]
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import numpy as np
from paddle import fluid from paddle import fluid
import paddle.fluid.dygraph as dg import paddle.fluid.dygraph as dg
import paddle.fluid.layers as F import paddle.fluid.layers as F
...@@ -44,10 +43,10 @@ def norm_except(param, dim, power): ...@@ -44,10 +43,10 @@ def norm_except(param, dim, power):
if dim is None: if dim is None:
return norm(param, dim, power) return norm(param, dim, power)
elif dim == 0: elif dim == 0:
param_matrix = F.reshape(param, (shape[0], np.prod(shape[1:]))) param_matrix = F.reshape(param, (shape[0], -1))
return norm(param_matrix, dim=1, power=power) return norm(param_matrix, dim=1, power=power)
elif dim == -1 or dim == ndim - 1: elif dim == -1 or dim == ndim - 1:
param_matrix = F.reshape(param, (np.prod(shape[:-1]), shape[-1])) param_matrix = F.reshape(param, (-1, shape[-1]))
return norm(param_matrix, dim=0, power=power) return norm(param_matrix, dim=0, power=power)
else: else:
perm = list(range(ndim)) perm = list(range(ndim))
...@@ -62,24 +61,26 @@ def compute_l2_normalized_weight(v, g, dim): ...@@ -62,24 +61,26 @@ def compute_l2_normalized_weight(v, g, dim):
ndim = len(shape) ndim = len(shape)
if dim is None: if dim is None:
v_normalized = v / (F.reduce_sum(F.square(v)) + 1e-12) v_normalized = v / (F.sqrt(F.reduce_sum(F.square(v))) + 1e-12)
elif dim == 0: elif dim == 0:
param_matrix = F.reshape(v, (shape[0], np.prod(shape[1:]))) param_matrix = F.reshape(v, (shape[0], -1))
v_normalized = F.l2_normalize(param_matrix, axis=1) v_normalized = F.l2_normalize(param_matrix, axis=1)
v_normalized = F.reshape(v_normalized, shape)
elif dim == -1 or dim == ndim - 1: elif dim == -1 or dim == ndim - 1:
param_matrix = F.reshape(v, (np.prod(shape[:-1]), shape[-1])) param_matrix = F.reshape(v, (-1, shape[-1]))
v_normalized = F.l2_normalize(param_matrix, axis=0) v_normalized = F.l2_normalize(param_matrix, axis=0)
v_normalized = F.reshape(v_normalized, shape)
else: else:
perm = list(range(ndim)) perm = list(range(ndim))
perm[0] = dim perm[0] = dim
perm[dim] = 0 perm[dim] = 0
transposed_param = F.transpose(v, perm) transposed_param = F.transpose(v, perm)
param_matrix = F.reshape( transposed_shape = transposed_param.shape
transposed_param, param_matrix = F.reshape(transposed_param,
(transposed_param.shape[0], np.prod(transposed_param.shape[1:]))) (transposed_param.shape[0], -1))
v_normalized = F.l2_normalize(param_matrix, axis=1) v_normalized = F.l2_normalize(param_matrix, axis=1)
v_normalized = F.reshape(v_normalized, transposed_shape)
v_normalized = F.transpose(v_normalized, perm) v_normalized = F.transpose(v_normalized, perm)
v_normalized = F.reshape(v_normalized, shape)
weight = F.elementwise_mul(v_normalized, g, axis=dim) weight = F.elementwise_mul(v_normalized, g, axis=dim)
return weight return weight
......
...@@ -55,7 +55,7 @@ setup_info = dict( ...@@ -55,7 +55,7 @@ setup_info = dict(
'inflect', 'inflect',
'librosa', 'librosa',
'unidecode', 'unidecode',
'numba==0.48.0', 'numba==0.47.0',
'tqdm==4.19.8', 'tqdm==4.19.8',
'matplotlib', 'matplotlib',
'tensorboardX', 'tensorboardX',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册