未验证 提交 565c3d49 编写于 作者: A Aurelius84 提交者: GitHub

Enrich tutorial of InputSpec (#2682)

* enrich tutorial of InputSpec

* fix typo

* fix indent
上级 d2f73ff1
......@@ -197,4 +197,46 @@ InputSpec 初始化中的只有 ``shape`` 是必须参数, ``dtype`` 和 ``nam
其中 ``input_spec`` 参数是长度为 2 的 list ,对应 forward 函数的 x 和 bias_info 两个参数。 ``input_spec`` 的最后一个元素是包含键名为 x 的 InputSpec 对象的 dict ,对应参数 bias_info 的 Tensor 签名信息。
2.4 指定非Tensor参数类型
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
目前,``to_static`` 装饰器中的 ``input_spec`` 参数仅接收 ``InputSpec`` 类型对象。若被装饰函数的参数列表除了 Tensor 类型,还包含其他如 Int、 String 等非 Tensor 类型时,推荐在函数中使用 kwargs 形式定义非 Tensor 参数,如下述样例中的 use_act 参数。
.. code-block:: python
class SimpleNet(Layer):
def __init__(self, ):
super(SimpleNet, self).__init__()
self.linear = paddle.nn.Linear(10, 3)
self.relu = paddle.nn.ReLU()
@to_static(input_spec=[InputSpec(shape=[None, 10], name='x')])
def forward(self, x, use_act=False):
out = self.linear(x)
if use_act:
out = self.relu(out)
return out
net = SimpleNet()
adam = paddle.optimizer.Adam(parameters=net.parameters())
# train model
batch_num = 10
for step in range(batch_num):
x = paddle.rand([4, 10], 'float32')
use_act = (step%2 == 0)
out = net(x, use_act)
loss = paddle.mean(out)
loss.backward()
adam.minimize(loss)
net.clear_gradients()
# save inference model with use_act=False
paddle.jit.save(net, model_path='./simple_net')
在上述样例中,step 为奇数时,use_act 取值为 False ; step 为偶数时, use_act 取值为 True 。动转静支持非 Tensor 参数在训练时取不同的值,且保证了取值不同的训练过程都可以更新模型的网络参数,行为与动态图一致。
kwargs 参数的默认值主要用于保存推理模型。在借助 ``paddle.jit.save`` 保存预测模型时,动转静会根据 input_spec 和 kwargs 的默认值保存推理模型和网络参数。因此建议将 kwargs 参数默认值设置为预测时的取值。
更多关于动转静 ``to_static`` 搭配 ``paddle.jit.save/load`` 的使用方式,可以参考 :ref:`user_guide_model_save_load` 。
\ No newline at end of file
......@@ -194,3 +194,46 @@ If a function takes an argument of type dict, the element in the ``input_spec``
The length of ``input_spec`` is 2 corresponding to arguments x and bias_info in forward function. The last element of ``input_spec`` is a InputSpec dict with same key corresponding to signature information of bias_info.
2.4 Specify non-Tensor arguments
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Currently, the ``input_spec`` from ``to_static`` decorator only receives objects with ``InputSpec`` type. When the decorated function contains some non-Tensor arguments, such as Int, String or other python types, we recommend to use kwargs with default values as argument, see use_act in followed example.
.. code-block:: python
class SimpleNet(Layer):
def __init__(self, ):
super(SimpleNet, self).__init__()
self.linear = paddle.nn.Linear(10, 3)
self.relu = paddle.nn.ReLU()
@to_static(input_spec=[InputSpec(shape=[None, 10], name='x')])
def forward(self, x, use_act=False):
out = self.linear(x)
if use_act:
out = self.relu(out)
return out
net = SimpleNet()
adam = paddle.optimizer.Adam(parameters=net.parameters())
# train model
batch_num = 10
for step in range(batch_num):
x = paddle.rand([4, 10], 'float32')
use_act = (step%2 == 0)
out = net(x, use_act)
loss = paddle.mean(out)
loss.backward()
adam.minimize(loss)
net.clear_gradients()
# save inference model with use_act=False
paddle.jit.save(net, model_path='./simple_net')
In above example, use_act is equal to True if step is an odd number, and False if step is an even number. We support non-tensor argument applied to different values during training after conversion. Moreover, the shared parameters of the model can be updated during the training with different values. The behavior is consistent with the dynamic graph.
The default value of the kwargs is primarily used for saving inference model. The inference model and network parameters will be exported based on input_spec and the default values of kwargs. Therefore, it is recommended to set the default value of the kwargs arguments for prediction.
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册