未验证 提交 379216ae 编写于 作者: R Roc 提交者: GitHub

[Clean Fluid] Rm and mv some fluid dygrah apis (#48576)

Remove fluid dygrah apis
GroupNorm
TreeConv
Move fluid dygraph apis
Flatten
SpectralNorm
上级 592ed40b
......@@ -55,10 +55,6 @@ __all__ = [
'BatchNorm',
'Embedding',
'Conv3DTranspose',
'GroupNorm',
'SpectralNorm',
'TreeConv',
'Flatten',
]
......@@ -1203,421 +1199,3 @@ class RowConv(layers.Layer):
outputs={'Out': [out]},
)
return self._helper.append_activation(out, act=self._act)
class GroupNorm(layers.Layer):
"""
:alias_main: paddle.nn.GroupNorm
:alias: paddle.nn.GroupNorm,paddle.nn.layer.GroupNorm,paddle.nn.layer.norm.GroupNorm
:old_api: paddle.fluid.dygraph.GroupNorm
This interface is used to construct a callable object of the ``GroupNorm`` class.
For more details, refer to code examples.
It implements the function of the Group Normalization Layer.
Refer to `Group Normalization <https://arxiv.org/abs/1803.08494>`_ .
Parameters:
channels(int): The number of channels of input.
groups(int): The number of groups that divided from channels.
epsilon(float, optional): The small value added to the variance to prevent
division by zero. Default: 1e-05.
param_attr(ParamAttr, optional): The parameter attribute for the learnable
scale :math:`g`. If it is set to False, no scale will be added to the output units.
If it is set to None, the bias is initialized one. Default: None.
bias_attr(ParamAttr, optional): The parameter attribute for the learnable
bias :math:`b`. If it is set to False, no bias will be added to the output units.
If it is set to None, the bias is initialized zero. Default: None.
act(str, optional): Activation to be applied to the output of group normalization. Default: None.
data_layout(str, optional): Specify the input data format. Only NCHW is supported. Default: NCHW.
Returns:
None
Examples:
.. code-block:: python
import paddle.fluid as fluid
import numpy as np
with fluid.dygraph.guard():
x = np.random.random((8, 32, 32)).astype('float32')
groupNorm = fluid.dygraph.nn.GroupNorm(channels=32, groups=4)
ret = groupNorm(fluid.dygraph.base.to_variable(x))
"""
def __init__(
self,
channels,
groups,
epsilon=1e-05,
param_attr=None,
bias_attr=None,
act=None,
data_layout='NCHW',
dtype='float32',
):
super().__init__()
self._param_attr = param_attr
self._bias_attr = bias_attr
self._epsilon = epsilon
self._channels = channels
self._groups = groups
self._act = act
self._dtype = dtype
if data_layout != 'NCHW':
raise ValueError("unsupported data layout:" + data_layout)
param_shape = [self._channels]
self.weight = self.create_parameter(
attr=self._param_attr or False,
shape=param_shape,
dtype=self._dtype,
default_initializer=Constant(1.0),
)
self.bias = self.create_parameter(
attr=self._bias_attr or False,
shape=param_shape,
dtype=self._dtype,
is_bias=True,
)
def forward(self, input):
mean_out = self._helper.create_variable_for_type_inference(
dtype=self._dtype, stop_gradient=True
)
variance_out = self._helper.create_variable_for_type_inference(
dtype=self._dtype, stop_gradient=True
)
if in_dygraph_mode():
out = _C_ops.group_norm(
input,
self.weight,
self.bias,
self._epsilon,
self._groups,
"NCHW",
)
return dygraph_utils._append_activation_in_dygraph(out, self._act)
elif _in_legacy_dygraph():
attrs = ('epsilon', self._epsilon, 'groups', self._groups)
out, _, _ = _legacy_C_ops.group_norm(
input, self.weight, self.bias, mean_out, variance_out, *attrs
)
return dygraph_utils._append_activation_in_dygraph(out, self._act)
else:
inputs = {'X': input}
if self.bias is not None:
inputs['Bias'] = self.bias
if self.weight is not None:
inputs['Scale'] = self.weight
# create output
group_norm_out = self._helper.create_variable_for_type_inference(
dtype=self._dtype
)
self._helper.append_op(
type="group_norm",
inputs=inputs,
outputs={
"Y": group_norm_out,
"Mean": mean_out,
"Variance": variance_out,
},
attrs={"epsilon": self._epsilon, "groups": self._groups},
)
return self._helper.append_activation(group_norm_out, self._act)
class SpectralNorm(layers.Layer):
r"""
This interface is used to construct a callable object of the ``SpectralNorm`` class.
For more details, refer to code examples. It implements the function of the Spectral Normalization Layer.
This layer calculates the spectral normalization value of weight parameters of
fc, conv1d, conv2d, conv3d layers which should be 2-D, 3-D, 4-D, 5-D
Parameters. Calculations are showed as follows.
Step 1:
Generate vector U in shape of [H], and V in shape of [W].
While H is the :attr:`dim` th dimension of the input weights,
and W is the product result of remaining dimensions.
Step 2:
:attr:`power_iters` should be a positive integer, do following
calculations with U and V for :attr:`power_iters` rounds.
.. math::
\mathbf{v} := \frac{\mathbf{W}^{T} \mathbf{u}}{\|\mathbf{W}^{T} \mathbf{u}\|_2}
\mathbf{u} := \frac{\mathbf{W}^{T} \mathbf{v}}{\|\mathbf{W}^{T} \mathbf{v}\|_2}
Step 3:
Calculate :math:`\sigma(\mathbf{W})` and normalize weight values.
.. math::
\sigma(\mathbf{W}) = \mathbf{u}^{T} \mathbf{W} \mathbf{v}
\mathbf{W} = \frac{\mathbf{W}}{\sigma(\mathbf{W})}
Refer to `Spectral Normalization <https://arxiv.org/abs/1802.05957>`_ .
Parameters:
weight_shape(list or tuple): The shape of weight parameter.
dim(int, optional): The index of dimension which should be permuted to the first before reshaping Input(Weight) to matrix, it should be set as 0 if Input(Weight) is the weight of fc layer, and should be set as 1 if Input(Weight) is the weight of conv layer. Default: 0.
power_iters(int, optional): The number of power iterations to calculate spectral norm. Default: 1.
eps(float, optional): The epsilon for numerical stability in calculating norms. Default: 1e-12.
name (str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name` .
dtype (str, optional): Data type, it can be "float32" or "float64". Default: "float32".
Returns:
None
Examples:
.. code-block:: python
import paddle
x = paddle.rand((2,8,32,32))
spectral_norm = paddle.nn.SpectralNorm(x.shape, dim=1, power_iters=2)
spectral_norm_out = spectral_norm(x)
print(spectral_norm_out.shape) # [2, 8, 32, 32]
"""
def __init__(
self, weight_shape, dim=0, power_iters=1, eps=1e-12, dtype='float32'
):
super().__init__()
self._power_iters = power_iters
self._eps = eps
self._dim = dim
self._dtype = dtype
self._weight_shape = list(weight_shape)
assert (
np.prod(self._weight_shape) > 0
), "Any dimension of `weight_shape` cannot be equal to 0."
assert dim < len(self._weight_shape), (
"The input `dim` should be less than the "
"length of `weight_shape`, but received dim="
"{}".format(dim)
)
h = self._weight_shape[self._dim]
w = np.prod(self._weight_shape) // h
self.weight_u = self.create_parameter(
attr=ParamAttr(),
shape=[h],
dtype=self._dtype,
default_initializer=Normal(0.0, 1.0),
)
self.weight_u.stop_gradient = True
self.weight_v = self.create_parameter(
attr=ParamAttr(),
shape=[w],
dtype=self._dtype,
default_initializer=Normal(0.0, 1.0),
)
self.weight_v.stop_gradient = True
def forward(self, weight):
if in_dygraph_mode():
return _C_ops.spectral_norm(
weight,
self.weight_u,
self.weight_v,
self._dim,
self._power_iters,
self._eps,
)
check_variable_and_dtype(
weight, "weight", ['float32', 'float64'], 'SpectralNorm'
)
inputs = {'Weight': weight, 'U': self.weight_u, 'V': self.weight_v}
out = self._helper.create_variable_for_type_inference(self._dtype)
self._helper.append_op(
type="spectral_norm",
inputs=inputs,
outputs={
"Out": out,
},
attrs={
"dim": self._dim,
"power_iters": self._power_iters,
"eps": self._eps,
},
)
return out
class TreeConv(layers.Layer):
"""
This interface is used to construct a callable object of the ``TreeConv`` class.
For more details, refer to code examples.
Tree-Based Convolution is a kind of convolution based on tree structure.
Tree-Based Convolution is a part of Tree-Based Convolution Neural Network(TBCNN),
which is used to classify tree structures, such as Abstract Syntax Tree.
Tree-Based Convolution proposed a kind of data structure called continuous binary tree,
which regards multiway tree as binary tree.
The paper of Tree-Based Convolution Operator is here: `tree-based convolution <https://arxiv.org/abs/1409.5718v1/>`_ .
Parameters:
feature_size(int): last dimension of nodes_vector.
output_size(int): output feature width.
num_filters(int, optional): number of filters, Default: 1.
max_depth(int, optional): max depth of filters, Default: 2.
act(str, optional): activation function, Default: tanh.
param_attr(ParamAttr, optional): the parameter attribute for the filters, Default: None.
bias_attr(ParamAttr, optional): the parameter attribute for the bias of this layer, Default: None.
name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name` .
dtype (str, optional): Data type, it can be "float32" or "float64". Default: "float32".
Attribute:
**weight** (Parameter): the learnable weights of filters of this layer.
**bias** (Parameter or None): the learnable bias of this layer.
Returns:
None
Examples:
.. code-block:: python
import paddle.fluid as fluid
import numpy
with fluid.dygraph.guard():
nodes_vector = numpy.random.random((1, 10, 5)).astype('float32')
edge_set = numpy.random.random((1, 9, 2)).astype('int32')
treeConv = fluid.dygraph.nn.TreeConv(
feature_size=5, output_size=6, num_filters=1, max_depth=2)
ret = treeConv(fluid.dygraph.base.to_variable(nodes_vector), fluid.dygraph.base.to_variable(edge_set))
"""
def __init__(
self,
feature_size,
output_size,
num_filters=1,
max_depth=2,
act='tanh',
param_attr=None,
bias_attr=None,
name=None,
dtype='float32',
):
super().__init__()
self._name = name
self._feature_size = feature_size
self._output_size = output_size
self._act = act
self._max_depth = max_depth
self._num_filters = num_filters
self._bias_attr = bias_attr
self._param_attr = param_attr
self._dtype = dtype
w_shape = [self._feature_size, 3, self._output_size, self._num_filters]
if self._bias_attr:
self.bias = self.create_parameter(
attr=self._bias_attr,
shape=[self._num_filters],
dtype=self._dtype,
is_bias=True,
)
self.weight = self.create_parameter(
attr=self._param_attr,
shape=w_shape,
dtype=self._dtype,
is_bias=False,
)
def forward(self, nodes_vector, edge_set):
check_type(nodes_vector, 'nodes_vector', (Variable), 'TreeConv')
check_type(edge_set, 'edge_set', (Variable), 'TreeConv')
if self._name:
out = self.create_variable(
name=self._name, dtype=self._dtype, persistable=False
)
else:
out = self._helper.create_variable_for_type_inference(
dtype=self._dtype
)
self._helper.append_op(
type='tree_conv',
inputs={
'NodesVector': nodes_vector,
'EdgeSet': edge_set,
'Filter': self.weight,
},
outputs={
'Out': out,
},
attrs={'max_depth': self._max_depth},
)
if self._bias_attr:
pre_activation = self._helper.create_variable_for_type_inference(
dtype=self._dtype
)
self._helper.append_op(
type='elementwise_add',
inputs={'X': [out], 'Y': [self.bias]},
outputs={'Out': [pre_activation]},
attrs={'axis': 1},
)
else:
pre_activation = out
return self._helper.append_activation(pre_activation, act=self._act)
class Flatten(layers.Layer):
"""
This interface is used to construct a callable object of the ``FLatten`` class.
For more details, refer to code examples.
It implements flatten a contiguous range of dims into a tensor.
Parameters:
start_axis(int): first dim to flatten (default = 1)
stop_axis(int): last dim to flatten (default = -1).
Returns:
None
Examples:
.. code-block:: python
import paddle
import numpy as np
inp_np = np.ones([5, 2, 3, 4]).astype('float32')
inp_np = paddle.to_tensor(inp_np)
flatten = paddle.nn.Flatten(start_axis=1, stop_axis=2)
flatten_res = flatten(inp_np)
"""
def __init__(self, start_axis=1, stop_axis=-1):
super().__init__()
self.start_axis = start_axis
self.stop_axis = stop_axis
def forward(self, input):
out = paddle.tensor.manipulation.flatten(
input, start_axis=self.start_axis, stop_axis=self.stop_axis
)
return out
......@@ -1351,9 +1351,9 @@ class GeneratorLoader(DataLoaderBase):
self._use_double_buffer = use_double_buffer
self._capacity = capacity
if not self._iterable:
# Because layers.io.double_buffer is not supported anymore, and only when iterable and use_double_buffer
# are both True layers.io.double_buffer will be in use, here if itrable is False, use_double_buffer will be
# forcely set False to avoid using layers.io.double_buffer.
# Because layers.io.double_buffer is not supported anymore and that iterable is False and use_double_buffer
# is True is not spported, here if itrable is False, use_double_buffer will be
# forcely set False to avoid unexpected error.
# TODO: keep use_double_buffer
self._use_double_buffer = False
self._init_non_iterable()
......
......@@ -293,21 +293,25 @@ class TestGroupNormException(unittest.TestCase):
class TestGroupNormEager(unittest.TestCase):
def test_dygraph_api(self):
self.dtype = np.float64
# not supported float64
# only support float32
self.dtype = np.float32
self.shape = (8, 32, 32)
input = np.random.random(self.shape).astype(self.dtype)
with fluid.dygraph.guard():
tensor_1 = fluid.dygraph.to_variable(input)
tensor_1.stop_gradient = False
groupNorm = fluid.dygraph.nn.GroupNorm(channels=32, groups=4)
groupNorm = paddle.nn.GroupNorm(num_channels=32, num_groups=4)
ret1 = groupNorm(tensor_1)
ret1.backward()
with _test_eager_guard():
tensor_eager_1 = fluid.dygraph.to_variable(input)
tensor_eager_1.stop_gradient = False
groupNorm_eager = fluid.dygraph.nn.GroupNorm(
channels=32, groups=4
groupNorm_eager = paddle.nn.GroupNorm(
num_channels=32, num_groups=4
)
ret2 = groupNorm_eager(tensor_eager_1)
ret2.backward()
......@@ -328,16 +332,14 @@ class TestGroupNormEager_fp32(unittest.TestCase):
with fluid.dygraph.guard():
tensor_1 = fluid.dygraph.to_variable(input)
tensor_1.stop_gradient = False
groupNorm = fluid.dygraph.nn.GroupNorm(
channels=32, groups=4, dtype='float32'
)
groupNorm = paddle.nn.GroupNorm(num_channels=32, num_groups=4)
ret1 = groupNorm(tensor_1)
ret1.backward()
with _test_eager_guard():
tensor_eager_1 = fluid.dygraph.to_variable(input)
tensor_eager_1.stop_gradient = False
groupNorm_eager = fluid.dygraph.nn.GroupNorm(
channels=32, groups=4
groupNorm_eager = paddle.nn.GroupNorm(
num_channels=32, num_groups=4
)
ret2 = groupNorm_eager(tensor_eager_1)
ret2.backward()
......@@ -351,23 +353,25 @@ class TestGroupNormEager_fp32(unittest.TestCase):
class TestGroupNormEager_fp16(unittest.TestCase):
def test_dygraph_api(self):
# not supported float16
# only support float32
self.dtype = np.float32
self.shape = (8, 32, 32)
input = np.random.random(self.shape).astype(self.dtype)
with fluid.dygraph.guard():
tensor_1 = fluid.dygraph.to_variable(input)
tensor_1.stop_gradient = False
groupNorm = fluid.dygraph.nn.GroupNorm(
channels=32, groups=4, dtype='float16'
)
groupNorm = paddle.nn.GroupNorm(num_channels=32, num_groups=4)
ret1 = groupNorm(tensor_1)
ret1.backward()
with _test_eager_guard():
tensor_eager_1 = fluid.dygraph.to_variable(input)
tensor_eager_1.stop_gradient = False
groupNorm_eager = fluid.dygraph.nn.GroupNorm(
channels=32, groups=4
groupNorm_eager = paddle.nn.GroupNorm(
num_channels=32, num_groups=4
)
ret2 = groupNorm_eager(tensor_eager_1)
ret2.backward()
......
......@@ -19,7 +19,6 @@ import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid import Program, program_guard
from paddle.fluid.framework import _test_eager_guard
......@@ -39,106 +38,6 @@ def group_norm_naive_for_general_dimension(x, scale, bias, epsilon, groups):
return output
class TestDygraphGroupNormv2(unittest.TestCase):
def test_dygraph(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda() and core.op_support_gpu("group_norm"):
places.append(fluid.CUDAPlace(0))
shapes = [
[2, 2, 2, 2],
[2, 2, 4],
[4, 2],
[4, 2, 6, 6, 2],
[2, 2, 2, 2, 2, 2],
]
for p in places:
def compute_v1(x):
with fluid.dygraph.guard(p):
gn = fluid.dygraph.GroupNorm(channels=2, groups=2)
y = gn(fluid.dygraph.to_variable(x))
return y.numpy()
def compute_v2(x):
with fluid.dygraph.guard(p):
gn = paddle.nn.GroupNorm(num_channels=2, num_groups=2)
y = gn(fluid.dygraph.to_variable(x))
return y.numpy()
def test_weight_bias_false():
with fluid.dygraph.guard(p):
gn = paddle.nn.GroupNorm(
num_channels=2,
num_groups=2,
weight_attr=False,
bias_attr=False,
)
def test_nn_exception():
with fluid.dygraph.guard(p):
def attr_data_format():
out = paddle.nn.GroupNorm(
num_groups=2, num_channels=2, data_format="CNHW"
)
self.assertRaises(ValueError, attr_data_format)
for shape in shapes:
x = np.random.randn(*shape).astype("float32")
y1 = compute_v1(x)
y2 = compute_v2(x)
result = np.allclose(y1, y2, atol=1e-5)
if not result:
print("y1:", y1, "\ty2:", y2)
self.assertTrue(result)
test_weight_bias_false()
test_nn_exception()
def test_static(self):
paddle.enable_static()
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda() and core.op_support_gpu("group_norm"):
places.append(fluid.CUDAPlace(0))
shapes = [
[2, 6, 2, 2],
[2, 6, 4],
[4, 6],
[4, 6, 6, 6, 2],
[4, 6, 2, 2, 2, 2],
]
for p in places:
exe = fluid.Executor(p)
def compute_v1(x_np):
with program_guard(Program(), Program()):
gn = fluid.dygraph.GroupNorm(channels=6, groups=2)
x = fluid.data(name='x', shape=x_np.shape, dtype=x_np.dtype)
y = gn(x)
exe.run(fluid.default_startup_program())
r = exe.run(feed={'x': x_np}, fetch_list=[y])[0]
return r
def compute_v2(x_np):
with program_guard(Program(), Program()):
gn = paddle.nn.GroupNorm(num_channels=6, num_groups=2)
x = fluid.data(name='x', shape=x_np.shape, dtype=x_np.dtype)
y = gn(x)
exe.run(fluid.default_startup_program())
r = exe.run(feed={'x': x_np}, fetch_list=[y])[0]
return r
for shape in shapes:
x = np.random.randn(*shape).astype("float32")
y1 = compute_v1(x)
y2 = compute_v2(x)
np.testing.assert_allclose(y1, y2, rtol=1e-05, atol=1e-05)
def test_eager_api(self):
with _test_eager_guard():
self.test_dygraph()
class TestGroupNormAPIV2_With_General_Dimensions(unittest.TestCase):
def test_numerical_accuracy(self):
paddle.disable_static()
......
......@@ -21,7 +21,7 @@ import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.fluid.framework as framework
from paddle.fluid.dygraph.nn import BatchNorm, Embedding, GroupNorm
from paddle.fluid.dygraph.nn import BatchNorm, Embedding
from paddle.nn import Linear
......@@ -122,10 +122,10 @@ class TestDygraphLoadStatic(unittest.TestCase):
name='groupnorm_in', shape=[None, 8, 32, 32], dtype='float32'
)
groupnorm_out1 = paddle.static.nn.group_norm(
input=groupnorm_in, groups=4
input=groupnorm_in, groups=4, param_attr=True, bias_attr=True
)
groupnorm_out2 = paddle.static.nn.group_norm(
input=groupnorm_in, groups=4
input=groupnorm_in, groups=4, param_attr=True, bias_attr=True
)
'''
spec_norm = fluid.data(name='spec_norm', shape=[2, 8, 32, 32], dtype='float32')
......@@ -212,8 +212,8 @@ class TestDygraphLoadStatic(unittest.TestCase):
self.layer_norm_1 = paddle.nn.LayerNorm([10])
self.layer_norm_2 = paddle.nn.LayerNorm(10)
self.group_norm1 = GroupNorm(8, 4)
self.gourp_norm2 = GroupNorm(8, 4)
self.group_norm1 = paddle.nn.GroupNorm(4, 8)
self.gourp_norm2 = paddle.nn.GroupNorm(4, 8)
self.w_1 = self.create_parameter(
[100, 100], dtype='float32', attr="weight_test_1"
......
......@@ -191,7 +191,7 @@ class TestLayer(LayerTest):
dtype='float32',
append_batch_size=False,
)
flatten = nn.Flatten()
flatten = paddle.nn.Flatten()
ret = flatten(t)
static_ret = self.get_static_graph_result(
feed={'data': inp}, fetch_list=[ret]
......@@ -199,12 +199,12 @@ class TestLayer(LayerTest):
with self.dynamic_graph():
with _test_eager_guard():
t = base.to_variable(inp)
flatten = nn.Flatten()
flatten = paddle.nn.Flatten()
dy_eager_ret = flatten(t)
dy_eager_ret_value = dy_eager_ret.numpy()
t = base.to_variable(inp)
flatten = nn.Flatten()
flatten = paddle.nn.Flatten()
dy_ret = flatten(t)
dy_ret_value = dy_ret.numpy()
......@@ -1066,10 +1066,10 @@ class TestLayer(LayerTest):
lod_level=1,
append_batch_size=False,
)
groupNorm = nn.GroupNorm(
channels=shape[1],
groups=2,
param_attr=fluid.initializer.Uniform(low=-0.5, high=0.5),
groupNorm = paddle.nn.GroupNorm(
num_channels=shape[1],
num_groups=2,
weight_attr=fluid.initializer.Uniform(low=-0.5, high=0.5),
bias_attr=fluid.initializer.ConstantInitializer(value=1),
)
ret = groupNorm(X)
......@@ -1084,10 +1084,10 @@ class TestLayer(LayerTest):
)[0]
with self.dynamic_graph():
groupNorm = nn.GroupNorm(
channels=shape[1],
groups=2,
param_attr=fluid.initializer.Uniform(low=-0.5, high=0.5),
groupNorm = paddle.nn.GroupNorm(
num_channels=shape[1],
num_groups=2,
weight_attr=fluid.initializer.Uniform(low=-0.5, high=0.5),
bias_attr=fluid.initializer.ConstantInitializer(value=1),
)
dy_ret = groupNorm(base.to_variable(input))
......@@ -1209,7 +1209,7 @@ class TestLayer(LayerTest):
lod_level=1,
append_batch_size=False,
)
spectralNorm = nn.SpectralNorm(shape, dim=1, power_iters=2)
spectralNorm = paddle.nn.SpectralNorm(shape, axis=1, power_iters=2)
ret = spectralNorm(Weight)
static_ret2 = self.get_static_graph_result(
feed={
......@@ -1223,11 +1223,13 @@ class TestLayer(LayerTest):
with self.dynamic_graph():
with _test_eager_guard():
spectralNorm = nn.SpectralNorm(shape, dim=1, power_iters=2)
spectralNorm = paddle.nn.SpectralNorm(
shape, axis=1, power_iters=2
)
dy_eager_ret = spectralNorm(base.to_variable(input))
dy_eager_rlt_value = dy_eager_ret.numpy()
spectralNorm = nn.SpectralNorm(shape, dim=1, power_iters=2)
spectralNorm = paddle.nn.SpectralNorm(shape, axis=1, power_iters=2)
dy_ret = spectralNorm(base.to_variable(input))
dy_rlt_value = dy_ret.numpy()
......@@ -1235,200 +1237,6 @@ class TestLayer(LayerTest):
np.testing.assert_allclose(static_ret, dy_eager_rlt_value, rtol=1e-05)
np.testing.assert_allclose(static_ret, static_ret2, rtol=1e-05)
def test_tree_conv(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
else:
place = core.CPUPlace()
adj_array = [1, 2, 1, 3, 1, 4, 1, 5, 2, 6, 2, 7, 2, 8, 4, 9, 4, 10]
adj = np.array(adj_array).reshape((1, 9, 2)).astype('int32')
adj = np.tile(adj, (1, 1, 1))
vectors = np.random.random((1, 10, 5)).astype('float32')
with self.static_graph():
NodesVector = fluid.layers.data(
name='NodesVector',
shape=(1, 10, 5),
dtype='float32',
lod_level=1,
append_batch_size=False,
)
EdgeSet = fluid.layers.data(
name='EdgeSet',
shape=(1, 9, 2),
dtype='int32',
lod_level=1,
append_batch_size=False,
)
ret = fluid.contrib.layers.tree_conv(
nodes_vector=NodesVector,
edge_set=EdgeSet,
output_size=6,
num_filters=1,
max_depth=2,
)
static_ret = self.get_static_graph_result(
feed={
'NodesVector': fluid.create_lod_tensor(
data=vectors, recursive_seq_lens=[[1]], place=place
),
'EdgeSet': fluid.create_lod_tensor(
data=adj, recursive_seq_lens=[[1]], place=place
),
},
fetch_list=[ret],
with_lod=False,
)[0]
with self.static_graph():
NodesVector = fluid.layers.data(
name='NodesVector',
shape=(1, 10, 5),
dtype='float32',
lod_level=1,
append_batch_size=False,
)
EdgeSet = fluid.layers.data(
name='EdgeSet',
shape=(1, 9, 2),
dtype='int32',
lod_level=1,
append_batch_size=False,
)
treeConv = nn.TreeConv(
feature_size=5, output_size=6, num_filters=1, max_depth=2
)
ret = treeConv(NodesVector, EdgeSet)
static_ret2 = self.get_static_graph_result(
feed={
'NodesVector': fluid.create_lod_tensor(
data=vectors, recursive_seq_lens=[[1]], place=place
),
'EdgeSet': fluid.create_lod_tensor(
data=adj, recursive_seq_lens=[[1]], place=place
),
},
fetch_list=[ret],
with_lod=False,
)[0]
with self.dynamic_graph():
with _test_eager_guard():
treeConv = nn.TreeConv(
feature_size=5, output_size=6, num_filters=1, max_depth=2
)
dy_eager_ret = treeConv(
base.to_variable(vectors), base.to_variable(adj)
)
dy_eager_rlt_value = dy_eager_ret.numpy()
treeConv = nn.TreeConv(
feature_size=5, output_size=6, num_filters=1, max_depth=2
)
dy_ret = treeConv(base.to_variable(vectors), base.to_variable(adj))
dy_rlt_value = dy_ret.numpy()
np.testing.assert_allclose(static_ret, static_ret2, rtol=1e-05)
np.testing.assert_allclose(static_ret, dy_rlt_value, rtol=1e-05)
np.testing.assert_allclose(static_ret, dy_eager_rlt_value, rtol=1e-05)
with self.dynamic_graph():
with _test_eager_guard():
custom_weight = np.random.randn(5, 3, 6, 1).astype("float32")
weight_attr = fluid.ParamAttr(
initializer=fluid.initializer.NumpyArrayInitializer(
custom_weight
)
)
treeConv1 = nn.TreeConv(
feature_size=5,
output_size=6,
num_filters=1,
max_depth=2,
bias_attr='eager_tc1_b',
)
treeConv2 = nn.TreeConv(
feature_size=5,
output_size=6,
num_filters=1,
max_depth=2,
param_attr=weight_attr,
bias_attr='eager_tc2_b',
)
dy_ret1 = treeConv1(
base.to_variable(vectors), base.to_variable(adj)
)
dy_ret2 = treeConv2(
base.to_variable(vectors), base.to_variable(adj)
)
self.assertFalse(
np.array_equal(dy_ret1.numpy(), dy_ret2.numpy())
)
treeConv2.weight.set_value(treeConv1.weight.numpy())
treeConv2.bias.set_value(treeConv1.bias)
dy_ret1 = treeConv1(
base.to_variable(vectors), base.to_variable(adj)
)
dy_ret2 = treeConv2(
base.to_variable(vectors), base.to_variable(adj)
)
np.testing.assert_array_equal(dy_ret1.numpy(), dy_ret2.numpy())
treeConv2.weight = treeConv1.weight
treeConv2.bias = treeConv1.bias
np.testing.assert_array_equal(
treeConv1.weight.numpy(), treeConv2.weight.numpy()
)
np.testing.assert_array_equal(
treeConv1.bias.numpy(), treeConv2.bias.numpy()
)
custom_weight = np.random.randn(5, 3, 6, 1).astype("float32")
weight_attr = fluid.ParamAttr(
initializer=fluid.initializer.NumpyArrayInitializer(
custom_weight
)
)
treeConv1 = nn.TreeConv(
feature_size=5,
output_size=6,
num_filters=1,
max_depth=2,
bias_attr='tc1_b',
)
treeConv2 = nn.TreeConv(
feature_size=5,
output_size=6,
num_filters=1,
max_depth=2,
param_attr=weight_attr,
bias_attr='tc2_b',
)
dy_ret1 = treeConv1(
base.to_variable(vectors), base.to_variable(adj)
)
dy_ret2 = treeConv2(
base.to_variable(vectors), base.to_variable(adj)
)
self.assertFalse(np.array_equal(dy_ret1.numpy(), dy_ret2.numpy()))
treeConv2.weight.set_value(treeConv1.weight.numpy())
treeConv2.bias.set_value(treeConv1.bias)
dy_ret1 = treeConv1(
base.to_variable(vectors), base.to_variable(adj)
)
dy_ret2 = treeConv2(
base.to_variable(vectors), base.to_variable(adj)
)
np.testing.assert_array_equal(dy_ret1.numpy(), dy_ret2.numpy())
treeConv2.weight = treeConv1.weight
treeConv2.bias = treeConv1.bias
np.testing.assert_array_equal(
treeConv1.weight.numpy(), treeConv2.weight.numpy()
)
np.testing.assert_array_equal(
treeConv1.bias.numpy(), treeConv2.bias.numpy()
)
def test_conv3d_transpose(self):
input_array = (
np.arange(0, 48).reshape([2, 3, 2, 2, 2]).astype('float32')
......
......@@ -17,6 +17,7 @@ import unittest
import numpy as np
from op_test import OpTest, skip_check_grad_ci
import paddle
import paddle.fluid as fluid
from paddle.fluid.framework import Program, program_guard
......@@ -152,9 +153,7 @@ class TestDygraphSpectralNormOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
shape = (2, 4, 3, 3)
spectralNorm = fluid.dygraph.nn.SpectralNorm(
shape, dim=1, power_iters=2
)
spectralNorm = paddle.nn.SpectralNorm(shape, axis=1, power_iters=2)
def test_Variable():
weight_1 = np.random.random((2, 4)).astype("float32")
......
......@@ -196,30 +196,5 @@ class TestTreeConv_OpError(unittest.TestCase):
)
class TestDygraphTreeConv_OpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
TreeConv = fluid.dygraph.nn.TreeConv(
feature_size=5, output_size=6, num_filters=1, max_depth=2
)
nodes_vector_1 = np.random.random((10, 5)).astype("float32")
edge_set_1 = fluid.layers.data(
name='edge_set_1', shape=[10, 2], dtype='float32'
)
# the nodes_vector of TreeConv must be Variable.
self.assertRaises(
TypeError, TreeConv, nodes_vector_1, edge_set_1, 3
)
nodes_vector_2 = fluid.layers.data(
name='vectors2', shape=[10, 5], dtype='float32'
)
edge_set_2 = np.random.random((10, 2)).astype("float32")
# the edge_set of TreeConv must be Variable.
self.assertRaises(
TypeError, TreeConv, nodes_vector_2, edge_set_2, 3
)
if __name__ == "__main__":
unittest.main()
......@@ -17,7 +17,6 @@ import paddle
from paddle import in_dynamic_mode
from paddle.nn import Layer
from ...fluid.dygraph import Flatten # noqa: F401
from .. import functional as F
__all__ = []
......@@ -1705,3 +1704,41 @@ class Fold(Layer):
self.strides,
name_str,
)
class Flatten(Layer):
"""
This interface is used to construct a callable object of the ``FLatten`` class.
For more details, refer to code examples.
It implements flatten a contiguous range of dims into a tensor.
Parameters:
start_axis(int): first dim to flatten (default = 1)
stop_axis(int): last dim to flatten (default = -1).
Returns:
None
Examples:
.. code-block:: python
import paddle
inp = paddle.ones([5, 2, 3, 4]).astype('float32')
flatten = paddle.nn.Flatten(start_axis=1, stop_axis=2)
y = flatten(inp)
# y.shape = [5, 6, 4]
"""
def __init__(self, start_axis=1, stop_axis=-1):
super().__init__()
self.start_axis = start_axis
self.stop_axis = stop_axis
def forward(self, x):
out = paddle.flatten(
x, start_axis=self.start_axis, stop_axis=self.stop_axis
)
return out
......@@ -39,12 +39,11 @@ from paddle.fluid.framework import _in_legacy_dygraph, in_dygraph_mode
from ...fluid.data_feeder import check_variable_and_dtype
from ...fluid.dygraph import BatchNorm # noqa: F401
from ...fluid.dygraph import SpectralNorm # noqa: F401
from ...framework import ParamAttr, get_default_dtype, no_grad
from .. import Layer
from .. import functional as F
from ..functional import batch_norm, instance_norm, layer_norm
from ..initializer import Constant
from ..initializer import Constant, Normal
__all__ = []
......@@ -388,8 +387,8 @@ class GroupNorm(Layer):
shape=param_shape,
default_initializer=Constant(1.0),
)
self.weight.stop_gradient = (
self._weight_attr is not None
self.weight.stop_gradient = self._weight_attr is not None and (
hasattr(self._weight_attr, "learning_rate")
and self._weight_attr.learning_rate == 0.0
)
......@@ -405,8 +404,8 @@ class GroupNorm(Layer):
self.bias = self.create_parameter(
attr=self._bias_attr, shape=param_shape, is_bias=True
)
self.bias.stop_gradient = (
self._bias_attr is not None
self.bias.stop_gradient = self._bias_attr is not None and (
hasattr(self._bias_attr, "learning_rate")
and self._bias_attr.learning_rate == 0.0
)
......@@ -1431,3 +1430,137 @@ class LocalResponseNorm(Layer):
if self.name is not None:
main_str += ', name={}'.format(self.name)
return main_str
class SpectralNorm(Layer):
r"""
This interface is used to construct a callable object of the ``SpectralNorm`` class.
For more details, refer to code examples. It implements the function of the Spectral Normalization Layer.
This layer calculates the spectral normalization value of weight parameters of
fc, conv1d, conv2d, conv3d layers which should be 2-D, 3-D, 4-D, 5-D
Parameters. Calculations are showed as follows.
Step 1:
Generate vector U in shape of [H], and V in shape of [W].
While H is the :attr:`axis` th dimension of the input weights,
and W is the product result of remaining dimensions.
Step 2:
:attr:`power_iters` should be a positive integer, do following
calculations with U and V for :attr:`power_iters` rounds.
.. math::
\mathbf{v} := \frac{\mathbf{W}^{T} \mathbf{u}}{\|\mathbf{W}^{T} \mathbf{u}\|_2}
\mathbf{u} := \frac{\mathbf{W}^{T} \mathbf{v}}{\|\mathbf{W}^{T} \mathbf{v}\|_2}
Step 3:
Calculate :math:`\sigma(\mathbf{W})` and normalize weight values.
.. math::
\sigma(\mathbf{W}) = \mathbf{u}^{T} \mathbf{W} \mathbf{v}
\mathbf{W} = \frac{\mathbf{W}}{\sigma(\mathbf{W})}
Refer to `Spectral Normalization <https://arxiv.org/abs/1802.05957>`_ .
Parameters:
weight_shape(list or tuple): The shape of weight parameter.
axis(int, optional): The index of dimension which should be permuted to the first before reshaping Input(Weight) to matrix, it should be set as 0 if Input(Weight) is the weight of fc layer, and should be set as 1 if Input(Weight) is the weight of conv layer. Default: 0.
power_iters(int, optional): The number of power iterations to calculate spectral norm. Default: 1.
epsilon(float, optional): The epsilon for numerical stability in calculating norms. Default: 1e-12.
name (str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name` .
dtype (str, optional): Data type, it can be "float32" or "float64". Default: "float32".
Returns:
None
Examples:
.. code-block:: python
import paddle
x = paddle.rand((2,8,32,32))
spectral_norm = paddle.nn.SpectralNorm(x.shape, axis=1, power_iters=2)
spectral_norm_out = spectral_norm(x)
print(spectral_norm_out.shape) # [2, 8, 32, 32]
"""
def __init__(
self,
weight_shape,
axis=0,
power_iters=1,
epsilon=1e-12,
dtype='float32',
):
super().__init__()
self._power_iters = power_iters
self._epsilon = epsilon
self._dim = axis
self._dtype = dtype
self._weight_shape = list(weight_shape)
assert (
np.prod(self._weight_shape) > 0
), "Any dimension of `weight_shape` cannot be equal to 0."
assert axis < len(self._weight_shape), (
"The input `axis` should be less than the "
"length of `weight_shape`, but received axis="
"{}".format(axis)
)
h = self._weight_shape[self._dim]
w = np.prod(self._weight_shape) // h
self.weight_u = self.create_parameter(
attr=ParamAttr(),
shape=[h],
dtype=self._dtype,
default_initializer=Normal(0.0, 1.0),
)
self.weight_u.stop_gradient = True
self.weight_v = self.create_parameter(
attr=ParamAttr(),
shape=[w],
dtype=self._dtype,
default_initializer=Normal(0.0, 1.0),
)
self.weight_v.stop_gradient = True
def forward(self, x):
weight = x
if in_dygraph_mode():
return _C_ops.spectral_norm(
weight,
self.weight_u,
self.weight_v,
self._dim,
self._power_iters,
self._epsilon,
)
check_variable_and_dtype(
weight, "weight", ['float32', 'float64'], 'SpectralNorm'
)
inputs = {'Weight': weight, 'U': self.weight_u, 'V': self.weight_v}
out = self._helper.create_variable_for_type_inference(self._dtype)
self._helper.append_op(
type="spectral_norm",
inputs=inputs,
outputs={
"Out": out,
},
attrs={
"dim": self._dim,
"power_iters": self._power_iters,
"eps": self._epsilon,
},
)
return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册