diff --git a/doc/getstarted/concepts/src/infer.py b/doc/getstarted/concepts/src/infer.py new file mode 100644 index 0000000000000000000000000000000000000000..4cc58dfee0bd6dade0340b4fd0ee1adb49ffebf6 --- /dev/null +++ b/doc/getstarted/concepts/src/infer.py @@ -0,0 +1,18 @@ +import paddle.v2 as paddle +import numpy as np + +paddle.init(use_gpu=False) +x = paddle.layer.data(name='x', type=paddle.data_type.dense_vector(2)) +y_predict = paddle.layer.fc(input=x, size=1, act=paddle.activation.Linear()) + +# loading the model which generated by training +with open('params_pass_90.tar', 'r') as f: + parameters = paddle.parameters.Parameters.from_tar(f) + +# Input multiple sets of data,Output the infer result in a array. +i = [[[1, 2]], [[3, 4]], [[5, 6]]] +print paddle.infer(output_layer=y_predict, parameters=parameters, input=i) +# Will print: +# [[ -3.24491572] +# [ -6.94668722] +# [-10.64845848]] diff --git a/doc/getstarted/concepts/src/train.py b/doc/getstarted/concepts/src/train.py index 8aceb23406a476f08639cc6223cdf730b728a705..4bccbfca3c70c12aec564e2cae3b8ca174b68777 100644 --- a/doc/getstarted/concepts/src/train.py +++ b/doc/getstarted/concepts/src/train.py @@ -26,6 +26,11 @@ def event_handler(event): if event.batch_id % 1 == 0: print "Pass %d, Batch %d, Cost %f" % (event.pass_id, event.batch_id, event.cost) + # product model every 10 pass + if isinstance(event, paddle.event.EndPass): + if event.pass_id % 10 == 0: + with open('params_pass_%d.tar' % event.pass_id, 'w') as f: + trainer.save_parameter_to_tar(f) # define training dataset reader diff --git a/doc/getstarted/concepts/use_concepts_cn.rst b/doc/getstarted/concepts/use_concepts_cn.rst index c243083794bb3c4659242de99b3b2715af9d7c24..e695ff283e2e806377a51c559b37e8068360a4ff 100644 --- a/doc/getstarted/concepts/use_concepts_cn.rst +++ b/doc/getstarted/concepts/use_concepts_cn.rst @@ -147,4 +147,9 @@ PaddlePaddle支持不同类型的输入数据,主要包括四种类型,和 .. literalinclude:: src/train.py :linenos: +使用以上训练好的模型进行预测,取其中一个模型params_pass_90.tar,输入需要预测的向量组,然后打印输出: + +.. literalinclude:: src/infer.py + :linenos: + 有关线性回归的实际应用,可以参考PaddlePaddle book的 `第一章节 `_。