未验证 提交 dc694f1e 编写于 作者: G Ghost Screaming 提交者: GitHub

Clean Fluid APIs in padde.fluid.layers.nn (#48908)

* Fix bug of reduce_sum op. When input.numel() > INT32_MAX, its result
is wrong.

* Remove climits.

* Clean Fluid APIs in python/paddle/fluid/layers/nn.py,
migrate spectral_norm and row_conv APIs, and remove ont_hot.
Including following files:
1. python/paddle/fluid/layers/nn.py
2. python/paddle/fluid/tests/unittests/collective/fleet/parallel_dygraph_transformer.py
3. python/paddle/fluid/tests/unittests/dist_transformer.py
4. python/paddle/fluid/tests/unittests/dygraph_to_static/transformer_dygraph_model.py
5. python/paddle/fluid/tests/unittests/ipu/test_one_hot_op_ipu.py
6. python/paddle/fluid/tests/unittests/test_imperative_auto_prune.py
7. python/paddle/fluid/tests/unittests/test_imperative_load_static_param.py
8. python/paddle/fluid/tests/unittests/test_layers.py
9. python/paddle/fluid/tests/unittests/test_one_hot_op.py
10. python/paddle/fluid/tests/unittests/test_row_conv_op.py
11. python/paddle/fluid/tests/unittests/test_runtime_and_compiletime_exception.py
12. python/paddle/fluid/tests/unittests/test_spectral_norm_op.py
13. python/paddle/static/nn/__init__.py
14. python/paddle/static/nn/common.py

* Polish code.

* Fix some bugs.

* Remove useless unittest.

* Fix some bug.

* Polish example code.

* Fix some bugs.

* Fix some bugs.
上级 aa0098f6
......@@ -65,9 +65,6 @@ from collections.abc import Iterable
__all__ = [
'fc',
'embedding',
'row_conv',
'spectral_norm',
'one_hot',
'autoincreased_step_counter',
'clip',
'clip_by_norm',
......@@ -741,130 +738,6 @@ def _pull_box_sparse(
return outs
@templatedoc()
def spectral_norm(weight, dim=0, power_iters=1, eps=1e-12, name=None):
r"""
:api_attr: Static Graph
**Spectral Normalization Layer**
This operation calculates the spectral normalization value of weight parameters of
fc, conv1d, conv2d, conv3d layers which should be 2-D, 3-D, 4-D, 5-D
Parameters. Output tensor will be in same shape with input tensor.
Calculations are showed as follows.
Step 1:
Generate vector U in shape of [H], and V in shape of [W].
While H is the :attr:`dim` th dimension of the input weights,
and W is the product result of remaining dimensions.
Step 2:
:attr:`power_iters` should be a positive integer, do following
calculations with U and V for :attr:`power_iters` rounds. Calculations
as follows:
.. math::
\mathbf{v} := \\frac{\mathbf{W}^{T} \mathbf{u}}{\|\mathbf{W}^{T} \mathbf{u}\|_2}
\mathbf{u} := \\frac{\mathbf{W}^{T} \mathbf{v}}{\|\mathbf{W}^{T} \mathbf{v}\|_2}
Step 3:
Calculate :math:`\sigma(\mathbf{W})` and normalize weight values.
.. math::
\sigma(\mathbf{W}) = \mathbf{u}^{T} \mathbf{W} \mathbf{v}
\mathbf{W} = \\frac{\mathbf{W}}{\sigma(\mathbf{W})}
Refer to `Spectral Normalization <https://arxiv.org/abs/1802.05957>`_ .
Args:
weight(Tensor): ${weight_comment}
dim(int): ${dim_comment}
power_iters(int): ${power_iters_comment}
eps(float): ${eps_comment}
name(str, optional): For detailed information, please refer
to :ref:`api_guide_Name`. Usually name is no need to set and
None by default.
Returns:
Tensor: A tensor of weight parameters after spectral normalization.
The data type and shape is same as input tensor.
Examples:
.. code-block:: python
import paddle
paddle.enable_static()
weight = paddle.static.data(name='weight', shape=[2, 8, 32, 32], dtype='float32')
x = paddle.static.nn.spectral_norm(weight=weight, dim=1, power_iters=2)
print(x.shape) # [2, 8, 32, 32]
"""
helper = LayerHelper('spectral_norm', **locals())
check_variable_and_dtype(
weight, 'weight', ['float32', 'float64'], 'spectral_norm'
)
check_type(dim, 'dim', int, 'spectral_norm')
check_type(power_iters, 'power_iters', int, 'spectral_norm')
check_type(eps, 'eps', float, 'spectral_norm')
dtype = weight.dtype
# create intput and parameters
input_shape = weight.shape
assert weight.numel() > 0, "Any dimension of input cannot be equal to 0."
assert dim < len(input_shape), (
"The input `dim` should be less than the "
"rank of `weight`, but received dim="
"{}".format(dim)
)
h = input_shape[dim]
w = np.prod(input_shape) // h
u = helper.create_parameter(
attr=ParamAttr(),
shape=[h],
dtype=dtype,
default_initializer=Normal(0.0, 1.0),
)
u.stop_gradient = True
v = helper.create_parameter(
attr=ParamAttr(),
shape=[w],
dtype=dtype,
default_initializer=Normal(0.0, 1.0),
)
v.stop_gradient = True
if in_dygraph_mode():
return _C_ops.spectral_norm(weight, u, v, dim, power_iters, eps)
inputs = {'Weight': weight}
inputs['U'] = u
inputs['V'] = v
# create output
out = helper.create_variable(dtype=dtype)
helper.append_op(
type="spectral_norm",
inputs=inputs,
outputs={
"Out": out,
},
attrs={
"dim": dim,
"power_iters": power_iters,
"eps": eps,
},
)
return out
def reduce_sum(input, dim=None, keep_dim=False, name=None):
"""
......@@ -943,171 +816,6 @@ def reduce_sum(input, dim=None, keep_dim=False, name=None):
return out
@templatedoc()
def row_conv(input, future_context_size, param_attr=None, act=None):
"""
:api_attr: Static Graph
${comment}
Args:
input (${x_type}): ${x_comment}.
future_context_size (int): Future context size. Please note, the shape
of convolution kernel is [future_context_size + 1, D].
param_attr (ParamAttr): Attributes of parameters, including
name, initializer etc.
act (str): Non-linear activation to be applied to output variable.
Returns:
${out_comment}.
Examples:
.. code-block:: python
# for LodTensor inputs
import paddle
paddle.enable_static()
x = paddle.static.data(name='x', shape=[9, 16],
dtype='float32', lod_level=1)
out = paddle.static.nn.row_conv(input=x, future_context_size=2)
# for Tensor inputs
x = paddle.static.data(name='x', shape=[9, 4, 16], dtype='float32')
out = paddle.static.nn.row_conv(input=x, future_context_size=2)
"""
helper = LayerHelper('row_conv', **locals())
check_variable_and_dtype(input, 'input', ['float32'], 'row_conv')
dtype = helper.input_dtype()
filter_shape = [future_context_size + 1, input.shape[-1]]
filter_param = helper.create_parameter(
attr=helper.param_attr, shape=filter_shape, dtype=dtype
)
out = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type='row_conv',
inputs={'X': [input], 'Filter': [filter_param]},
outputs={'Out': [out]},
)
return helper.append_activation(out)
@deprecated(since='2.0.0', update_to='paddle.nn.functional.one_hot')
def one_hot(input, depth, allow_out_of_range=False):
"""
**WARING:** This OP requires the last dimension of Tensor shape must be equal to 1.
This OP will be deprecated in a future release. It is recommended to use fluid. :ref:`api_fluid_one_hot` .
The operator converts each id in the input to an one-hot vector with a
:attr:`depth` length. The value in the vector dimension corresponding to the id
is 1, and the value in the remaining dimension is 0.
The shape of output Tensor or LoDTensor is generated by adding :attr:`depth` dimension
behind the last dimension of the input shape.
.. code-block:: text
Example 1 (allow_out_of_range=False):
input:
X.shape = [4, 1]
X.data = [[1], [1], [3], [0]]
depth = 4
output:
Out.shape = [4, 4]
Out.data = [[0., 1., 0., 0.],
[0., 1., 0., 0.],
[0., 0., 0., 1.],
[1., 0., 0., 0.]]
Example 2 (allow_out_of_range=True):
input:
X.shape = [4, 1]
X.data = [[1], [1], [5], [0]]
depth = 4
allow_out_of_range = True
output:
Out.shape = [4, 4]
Out.data = [[0., 1., 0., 0.],
[0., 1., 0., 0.],
[0., 0., 0., 0.], # This id is 5, which goes beyond depth, so set it all-zeros data.
[1., 0., 0., 0.]]
Example 3 (allow_out_of_range=False):
input:
X.shape = [4, 1]
X.data = [[1], [1], [5], [0]]
depth = 4
allow_out_of_range = False
output: Throw an exception for Illegal value
The second dimension in X is 5, which is greater than depth.
Allow_out_of_range =False means that does not allow the word id to exceed depth,
so it throws an exception.
Args:
input(Variable): Tensor or LoDTensor with shape :math:`[N_1, N_2, ..., N_k, 1]` ,
which contains at least one dimension and the last dimension must be 1.
The data type is int32 or int64.
depth(scalar): An integer defining the :attr:`depth` of the one hot dimension. If input
is word id, depth is generally the dictionary size.
allow_out_of_range(bool): A bool value indicating whether the input
indices could be out of range :math:`[0, depth)` . When input indices are
out of range, exceptions :code:`Illegal value` is raised if :attr:`allow_out_of_range`
is False, or zero-filling representations is created if it is set True.
Default: False.
Returns:
Variable: The one-hot representations of input. A Tensor or LoDTensor with type float32.
Examples:
.. code-block:: python
import paddle
import paddle.fluid as fluid
paddle.enable_static()
# Correspond to the first example above, where label.shape is [4, 1] and one_hot_label.shape is [4, 4].
label = fluid.data(name="label", shape=[4, 1], dtype="int64")
one_hot_label = fluid.layers.one_hot(input=label, depth=4)
"""
if _non_static_mode():
if isinstance(depth, Variable):
depth = depth.numpy()
assert depth.shape == (
1,
), "depth of type Variable should have shape [1]"
depth = depth.item(0)
out = _legacy_C_ops.one_hot(
input, 'depth', depth, 'allow_out_of_range', allow_out_of_range
)
out.stop_gradient = True
return out
helper = LayerHelper("one_hot", **locals())
check_variable_and_dtype(input, 'input', ['int32', 'int64'], 'one_hot')
check_type(depth, 'depth', (int, Variable), 'one_hot')
one_hot_out = helper.create_variable_for_type_inference(dtype='float32')
if not isinstance(depth, Variable):
# user attribute
inputs = {'X': input}
attrs = {'depth': depth, 'allow_out_of_range': allow_out_of_range}
else:
depth.stop_gradient = True
inputs = {'X': input, 'depth_tensor': depth}
attrs = {'allow_out_of_range': allow_out_of_range}
helper.append_op(
type="one_hot", inputs=inputs, attrs=attrs, outputs={'Out': one_hot_out}
)
one_hot_out.stop_gradient = True
return one_hot_out
def autoincreased_step_counter(counter_name=None, begin=1, step=1):
"""
:api_attr: Static Graph
......
......@@ -932,8 +932,8 @@ class TransFormer(Layer):
predict = self._wrap_decoder_layer(dec_inputs, enc_output)
if self._label_smooth_eps:
label_out = F.label_smooth(
label=fluid.layers.one_hot(
input=label, depth=self._trg_vocab_size
label=paddle.squeeze(
paddle.nn.functional.one_hot(label, self._trg_vocab_size)
),
epsilon=self._label_smooth_eps,
)
......
......@@ -1575,7 +1575,7 @@ def transformer(
label, weights = make_all_inputs(label_data_input_fields)
if label_smooth_eps:
label = F.label_smooth(
label=layers.one_hot(input=label, depth=trg_vocab_size),
label=paddle.nn.functional.one_hot(label, trg_vocab_size),
epsilon=label_smooth_eps,
)
......
......@@ -593,7 +593,9 @@ class CrossEntropyCriterion:
def __call__(self, predict, label, weights):
if self.label_smooth_eps:
label_out = F.label_smooth(
label=layers.one_hot(input=label, depth=predict.shape[-1]),
label=paddle.squeeze(
paddle.nn.functional.one_hot(label, predict.shape[-1])
),
epsilon=self.label_smooth_eps,
)
......
......@@ -46,7 +46,7 @@ class TestBase(IPUOpTest):
x = paddle.static.data(
name=self.feed_list[0], shape=self.feed_shape[0], dtype='int32'
)
out = paddle.fluid.layers.one_hot(x, **self.attrs)
out = paddle.nn.functional.one_hot(x, **self.attrs)
self.fetch_list = [out.name]
def run_model(self, exec_mode):
......
......@@ -418,7 +418,7 @@ class TestImperativeAutoPrune(unittest.TestCase):
label = linear(label)
label = fluid.layers.cast(label, dtype="float32")
label = fluid.layers.cast(label, dtype='int64')
out = fluid.layers.one_hot(input=label, depth=100)
out = paddle.nn.functional.one_hot(label, 100)
loss = paddle.mean(out)
loss.backward()
self.assertIsNone(linear.weight._grad_ivar())
......
......@@ -23,6 +23,8 @@ import paddle.fluid as fluid
import paddle.fluid.framework as framework
from paddle.nn import BatchNorm, Linear
paddle.enable_static()
class TestDygraphLoadStatic(unittest.TestCase):
def testLoadStaticModel(self):
......@@ -128,8 +130,8 @@ class TestDygraphLoadStatic(unittest.TestCase):
)
'''
spec_norm = fluid.data(name='spec_norm', shape=[2, 8, 32, 32], dtype='float32')
spe_norm_out_1 = fluid.layers.spectral_norm(weight=spec_norm, dim=1, power_iters=2)
spe_norm_out_2 = fluid.layers.spectral_norm(weight=spec_norm, dim=1, power_iters=2)
spe_norm_out_1 = paddle.static.nn.spectral_norm(weight=spec_norm, dim=1, power_iters=2)
spe_norm_out_2 = paddle.static.nn.spectral_norm(weight=spec_norm, dim=1, power_iters=2)
'''
nodes_vector = fluid.data(
......
......@@ -1085,8 +1085,8 @@ class TransFormer(Layer):
predict = self._wrap_decoder_layer(dec_inputs, enc_output)
if self._label_smooth_eps:
label_out = F.label_smooth(
label=fluid.layers.one_hot(
input=label, depth=self._trg_vocab_size
label=paddle.squeeze(
paddle.nn.functional.one_hot(label, self._trg_vocab_size)
),
epsilon=self._label_smooth_eps,
)
......
......@@ -657,9 +657,9 @@ class TestLayer(LayerTest):
def test_one_hot(self):
with self.dynamic_graph():
label = fluid.dygraph.to_variable(np.array([[1], [1], [3], [0]]))
one_hot_label1 = fluid.layers.one_hot(input=label, depth=4)
one_hot_label2 = fluid.layers.one_hot(
input=label, depth=fluid.dygraph.to_variable(np.array([4]))
one_hot_label1 = paddle.nn.functional.one_hot(label, 4)
one_hot_label2 = paddle.nn.functional.one_hot(
label, fluid.dygraph.to_variable(np.array([4]))
)
np.testing.assert_array_equal(
one_hot_label1.numpy(), one_hot_label2.numpy()
......@@ -921,7 +921,9 @@ class TestLayer(LayerTest):
lod_level=1,
append_batch_size=False,
)
ret = layers.spectral_norm(weight=Weight, dim=1, power_iters=2)
ret = paddle.static.nn.spectral_norm(
weight=Weight, dim=1, power_iters=2
)
static_ret = self.get_static_graph_result(
feed={
'Weight': fluid.create_lod_tensor(
......@@ -1791,7 +1793,7 @@ class TestBook(LayerTest):
def make_one_hot(self):
with fluid.framework._dygraph_place_guard(place=fluid.CPUPlace()):
label = self._get_data(name="label", shape=[1], dtype="int32")
one_hot_label = layers.one_hot(input=label, depth=10)
one_hot_label = paddle.nn.functional.one_hot(label, 10)
return one_hot_label
def make_label_smooth(self):
......@@ -1799,7 +1801,7 @@ class TestBook(LayerTest):
self._force_to_use_cpu = True
with fluid.framework._dygraph_place_guard(place=fluid.CPUPlace()):
label = self._get_data(name="label", shape=[1], dtype="int32")
one_hot_label = layers.one_hot(input=label, depth=10)
one_hot_label = paddle.nn.functional.one_hot(label, 10)
smooth_label = F.label_smooth(label=one_hot_label, epsilon=0.1)
return smooth_label
......@@ -1992,7 +1994,7 @@ class TestBook(LayerTest):
dtype="float32",
append_batch_size=False,
)
out = layers.spectral_norm(weight, dim=1, power_iters=1)
out = paddle.static.nn.spectral_norm(weight, dim=1, power_iters=1)
return out
def make_kldiv_loss(self):
......@@ -2226,7 +2228,7 @@ class TestBook(LayerTest):
# TODO(minqiyang): dygraph do not support lod now
with self.static_graph():
x = layers.data(name='x', shape=[16], dtype='float32', lod_level=1)
out = layers.row_conv(input=x, future_context_size=2)
out = paddle.static.nn.row_conv(input=x, future_context_size=2)
return out
def test_simple_conv2d(self):
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from op_test import OpTest
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.framework import Program, program_guard
class TestOneHotOp(OpTest):
def setUp(self):
self.op_type = 'one_hot'
depth = 10
depth_np = np.array(10).astype('int32')
dimension = 12
x_lod = [[4, 1, 3, 3]]
x = [np.random.randint(0, depth - 1) for i in range(sum(x_lod[0]))]
x = np.array(x).astype('int32').reshape([sum(x_lod[0]), 1])
out = np.zeros(shape=(np.product(x.shape[:-1]), depth)).astype(
'float32'
)
for i in range(np.product(x.shape)):
out[i, x[i]] = 1.0
self.inputs = {'X': (x, x_lod), 'depth_tensor': depth_np}
self.attrs = {'dtype': int(core.VarDesc.VarType.FP32)}
self.outputs = {'Out': (out, x_lod)}
def test_check_output(self):
self.check_output(check_dygraph=False)
class TestOneHotOp_attr(OpTest):
def setUp(self):
self.op_type = 'one_hot'
depth = 10
dimension = 12
x_lod = [[4, 1, 3, 3]]
x = [np.random.randint(0, depth - 1) for i in range(sum(x_lod[0]))]
x = np.array(x).astype('int32').reshape([sum(x_lod[0]), 1])
out = np.zeros(shape=(np.product(x.shape[:-1]), depth)).astype(
'float32'
)
for i in range(np.product(x.shape)):
out[i, x[i]] = 1.0
self.inputs = {'X': (x, x_lod)}
self.attrs = {'dtype': int(core.VarDesc.VarType.FP32), 'depth': depth}
self.outputs = {'Out': (out, x_lod)}
def test_check_output(self):
self.check_output(check_dygraph=False)
class TestOneHotOp_default_dtype(OpTest):
def setUp(self):
self.op_type = 'one_hot'
depth = 10
depth_np = np.array(10).astype('int32')
dimension = 12
x_lod = [[4, 1, 3, 3]]
x = [np.random.randint(0, depth - 1) for i in range(sum(x_lod[0]))]
x = np.array(x).astype('int32').reshape([sum(x_lod[0]), 1])
out = np.zeros(shape=(np.product(x.shape[:-1]), depth)).astype(
'float32'
)
for i in range(np.product(x.shape)):
out[i, x[i]] = 1.0
self.inputs = {'X': (x, x_lod), 'depth_tensor': depth_np}
self.attrs = {}
self.outputs = {'Out': (out, x_lod)}
def test_check_output(self):
self.check_output(check_dygraph=False)
class TestOneHotOp_default_dtype_attr(OpTest):
def setUp(self):
self.op_type = 'one_hot'
depth = 10
dimension = 12
x_lod = [[4, 1, 3, 3]]
x = [np.random.randint(0, depth - 1) for i in range(sum(x_lod[0]))]
x = np.array(x).astype('int32').reshape([sum(x_lod[0]), 1])
out = np.zeros(shape=(np.product(x.shape[:-1]), depth)).astype(
'float32'
)
for i in range(np.product(x.shape)):
out[i, x[i]] = 1.0
self.inputs = {'X': (x, x_lod)}
self.attrs = {'depth': depth}
self.outputs = {'Out': (out, x_lod)}
def test_check_output(self):
self.check_output(check_dygraph=False)
class TestOneHotOp_out_of_range(OpTest):
def setUp(self):
self.op_type = 'one_hot'
depth = 10
x_lod = [[4, 1, 3, 3]]
x = [np.random.choice([-1, depth]) for i in range(sum(x_lod[0]))]
x = np.array(x).astype('int32').reshape([sum(x_lod[0]), 1])
out = np.zeros(shape=(np.product(x.shape[:-1]), depth)).astype(
'float32'
)
self.inputs = {'X': (x, x_lod)}
self.attrs = {'depth': depth, 'allow_out_of_range': True}
self.outputs = {'Out': (out, x_lod)}
def test_check_output(self):
self.check_output(check_dygraph=False)
class TestOneHotOp_exception(unittest.TestCase):
def setUp(self):
self.op_type = 'one_hot'
self.depth = 10
self.place = core.CPUPlace()
self.dimension = 12
self.x = core.LoDTensor()
x_lod = [[4, 1, 3, 3]]
data = [np.random.randint(11, 20) for i in range(sum(x_lod[0]))]
data = np.array(data).astype('int').reshape([sum(x_lod[0]), 1])
self.x.set(data, self.place)
self.x.set_recursive_sequence_lengths(x_lod)
def test_check_output(self):
program = Program()
with program_guard(program):
x = fluid.layers.data(
name='x', shape=[self.dimension], dtype='float32', lod_level=1
)
block = program.current_block()
one_hot_out = block.create_var(
name="one_hot_out",
type=core.VarDesc.VarType.LOD_TENSOR,
dtype='float32',
)
block.append_op(
type='one_hot',
inputs={'X': x},
attrs={'depth': self.depth},
outputs={'Out': one_hot_out},
)
exe = fluid.Executor(self.place)
def run():
exe.run(
feed={'x': self.x},
fetch_list=[one_hot_out],
return_numpy=False,
)
self.assertRaises(ValueError, run)
class TestOneHotOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
# the input must be Variable
in_w = np.random.random((4, 1)).astype("int32")
self.assertRaises(TypeError, fluid.layers.one_hot, in_w)
# the input must be int32 or int 64
in_w2 = fluid.layers.data(
name="in_w2",
shape=[4, 1],
append_batch_size=False,
dtype="float32",
)
self.assertRaises(TypeError, fluid.layers.one_hot, in_w2)
# the depth must be int, long or Variable
in_r = fluid.layers.data(
name="in_r",
shape=[4, 1],
append_batch_size=False,
dtype="int32",
)
depth_w = np.array([4])
self.assertRaises(TypeError, fluid.layers.one_hot, in_r, 4.1)
self.assertRaises(TypeError, fluid.layers.one_hot, in_r, depth_w)
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
......@@ -17,8 +17,11 @@ import unittest
import numpy as np
from op_test import OpTest
import paddle
from paddle import fluid
paddle.enable_static()
def row_conv_forward(x, lod, wt):
out = np.zeros_like(x)
......@@ -191,7 +194,7 @@ class TestRowConvLayer(unittest.TestCase):
with fluid.unique_name.guard():
with fluid.program_guard(main, start):
x = fluid.data("x", (-1, -1, self.C), "float32")
out = fluid.layers.row_conv(
out = paddle.static.nn.row_conv(
x,
self.context_length,
param_attr=fluid.initializer.NumpyArrayInitializer(self.w),
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
class TestRunTimeException(unittest.TestCase):
def test_run_time_exception(self):
place = fluid.CPUPlace()
exe = fluid.Executor(place)
train_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(train_program, startup_program):
label = fluid.layers.data(name="label", shape=[1], dtype="int64")
fluid.layers.one_hot(input=label, depth=100)
def _run_program():
x = np.random.random(size=(10)).astype('int64')
exe.run(train_program, feed={"label": x})
self.assertRaises(ValueError, _run_program)
class TestCompileTimeException(unittest.TestCase):
def test_compile_time_exception(self):
self.assertRaises(ValueError, self.build_model)
def build_model(self):
train_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(train_program, startup_program):
label = fluid.layers.data(
name="label", shape=[1], dtype="int64", append_batch_size=False
)
fluid.layers.one_hot(input=label, depth=100)
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
......@@ -18,9 +18,10 @@ import numpy as np
from op_test import OpTest, skip_check_grad_ci
import paddle
import paddle.fluid as fluid
from paddle.fluid.framework import Program, program_guard
paddle.enable_static()
def spectral_norm(weight, u, v, dim, power_iters, eps):
shape = weight.shape
......@@ -136,14 +137,14 @@ class TestSpectralNormOpError(unittest.TestCase):
def test_Variable():
weight_1 = np.random.random((2, 4)).astype("float32")
fluid.layers.spectral_norm(weight_1, dim=1, power_iters=2)
paddle.static.nn.spectral_norm(weight_1, dim=1, power_iters=2)
# the weight type of spectral_norm must be Variable
self.assertRaises(TypeError, test_Variable)
def test_weight_dtype():
weight_2 = np.random.random((2, 4)).astype("int32")
fluid.layers.spectral_norm(weight_2, dim=1, power_iters=2)
paddle.static.nn.spectral_norm(weight_2, dim=1, power_iters=2)
# the data type of type must be float32 or float64
self.assertRaises(TypeError, test_weight_dtype)
......
......@@ -30,12 +30,13 @@ from .control_flow import (
)
from .common import bilinear_tensor_product # noqa: F401
from .common import py_func # noqa: F401
from .common import row_conv # noqa: F401
from .common import spectral_norm # noqa: F401
from ...tensor.creation import create_parameter # noqa: F401
from .loss import nce # noqa: F401
from .common import prelu # noqa: F401
from .common import layer_norm # noqa: F401
from ...fluid.layers import row_conv # noqa: F401
from ...fluid.layers import spectral_norm # noqa: F401
from ...fluid.input import embedding # noqa: F401
from ...fluid.contrib.layers import sparse_embedding # noqa: F401
......
......@@ -3233,6 +3233,178 @@ def py_func(func, x, out, backward_func=None, skip_vars_in_backward_input=None):
return out
@templatedoc()
def row_conv(input, future_context_size, param_attr=None, act=None):
"""
:api_attr: Static Graph
${comment}
Args:
input (${x_type}): ${x_comment}.
future_context_size (int): Future context size. Please note, the shape
of convolution kernel is [future_context_size + 1, D].
param_attr (ParamAttr): Attributes of parameters, including
name, initializer etc.
act (str): Non-linear activation to be applied to output variable.
Returns:
${out_comment}.
Examples:
.. code-block:: python
# for LodTensor inputs
import paddle
paddle.enable_static()
x = paddle.static.data(name='x', shape=[9, 16],
dtype='float32', lod_level=1)
out_x = paddle.static.nn.row_conv(input=x, future_context_size=2)
# for Tensor inputs
y = paddle.static.data(name='y', shape=[9, 4, 16], dtype='float32')
out_y = paddle.static.nn.row_conv(input=y, future_context_size=2)
"""
helper = LayerHelper('row_conv', **locals())
check_variable_and_dtype(input, 'input', ['float32'], 'row_conv')
dtype = helper.input_dtype()
filter_shape = [future_context_size + 1, input.shape[-1]]
filter_param = helper.create_parameter(
attr=helper.param_attr, shape=filter_shape, dtype=dtype
)
out = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type='row_conv',
inputs={'X': [input], 'Filter': [filter_param]},
outputs={'Out': [out]},
)
return helper.append_activation(out)
@templatedoc()
def spectral_norm(weight, dim=0, power_iters=1, eps=1e-12, name=None):
r"""
:api_attr: Static Graph
**Spectral Normalization Layer**
This operation calculates the spectral normalization value of weight parameters of
fc, conv1d, conv2d, conv3d layers which should be 2-D, 3-D, 4-D, 5-D
Parameters. Output tensor will be in same shape with input tensor.
Calculations are showed as follows.
Step 1:
Generate vector U in shape of [H], and V in shape of [W].
While H is the :attr:`dim` th dimension of the input weights,
and W is the product result of remaining dimensions.
Step 2:
:attr:`power_iters` should be a positive integer, do following
calculations with U and V for :attr:`power_iters` rounds. Calculations
as follows:
.. math::
\mathbf{v} := \\frac{\mathbf{W}^{T} \mathbf{u}}{\|\mathbf{W}^{T} \mathbf{u}\|_2}
\mathbf{u} := \\frac{\mathbf{W}^{T} \mathbf{v}}{\|\mathbf{W}^{T} \mathbf{v}\|_2}
Step 3:
Calculate :math:`\sigma(\mathbf{W})` and normalize weight values.
.. math::
\sigma(\mathbf{W}) = \mathbf{u}^{T} \mathbf{W} \mathbf{v}
\mathbf{W} = \\frac{\mathbf{W}}{\sigma(\mathbf{W})}
Refer to `Spectral Normalization <https://arxiv.org/abs/1802.05957>`_ .
Args:
weight(Tensor): ${weight_comment}
dim(int): ${dim_comment}
power_iters(int): ${power_iters_comment}
eps(float): ${eps_comment}
name(str, optional): For detailed information, please refer
to :ref:`api_guide_Name`. Usually name is no need to set and
None by default.
Returns:
Tensor: A tensor of weight parameters after spectral normalization.
The data type and shape is same as input tensor.
Examples:
.. code-block:: python
import paddle
paddle.enable_static()
weight = paddle.static.data(name='weight', shape=[2, 8, 32, 32], dtype='float32')
x = paddle.static.nn.spectral_norm(weight=weight, dim=1, power_iters=2)
print(x.shape) # [2, 8, 32, 32]
"""
helper = LayerHelper('spectral_norm', **locals())
check_variable_and_dtype(
weight, 'weight', ['float32', 'float64'], 'spectral_norm'
)
check_type(dim, 'dim', int, 'spectral_norm')
check_type(power_iters, 'power_iters', int, 'spectral_norm')
check_type(eps, 'eps', float, 'spectral_norm')
dtype = weight.dtype
# create intput and parameters
input_shape = weight.shape
assert weight.numel() > 0, "Any dimension of input cannot be equal to 0."
assert dim < len(input_shape), (
"The input `dim` should be less than the "
"rank of `weight`, but received dim="
"{}".format(dim)
)
h = input_shape[dim]
w = np.prod(input_shape) // h
u = helper.create_parameter(
attr=ParamAttr(),
shape=[h],
dtype=dtype,
default_initializer=Normal(0.0, 1.0),
)
u.stop_gradient = True
v = helper.create_parameter(
attr=ParamAttr(),
shape=[w],
dtype=dtype,
default_initializer=Normal(0.0, 1.0),
)
v.stop_gradient = True
if paddle.framework.in_dygraph_mode():
return paddle._C_ops.spectral_norm(weight, u, v, dim, power_iters, eps)
inputs = {'Weight': weight}
inputs['U'] = u
inputs['V'] = v
# create output
out = helper.create_variable(dtype=dtype)
helper.append_op(
type="spectral_norm",
inputs=inputs,
outputs={
"Out": out,
},
attrs={
"dim": dim,
"power_iters": power_iters,
"eps": eps,
},
)
return out
# For debug usage
py_func.registered_func = PyFuncRegistry.registered_func
py_func.registered_func_num = PyFuncRegistry.registered_func_num
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册