From 57e097ac005e8078be43d73fa028b73f208b3adc Mon Sep 17 00:00:00 2001 From: LiYuRio <63526175+LiYuRio@users.noreply.github.com> Date: Tue, 29 Nov 2022 12:25:27 +0800 Subject: [PATCH] [Fluid API] Move instance_norm, group_norm, data_norm from fluid to static (#48448) * move instance_norm from fluid to static * move group_norm, data_norm to static --- python/paddle/fluid/layers/nn.py | 456 ----------------- .../unittests/ipu/test_groupnorm_op_ipu.py | 4 +- .../unittests/ipu/test_instancenorm_op_ipu.py | 4 +- .../ir/inference/test_trt_instance_norm_op.py | 3 +- .../ir/inference/test_trt_subgraph_pass.py | 2 +- .../unittests/npu/test_group_norm_op_npu.py | 6 +- .../tests/unittests/test_data_norm_op.py | 3 +- .../tests/unittests/test_dist_fleet_ps2.py | 2 +- .../fluid/tests/unittests/test_fleet.py | 5 +- .../tests/unittests/test_group_norm_op.py | 11 +- .../unittests/test_imperative_double_grad.py | 2 +- .../test_imperative_load_static_param.py | 8 +- ...perative_star_gan_with_gradient_penalty.py | 2 +- .../tests/unittests/test_instance_norm_op.py | 4 +- .../fluid/tests/unittests/test_layers.py | 4 +- .../tests/unittests/test_norm_nn_grad.py | 4 +- python/paddle/static/nn/__init__.py | 6 +- python/paddle/static/nn/common.py | 457 +++++++++++++++++- 18 files changed, 496 insertions(+), 487 deletions(-) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index e760b357e09..ca0d3cd721d 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -73,8 +73,6 @@ __all__ = [ 'pool2d', 'pool3d', 'batch_norm', - 'instance_norm', - 'data_norm', 'reduce_mean', 'reduce_all', 'reduce_any', @@ -88,7 +86,6 @@ __all__ = [ 'row_conv', 'multiplex', 'layer_norm', - 'group_norm', 'spectral_norm', 'smooth_l1', 'one_hot', @@ -2462,349 +2459,6 @@ def batch_norm( return helper.append_activation(batch_norm_out) -def instance_norm( - input, epsilon=1e-05, param_attr=None, bias_attr=None, name=None -): - r""" - :api_attr: Static Graph - - **Instance Normalization Layer** - - Can be used as a normalizer function for convolution or fully_connected operations. - The required data format for this layer is one of the following: - - DataLayout: NCHW `[batch, in_channels, in_height, in_width]` - - Refer to `Instance Normalization: The Missing Ingredient for - Fast Stylization `_ - for more details. - - :math:`input` is the input features over a mini-batch. - - .. math:: - - \\mu_{\\beta} &\\gets \\frac{1}{HW} \\sum_{i=1}^{HW} x_i \\qquad &//\\ - \\ mean\ of\ one\ feature\ map\ in\ mini-batch \\\\ - \\sigma_{\\beta}^{2} &\\gets \\frac{1}{HW} \\sum_{i=1}^{HW}(x_i - \\ - \\mu_{\\beta})^2 \\qquad &//\ variance\ of\ one\ feature\ map\ in\ mini-batch \\\\ - \\hat{x_i} &\\gets \\frac{x_i - \\mu_\\beta} {\\sqrt{\\ - \\sigma_{\\beta}^{2} + \\epsilon}} \\qquad &//\ normalize \\\\ - y_i &\\gets \\gamma \\hat{x_i} + \\beta \\qquad &//\ scale\ and\ shift - - Note: - `H` means height of feature map, `W` means width of feature map. - - Args: - input(Tensor): The rank of input tensor can be 2, 3, 4, 5. - The data type is float32 or float64. - epsilon(float, Default 1e-05): A value added to the denominator for - numerical stability. Default is 1e-5. - param_attr(ParamAttr|None|bool, optional): The parameter attribute for Parameter `scale` - of instance_norm. If it is set to None or one attribute of ParamAttr, instance_norm - will create ParamAttr as param_attr, the name of scale can be set in ParamAttr. - If the Initializer of the param_attr is not set, the parameter is initialized - with Xavier. If the param_attr is set to False, instance_norm will not create param_attr. - Default: None. - bias_attr(ParamAttr|None|bool, optional): The parameter attribute for the bias of instance_norm. - If it is set to None or one attribute of ParamAttr, instance_norm - will create ParamAttr as bias_attr, the name of bias can be set in ParamAttr. - If the Initializer of the bias_attr is not set, the bias is initialized zero. - If the bias_attr is set to False, instance_norm will not create bias_attr. - Default: None. - name(string, Default None): A name for this layer(optional). If set None, the layer - will be named automatically. - - Returns: - A Tensor which is the result after applying instance normalization on the input, - has same shape and data type with input. - - Examples: - - .. code-block:: python - - import paddle - paddle.enable_static() - x = paddle.static.data(name='x', shape=[3, 7, 3, 7], dtype='float32') - hidden1 = paddle.static.nn.fc(x, size=200) - hidden2 = paddle.static.nn.instance_norm(hidden1) - """ - check_variable_and_dtype( - input, 'input', ['float32', 'float64'], 'instance_norm' - ) - if param_attr is False: - assert ( - bias_attr is False - ), "param_attr and bias_attr must be set to False at the same time in instance_norm" - - helper = LayerHelper('instance_norm', **locals()) - dtype = helper.input_dtype() - - # use fp32 for in parameter - if dtype == core.VarDesc.VarType.FP16: - dtype = core.VarDesc.VarType.FP32 - - input_shape = input.shape - if len(input.shape) < 2 or len(input.shape) > 5: - raise ValueError( - 'expected 2D or 3D or 4D or 5D input (got {}D input, input shape is: {})'.format( - len(input.shape), input_shape - ) - ) - channel_num = input_shape[1] - - param_shape = [channel_num] - - if param_attr != False and bias_attr != False: - # create parameter - scale = helper.create_parameter( - attr=helper.param_attr, - shape=param_shape, - dtype=dtype, - default_initializer=Constant(1.0), - ) - bias = helper.create_parameter( - attr=helper.bias_attr, - shape=param_shape, - dtype=dtype, - is_bias=True, - default_initializer=Constant(0.0), - ) - - # create output - saved_mean = helper.create_variable_for_type_inference( - dtype=dtype, stop_gradient=True - ) - saved_variance = helper.create_variable_for_type_inference( - dtype=dtype, stop_gradient=True - ) - - instance_norm_out = helper.create_variable_for_type_inference(dtype) - - inputs = {"X": input} - if param_attr != False and bias_attr != False: - inputs["Scale"] = scale - inputs["Bias"] = bias - - helper.append_op( - type="instance_norm", - inputs=inputs, - outputs={ - "Y": instance_norm_out, - "SavedMean": saved_mean, - "SavedVariance": saved_variance, - }, - attrs={ - "epsilon": epsilon, - }, - ) - - return instance_norm_out - - -@static_only -def data_norm( - input, - act=None, - epsilon=1e-05, - param_attr=None, - data_layout='NCHW', - in_place=False, - name=None, - moving_mean_name=None, - moving_variance_name=None, - do_model_average_for_mean_and_var=True, - slot_dim=-1, - sync_stats=False, - summary_decay_rate=0.9999999, - enable_scale_and_shift=False, -): - r""" - :api_attr: Static Graph - - **Data Normalization Layer** - - This op can be used as a normalizer function for conv2d and fully_connected operations. - The required data format for this layer is one of the following: - - 1. NHWC `[batch, in_height, in_width, in_channels]` - - 2. NCHW `[batch, in_channels, in_height, in_width]` - - :math:`input` is the input features over a mini-batch. - - .. math:: - - \\mu_{\\beta} &\\gets \\frac{1}{m} \\sum_{i=1}^{m} x_i \\qquad &//\\ - \ mini-batch\ mean \\\\ - \\sigma_{\\beta}^{2} &\\gets \\frac{1}{m} \\sum_{i=1}^{m}(x_i - \\ - \\mu_{\\beta})^2 \\qquad &//\ mini-batch\ variance \\\\ - \\hat{x_i} &\\gets \\frac{x_i - \\mu_\\beta} {\\sqrt{\\ - \\sigma_{\\beta}^{2} + \\epsilon}} \\qquad &//\ normalize \\\\ - y_i &\\gets \\gamma \\hat{x_i} + \\beta \\qquad &//\ scale\ and\ shift - - Args: - input(Tensor): The input Tensor. - act(string, Default None): Activation type, linear|relu|prelu|... - epsilon(float, Default 1e-05): - param_attr(ParamAttr): The parameter attribute for Parameter `scale`. - data_layout (str, optional): Specify the data format of the input, and the data format of the output - will be consistent with that of the input. An optional string from: `"NCHW"`, `"NHWC"`. - The default is `"NCHW"`. When it is `"NCHW"`, the data is stored in the order of: - `[batch_size, input_channels, input_height, input_width]`. - in_place(bool, Default False): Make the input and output of batch norm reuse memory. - name(string, Default None): A name for this layer(optional). If set None, the layer - will be named automatically. - moving_mean_name(string, Default None): The name of moving_mean which store the global Mean. - moving_variance_name(string, Default None): The name of the moving_variance which store the global Variance. - do_model_average_for_mean_and_var(bool, Default True): Whether parameter mean and variance - should do model average when model average is enabled. - slot_dim(int): The embedding dimension of one slot. Slot is a set of one specific feature. In pslib mode, we - distinguish feature ids by slot and pull their embeddings from parameter server (pslib). The first - place of the embedding is the historical show number (occurence time of this feature id with a label 0). - If the input of this op is concated by slot-wise embeddings, and the show number is zero when this slot - is new or empty, the normalization result may be impractical. To avoid this, we add slot_dim to locate - the show number and judge if the show number is zero. If so, we choose to skip normalization on this - embedding. - sync_stats(bool, Default False): When running with multiple GPU cards, using allreduce to sync the - summary messages. - summary_decay_rate(float, Default 0.9999999): The decay rate when updating summary. - enable_scale_and_shift(bool, Default False): do scale&shift after normalization. - - Returns: - Tensor: A tensor which is the result after applying data normalization on the input. - - Examples: - - .. code-block:: python - - import paddle - paddle.enable_static() - - x = paddle.randn(shape=[32,100]) - hidden2 = paddle.static.nn.data_norm(input=x) - """ - helper = LayerHelper('data_norm', **locals()) - dtype = helper.input_dtype() - - input_shape = input.shape - if data_layout == 'NCHW': - channel_num = input_shape[1] - else: - if data_layout == 'NHWC': - channel_num = input_shape[-1] - else: - raise ValueError("unsupported data layout:" + data_layout) - - param_shape = [channel_num] - - batch_size_default = 1e4 - batch_sum_default = 0.0 - batch_square_sum_default = 1e4 - scale_w_default = 1.0 - bias_default = 0.0 - - if param_attr and isinstance(param_attr, dict): - batch_size_default = param_attr.get("batch_size", 1e4) - batch_sum_default = param_attr.get("batch_sum", 0.0) - batch_square_sum_default = param_attr.get("batch_square", 1e4) - if enable_scale_and_shift: - scale_w_default = param_attr.get("scale_w", 1.0) - bias_default = param_attr.get("bias", 0.0) - - # create scale and shift(bias) when enable_scale_and_shift is True - if name is None: - name = "dn" - if enable_scale_and_shift: - scale_w = helper.create_parameter( - attr=ParamAttr( - name=name + '.scale_w', - initializer=Constant(value=float(scale_w_default)), - trainable=True, - ), - shape=param_shape, - dtype=input.dtype, - ) - bias = helper.create_parameter( - attr=ParamAttr( - name=name + '.bias', - initializer=Constant(value=float(bias_default)), - trainable=True, - ), - shape=param_shape, - dtype=input.dtype, - ) - # create parameter - batch_size = helper.create_parameter( - attr=ParamAttr( - name=name + '.batch_size', - initializer=Constant(value=float(batch_size_default)), - trainable=True, - ), - shape=param_shape, - dtype=input.dtype, - ) - - batch_sum = helper.create_parameter( - attr=ParamAttr( - name=name + '.batch_sum', - initializer=Constant(value=float(batch_sum_default)), - trainable=True, - ), - shape=param_shape, - dtype=input.dtype, - ) - - batch_square_sum = helper.create_parameter( - attr=ParamAttr( - name=name + '.batch_square_sum', - initializer=Constant(value=float(batch_square_sum_default)), - trainable=True, - ), - shape=param_shape, - dtype=input.dtype, - ) - - means = helper.create_variable(dtype=dtype, stop_gradient=True) - scales = helper.create_variable(dtype=dtype, stop_gradient=True) - - data_norm_out = input if in_place else helper.create_variable(dtype=dtype) - - inputs = { - "X": input, - "BatchSize": batch_size, - "BatchSum": batch_sum, - "BatchSquareSum": batch_square_sum, - } - attrs = { - "epsilon": epsilon, - "data_layout": data_layout, - "sync_stats": sync_stats, - "summary_decay_rate": summary_decay_rate, - } - if slot_dim > 0: - attrs["slot_dim"] = slot_dim - if enable_scale_and_shift: - attrs["enable_scale_and_shift"] = enable_scale_and_shift - if enable_scale_and_shift: - inputs["scale_w"] = scale_w - inputs["bias"] = bias - helper.append_op( - type="data_norm", - inputs=inputs, - outputs={ - "Y": data_norm_out, - "Means": means, - "Scales": scales, - "BatchSize": batch_size, - "BatchSum": batch_sum, - "BatchSquareSum": batch_square_sum, - }, - attrs=attrs, - ) - - return helper.append_activation(data_norm_out) - - @templatedoc() def layer_norm( input, @@ -2941,116 +2595,6 @@ def layer_norm( return helper.append_activation(layer_norm_out) -@templatedoc() -def group_norm( - input, - groups, - epsilon=1e-05, - param_attr=None, - bias_attr=None, - act=None, - data_layout='NCHW', - name=None, -): - """ - :api_attr: Static Graph - - **Group Normalization Layer** - - Refer to `Group Normalization `_ . - - Parameters: - input(Tensor): Tensor with dimension greater than 1, the data type is float32 or float64. - groups(int): The number of groups that divided from channels, the data type - is int32. - epsilon(float, optional): The small value added to the variance to prevent - division by zero, the data type is float32. Default: 1e-05. - param_attr(ParamAttr|bool, optional): ParamAttr object that specifies weight parameter - attribute. If a bool type, only False is supported, which means there is no weight parameter. - Default: None, the default weight parameter attribute is used. For more information, please - refer to :ref:`api_guide_ParamAttr` . - bias_attr(ParamAttr|bool, optional): ParamAttr object that specifies bias parameter - attribute. If a bool type, only False is supported, which means there is no bias parameter. - Default: None, the default bias parameter attribute is used. For more information, please - refer to :ref:`api_guide_ParamAttr` . - act(str, optional): Activation to be applied to the output of group normalization. - data_layout(str, optional): Specify the data format of the input, and the data format of the output - will be consistent with that of the input. An optional string from: `"NCHW"`, `"NHWC"`. - The default is `"NCHW"`. When it is `"NCHW"`, the data is stored in the order of: - `[batch_size, input_channels, *]`. - name (str, optional): The default value is None. Normally there is no need for user to set this - property. For more information, please refer to :ref:`api_guide_Name` . - - Returns: - Tensor: A Tensor has same data type and data format with `input`. - - Examples: - .. code-block:: python - - import paddle - paddle.enable_static() - - data = paddle.static.data(name='data', shape=[2, 8, 32, 32], dtype='float32') - x = paddle.static.nn.group_norm(input=data, groups=4) - print(x.shape) # [2, 8, 32, 32] - """ - helper = LayerHelper('group_norm', **locals()) - dtype = helper.input_dtype() - check_variable_and_dtype( - input, 'input', ['float32', 'float64'], 'group_norm' - ) - # create intput and parameters - inputs = {'X': input} - input_shape = input.shape - if len(input_shape) < 2: - raise ValueError( - f"The dimensions of Op(fluid.layers.group_norm)'s input should be more than 1. But received {len(input_shape)}" - ) - if data_layout != 'NCHW' and data_layout != 'NHWC': - raise ValueError( - "Param(data_layout) of Op(fluid.layers.group_norm) got wrong value: received " - + data_layout - + " but only NCHW or NHWC supported." - ) - channel_num = input_shape[1] if data_layout == 'NCHW' else input_shape[-1] - param_shape = [channel_num] - if param_attr: - scale = helper.create_parameter( - attr=helper.param_attr, - shape=param_shape, - dtype=dtype, - default_initializer=Constant(1.0), - ) - inputs['Scale'] = scale - if bias_attr: - bias = helper.create_parameter( - attr=helper.bias_attr, shape=param_shape, dtype=dtype, is_bias=True - ) - inputs['Bias'] = bias - - # create output - mean_out = helper.create_variable(dtype=dtype, stop_gradient=True) - variance_out = helper.create_variable(dtype=dtype, stop_gradient=True) - group_norm_out = helper.create_variable(dtype=dtype) - - helper.append_op( - type="group_norm", - inputs=inputs, - outputs={ - "Y": group_norm_out, - "Mean": mean_out, - "Variance": variance_out, - }, - attrs={ - "epsilon": epsilon, - "groups": groups, - "data_layout": data_layout, - }, - ) - - return helper.append_activation(group_norm_out) - - @templatedoc() def spectral_norm(weight, dim=0, power_iters=1, eps=1e-12, name=None): r""" diff --git a/python/paddle/fluid/tests/unittests/ipu/test_groupnorm_op_ipu.py b/python/paddle/fluid/tests/unittests/ipu/test_groupnorm_op_ipu.py index 609b212e03a..457858ac08c 100644 --- a/python/paddle/fluid/tests/unittests/ipu/test_groupnorm_op_ipu.py +++ b/python/paddle/fluid/tests/unittests/ipu/test_groupnorm_op_ipu.py @@ -63,7 +63,7 @@ class TestBase(IPUOpTest): ) scale = paddle.ParamAttr(trainable=True) bias = paddle.ParamAttr(trainable=True) - out = paddle.fluid.layers.nn.group_norm( + out = paddle.static.nn.group_norm( conv1, param_attr=scale, bias_attr=bias, **self.attrs ) loss = paddle.mean(out) @@ -71,7 +71,7 @@ class TestBase(IPUOpTest): adam.minimize(loss) self.fetch_list = [loss.name] else: - out = paddle.fluid.layers.nn.group_norm( + out = paddle.static.nn.group_norm( x, param_attr=True, bias_attr=True, **self.attrs ) self.fetch_list = [out.name] diff --git a/python/paddle/fluid/tests/unittests/ipu/test_instancenorm_op_ipu.py b/python/paddle/fluid/tests/unittests/ipu/test_instancenorm_op_ipu.py index b2fb872e369..14210c69e4f 100644 --- a/python/paddle/fluid/tests/unittests/ipu/test_instancenorm_op_ipu.py +++ b/python/paddle/fluid/tests/unittests/ipu/test_instancenorm_op_ipu.py @@ -60,7 +60,7 @@ class TestBase(IPUOpTest): ) scale = paddle.ParamAttr(trainable=True) bias = paddle.ParamAttr(trainable=True) - out = paddle.fluid.layers.nn.instance_norm( + out = paddle.static.nn.instance_norm( conv1, param_attr=scale, bias_attr=bias, **self.attrs ) loss = paddle.mean(out) @@ -68,7 +68,7 @@ class TestBase(IPUOpTest): adam.minimize(loss) self.fetch_list = [loss.name] else: - out = paddle.fluid.layers.nn.instance_norm( + out = paddle.static.nn.instance_norm( x, param_attr=True, bias_attr=True, **self.attrs ) self.fetch_list = [out.name] diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_instance_norm_op.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_instance_norm_op.py index 695bf42b3db..2901238ffe4 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_instance_norm_op.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_instance_norm_op.py @@ -20,6 +20,7 @@ import unittest import numpy as np from inference_pass_test import InferencePassTest +import paddle import paddle.fluid as fluid import paddle.fluid.core as core from paddle.fluid.core import AnalysisConfig, PassVersionChecker @@ -43,7 +44,7 @@ class TRTInstanceNormTest(InferencePassTest): with fluid.program_guard(self.main_program, self.startup_program): shape = [-1, self.channel, self.height, self.width] data = fluid.data(name='in', shape=shape, dtype='float32') - instance_norm_out = fluid.layers.instance_norm(data) + instance_norm_out = paddle.static.nn.instance_norm(data) out = fluid.layers.batch_norm(instance_norm_out, is_test=True) shape[0] = self.bs diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_subgraph_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_subgraph_pass.py index f6136d77f2d..235f2446cb1 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_subgraph_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_subgraph_pass.py @@ -186,7 +186,7 @@ class TensorRTSubgraphPassInstanceNormTest(InferencePassTest): name='instance_norm_b', initializer=fluid.initializer.Constant(value=0.0), ) - out = fluid.layers.instance_norm( + out = paddle.static.nn.instance_norm( input=data, param_attr=param_attr, bias_attr=bias_attr ) self.feeds = { diff --git a/python/paddle/fluid/tests/unittests/npu/test_group_norm_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_group_norm_op_npu.py index 06d0f5dd1d0..7f95e2b55c6 100644 --- a/python/paddle/fluid/tests/unittests/npu/test_group_norm_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_group_norm_op_npu.py @@ -52,7 +52,7 @@ class TestGroupNormOpError(unittest.TestCase): def test_x_type(): input = np.random.random(2, 100, 3, 5).astype('float32') groups = 2 - fluid.layers.group_norm(input, groups) + paddle.static.nn.group_norm(input, groups) self.assertRaises(TypeError, test_x_type) @@ -61,7 +61,7 @@ class TestGroupNormOpError(unittest.TestCase): name='x2', shape=[2, 100, 3, 5], dtype='int32' ) groups = 2 - fluid.layers.group_norm(x2, groups) + paddle.static.nn.group_norm(x2, groups) self.assertRaises(TypeError, test_x_dtype) @@ -219,7 +219,7 @@ class TestGroupNormException(unittest.TestCase): data = fluid.data(name='data', shape=[None, 3, 3, 4], dtype="float64") def attr_data_format(): - out = fluid.layers.group_norm( + out = paddle.static.nn.group_norm( input=data, groups=2, data_layout="NDHW" ) diff --git a/python/paddle/fluid/tests/unittests/test_data_norm_op.py b/python/paddle/fluid/tests/unittests/test_data_norm_op.py index 1f32feb3527..b5a2e76fe87 100644 --- a/python/paddle/fluid/tests/unittests/test_data_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_data_norm_op.py @@ -15,6 +15,7 @@ import unittest import numpy as np +import paddle import paddle.fluid.core as core from paddle.fluid.op import Operator import paddle.fluid as fluid @@ -517,7 +518,7 @@ class TestDataNormOpErrorr(unittest.TestCase): with program_guard(Program(), Program()): x2 = fluid.layers.data(name='x2', shape=[3, 4], dtype="int32") # self.assertRaises(TypeError, fluid.data_norm, x2) - fluid.layers.data_norm( + paddle.static.nn.data_norm( input=x2, param_attr={}, enable_scale_and_shift=True ) diff --git a/python/paddle/fluid/tests/unittests/test_dist_fleet_ps2.py b/python/paddle/fluid/tests/unittests/test_dist_fleet_ps2.py index 3bc478a0085..30f3f813488 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_fleet_ps2.py +++ b/python/paddle/fluid/tests/unittests/test_dist_fleet_ps2.py @@ -89,7 +89,7 @@ class TestPSPassWithBow(unittest.TestCase): # vsum q_sum = fluid.layers.sequence_pool(input=q_emb, pool_type='sum') q_ss = paddle.nn.functional.softsign(q_sum) - q_ss = fluid.layers.data_norm(input=q_ss) + q_ss = paddle.static.nn.data_norm(input=q_ss) # fc layer after conv q_fc = fluid.layers.fc( input=q_ss, diff --git a/python/paddle/fluid/tests/unittests/test_fleet.py b/python/paddle/fluid/tests/unittests/test_fleet.py index a9a75868ee3..75d6ab31754 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet.py +++ b/python/paddle/fluid/tests/unittests/test_fleet.py @@ -32,6 +32,7 @@ class TestFleet1(unittest.TestCase): def test_pslib_1(self): """Test cases for pslib.""" + import paddle import paddle.fluid as fluid from paddle.fluid.incubate.fleet.parameter_server.pslib import fleet from paddle.fluid.incubate.fleet.base.role_maker import GeneralRoleMaker @@ -66,7 +67,9 @@ class TestFleet1(unittest.TestCase): param_attr=fluid.ParamAttr(name="embedding"), ) bow = fluid.layers.sequence_pool(input=emb, pool_type='sum') - bow = fluid.layers.data_norm(input=bow, epsilon=1e-4, name="norm") + bow = paddle.static.nn.data_norm( + input=bow, epsilon=1e-4, name="norm" + ) fc = fluid.layers.fc(input=bow, size=1, act=None) label = fluid.layers.data( name="click", diff --git a/python/paddle/fluid/tests/unittests/test_group_norm_op.py b/python/paddle/fluid/tests/unittests/test_group_norm_op.py index df5c832e2f2..2b74636939e 100644 --- a/python/paddle/fluid/tests/unittests/test_group_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_group_norm_op.py @@ -15,6 +15,7 @@ import unittest import numpy as np +import paddle import paddle.fluid.core as core import paddle.fluid as fluid from op_test import OpTest, skip_check_grad_ci @@ -46,7 +47,7 @@ class TestGroupNormOpError(unittest.TestCase): def test_x_type(): input = np.random.random(2, 100, 3, 5).astype('float32') groups = 2 - fluid.layers.group_norm(input, groups) + paddle.static.nn.group_norm(input, groups) self.assertRaises(TypeError, test_x_type) @@ -55,7 +56,7 @@ class TestGroupNormOpError(unittest.TestCase): name='x2', shape=[2, 100, 3, 5], dtype='int32' ) groups = 2 - fluid.layers.group_norm(x2, groups) + paddle.static.nn.group_norm(x2, groups) self.assertRaises(TypeError, test_x_dtype) @@ -245,11 +246,11 @@ class TestGroupNormOpLargeData_With_NHWC(TestGroupNormOp): class TestGroupNormAPI_With_NHWC(unittest.TestCase): def test_case1(self): data1 = fluid.data(name='data1', shape=[None, 3, 3, 4], dtype='float64') - out1 = fluid.layers.group_norm( + out1 = paddle.static.nn.group_norm( input=data1, groups=2, data_layout="NHWC" ) data2 = fluid.data(name='data2', shape=[None, 4, 3, 3], dtype='float64') - out2 = fluid.layers.group_norm( + out2 = paddle.static.nn.group_norm( input=data2, groups=2, data_layout="NCHW" ) @@ -282,7 +283,7 @@ class TestGroupNormException(unittest.TestCase): data = fluid.data(name='data', shape=[None, 3, 3, 4], dtype="float64") def attr_data_format(): - out = fluid.layers.group_norm( + out = paddle.static.nn.group_norm( input=data, groups=2, data_layout="NDHW" ) diff --git a/python/paddle/fluid/tests/unittests/test_imperative_double_grad.py b/python/paddle/fluid/tests/unittests/test_imperative_double_grad.py index c38caf69e08..4b5e008cb74 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_double_grad.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_double_grad.py @@ -636,7 +636,7 @@ class TestRaiseNoDoubleGradOp(TestCase): with fluid.dygraph.guard(): x = fluid.layers.ones(shape=[2, 3, 2, 2], dtype='float32') x.stop_gradient = False - y = paddle.fluid.layers.group_norm(x, groups=1) + y = paddle.static.nn.group_norm(x, groups=1) dx = fluid.dygraph.grad( outputs=[y], inputs=[x], create_graph=True, retain_graph=True diff --git a/python/paddle/fluid/tests/unittests/test_imperative_load_static_param.py b/python/paddle/fluid/tests/unittests/test_imperative_load_static_param.py index 3ee24ec9821..0fb5f40470a 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_load_static_param.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_load_static_param.py @@ -123,8 +123,12 @@ class TestDygraphLoadStatic(unittest.TestCase): groupnorm_in = fluid.data( name='groupnorm_in', shape=[None, 8, 32, 32], dtype='float32' ) - groupnorm_out1 = fluid.layers.group_norm(input=groupnorm_in, groups=4) - groupnorm_out2 = fluid.layers.group_norm(input=groupnorm_in, groups=4) + groupnorm_out1 = paddle.static.nn.group_norm( + input=groupnorm_in, groups=4 + ) + groupnorm_out2 = paddle.static.nn.group_norm( + input=groupnorm_in, groups=4 + ) ''' spec_norm = fluid.data(name='spec_norm', shape=[2, 8, 32, 32], dtype='float32') spe_norm_out_1 = fluid.layers.spectral_norm(weight=spec_norm, dim=1, power_iters=2) diff --git a/python/paddle/fluid/tests/unittests/test_imperative_star_gan_with_gradient_penalty.py b/python/paddle/fluid/tests/unittests/test_imperative_star_gan_with_gradient_penalty.py index 92279d501e4..f9034aa45f6 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_star_gan_with_gradient_penalty.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_star_gan_with_gradient_penalty.py @@ -120,7 +120,7 @@ class InstanceNorm(fluid.dygraph.Layer): ) return out else: - return fluid.layers.instance_norm( + return paddle.static.nn.instance_norm( input, epsilon=self.epsilon, param_attr=fluid.ParamAttr(self.scale.name), diff --git a/python/paddle/fluid/tests/unittests/test_instance_norm_op.py b/python/paddle/fluid/tests/unittests/test_instance_norm_op.py index ed9e01259e6..c5cf210f340 100644 --- a/python/paddle/fluid/tests/unittests/test_instance_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_instance_norm_op.py @@ -239,11 +239,11 @@ class TestInstanceNormOpError(unittest.TestCase): x1 = fluid.create_lod_tensor( np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.CPUPlace() ) - self.assertRaises(TypeError, fluid.layers.instance_norm, x1) + self.assertRaises(TypeError, paddle.static.nn.instance_norm, x1) # the input dtype of instance_norm must be float32 or float64 x2 = fluid.layers.data(name='x2', shape=[3, 4, 5, 6], dtype="int32") - self.assertRaises(TypeError, fluid.layers.instance_norm, x2) + self.assertRaises(TypeError, paddle.static.nn.instance_norm, x2) class TestInstanceNormOpErrorCase1(unittest.TestCase): diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 30e50294448..1d5521f4bdc 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -1880,7 +1880,7 @@ class TestLayer(LayerTest): lod_level=1, append_batch_size=False, ) - ret = layers.group_norm( + ret = paddle.static.nn.group_norm( input=X, groups=2, param_attr=fluid.initializer.Uniform(low=-0.5, high=0.5), @@ -1953,7 +1953,7 @@ class TestLayer(LayerTest): X = fluid.layers.data( name='X', shape=shape, dtype='float32', append_batch_size=False ) - ret = layers.instance_norm(input=X) + ret = paddle.static.nn.instance_norm(input=X) static_ret = self.get_static_graph_result( feed={'X': input}, fetch_list=[ret] )[0] diff --git a/python/paddle/fluid/tests/unittests/test_norm_nn_grad.py b/python/paddle/fluid/tests/unittests/test_norm_nn_grad.py index 26dea91aecc..ed6b94432a4 100644 --- a/python/paddle/fluid/tests/unittests/test_norm_nn_grad.py +++ b/python/paddle/fluid/tests/unittests/test_norm_nn_grad.py @@ -35,7 +35,7 @@ class TestInstanceNormDoubleGradCheck(unittest.TestCase): eps = 0.005 atol = 1e-4 x = layers.create_parameter(dtype=dtype, shape=shape, name='x') - z = fluid.layers.instance_norm(input=x) + z = paddle.static.nn.instance_norm(input=x) x_arr = np.random.uniform(-1, 1, shape).astype(dtype) gradient_checker.double_grad_check( [x], z, x_init=x_arr, atol=atol, place=place, eps=eps @@ -63,7 +63,7 @@ class TestInstanceNormDoubleGradCheckWithoutParamBias( eps = 0.005 atol = 1e-4 x = layers.create_parameter(dtype=dtype, shape=shape, name='x') - z = fluid.layers.instance_norm( + z = paddle.static.nn.instance_norm( input=x, param_attr=False, bias_attr=False ) x_arr = np.random.uniform(-1, 1, shape).astype(dtype) diff --git a/python/paddle/static/nn/__init__.py b/python/paddle/static/nn/__init__.py index 6f27289efc7..5dfae6c9809 100755 --- a/python/paddle/static/nn/__init__.py +++ b/python/paddle/static/nn/__init__.py @@ -13,6 +13,9 @@ # limitations under the License. from .common import fc # noqa: F401 +from .common import instance_norm # noqa: F401 +from .common import data_norm # noqa: F401 +from .common import group_norm # noqa: F401 from .common import deform_conv2d # noqa: F401 from .common import conv3d # noqa: F401 from .common import conv2d_transpose # noqa: F401 @@ -25,9 +28,6 @@ from ...fluid.layers import cond # noqa: F401 from ...fluid.layers import conv2d # noqa: F401 from ...fluid.layers import create_parameter # noqa: F401 from ...fluid.layers import crf_decoding # noqa: F401 -from ...fluid.layers import data_norm # noqa: F401 -from ...fluid.layers import group_norm # noqa: F401 -from ...fluid.layers import instance_norm # noqa: F401 from ...fluid.layers import layer_norm # noqa: F401 from ...fluid.layers import multi_box_head # noqa: F401 from .loss import nce # noqa: F401 diff --git a/python/paddle/static/nn/common.py b/python/paddle/static/nn/common.py index a7470f2fb2e..da3b58bb182 100755 --- a/python/paddle/static/nn/common.py +++ b/python/paddle/static/nn/common.py @@ -13,8 +13,10 @@ # limitations under the License. import paddle -from paddle.fluid.initializer import Normal +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.initializer import Normal, Constant from paddle.fluid.framework import static_only, Variable, _non_static_mode +from paddle.fluid.layers.layer_function_generator import templatedoc from paddle.fluid.data_feeder import check_dtype @@ -177,6 +179,459 @@ def fc( ) +def instance_norm( + input, epsilon=1e-05, param_attr=None, bias_attr=None, name=None +): + r""" + :api_attr: Static Graph + + **Instance Normalization Layer** + + Can be used as a normalizer function for convolution or fully_connected operations. + The required data format for this layer is one of the following: + + DataLayout: NCHW `[batch, in_channels, in_height, in_width]` + + Refer to `Instance Normalization: The Missing Ingredient for + Fast Stylization `_ + for more details. + + :math:`input` is the input features over a mini-batch. + + .. math:: + + \\mu_{\\beta} &\\gets \\frac{1}{HW} \\sum_{i=1}^{HW} x_i \\qquad &//\\ + \\ mean\ of\ one\ feature\ map\ in\ mini-batch \\\\ + \\sigma_{\\beta}^{2} &\\gets \\frac{1}{HW} \\sum_{i=1}^{HW}(x_i - \\ + \\mu_{\\beta})^2 \\qquad &//\ variance\ of\ one\ feature\ map\ in\ mini-batch \\\\ + \\hat{x_i} &\\gets \\frac{x_i - \\mu_\\beta} {\\sqrt{\\ + \\sigma_{\\beta}^{2} + \\epsilon}} \\qquad &//\ normalize \\\\ + y_i &\\gets \\gamma \\hat{x_i} + \\beta \\qquad &//\ scale\ and\ shift + + Note: + `H` means height of feature map, `W` means width of feature map. + + Args: + input(Tensor): The rank of input tensor can be 2, 3, 4, 5. + The data type is float32 or float64. + epsilon(float, Default 1e-05): A value added to the denominator for + numerical stability. Default is 1e-5. + param_attr(ParamAttr|None|bool, optional): The parameter attribute for Parameter `scale` + of instance_norm. If it is set to None or one attribute of ParamAttr, instance_norm + will create ParamAttr as param_attr, the name of scale can be set in ParamAttr. + If the Initializer of the param_attr is not set, the parameter is initialized + with Xavier. If the param_attr is set to False, instance_norm will not create param_attr. + Default: None. + bias_attr(ParamAttr|None|bool, optional): The parameter attribute for the bias of instance_norm. + If it is set to None or one attribute of ParamAttr, instance_norm + will create ParamAttr as bias_attr, the name of bias can be set in ParamAttr. + If the Initializer of the bias_attr is not set, the bias is initialized zero. + If the bias_attr is set to False, instance_norm will not create bias_attr. + Default: None. + name(string, Default None): A name for this layer(optional). If set None, the layer + will be named automatically. + + Returns: + A Tensor which is the result after applying instance normalization on the input, + has same shape and data type with input. + + Examples: + + .. code-block:: python + + import paddle + paddle.enable_static() + x = paddle.static.data(name='x', shape=[3, 7, 3, 7], dtype='float32') + hidden1 = paddle.static.nn.fc(x, size=200) + hidden2 = paddle.static.nn.instance_norm(hidden1) + """ + check_variable_and_dtype( + input, 'input', ['float32', 'float64'], 'instance_norm' + ) + if param_attr is False: + assert ( + bias_attr is False + ), "param_attr and bias_attr must be set to False at the same time in instance_norm" + + helper = LayerHelper('instance_norm', **locals()) + dtype = helper.input_dtype() + + # use fp32 for in parameter + if dtype == paddle.framework.core.VarDesc.VarType.FP16: + dtype = paddle.framework.core.VarDesc.VarType.FP32 + + input_shape = input.shape + if len(input.shape) < 2 or len(input.shape) > 5: + raise ValueError( + 'expected 2D or 3D or 4D or 5D input (got {}D input, input shape is: {})'.format( + len(input.shape), input_shape + ) + ) + channel_num = input_shape[1] + + param_shape = [channel_num] + + if param_attr and bias_attr: + # create parameter + scale = helper.create_parameter( + attr=helper.param_attr, + shape=param_shape, + dtype=dtype, + default_initializer=Constant(1.0), + ) + bias = helper.create_parameter( + attr=helper.bias_attr, + shape=param_shape, + dtype=dtype, + is_bias=True, + default_initializer=Constant(0.0), + ) + + # create output + saved_mean = helper.create_variable_for_type_inference( + dtype=dtype, stop_gradient=True + ) + saved_variance = helper.create_variable_for_type_inference( + dtype=dtype, stop_gradient=True + ) + + instance_norm_out = helper.create_variable_for_type_inference(dtype) + + inputs = {"X": input} + if param_attr and bias_attr: + inputs["Scale"] = scale + inputs["Bias"] = bias + + helper.append_op( + type="instance_norm", + inputs=inputs, + outputs={ + "Y": instance_norm_out, + "SavedMean": saved_mean, + "SavedVariance": saved_variance, + }, + attrs={ + "epsilon": epsilon, + }, + ) + + return instance_norm_out + + +@static_only +def data_norm( + input, + act=None, + epsilon=1e-05, + param_attr=None, + data_layout='NCHW', + in_place=False, + name=None, + moving_mean_name=None, + moving_variance_name=None, + do_model_average_for_mean_and_var=True, + slot_dim=-1, + sync_stats=False, + summary_decay_rate=0.9999999, + enable_scale_and_shift=False, +): + r""" + :api_attr: Static Graph + + **Data Normalization Layer** + + This op can be used as a normalizer function for conv2d and fully_connected operations. + The required data format for this layer is one of the following: + + 1. NHWC `[batch, in_height, in_width, in_channels]` + + 2. NCHW `[batch, in_channels, in_height, in_width]` + + :math:`input` is the input features over a mini-batch. + + .. math:: + + \\mu_{\\beta} &\\gets \\frac{1}{m} \\sum_{i=1}^{m} x_i \\qquad &//\\ + \ mini-batch\ mean \\\\ + \\sigma_{\\beta}^{2} &\\gets \\frac{1}{m} \\sum_{i=1}^{m}(x_i - \\ + \\mu_{\\beta})^2 \\qquad &//\ mini-batch\ variance \\\\ + \\hat{x_i} &\\gets \\frac{x_i - \\mu_\\beta} {\\sqrt{\\ + \\sigma_{\\beta}^{2} + \\epsilon}} \\qquad &//\ normalize \\\\ + y_i &\\gets \\gamma \\hat{x_i} + \\beta \\qquad &//\ scale\ and\ shift + + Args: + input(Tensor): The input Tensor. + act(string, Default None): Activation type, linear|relu|prelu|... + epsilon(float, Default 1e-05): + param_attr(ParamAttr): The parameter attribute for Parameter `scale`. + data_layout (str, optional): Specify the data format of the input, and the data format of the output + will be consistent with that of the input. An optional string from: `"NCHW"`, `"NHWC"`. + The default is `"NCHW"`. When it is `"NCHW"`, the data is stored in the order of: + `[batch_size, input_channels, input_height, input_width]`. + in_place(bool, Default False): Make the input and output of batch norm reuse memory. + name(string, Default None): A name for this layer(optional). If set None, the layer + will be named automatically. + moving_mean_name(string, Default None): The name of moving_mean which store the global Mean. + moving_variance_name(string, Default None): The name of the moving_variance which store the global Variance. + do_model_average_for_mean_and_var(bool, Default True): Whether parameter mean and variance + should do model average when model average is enabled. + slot_dim(int): The embedding dimension of one slot. Slot is a set of one specific feature. In pslib mode, we + distinguish feature ids by slot and pull their embeddings from parameter server (pslib). The first + place of the embedding is the historical show number (occurence time of this feature id with a label 0). + If the input of this op is concated by slot-wise embeddings, and the show number is zero when this slot + is new or empty, the normalization result may be impractical. To avoid this, we add slot_dim to locate + the show number and judge if the show number is zero. If so, we choose to skip normalization on this + embedding. + sync_stats(bool, Default False): When running with multiple GPU cards, using allreduce to sync the + summary messages. + summary_decay_rate(float, Default 0.9999999): The decay rate when updating summary. + enable_scale_and_shift(bool, Default False): do scale&shift after normalization. + + Returns: + Tensor: A tensor which is the result after applying data normalization on the input. + + Examples: + + .. code-block:: python + + import paddle + paddle.enable_static() + + x = paddle.randn(shape=[32,100]) + hidden2 = paddle.static.nn.data_norm(input=x) + """ + helper = LayerHelper('data_norm', **locals()) + dtype = helper.input_dtype() + + input_shape = input.shape + if data_layout == 'NCHW': + channel_num = input_shape[1] + else: + if data_layout == 'NHWC': + channel_num = input_shape[-1] + else: + raise ValueError("unsupported data layout:" + data_layout) + + param_shape = [channel_num] + + batch_size_default = 1e4 + batch_sum_default = 0.0 + batch_square_sum_default = 1e4 + scale_w_default = 1.0 + bias_default = 0.0 + + if param_attr and isinstance(param_attr, dict): + batch_size_default = param_attr.get("batch_size", 1e4) + batch_sum_default = param_attr.get("batch_sum", 0.0) + batch_square_sum_default = param_attr.get("batch_square", 1e4) + if enable_scale_and_shift: + scale_w_default = param_attr.get("scale_w", 1.0) + bias_default = param_attr.get("bias", 0.0) + + # create scale and shift(bias) when enable_scale_and_shift is True + if name is None: + name = "dn" + if enable_scale_and_shift: + scale_w = helper.create_parameter( + attr=ParamAttr( + name=name + '.scale_w', + initializer=Constant(value=float(scale_w_default)), + trainable=True, + ), + shape=param_shape, + dtype=input.dtype, + ) + bias = helper.create_parameter( + attr=ParamAttr( + name=name + '.bias', + initializer=Constant(value=float(bias_default)), + trainable=True, + ), + shape=param_shape, + dtype=input.dtype, + ) + # create parameter + batch_size = helper.create_parameter( + attr=ParamAttr( + name=name + '.batch_size', + initializer=Constant(value=float(batch_size_default)), + trainable=True, + ), + shape=param_shape, + dtype=input.dtype, + ) + + batch_sum = helper.create_parameter( + attr=ParamAttr( + name=name + '.batch_sum', + initializer=Constant(value=float(batch_sum_default)), + trainable=True, + ), + shape=param_shape, + dtype=input.dtype, + ) + + batch_square_sum = helper.create_parameter( + attr=ParamAttr( + name=name + '.batch_square_sum', + initializer=Constant(value=float(batch_square_sum_default)), + trainable=True, + ), + shape=param_shape, + dtype=input.dtype, + ) + + means = helper.create_variable(dtype=dtype, stop_gradient=True) + scales = helper.create_variable(dtype=dtype, stop_gradient=True) + + data_norm_out = input if in_place else helper.create_variable(dtype=dtype) + + inputs = { + "X": input, + "BatchSize": batch_size, + "BatchSum": batch_sum, + "BatchSquareSum": batch_square_sum, + } + attrs = { + "epsilon": epsilon, + "data_layout": data_layout, + "sync_stats": sync_stats, + "summary_decay_rate": summary_decay_rate, + } + if slot_dim > 0: + attrs["slot_dim"] = slot_dim + if enable_scale_and_shift: + attrs["enable_scale_and_shift"] = enable_scale_and_shift + if enable_scale_and_shift: + inputs["scale_w"] = scale_w + inputs["bias"] = bias + helper.append_op( + type="data_norm", + inputs=inputs, + outputs={ + "Y": data_norm_out, + "Means": means, + "Scales": scales, + "BatchSize": batch_size, + "BatchSum": batch_sum, + "BatchSquareSum": batch_square_sum, + }, + attrs=attrs, + ) + + return helper.append_activation(data_norm_out) + + +@templatedoc() +def group_norm( + input, + groups, + epsilon=1e-05, + param_attr=None, + bias_attr=None, + act=None, + data_layout='NCHW', + name=None, +): + """ + :api_attr: Static Graph + + **Group Normalization Layer** + + Refer to `Group Normalization `_ . + + Parameters: + input(Tensor): Tensor with dimension greater than 1, the data type is float32 or float64. + groups(int): The number of groups that divided from channels, the data type + is int32. + epsilon(float, optional): The small value added to the variance to prevent + division by zero, the data type is float32. Default: 1e-05. + param_attr(ParamAttr|bool, optional): ParamAttr object that specifies weight parameter + attribute. If a bool type, only False is supported, which means there is no weight parameter. + Default: None, the default weight parameter attribute is used. For more information, please + refer to :ref:`api_guide_ParamAttr` . + bias_attr(ParamAttr|bool, optional): ParamAttr object that specifies bias parameter + attribute. If a bool type, only False is supported, which means there is no bias parameter. + Default: None, the default bias parameter attribute is used. For more information, please + refer to :ref:`api_guide_ParamAttr` . + act(str, optional): Activation to be applied to the output of group normalization. + data_layout(str, optional): Specify the data format of the input, and the data format of the output + will be consistent with that of the input. An optional string from: `"NCHW"`, `"NHWC"`. + The default is `"NCHW"`. When it is `"NCHW"`, the data is stored in the order of: + `[batch_size, input_channels, *]`. + name (str, optional): The default value is None. Normally there is no need for user to set this + property. For more information, please refer to :ref:`api_guide_Name` . + + Returns: + Tensor: A Tensor has same data type and data format with `input`. + + Examples: + .. code-block:: python + + import paddle + paddle.enable_static() + + data = paddle.static.data(name='data', shape=[2, 8, 32, 32], dtype='float32') + x = paddle.static.nn.group_norm(input=data, groups=4) + print(x.shape) # [2, 8, 32, 32] + """ + helper = LayerHelper('group_norm', **locals()) + dtype = helper.input_dtype() + check_variable_and_dtype( + input, 'input', ['float32', 'float64'], 'group_norm' + ) + # create intput and parameters + inputs = {'X': input} + input_shape = input.shape + if len(input_shape) < 2: + raise ValueError( + f"The dimensions of Op(static.nn.group_norm)'s input should be more than 1. But received {len(input_shape)}" + ) + if data_layout != 'NCHW' and data_layout != 'NHWC': + raise ValueError( + "Param(data_layout) of Op(static.nn.group_norm) got wrong value: received " + + data_layout + + " but only NCHW or NHWC supported." + ) + channel_num = input_shape[1] if data_layout == 'NCHW' else input_shape[-1] + param_shape = [channel_num] + if param_attr: + scale = helper.create_parameter( + attr=helper.param_attr, + shape=param_shape, + dtype=dtype, + default_initializer=Constant(1.0), + ) + inputs['Scale'] = scale + if bias_attr: + bias = helper.create_parameter( + attr=helper.bias_attr, shape=param_shape, dtype=dtype, is_bias=True + ) + inputs['Bias'] = bias + + # create output + mean_out = helper.create_variable(dtype=dtype, stop_gradient=True) + variance_out = helper.create_variable(dtype=dtype, stop_gradient=True) + group_norm_out = helper.create_variable(dtype=dtype) + + helper.append_op( + type="group_norm", + inputs=inputs, + outputs={ + "Y": group_norm_out, + "Mean": mean_out, + "Variance": variance_out, + }, + attrs={ + "epsilon": epsilon, + "groups": groups, + "data_layout": data_layout, + }, + ) + + return helper.append_activation(group_norm_out) + + def conv3d( input, num_filters, -- GitLab