diff --git a/deepspeech/training/cli.py b/deepspeech/training/cli.py index ecd7a8f267e536e708fc4aea5eaba695780f359d..7f4bb804832a1df7dcfabf8fb48ee39ddaed8d5a 100644 --- a/deepspeech/training/cli.py +++ b/deepspeech/training/cli.py @@ -64,7 +64,7 @@ def default_argument_parser(): help="options to overwrite --config file and the default config, passing in KEY VALUE pairs") parser.add_argument("--seed", type=int, default=None, - help="seed to use for paddle, np and random. The default value is None") + help="seed to use for paddle, np and random. None or 0 for random, else set seed.") # yapd: enable return parser diff --git a/deepspeech/training/extensions/__init__.py b/deepspeech/training/extensions/__init__.py index 7ea7470eedfc4d32de9670e761438af146cc19be..6ad04155931b1071c6fe746c3befaf07bda91051 100644 --- a/deepspeech/training/extensions/__init__.py +++ b/deepspeech/training/extensions/__init__.py @@ -1,8 +1,21 @@ - +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import Callable from .extension import Extension + def make_extension(trigger: Callable=None, default_name: str=None, priority: int=None, @@ -25,4 +38,4 @@ def make_extension(trigger: Callable=None, ext.initialize = initializer return ext - return decorator \ No newline at end of file + return decorator diff --git a/deepspeech/training/extensions/evaluator.py b/deepspeech/training/extensions/evaluator.py index ffb7b3a246981cb1f1d5bb07857b652324bb442e..96ff967f53a14203d313aaf024799d95b6fd307f 100644 --- a/deepspeech/training/extensions/evaluator.py +++ b/deepspeech/training/extensions/evaluator.py @@ -1,10 +1,23 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import Dict +import extension import paddle from paddle.io import DataLoader from paddle.nn import Layer -import extension from ..reporter import DictSummary from ..reporter import report from ..reporter import scope @@ -55,4 +68,4 @@ class StandardEvaluator(extension.Extension): # or otherwise, you can use your own observation summary = self.evaluate() for k, v in summary.items(): - report(k, v) \ No newline at end of file + report(k, v) diff --git a/deepspeech/training/extensions/extension.py b/deepspeech/training/extensions/extension.py index f8fcede3fe728b0eccc485e1e55e51c0d33f02be..02f924951304a5c83e4354297f12919033dc265b 100644 --- a/deepspeech/training/extensions/extension.py +++ b/deepspeech/training/extensions/extension.py @@ -1,5 +1,16 @@ -from typing import Callable - +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. PRIORITY_WRITER = 300 PRIORITY_EDITOR = 200 PRIORITY_READER = 100 @@ -38,4 +49,4 @@ class Extension(): """Action that is executed when training is done. For example, visualizers would need to be closed. """ - pass \ No newline at end of file + pass diff --git a/deepspeech/training/extensions/snapshot.py b/deepspeech/training/extensions/snapshot.py index a15537a05eb27bd7ea37dff7393fa8e7e0393c61..cb4e6dfbff84d4dfed91af599b1bb040f37e9660 100644 --- a/deepspeech/training/extensions/snapshot.py +++ b/deepspeech/training/extensions/snapshot.py @@ -1,3 +1,16 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import os from datetime import datetime from pathlib import Path @@ -7,11 +20,10 @@ from typing import List import jsonlines -from deepspeech.training.updaters.trainer import Trainer from deepspeech.training.extensions import extension -from deepspeech.utils.mp_tools import rank_zero_only - +from deepspeech.training.updaters.trainer import Trainer from deepspeech.utils.log import Log +from deepspeech.utils.mp_tools import rank_zero_only logger = Log(__name__).getlog() @@ -75,7 +87,7 @@ class Snapshot(extension.Extension): """Saving new snapshot and remove the oldest snapshot if needed.""" iteration = trainer.updater.state.iteration epoch = trainer.updater.state.epoch - num = epoch if self.trigger[1] is 'epoch' else iteration + num = epoch if self.trigger[1] == 'epoch' else iteration path = self.checkpoint_dir / f"{num}.pdz" # add the new one @@ -99,4 +111,4 @@ class Snapshot(extension.Extension): with jsonlines.open(record_path, 'w') as writer: for record in self.records: # jsonlines.open may return a Writer or a Reader - writer.write(record) # pylint: disable=no-member \ No newline at end of file + writer.write(record) # pylint: disable=no-member diff --git a/deepspeech/training/extensions/visualizer.py b/deepspeech/training/extensions/visualizer.py index 92e07704a51c6aabfee18f6d092e358079f5c311..b69e94aaf4bbd42bb1bb50010af86e419d7c7ddb 100644 --- a/deepspeech/training/extensions/visualizer.py +++ b/deepspeech/training/extensions/visualizer.py @@ -1,3 +1,16 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from deepspeech.training.extensions import extension from deepspeech.training.updaters.trainer import Trainer @@ -21,4 +34,4 @@ class VisualDL(extension.Extension): self.writer.add_scalar(k, v, step=trainer.updater.state.iteration) def finalize(self, trainer): - self.writer.close() \ No newline at end of file + self.writer.close() diff --git a/deepspeech/training/reporter.py b/deepspeech/training/reporter.py index a5f79fb0e711d363f9158aa6382149189d03966b..66a81adef1c47f8fe55ad8d608daaa2cb97545ff 100644 --- a/deepspeech/training/reporter.py +++ b/deepspeech/training/reporter.py @@ -1,3 +1,16 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import contextlib import math from collections import defaultdict @@ -128,4 +141,4 @@ class DictSummary(): stats[name] = mean stats[name + '.std'] = std - return stats \ No newline at end of file + return stats diff --git a/deepspeech/training/trainer.py b/deepspeech/training/trainer.py index 866be552da1537ac5a32720b1fdb3cbd35e7509b..3a922c6f4f88f03dadf20e8e978a84bcf436a58a 100644 --- a/deepspeech/training/trainer.py +++ b/deepspeech/training/trainer.py @@ -11,11 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import random import time from pathlib import Path -import numpy as np import paddle from paddle import distributed as dist from tensorboardX import SummaryWriter @@ -23,6 +21,7 @@ from tensorboardX import SummaryWriter from deepspeech.utils import mp_tools from deepspeech.utils.checkpoint import Checkpoint from deepspeech.utils.log import Log +from deepspeech.utils.utility import seed_all __all__ = ["Trainer"] @@ -95,13 +94,10 @@ class Trainer(): self.checkpoint_dir = None self.iteration = 0 self.epoch = 0 - if args.seed is not None: - self.set_seed(args.seed) - def set_seed(self, seed): - np.random.seed(seed) - random.seed(seed) - paddle.seed(seed) + if args.seed: + seed_all(args.seed) + logger.info(f"Set seed {args.seed}") def setup(self): """Setup the experiment. @@ -182,7 +178,9 @@ class Trainer(): """ self.epoch += 1 if self.parallel and hasattr(self.train_loader, "batch_sampler"): - self.train_loader.batch_sampler.set_epoch(self.epoch) + batch_sampler = self.train_loader.batch_sampler + if isinstance(batch_sampler, paddle.io.DistributedBatchSampler): + batch_sampler.set_epoch(self.epoch) def train(self): """The training process control by epoch.""" diff --git a/deepspeech/training/triggers/__init__.py b/deepspeech/training/triggers/__init__.py index 9da7e6153a8d4cb4fd2208dd2ba02cdaac5d9dab..1a7c4292e6e1f81e4a34efb517c05f58c5d8f1fe 100644 --- a/deepspeech/training/triggers/__init__.py +++ b/deepspeech/training/triggers/__init__.py @@ -1,8 +1,23 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from .interval_trigger import IntervalTrigger + def never_fail_trigger(trainer): return False + def get_trigger(trigger): if trigger is None: return never_fail_trigger @@ -10,4 +25,4 @@ def get_trigger(trigger): return trigger else: trigger = IntervalTrigger(*trigger) - return trigger \ No newline at end of file + return trigger diff --git a/deepspeech/training/triggers/interval_trigger.py b/deepspeech/training/triggers/interval_trigger.py index ef80379cb8f8aa42c41c43ea1a51b16b6b14ad0f..1e04afad8d52ba6a8a272edebfba7f09f9784723 100644 --- a/deepspeech/training/triggers/interval_trigger.py +++ b/deepspeech/training/triggers/interval_trigger.py @@ -1,3 +1,17 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + class IntervalTrigger(): """A Predicate to do something every N cycle.""" @@ -21,4 +35,4 @@ class IntervalTrigger(): fire = index // self.period != last_index // self.period self.last_index = index - return fire \ No newline at end of file + return fire diff --git a/deepspeech/training/triggers/limit_trigger.py b/deepspeech/training/triggers/limit_trigger.py index ce13f940a8c3510676b1ccc55a628c2c8795b39d..ecd527ac5349486fa398c00e1171d8d1b51b293b 100644 --- a/deepspeech/training/triggers/limit_trigger.py +++ b/deepspeech/training/triggers/limit_trigger.py @@ -1,3 +1,17 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + class LimitTrigger(): """A Predicate to decide whether to stop.""" @@ -14,4 +28,4 @@ class LimitTrigger(): state = trainer.updater.state index = getattr(state, self.unit) fire = index >= self.limit - return fire \ No newline at end of file + return fire diff --git a/deepspeech/training/triggers/time_trigger.py b/deepspeech/training/triggers/time_trigger.py index 6232a12d3bcd728ed5004ccac054619afba0fa4f..ea8fe562c7f67a6732e88fda3518fcc526596a20 100644 --- a/deepspeech/training/triggers/time_trigger.py +++ b/deepspeech/training/triggers/time_trigger.py @@ -1,3 +1,18 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + class TimeTrigger(): """Trigger based on a fixed time interval. This trigger accepts iterations with a given interval time. @@ -14,4 +29,4 @@ class TimeTrigger(): self._next_time += self._period return True else: - return False \ No newline at end of file + return False diff --git a/deepspeech/training/updaters/__init__.py b/deepspeech/training/updaters/__init__.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..185a92b8d94d3426d616c0624f0f2ee04339349e 100644 --- a/deepspeech/training/updaters/__init__.py +++ b/deepspeech/training/updaters/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/deepspeech/training/updaters/standard_updater.py b/deepspeech/training/updaters/standard_updater.py index 062029ff78eca051b0d5f596aa5ef773078cd817..fc758e93e7390694a1bbd26763db4c941ebc85dd 100644 --- a/deepspeech/training/updaters/standard_updater.py +++ b/deepspeech/training/updaters/standard_updater.py @@ -1,3 +1,16 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import Dict from typing import Optional @@ -11,13 +24,13 @@ from timer import timer from deepspeech.training.reporter import report from deepspeech.training.updaters.updater import UpdaterBase from deepspeech.training.updaters.updater import UpdaterState - from deepspeech.utils.log import Log __all__ = ["StandardUpdater"] logger = Log(__name__).getlog() + class StandardUpdater(UpdaterBase): """An example of over-simplification. Things may not be that simple, but you can subclass it to fit your need. @@ -142,7 +155,7 @@ class StandardUpdater(UpdaterBase): """Start a new epoch.""" # NOTE: all batch sampler for distributed training should # subclass DistributedBatchSampler and implement `set_epoch` method - if hasattr(self.dataloader, "batch_sampler") + if hasattr(self.dataloader, "batch_sampler"): batch_sampler = self.dataloader.batch_sampler if isinstance(batch_sampler, DistributedBatchSampler): batch_sampler.set_epoch(self.state.epoch) @@ -176,4 +189,4 @@ class StandardUpdater(UpdaterBase): model.set_state_dict(state_dict[f"{name}_params"]) for name, optim in self.optimizers.items(): optim.set_state_dict(state_dict[f"{name}_optimizer"]) - super().set_state_dict(state_dict) \ No newline at end of file + super().set_state_dict(state_dict) diff --git a/deepspeech/training/updaters/trainer.py b/deepspeech/training/updaters/trainer.py index c7562ff06c07123f26b49947a399229f8676119a..954ce2604d18569b34c35be4fd517f74a59fc14e 100644 --- a/deepspeech/training/updaters/trainer.py +++ b/deepspeech/training/updaters/trainer.py @@ -1,3 +1,16 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import sys import traceback from collections import OrderedDict @@ -168,4 +181,4 @@ class Trainer(): finally: for name, entry in extensions: if hasattr(entry.extension, "finalize"): - entry.extension.finalize(self) \ No newline at end of file + entry.extension.finalize(self) diff --git a/deepspeech/training/updaters/updater.py b/deepspeech/training/updaters/updater.py index 548042d6a8e0e48e8445afd514cbc290309ad8d0..66fdc2bbc7aea7b7f08f1423a58736e6ffb3b068 100644 --- a/deepspeech/training/updaters/updater.py +++ b/deepspeech/training/updaters/updater.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass + import paddle from deepspeech.utils.log import Log @@ -79,4 +80,4 @@ class UpdaterBase(): def load(self, path): logger.debug(f"Loading from {path}.") archive = paddle.load(str(path)) - self.set_state_dict(archive) \ No newline at end of file + self.set_state_dict(archive) diff --git a/deepspeech/utils/utility.py b/deepspeech/utils/utility.py index a0639e0654faec299bb7353b950b122ee0103167..e18fc1f775fc1281dbcd05ecd25bc3de6d1cbef1 100644 --- a/deepspeech/utils/utility.py +++ b/deepspeech/utils/utility.py @@ -15,9 +15,19 @@ import distutils.util import math import os +import random from typing import List -__all__ = ['print_arguments', 'add_arguments', "log_add"] +import numpy as np +import paddle + +__all__ = ["seed_all", 'print_arguments', 'add_arguments', "log_add"] + + +def seed_all(seed: int=210329): + np.random.seed(seed) + random.seed(seed) + paddle.seed(seed) def print_arguments(args, info=None): diff --git a/examples/aishell/s0/README.md b/examples/aishell/s0/README.md index eedf92c963bbad7d06c12c5fd661ca4b18abd9c5..537496a6781840a6091931f6c5c7e30f1b109a83 100644 --- a/examples/aishell/s0/README.md +++ b/examples/aishell/s0/README.md @@ -3,11 +3,9 @@ ## Data | Data Subset | Duration in Seconds | | data/manifest.train | 1.23 ~ 14.53125 | -| data/manifest.dev | 1.645 ~ 12.533 | +| data/manifest.dev | 1.645 ~ 12.533 | | data/manifest.test | 1.859125 ~ 14.6999375 | -`jq '.feat_shape[0]' data/manifest.train | sort -un` - ## Deepspeech2 | Model | Params | Release | Config | Test set | Loss | CER | diff --git a/requirements.txt b/requirements.txt index 1ed5525ea6f5e2e964ba3cb81234b46adf46095b..7c3da37e131064d0327e728d115103d7e34da192 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ coverage gpustat +jsonlines kaldiio Pillow pre-commit @@ -15,4 +16,3 @@ tensorboardX textgrid typeguard yacs -jsonlines \ No newline at end of file