提交 bfab8d96 编写于 作者: L LiuChiaChi

fix example code for Model, test=document_fix, notest

上级 d28162b9
...@@ -822,7 +822,6 @@ class Model(object): ...@@ -822,7 +822,6 @@ class Model(object):
nn.Tanh(), nn.Tanh(),
nn.Linear(200, 10)) nn.Linear(200, 10))
# inputs and labels are not required for dynamic graph.
input = InputSpec([None, 784], 'float32', 'x') input = InputSpec([None, 784], 'float32', 'x')
label = InputSpec([None, 1], 'int64', 'label') label = InputSpec([None, 1], 'int64', 'label')
...@@ -980,7 +979,9 @@ class Model(object): ...@@ -980,7 +979,9 @@ class Model(object):
nn.Linear(200, 10), nn.Linear(200, 10),
nn.Softmax()) nn.Softmax())
model = paddle.Model(net) input = InputSpec([None, 784], 'float32', 'x')
label = InputSpec([None, 1], 'int64', 'label')
model = paddle.Model(net, input, label)
model.prepare() model.prepare()
data = np.random.random(size=(4,784)).astype(np.float32) data = np.random.random(size=(4,784)).astype(np.float32)
out = model.test_batch([data]) out = model.test_batch([data])
...@@ -1096,11 +1097,15 @@ class Model(object): ...@@ -1096,11 +1097,15 @@ class Model(object):
device = paddle.set_device('cpu') device = paddle.set_device('cpu')
paddle.disable_static(device) paddle.disable_static(device)
input = InputSpec([None, 784], 'float32', 'x')
label = InputSpec([None, 1], 'int64', 'label')
model = paddle.Model(nn.Sequential( model = paddle.Model(nn.Sequential(
nn.Linear(784, 200), nn.Linear(784, 200),
nn.Tanh(), nn.Tanh(),
nn.Linear(200, 10), nn.Linear(200, 10),
nn.Softmax())) nn.Softmax()),
input,
label)
model.save('checkpoint/test') model.save('checkpoint/test')
model.load('checkpoint/test') model.load('checkpoint/test')
""" """
...@@ -1168,10 +1173,14 @@ class Model(object): ...@@ -1168,10 +1173,14 @@ class Model(object):
paddle.disable_static() paddle.disable_static()
input = InputSpec([None, 784], 'float32', 'x')
label = InputSpec([None, 1], 'int64', 'label')
model = paddle.Model(nn.Sequential( model = paddle.Model(nn.Sequential(
nn.Linear(784, 200), nn.Linear(784, 200),
nn.Tanh(), nn.Tanh(),
nn.Linear(200, 10))) nn.Linear(200, 10)),
input,
label)
params = model.parameters() params = model.parameters()
""" """
return self._adapter.parameters() return self._adapter.parameters()
...@@ -1483,7 +1492,7 @@ class Model(object): ...@@ -1483,7 +1492,7 @@ class Model(object):
# imperative mode # imperative mode
paddle.disable_static() paddle.disable_static()
model = paddle.Model(paddle.vision.models.LeNet()) model = paddle.Model(paddle.vision.models.LeNet(), input, label)
model.prepare(metrics=paddle.metric.Accuracy()) model.prepare(metrics=paddle.metric.Accuracy())
result = model.evaluate(val_dataset, batch_size=64) result = model.evaluate(val_dataset, batch_size=64)
print(result) print(result)
...@@ -1591,7 +1600,7 @@ class Model(object): ...@@ -1591,7 +1600,7 @@ class Model(object):
# imperative mode # imperative mode
device = paddle.set_device('cpu') device = paddle.set_device('cpu')
paddle.disable_static(device) paddle.disable_static(device)
model = paddle.Model(paddle.vision.models.LeNet()) model = paddle.Model(paddle.vision.models.LeNet(), input)
model.prepare() model.prepare()
result = model.predict(test_dataset, batch_size=64) result = model.predict(test_dataset, batch_size=64)
print(len(result[0]), result[0][0].shape) print(len(result[0]), result[0][0].shape)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册