From d2fab3238b7082ee5a5df33d6725514cf4cceb05 Mon Sep 17 00:00:00 2001 From: Junkun Date: Mon, 29 Nov 2021 16:57:36 -0800 Subject: [PATCH] fix bugs --- paddlespeech/s2t/frontend/utility.py | 8 ++++---- paddlespeech/s2t/io/sampler.py | 2 +- paddlespeech/s2t/utils/checkpoint.py | 3 +++ 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/paddlespeech/s2t/frontend/utility.py b/paddlespeech/s2t/frontend/utility.py index 703f2127..d423a604 100644 --- a/paddlespeech/s2t/frontend/utility.py +++ b/paddlespeech/s2t/frontend/utility.py @@ -102,10 +102,10 @@ def read_manifest( manifest = [] with jsonlines.open(manifest_path, 'r') as reader: for json_data in reader: - feat_len = json_data["feat_shape"][ - 0] if 'feat_shape' in json_data else 1.0 - token_len = json_data["token_shape"][ - 0] if 'token_shape' in json_data else 1.0 + feat_len = json_data["input"][0]["shape"][ + 0] if 'shape' in json_data["input"][0] else 1.0 + token_len = json_data["output"][0]["shape"][ + 0] if 'shape' in json_data["output"][0] else 1.0 conditions = [ feat_len >= min_input_len, feat_len <= max_input_len, diff --git a/paddlespeech/s2t/io/sampler.py b/paddlespeech/s2t/io/sampler.py index 35b57524..0d5a16ce 100644 --- a/paddlespeech/s2t/io/sampler.py +++ b/paddlespeech/s2t/io/sampler.py @@ -51,7 +51,7 @@ def _batch_shuffle(indices, batch_size, epoch, clipped=False): """ rng = np.random.RandomState(epoch) shift_len = rng.randint(0, batch_size - 1) - batch_indices = list(zip(* [iter(indices[shift_len:])] * batch_size)) + batch_indices = list(zip(*[iter(indices[shift_len:])] * batch_size)) rng.shuffle(batch_indices) batch_indices = [item for batch in batch_indices for item in batch] assert clipped is False diff --git a/paddlespeech/s2t/utils/checkpoint.py b/paddlespeech/s2t/utils/checkpoint.py index 5105f95e..4c493715 100644 --- a/paddlespeech/s2t/utils/checkpoint.py +++ b/paddlespeech/s2t/utils/checkpoint.py @@ -94,6 +94,9 @@ class Checkpoint(): """ configs = {} + if len(checkpoint_path) == 0 or checkpoint_path == "None": + checkpoint_path = None + if checkpoint_path is not None: pass elif checkpoint_dir is not None and record_file is not None: -- GitLab