diff --git a/paddlespeech/s2t/frontend/utility.py b/paddlespeech/s2t/frontend/utility.py index 703f2127d7e71b093030658f6d88c45979770c61..d423a60447b48b6be218afd06a13e6ce355a31ea 100644 --- a/paddlespeech/s2t/frontend/utility.py +++ b/paddlespeech/s2t/frontend/utility.py @@ -102,10 +102,10 @@ def read_manifest( manifest = [] with jsonlines.open(manifest_path, 'r') as reader: for json_data in reader: - feat_len = json_data["feat_shape"][ - 0] if 'feat_shape' in json_data else 1.0 - token_len = json_data["token_shape"][ - 0] if 'token_shape' in json_data else 1.0 + feat_len = json_data["input"][0]["shape"][ + 0] if 'shape' in json_data["input"][0] else 1.0 + token_len = json_data["output"][0]["shape"][ + 0] if 'shape' in json_data["output"][0] else 1.0 conditions = [ feat_len >= min_input_len, feat_len <= max_input_len, diff --git a/paddlespeech/s2t/io/sampler.py b/paddlespeech/s2t/io/sampler.py index 35b57524b5906d53366ebc1c8d4b36322129bba2..0d5a16ce10a25b8234ad3bc6244b2828f350f092 100644 --- a/paddlespeech/s2t/io/sampler.py +++ b/paddlespeech/s2t/io/sampler.py @@ -51,7 +51,7 @@ def _batch_shuffle(indices, batch_size, epoch, clipped=False): """ rng = np.random.RandomState(epoch) shift_len = rng.randint(0, batch_size - 1) - batch_indices = list(zip(* [iter(indices[shift_len:])] * batch_size)) + batch_indices = list(zip(*[iter(indices[shift_len:])] * batch_size)) rng.shuffle(batch_indices) batch_indices = [item for batch in batch_indices for item in batch] assert clipped is False diff --git a/paddlespeech/s2t/utils/checkpoint.py b/paddlespeech/s2t/utils/checkpoint.py index 5105f95efaebf4347a0d13b4383a1db962835308..4c493715a4d25daeb2b3736c3f306443fb73503a 100644 --- a/paddlespeech/s2t/utils/checkpoint.py +++ b/paddlespeech/s2t/utils/checkpoint.py @@ -94,6 +94,9 @@ class Checkpoint(): """ configs = {} + if len(checkpoint_path) == 0 or checkpoint_path == "None": + checkpoint_path = None + if checkpoint_path is not None: pass elif checkpoint_dir is not None and record_file is not None: