diff --git a/examples/deepvoice3/data.py b/examples/deepvoice3/data.py index 0d0aaeb7f4cf9731df98e9540dd9512b93671f1a..46381744d9cf1899bd93901b3226dc859ec54ee8 100644 --- a/examples/deepvoice3/data.py +++ b/examples/deepvoice3/data.py @@ -189,11 +189,14 @@ class DataCollector(object): # text positions text_mask = (np.arange(1, 1 + max_text_length) <= np.expand_dims( text_lengths, -1)).astype(np.int64) - text_positions = np.arange(1, 1 + max_text_length) * text_mask + text_positions = np.arange( + 1, 1 + max_text_length, dtype=np.int64) * text_mask # decoder_positions decoder_positions = np.tile( - np.expand_dims(np.arange(1, 1 + max_decoder_length), 0), + np.expand_dims( + np.arange( + 1, 1 + max_decoder_length, dtype=np.int64), 0), (batch_size, 1)) return (text_sequences, text_lengths, text_positions, mel_specs,