提交 83037e55 编写于 作者: S songyouwei 提交者: liym27

named_sublayers and named_parameters (#21868)

* use snake_cased Layer name

* add named_parameters and named_sublayers api

* add include_sublayers param, add unittest
test=develop

* fix named unittests
test=develop

* fix unittest
test=develop

* add api docs
test=develop

* arg fix
test=develop

* reserve rnn_impl name_scope for static graph
test=develop

* fix load static param
test=develop

* fix load static param
test=develop
上级 ad0dfb17
......@@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.fluid import layers
from paddle.fluid import layers, unique_name
from paddle.fluid.dygraph import Layer
from paddle.fluid.dygraph.layer_object_helper import LayerObjectHelper
from paddle.fluid.layers.control_flow import StaticRNN
__all__ = ['BasicGRUUnit', 'basic_gru', 'BasicLSTMUnit', 'basic_lstm']
......@@ -80,6 +81,10 @@ class BasicGRUUnit(Layer):
activation=None,
dtype='float32'):
super(BasicGRUUnit, self).__init__(name_scope, dtype)
# reserve old school _full_name and _helper for static graph save load
self._full_name = unique_name.generate(name_scope + "/" +
self.__class__.__name__)
self._helper = LayerObjectHelper(self._full_name)
self._name = name_scope
self._hiden_size = hidden_size
......@@ -710,6 +715,10 @@ class BasicLSTMUnit(Layer):
forget_bias=1.0,
dtype='float32'):
super(BasicLSTMUnit, self).__init__(name_scope, dtype)
# reserve old school _full_name and _helper for static graph save load
self._full_name = unique_name.generate(name_scope + "/" +
self.__class__.__name__)
self._helper = LayerObjectHelper(self._full_name)
self._name = name_scope
self._hiden_size = hidden_size
......
......@@ -18,6 +18,7 @@ import sys
import numpy as np
import collections
import six
import re
from . import parallel_helper
from .. import unique_name
from paddle.fluid import core
......@@ -30,6 +31,14 @@ import warnings
__all__ = ['Layer']
_first_cap_re = re.compile('(.)([A-Z][a-z]+)')
_all_cap_re = re.compile('([a-z])([A-Z])')
def _convert_camel_to_snake(name):
s1 = _first_cap_re.sub(r'\1_\2', name)
return _all_cap_re.sub(r'\1_\2', s1).lower()
class Layer(core.Layer):
"""Dynamic graph Layer based on OOD, includes the parameters of the layer, the structure of the forward graph and so on.
......@@ -37,9 +46,9 @@ class Layer(core.Layer):
Parameters:
name_scope (str, optional): prefix name used by the layer to name parameters.
If prefix is "my_layer", parameter name in MyLayer
can be "mylayer_0.w_n", where w is the parameter
base name and n is an unique suffix auto-generated.
If None, prefix name will be lower cased class name. Default: None.
can be "my_layer_0.w_n", where "w" is the parameter
base name and "n" is an unique suffix auto-generated.
If None, prefix name will be snake cased class name. Default: None.
dtype(str or core.VarDesc.VarType, optional): data type of this parameter.
If set str, it can be "bool", "float16", "float32", "float64",
"int8", "int16", "int32", "int64", "uint8" or "uint16".
......@@ -51,12 +60,8 @@ class Layer(core.Layer):
def __init__(self, name_scope=None, dtype=core.VarDesc.VarType.FP32):
if name_scope is None:
name_scope = self.__class__.__name__.lower()
self._full_name = unique_name.generate(name_scope)
else:
# TODO: remove name_scope parameter and all hard-coded usages
self._full_name = unique_name.generate(name_scope + "/" +
self.__class__.__name__)
name_scope = _convert_camel_to_snake(self.__class__.__name__)
self._full_name = unique_name.generate(name_scope)
self._helper = LayerObjectHelper(self._full_name)
self._built = False
self._dtype = dtype
......@@ -172,6 +177,93 @@ class Layer(core.Layer):
ret.append(sub_l)
return ret
def named_parameters(self, prefix='', include_sublayers=True):
"""
Returns an iterator over all parameters in the Layer, yielding tuple of name and parameter.
Parameters:
prefix(str, optional): Prefix to prepend to all parameter names. Default: ''.
include_sublayers(bool, optional): Whether include the parameters of sublayers.
If True, also include the named parameters from sublayers. Default: True.
Yields:
(string, Parameter): Tuple of name and Parameter
Examples:
.. code-block:: python
import paddle.fluid as fluid
with fluid.dygraph.guard():
fc1 = fluid.Linear(10, 3)
fc2 = fluid.Linear(3, 10, bias_attr=False)
model = fluid.dygraph.Sequential(fc1, fc2)
for name, param in model.named_parameters():
print(name, param)
"""
params_set = set()
named_sublayers = self.named_sublayers(
prefix=prefix,
include_sublayers=include_sublayers,
include_self=True)
for layer_prefix, sublayer in named_sublayers:
params = sublayer._parameters.items()
for key, param in params:
if param is None or param in params_set:
continue
params_set.add(param)
name = layer_prefix + ('.' if layer_prefix else '') + key
yield name, param
def named_sublayers(self,
prefix='',
include_sublayers=True,
include_self=False,
layers_set=None):
"""
Returns an iterator over all sublayers in the Layer, yielding tuple of name and sublayer.
The duplicate sublayer will only be yielded once.
Parameters:
prefix(str, optional): Prefix to prepend to all parameter names. Default: ''.
include_sublayers(bool, optional): Whether include the sublayers. Default: True.
include_self(bool, optional): Whether include the Layer itself. Default: False.
layers_set(set, optioanl): The set to record duplicate sublayers. Default: None.
Yields:
(string, Layer): Tuple of name and Layer
Examples:
.. code-block:: python
import paddle.fluid as fluid
with fluid.dygraph.guard():
fc1 = fluid.Linear(10, 3)
fc2 = fluid.Linear(3, 10, bias_attr=False)
model = fluid.dygraph.Sequential(fc1, fc2)
for prefix, layer in model.named_sublayers():
print(prefix, layer)
"""
if layers_set is None:
layers_set = set()
if include_self and self not in layers_set:
layers_set.add(self)
yield prefix, self
if include_sublayers:
for key, layer in self._sub_layers.items():
if layer is None:
continue
layer_prefix = prefix + ('.' if prefix else '') + key
for p, l in layer.named_sublayers(
prefix=layer_prefix,
include_sublayers=include_sublayers,
include_self=True,
layers_set=layers_set):
yield p, l
def clear_gradients(self):
"""
Clear the gradients of all parameters for this layer.
......
......@@ -1073,9 +1073,6 @@ class BatchNorm(layers.Layer):
self._bias_attr = bias_attr
self._act = act
self._full_name = unique_name.generate("batch_norm")
self._helper = LayerObjectHelper(self._full_name)
assert bias_attr is not False, "bias_attr should not be False in batch_norm."
if dtype == "float16":
......@@ -1424,9 +1421,6 @@ class LayerNorm(layers.Layer):
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = [normalized_shape]
self._full_name = unique_name.generate("layer_norm")
self._helper = LayerObjectHelper(self._full_name)
self._normalized_shape = list(normalized_shape)
self._scale = scale
self._shift = shift
......@@ -1989,7 +1983,8 @@ class PRelu(layers.Layer):
input_shape=None,
param_attr=None,
dtype='float32'):
super(PRelu, self).__init__()
# need specify name_scope since snake-cased 'PRelu' is 'p_relu'
super(PRelu, self).__init__(name_scope='prelu')
self._mode = mode
self._param_attr = param_attr
self._dtype = dtype
......
......@@ -19,8 +19,8 @@ import paddle.fluid as fluid
class L1(fluid.Layer):
def __init__(self, prefix):
super(L1, self).__init__(prefix)
def __init__(self):
super(L1, self).__init__()
self._param_attr = fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0.1))
self.w1 = self.create_parameter(
......@@ -33,20 +33,20 @@ class L1(fluid.Layer):
class L2(fluid.Layer):
def __init__(self, prefix):
super(L2, self).__init__(prefix)
self.layer1 = L1(self.full_name())
self.layer2 = L1(self.full_name())
def __init__(self):
super(L2, self).__init__()
self.layer1 = L1()
self.layer2 = L1()
def forward(self):
return self.layer1() + self.layer2()
class L3(fluid.Layer):
def __init__(self, prefix):
super(L3, self).__init__(prefix)
self.layer1 = L2(self.full_name())
self.layer2 = L2(self.full_name())
def __init__(self):
super(L3, self).__init__()
self.layer1 = L2()
self.layer2 = L2()
def forward(self):
return self.layer1() + self.layer2()
......@@ -55,23 +55,33 @@ class L3(fluid.Layer):
class TestBaseLayer(unittest.TestCase):
def test_one_level(self):
with fluid.dygraph.guard():
l = L1('test_one_level')
l = L1()
ret = l()
self.assertEqual(l.w1.name, "test_one_level/L1_0.w_0")
self.assertEqual(l.w2.name, "test_one_level/L1_0.w_1")
expected_names = ['l1.w1', 'l1.w2']
idx = 0
for name, _ in l.named_parameters(prefix='l1'):
self.assertEqual(name, expected_names[idx])
idx += 1
self.assertTrue(np.allclose(ret.numpy(), 0.2 * np.ones([2, 2])))
def test_three_level(self):
with fluid.dygraph.guard():
l = L3('test_three_level')
names = [p.name for p in l.parameters()]
l = L3()
expected_names = [
'l3.layer1.layer1.w1',
'l3.layer1.layer1.w2',
'l3.layer1.layer2.w1',
'l3.layer1.layer2.w2',
'l3.layer2.layer1.w1',
'l3.layer2.layer1.w2',
'l3.layer2.layer2.w1',
'l3.layer2.layer2.w2',
]
idx = 0
for name, _ in l.named_parameters(prefix='l3'):
self.assertEqual(name, expected_names[idx])
idx += 1
ret = l()
self.assertEqual(names[0], "test_three_level/L3_0/L2_0/L1_0.w_0")
self.assertEqual(names[1], "test_three_level/L3_0/L2_0/L1_0.w_1")
self.assertEqual(names[2], "test_three_level/L3_0/L2_0/L1_1.w_0")
self.assertEqual(names[3], "test_three_level/L3_0/L2_0/L1_1.w_1")
self.assertEqual(names[4], "test_three_level/L3_0/L2_1/L1_0.w_0")
self.assertEqual(names[5], "test_three_level/L3_0/L2_1/L1_0.w_1")
self.assertTrue(np.allclose(ret.numpy(), 0.8 * np.ones([2, 2])))
......
......@@ -23,8 +23,8 @@ from test_imperative_base import new_program_scope
class MyLayer(fluid.Layer):
def __init__(self, name_scope):
super(MyLayer, self).__init__(name_scope)
def __init__(self):
super(MyLayer, self).__init__()
def forward(self, inputs):
x = fluid.layers.relu(inputs)
......@@ -60,16 +60,14 @@ class MLP(fluid.Layer):
class SimpleRNNCell(fluid.Layer):
def __init__(self, name_scope, step_input_size, hidden_size, output_size,
param_attr):
super(SimpleRNNCell, self).__init__(name_scope)
def __init__(self, step_input_size, hidden_size, output_size, param_attr):
super(SimpleRNNCell, self).__init__()
self.step_input_size = step_input_size
self.hidden_size = hidden_size
self.output_size = output_size
self._dtype = core.VarDesc.VarType.FP32
self.param_attr = param_attr
def _build_once(self, inputs, pre_hidden):
i2h_param_shape = [self.step_input_size, self.hidden_size]
h2h_param_shape = [self.hidden_size, self.hidden_size]
h2o_param_shape = [self.output_size, self.hidden_size]
......@@ -90,7 +88,6 @@ class SimpleRNNCell(fluid.Layer):
is_bias=False)
def forward(self, input, pre_hidden):
tmp_i2h = self.create_variable(dtype=self._dtype)
tmp_h2h = self.create_variable(dtype=self._dtype)
hidden = self.create_variable(dtype=self._dtype)
......@@ -147,11 +144,10 @@ class SimpleRNNCell(fluid.Layer):
class SimpleRNN(fluid.Layer):
def __init__(self, name_scope):
super(SimpleRNN, self).__init__(name_scope)
def __init__(self):
super(SimpleRNN, self).__init__()
self.seq_len = 4
self._cell = SimpleRNNCell(
self.full_name(),
3,
3,
3,
......@@ -297,7 +293,7 @@ class TestImperative(unittest.TestCase):
with fluid.dygraph.guard():
var_inp = fluid.dygraph.base.to_variable(np_inp)
var_inp.stop_gradient = False
l = MyLayer("my_layer")
l = MyLayer()
print(var_inp)
x = l(var_inp)[0]
self.assertIsNotNone(x)
......@@ -308,7 +304,7 @@ class TestImperative(unittest.TestCase):
with fluid.dygraph.guard():
var_inp2 = fluid.dygraph.base.to_variable(np_inp)
var_inp2.stop_gradient = False
l2 = MyLayer("my_layer")
l2 = MyLayer()
x2 = l2(var_inp2)[0]
self.assertIsNotNone(x2)
dy_out2 = x2.numpy()
......@@ -320,7 +316,7 @@ class TestImperative(unittest.TestCase):
with new_program_scope():
inp = fluid.layers.data(
name="inp", shape=[3], append_batch_size=False)
l = MyLayer("my_layer")
l = MyLayer()
x = l(inp)[0]
param_grads = fluid.backward.append_backward(
x, parameter_list=[l._x_for_debug.name])[0]
......@@ -447,7 +443,7 @@ class TestImperative(unittest.TestCase):
with fluid.dygraph.guard():
var_inp = fluid.dygraph.base.to_variable(np_inp)
var_inp = fluid.layers.reshape(var_inp, shape=[1, 4, 3])
simple_rnn = SimpleRNN("simple_rnn")
simple_rnn = SimpleRNN()
outs, pre_hiddens = simple_rnn.forward(var_inp)
dy_out = outs[3].numpy()
outs[3].backward()
......@@ -458,7 +454,7 @@ class TestImperative(unittest.TestCase):
with fluid.dygraph.guard():
var_inp2 = fluid.dygraph.base.to_variable(np_inp)
var_inp2 = fluid.layers.reshape(var_inp2, shape=[1, 4, 3])
simple_rnn2 = SimpleRNN("simple_rnn")
simple_rnn2 = SimpleRNN()
outs2, pre_hiddens2 = simple_rnn2.forward(var_inp2)
dy_out2 = outs2[3].numpy()
backward_strategy = fluid.dygraph.BackwardStrategy()
......@@ -471,7 +467,7 @@ class TestImperative(unittest.TestCase):
with new_program_scope():
inp = fluid.layers.data(
name="inp", shape=[1, 4, 3], append_batch_size=False)
simple_rnn = SimpleRNN("simple_rnn")
simple_rnn = SimpleRNN()
outs, pre_hiddens = simple_rnn(inp)
param_grads = fluid.backward.append_backward(outs[3])
exe = fluid.Executor(fluid.CPUPlace())
......
......@@ -18,8 +18,6 @@ import paddle.fluid.framework as framework
from paddle.fluid.dygraph.nn import *
import numpy as np
print("11")
class TestDygraphLoadStatic(unittest.TestCase):
def testLoadStaticModel(self):
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle.fluid as fluid
class MyLayer(fluid.Layer):
def __init__(self, num_channel, dim, num_filter=5):
super(MyLayer, self).__init__()
self.fc = fluid.dygraph.Linear(dim, dim)
self.conv = fluid.dygraph.Conv2D(num_channel, num_channel, num_filter)
def forward(self, x):
x = self.fc(x)
x = self.conv(x)
return x
class TestImperativeNamedSubLayers(unittest.TestCase):
def test_named_sublayers(self):
with fluid.dygraph.guard():
fc1 = fluid.Linear(10, 3)
fc2 = fluid.Linear(3, 10, bias_attr=False)
custom = MyLayer(3, 10)
model = fluid.dygraph.Sequential(fc1, fc2, custom)
named_sublayers = model.named_sublayers()
list_named_sublayers = list(named_sublayers)
expected_sublayers = [fc1, fc2, custom, custom.fc, custom.conv]
self.assertEqual(len(list_named_sublayers), len(expected_sublayers))
for (name, sublayer), expected_sublayer in zip(list_named_sublayers,
expected_sublayers):
self.assertEqual(sublayer, expected_sublayer)
list_sublayers = list(model.sublayers())
self.assertEqual(len(list_named_sublayers), len(list_sublayers))
for (name, sublayer), expected_sublayer in zip(list_named_sublayers,
list_sublayers):
self.assertEqual(sublayer, expected_sublayer)
for name, sublayer in model.named_sublayers(
include_sublayers=False):
self.assertEqual(model[name], sublayer)
self.assertListEqual(
[l for _, l in list(model.named_sublayers(include_self=True))],
[model] + expected_sublayers)
class TestImperativeNamedParameters(unittest.TestCase):
def test_named_parameters(self):
with fluid.dygraph.guard():
fc1 = fluid.Linear(10, 3)
fc2 = fluid.Linear(3, 10, bias_attr=False)
custom = MyLayer(3, 10)
model = fluid.dygraph.Sequential(fc1, fc2, custom)
named_parameters = list(model.named_parameters())
expected_named_parameters = list()
for prefix, layer in model.named_sublayers(include_sublayers=True):
for name, param in layer.named_parameters(
include_sublayers=False):
full_name = prefix + ('.' if prefix else '') + name
expected_named_parameters.append((full_name, param))
self.assertListEqual(expected_named_parameters, named_parameters)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册