未验证 提交 12402c30 编写于 作者: L LiuChiachi 提交者: GitHub

update doc for hapi.model.save and hapi.model.Model in 2.0-RC (#2645)

* update doc for hapi.model.save and hapi.model.Model in 2.0-RC

* fix example code for Model

* fix doc for Model example code

* fix sample code, add import InputSpec
上级 7b859d23
......@@ -5,7 +5,7 @@ Model
.. py:class:: paddle.Model()
``Model`` 对象是一个具备训练、测试、推理的神经网络。该对象同时支持静态图和动态图模式,通过 ``paddle.disable_static()`` 来切换。需要注意的是,该开关需要在实例化 ``Model`` 对象之前使用。 在静态图模式下,输入需要使用 ``paddle.static.InputSpec`` 来定义。
``Model`` 对象是一个具备训练、测试、推理的神经网络。该对象同时支持静态图和动态图模式,通过 ``paddle.disable_static()`` 来切换。需要注意的是,该开关需要在实例化 ``Model`` 对象之前使用。输入需要使用 ``paddle.static.InputSpec`` 来定义。
**代码示例**
......@@ -24,7 +24,6 @@ Model
nn.Tanh(),
nn.Linear(200, 10))
# inputs and labels are not required for dynamic graph.
input = InputSpec([None, 784], 'float32', 'x')
label = InputSpec([None, 1], 'int64', 'label')
......@@ -137,6 +136,7 @@ Model
import numpy as np
import paddle
import paddle.nn as nn
from paddle.static import InputSpec
device = paddle.set_device('cpu') # or 'gpu'
paddle.disable_static(device)
......@@ -147,7 +147,9 @@ Model
nn.Linear(200, 10),
nn.Softmax())
model = paddle.Model(net)
input = InputSpec([None, 784], 'float32', 'x')
label = InputSpec([None, 1], 'int64', 'label')
model = paddle.Model(net, input, label)
model.prepare()
data = np.random.random(size=(4,784)).astype(np.float32)
out = model.test_batch([data])
......@@ -157,7 +159,6 @@ Model
将模型的参数和训练过程中优化器的信息保存到指定的路径,以及推理所需的参数与文件。如果training=True,所有的模型参数都会保存到一个后缀为 ``.pdparams`` 的文件中。
所有的优化器信息和相关参数,比如 ``Adam`` 优化器中的 ``beta1`` ``beta2`` ``momentum`` 等,都会被保存到后缀为 ``.pdopt``。如果优化器比如SGD没有参数,则该不会产生该文件。如果training=False,则不会保存上述说的文件。只会保存推理需要的参数文件和模型文件。
需要注意的是,保存推理模型的参数文件和模型文件时,需要在 ``forward`` 上添加 ``@paddle.jit.to_static`` 函数在动态图模式下。
参数:
- **path** (str) - 保存的文件名前缀。格式如 ``dirname/file_prefix`` 或者 ``file_prefix``
......@@ -182,8 +183,6 @@ Model
nn.Linear(200, 10),
nn.Softmax())
# If save for inference in dygraph, need this
@paddle.jit.to_static
def forward(self, x):
return self.net(x)
......@@ -191,7 +190,7 @@ Model
device = paddle.set_device('cpu')
# if use static graph, do not set
paddle.disable_static(device) if dynamic else None
# inputs and labels are not required for dynamic graph.
input = InputSpec([None, 784], 'float32', 'x')
label = InputSpec([None, 1], 'int64', 'label')
model = paddle.Model(Mnist(), input, label)
......@@ -220,15 +219,20 @@ Model
import paddle
import paddle.nn as nn
from paddle.static import InputSpec
device = paddle.set_device('cpu')
paddle.disable_static(device)
input = InputSpec([None, 784], 'float32', 'x')
label = InputSpec([None, 1], 'int64', 'label')
model = paddle.Model(nn.Sequential(
nn.Linear(784, 200),
nn.Tanh(),
nn.Linear(200, 10),
nn.Softmax()))
nn.Softmax()),
input,
label)
model.save('checkpoint/test')
model.load('checkpoint/test')
......@@ -243,13 +247,18 @@ Model
.. code-block:: python
import paddle
import paddle.nn as nn
from paddle.static import InputSpec
paddle.disable_static()
input = InputSpec([None, 784], 'float32', 'x')
label = InputSpec([None, 1], 'int64', 'label')
model = paddle.Model(nn.Sequential(
nn.Linear(784, 200),
nn.Tanh(),
nn.Linear(200, 10)))
nn.Linear(200, 10)),
input,
label)
params = model.parameters()
......@@ -384,7 +393,7 @@ Model
# imperative mode
paddle.disable_static()
model = paddle.Model(paddle.vision.models.LeNet())
model = paddle.Model(paddle.vision.models.LeNet(), input, label)
model.prepare(metrics=paddle.metric.Accuracy())
result = model.evaluate(val_dataset, batch_size=64)
print(result)
......@@ -439,7 +448,7 @@ Model
# imperative mode
device = paddle.set_device('cpu')
paddle.disable_static(device)
model = paddle.Model(paddle.vision.models.LeNet())
model = paddle.Model(paddle.vision.models.LeNet(), input)
model.prepare()
result = model.predict(test_dataset, batch_size=64)
print(len(result[0]), result[0][0].shape)
......@@ -479,4 +488,4 @@ Model
paddle.nn.CrossEntropyLoss())
params_info = model.summary()
print(params_info)
\ No newline at end of file
print(params_info)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册