提交 756be8fb 编写于 作者: H Hui Zhang

fix dist batch sampler set_epcoh call

上级 12f540cd
......@@ -24,6 +24,7 @@ import numpy as np
import paddle
from paddle import distributed as dist
from paddle.io import DataLoader
from paddle.io import DistributedBatchSampler
from yacs.config import CfgNode
from deepspeech.io.collator import SpeechCollator
......@@ -162,8 +163,10 @@ class U2Trainer(Trainer):
self.save(tag='init')
self.lr_scheduler.step(self.iteration)
if self.parallel:
self.train_loader.batch_sampler.set_epoch(self.epoch)
if hasattr(self.train_loader, "batch_sampler"):
batch_sampler = self.train_loader.batch_sampler
if isinstance(batch_sampler, DistributedBatchSampler):
batch_sampler.set_epoch(self.epoch)
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.training.n_epoch:
......@@ -476,13 +479,6 @@ class U2Tester(U2Trainer):
})
f.write(data + '\n')
# def run_test(self):
# self.resume_or_scratch()
# try:
# self.test()
# except KeyboardInterrupt:
# sys.exit(-1)
def load_inferspec(self):
"""infer model and input spec.
......@@ -491,7 +487,7 @@ class U2Tester(U2Trainer):
List[paddle.static.InputSpec]: input spec.
"""
from deepspeech.models.u2 import U2InferModel
infer_model = U2InferModel.from_pretrained(self.test_loader.dataset,
infer_model = U2InferModel.from_pretrained(self.test_loader,
self.config.model.clone(),
self.args.checkpoint_path)
feat_dim = self.test_loader.dataset.feature_size
......@@ -511,37 +507,3 @@ class U2Tester(U2Trainer):
static_model = paddle.jit.to_static(infer_model, input_spec=input_spec)
logger.info(f"Export code: {static_model.forward.code}")
paddle.jit.save(static_model, self.args.export_path)
# def run_export(self):
# try:
# self.export()
# except KeyboardInterrupt:
# sys.exit(-1)
# def setup(self):
# """Setup the experiment.
# """
# paddle.set_device(self.args.device)
# self.setup_output_dir()
# self.setup_checkpointer()
# self.setup_dataloader()
# self.setup_model()
# self.iteration = 0
# self.epoch = 0
# def setup_output_dir(self):
# """Create a directory used for output.
# """
# # output dir
# if self.args.output:
# output_dir = Path(self.args.output).expanduser()
# output_dir.mkdir(parents=True, exist_ok=True)
# else:
# output_dir = Path(
# self.args.checkpoint_path).expanduser().parent.parent
# output_dir.mkdir(parents=True, exist_ok=True)
# self.output_dir = output_dir
......@@ -25,10 +25,10 @@ class SpecAugmentor(AugmentorBase):
SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition
https://arxiv.org/abs/1904.08779
SpecAugment on Large Scale Datasets
https://arxiv.org/abs/1912.05533
"""
def __init__(self,
......@@ -41,7 +41,8 @@ class SpecAugmentor(AugmentorBase):
W=40,
adaptive_number_ratio=0,
adaptive_size_ratio=0,
max_n_time_masks=20):
max_n_time_masks=20,
**kwargs):
"""SpecAugment class.
Args:
rng (random.Random): random generator object.
......@@ -121,7 +122,7 @@ class SpecAugmentor(AugmentorBase):
def time_mask(self):
return self._time_mask
def time_warp(xs, W=40):
def time_warp(self, xs, W=40):
raise NotImplementedError
def mask_freq(self, xs, replace_with_zero=False):
......
......@@ -13,6 +13,7 @@
# limitations under the License.
"""Contains the text featurizer class."""
import sentencepiece as spm
from pprint import pformat
from ..utility import EOS
from ..utility import SPACE
......@@ -206,7 +207,7 @@ class TextFeaturizer():
"""Load vocabulary from file."""
vocab_list = load_dict(vocab_filepath, maskctc)
assert vocab_list is not None
logger.info(f"Vocab: {vocab_list}")
logger.info(f"Vocab: {pformat(vocab_list)}")
id2token = dict(
[(idx, token) for (idx, token) in enumerate(vocab_list)])
......@@ -220,10 +221,10 @@ class TextFeaturizer():
sos_id = vocab_list.index(SOS) if SOS in vocab_list else -1
space_id = vocab_list.index(SPACE) if SPACE in vocab_list else -1
logger.info(f"BLANK id: {blank_id}")
logger.info(f"UNK id: {unk_id}")
logger.info(f"EOS id: {eos_id}")
logger.info(f"SOS id: {sos_id}")
logger.info(f"SPACE id: {space_id}")
logger.info(f"BLANK id: {blank_id}")
logger.info(f"MASKCTC id: {maskctc_id}")
return token2id, id2token, vocab_list, unk_id, eos_id
......@@ -911,8 +911,10 @@ class U2Model(U2BaseModel):
DeepSpeech2Model: The model built from pretrained result.
"""
with UpdateConfig(config):
config.input_dim = dataloader.collate_fn.feature_size
config.output_dim = dataloader.collate_fn.vocab_size
#config.input_dim = dataloader.collate_fn.feature_size
#config.output_dim = dataloader.collate_fn.vocab_size
config.input_dim = dataloader.dataset.feature_size
config.output_dim = dataloader.dataset.vocab_size
model = cls.from_config(config)
......
......@@ -17,6 +17,7 @@ from pathlib import Path
import paddle
from paddle import distributed as dist
from paddle.io import DistributedBatchSampler
from tensorboardX import SummaryWriter
from deepspeech.utils import mp_tools
......@@ -179,8 +180,10 @@ class Trainer():
"""Reset the train loader seed and increment `epoch`.
"""
self.epoch += 1
if self.parallel:
self.train_loader.batch_sampler.set_epoch(self.epoch)
if hasattr(self.train_loader, "batch_sampler"):
batch_sampler = self.train_loader.batch_sampler
if isinstance(batch_sampler, DistributedBatchSampler):
batch_sampler.set_epoch(self.epoch)
def train(self):
"""The training process control by epoch."""
......@@ -190,8 +193,10 @@ class Trainer():
self.save(tag='init')
self.lr_scheduler.step(self.iteration)
if self.parallel:
self.train_loader.batch_sampler.set_epoch(self.epoch)
if hasattr(self.train_loader, "batch_sampler"):
batch_sampler = self.train_loader.batch_sampler
if isinstance(batch_sampler, DistributedBatchSampler):
batch_sampler.set_epoch(self.epoch)
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.training.n_epoch:
......
[
{
"type": "speed",
"params": {
"min_speed_rate": 0.9,
"max_speed_rate": 1.1,
"num_rates": 3
},
"prob": 0.0
},
{
"type": "shift",
"params": {
......@@ -6,5 +15,22 @@
"max_shift_ms": 5
},
"prob": 1.0
},
{
"type": "specaug",
"params": {
"W": 5,
"warp_mode": "PIL",
"F": 30,
"n_freq_masks": 2,
"T": 40,
"n_time_masks": 2,
"p": 1.0,
"adaptive_number_ratio": 0,
"adaptive_size_ratio": 0,
"max_n_time_masks": 20,
"replace_with_zero": true
},
"prob": 1.0
}
]
#! /usr/bin/env bash
#!/bin/bash
stage=-1
stop_stage=100
......
export MAIN_ROOT=${PWD}/../../../
export MAIN_ROOT=`realpath ${PWD}/../../../`
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
export LC_ALL=C
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册