未验证 提交 485de16a 编写于 作者: 傅剑寒 提交者: GitHub

(fluid清理)move prelu from fluid.layers to static.nn (#47894)

上级 9ae6c854
......@@ -98,7 +98,6 @@ __all__ = [
'resize_nearest',
'relu',
'log',
'prelu',
'unique',
'unique_with_counts',
'elementwise_add',
......@@ -5333,112 +5332,6 @@ def relu(x, name=None):
return out
@deprecated(since="2.0.0", update_to="paddle.static.nn.prelu")
def prelu(x, mode, param_attr=None, data_format="NCHW", name=None):
r"""
prelu activation.
.. math::
prelu(x) = max(0, x) + \alpha * min(0, x)
There are three modes for the activation:
.. code-block:: text
all: All elements share same alpha.
channel: Elements in same channel share same alpha.
element: All elements do not share alpha. Each element has its own alpha.
Parameters:
x (Tensor): The input Tensor or LoDTensor with data type float32.
mode (str): The mode for weight sharing.
param_attr (ParamAttr|None, optional): The parameter attribute for the learnable
weight (alpha), it can be create by ParamAttr. None by default.
For detailed information, please refer to :ref:`api_fluid_ParamAttr`.
data_format(str, optional): Data format that specifies the layout of input.
It may be "NC", "NCL", "NCHW", "NCDHW", "NLC", "NHWC" or "NDHWC". Default: "NCHW".
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor, A tensor with the same shape and data type as x.
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([-1., 2., 3.])
param = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(0.2))
out = paddle.static.nn.prelu(x, 'all', param)
# [-0.2, 2., 3.]
"""
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'prelu')
helper = LayerHelper('prelu', **locals())
if mode not in ['all', 'channel', 'element']:
raise ValueError('mode should be one of all, channel, element.')
alpha_shape = [1]
if mode == 'channel':
true_data_format = [
'NC',
'NCL',
'NCHW',
'NCDHW',
'NLC',
'NHWC',
'NDHWC',
]
if data_format not in true_data_format:
raise ValueError(
"data_format must be one of 'NC', 'NCL', 'NCHW', 'NCDHW', "
"'NLC', 'NHWC', 'NDHWC' but receive {}".format(data_format)
)
data_format = 'NCHW' if data_format[1] == 'C' else 'NHWC'
assert (
len(x.shape) >= 2
), "The size of input shape should be equal or larger than 2 in prelu() when mode is 'channel'"
# NOTE(zhiqiu): The alpha_shape should be [1, channel] + [1] * len(x.shape[2:]).
# To be consistent with Prelu, it is simplified.
# NOTE(zhiqiu): Revert shape to [1, channel, 1, 1] for compatibility with saved model of old version.
# NOTE(GuoxiaWang): support NHWC data format
if data_format == 'NHWC':
alpha_shape = [1, 1, 1, x.shape[-1]]
else:
alpha_shape = [1, x.shape[1], 1, 1]
elif mode == 'element':
assert (
len(x.shape) >= 1
), "The size of input shape should be equal or larger than 1 in prelu() when mode is 'element'"
alpha_shape = [1] + list(x.shape)[1:]
dtype = helper.input_dtype(input_param_name='x')
alpha = helper.create_parameter(
attr=helper.param_attr,
shape=alpha_shape,
dtype=dtype,
is_bias=False,
default_initializer=Constant(0.25),
)
if in_dygraph_mode():
return _C_ops.prelu(x, alpha, data_format, mode)
out = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type="prelu",
inputs={"X": x, 'Alpha': alpha},
attrs={"mode": mode, "data_format": data_format},
outputs={"Out": out},
)
return out
from paddle.fluid.framework import convert_np_dtype_to_dtype_
......
......@@ -204,17 +204,17 @@ class TensorRTSubgraphPassDynamicMishFp16SerializeTest(
class TensorRTSubgraphPassPreluAllTest(TensorRTSubgraphPassActivationTest):
def append_act(self, x):
return fluid.layers.prelu(x, mode='all')
return paddle.static.nn.prelu(x, mode='all')
class TensorRTSubgraphPassPreluChannelTest(TensorRTSubgraphPassActivationTest):
def append_act(self, x):
return fluid.layers.prelu(x, mode='channel')
return paddle.static.nn.prelu(x, mode='channel')
class TensorRTSubgraphPassPreluElementTest(TensorRTSubgraphPassActivationTest):
def append_act(self, x):
return fluid.layers.prelu(x, mode='element')
return paddle.static.nn.prelu(x, mode='element')
class TensorRTSubgraphPassPreluDynamicTest(TensorRTSubgraphPassActivationTest):
......@@ -233,7 +233,7 @@ class TensorRTSubgraphPassPreluDynamicTest(TensorRTSubgraphPassActivationTest):
)
def append_act(self, x):
return fluid.layers.prelu(x, mode='all')
return paddle.static.nn.prelu(x, mode='all')
class TensorRTSubgraphPassPreluFp16Test(TensorRTSubgraphPassActivationTest):
......@@ -244,7 +244,7 @@ class TensorRTSubgraphPassPreluFp16Test(TensorRTSubgraphPassActivationTest):
)
def append_act(self, x):
return fluid.layers.prelu(x, mode='all')
return paddle.static.nn.prelu(x, mode='all')
class TensorRTSubgraphPassPreluFp16SerializeTest(
......@@ -257,7 +257,7 @@ class TensorRTSubgraphPassPreluFp16SerializeTest(
)
def append_act(self, x):
return fluid.layers.prelu(x, mode='all')
return paddle.static.nn.prelu(x, mode='all')
class TensorRTSubgraphPassPreluFp16DynamicTest(
......@@ -278,7 +278,7 @@ class TensorRTSubgraphPassPreluFp16DynamicTest(
)
def append_act(self, x):
return fluid.layers.prelu(x, mode='all')
return paddle.static.nn.prelu(x, mode='all')
class TensorRTSubgraphPassPreluFp16DynamicSerializeTest(
......@@ -299,7 +299,7 @@ class TensorRTSubgraphPassPreluFp16DynamicSerializeTest(
)
def append_act(self, x):
return fluid.layers.prelu(x, mode='all')
return paddle.static.nn.prelu(x, mode='all')
class TensorRTSubgraphPassGeluTest(TensorRTSubgraphPassActivationTest):
......
......@@ -82,8 +82,8 @@ class TestDygraphLoadStatic(unittest.TestCase):
prelu_in = fluid.data(
name="prelu_in", shape=[None, 5, 10, 10], dtype='float32'
)
prelu_out_1 = fluid.layers.prelu(prelu_in, "channel")
prelu_out_2 = fluid.layers.prelu(prelu_in, "channel")
prelu_out_1 = paddle.static.nn.prelu(prelu_in, "channel")
prelu_out_2 = paddle.static.nn.prelu(prelu_in, "channel")
bilinear_tensor_pro_x = fluid.data(
"t1", shape=[None, 5], dtype="float32"
......
......@@ -60,7 +60,7 @@ def create_program(data_format="NCHW"):
x.stop_gradient = False
if data_format == "NHWC":
x = paddle.transpose(x, [0, 2, 3, 1])
x = fluid.layers.prelu(x, mode="channel")
x = paddle.static.nn.prelu(x, mode="channel")
conv = ConvBNLayer(
num_channels=3,
num_filters=3,
......
......@@ -1064,7 +1064,7 @@ class TestLayer(LayerTest):
dtype="float32",
append_batch_size=False,
)
out = layers.prelu(
out = paddle.static.nn.prelu(
data_t, mode, param_attr=ParamAttr(initializer=Constant(1.0))
)
static_rlt = self.get_static_graph_result(
......@@ -2916,7 +2916,6 @@ class TestBook(LayerTest):
{
"make_gaussian_random",
"make_kldiv_loss",
"make_prelu",
"make_sampling_id",
"make_uniform_random_batch_size_like",
}
......@@ -3482,22 +3481,6 @@ class TestBook(LayerTest):
out = tmp_pad(input)
return out
def make_prelu(self):
with program_guard(
fluid.default_main_program(), fluid.default_startup_program()
):
input = self._get_data(
name="input", shape=[5, 200, 100, 100], dtype="float32"
)
mode = 'channel'
out = layers.prelu(
input,
mode,
param_attr=ParamAttr(initializer=Constant(1.0)),
name='prelu',
)
return out
def make_mish(self):
with program_guard(
fluid.default_main_program(), fluid.default_startup_program()
......
......@@ -31,7 +31,7 @@ from ...fluid.layers import crf_decoding # noqa: F401
from ...fluid.layers import layer_norm # noqa: F401
from ...fluid.layers import multi_box_head # noqa: F401
from .loss import nce # noqa: F401
from ...fluid.layers import prelu # noqa: F401
from .common import prelu # noqa: F401
from ...fluid.layers import py_func # noqa: F401
from ...fluid.layers import row_conv # noqa: F401
from ...fluid.layers import spectral_norm # noqa: F401
......@@ -78,7 +78,6 @@ __all__ = [ # noqa
'layer_norm',
'multi_box_head',
'nce',
'prelu',
'py_func',
'row_conv',
'spectral_norm',
......@@ -101,4 +100,5 @@ __all__ = [ # noqa
'sequence_enumerate',
'sequence_reverse',
'StaticRNN',
'prelu',
]
......@@ -2083,3 +2083,109 @@ def deform_conv2d(
modulated=True,
name=name,
)
@static_only
def prelu(x, mode, param_attr=None, data_format="NCHW", name=None):
r"""
prelu activation.
.. math::
prelu(x) = max(0, x) + \alpha * min(0, x)
There are three modes for the activation:
.. code-block:: text
all: All elements share same alpha.
channel: Elements in same channel share same alpha.
element: All elements do not share alpha. Each element has its own alpha.
Parameters:
x (Tensor): The input Tensor or LoDTensor with data type float32.
mode (str): The mode for weight sharing.
param_attr (ParamAttr|None, optional): The parameter attribute for the learnable \
weight (alpha), it can be create by ParamAttr. None by default. \
For detailed information, please refer to :ref:`api_paddle_ParamAttr`.
data_format(str, optional): Data format that specifies the layout of input.
It may be "NC", "NCL", "NCHW", "NCDHW", "NLC", "NHWC" or "NDHWC". Default: "NCHW".
name (str, optional): Name for the operation (optional, default is None). \
For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor: A tensor with the same shape and data type as x.
Examples:
.. code-block:: python
import paddle
paddle.enable_static()
x = paddle.static.data(name="x", shape=[None,5,10,10], dtype="float32")
mode = 'channel'
output = paddle.static.nn.prelu(
x,mode,param_attr=paddle.ParamAttr(name='alpha'))
"""
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'prelu')
helper = LayerHelper('prelu', **locals())
if mode not in ['all', 'channel', 'element']:
raise ValueError('mode should be one of all, channel, element.')
alpha_shape = [1]
if mode == 'channel':
true_data_format = [
'NC',
'NCL',
'NCHW',
'NCDHW',
'NLC',
'NHWC',
'NDHWC',
]
if data_format not in true_data_format:
raise ValueError(
"data_format must be one of 'NC', 'NCL', 'NCHW', 'NCDHW', "
"'NLC', 'NHWC', 'NDHWC' but receive {}".format(data_format)
)
data_format = 'NCHW' if data_format[1] == 'C' else 'NHWC'
assert (
len(x.shape) >= 2
), "The size of input shape should be equal or larger than 2 in prelu() when mode is 'channel'"
# NOTE(zhiqiu): The alpha_shape should be [1, channel] + [1] * len(x.shape[2:]).
# To be consistent with Prelu, it is simplified.
# NOTE(zhiqiu): Revert shape to [1, channel, 1, 1] for compatibility with saved model of old version.
# NOTE(GuoxiaWang): support NHWC data format
if data_format == 'NHWC':
alpha_shape = [1, 1, 1, x.shape[-1]]
else:
alpha_shape = [1, x.shape[1], 1, 1]
elif mode == 'element':
assert (
len(x.shape) >= 1
), "The size of input shape should be equal or larger than 1 in prelu() when mode is 'element'"
alpha_shape = [1] + list(x.shape)[1:]
dtype = helper.input_dtype(input_param_name='x')
alpha = helper.create_parameter(
attr=helper.param_attr,
shape=alpha_shape,
dtype=dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0.25),
)
out = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type="prelu",
inputs={"X": x, 'Alpha': alpha},
attrs={"mode": mode, "data_format": data_format},
outputs={"Out": out},
)
return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册