From 6f3c7d9b241ae91d39fd9114ea19c2d13ec4b544 Mon Sep 17 00:00:00 2001 From: Nicky Date: Wed, 30 May 2018 14:39:55 -0700 Subject: [PATCH] Simplify and make a clear function name --- .../machine_translation/test_machine_translation.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/paddle/fluid/tests/book/high-level-api/machine_translation/test_machine_translation.py b/python/paddle/fluid/tests/book/high-level-api/machine_translation/test_machine_translation.py index 1f85221a9d..d4b723d3e6 100644 --- a/python/paddle/fluid/tests/book/high-level-api/machine_translation/test_machine_translation.py +++ b/python/paddle/fluid/tests/book/high-level-api/machine_translation/test_machine_translation.py @@ -53,7 +53,7 @@ def encoder(is_sparse): return encoder_out -def decoder_train(context, is_sparse): +def train_decoder(context, is_sparse): # decoder trg_language_word = pd.data( name="target_language_word", shape=[1], dtype='int64', lod_level=1) @@ -81,7 +81,7 @@ def decoder_train(context, is_sparse): return rnn() -def decoder_decode(context, is_sparse): +def decode(context, is_sparse): init_state = context array_len = pd.fill_constant(shape=[1], dtype='int64', value=max_length) counter = pd.zeros(shape=[1], dtype='int64', force_cpu=True) @@ -150,7 +150,7 @@ def decoder_decode(context, is_sparse): def train_program(is_sparse): context = encoder(is_sparse) - rnn_out = decoder_train(context, is_sparse) + rnn_out = train_decoder(context, is_sparse) label = pd.data( name="target_language_next_word", shape=[1], dtype='int64', lod_level=1) cost = pd.cross_entropy(input=rnn_out, label=label) @@ -201,7 +201,7 @@ def decode_main(use_cuda, is_sparse): place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() context = encoder(is_sparse) - translation_ids, translation_scores = decoder_decode(context, is_sparse) + translation_ids, translation_scores = decode(context, is_sparse) exe = Executor(place) exe.run(framework.default_startup_program()) -- GitLab