未验证 提交 df300fff 编写于 作者: N Nicky Chan 提交者: GitHub

Merge pull request #11056 from nickyfantasy/refract_machine_translation_test

Simplify and make clear function names on Machine Translation example
...@@ -53,7 +53,7 @@ def encoder(is_sparse): ...@@ -53,7 +53,7 @@ def encoder(is_sparse):
return encoder_out return encoder_out
def decoder_train(context, is_sparse): def train_decoder(context, is_sparse):
# decoder # decoder
trg_language_word = pd.data( trg_language_word = pd.data(
name="target_language_word", shape=[1], dtype='int64', lod_level=1) name="target_language_word", shape=[1], dtype='int64', lod_level=1)
...@@ -81,7 +81,7 @@ def decoder_train(context, is_sparse): ...@@ -81,7 +81,7 @@ def decoder_train(context, is_sparse):
return rnn() return rnn()
def decoder_decode(context, is_sparse): def decode(context, is_sparse):
init_state = context init_state = context
array_len = pd.fill_constant(shape=[1], dtype='int64', value=max_length) array_len = pd.fill_constant(shape=[1], dtype='int64', value=max_length)
counter = pd.zeros(shape=[1], dtype='int64', force_cpu=True) counter = pd.zeros(shape=[1], dtype='int64', force_cpu=True)
...@@ -150,7 +150,7 @@ def decoder_decode(context, is_sparse): ...@@ -150,7 +150,7 @@ def decoder_decode(context, is_sparse):
def train_program(is_sparse): def train_program(is_sparse):
context = encoder(is_sparse) context = encoder(is_sparse)
rnn_out = decoder_train(context, is_sparse) rnn_out = train_decoder(context, is_sparse)
label = pd.data( label = pd.data(
name="target_language_next_word", shape=[1], dtype='int64', lod_level=1) name="target_language_next_word", shape=[1], dtype='int64', lod_level=1)
cost = pd.cross_entropy(input=rnn_out, label=label) cost = pd.cross_entropy(input=rnn_out, label=label)
...@@ -201,7 +201,7 @@ def decode_main(use_cuda, is_sparse): ...@@ -201,7 +201,7 @@ def decode_main(use_cuda, is_sparse):
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
context = encoder(is_sparse) context = encoder(is_sparse)
translation_ids, translation_scores = decoder_decode(context, is_sparse) translation_ids, translation_scores = decode(context, is_sparse)
exe = Executor(place) exe = Executor(place)
exe.run(framework.default_startup_program()) exe.run(framework.default_startup_program())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册