提交 756be8fb 编写于 作者: H Hui Zhang

fix dist batch sampler set_epcoh call

上级 12f540cd
...@@ -24,6 +24,7 @@ import numpy as np ...@@ -24,6 +24,7 @@ import numpy as np
import paddle import paddle
from paddle import distributed as dist from paddle import distributed as dist
from paddle.io import DataLoader from paddle.io import DataLoader
from paddle.io import DistributedBatchSampler
from yacs.config import CfgNode from yacs.config import CfgNode
from deepspeech.io.collator import SpeechCollator from deepspeech.io.collator import SpeechCollator
...@@ -162,8 +163,10 @@ class U2Trainer(Trainer): ...@@ -162,8 +163,10 @@ class U2Trainer(Trainer):
self.save(tag='init') self.save(tag='init')
self.lr_scheduler.step(self.iteration) self.lr_scheduler.step(self.iteration)
if self.parallel: if hasattr(self.train_loader, "batch_sampler"):
self.train_loader.batch_sampler.set_epoch(self.epoch) batch_sampler = self.train_loader.batch_sampler
if isinstance(batch_sampler, DistributedBatchSampler):
batch_sampler.set_epoch(self.epoch)
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.training.n_epoch: while self.epoch < self.config.training.n_epoch:
...@@ -476,13 +479,6 @@ class U2Tester(U2Trainer): ...@@ -476,13 +479,6 @@ class U2Tester(U2Trainer):
}) })
f.write(data + '\n') f.write(data + '\n')
# def run_test(self):
# self.resume_or_scratch()
# try:
# self.test()
# except KeyboardInterrupt:
# sys.exit(-1)
def load_inferspec(self): def load_inferspec(self):
"""infer model and input spec. """infer model and input spec.
...@@ -491,7 +487,7 @@ class U2Tester(U2Trainer): ...@@ -491,7 +487,7 @@ class U2Tester(U2Trainer):
List[paddle.static.InputSpec]: input spec. List[paddle.static.InputSpec]: input spec.
""" """
from deepspeech.models.u2 import U2InferModel from deepspeech.models.u2 import U2InferModel
infer_model = U2InferModel.from_pretrained(self.test_loader.dataset, infer_model = U2InferModel.from_pretrained(self.test_loader,
self.config.model.clone(), self.config.model.clone(),
self.args.checkpoint_path) self.args.checkpoint_path)
feat_dim = self.test_loader.dataset.feature_size feat_dim = self.test_loader.dataset.feature_size
...@@ -511,37 +507,3 @@ class U2Tester(U2Trainer): ...@@ -511,37 +507,3 @@ class U2Tester(U2Trainer):
static_model = paddle.jit.to_static(infer_model, input_spec=input_spec) static_model = paddle.jit.to_static(infer_model, input_spec=input_spec)
logger.info(f"Export code: {static_model.forward.code}") logger.info(f"Export code: {static_model.forward.code}")
paddle.jit.save(static_model, self.args.export_path) paddle.jit.save(static_model, self.args.export_path)
# def run_export(self):
# try:
# self.export()
# except KeyboardInterrupt:
# sys.exit(-1)
# def setup(self):
# """Setup the experiment.
# """
# paddle.set_device(self.args.device)
# self.setup_output_dir()
# self.setup_checkpointer()
# self.setup_dataloader()
# self.setup_model()
# self.iteration = 0
# self.epoch = 0
# def setup_output_dir(self):
# """Create a directory used for output.
# """
# # output dir
# if self.args.output:
# output_dir = Path(self.args.output).expanduser()
# output_dir.mkdir(parents=True, exist_ok=True)
# else:
# output_dir = Path(
# self.args.checkpoint_path).expanduser().parent.parent
# output_dir.mkdir(parents=True, exist_ok=True)
# self.output_dir = output_dir
...@@ -25,10 +25,10 @@ class SpecAugmentor(AugmentorBase): ...@@ -25,10 +25,10 @@ class SpecAugmentor(AugmentorBase):
SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition
https://arxiv.org/abs/1904.08779 https://arxiv.org/abs/1904.08779
SpecAugment on Large Scale Datasets SpecAugment on Large Scale Datasets
https://arxiv.org/abs/1912.05533 https://arxiv.org/abs/1912.05533
""" """
def __init__(self, def __init__(self,
...@@ -41,7 +41,8 @@ class SpecAugmentor(AugmentorBase): ...@@ -41,7 +41,8 @@ class SpecAugmentor(AugmentorBase):
W=40, W=40,
adaptive_number_ratio=0, adaptive_number_ratio=0,
adaptive_size_ratio=0, adaptive_size_ratio=0,
max_n_time_masks=20): max_n_time_masks=20,
**kwargs):
"""SpecAugment class. """SpecAugment class.
Args: Args:
rng (random.Random): random generator object. rng (random.Random): random generator object.
...@@ -121,7 +122,7 @@ class SpecAugmentor(AugmentorBase): ...@@ -121,7 +122,7 @@ class SpecAugmentor(AugmentorBase):
def time_mask(self): def time_mask(self):
return self._time_mask return self._time_mask
def time_warp(xs, W=40): def time_warp(self, xs, W=40):
raise NotImplementedError raise NotImplementedError
def mask_freq(self, xs, replace_with_zero=False): def mask_freq(self, xs, replace_with_zero=False):
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
"""Contains the text featurizer class.""" """Contains the text featurizer class."""
import sentencepiece as spm import sentencepiece as spm
from pprint import pformat
from ..utility import EOS from ..utility import EOS
from ..utility import SPACE from ..utility import SPACE
...@@ -206,7 +207,7 @@ class TextFeaturizer(): ...@@ -206,7 +207,7 @@ class TextFeaturizer():
"""Load vocabulary from file.""" """Load vocabulary from file."""
vocab_list = load_dict(vocab_filepath, maskctc) vocab_list = load_dict(vocab_filepath, maskctc)
assert vocab_list is not None assert vocab_list is not None
logger.info(f"Vocab: {vocab_list}") logger.info(f"Vocab: {pformat(vocab_list)}")
id2token = dict( id2token = dict(
[(idx, token) for (idx, token) in enumerate(vocab_list)]) [(idx, token) for (idx, token) in enumerate(vocab_list)])
...@@ -220,10 +221,10 @@ class TextFeaturizer(): ...@@ -220,10 +221,10 @@ class TextFeaturizer():
sos_id = vocab_list.index(SOS) if SOS in vocab_list else -1 sos_id = vocab_list.index(SOS) if SOS in vocab_list else -1
space_id = vocab_list.index(SPACE) if SPACE in vocab_list else -1 space_id = vocab_list.index(SPACE) if SPACE in vocab_list else -1
logger.info(f"BLANK id: {blank_id}")
logger.info(f"UNK id: {unk_id}") logger.info(f"UNK id: {unk_id}")
logger.info(f"EOS id: {eos_id}") logger.info(f"EOS id: {eos_id}")
logger.info(f"SOS id: {sos_id}") logger.info(f"SOS id: {sos_id}")
logger.info(f"SPACE id: {space_id}") logger.info(f"SPACE id: {space_id}")
logger.info(f"BLANK id: {blank_id}")
logger.info(f"MASKCTC id: {maskctc_id}") logger.info(f"MASKCTC id: {maskctc_id}")
return token2id, id2token, vocab_list, unk_id, eos_id return token2id, id2token, vocab_list, unk_id, eos_id
...@@ -911,8 +911,10 @@ class U2Model(U2BaseModel): ...@@ -911,8 +911,10 @@ class U2Model(U2BaseModel):
DeepSpeech2Model: The model built from pretrained result. DeepSpeech2Model: The model built from pretrained result.
""" """
with UpdateConfig(config): with UpdateConfig(config):
config.input_dim = dataloader.collate_fn.feature_size #config.input_dim = dataloader.collate_fn.feature_size
config.output_dim = dataloader.collate_fn.vocab_size #config.output_dim = dataloader.collate_fn.vocab_size
config.input_dim = dataloader.dataset.feature_size
config.output_dim = dataloader.dataset.vocab_size
model = cls.from_config(config) model = cls.from_config(config)
......
...@@ -17,6 +17,7 @@ from pathlib import Path ...@@ -17,6 +17,7 @@ from pathlib import Path
import paddle import paddle
from paddle import distributed as dist from paddle import distributed as dist
from paddle.io import DistributedBatchSampler
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
from deepspeech.utils import mp_tools from deepspeech.utils import mp_tools
...@@ -179,8 +180,10 @@ class Trainer(): ...@@ -179,8 +180,10 @@ class Trainer():
"""Reset the train loader seed and increment `epoch`. """Reset the train loader seed and increment `epoch`.
""" """
self.epoch += 1 self.epoch += 1
if self.parallel: if hasattr(self.train_loader, "batch_sampler"):
self.train_loader.batch_sampler.set_epoch(self.epoch) batch_sampler = self.train_loader.batch_sampler
if isinstance(batch_sampler, DistributedBatchSampler):
batch_sampler.set_epoch(self.epoch)
def train(self): def train(self):
"""The training process control by epoch.""" """The training process control by epoch."""
...@@ -190,8 +193,10 @@ class Trainer(): ...@@ -190,8 +193,10 @@ class Trainer():
self.save(tag='init') self.save(tag='init')
self.lr_scheduler.step(self.iteration) self.lr_scheduler.step(self.iteration)
if self.parallel: if hasattr(self.train_loader, "batch_sampler"):
self.train_loader.batch_sampler.set_epoch(self.epoch) batch_sampler = self.train_loader.batch_sampler
if isinstance(batch_sampler, DistributedBatchSampler):
batch_sampler.set_epoch(self.epoch)
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.training.n_epoch: while self.epoch < self.config.training.n_epoch:
......
[ [
{
"type": "speed",
"params": {
"min_speed_rate": 0.9,
"max_speed_rate": 1.1,
"num_rates": 3
},
"prob": 0.0
},
{ {
"type": "shift", "type": "shift",
"params": { "params": {
...@@ -6,5 +15,22 @@ ...@@ -6,5 +15,22 @@
"max_shift_ms": 5 "max_shift_ms": 5
}, },
"prob": 1.0 "prob": 1.0
},
{
"type": "specaug",
"params": {
"W": 5,
"warp_mode": "PIL",
"F": 30,
"n_freq_masks": 2,
"T": 40,
"n_time_masks": 2,
"p": 1.0,
"adaptive_number_ratio": 0,
"adaptive_size_ratio": 0,
"max_n_time_masks": 20,
"replace_with_zero": true
},
"prob": 1.0
} }
] ]
#! /usr/bin/env bash #!/bin/bash
stage=-1 stage=-1
stop_stage=100 stop_stage=100
......
export MAIN_ROOT=${PWD}/../../../ export MAIN_ROOT=`realpath ${PWD}/../../../`
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH} export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
export LC_ALL=C export LC_ALL=C
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册