提交 673cc4a0 编写于 作者: H Hui Zhang

seed all with log; and format

上级 de98283b
...@@ -64,7 +64,7 @@ def default_argument_parser(): ...@@ -64,7 +64,7 @@ def default_argument_parser():
help="options to overwrite --config file and the default config, passing in KEY VALUE pairs") help="options to overwrite --config file and the default config, passing in KEY VALUE pairs")
parser.add_argument("--seed", type=int, default=None, parser.add_argument("--seed", type=int, default=None,
help="seed to use for paddle, np and random. The default value is None") help="seed to use for paddle, np and random. None or 0 for random, else set seed.")
# yapd: enable # yapd: enable
return parser return parser
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable from typing import Callable
from .extension import Extension from .extension import Extension
def make_extension(trigger: Callable=None, def make_extension(trigger: Callable=None,
default_name: str=None, default_name: str=None,
priority: int=None, priority: int=None,
...@@ -25,4 +38,4 @@ def make_extension(trigger: Callable=None, ...@@ -25,4 +38,4 @@ def make_extension(trigger: Callable=None,
ext.initialize = initializer ext.initialize = initializer
return ext return ext
return decorator return decorator
\ No newline at end of file
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict from typing import Dict
import extension
import paddle import paddle
from paddle.io import DataLoader from paddle.io import DataLoader
from paddle.nn import Layer from paddle.nn import Layer
import extension
from ..reporter import DictSummary from ..reporter import DictSummary
from ..reporter import report from ..reporter import report
from ..reporter import scope from ..reporter import scope
...@@ -55,4 +68,4 @@ class StandardEvaluator(extension.Extension): ...@@ -55,4 +68,4 @@ class StandardEvaluator(extension.Extension):
# or otherwise, you can use your own observation # or otherwise, you can use your own observation
summary = self.evaluate() summary = self.evaluate()
for k, v in summary.items(): for k, v in summary.items():
report(k, v) report(k, v)
\ No newline at end of file
from typing import Callable # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
PRIORITY_WRITER = 300 PRIORITY_WRITER = 300
PRIORITY_EDITOR = 200 PRIORITY_EDITOR = 200
PRIORITY_READER = 100 PRIORITY_READER = 100
...@@ -38,4 +49,4 @@ class Extension(): ...@@ -38,4 +49,4 @@ class Extension():
"""Action that is executed when training is done. """Action that is executed when training is done.
For example, visualizers would need to be closed. For example, visualizers would need to be closed.
""" """
pass pass
\ No newline at end of file
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os import os
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
...@@ -7,11 +20,10 @@ from typing import List ...@@ -7,11 +20,10 @@ from typing import List
import jsonlines import jsonlines
from deepspeech.training.updaters.trainer import Trainer
from deepspeech.training.extensions import extension from deepspeech.training.extensions import extension
from deepspeech.utils.mp_tools import rank_zero_only from deepspeech.training.updaters.trainer import Trainer
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
from deepspeech.utils.mp_tools import rank_zero_only
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
...@@ -75,7 +87,7 @@ class Snapshot(extension.Extension): ...@@ -75,7 +87,7 @@ class Snapshot(extension.Extension):
"""Saving new snapshot and remove the oldest snapshot if needed.""" """Saving new snapshot and remove the oldest snapshot if needed."""
iteration = trainer.updater.state.iteration iteration = trainer.updater.state.iteration
epoch = trainer.updater.state.epoch epoch = trainer.updater.state.epoch
num = epoch if self.trigger[1] is 'epoch' else iteration num = epoch if self.trigger[1] == 'epoch' else iteration
path = self.checkpoint_dir / f"{num}.pdz" path = self.checkpoint_dir / f"{num}.pdz"
# add the new one # add the new one
...@@ -99,4 +111,4 @@ class Snapshot(extension.Extension): ...@@ -99,4 +111,4 @@ class Snapshot(extension.Extension):
with jsonlines.open(record_path, 'w') as writer: with jsonlines.open(record_path, 'w') as writer:
for record in self.records: for record in self.records:
# jsonlines.open may return a Writer or a Reader # jsonlines.open may return a Writer or a Reader
writer.write(record) # pylint: disable=no-member writer.write(record) # pylint: disable=no-member
\ No newline at end of file
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from deepspeech.training.extensions import extension from deepspeech.training.extensions import extension
from deepspeech.training.updaters.trainer import Trainer from deepspeech.training.updaters.trainer import Trainer
...@@ -21,4 +34,4 @@ class VisualDL(extension.Extension): ...@@ -21,4 +34,4 @@ class VisualDL(extension.Extension):
self.writer.add_scalar(k, v, step=trainer.updater.state.iteration) self.writer.add_scalar(k, v, step=trainer.updater.state.iteration)
def finalize(self, trainer): def finalize(self, trainer):
self.writer.close() self.writer.close()
\ No newline at end of file
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib import contextlib
import math import math
from collections import defaultdict from collections import defaultdict
...@@ -128,4 +141,4 @@ class DictSummary(): ...@@ -128,4 +141,4 @@ class DictSummary():
stats[name] = mean stats[name] = mean
stats[name + '.std'] = std stats[name + '.std'] = std
return stats return stats
\ No newline at end of file
...@@ -11,11 +11,9 @@ ...@@ -11,11 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import random
import time import time
from pathlib import Path from pathlib import Path
import numpy as np
import paddle import paddle
from paddle import distributed as dist from paddle import distributed as dist
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
...@@ -23,6 +21,7 @@ from tensorboardX import SummaryWriter ...@@ -23,6 +21,7 @@ from tensorboardX import SummaryWriter
from deepspeech.utils import mp_tools from deepspeech.utils import mp_tools
from deepspeech.utils.checkpoint import Checkpoint from deepspeech.utils.checkpoint import Checkpoint
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
from deepspeech.utils.utility import seed_all
__all__ = ["Trainer"] __all__ = ["Trainer"]
...@@ -95,13 +94,10 @@ class Trainer(): ...@@ -95,13 +94,10 @@ class Trainer():
self.checkpoint_dir = None self.checkpoint_dir = None
self.iteration = 0 self.iteration = 0
self.epoch = 0 self.epoch = 0
if args.seed is not None:
self.set_seed(args.seed)
def set_seed(self, seed): if args.seed:
np.random.seed(seed) seed_all(args.seed)
random.seed(seed) logger.info(f"Set seed {args.seed}")
paddle.seed(seed)
def setup(self): def setup(self):
"""Setup the experiment. """Setup the experiment.
...@@ -182,7 +178,9 @@ class Trainer(): ...@@ -182,7 +178,9 @@ class Trainer():
""" """
self.epoch += 1 self.epoch += 1
if self.parallel and hasattr(self.train_loader, "batch_sampler"): if self.parallel and hasattr(self.train_loader, "batch_sampler"):
self.train_loader.batch_sampler.set_epoch(self.epoch) batch_sampler = self.train_loader.batch_sampler
if isinstance(batch_sampler, paddle.io.DistributedBatchSampler):
batch_sampler.set_epoch(self.epoch)
def train(self): def train(self):
"""The training process control by epoch.""" """The training process control by epoch."""
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .interval_trigger import IntervalTrigger from .interval_trigger import IntervalTrigger
def never_fail_trigger(trainer): def never_fail_trigger(trainer):
return False return False
def get_trigger(trigger): def get_trigger(trigger):
if trigger is None: if trigger is None:
return never_fail_trigger return never_fail_trigger
...@@ -10,4 +25,4 @@ def get_trigger(trigger): ...@@ -10,4 +25,4 @@ def get_trigger(trigger):
return trigger return trigger
else: else:
trigger = IntervalTrigger(*trigger) trigger = IntervalTrigger(*trigger)
return trigger return trigger
\ No newline at end of file
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class IntervalTrigger(): class IntervalTrigger():
"""A Predicate to do something every N cycle.""" """A Predicate to do something every N cycle."""
...@@ -21,4 +35,4 @@ class IntervalTrigger(): ...@@ -21,4 +35,4 @@ class IntervalTrigger():
fire = index // self.period != last_index // self.period fire = index // self.period != last_index // self.period
self.last_index = index self.last_index = index
return fire return fire
\ No newline at end of file
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class LimitTrigger(): class LimitTrigger():
"""A Predicate to decide whether to stop.""" """A Predicate to decide whether to stop."""
...@@ -14,4 +28,4 @@ class LimitTrigger(): ...@@ -14,4 +28,4 @@ class LimitTrigger():
state = trainer.updater.state state = trainer.updater.state
index = getattr(state, self.unit) index = getattr(state, self.unit)
fire = index >= self.limit fire = index >= self.limit
return fire return fire
\ No newline at end of file
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class TimeTrigger(): class TimeTrigger():
"""Trigger based on a fixed time interval. """Trigger based on a fixed time interval.
This trigger accepts iterations with a given interval time. This trigger accepts iterations with a given interval time.
...@@ -14,4 +29,4 @@ class TimeTrigger(): ...@@ -14,4 +29,4 @@ class TimeTrigger():
self._next_time += self._period self._next_time += self._period
return True return True
else: else:
return False return False
\ No newline at end of file
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict from typing import Dict
from typing import Optional from typing import Optional
...@@ -11,13 +24,13 @@ from timer import timer ...@@ -11,13 +24,13 @@ from timer import timer
from deepspeech.training.reporter import report from deepspeech.training.reporter import report
from deepspeech.training.updaters.updater import UpdaterBase from deepspeech.training.updaters.updater import UpdaterBase
from deepspeech.training.updaters.updater import UpdaterState from deepspeech.training.updaters.updater import UpdaterState
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
__all__ = ["StandardUpdater"] __all__ = ["StandardUpdater"]
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
class StandardUpdater(UpdaterBase): class StandardUpdater(UpdaterBase):
"""An example of over-simplification. Things may not be that simple, but """An example of over-simplification. Things may not be that simple, but
you can subclass it to fit your need. you can subclass it to fit your need.
...@@ -142,7 +155,7 @@ class StandardUpdater(UpdaterBase): ...@@ -142,7 +155,7 @@ class StandardUpdater(UpdaterBase):
"""Start a new epoch.""" """Start a new epoch."""
# NOTE: all batch sampler for distributed training should # NOTE: all batch sampler for distributed training should
# subclass DistributedBatchSampler and implement `set_epoch` method # subclass DistributedBatchSampler and implement `set_epoch` method
if hasattr(self.dataloader, "batch_sampler") if hasattr(self.dataloader, "batch_sampler"):
batch_sampler = self.dataloader.batch_sampler batch_sampler = self.dataloader.batch_sampler
if isinstance(batch_sampler, DistributedBatchSampler): if isinstance(batch_sampler, DistributedBatchSampler):
batch_sampler.set_epoch(self.state.epoch) batch_sampler.set_epoch(self.state.epoch)
...@@ -176,4 +189,4 @@ class StandardUpdater(UpdaterBase): ...@@ -176,4 +189,4 @@ class StandardUpdater(UpdaterBase):
model.set_state_dict(state_dict[f"{name}_params"]) model.set_state_dict(state_dict[f"{name}_params"])
for name, optim in self.optimizers.items(): for name, optim in self.optimizers.items():
optim.set_state_dict(state_dict[f"{name}_optimizer"]) optim.set_state_dict(state_dict[f"{name}_optimizer"])
super().set_state_dict(state_dict) super().set_state_dict(state_dict)
\ No newline at end of file
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys import sys
import traceback import traceback
from collections import OrderedDict from collections import OrderedDict
...@@ -168,4 +181,4 @@ class Trainer(): ...@@ -168,4 +181,4 @@ class Trainer():
finally: finally:
for name, entry in extensions: for name, entry in extensions:
if hasattr(entry.extension, "finalize"): if hasattr(entry.extension, "finalize"):
entry.extension.finalize(self) entry.extension.finalize(self)
\ No newline at end of file
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from dataclasses import dataclass from dataclasses import dataclass
import paddle import paddle
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
...@@ -79,4 +80,4 @@ class UpdaterBase(): ...@@ -79,4 +80,4 @@ class UpdaterBase():
def load(self, path): def load(self, path):
logger.debug(f"Loading from {path}.") logger.debug(f"Loading from {path}.")
archive = paddle.load(str(path)) archive = paddle.load(str(path))
self.set_state_dict(archive) self.set_state_dict(archive)
\ No newline at end of file
...@@ -15,9 +15,19 @@ ...@@ -15,9 +15,19 @@
import distutils.util import distutils.util
import math import math
import os import os
import random
from typing import List from typing import List
__all__ = ['print_arguments', 'add_arguments', "log_add"] import numpy as np
import paddle
__all__ = ["seed_all", 'print_arguments', 'add_arguments', "log_add"]
def seed_all(seed: int=210329):
np.random.seed(seed)
random.seed(seed)
paddle.seed(seed)
def print_arguments(args, info=None): def print_arguments(args, info=None):
......
...@@ -3,11 +3,9 @@ ...@@ -3,11 +3,9 @@
## Data ## Data
| Data Subset | Duration in Seconds | | Data Subset | Duration in Seconds |
| data/manifest.train | 1.23 ~ 14.53125 | | data/manifest.train | 1.23 ~ 14.53125 |
| data/manifest.dev | 1.645 ~ 12.533 | | data/manifest.dev | 1.645 ~ 12.533 |
| data/manifest.test | 1.859125 ~ 14.6999375 | | data/manifest.test | 1.859125 ~ 14.6999375 |
`jq '.feat_shape[0]' data/manifest.train | sort -un`
## Deepspeech2 ## Deepspeech2
| Model | Params | Release | Config | Test set | Loss | CER | | Model | Params | Release | Config | Test set | Loss | CER |
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册