提交 bfab8d96 编写于 作者: L LiuChiaChi

fix example code for Model, test=document_fix, notest

上级 d28162b9
......@@ -822,7 +822,6 @@ class Model(object):
nn.Tanh(),
nn.Linear(200, 10))
# inputs and labels are not required for dynamic graph.
input = InputSpec([None, 784], 'float32', 'x')
label = InputSpec([None, 1], 'int64', 'label')
......@@ -980,7 +979,9 @@ class Model(object):
nn.Linear(200, 10),
nn.Softmax())
model = paddle.Model(net)
input = InputSpec([None, 784], 'float32', 'x')
label = InputSpec([None, 1], 'int64', 'label')
model = paddle.Model(net, input, label)
model.prepare()
data = np.random.random(size=(4,784)).astype(np.float32)
out = model.test_batch([data])
......@@ -1096,11 +1097,15 @@ class Model(object):
device = paddle.set_device('cpu')
paddle.disable_static(device)
input = InputSpec([None, 784], 'float32', 'x')
label = InputSpec([None, 1], 'int64', 'label')
model = paddle.Model(nn.Sequential(
nn.Linear(784, 200),
nn.Tanh(),
nn.Linear(200, 10),
nn.Softmax()))
nn.Softmax()),
input,
label)
model.save('checkpoint/test')
model.load('checkpoint/test')
"""
......@@ -1168,10 +1173,14 @@ class Model(object):
paddle.disable_static()
input = InputSpec([None, 784], 'float32', 'x')
label = InputSpec([None, 1], 'int64', 'label')
model = paddle.Model(nn.Sequential(
nn.Linear(784, 200),
nn.Tanh(),
nn.Linear(200, 10)))
nn.Linear(200, 10)),
input,
label)
params = model.parameters()
"""
return self._adapter.parameters()
......@@ -1483,7 +1492,7 @@ class Model(object):
# imperative mode
paddle.disable_static()
model = paddle.Model(paddle.vision.models.LeNet())
model = paddle.Model(paddle.vision.models.LeNet(), input, label)
model.prepare(metrics=paddle.metric.Accuracy())
result = model.evaluate(val_dataset, batch_size=64)
print(result)
......@@ -1591,7 +1600,7 @@ class Model(object):
# imperative mode
device = paddle.set_device('cpu')
paddle.disable_static(device)
model = paddle.Model(paddle.vision.models.LeNet())
model = paddle.Model(paddle.vision.models.LeNet(), input)
model.prepare()
result = model.predict(test_dataset, batch_size=64)
print(len(result[0]), result[0][0].shape)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册