diff --git a/python/paddle/v2/fluid/tests/book/test_rnn_encoder_decoder.py b/python/paddle/v2/fluid/tests/book/test_rnn_encoder_decoder.py index fdc60861760163d2ebad3b050e551929321baafd..593d0013c9d16c0f10fbb2317ff2b0bab709609f 100644 --- a/python/paddle/v2/fluid/tests/book/test_rnn_encoder_decoder.py +++ b/python/paddle/v2/fluid/tests/book/test_rnn_encoder_decoder.py @@ -145,7 +145,7 @@ def seq_to_seq_net(): cost = fluid.layers.cross_entropy(input=prediction, label=label) avg_cost = fluid.layers.mean(x=cost) - return avg_cost + return avg_cost, prediction def to_lodtensor(data, place): @@ -163,8 +163,8 @@ def to_lodtensor(data, place): return res -def main(): - avg_cost = seq_to_seq_net() +def train(save_dirname=None): + [avg_cost, prediction] = seq_to_seq_net() optimizer = fluid.optimizer.Adagrad(learning_rate=1e-4) optimizer.minimize(avg_cost) @@ -196,9 +196,52 @@ def main(): print('pass_id=' + str(pass_id) + ' batch=' + str(batch_id) + " avg_cost=" + str(avg_cost_val)) if batch_id > 3: + if save_dirname is not None: + fluid.io.save_inference_model(save_dirname, [ + 'source_sequence', 'target_sequence', 'label_sequence' + ], [prediction], exe) exit(0) batch_id += 1 +def inference(save_dirname=None): + if save_dirname is None: + return + + place = fluid.CPUPlace() + exe = fluid.Executor(place) + + # Use fluid.io.load_inference_model to obtain the inference program desc, + # the feed_target_names (the names of variables that will be feeded + # data using feed operators), and the fetch_targets (variables that + # we want to obtain data from using fetch operators). + [inference_program, feed_target_names, + fetch_targets] = fluid.io.load_inference_model(save_dirname, exe) + + data = [[0, 1, 0, 1], [0, 1, 1, 0, 0, 1]] + word_data = to_lodtensor(data, place) + trg_word = to_lodtensor(data, place) + trg_word_next = to_lodtensor(data, place) + + # Construct feed as a dictionary of {feed_target_name: feed_target_data} + # and results will contain a list of data corresponding to fetch_targets. + print(feed_target_names) + assert feed_target_names[0] == 'source_sequence' + assert feed_target_names[1] == 'target_sequence' + assert feed_target_names[2] == 'label_sequence' + results = exe.run(inference_program, + feed={ + feed_target_names[0]: word_data, + feed_target_names[1]: trg_word, + feed_target_names[2]: trg_word_next + }, + fetch_list=fetch_targets) + + print("Inference Shape: ", results[0].shape) + print("infer results: ", results[0]) + + if __name__ == '__main__': - main() + save_dirname = "rnn_encoder_decoder.inference.model" + train(save_dirname) + infer(save_dirname)