未验证 提交 52e2eb65 编写于 作者: S Siddharth Goyal 提交者: GitHub

Fix function in fit-a-line with new API (#11020)

上级 fae3d8d2
...@@ -38,7 +38,7 @@ def inference_program(): ...@@ -38,7 +38,7 @@ def inference_program():
return y_predict return y_predict
def linear(): def train_program():
y = fluid.layers.data(name='y', shape=[1], dtype='float32') y = fluid.layers.data(name='y', shape=[1], dtype='float32')
y_predict = inference_program() y_predict = inference_program()
...@@ -104,7 +104,7 @@ def main(use_cuda): ...@@ -104,7 +104,7 @@ def main(use_cuda):
# Directory for saving the trained model # Directory for saving the trained model
params_dirname = "fit_a_line.inference.model" params_dirname = "fit_a_line.inference.model"
train(use_cuda, linear, params_dirname) train(use_cuda, train_program, params_dirname)
infer(use_cuda, inference_program, params_dirname) infer(use_cuda, inference_program, params_dirname)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册