diff --git a/deepspeech/utils/checkpoint.py b/deepspeech/utils/checkpoint.py index 000fa87bab7c3cd412355ef2d9f7f1c6c6e26723..8c5d8d6052e1cc2cdf952fea086c2a7c2dca6bdb 100644 --- a/deepspeech/utils/checkpoint.py +++ b/deepspeech/utils/checkpoint.py @@ -24,7 +24,6 @@ from paddle.optimizer import Optimizer from deepspeech.utils import mp_tools from deepspeech.utils.log import Log -# import operator logger = Log(__name__).getlog() @@ -38,7 +37,7 @@ class Checkpoint(object): self.kbest_n = kbest_n self.latest_n = latest_n self._save_all = (kbest_n == -1) - + def add_checkpoint(self, checkpoint_dir, tag_or_iteration, @@ -64,10 +63,10 @@ class Checkpoint(object): self._save_checkpoint_record(checkpoint_dir, tag_or_iteration) def load_latest_parameters(self, - model, - optimizer=None, - checkpoint_dir=None, - checkpoint_path=None): + model, + optimizer=None, + checkpoint_dir=None, + checkpoint_path=None): """Load a last model checkpoint from disk. Args: model (Layer): model to load parameters. @@ -80,14 +79,14 @@ class Checkpoint(object): Returns: configs (dict): epoch or step, lr and other meta info should be saved. """ - return self._load_parameters(model, optimizer, checkpoint_dir, checkpoint_path, - "checkpoint_latest") + return self._load_parameters(model, optimizer, checkpoint_dir, + checkpoint_path, "checkpoint_latest") def load_best_parameters(self, - model, - optimizer=None, - checkpoint_dir=None, - checkpoint_path=None): + model, + optimizer=None, + checkpoint_dir=None, + checkpoint_path=None): """Load a last model checkpoint from disk. Args: model (Layer): model to load parameters. @@ -100,8 +99,8 @@ class Checkpoint(object): Returns: configs (dict): epoch or step, lr and other meta info should be saved. """ - return self._load_parameters(model, optimizer, checkpoint_dir, checkpoint_path, - "checkpoint_best") + return self._load_parameters(model, optimizer, checkpoint_dir, + checkpoint_path, "checkpoint_best") def _should_save_best(self, metric: float) -> bool: if not self._best_full(): @@ -248,7 +247,6 @@ class Checkpoint(object): configs = json.load(fin) return configs - @mp_tools.rank_zero_only def _save_parameters(self, checkpoint_dir: str, diff --git a/examples/aishell/s0/conf/deepspeech2.yaml b/examples/aishell/s0/conf/deepspeech2.yaml index 54ce240e7c4eb1e53e81477994ab1ae27f0c1db3..27ede01bc60ee8634152ebc85bfccbd07cc00bbf 100644 --- a/examples/aishell/s0/conf/deepspeech2.yaml +++ b/examples/aishell/s0/conf/deepspeech2.yaml @@ -48,6 +48,9 @@ training: weight_decay: 1e-06 global_grad_clip: 3.0 log_interval: 100 + checkpoint: + kbest_n: 50 + latest_n: 5 decoding: batch_size: 128 diff --git a/examples/aishell/s1/conf/chunk_conformer.yaml b/examples/aishell/s1/conf/chunk_conformer.yaml index 904624c3ccd0915685d41ccece38d6478f18b5f1..1065dcb036292c1f55b8457d4df2d335fb8291a5 100644 --- a/examples/aishell/s1/conf/chunk_conformer.yaml +++ b/examples/aishell/s1/conf/chunk_conformer.yaml @@ -90,6 +90,9 @@ training: warmup_steps: 25000 lr_decay: 1.0 log_interval: 100 + checkpoint: + kbest_n: 50 + latest_n: 5 decoding: diff --git a/examples/aishell/s1/conf/conformer.yaml b/examples/aishell/s1/conf/conformer.yaml index 116c919279134bf7ca7f3aa9c50171ca1488be82..4b1430c58848a5cac7303518021a6256b52d525d 100644 --- a/examples/aishell/s1/conf/conformer.yaml +++ b/examples/aishell/s1/conf/conformer.yaml @@ -88,6 +88,9 @@ training: warmup_steps: 25000 lr_decay: 1.0 log_interval: 100 + checkpoint: + kbest_n: 50 + latest_n: 5 decoding: diff --git a/examples/librispeech/s0/conf/deepspeech2.yaml b/examples/librispeech/s0/conf/deepspeech2.yaml index d1746bff398ad224f28f082ea2dc227ff7624eb7..9f06a3802f6dc4ae833f48eb5e070b1c1b3f2dc5 100644 --- a/examples/librispeech/s0/conf/deepspeech2.yaml +++ b/examples/librispeech/s0/conf/deepspeech2.yaml @@ -43,6 +43,9 @@ training: weight_decay: 1e-06 global_grad_clip: 5.0 log_interval: 100 + checkpoint: + kbest_n: 50 + latest_n: 5 decoding: batch_size: 128 diff --git a/examples/librispeech/s1/conf/chunk_confermer.yaml b/examples/librispeech/s1/conf/chunk_confermer.yaml index ec945a188bd2f66c12b34dc8499612b38b0912c5..97912163943c1c047da4daa90080f6ab96b21d8a 100644 --- a/examples/librispeech/s1/conf/chunk_confermer.yaml +++ b/examples/librispeech/s1/conf/chunk_confermer.yaml @@ -91,6 +91,9 @@ training: warmup_steps: 25000 lr_decay: 1.0 log_interval: 100 + checkpoint: + kbest_n: 50 + latest_n: 5 decoding: diff --git a/examples/librispeech/s1/conf/chunk_transformer.yaml b/examples/librispeech/s1/conf/chunk_transformer.yaml index 3939ffc688e1de5dc66606328e48e2d69459b0b6..dc2a51f9277407ece63319af0ee59cc767cfd460 100644 --- a/examples/librispeech/s1/conf/chunk_transformer.yaml +++ b/examples/librispeech/s1/conf/chunk_transformer.yaml @@ -84,6 +84,9 @@ training: warmup_steps: 25000 lr_decay: 1.0 log_interval: 100 + checkpoint: + kbest_n: 50 + latest_n: 5 decoding: diff --git a/examples/librispeech/s1/conf/conformer.yaml b/examples/librispeech/s1/conf/conformer.yaml index 8f8bf45398813179db88781dcfc5c71356295934..989af22a0ec77a28f56ca716ec5fcf98cd916e7e 100644 --- a/examples/librispeech/s1/conf/conformer.yaml +++ b/examples/librispeech/s1/conf/conformer.yaml @@ -87,6 +87,9 @@ training: warmup_steps: 25000 lr_decay: 1.0 log_interval: 100 + checkpoint: + kbest_n: 50 + latest_n: 5 decoding: diff --git a/examples/librispeech/s1/conf/transformer.yaml b/examples/librispeech/s1/conf/transformer.yaml index a094b0fba6088ced2252fc71963ed3afb9ca5c0f..931d7524bbe1016ab69b4e38daf6e8f3949a1616 100644 --- a/examples/librispeech/s1/conf/transformer.yaml +++ b/examples/librispeech/s1/conf/transformer.yaml @@ -82,6 +82,9 @@ training: warmup_steps: 25000 lr_decay: 1.0 log_interval: 100 + checkpoint: + kbest_n: 50 + latest_n: 5 decoding: diff --git a/examples/tiny/s1/conf/chunk_confermer.yaml b/examples/tiny/s1/conf/chunk_confermer.yaml index 79006626408823732ba74838ebece5927b6a88f0..606300bdf30c6749fdf1b8700365f80e7a3fb008 100644 --- a/examples/tiny/s1/conf/chunk_confermer.yaml +++ b/examples/tiny/s1/conf/chunk_confermer.yaml @@ -91,6 +91,9 @@ training: warmup_steps: 25000 lr_decay: 1.0 log_interval: 1 + checkpoint: + kbest_n: 10 + latest_n: 1 decoding: diff --git a/examples/tiny/s1/conf/chunk_transformer.yaml b/examples/tiny/s1/conf/chunk_transformer.yaml index aa2b145a681dff821d4695f96be8aef35d674a5e..72d368485c6fab0b0cd20e7f91d9085830d1890e 100644 --- a/examples/tiny/s1/conf/chunk_transformer.yaml +++ b/examples/tiny/s1/conf/chunk_transformer.yaml @@ -84,6 +84,9 @@ training: warmup_steps: 25000 lr_decay: 1.0 log_interval: 1 + checkpoint: + kbest_n: 10 + latest_n: 1 decoding: diff --git a/examples/tiny/s1/conf/conformer.yaml b/examples/tiny/s1/conf/conformer.yaml index 3813daa04a516c143d7a545cd28999518fecf2d8..a6f73050144594ccd3ebea33cef53f36b0ba7672 100644 --- a/examples/tiny/s1/conf/conformer.yaml +++ b/examples/tiny/s1/conf/conformer.yaml @@ -87,6 +87,9 @@ training: warmup_steps: 25000 lr_decay: 1.0 log_interval: 1 + checkpoint: + kbest_n: 10 + latest_n: 1 decoding: diff --git a/examples/tiny/s1/conf/transformer.yaml b/examples/tiny/s1/conf/transformer.yaml index 250995faadc8b4e668ed717d70b9ebadcdc67b60..71cbdde7f930baa27e88d3ff86ed72eac9d06182 100644 --- a/examples/tiny/s1/conf/transformer.yaml +++ b/examples/tiny/s1/conf/transformer.yaml @@ -84,6 +84,9 @@ training: warmup_steps: 25000 lr_decay: 1.0 log_interval: 1 + checkpoint: + kbest_n: 10 + latest_n: 1 decoding: