提交 dc8390d8 编写于 作者: K Kexin Zhao

initial commit

上级 0f8dd956
......@@ -145,7 +145,7 @@ def seq_to_seq_net():
cost = fluid.layers.cross_entropy(input=prediction, label=label)
avg_cost = fluid.layers.mean(x=cost)
return avg_cost
return avg_cost, prediction
def to_lodtensor(data, place):
......@@ -163,8 +163,8 @@ def to_lodtensor(data, place):
return res
def main():
avg_cost = seq_to_seq_net()
def train(save_dirname=None):
[avg_cost, prediction] = seq_to_seq_net()
optimizer = fluid.optimizer.Adagrad(learning_rate=1e-4)
optimizer.minimize(avg_cost)
......@@ -196,9 +196,52 @@ def main():
print('pass_id=' + str(pass_id) + ' batch=' + str(batch_id) +
" avg_cost=" + str(avg_cost_val))
if batch_id > 3:
if save_dirname is not None:
fluid.io.save_inference_model(save_dirname, [
'source_sequence', 'target_sequence', 'label_sequence'
], [prediction], exe)
exit(0)
batch_id += 1
def inference(save_dirname=None):
if save_dirname is None:
return
place = fluid.CPUPlace()
exe = fluid.Executor(place)
# Use fluid.io.load_inference_model to obtain the inference program desc,
# the feed_target_names (the names of variables that will be feeded
# data using feed operators), and the fetch_targets (variables that
# we want to obtain data from using fetch operators).
[inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(save_dirname, exe)
data = [[0, 1, 0, 1], [0, 1, 1, 0, 0, 1]]
word_data = to_lodtensor(data, place)
trg_word = to_lodtensor(data, place)
trg_word_next = to_lodtensor(data, place)
# Construct feed as a dictionary of {feed_target_name: feed_target_data}
# and results will contain a list of data corresponding to fetch_targets.
print(feed_target_names)
assert feed_target_names[0] == 'source_sequence'
assert feed_target_names[1] == 'target_sequence'
assert feed_target_names[2] == 'label_sequence'
results = exe.run(inference_program,
feed={
feed_target_names[0]: word_data,
feed_target_names[1]: trg_word,
feed_target_names[2]: trg_word_next
},
fetch_list=fetch_targets)
print("Inference Shape: ", results[0].shape)
print("infer results: ", results[0])
if __name__ == '__main__':
main()
save_dirname = "rnn_encoder_decoder.inference.model"
train(save_dirname)
infer(save_dirname)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册