提交 673cc4a0 编写于 作者: H Hui Zhang

seed all with log; and format

上级 de98283b
......@@ -64,7 +64,7 @@ def default_argument_parser():
help="options to overwrite --config file and the default config, passing in KEY VALUE pairs")
parser.add_argument("--seed", type=int, default=None,
help="seed to use for paddle, np and random. The default value is None")
help="seed to use for paddle, np and random. None or 0 for random, else set seed.")
# yapd: enable
return parser
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable
from .extension import Extension
def make_extension(trigger: Callable=None,
default_name: str=None,
priority: int=None,
......@@ -25,4 +38,4 @@ def make_extension(trigger: Callable=None,
ext.initialize = initializer
return ext
return decorator
\ No newline at end of file
return decorator
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict
import extension
import paddle
from paddle.io import DataLoader
from paddle.nn import Layer
import extension
from ..reporter import DictSummary
from ..reporter import report
from ..reporter import scope
......@@ -55,4 +68,4 @@ class StandardEvaluator(extension.Extension):
# or otherwise, you can use your own observation
summary = self.evaluate()
for k, v in summary.items():
report(k, v)
\ No newline at end of file
report(k, v)
from typing import Callable
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
PRIORITY_WRITER = 300
PRIORITY_EDITOR = 200
PRIORITY_READER = 100
......@@ -38,4 +49,4 @@ class Extension():
"""Action that is executed when training is done.
For example, visualizers would need to be closed.
"""
pass
\ No newline at end of file
pass
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from datetime import datetime
from pathlib import Path
......@@ -7,11 +20,10 @@ from typing import List
import jsonlines
from deepspeech.training.updaters.trainer import Trainer
from deepspeech.training.extensions import extension
from deepspeech.utils.mp_tools import rank_zero_only
from deepspeech.training.updaters.trainer import Trainer
from deepspeech.utils.log import Log
from deepspeech.utils.mp_tools import rank_zero_only
logger = Log(__name__).getlog()
......@@ -75,7 +87,7 @@ class Snapshot(extension.Extension):
"""Saving new snapshot and remove the oldest snapshot if needed."""
iteration = trainer.updater.state.iteration
epoch = trainer.updater.state.epoch
num = epoch if self.trigger[1] is 'epoch' else iteration
num = epoch if self.trigger[1] == 'epoch' else iteration
path = self.checkpoint_dir / f"{num}.pdz"
# add the new one
......@@ -99,4 +111,4 @@ class Snapshot(extension.Extension):
with jsonlines.open(record_path, 'w') as writer:
for record in self.records:
# jsonlines.open may return a Writer or a Reader
writer.write(record) # pylint: disable=no-member
\ No newline at end of file
writer.write(record) # pylint: disable=no-member
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from deepspeech.training.extensions import extension
from deepspeech.training.updaters.trainer import Trainer
......@@ -21,4 +34,4 @@ class VisualDL(extension.Extension):
self.writer.add_scalar(k, v, step=trainer.updater.state.iteration)
def finalize(self, trainer):
self.writer.close()
\ No newline at end of file
self.writer.close()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import math
from collections import defaultdict
......@@ -128,4 +141,4 @@ class DictSummary():
stats[name] = mean
stats[name + '.std'] = std
return stats
\ No newline at end of file
return stats
......@@ -11,11 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import time
from pathlib import Path
import numpy as np
import paddle
from paddle import distributed as dist
from tensorboardX import SummaryWriter
......@@ -23,6 +21,7 @@ from tensorboardX import SummaryWriter
from deepspeech.utils import mp_tools
from deepspeech.utils.checkpoint import Checkpoint
from deepspeech.utils.log import Log
from deepspeech.utils.utility import seed_all
__all__ = ["Trainer"]
......@@ -95,13 +94,10 @@ class Trainer():
self.checkpoint_dir = None
self.iteration = 0
self.epoch = 0
if args.seed is not None:
self.set_seed(args.seed)
def set_seed(self, seed):
np.random.seed(seed)
random.seed(seed)
paddle.seed(seed)
if args.seed:
seed_all(args.seed)
logger.info(f"Set seed {args.seed}")
def setup(self):
"""Setup the experiment.
......@@ -182,7 +178,9 @@ class Trainer():
"""
self.epoch += 1
if self.parallel and hasattr(self.train_loader, "batch_sampler"):
self.train_loader.batch_sampler.set_epoch(self.epoch)
batch_sampler = self.train_loader.batch_sampler
if isinstance(batch_sampler, paddle.io.DistributedBatchSampler):
batch_sampler.set_epoch(self.epoch)
def train(self):
"""The training process control by epoch."""
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .interval_trigger import IntervalTrigger
def never_fail_trigger(trainer):
return False
def get_trigger(trigger):
if trigger is None:
return never_fail_trigger
......@@ -10,4 +25,4 @@ def get_trigger(trigger):
return trigger
else:
trigger = IntervalTrigger(*trigger)
return trigger
\ No newline at end of file
return trigger
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class IntervalTrigger():
"""A Predicate to do something every N cycle."""
......@@ -21,4 +35,4 @@ class IntervalTrigger():
fire = index // self.period != last_index // self.period
self.last_index = index
return fire
\ No newline at end of file
return fire
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class LimitTrigger():
"""A Predicate to decide whether to stop."""
......@@ -14,4 +28,4 @@ class LimitTrigger():
state = trainer.updater.state
index = getattr(state, self.unit)
fire = index >= self.limit
return fire
\ No newline at end of file
return fire
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class TimeTrigger():
"""Trigger based on a fixed time interval.
This trigger accepts iterations with a given interval time.
......@@ -14,4 +29,4 @@ class TimeTrigger():
self._next_time += self._period
return True
else:
return False
\ No newline at end of file
return False
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict
from typing import Optional
......@@ -11,13 +24,13 @@ from timer import timer
from deepspeech.training.reporter import report
from deepspeech.training.updaters.updater import UpdaterBase
from deepspeech.training.updaters.updater import UpdaterState
from deepspeech.utils.log import Log
__all__ = ["StandardUpdater"]
logger = Log(__name__).getlog()
class StandardUpdater(UpdaterBase):
"""An example of over-simplification. Things may not be that simple, but
you can subclass it to fit your need.
......@@ -142,7 +155,7 @@ class StandardUpdater(UpdaterBase):
"""Start a new epoch."""
# NOTE: all batch sampler for distributed training should
# subclass DistributedBatchSampler and implement `set_epoch` method
if hasattr(self.dataloader, "batch_sampler")
if hasattr(self.dataloader, "batch_sampler"):
batch_sampler = self.dataloader.batch_sampler
if isinstance(batch_sampler, DistributedBatchSampler):
batch_sampler.set_epoch(self.state.epoch)
......@@ -176,4 +189,4 @@ class StandardUpdater(UpdaterBase):
model.set_state_dict(state_dict[f"{name}_params"])
for name, optim in self.optimizers.items():
optim.set_state_dict(state_dict[f"{name}_optimizer"])
super().set_state_dict(state_dict)
\ No newline at end of file
super().set_state_dict(state_dict)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import traceback
from collections import OrderedDict
......@@ -168,4 +181,4 @@ class Trainer():
finally:
for name, entry in extensions:
if hasattr(entry.extension, "finalize"):
entry.extension.finalize(self)
\ No newline at end of file
entry.extension.finalize(self)
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
import paddle
from deepspeech.utils.log import Log
......@@ -79,4 +80,4 @@ class UpdaterBase():
def load(self, path):
logger.debug(f"Loading from {path}.")
archive = paddle.load(str(path))
self.set_state_dict(archive)
\ No newline at end of file
self.set_state_dict(archive)
......@@ -15,9 +15,19 @@
import distutils.util
import math
import os
import random
from typing import List
__all__ = ['print_arguments', 'add_arguments', "log_add"]
import numpy as np
import paddle
__all__ = ["seed_all", 'print_arguments', 'add_arguments', "log_add"]
def seed_all(seed: int=210329):
np.random.seed(seed)
random.seed(seed)
paddle.seed(seed)
def print_arguments(args, info=None):
......
......@@ -3,11 +3,9 @@
## Data
| Data Subset | Duration in Seconds |
| data/manifest.train | 1.23 ~ 14.53125 |
| data/manifest.dev | 1.645 ~ 12.533 |
| data/manifest.dev | 1.645 ~ 12.533 |
| data/manifest.test | 1.859125 ~ 14.6999375 |
`jq '.feat_shape[0]' data/manifest.train | sort -un`
## Deepspeech2
| Model | Params | Release | Config | Test set | Loss | CER |
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册