提交 e5641ca4 编写于 作者: H Hui Zhang

fix bugs, refactor collator, add pad_sequence, fix ckpt bugs

上级 944457d6
......@@ -13,6 +13,9 @@
# limitations under the License.
import logging
from typing import Union
from typing import Optional
from typing import List
from typing import Tuple
from typing import Any
import paddle
......@@ -83,6 +86,20 @@ if not hasattr(paddle.Tensor, 'numel'):
paddle.Tensor.numel = paddle.numel
def new_full(x: paddle.Tensor,
size: Union[List[int], Tuple[int], paddle.Tensor],
fill_value: Union[float, int, bool, paddle.Tensor],
dtype=None):
return paddle.full(size, fill_value, dtype=x.dtype)
if not hasattr(paddle.Tensor, 'new_full'):
logger.warn(
"override new_full of paddle.Tensor if exists or register, remove this when fixed!"
)
paddle.Tensor.new_full = new_full
def eq(xs: paddle.Tensor, ys: Union[paddle.Tensor, float]) -> paddle.Tensor:
return xs.equal(paddle.to_tensor(ys, dtype=xs.dtype, place=xs.place))
......@@ -279,6 +296,7 @@ if not hasattr(paddle.nn, 'Module'):
logger.warn("register user Module to paddle.nn, remove this when fixed!")
setattr(paddle.nn, 'Module', paddle.nn.Layer)
# maybe cause assert isinstance(sublayer, core.Layer)
if not hasattr(paddle.nn, 'ModuleList'):
logger.warn(
"register user ModuleList to paddle.nn, remove this when fixed!")
......@@ -332,3 +350,78 @@ if not hasattr(paddle.nn, 'ConstantPad2d'):
logger.warn(
"register user ConstantPad2d to paddle.nn, remove this when fixed!")
setattr(paddle.nn, 'ConstantPad2d', ConstantPad2d)
########### hcak paddle.jit #############
if not hasattr(paddle.jit, 'export'):
logger.warn("register user export to paddle.jit, remove this when fixed!")
setattr(paddle.jit, 'export', paddle.jit.to_static)
########### hcak paddle.nn.utils #############
def pad_sequence(sequences: List[paddle.Tensor],
batch_first: bool=False,
padding_value: float=0.0) -> paddle.Tensor:
r"""Pad a list of variable length Tensors with ``padding_value``
``pad_sequence`` stacks a list of Tensors along a new dimension,
and pads them to equal length. For example, if the input is list of
sequences with size ``L x *`` and if batch_first is False, and ``T x B x *``
otherwise.
`B` is batch size. It is equal to the number of elements in ``sequences``.
`T` is length of the longest sequence.
`L` is length of the sequence.
`*` is any number of trailing dimensions, including none.
Example:
>>> from paddle.nn.utils.rnn import pad_sequence
>>> a = paddle.ones(25, 300)
>>> b = paddle.ones(22, 300)
>>> c = paddle.ones(15, 300)
>>> pad_sequence([a, b, c]).size()
paddle.Tensor([25, 3, 300])
Note:
This function returns a Tensor of size ``T x B x *`` or ``B x T x *``
where `T` is the length of the longest sequence. This function assumes
trailing dimensions and type of all the Tensors in sequences are same.
Args:
sequences (list[Tensor]): list of variable length sequences.
batch_first (bool, optional): output will be in ``B x T x *`` if True, or in
``T x B x *`` otherwise
padding_value (float, optional): value for padded elements. Default: 0.
Returns:
Tensor of size ``T x B x *`` if :attr:`batch_first` is ``False``.
Tensor of size ``B x T x *`` otherwise
"""
# assuming trailing dimensions and type of all the Tensors
# in sequences are same and fetching those from sequences[0]
max_size = sequences[0].size()
trailing_dims = max_size[1:]
max_len = max([s.size(0) for s in sequences])
if batch_first:
out_dims = (len(sequences), max_len) + trailing_dims
else:
out_dims = (max_len, len(sequences)) + trailing_dims
out_tensor = sequences[0].new_full(out_dims, padding_value)
for i, tensor in enumerate(sequences):
length = tensor.size(0)
# use index notation to prevent duplicate references to the tensor
if batch_first:
out_tensor[i, :length, ...] = tensor
else:
out_tensor[:length, i, ...] = tensor
return out_tensor
if not hasattr(paddle.nn.utils, 'rnn.pad_sequence'):
logger.warn(
"register user rnn.pad_sequence to paddle.nn.utils, remove this when fixed!"
)
setattr(paddle.nn.utils, 'rnn.pad_sequence', pad_sequence)
......@@ -16,15 +16,15 @@ import logging
import numpy as np
from collections import namedtuple
from deepspeech.io.utility import pad_sequence
logger = logging.getLogger(__name__)
__all__ = [
"SpeechCollator",
]
__all__ = ["SpeechCollator"]
class SpeechCollator():
def __init__(self, padding_to=-1, is_training=True):
def __init__(self, is_training=True):
"""
Padding audio features with zeros to make them have the same shape (or
a user-defined shape) within one bach.
......@@ -32,42 +32,51 @@ class SpeechCollator():
If ``padding_to`` is -1, the maximun shape in the batch will be used
as the target shape for padding. Otherwise, `padding_to` will be the
target shape (only refers to the second axis).
if ``is_training`` is True, text is token ids else is raw string.
"""
self._padding_to = padding_to
self._is_training = is_training
def __call__(self, batch):
new_batch = []
# get target shape
max_length = max([audio.shape[1] for audio, _ in batch])
if self._padding_to != -1:
if self._padding_to < max_length:
raise ValueError("If padding_to is not -1, it should be larger "
"than any instance's shape in the batch")
max_length = self._padding_to
max_text_length = max([len(text) for _, text in batch])
# padding
padded_audios = []
"""batch examples
Args:
batch ([List]): batch is (audio, text)
audio (np.ndarray) shape (D, T)
text (List[int] or str): shape (U,)
Returns:
tuple(audio, text, audio_lens, text_lens): batched data.
audio : (B, Tmax, D)
text : (B, Umax)
audio_lens: (B)
text_lens: (B)
"""
audios = []
audio_lens = []
texts, text_lens = [], []
texts = []
text_lens = []
for audio, text in batch:
# audio
padded_audio = np.zeros([audio.shape[0], max_length])
padded_audio[:, :audio.shape[1]] = audio
padded_audios.append(padded_audio)
audios.append(audio.T) # [T, D]
audio_lens.append(audio.shape[1])
# text
padded_text = np.zeros([max_text_length])
# for training, text is token ids
# else text is string, convert to unicode ord
tokens = []
if self._is_training:
padded_text[:len(text)] = text # token ids
tokens = text # token ids
else:
padded_text[:len(text)] = [ord(t)
for t in text] # string, unicode ord
texts.append(padded_text)
assert isinstance(text, str)
tokens = [ord(t) for t in text]
tokens = tokens if isinstance(tokens, np.ndarray) else np.array(
tokens, dtype=np.int64)
texts.append(tokens)
text_lens.append(len(text))
padded_audios = np.array(padded_audios).astype('float32')
audio_lens = np.array(audio_lens).astype('int64')
texts = np.array(texts).astype('int32')
text_lens = np.array(text_lens).astype('int64')
return padded_audios, texts, audio_lens, text_lens
padded_audios = pad_sequence(
audios, padding_value=0.0).astype(np.float32) #[B, T, D]
padded_texts = pad_sequence(texts, padding_value=-1).astype(np.int32)
audio_lens = np.array(audio_lens).astype(np.int64)
text_lens = np.array(text_lens).astype(np.int64)
return padded_audios, padded_texts, audio_lens, text_lens
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import numpy as np
from collections import namedtuple
from typing import List
logger = logging.getLogger(__name__)
__all__ = ["pad_sequence"]
def pad_sequence(sequences: List[np.ndarray],
batch_first: bool=True,
padding_value: float=0.0) -> np.ndarray:
r"""Pad a list of variable length Tensors with ``padding_value``
``pad_sequence`` stacks a list of Tensors along a new dimension,
and pads them to equal length. For example, if the input is list of
sequences with size ``L x *`` and if batch_first is False, and ``T x B x *``
otherwise.
`B` is batch size. It is equal to the number of elements in ``sequences``.
`T` is length of the longest sequence.
`L` is length of the sequence.
`*` is any number of trailing dimensions, including none.
Example:
>>> a = np.ones([25, 300])
>>> b = np.ones([22, 300])
>>> c = np.ones([15, 300])
>>> pad_sequence([a, b, c]).shape
[25, 3, 300]
Note:
This function returns a np.ndarray of size ``T x B x *`` or ``B x T x *``
where `T` is the length of the longest sequence. This function assumes
trailing dimensions and type of all the Tensors in sequences are same.
Args:
sequences (list[np.ndarray]): list of variable length sequences.
batch_first (bool, optional): output will be in ``B x T x *`` if True, or in
``T x B x *`` otherwise
padding_value (float, optional): value for padded elements. Default: 0.
Returns:
np.ndarray of size ``T x B x *`` if :attr:`batch_first` is ``False``.
np.ndarray of size ``B x T x *`` otherwise
"""
# assuming trailing dimensions and type of all the Tensors
# in sequences are same and fetching those from sequences[0]
max_size = sequences[0].shape
trailing_dims = max_size[1:]
max_len = max([s.shape[0] for s in sequences])
if batch_first:
out_dims = (len(sequences), max_len) + trailing_dims
else:
out_dims = (max_len, len(sequences)) + trailing_dims
out_tensor = np.full(out_dims, padding_value, dtype=sequences[0].dtype)
for i, tensor in enumerate(sequences):
length = tensor.shape[0]
# use index notation to prevent duplicate references to the tensor
if batch_first:
out_tensor[i, :length, ...] = tensor
else:
out_tensor[:length, i, ...] = tensor
return out_tensor
......@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Deepspeech2 ASR Model"""
import math
import collections
import numpy as np
......@@ -67,23 +67,19 @@ class CRNNEncoder(nn.Layer):
return self.rnn_size * 2
def forward(self, audio, audio_len):
"""
audio: shape [B, D, T]
text: shape [B, T]
audio_len: shape [B]
text_len: shape [B]
"""
"""Compute Encoder outputs
Args:
audio (Tensor): [B, D, T]
text (Tensor): [B, T]
audio (Tensor): [B, Tmax, D]
text (Tensor): [B, Umax]
audio_len (Tensor): [B]
text_len (Tensor): [B]
Returns:
x (Tensor): encoder outputs, [B, T, D]
x_lens (Tensor): encoder length, [B]
"""
# [B, T, D] -> [B, D, T]
audio = audio.transpose([0, 2, 1])
# [B, D, T] -> [B, C=1, D, T]
x = audio.unsqueeze(1)
x_lens = audio_len
......
此差异已折叠。
......@@ -145,7 +145,7 @@ class ConvStack(nn.Layer):
act='brelu')
out_channel = 32
self.conv_stack = nn.Sequential([
convs = [
ConvBn(
num_channels_in=32,
num_channels_out=out_channel,
......@@ -153,7 +153,8 @@ class ConvStack(nn.Layer):
stride=(2, 1),
padding=(10, 5),
act='brelu') for i in range(num_stacks - 1)
])
]
self.conv_stack = nn.LayerList(convs)
# conv output feat_dim
output_height = (feat_size - 1) // 2 + 1
......
......@@ -298,7 +298,7 @@ class RNNStack(nn.Layer):
share_weights=share_rnn_weights))
i_size = h_size * 2
self.rnn_stacks = nn.Sequential(rnn_stacks)
self.rnn_stacks = nn.ModuleList(rnn_stacks)
def forward(self, x: paddle.Tensor, x_len: paddle.Tensor):
"""
......
......@@ -128,9 +128,10 @@ class Trainer():
dist.init_parallel_env()
@mp_tools.rank_zero_only
def save(self):
def save(self, infos=None):
"""Save checkpoint (model parameters and optimizer states).
"""
if infos is None:
infos = {
"step": self.iteration,
"epoch": self.epoch,
......@@ -151,6 +152,7 @@ class Trainer():
self.optimizer,
checkpoint_dir=self.checkpoint_dir,
checkpoint_path=self.args.checkpoint_path)
if infos:
self.iteration = infos["step"]
self.epoch = infos["epoch"]
......
......@@ -36,11 +36,11 @@ def _load_latest_checkpoint(checkpoint_dir: str) -> int:
Args:
checkpoint_dir (str): the directory where checkpoint is saved.
Returns:
int: the latest iteration number.
int: the latest iteration number. -1 for no checkpoint to load.
"""
checkpoint_record = os.path.join(checkpoint_dir, "checkpoint")
if not os.path.isfile(checkpoint_record):
return 0
return -1
# Fetch the latest checkpoint index.
with open(checkpoint_record, "rt") as handle:
......@@ -79,11 +79,15 @@ def load_parameters(model,
Returns:
configs (dict): epoch or step, lr and other meta info should be saved.
"""
configs = {}
if checkpoint_path is not None:
iteration = int(os.path.basename(checkpoint_path).split(":")[-1])
elif checkpoint_dir is not None:
iteration = _load_latest_checkpoint(checkpoint_dir)
checkpoint_path = os.path.join(checkpoint_dir, "-{}".format(iteration))
if iteration == -1:
return configs
checkpoint_path = os.path.join(checkpoint_dir, "{}".format(iteration))
else:
raise ValueError(
"At least one of 'checkpoint_dir' and 'checkpoint_path' should be specified!"
......@@ -104,7 +108,6 @@ def load_parameters(model,
rank, optimizer_path))
info_path = re.sub('.pdparams$', '.json', params_path)
configs = {}
if os.path.exists(info_path):
with open(info_path, 'r') as fin:
configs = json.load(fin)
......@@ -128,7 +131,7 @@ def save_parameters(checkpoint_dir: str,
Returns:
None
"""
checkpoint_path = os.path.join(checkpoint_dir, "-{}".format(iteration))
checkpoint_path = os.path.join(checkpoint_dir, "{}".format(iteration))
model_dict = model.state_dict()
params_path = checkpoint_path + ".pdparams"
......
......@@ -16,6 +16,7 @@
import math
import numpy as np
import distutils.util
from typing import List
__all__ = ['print_arguments', 'add_arguments', "log_add"]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册