提交 75098698 编写于 作者: H Hui Zhang

format,test=doc

上级 54341c88
...@@ -51,7 +51,7 @@ def _batch_shuffle(indices, batch_size, epoch, clipped=False): ...@@ -51,7 +51,7 @@ def _batch_shuffle(indices, batch_size, epoch, clipped=False):
""" """
rng = np.random.RandomState(epoch) rng = np.random.RandomState(epoch)
shift_len = rng.randint(0, batch_size - 1) shift_len = rng.randint(0, batch_size - 1)
batch_indices = list(zip(*[iter(indices[shift_len:])] * batch_size)) batch_indices = list(zip(* [iter(indices[shift_len:])] * batch_size))
rng.shuffle(batch_indices) rng.shuffle(batch_indices)
batch_indices = [item for batch in batch_indices for item in batch] batch_indices = [item for batch in batch_indices for item in batch]
assert clipped is False assert clipped is False
......
...@@ -33,8 +33,6 @@ from paddlespeech.s2t.modules.decoder import TransformerDecoder ...@@ -33,8 +33,6 @@ from paddlespeech.s2t.modules.decoder import TransformerDecoder
from paddlespeech.s2t.modules.encoder import ConformerEncoder from paddlespeech.s2t.modules.encoder import ConformerEncoder
from paddlespeech.s2t.modules.encoder import TransformerEncoder from paddlespeech.s2t.modules.encoder import TransformerEncoder
from paddlespeech.s2t.modules.loss import LabelSmoothingLoss from paddlespeech.s2t.modules.loss import LabelSmoothingLoss
from paddlespeech.s2t.modules.mask import mask_finished_preds
from paddlespeech.s2t.modules.mask import mask_finished_scores
from paddlespeech.s2t.modules.mask import subsequent_mask from paddlespeech.s2t.modules.mask import subsequent_mask
from paddlespeech.s2t.utils import checkpoint from paddlespeech.s2t.utils import checkpoint
from paddlespeech.s2t.utils import layer_tools from paddlespeech.s2t.utils import layer_tools
...@@ -291,7 +289,7 @@ class U2STBaseModel(nn.Layer): ...@@ -291,7 +289,7 @@ class U2STBaseModel(nn.Layer):
device = speech.place device = speech.place
# Let's assume B = batch_size and N = beam_size # Let's assume B = batch_size and N = beam_size
# 1. Encoder and init hypothesis # 1. Encoder and init hypothesis
encoder_out, encoder_mask = self._forward_encoder( encoder_out, encoder_mask = self._forward_encoder(
speech, speech_lengths, decoding_chunk_size, speech, speech_lengths, decoding_chunk_size,
num_decoding_left_chunks, num_decoding_left_chunks,
......
...@@ -36,4 +36,4 @@ def repeat(N, fn): ...@@ -36,4 +36,4 @@ def repeat(N, fn):
Returns: Returns:
MultiSequential: Repeated model instance. MultiSequential: Repeated model instance.
""" """
return MultiSequential(*[fn(n) for n in range(N)]) return MultiSequential(* [fn(n) for n in range(N)])
...@@ -11,16 +11,17 @@ ...@@ -11,16 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import pickle
import unittest import unittest
import numpy as np import numpy as np
import paddle import paddle
import pickle
import os
from paddle import inference from paddle import inference
from paddlespeech.s2t.models.ds2_online import DeepSpeech2ModelOnline
from paddlespeech.s2t.models.ds2_online import DeepSpeech2InferModelOnline from paddlespeech.s2t.models.ds2_online import DeepSpeech2InferModelOnline
from paddlespeech.s2t.models.ds2_online import DeepSpeech2ModelOnline
class TestDeepSpeech2ModelOnline(unittest.TestCase): class TestDeepSpeech2ModelOnline(unittest.TestCase):
def setUp(self): def setUp(self):
...@@ -185,15 +186,12 @@ class TestDeepSpeech2ModelOnline(unittest.TestCase): ...@@ -185,15 +186,12 @@ class TestDeepSpeech2ModelOnline(unittest.TestCase):
paddle.allclose(final_state_c_box, final_state_c_box_chk), True) paddle.allclose(final_state_c_box, final_state_c_box_chk), True)
class TestDeepSpeech2StaticModelOnline(unittest.TestCase): class TestDeepSpeech2StaticModelOnline(unittest.TestCase):
def setUp(self): def setUp(self):
export_prefix = "exp/deepspeech2_online/checkpoints/test_export" export_prefix = "exp/deepspeech2_online/checkpoints/test_export"
if not os.path.exists(os.path.dirname(export_prefix)): if not os.path.exists(os.path.dirname(export_prefix)):
os.makedirs(os.path.dirname(export_prefix), mode=0o755) os.makedirs(os.path.dirname(export_prefix), mode=0o755)
infer_model = DeepSpeech2InferModelOnline( infer_model = DeepSpeech2InferModelOnline(
feat_size=161, feat_size=161,
dict_size=4233, dict_size=4233,
num_conv_layers=2, num_conv_layers=2,
...@@ -207,27 +205,25 @@ class TestDeepSpeech2StaticModelOnline(unittest.TestCase): ...@@ -207,27 +205,25 @@ class TestDeepSpeech2StaticModelOnline(unittest.TestCase):
with open("test_data/static_ds2online_inputs.pickle", "rb") as f: with open("test_data/static_ds2online_inputs.pickle", "rb") as f:
self.data_dict = pickle.load(f) self.data_dict = pickle.load(f)
self.setup_model(export_prefix) self.setup_model(export_prefix)
def setup_model(self, export_prefix): def setup_model(self, export_prefix):
deepspeech_config = inference.Config( deepspeech_config = inference.Config(export_prefix + ".pdmodel",
export_prefix + ".pdmodel", export_prefix + ".pdiparams")
export_prefix + ".pdiparams") if ('CUDA_VISIBLE_DEVICES' in os.environ.keys() and
if ('CUDA_VISIBLE_DEVICES' in os.environ.keys() and os.environ['CUDA_VISIBLE_DEVICES'].strip() != ''): os.environ['CUDA_VISIBLE_DEVICES'].strip() != ''):
deepspeech_config.enable_use_gpu(100, 0) deepspeech_config.enable_use_gpu(100, 0)
deepspeech_config.enable_memory_optim() deepspeech_config.enable_memory_optim()
deepspeech_predictor = inference.create_predictor(deepspeech_config) deepspeech_predictor = inference.create_predictor(deepspeech_config)
self.predictor = deepspeech_predictor self.predictor = deepspeech_predictor
def test_unit(self): def test_unit(self):
input_names = self.predictor.get_input_names() input_names = self.predictor.get_input_names()
audio_handle = self.predictor.get_input_handle(input_names[0]) audio_handle = self.predictor.get_input_handle(input_names[0])
audio_len_handle = self.predictor.get_input_handle(input_names[1]) audio_len_handle = self.predictor.get_input_handle(input_names[1])
h_box_handle = self.predictor.get_input_handle(input_names[2]) h_box_handle = self.predictor.get_input_handle(input_names[2])
c_box_handle = self.predictor.get_input_handle(input_names[3]) c_box_handle = self.predictor.get_input_handle(input_names[3])
x_chunk = self.data_dict["audio_chunk"] x_chunk = self.data_dict["audio_chunk"]
x_chunk_lens = self.data_dict["audio_chunk_lens"] x_chunk_lens = self.data_dict["audio_chunk_lens"]
...@@ -246,13 +242,9 @@ class TestDeepSpeech2StaticModelOnline(unittest.TestCase): ...@@ -246,13 +242,9 @@ class TestDeepSpeech2StaticModelOnline(unittest.TestCase):
c_box_handle.reshape(chunk_state_c_box.shape) c_box_handle.reshape(chunk_state_c_box.shape)
c_box_handle.copy_from_cpu(chunk_state_c_box) c_box_handle.copy_from_cpu(chunk_state_c_box)
output_names = self.predictor.get_output_names() output_names = self.predictor.get_output_names()
output_handle = self.predictor.get_output_handle( output_handle = self.predictor.get_output_handle(output_names[0])
output_names[0]) output_lens_handle = self.predictor.get_output_handle(output_names[1])
output_lens_handle = self.predictor.get_output_handle(
output_names[1])
output_state_h_handle = self.predictor.get_output_handle( output_state_h_handle = self.predictor.get_output_handle(
output_names[2]) output_names[2])
output_state_c_handle = self.predictor.get_output_handle( output_state_c_handle = self.predictor.get_output_handle(
...@@ -264,7 +256,7 @@ class TestDeepSpeech2StaticModelOnline(unittest.TestCase): ...@@ -264,7 +256,7 @@ class TestDeepSpeech2StaticModelOnline(unittest.TestCase):
chunk_state_h_box = output_state_h_handle.copy_to_cpu() chunk_state_h_box = output_state_h_handle.copy_to_cpu()
chunk_state_c_box = output_state_c_handle.copy_to_cpu() chunk_state_c_box = output_state_c_handle.copy_to_cpu()
return True return True
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册