提交 38433729 编写于 作者: H Hui Zhang

u2 with chianer updater

上级 91bc5959
......@@ -80,23 +80,23 @@ def convert_dtype_to_string(tensor_dtype):
if not hasattr(paddle, 'softmax'):
logger.warn("register user softmax to paddle, remove this when fixed!")
logger.debug("register user softmax to paddle, remove this when fixed!")
setattr(paddle, 'softmax', paddle.nn.functional.softmax)
if not hasattr(paddle, 'log_softmax'):
logger.warn("register user log_softmax to paddle, remove this when fixed!")
logger.debug("register user log_softmax to paddle, remove this when fixed!")
setattr(paddle, 'log_softmax', paddle.nn.functional.log_softmax)
if not hasattr(paddle, 'sigmoid'):
logger.warn("register user sigmoid to paddle, remove this when fixed!")
logger.debug("register user sigmoid to paddle, remove this when fixed!")
setattr(paddle, 'sigmoid', paddle.nn.functional.sigmoid)
if not hasattr(paddle, 'log_sigmoid'):
logger.warn("register user log_sigmoid to paddle, remove this when fixed!")
logger.debug("register user log_sigmoid to paddle, remove this when fixed!")
setattr(paddle, 'log_sigmoid', paddle.nn.functional.log_sigmoid)
if not hasattr(paddle, 'relu'):
logger.warn("register user relu to paddle, remove this when fixed!")
logger.debug("register user relu to paddle, remove this when fixed!")
setattr(paddle, 'relu', paddle.nn.functional.relu)
......@@ -105,7 +105,7 @@ def cat(xs, dim=0):
if not hasattr(paddle, 'cat'):
logger.warn(
logger.debug(
"override cat of paddle if exists or register, remove this when fixed!")
paddle.cat = cat
......@@ -116,7 +116,7 @@ def item(x: paddle.Tensor):
if not hasattr(paddle.Tensor, 'item'):
logger.warn(
logger.debug(
"override item of paddle.Tensor if exists or register, remove this when fixed!"
)
paddle.Tensor.item = item
......@@ -127,13 +127,13 @@ def func_long(x: paddle.Tensor):
if not hasattr(paddle.Tensor, 'long'):
logger.warn(
logger.debug(
"override long of paddle.Tensor if exists or register, remove this when fixed!"
)
paddle.Tensor.long = func_long
if not hasattr(paddle.Tensor, 'numel'):
logger.warn(
logger.debug(
"override numel of paddle.Tensor if exists or register, remove this when fixed!"
)
paddle.Tensor.numel = paddle.numel
......@@ -147,7 +147,7 @@ def new_full(x: paddle.Tensor,
if not hasattr(paddle.Tensor, 'new_full'):
logger.warn(
logger.debug(
"override new_full of paddle.Tensor if exists or register, remove this when fixed!"
)
paddle.Tensor.new_full = new_full
......@@ -162,13 +162,13 @@ def eq(xs: paddle.Tensor, ys: Union[paddle.Tensor, float]) -> paddle.Tensor:
if not hasattr(paddle.Tensor, 'eq'):
logger.warn(
logger.debug(
"override eq of paddle.Tensor if exists or register, remove this when fixed!"
)
paddle.Tensor.eq = eq
if not hasattr(paddle, 'eq'):
logger.warn(
logger.debug(
"override eq of paddle if exists or register, remove this when fixed!")
paddle.eq = eq
......@@ -178,7 +178,7 @@ def contiguous(xs: paddle.Tensor) -> paddle.Tensor:
if not hasattr(paddle.Tensor, 'contiguous'):
logger.warn(
logger.debug(
"override contiguous of paddle.Tensor if exists or register, remove this when fixed!"
)
paddle.Tensor.contiguous = contiguous
......@@ -195,7 +195,7 @@ def size(xs: paddle.Tensor, *args: int) -> paddle.Tensor:
#`to_static` do not process `size` property, maybe some `paddle` api dependent on it.
logger.warn(
logger.debug(
"override size of paddle.Tensor "
"(`to_static` do not process `size` property, maybe some `paddle` api dependent on it), remove this when fixed!"
)
......@@ -207,7 +207,7 @@ def view(xs: paddle.Tensor, *args: int) -> paddle.Tensor:
if not hasattr(paddle.Tensor, 'view'):
logger.warn("register user view to paddle.Tensor, remove this when fixed!")
logger.debug("register user view to paddle.Tensor, remove this when fixed!")
paddle.Tensor.view = view
......@@ -216,7 +216,7 @@ def view_as(xs: paddle.Tensor, ys: paddle.Tensor) -> paddle.Tensor:
if not hasattr(paddle.Tensor, 'view_as'):
logger.warn(
logger.debug(
"register user view_as to paddle.Tensor, remove this when fixed!")
paddle.Tensor.view_as = view_as
......@@ -242,7 +242,7 @@ def masked_fill(xs: paddle.Tensor,
if not hasattr(paddle.Tensor, 'masked_fill'):
logger.warn(
logger.debug(
"register user masked_fill to paddle.Tensor, remove this when fixed!")
paddle.Tensor.masked_fill = masked_fill
......@@ -260,7 +260,7 @@ def masked_fill_(xs: paddle.Tensor,
if not hasattr(paddle.Tensor, 'masked_fill_'):
logger.warn(
logger.debug(
"register user masked_fill_ to paddle.Tensor, remove this when fixed!")
paddle.Tensor.masked_fill_ = masked_fill_
......@@ -272,7 +272,8 @@ def fill_(xs: paddle.Tensor, value: Union[float, int]) -> paddle.Tensor:
if not hasattr(paddle.Tensor, 'fill_'):
logger.warn("register user fill_ to paddle.Tensor, remove this when fixed!")
logger.debug(
"register user fill_ to paddle.Tensor, remove this when fixed!")
paddle.Tensor.fill_ = fill_
......@@ -281,22 +282,22 @@ def repeat(xs: paddle.Tensor, *size: Any) -> paddle.Tensor:
if not hasattr(paddle.Tensor, 'repeat'):
logger.warn(
logger.debug(
"register user repeat to paddle.Tensor, remove this when fixed!")
paddle.Tensor.repeat = repeat
if not hasattr(paddle.Tensor, 'softmax'):
logger.warn(
logger.debug(
"register user softmax to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'softmax', paddle.nn.functional.softmax)
if not hasattr(paddle.Tensor, 'sigmoid'):
logger.warn(
logger.debug(
"register user sigmoid to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'sigmoid', paddle.nn.functional.sigmoid)
if not hasattr(paddle.Tensor, 'relu'):
logger.warn("register user relu to paddle.Tensor, remove this when fixed!")
logger.debug("register user relu to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'relu', paddle.nn.functional.relu)
......@@ -305,7 +306,7 @@ def type_as(x: paddle.Tensor, other: paddle.Tensor) -> paddle.Tensor:
if not hasattr(paddle.Tensor, 'type_as'):
logger.warn(
logger.debug(
"register user type_as to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'type_as', type_as)
......@@ -321,7 +322,7 @@ def to(x: paddle.Tensor, *args, **kwargs) -> paddle.Tensor:
if not hasattr(paddle.Tensor, 'to'):
logger.warn("register user to to paddle.Tensor, remove this when fixed!")
logger.debug("register user to to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'to', to)
......@@ -330,7 +331,8 @@ def func_float(x: paddle.Tensor) -> paddle.Tensor:
if not hasattr(paddle.Tensor, 'float'):
logger.warn("register user float to paddle.Tensor, remove this when fixed!")
logger.debug(
"register user float to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'float', func_float)
......@@ -339,7 +341,7 @@ def func_int(x: paddle.Tensor) -> paddle.Tensor:
if not hasattr(paddle.Tensor, 'int'):
logger.warn("register user int to paddle.Tensor, remove this when fixed!")
logger.debug("register user int to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'int', func_int)
......@@ -348,6 +350,6 @@ def tolist(x: paddle.Tensor) -> List[Any]:
if not hasattr(paddle.Tensor, 'tolist'):
logger.warn(
logger.debug(
"register user tolist to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'tolist', tolist)
......@@ -18,10 +18,12 @@ import os
from paddle import distributed as dist
from deepspeech.exps.u2.config import get_cfg_defaults
from deepspeech.exps.u2.model import U2Trainer as Trainer
# from deepspeech.exps.u2.trainer import U2Trainer as Trainer
from deepspeech.training.cli import default_argument_parser
from deepspeech.utils.utility import print_arguments
from deepspeech.exps.u2.model import U2Trainer as Trainer
def main_sp(config, args):
exp = Trainer(config, args)
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains U2 model."""
import paddle
from paddle import distributed as dist
from paddle.io import DataLoader
from deepspeech.io.collator import SpeechCollator
from deepspeech.io.dataset import ManifestDataset
from deepspeech.io.sampler import SortagradBatchSampler
from deepspeech.io.sampler import SortagradDistributedBatchSampler
from deepspeech.models.u2 import U2Evaluator
from deepspeech.models.u2 import U2Model
from deepspeech.models.u2 import U2Updater
from deepspeech.training.extensions.snapshot import Snapshot
from deepspeech.training.extensions.visualizer import VisualDL
from deepspeech.training.optimizer import OptimizerFactory
from deepspeech.training.scheduler import LRSchedulerFactory
from deepspeech.training.timer import Timer
from deepspeech.training.trainer import Trainer
from deepspeech.training.updaters.trainer import Trainer as NewTrainer
from deepspeech.utils import layer_tools
from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
class U2Trainer(Trainer):
def __init__(self, config, args):
super().__init__(config, args)
def setup_dataloader(self):
config = self.config.clone()
config.defrost()
config.collator.keep_transcription_text = False
# train/valid dataset, return token ids
config.data.manifest = config.data.train_manifest
train_dataset = ManifestDataset.from_config(config)
config.data.manifest = config.data.dev_manifest
dev_dataset = ManifestDataset.from_config(config)
collate_fn_train = SpeechCollator.from_config(config)
config.collator.augmentation_config = ""
collate_fn_dev = SpeechCollator.from_config(config)
if self.parallel:
batch_sampler = SortagradDistributedBatchSampler(
train_dataset,
batch_size=config.collator.batch_size,
num_replicas=None,
rank=None,
shuffle=True,
drop_last=True,
sortagrad=config.collator.sortagrad,
shuffle_method=config.collator.shuffle_method)
else:
batch_sampler = SortagradBatchSampler(
train_dataset,
shuffle=True,
batch_size=config.collator.batch_size,
drop_last=True,
sortagrad=config.collator.sortagrad,
shuffle_method=config.collator.shuffle_method)
self.train_loader = DataLoader(
train_dataset,
batch_sampler=batch_sampler,
collate_fn=collate_fn_train,
num_workers=config.collator.num_workers, )
self.valid_loader = DataLoader(
dev_dataset,
batch_size=config.collator.batch_size,
shuffle=False,
drop_last=False,
collate_fn=collate_fn_dev)
# test dataset, return raw text
config.data.manifest = config.data.test_manifest
# filter test examples, will cause less examples, but no mismatch with training
# and can use large batch size , save training time, so filter test egs now.
config.data.min_input_len = 0.0 # second
config.data.max_input_len = float('inf') # second
config.data.min_output_len = 0.0 # tokens
config.data.max_output_len = float('inf') # tokens
config.data.min_output_input_ratio = 0.00
config.data.max_output_input_ratio = float('inf')
test_dataset = ManifestDataset.from_config(config)
# return text ord id
config.collator.keep_transcription_text = True
config.collator.augmentation_config = ""
self.test_loader = DataLoader(
test_dataset,
batch_size=config.decoding.batch_size,
shuffle=False,
drop_last=False,
collate_fn=SpeechCollator.from_config(config))
# return text token id
config.collator.keep_transcription_text = False
self.align_loader = DataLoader(
test_dataset,
batch_size=config.decoding.batch_size,
shuffle=False,
drop_last=False,
collate_fn=SpeechCollator.from_config(config))
logger.info("Setup train/valid/test/align Dataloader!")
def setup_model(self):
config = self.config
model_conf = config.model
model_conf.defrost()
model_conf.input_dim = self.train_loader.collate_fn.feature_size
model_conf.output_dim = self.train_loader.collate_fn.vocab_size
model_conf.freeze()
model = U2Model.from_config(model_conf)
if self.parallel:
model = paddle.DataParallel(model)
model.train()
logger.info(f"{model}")
layer_tools.print_params(model, logger.info)
train_config = config.training
optim_type = train_config.optim
optim_conf = train_config.optim_conf
scheduler_type = train_config.scheduler
scheduler_conf = train_config.scheduler_conf
scheduler_args = {
"learning_rate": optim_conf.lr,
"verbose": False,
"warmup_steps": scheduler_conf.warmup_steps,
"gamma": scheduler_conf.lr_decay,
"d_model": model_conf.encoder_conf.output_size,
}
lr_scheduler = LRSchedulerFactory.from_args(scheduler_type,
scheduler_args)
def optimizer_args(
config,
parameters,
lr_scheduler=None, ):
train_config = config.training
optim_type = train_config.optim
optim_conf = train_config.optim_conf
scheduler_type = train_config.scheduler
scheduler_conf = train_config.scheduler_conf
return {
"grad_clip": train_config.global_grad_clip,
"weight_decay": optim_conf.weight_decay,
"learning_rate": lr_scheduler
if lr_scheduler else optim_conf.lr,
"parameters": parameters,
"epsilon": 1e-9 if optim_type == 'noam' else None,
"beta1": 0.9 if optim_type == 'noam' else None,
"beat2": 0.98 if optim_type == 'noam' else None,
}
optimzer_args = optimizer_args(config, model.parameters(), lr_scheduler)
optimizer = OptimizerFactory.from_args(optim_type, optimzer_args)
self.model = model
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
logger.info("Setup model/optimizer/lr_scheduler!")
def setup_updater(self):
output_dir = self.output_dir
config = self.config.training
updater = U2Updater(
model=self.model,
optimizer=self.optimizer,
scheduler=self.lr_scheduler,
dataloader=self.train_loader,
output_dir=output_dir,
accum_grad=config.accum_grad)
trainer = NewTrainer(updater, (config.n_epoch, 'epoch'), output_dir)
evaluator = U2Evaluator(self.model, self.valid_loader)
trainer.extend(evaluator, trigger=(1, "epoch"))
if dist.get_rank() == 0:
trainer.extend(VisualDL(output_dir), trigger=(1, "iteration"))
num_snapshots = config.checkpoint.kbest_n
trainer.extend(
Snapshot(
mode='kbest',
max_size=num_snapshots,
indicator='VALID/LOSS',
less_better=True),
trigger=(1, 'epoch'))
# print(trainer.extensions)
# trainer.run()
self.trainer = trainer
def run(self):
"""The routine of the experiment after setup. This method is intended
to be used by the user.
"""
self.setup_updater()
with Timer("Training Done: {}"):
self.trainer.run()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .u2 import U2InferModel
from .u2 import U2Model
from .updater import U2Evaluator
from .updater import U2Updater
__all__ = ["U2Model", "U2InferModel", "U2Evaluator", "U2Updater"]
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import nullcontext
import paddle
from paddle import distributed as dist
from deepspeech.training.extensions.evaluator import StandardEvaluator
from deepspeech.training.reporter import report
from deepspeech.training.timer import Timer
from deepspeech.training.updaters.standard_updater import StandardUpdater
from deepspeech.utils import layer_tools
from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
class U2Evaluator(StandardEvaluator):
def __init__(self, model, dataloader):
super().__init__(model, dataloader)
self.msg = ""
self.num_seen_utts = 0
self.total_loss = 0.0
def evaluate_core(self, batch):
self.msg = "Valid: Rank: {}, ".format(dist.get_rank())
losses_dict = {}
loss, attention_loss, ctc_loss = self.model(*batch[1:])
if paddle.isfinite(loss):
num_utts = batch[1].shape[0]
self.num_seen_utts += num_utts
self.total_loss += float(loss) * num_utts
losses_dict['loss'] = float(loss)
if attention_loss:
losses_dict['att_loss'] = float(attention_loss)
if ctc_loss:
losses_dict['ctc_loss'] = float(ctc_loss)
for k, v in losses_dict.items():
report("eval/" + k, v)
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_dict.items())
logger.info(self.msg)
return self.total_loss, self.num_seen_utts
class U2Updater(StandardUpdater):
def __init__(self,
model,
optimizer,
scheduler,
dataloader,
init_state=None,
accum_grad=1,
**kwargs):
super().__init__(
model, optimizer, scheduler, dataloader, init_state=init_state)
self.accum_grad = accum_grad
self.forward_count = 0
self.msg = ""
def update_core(self, batch):
"""One Step
Args:
batch (List[Object]): utts, xs, xlens, ys, ylens
"""
losses_dict = {}
self.msg = "Rank: {}, ".format(dist.get_rank())
# forward
batch_size = batch[1].shape[0]
loss, attention_loss, ctc_loss = self.model(*batch[1:])
# loss div by `batch_size * accum_grad`
loss /= self.accum_grad
# loss backward
if (self.forward_count + 1) != self.accum_grad:
# Disable gradient synchronizations across DDP processes.
# Within this context, gradients will be accumulated on module
# variables, which will later be synchronized.
context = self.model.no_sync
else:
# Used for single gpu training and DDP gradient synchronization
# processes.
context = nullcontext
with context():
loss.backward()
layer_tools.print_grads(self.model, print_func=None)
# loss info
losses_dict['loss'] = float(loss) * self.accum_grad
if attention_loss:
losses_dict['att_loss'] = float(attention_loss)
if ctc_loss:
losses_dict['ctc_loss'] = float(ctc_loss)
# report loss
for k, v in losses_dict.items():
report("train/" + k, v)
# loss msg
self.msg += "batch size: {}, ".format(batch_size)
self.msg += "accum: {}, ".format(self.accum_grad)
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_dict.items())
# Truncate the graph
loss.detach()
# update parameters
self.forward_count += 1
if self.forward_count != self.accum_grad:
return
self.forward_count = 0
self.optimizer.step()
self.optimizer.clear_grad()
self.scheduler.step()
def update(self):
# model is default in train mode
# training for a step is implemented here
with Timer("data time cost:{}"):
batch = self.read_batch()
with Timer("step time cost:{}"):
self.update_core(batch)
# #iterations with accum_grad > 1
# Ref.: https://github.com/espnet/espnet/issues/777
if self.forward_count == 0:
self.state.iteration += 1
if self.updates_per_epoch is not None:
if self.state.iteration % self.updates_per_epoch == 0:
self.state.epoch += 1
......@@ -46,7 +46,7 @@ class CTCLoss(nn.Layer):
if grad_norm_type == 'instance':
self.norm_by_times = True
if grad_norm_type == 'batch':
self.norm_by_times = True
self.norm_by_batchsize = True
if grad_norm_type == 'frame':
self.norm_by_total_logits_len = True
......
......@@ -13,14 +13,18 @@
# limitations under the License.
from typing import Dict
import extension
import paddle
from paddle import distributed as dist
from paddle.io import DataLoader
from paddle.nn import Layer
from . import extension
from ..reporter import DictSummary
from ..reporter import report
from ..reporter import scope
from ..timer import Timer
from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
class StandardEvaluator(extension.Extension):
......@@ -43,6 +47,27 @@ class StandardEvaluator(extension.Extension):
def evaluate_core(self, batch):
# compute
self.model(batch) # you may report here
return
def evaluate_sync(self, data):
# dist sync `evaluate_core` outputs
if data is None:
return
numerator, denominator = data
if dist.get_world_size() > 1:
numerator = paddle.to_tensor(numerator)
denominator = paddle.to_tensor(denominator)
# the default operator in all_reduce function is sum.
dist.all_reduce(numerator)
dist.all_reduce(denominator)
value = numerator / denominator
value = float(value)
else:
value = numerator / denominator
# used for `snapshort` to do kbest save.
report("VALID/LOSS", value)
logger.info(f"Valid: all-reduce loss {value}")
def evaluate(self):
# switch to eval mode
......@@ -56,9 +81,13 @@ class StandardEvaluator(extension.Extension):
with scope(observation):
# main evaluation computation here.
with paddle.no_grad():
self.evaluate_core(batch)
self.evaluate_sync(self.evaluate_core(batch))
summary.add(observation)
summary = summary.compute_mean()
# switch to train mode
for model in self.models.values():
model.train()
return summary
def __call__(self, trainer=None):
......@@ -66,6 +95,7 @@ class StandardEvaluator(extension.Extension):
# if it is used to extend a trainer, the metrics is reported to
# to observation of the trainer
# or otherwise, you can use your own observation
summary = self.evaluate()
with Timer("Eval Time Cost: {}"):
summary = self.evaluate()
for k, v in summary.items():
report(k, v)
......@@ -20,8 +20,9 @@ from typing import List
import jsonlines
from deepspeech.training.extensions import extension
from deepspeech.training.updaters.trainer import Trainer
from . import extension
from ..reporter import get_observations
from ..updaters.trainer import Trainer
from deepspeech.utils.log import Log
from deepspeech.utils.mp_tools import rank_zero_only
......@@ -52,8 +53,19 @@ class Snapshot(extension.Extension):
priority = -100
default_name = "snapshot"
def __init__(self, max_size: int=5, snapshot_on_error: bool=False):
def __init__(self,
mode='latest',
max_size: int=5,
indicator=None,
less_better=True,
snapshot_on_error: bool=False):
self.records: List[Dict[str, Any]] = []
assert mode in ('latest', 'kbest'), mode
if mode == 'kbest':
assert indicator is not None
self.mode = mode
self.indicator = indicator
self.less_is_better = less_better
self.max_size = max_size
self._snapshot_on_error = snapshot_on_error
self._save_all = (max_size == -1)
......@@ -66,16 +78,17 @@ class Snapshot(extension.Extension):
# load existing records
record_path: Path = self.checkpoint_dir / "records.jsonl"
if record_path.exists():
logger.debug("Loading from an existing checkpoint dir")
self.records = load_records(record_path)
trainer.updater.load(self.records[-1]['path'])
ckpt_path = self.records[-1]['path']
logger.info(f"Loading from an existing checkpoint {ckpt_path}")
trainer.updater.load(ckpt_path)
def on_error(self, trainer, exc, tb):
if self._snapshot_on_error:
self.save_checkpoint_and_update(trainer)
self.save_checkpoint_and_update(trainer, 'latest')
def __call__(self, trainer: Trainer):
self.save_checkpoint_and_update(trainer)
self.save_checkpoint_and_update(trainer, self.mode)
def full(self):
"""Whether the number of snapshots it keeps track of is greater
......@@ -83,7 +96,7 @@ class Snapshot(extension.Extension):
return (not self._save_all) and len(self.records) > self.max_size
@rank_zero_only
def save_checkpoint_and_update(self, trainer: Trainer):
def save_checkpoint_and_update(self, trainer: Trainer, mode: str):
"""Saving new snapshot and remove the oldest snapshot if needed."""
iteration = trainer.updater.state.iteration
epoch = trainer.updater.state.epoch
......@@ -97,11 +110,17 @@ class Snapshot(extension.Extension):
'path': str(path.resolve()), # use absolute path
'iteration': iteration,
'epoch': epoch,
'indicator': get_observations()[self.indicator]
}
self.records.append(record)
# remove the earist
if self.full():
if mode == 'kbest':
self.records = sorted(
self.records,
key=lambda record: record['indicator'],
reverse=not self.less_is_better)
eariest_record = self.records[0]
os.remove(eariest_record["path"])
self.records.pop(0)
......
......@@ -11,8 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from deepspeech.training.extensions import extension
from deepspeech.training.updaters.trainer import Trainer
from visualdl import LogWriter
from . import extension
from ..updaters.trainer import Trainer
class VisualDL(extension.Extension):
......@@ -26,8 +28,8 @@ class VisualDL(extension.Extension):
default_name = 'visualdl'
priority = extension.PRIORITY_READER
def __init__(self, writer):
self.writer = writer
def __init__(self, output_dir):
self.writer = LogWriter(str(output_dir))
def __call__(self, trainer: Trainer):
for k, v in trainer.observation.items():
......
......@@ -171,7 +171,7 @@ class Trainer():
self.iteration = 0
self.epoch = 0
scratch = True
logger.info("Restore/Init checkpoint!")
return scratch
def new_epoch(self):
......
......@@ -14,12 +14,12 @@
from typing import Dict
from typing import Optional
from paddle import Tensor
import paddle
from paddle.io import DataLoader
from paddle.io import DistributedBatchSampler
from paddle.nn import Layer
from paddle.optimizer import Optimizer
from timer import timer
from paddle.optimizer.lr import LRScheduler
from deepspeech.training.reporter import report
from deepspeech.training.updaters.updater import UpdaterBase
......@@ -39,8 +39,10 @@ class StandardUpdater(UpdaterBase):
def __init__(self,
model: Layer,
optimizer: Optimizer,
scheduler: LRScheduler,
dataloader: DataLoader,
init_state: Optional[UpdaterState]=None):
super().__init__(init_state)
# it is designed to hold multiple models
models = {"main": model}
self.models: Dict[str, Layer] = models
......@@ -51,15 +53,14 @@ class StandardUpdater(UpdaterBase):
self.optimizer = optimizer
self.optimizers: Dict[str, Optimizer] = optimizers
# it is designed to hold multiple scheduler
schedulers = {"main": scheduler}
self.scheduler = scheduler
self.schedulers: Dict[str, LRScheduler] = schedulers
# dataloaders
self.dataloader = dataloader
# init state
if init_state is None:
self.state = UpdaterState()
else:
self.state = init_state
self.train_iterator = iter(dataloader)
def update(self):
......@@ -103,8 +104,10 @@ class StandardUpdater(UpdaterBase):
model.train()
# training for a step is implemented here
batch = self.read_batch()
self.update_core(batch)
with Timier("data time cost:{}"):
batch = self.read_batch()
with Timier("step time cost:{}"):
self.update_core(batch)
self.state.iteration += 1
if self.updates_per_epoch is not None:
......@@ -115,13 +118,14 @@ class StandardUpdater(UpdaterBase):
"""A simple case for a training step. Basic assumptions are:
Single model;
Single optimizer;
Single scheduler, and update learning rate each step;
A batch from the dataloader is just the input of the model;
The model return a single loss, or a dict containing serval losses.
Parameters updates at every batch, no gradient accumulation.
"""
loss = self.model(*batch)
if isinstance(loss, Tensor):
if isinstance(loss, paddle.Tensor):
loss_dict = {"main": loss}
else:
# Dict[str, Tensor]
......@@ -135,14 +139,15 @@ class StandardUpdater(UpdaterBase):
for name, loss_item in loss_dict.items():
report(name, float(loss_item))
self.optimizer.clear_gradient()
self.optimizer.clear_grad()
loss_dict["main"].backward()
self.optimizer.update()
self.optimizer.step()
self.scheduler.step()
@property
def updates_per_epoch(self):
"""Number of updater per epoch, determined by the length of the
dataloader."""
"""Number of steps per epoch,
determined by the length of the dataloader."""
length_of_dataloader = None
try:
length_of_dataloader = len(self.dataloader)
......@@ -163,18 +168,16 @@ class StandardUpdater(UpdaterBase):
def read_batch(self):
"""Read a batch from the data loader, auto renew when data is exhausted."""
with timer() as t:
try:
batch = next(self.train_iterator)
except StopIteration:
self.new_epoch()
batch = next(self.train_iterator)
logger.debug(
f"Read a batch takes {t.elapse}s.") # replace it with logger
try:
batch = next(self.train_iterator)
except StopIteration:
self.new_epoch()
batch = next(self.train_iterator)
return batch
def state_dict(self):
"""State dict of a Updater, model, optimizer and updater state are included."""
"""State dict of a Updater, model, optimizers/schedulers
and updater state are included."""
state_dict = super().state_dict()
for name, model in self.models.items():
state_dict[f"{name}_params"] = model.state_dict()
......@@ -184,7 +187,7 @@ class StandardUpdater(UpdaterBase):
def set_state_dict(self, state_dict):
"""Set state dict for a Updater. Parameters of models, states for
optimizers and UpdaterState are restored."""
optimizers/schedulers and UpdaterState are restored."""
for name, model in self.models.items():
model.set_state_dict(state_dict[f"{name}_params"])
for name, optim in self.optimizers.items():
......
......@@ -140,8 +140,8 @@ class Trainer():
try:
while not stop_trigger(self):
self.observation = {}
# set observation as the report target
# you can use report freely in Updater.update()
# set observation as the `report` target
# you can use `report` freely in Updater.update()
# updating parameters and state
with scope(self.observation):
......
......@@ -52,6 +52,7 @@ class UpdaterBase():
"""
def __init__(self, init_state=None):
# init state
if init_state is None:
self.state = UpdaterState()
else:
......
......@@ -114,13 +114,13 @@ class Checkpoint():
params_path = checkpoint_path + ".pdparams"
model_dict = paddle.load(params_path)
model.set_state_dict(model_dict)
logger.info("Rank {}: loaded model from {}".format(rank, params_path))
logger.info("Rank {}: Restore model from {}".format(rank, params_path))
optimizer_path = checkpoint_path + ".pdopt"
if optimizer and os.path.isfile(optimizer_path):
optimizer_dict = paddle.load(optimizer_path)
optimizer.set_state_dict(optimizer_dict)
logger.info("Rank {}: loaded optimizer state from {}".format(
logger.info("Rank {}: Restore optimizer state from {}".format(
rank, optimizer_path))
info_path = re.sub('.pdparams$', '.json', params_path)
......
......@@ -12,19 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import getpass
import logging
import os
import socket
import sys
from loguru import logger
from paddle import inference
FORMAT_STR = '[%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s'
DATE_FMT_STR = '%Y/%m/%d %H:%M:%S'
logging.basicConfig(
level=logging.DEBUG, format=FORMAT_STR, datefmt=DATE_FMT_STR)
def find_log_dir(log_dir=None):
"""Returns the most suitable directory to put log files into.
......@@ -98,59 +92,28 @@ def find_log_dir_and_names(program_name=None, log_dir=None):
class Log():
log_name = None
def __init__(self, logger=None):
self.logger = logging.getLogger(logger)
self.logger.setLevel(logging.DEBUG)
file_dir = os.getcwd() + '/log'
if not os.path.exists(file_dir):
os.mkdir(file_dir)
self.log_dir = file_dir
actual_log_dir, file_prefix, symlink_prefix = find_log_dir_and_names(
program_name=None, log_dir=self.log_dir)
basename = '%s.DEBUG.%d' % (file_prefix, os.getpid())
filename = os.path.join(actual_log_dir, basename)
if Log.log_name is None:
Log.log_name = filename
# Create a symlink to the log file with a canonical name.
symlink = os.path.join(actual_log_dir, symlink_prefix + '.DEBUG')
try:
if os.path.islink(symlink):
os.unlink(symlink)
os.symlink(os.path.basename(Log.log_name), symlink)
except EnvironmentError:
# If it fails, we're sad but it's no error. Commonly, this
# fails because the symlink was created by another user and so
# we can't modify it
pass
if not self.logger.hasHandlers():
formatter = logging.Formatter(fmt=FORMAT_STR, datefmt=DATE_FMT_STR)
fh = logging.FileHandler(Log.log_name)
fh.setLevel(logging.DEBUG)
fh.setFormatter(formatter)
self.logger.addHandler(fh)
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
ch.setFormatter(formatter)
self.logger.addHandler(ch)
# stop propagate for propagating may print
# log multiple times
self.logger.propagate = False
"""Default Logger for all."""
logger.remove()
logger.add(
sys.stdout,
level='INFO',
enqueue=True,
filter=lambda record: record['level'].no >= 20)
_, file_prefix, _ = find_log_dir_and_names()
sink_prefix = os.path.join("exp/log", file_prefix)
sink_path = sink_prefix[:-3] + "{time}.log"
logger.add(sink_path, level='DEBUG', enqueue=True, rotation="500 MB")
def __init__(self, name=None):
pass
def getlog(self):
return self.logger
return logger
class Autolog:
"""Just used by fullchain project"""
def __init__(self,
batch_size,
model_name="DeepSpeech",
......
......@@ -86,7 +86,7 @@ training:
lr_decay: 1.0
log_interval: 1
checkpoint:
kbest_n: 10
kbest_n: 2
latest_n: 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册