提交 75098698 编写于 作者: H Hui Zhang

format,test=doc

上级 54341c88
......@@ -51,7 +51,7 @@ def _batch_shuffle(indices, batch_size, epoch, clipped=False):
"""
rng = np.random.RandomState(epoch)
shift_len = rng.randint(0, batch_size - 1)
batch_indices = list(zip(*[iter(indices[shift_len:])] * batch_size))
batch_indices = list(zip(* [iter(indices[shift_len:])] * batch_size))
rng.shuffle(batch_indices)
batch_indices = [item for batch in batch_indices for item in batch]
assert clipped is False
......
......@@ -33,8 +33,6 @@ from paddlespeech.s2t.modules.decoder import TransformerDecoder
from paddlespeech.s2t.modules.encoder import ConformerEncoder
from paddlespeech.s2t.modules.encoder import TransformerEncoder
from paddlespeech.s2t.modules.loss import LabelSmoothingLoss
from paddlespeech.s2t.modules.mask import mask_finished_preds
from paddlespeech.s2t.modules.mask import mask_finished_scores
from paddlespeech.s2t.modules.mask import subsequent_mask
from paddlespeech.s2t.utils import checkpoint
from paddlespeech.s2t.utils import layer_tools
......@@ -291,7 +289,7 @@ class U2STBaseModel(nn.Layer):
device = speech.place
# Let's assume B = batch_size and N = beam_size
# 1. Encoder and init hypothesis
# 1. Encoder and init hypothesis
encoder_out, encoder_mask = self._forward_encoder(
speech, speech_lengths, decoding_chunk_size,
num_decoding_left_chunks,
......
......@@ -36,4 +36,4 @@ def repeat(N, fn):
Returns:
MultiSequential: Repeated model instance.
"""
return MultiSequential(*[fn(n) for n in range(N)])
return MultiSequential(* [fn(n) for n in range(N)])
......@@ -11,16 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pickle
import unittest
import numpy as np
import paddle
import pickle
import os
from paddle import inference
from paddlespeech.s2t.models.ds2_online import DeepSpeech2ModelOnline
from paddlespeech.s2t.models.ds2_online import DeepSpeech2InferModelOnline
from paddlespeech.s2t.models.ds2_online import DeepSpeech2ModelOnline
class TestDeepSpeech2ModelOnline(unittest.TestCase):
def setUp(self):
......@@ -185,15 +186,12 @@ class TestDeepSpeech2ModelOnline(unittest.TestCase):
paddle.allclose(final_state_c_box, final_state_c_box_chk), True)
class TestDeepSpeech2StaticModelOnline(unittest.TestCase):
def setUp(self):
export_prefix = "exp/deepspeech2_online/checkpoints/test_export"
if not os.path.exists(os.path.dirname(export_prefix)):
os.makedirs(os.path.dirname(export_prefix), mode=0o755)
infer_model = DeepSpeech2InferModelOnline(
infer_model = DeepSpeech2InferModelOnline(
feat_size=161,
dict_size=4233,
num_conv_layers=2,
......@@ -207,27 +205,25 @@ class TestDeepSpeech2StaticModelOnline(unittest.TestCase):
with open("test_data/static_ds2online_inputs.pickle", "rb") as f:
self.data_dict = pickle.load(f)
self.setup_model(export_prefix)
def setup_model(self, export_prefix):
deepspeech_config = inference.Config(
export_prefix + ".pdmodel",
export_prefix + ".pdiparams")
if ('CUDA_VISIBLE_DEVICES' in os.environ.keys() and os.environ['CUDA_VISIBLE_DEVICES'].strip() != ''):
deepspeech_config = inference.Config(export_prefix + ".pdmodel",
export_prefix + ".pdiparams")
if ('CUDA_VISIBLE_DEVICES' in os.environ.keys() and
os.environ['CUDA_VISIBLE_DEVICES'].strip() != ''):
deepspeech_config.enable_use_gpu(100, 0)
deepspeech_config.enable_memory_optim()
deepspeech_predictor = inference.create_predictor(deepspeech_config)
self.predictor = deepspeech_predictor
def test_unit(self):
input_names = self.predictor.get_input_names()
audio_handle = self.predictor.get_input_handle(input_names[0])
audio_len_handle = self.predictor.get_input_handle(input_names[1])
h_box_handle = self.predictor.get_input_handle(input_names[2])
c_box_handle = self.predictor.get_input_handle(input_names[3])
x_chunk = self.data_dict["audio_chunk"]
x_chunk_lens = self.data_dict["audio_chunk_lens"]
......@@ -246,13 +242,9 @@ class TestDeepSpeech2StaticModelOnline(unittest.TestCase):
c_box_handle.reshape(chunk_state_c_box.shape)
c_box_handle.copy_from_cpu(chunk_state_c_box)
output_names = self.predictor.get_output_names()
output_handle = self.predictor.get_output_handle(
output_names[0])
output_lens_handle = self.predictor.get_output_handle(
output_names[1])
output_handle = self.predictor.get_output_handle(output_names[0])
output_lens_handle = self.predictor.get_output_handle(output_names[1])
output_state_h_handle = self.predictor.get_output_handle(
output_names[2])
output_state_c_handle = self.predictor.get_output_handle(
......@@ -264,7 +256,7 @@ class TestDeepSpeech2StaticModelOnline(unittest.TestCase):
chunk_state_h_box = output_state_h_handle.copy_to_cpu()
chunk_state_c_box = output_state_c_handle.copy_to_cpu()
return True
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册