未验证 提交 171ec35f 编写于 作者: F Feiyu Chan 提交者: GitHub

Dv3: update for Variable & Embedding update in dygraph

1. use set_value() to set weight.
2. use weight in PositionEmbedding instead of protected member `_w`.
3. fix data layout for Embedding, the tailing sized-1 dimension is removed, update documentation, 
4. fix tensorboard tags, spaces are replaced with `_`.
上级 b3a14f46
...@@ -10,7 +10,7 @@ Paddle 实现的 Deepvoice3,一个基于卷积神经网络的语音合成 (Tex ...@@ -10,7 +10,7 @@ Paddle 实现的 Deepvoice3,一个基于卷积神经网络的语音合成 (Tex
### 安装 paddlepaddle 框架 ### 安装 paddlepaddle 框架
为了更快的训练速度和更好的支持,我们推荐使用最新的开发版 paddle。用户可以最新编译的开发版 whl 包,也可以选择从源码编译 Paddle。 为了更快的训练速度和更好的支持,我们推荐使用最新的 Paddle 开发版。用户也可以最新编译的开发版 whl 包,也可以选择从源码编译 Paddle。
1. 下载最新编译的开发版 whl 包。可以从 [**多版本 wheel 包列表-dev**](https://www.paddlepaddle.org.cn/documentation/docs/zh/beginners_guide/install/Tables.html#whl-dev) 页面中选择合适的版本。 1. 下载最新编译的开发版 whl 包。可以从 [**多版本 wheel 包列表-dev**](https://www.paddlepaddle.org.cn/documentation/docs/zh/beginners_guide/install/Tables.html#whl-dev) 页面中选择合适的版本。
......
...@@ -31,7 +31,7 @@ class Conv1D(dg.Layer): ...@@ -31,7 +31,7 @@ class Conv1D(dg.Layer):
def __init__(self, def __init__(self,
name_scope, name_scope,
in_cahnnels, in_channels,
num_filters, num_filters,
filter_size=3, filter_size=3,
dilation=1, dilation=1,
...@@ -49,7 +49,7 @@ class Conv1D(dg.Layer): ...@@ -49,7 +49,7 @@ class Conv1D(dg.Layer):
else: else:
padding = (dilation * (filter_size - 1)) // 2 padding = (dilation * (filter_size - 1)) // 2
self.in_channels = in_cahnnels self.in_channels = in_channels
self.num_filters = num_filters self.num_filters = num_filters
self.filter_size = filter_size self.filter_size = filter_size
self.dilation = dilation self.dilation = dilation
......
...@@ -294,7 +294,6 @@ def create_batch(batch): ...@@ -294,7 +294,6 @@ def create_batch(batch):
text_positions = np.array( text_positions = np.array(
[_pad(np.arange(1, len(x[0]) + 1), max_input_len) for x in batch], [_pad(np.arange(1, len(x[0]) + 1), max_input_len) for x in batch],
dtype=np.int64) dtype=np.int64)
text_positions = np.expand_dims(text_positions, axis=-1)
max_decoder_target_len = max_target_len // r // downsample_step max_decoder_target_len = max_target_len // r // downsample_step
...@@ -304,7 +303,6 @@ def create_batch(batch): ...@@ -304,7 +303,6 @@ def create_batch(batch):
np.expand_dims( np.expand_dims(
np.arange( np.arange(
s, e, dtype=np.int64), axis=0), (len(batch), 1)) s, e, dtype=np.int64), axis=0), (len(batch), 1))
frame_positions = np.expand_dims(frame_positions, axis=-1)
# done flags # done flags
done = np.array([ done = np.array([
......
...@@ -591,10 +591,10 @@ class Decoder(dg.Layer): ...@@ -591,10 +591,10 @@ class Decoder(dg.Layer):
of text inputs for each example. of text inputs for each example.
inputs (Variable): Shape(B, C_mel, 1, T_mel), ground truth inputs (Variable): Shape(B, C_mel, 1, T_mel), ground truth
mel-spectrogram, which is used as decoder inputs when training. mel-spectrogram, which is used as decoder inputs when training.
text_positions (Variable): Shape(B, T_enc, 1), dtype: int64. text_positions (Variable): Shape(B, T_enc), dtype: int64.
Positions indices for text inputs for the encoder, where Positions indices for text inputs for the encoder, where
T_enc means the encoder timesteps. T_enc means the encoder timesteps.
frame_positions (Variable): Shape(B, T_dec // r, 1), dtype: frame_positions (Variable): Shape(B, T_dec // r), dtype:
int64. Positions indices for each decoder time steps. int64. Positions indices for each decoder time steps.
speaker_embed: shape(batch_size, speaker_dim), speaker embedding, speaker_embed: shape(batch_size, speaker_dim), speaker embedding,
only used for multispeaker model. only used for multispeaker model.
...@@ -717,7 +717,7 @@ class Decoder(dg.Layer): ...@@ -717,7 +717,7 @@ class Decoder(dg.Layer):
values (Variable): shape(B, C_emb, 1, T_enc), the value values (Variable): shape(B, C_emb, 1, T_enc), the value
representation from an encoder, where C_emb means representation from an encoder, where C_emb means
text embedding size. text embedding size.
text_positions (Variable): Shape(B, T_enc, 1), dtype: int64. text_positions (Variable): Shape(B, T_enc), dtype: int64.
Positions indices for text inputs for the encoder, where Positions indices for text inputs for the encoder, where
T_enc means the encoder timesteps. T_enc means the encoder timesteps.
...@@ -789,7 +789,7 @@ class Decoder(dg.Layer): ...@@ -789,7 +789,7 @@ class Decoder(dg.Layer):
while True: while True:
frame_pos = fluid.layers.fill_constant( frame_pos = fluid.layers.fill_constant(
shape=[B, 1, 1], value=t + 1, dtype="int64") shape=[B, 1], value=t + 1, dtype="int64")
w = self.query_position_rate w = self.query_position_rate
if self.n_speakers > 1: if self.n_speakers > 1:
w = w * fluid.layers.reshape( w = w * fluid.layers.reshape(
...@@ -1222,7 +1222,7 @@ class DeepVoiceTTS(dg.Layer): ...@@ -1222,7 +1222,7 @@ class DeepVoiceTTS(dg.Layer):
Encode text sequence and decode with ground truth mel spectrogram. Encode text sequence and decode with ground truth mel spectrogram.
Args: Args:
text_sequences (Variable): Shape(B, T_enc, 1), dtype: int64. Ihe text_sequences (Variable): Shape(B, T_enc), dtype: int64. Ihe
input text indices. T_enc means the timesteps of text_sequences. input text indices. T_enc means the timesteps of text_sequences.
valid_lengths (Variable): shape(batch_size,), dtype: int64, valid_lengths (Variable): shape(batch_size,), dtype: int64,
valid lengths for each example in text_sequences. valid lengths for each example in text_sequences.
...@@ -1231,10 +1231,10 @@ class DeepVoiceTTS(dg.Layer): ...@@ -1231,10 +1231,10 @@ class DeepVoiceTTS(dg.Layer):
speaker_indices (Variable, optional): Shape(Batch_size), speaker_indices (Variable, optional): Shape(Batch_size),
dtype: int64. Speaker index for each example. This arg is not dtype: int64. Speaker index for each example. This arg is not
None only when the model is a multispeaker model. None only when the model is a multispeaker model.
text_positions (Variable): Shape(B, T_enc, 1), dtype: int64. text_positions (Variable): Shape(B, T_enc), dtype: int64.
Positions indices for text inputs for the encoder, where Positions indices for text inputs for the encoder, where
T_enc means the encoder timesteps. T_enc means the encoder timesteps.
frame_positions (Variable): Shape(B, T_dec // r, 1), dtype: frame_positions (Variable): Shape(B, T_dec // r), dtype:
int64. Positions indices for each decoder time steps. int64. Positions indices for each decoder time steps.
Returns: Returns:
...@@ -1295,12 +1295,12 @@ class DeepVoiceTTS(dg.Layer): ...@@ -1295,12 +1295,12 @@ class DeepVoiceTTS(dg.Layer):
Encode text sequence and decode without ground truth mel spectrogram. Encode text sequence and decode without ground truth mel spectrogram.
Args: Args:
text_sequences (Variable): Shape(B, T_enc, 1), dtype: int64. Ihe text_sequences (Variable): Shape(B, T_enc), dtype: int64. Ihe
input text indices. T_enc means the timesteps of text_sequences. input text indices. T_enc means the timesteps of text_sequences.
text_positions (Variable): Shape(B, T_enc, 1), dtype: int64. text_positions (Variable): Shape(B, T_enc), dtype: int64.
Positions indices for text inputs for the encoder, where Positions indices for text inputs for the encoder, where
T_enc means the encoder timesteps. T_enc means the encoder timesteps.
speaker_indices (Variable, optional): Shape(Batch_size, 1), speaker_indices (Variable, optional): Shape(Batch_size),
dtype: int64. Speaker index for each example. This arg is not dtype: int64. Speaker index for each example. This arg is not
None only when the model is a multispeaker model. None only when the model is a multispeaker model.
...@@ -1423,7 +1423,7 @@ class ConvS2S(dg.Layer): ...@@ -1423,7 +1423,7 @@ class ConvS2S(dg.Layer):
Encode text sequence and decode with ground truth mel spectrogram. Encode text sequence and decode with ground truth mel spectrogram.
Args: Args:
text_sequences (Variable): Shape(B, T_enc, 1), dtype: int64. Ihe text_sequences (Variable): Shape(B, T_enc), dtype: int64. Ihe
input text indices. T_enc means the timesteps of text_sequences. input text indices. T_enc means the timesteps of text_sequences.
valid_lengths (Variable): shape(batch_size,), dtype: int64, valid_lengths (Variable): shape(batch_size,), dtype: int64,
valid lengths for each example in text_sequences. valid lengths for each example in text_sequences.
...@@ -1432,10 +1432,10 @@ class ConvS2S(dg.Layer): ...@@ -1432,10 +1432,10 @@ class ConvS2S(dg.Layer):
speaker_embed (Variable, optional): Shape(Batch_size, speaker_dim), speaker_embed (Variable, optional): Shape(Batch_size, speaker_dim),
dtype: float32. Speaker embeddings. This arg is not None only dtype: float32. Speaker embeddings. This arg is not None only
when the model is a multispeaker model. when the model is a multispeaker model.
text_positions (Variable): Shape(B, T_enc, 1), dtype: int64. text_positions (Variable): Shape(B, T_enc), dtype: int64.
Positions indices for text inputs for the encoder, where Positions indices for text inputs for the encoder, where
T_enc means the encoder timesteps. T_enc means the encoder timesteps.
frame_positions (Variable): Shape(B, T_dec // r, 1), dtype: frame_positions (Variable): Shape(B, T_dec // r), dtype:
int64. Positions indices for each decoder time steps. int64. Positions indices for each decoder time steps.
Returns: Returns:
...@@ -1466,9 +1466,9 @@ class ConvS2S(dg.Layer): ...@@ -1466,9 +1466,9 @@ class ConvS2S(dg.Layer):
Encode text sequence and decode without ground truth mel spectrogram. Encode text sequence and decode without ground truth mel spectrogram.
Args: Args:
text_sequences (Variable): Shape(B, T_enc, 1), dtype: int64. Ihe text_sequences (Variable): Shape(B, T_enc), dtype: int64. Ihe
input text indices. T_enc means the timesteps of text_sequences. input text indices. T_enc means the timesteps of text_sequences.
text_positions (Variable): Shape(B, T_enc, 1), dtype: int64. text_positions (Variable): Shape(B, T_enc), dtype: int64.
Positions indices for text inputs for the encoder, where Positions indices for text inputs for the encoder, where
T_enc means the encoder timesteps. T_enc means the encoder timesteps.
speaker_embed (Variable, optional): Shape(Batch_size, speaker_dim), speaker_embed (Variable, optional): Shape(Batch_size, speaker_dim),
......
...@@ -48,7 +48,7 @@ def dry_run(model): ...@@ -48,7 +48,7 @@ def dry_run(model):
mel_dim = hparams.num_mels mel_dim = hparams.num_mels
x = np.random.randint( x = np.random.randint(
low=0, high=n_vocab, size=(batch_size, enc_length, 1), dtype="int64") low=0, high=n_vocab, size=(batch_size, enc_length), dtype="int64")
input_lengths = np.arange( input_lengths = np.arange(
enc_length - batch_size + 1, enc_length + 1, dtype="int64") enc_length - batch_size + 1, enc_length + 1, dtype="int64")
mel = np.random.randn(batch_size, mel_dim, 1, mel_length).astype("float32") mel = np.random.randn(batch_size, mel_dim, 1, mel_length).astype("float32")
...@@ -60,18 +60,16 @@ def dry_run(model): ...@@ -60,18 +60,16 @@ def dry_run(model):
0, enc_length, dtype="int64"), (batch_size, 1)) 0, enc_length, dtype="int64"), (batch_size, 1))
text_mask = text_positions > np.expand_dims(input_lengths, 1) text_mask = text_positions > np.expand_dims(input_lengths, 1)
text_positions[text_mask] = 0 text_positions[text_mask] = 0
text_positions = np.expand_dims(text_positions, axis=-1)
frame_positions = np.tile( frame_positions = np.tile(
np.arange( np.arange(
1, decoder_length + 1, dtype="int64"), (batch_size, 1)) 1, decoder_length + 1, dtype="int64"), (batch_size, 1))
frame_positions = np.expand_dims(frame_positions, axis=-1)
done = np.zeros(shape=(batch_size, 1, 1, decoder_length), dtype="float32") done = np.zeros(shape=(batch_size, 1, 1, decoder_length), dtype="float32")
target_lengths = np.array([snd_sample_length] * batch_size).astype("int64") target_lengths = np.array([snd_sample_length] * batch_size).astype("int64")
speaker_ids = np.random.randint( speaker_ids = np.random.randint(
low=0, high=n_speakers, size=(batch_size, 1), low=0, high=n_speakers, size=(batch_size),
dtype="int64") if n_speakers > 1 else None dtype="int64") if n_speakers > 1 else None
ismultispeaker = speaker_ids is not None ismultispeaker = speaker_ids is not None
......
...@@ -366,14 +366,14 @@ class PositionEmbedding(dg.Layer): ...@@ -366,14 +366,14 @@ class PositionEmbedding(dg.Layer):
self._dtype = dtype self._dtype = dtype
def set_weight(self, array): def set_weight(self, array):
assert self.embed._w.shape == list(array.shape), "shape does not match" assert self.embed.weight.shape == list(
self.embed._w.value().get_tensor().set( array.shape), "shape does not match"
array, fluid.framework._current_expected_place()) self.embed.weight.set_value(array)
def forward(self, indices, speaker_position_rate=None): def forward(self, indices, speaker_position_rate=None):
""" """
Args: Args:
indices (Variable): Shape (B, T, 1), dtype: int64, position indices (Variable): Shape (B, T), dtype: int64, position
indices, where B means the batch size, T means the time steps. indices, where B means the batch size, T means the time steps.
speaker_position_rate (Variable | float, optional), position speaker_position_rate (Variable | float, optional), position
rate. It can be a float point number or a Variable with rate. It can be a float point number or a Variable with
...@@ -391,7 +391,7 @@ class PositionEmbedding(dg.Layer): ...@@ -391,7 +391,7 @@ class PositionEmbedding(dg.Layer):
weight = compute_position_embedding(rad) weight = compute_position_embedding(rad)
out = self._helper.create_variable_for_type_inference(self._dtype) out = self._helper.create_variable_for_type_inference(self._dtype)
self._helper.append_op( self._helper.append_op(
type="lookup_table", type="lookup_table_v2",
inputs={"Ids": indices, inputs={"Ids": indices,
"W": weight}, "W": weight},
outputs={"Out": out}, outputs={"Out": out},
...@@ -417,7 +417,7 @@ class PositionEmbedding(dg.Layer): ...@@ -417,7 +417,7 @@ class PositionEmbedding(dg.Layer):
weight = compute_position_embedding(scaled_rad) weight = compute_position_embedding(scaled_rad)
out = self._helper.create_variable_for_type_inference(self._dtype) out = self._helper.create_variable_for_type_inference(self._dtype)
self._helper.append_op( self._helper.append_op(
type="lookup_table", type="lookup_table_v2",
inputs={"Ids": indices, inputs={"Ids": indices,
"W": weight}, "W": weight},
outputs={"Out": out}, outputs={"Out": out},
...@@ -441,7 +441,7 @@ class PositionEmbedding(dg.Layer): ...@@ -441,7 +441,7 @@ class PositionEmbedding(dg.Layer):
self._dtype) self._dtype)
sequence = indices[i] sequence = indices[i]
self._helper.append_op( self._helper.append_op(
type="lookup_table", type="lookup_table_v2",
inputs={"Ids": sequence, inputs={"Ids": sequence,
"W": weight}, "W": weight},
outputs={"Out": out}, outputs={"Out": out},
......
...@@ -67,9 +67,9 @@ def tts(model, text, p=0., speaker_id=None): ...@@ -67,9 +67,9 @@ def tts(model, text, p=0., speaker_id=None):
model.eval() model.eval()
sequence = np.array(_frontend.text_to_sequence(text, p=p)).astype("int64") sequence = np.array(_frontend.text_to_sequence(text, p=p)).astype("int64")
sequence = np.reshape(sequence, (1, -1, 1)) sequence = np.reshape(sequence, (1, -1))
text_positions = np.arange(1, sequence.shape[1] + 1, dtype="int64") text_positions = np.arange(1, sequence.shape[1] + 1, dtype="int64")
text_positions = np.reshape(text_positions, (1, -1, 1)) text_positions = np.reshape(text_positions, (1, -1))
sequence = dg.to_variable(sequence) sequence = dg.to_variable(sequence)
text_positions = dg.to_variable(text_positions) text_positions = dg.to_variable(text_positions)
...@@ -191,8 +191,8 @@ def eval_model(global_step, writer, model, checkpoint_dir, ismultispeaker): ...@@ -191,8 +191,8 @@ def eval_model(global_step, writer, model, checkpoint_dir, ismultispeaker):
# Mel # Mel
writer.add_image( writer.add_image(
"(Eval) Predicted mel spectrogram text{}_{}".format( "Eval_Predicted_mel_spectrogram_text{}_{}".format(idx,
idx, speaker_str), speaker_str),
prepare_spec_image(mel), prepare_spec_image(mel),
global_step, global_step,
dataformats='HWC') dataformats='HWC')
...@@ -205,7 +205,7 @@ def eval_model(global_step, writer, model, checkpoint_dir, ismultispeaker): ...@@ -205,7 +205,7 @@ def eval_model(global_step, writer, model, checkpoint_dir, ismultispeaker):
try: try:
writer.add_audio( writer.add_audio(
"(Eval) Predicted audio signal {}_{}".format(idx, "Eval_Predicted_audio_signal_{}_{}".format(idx,
speaker_str), speaker_str),
signal, signal,
global_step, global_step,
...@@ -273,7 +273,7 @@ def save_states(global_step, ...@@ -273,7 +273,7 @@ def save_states(global_step,
mel_output = mel_outputs[idx].numpy().squeeze().T mel_output = mel_outputs[idx].numpy().squeeze().T
mel_output = prepare_spec_image(audio._denormalize(mel_output)) mel_output = prepare_spec_image(audio._denormalize(mel_output))
writer.add_image( writer.add_image(
"Predicted mel spectrogram", "Predicted_mel_spectrogram",
mel_output, mel_output,
global_step, global_step,
dataformats="HWC") dataformats="HWC")
...@@ -282,7 +282,7 @@ def save_states(global_step, ...@@ -282,7 +282,7 @@ def save_states(global_step,
linear_output = linear_outputs[idx].numpy().squeeze().T linear_output = linear_outputs[idx].numpy().squeeze().T
spectrogram = prepare_spec_image(audio._denormalize(linear_output)) spectrogram = prepare_spec_image(audio._denormalize(linear_output))
writer.add_image( writer.add_image(
"Predicted linear spectrogram", "Predicted_linear_spectrogram",
spectrogram, spectrogram,
global_step, global_step,
dataformats="HWC") dataformats="HWC")
...@@ -293,7 +293,7 @@ def save_states(global_step, ...@@ -293,7 +293,7 @@ def save_states(global_step,
"step{:09d}_predicted.wav".format(global_step)) "step{:09d}_predicted.wav".format(global_step))
try: try:
writer.add_audio( writer.add_audio(
"Predicted audio signal", "Predicted_audio_signal",
signal, signal,
global_step, global_step,
sample_rate=hparams.sample_rate) sample_rate=hparams.sample_rate)
...@@ -306,7 +306,7 @@ def save_states(global_step, ...@@ -306,7 +306,7 @@ def save_states(global_step,
mel_output = mel[idx].numpy().squeeze().T mel_output = mel[idx].numpy().squeeze().T
mel_output = prepare_spec_image(audio._denormalize(mel_output)) mel_output = prepare_spec_image(audio._denormalize(mel_output))
writer.add_image( writer.add_image(
"Target mel spectrogram", "Target_mel_spectrogram",
mel_output, mel_output,
global_step, global_step,
dataformats="HWC") dataformats="HWC")
...@@ -315,7 +315,7 @@ def save_states(global_step, ...@@ -315,7 +315,7 @@ def save_states(global_step,
linear_output = y[idx].numpy().squeeze().T linear_output = y[idx].numpy().squeeze().T
spectrogram = prepare_spec_image(audio._denormalize(linear_output)) spectrogram = prepare_spec_image(audio._denormalize(linear_output))
writer.add_image( writer.add_image(
"Target linear spectrogram", "Target_linear_spectrogram",
spectrogram, spectrogram,
global_step, global_step,
dataformats="HWC") dataformats="HWC")
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册