未验证 提交 e48091db 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2Static]Add param_guard in ParameterList to support @to_static

上级 b035c8b0
......@@ -15,6 +15,7 @@
from collections import OrderedDict
from ..framework import Parameter
from .layers import Layer
from .base import param_guard
__all__ = [
'Sequential',
......@@ -159,7 +160,8 @@ class ParameterList(Layer):
self.add_parameter(str(idx), param)
def __getitem__(self, idx):
return self._parameters[str(idx)]
with param_guard(self._parameters):
return self._parameters[str(idx)]
def __setitem__(self, idx, param):
assert isinstance(param, Parameter)
......@@ -169,7 +171,8 @@ class ParameterList(Layer):
return len(self._parameters)
def __iter__(self):
return iter(self._parameters.values())
with param_guard(self._parameters):
return iter(self._parameters.values())
def append(self, parameter):
"""Appends a given parameter at the end of the list.
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import numpy as np
import unittest
from paddle.jit import to_static, ProgramTranslator
class NetWithParameterList(paddle.nn.Layer):
def __init__(self, in_size, out_size):
super(NetWithParameterList, self).__init__()
weight = self.create_parameter([in_size, out_size])
bias = self.create_parameter([out_size], is_bias=True)
self.params = paddle.nn.ParameterList([weight, bias])
@to_static
def forward(self, x):
out = paddle.matmul(x, self.params[0])
out = paddle.add(out, self.params[1])
out = paddle.tanh(out)
return out
class NetWithParameterListIter(NetWithParameterList):
def __init__(self, in_size, out_size):
super(NetWithParameterListIter, self).__init__(in_size, out_size)
@to_static
def forward(self, x):
# NOTE: manually trigger `__iter__` logic.
params = list(self.params.__iter__())
out = paddle.matmul(x, params[0])
out = paddle.add(out, params[1])
out = paddle.tanh(out)
return out
class TestParameterList(unittest.TestCase):
def setUp(self):
self.seed = 2021
self.iter_num = 5
self.prog_trans = ProgramTranslator()
def train(self, is_iter, to_static):
paddle.seed(self.seed)
np.random.seed(self.seed)
self.prog_trans.enable(to_static)
if is_iter:
net = NetWithParameterList(10, 3)
else:
net = NetWithParameterListIter(10, 3)
sgd = paddle.optimizer.SGD(0.1, parameters=net.parameters())
for batch_id in range(self.iter_num):
x = paddle.rand([4, 10], dtype='float32')
out = net(x)
loss = paddle.mean(out)
loss.backward()
sgd.step()
sgd.clear_grad()
return loss
def test_parameter_list(self):
static_loss = self.train(False, to_static=True)
dygraph_loss = self.train(False, to_static=False)
self.assertTrue(
np.allclose(dygraph_loss, static_loss),
msg='dygraph result is {}\nstatic result is {}'.format(dygraph_loss,
static_loss))
def test_parameter_list_iter(self):
static_loss = self.train(True, to_static=True)
dygraph_loss = self.train(True, to_static=False)
self.assertTrue(
np.allclose(dygraph_loss, static_loss),
msg='dygraph result is {}\nstatic result is {}'.format(dygraph_loss,
static_loss))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册