From 52be62c5ae9c060bd8457d425e41c871bb1a0800 Mon Sep 17 00:00:00 2001 From: ceci3 Date: Wed, 8 Jul 2020 10:10:43 +0800 Subject: [PATCH] fix instance norm in dy (#24717) * fix bn & in in dy, test=develop * update instance_norm,test=develop * fix bugs,test=develop * add more case in unittest,test=develop * fix,test=develop * fix,test=develop --- paddle/fluid/operators/instance_norm_op.cc | 154 +++++++++++------- paddle/fluid/operators/instance_norm_op.cu | 76 ++++++--- paddle/fluid/pybind/op_function_generator.cc | 1 + python/paddle/fluid/dygraph/nn.py | 44 +++-- python/paddle/fluid/layers/nn.py | 48 +++--- .../tests/unittests/test_instance_norm_op.py | 59 +++++++ .../tests/unittests/test_norm_nn_grad.py | 19 +++ 7 files changed, 278 insertions(+), 123 deletions(-) diff --git a/paddle/fluid/operators/instance_norm_op.cc b/paddle/fluid/operators/instance_norm_op.cc index d2b59a239a2..f72f7e8b85b 100644 --- a/paddle/fluid/operators/instance_norm_op.cc +++ b/paddle/fluid/operators/instance_norm_op.cc @@ -24,8 +24,6 @@ namespace operators { void InstanceNormOp::InferShape(framework::InferShapeContext *ctx) const { OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "InstanceNorm"); - OP_INOUT_CHECK(ctx->HasInput("Scale"), "Input", "Scale", "InstanceNorm"); - OP_INOUT_CHECK(ctx->HasInput("Bias"), "Input", "Bias", "InstanceNorm"); OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "InstanceNorm"); OP_INOUT_CHECK(ctx->HasOutput("SavedMean"), "Output", "SavedMean", "InstanceNorm"); @@ -51,37 +49,45 @@ void InstanceNormOp::InferShape(framework::InferShapeContext *ctx) const { auto C = x_dims[1]; auto NxC = N * C; - auto scale_dim = ctx->GetInputDim("Scale"); - auto bias_dim = ctx->GetInputDim("Bias"); - - PADDLE_ENFORCE_EQ( - scale_dim.size(), 1UL, - platform::errors::InvalidArgument( - "ShapeError: the dimension of scale must equal to 1." - "But received: the shape of scale is [%s], the dimension " - "of scale is [%d]", - scale_dim, scale_dim.size())); - PADDLE_ENFORCE_EQ(bias_dim.size(), 1UL, - platform::errors::InvalidArgument( - "ShapeError: the dimension of bias must equal to 1." - "But received: the shape of bias is [%s],the dimension " - "of bias is [%d]", - bias_dim, bias_dim.size())); - - bool check = !((!ctx->IsRuntime()) && (framework::product(scale_dim) <= 0 || - framework::product(bias_dim) <= 0)); - - if (check) { - PADDLE_ENFORCE_EQ(scale_dim[0], C, - platform::errors::InvalidArgument( - "ShapeError: the shape of scale must equal to [%d]" - "But received: the shape of scale is [%d]", - C, scale_dim[0])); - PADDLE_ENFORCE_EQ(bias_dim[0], C, - platform::errors::InvalidArgument( - "ShapeError: the shape of bias must equal to [%d]" - "But received: the shape of bias is [%d]", - C, bias_dim[0])); + if (ctx->HasInput("Scale")) { + auto scale_dim = ctx->GetInputDim("Scale"); + + PADDLE_ENFORCE_EQ( + scale_dim.size(), 1UL, + platform::errors::InvalidArgument( + "ShapeError: the dimension of scale must equal to 1." + "But received: the shape of scale is [%s], the dimension " + "of scale is [%d]", + scale_dim, scale_dim.size())); + + bool check = !((!ctx->IsRuntime()) && (framework::product(scale_dim) <= 0)); + + if (check) { + PADDLE_ENFORCE_EQ(scale_dim[0], C, + platform::errors::InvalidArgument( + "ShapeError: the shape of scale must equal to [%d]" + "But received: the shape of scale is [%d]", + C, scale_dim[0])); + } + } + if (ctx->HasInput("Bias")) { + auto bias_dim = ctx->GetInputDim("Bias"); + PADDLE_ENFORCE_EQ( + bias_dim.size(), 1UL, + platform::errors::InvalidArgument( + "ShapeError: the dimension of bias must equal to 1." + "But received: the shape of bias is [%s],the dimension " + "of bias is [%d]", + bias_dim, bias_dim.size())); + + bool check = !((!ctx->IsRuntime()) && (framework::product(bias_dim) <= 0)); + if (check) { + PADDLE_ENFORCE_EQ(bias_dim[0], C, + platform::errors::InvalidArgument( + "ShapeError: the shape of bias must equal to [%d]" + "But received: the shape of bias is [%d]", + C, bias_dim[0])); + } } ctx->SetOutputDim("Y", x_dims); @@ -100,12 +106,16 @@ framework::OpKernelType InstanceNormOp::GetExpectedKernelType( if (input_data_type == framework::proto::VarType::FP64) { in_param_type = framework::proto::VarType::FP64; } - PADDLE_ENFORCE_EQ( - in_param_type, ctx.Input("Scale")->type(), - platform::errors::InvalidArgument("Scale input should be of float type")); - PADDLE_ENFORCE_EQ( - in_param_type, ctx.Input("Bias")->type(), - platform::errors::InvalidArgument("Bias input should be of float type")); + if (ctx.HasInput("Scale")) { + PADDLE_ENFORCE_EQ(in_param_type, ctx.Input("Scale")->type(), + platform::errors::InvalidArgument( + "Scale input should be of float type")); + } + if (ctx.HasInput("Bias")) { + PADDLE_ENFORCE_EQ(in_param_type, ctx.Input("Bias")->type(), + platform::errors::InvalidArgument( + "Bias input should be of float type")); + } return framework::OpKernelType(input_data_type, ctx.GetPlace()); } @@ -121,10 +131,12 @@ void InstanceNormOpMaker::Make() { AddInput("X", "The input tensor"); AddInput("Scale", "Scale is a 1-dimensional tensor of size C " - "that is applied to the output"); + "that is applied to the output") + .AsDispensable(); AddInput("Bias", "Bias is a 1-dimensional tensor of size C " - "that is applied to the output"); + "that is applied to the output") + .AsDispensable(); AddOutput("Y", "result after normalization"); AddOutput("SavedMean", "Mean of the current mini batch, " @@ -199,9 +211,26 @@ class InstanceNormKernel const auto *scale = ctx.Input("Scale"); const auto *bias = ctx.Input("Bias"); - auto scale_e = framework::EigenVector::Flatten(*scale); + + Tensor scale_data; + Tensor bias_data; + if (!scale) { + scale_data.mutable_data({C}, ctx.GetPlace()); + set_constant(dev_ctx, &scale_data, static_cast(1)); + } + + if (!bias) { + bias_data.mutable_data({C}, ctx.GetPlace()); + set_constant(dev_ctx, &bias_data, static_cast(0)); + } + auto scale_e = scale + ? framework::EigenVector::Flatten(*scale) + : framework::EigenVector::Flatten( + const_cast(scale_data)); auto scale_arr = scale_e.reshape(C_shape); - auto bias_e = framework::EigenVector::Flatten(*bias); + auto bias_e = bias ? framework::EigenVector::Flatten(*bias) + : framework::EigenVector::Flatten( + const_cast(bias_data)); auto bias_arr = bias_e.reshape(C_shape); y->mutable_data(ctx.GetPlace()); @@ -219,7 +248,6 @@ class InstanceNormKernel void InstanceNormGradOp::InferShape(framework::InferShapeContext *ctx) const { OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "InstanceNormGrad"); - OP_INOUT_CHECK(ctx->HasInput("Scale"), "Input", "Scale", "InstanceNormGrad"); OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Y")), "Input", framework::GradVarName("Y"), "InstanceNormGrad"); OP_INOUT_CHECK(ctx->HasInput("SavedMean"), "Input", "SavedMean", @@ -230,15 +258,13 @@ void InstanceNormGradOp::InferShape(framework::InferShapeContext *ctx) const { // check output OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output", framework::GradVarName("X"), "InstanceNormGrad"); - if (ctx->HasOutput(framework::GradVarName("Scale"))) { - OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("Bias")), "Output", - framework::GradVarName("Bias"), "InstanceNormGrad"); - } const auto x_dims = ctx->GetInputDim("X"); const int C = x_dims[1]; ctx->SetOutputDim(framework::GradVarName("X"), x_dims); if (ctx->HasOutput(framework::GradVarName("Scale"))) { ctx->SetOutputDim(framework::GradVarName("Scale"), {C}); + } + if (ctx->HasOutput(framework::GradVarName("Bias"))) { ctx->SetOutputDim(framework::GradVarName("Bias"), {C}); } } @@ -299,7 +325,18 @@ class InstanceNormGradKernel Eigen::DSizes param_shape(N, C); Eigen::DSizes shape(NxC, sample_size); - auto scale_e = framework::EigenVector::Flatten(*scale); + math::SetConstant set_constant; + + Tensor scale_data; + if (!scale) { + scale_data.mutable_data({C}, ctx.GetPlace()); + set_constant(dev_ctx, &scale_data, static_cast(1)); + } + + auto scale_e = scale + ? framework::EigenVector::Flatten(*scale) + : framework::EigenVector::Flatten( + const_cast(scale_data)); auto mean_e = framework::EigenVector::Flatten(*saved_mean); auto inv_var_e = framework::EigenVector::Flatten(*saved_inv_variance); auto dy_e = framework::EigenVector::Flatten(*d_y); @@ -314,7 +351,6 @@ class InstanceNormGradKernel auto tmp = (x_arr - mean_arr.eval().broadcast(bcast)) * inv_var_arr.eval().broadcast(bcast); - math::SetConstant set_constant; // math: d_bias = np.sum(d_y, axis=(n,h,w)) // math: d_scale = np.sum((X-mean) / inv_std * dy, axis=(n, h,w)) if (d_scale && d_bias) { @@ -324,8 +360,8 @@ class InstanceNormGradKernel set_constant(dev_ctx, d_bias, static_cast(0)); auto d_scale_e = framework::EigenVector::Flatten(*d_scale); - auto d_bias_e = framework::EigenVector::Flatten(*d_bias); auto d_scale_data = d_scale_e.reshape(C_shape); + auto d_bias_e = framework::EigenVector::Flatten(*d_bias); auto d_bias_data = d_bias_e.reshape(C_shape); d_bias_data.device(*place) = dy_arr.sum(mean_rdims).reshape(param_shape).sum(rdims); @@ -360,8 +396,6 @@ class InstanceNormGradKernel void InstanceNormDoubleGradOp::InferShape( framework::InferShapeContext *ctx) const { OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "InstanceNormDoubleGrad"); - OP_INOUT_CHECK(ctx->HasInput("Scale"), "Input", "Scale", - "InstanceNormDoubleGrad"); OP_INOUT_CHECK(ctx->HasInput("SavedMean"), "Input", "SavedMean", "InstanceNormDoubleGrad"); OP_INOUT_CHECK(ctx->HasInput("SavedVariance"), "Input", "SavedVariance", @@ -426,6 +460,9 @@ class InstanceNormDoubleGradKernel auto *dScale = ctx.Output("DScale"); auto *ddY = ctx.Output("DDY"); + auto &dev_ctx = ctx.template device_context(); + math::SetConstant set_constant; + const auto &x_dims = X->dims(); int N, C, H, W, D; ExtractNCWHD(x_dims, DataLayout::kNCHW, &N, &C, &H, &W, &D); @@ -455,7 +492,13 @@ class InstanceNormDoubleGradKernel mean_tile_data = mean_arr.transpose().replicate(sample_size, 1); inv_var_tile_data = inv_var_arr.transpose().replicate(sample_size, 1); - ConstEigenVectorArrayMap scale_arr(Scale->data(), C); + Tensor Scale_data; + if (!Scale) { + Scale_data.mutable_data({C}, ctx.GetPlace()); + set_constant(dev_ctx, &Scale_data, static_cast(1)); + } + ConstEigenVectorArrayMap scale_arr( + Scale ? Scale->data() : Scale_data.data(), C); Tensor scale_tile; scale_tile.Resize({sample_size, NxC}); @@ -483,9 +526,6 @@ class InstanceNormDoubleGradKernel // inv_var.pow(3) * (x - mean) * np.mean(dy * (x - mean), // axis=(h,w)))) - auto &dev_ctx = ctx.template device_context(); - math::SetConstant set_constant; - Tensor x_sub_mean_mul_invstd; x_sub_mean_mul_invstd.Resize({sample_size, NxC}); x_sub_mean_mul_invstd.mutable_data(ctx.GetPlace()); diff --git a/paddle/fluid/operators/instance_norm_op.cu b/paddle/fluid/operators/instance_norm_op.cu index 1567c229cdc..83236712098 100644 --- a/paddle/fluid/operators/instance_norm_op.cu +++ b/paddle/fluid/operators/instance_norm_op.cu @@ -146,10 +146,19 @@ class InstanceNormKernel const int max_blocks = std::max(max_threads / block, 1); const int grid = std::min((NxC + block - 1) / block, max_blocks); - repeat_param<<>>( - scale->data(), scale_tmp.data(), N, C); - repeat_param<<>>( - bias->data(), bias_tmp.data(), N, C); + math::SetConstant set_constant; + if (scale) { + repeat_param<<>>( + scale->data(), scale_tmp.data(), N, C); + } else { + set_constant(dev_ctx, &scale_tmp, static_cast(1)); + } + if (bias) { + repeat_param<<>>( + bias->data(), bias_tmp.data(), N, C); + } else { + set_constant(dev_ctx, &bias_tmp, static_cast(0)); + } auto handle = dev_ctx.cudnn_handle(); @@ -267,24 +276,27 @@ class InstanceNormGradKernel d_scale->mutable_data(ctx.GetPlace()); d_bias->mutable_data(ctx.GetPlace()); } - PADDLE_ENFORCE_EQ( - scale->dims().size(), 1UL, - platform::errors::InvalidArgument( - "The `shape` in InstanceNormOp is invalid: " - "the size of scale's dimensions must be equal to 1. But " - "received: the size of scale's dimensions" - "is [%d]", - scale->dims().size())); - PADDLE_ENFORCE_EQ(scale->dims()[0], C, - platform::errors::InvalidArgument( - "The `shape` in InstanceNormOp is invalid: " - "the first dimension of scale must be equal to " - "Channels([%d]). But received: " - "the first dimension of scale is [%d]," - "the dimensions of scale is [%s], ", - C, scale->dims()[0], scale->dims())); + if (scale) { + PADDLE_ENFORCE_EQ( + scale->dims().size(), 1UL, + platform::errors::InvalidArgument( + "The `shape` in InstanceNormOp is invalid: " + "the size of scale's dimensions must be equal to 1. But " + "received: the size of scale's dimensions" + "is [%d]", + scale->dims().size())); + PADDLE_ENFORCE_EQ(scale->dims()[0], C, + platform::errors::InvalidArgument( + "The `shape` in InstanceNormOp is invalid: " + "the first dimension of scale must be equal to " + "Channels([%d]). But received: " + "the first dimension of scale is [%d]," + "the dimensions of scale is [%s], ", + C, scale->dims()[0], scale->dims())); + } auto &dev_ctx = ctx.template device_context(); + math::SetConstant set_constant; const int n = x->numel(); const int block = 512; @@ -300,8 +312,12 @@ class InstanceNormGradKernel ctx.AllocateTmpTensor({NxC}, dev_ctx); Tensor d_bias_tmp = ctx.AllocateTmpTensor({NxC}, dev_ctx); - repeat_param<<>>( - scale->data(), scale_tmp.data(), N, C); + if (scale) { + repeat_param<<>>( + scale->data(), scale_tmp.data(), N, C); + } else { + set_constant(dev_ctx, &scale_tmp, static_cast(1)); + } std::vector dims; std::vector strides; @@ -361,7 +377,7 @@ class InstanceNormGradKernel } else { if (d_x) { GradComputeDX<<>>( - d_y->data(), scale->data>(), + d_y->data(), scale_tmp.data>(), saved_mean_data, x->data(), saved_var_data, C, H * W * D, d_x->data()); } @@ -610,7 +626,6 @@ class InstanceNormDoubleGradKernel auto *ddY = ctx.Output("DDY"); const T *x_data = X->data(); - const T *scale_data = Scale->data(); const T *dy_data = dY->data(); const T *ddx_data = (ddX == nullptr ? nullptr : ddX->data()); @@ -620,6 +635,9 @@ class InstanceNormDoubleGradKernel const T *mean_data = Saved_mean->data(); const T *variance_data = Saved_variance->data(); + auto &dev_ctx = ctx.template device_context(); + math::SetConstant set_zero; + auto &x_dims = X->dims(); int N, C, H, W, D; ExtractNCWHD(x_dims, DataLayout::kNCHW, &N, &C, &H, &W, &D); @@ -627,15 +645,19 @@ class InstanceNormDoubleGradKernel const int n = X->numel(); int sample_size = n / N / C; - auto &dev_ctx = ctx.template device_context(); + Tensor scale_tmp; + if (!Scale) { + scale_tmp.mutable_data({C}, ctx.GetPlace()); + set_zero(dev_ctx, &scale_tmp, static_cast(1)); + } + const T *scale_data = Scale ? Scale->data() : scale_tmp.data(); + const int block = 512; int max_threads = dev_ctx.GetMaxPhysicalThreadCount(); const int max_blocks = std::max(max_threads / block, 1); const int grid = NxC; const int grid1 = (C + block - 1) / block; - math::SetConstant set_zero; - if (dX) { T *dx_data = dX->mutable_data(ctx.GetPlace()); set_zero(dev_ctx, dX, static_cast(0)); diff --git a/paddle/fluid/pybind/op_function_generator.cc b/paddle/fluid/pybind/op_function_generator.cc index 89ebb925363..ee9fa26b2fb 100644 --- a/paddle/fluid/pybind/op_function_generator.cc +++ b/paddle/fluid/pybind/op_function_generator.cc @@ -34,6 +34,7 @@ // need to manually specify them in this map. std::map> op_ins_map = { {"layer_norm", {"X", "Scale", "Bias"}}, + {"instance_norm", {"X", "Scale", "Bias"}}, {"gru_unit", {"Input", "HiddenPrev", "Weight", "Bias"}}, {"label_smooth", {"X", "PriorDist"}}, {"assign", {"X"}}, diff --git a/python/paddle/fluid/dygraph/nn.py b/python/paddle/fluid/dygraph/nn.py index c4ea1fabfcb..cc2b746b0c1 100644 --- a/python/paddle/fluid/dygraph/nn.py +++ b/python/paddle/fluid/dygraph/nn.py @@ -1028,16 +1028,16 @@ class InstanceNorm(layers.Layer): num_channels(int): Indicate the number of channels of the input ``Tensor``. epsilon(float, optional): A value added to the denominator for numerical stability. Default is 1e-5. - param_attr(ParamAttr, optional): The parameter attribute for Parameter `scale` + param_attr(ParamAttr|bool, optional): The parameter attribute for Parameter `scale` of instance_norm. If it is set to None or one attribute of ParamAttr, instance_norm will create ParamAttr as param_attr, the name of scale can be set in ParamAttr. If the Initializer of the param_attr is not set, the parameter is initialized - one. Default: None. - bias_attr(ParamAttr, optional): The parameter attribute for the bias of instance_norm. + one. If it is set to False, will not create param_attr. Default: None. + bias_attr(ParamAttr|bool, optional): The parameter attribute for the bias of instance_norm. If it is set to None or one attribute of ParamAttr, instance_norm will create ParamAttr as bias_attr, the name of bias can be set in ParamAttr. If the Initializer of the bias_attr is not set, the bias is initialized zero. - Default: None. + If it is set to False, will not create bias_attr. Default: None. dtype(str, optional): Indicate the data type of the input ``Tensor``, which can be float32 or float64. Default: float32. @@ -1071,25 +1071,30 @@ class InstanceNorm(layers.Layer): bias_attr=None, dtype='float32'): super(InstanceNorm, self).__init__() - assert bias_attr is not False, "bias_attr should not be False in InstanceNorm." + if param_attr == False or bias_attr == False: + assert bias_attr == param_attr, "param_attr and bias_attr must be set to Fasle at the same time in InstanceNorm" self._epsilon = epsilon self._param_attr = param_attr self._bias_attr = bias_attr self._dtype = dtype - self.scale = self.create_parameter( - attr=self._param_attr, - shape=[num_channels], - dtype=self._dtype, - default_initializer=Constant(1.0), - is_bias=False) - self.bias = self.create_parameter( - attr=self._bias_attr, - shape=[num_channels], - dtype=self._dtype, - default_initializer=Constant(0.0), - is_bias=True) + if param_attr != False and bias_attr != False: + self.scale = self.create_parameter( + attr=self._param_attr, + shape=[num_channels], + dtype=self._dtype, + default_initializer=Constant(1.0), + is_bias=False) + self.bias = self.create_parameter( + attr=self._bias_attr, + shape=[num_channels], + dtype=self._dtype, + default_initializer=Constant(0.0), + is_bias=True) + else: + self.scale = None + self.bias = None def forward(self, input): if in_dygraph_mode(): @@ -1102,7 +1107,10 @@ class InstanceNorm(layers.Layer): attrs = {"epsilon": self._epsilon} - inputs = {"X": [input], "Scale": [self.scale], "Bias": [self.bias]} + if self.scale and self.bias: + inputs = {"X": [input], "Scale": [self.scale], "Bias": [self.bias]} + else: + inputs = {"X": [input]} saved_mean = self._helper.create_variable_for_type_inference( dtype=self._dtype, stop_gradient=True) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index fbe42c0aadd..7769bf643d8 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -3114,15 +3114,17 @@ def instance_norm(input, The data type is float32 or float64. epsilon(float, Default 1e-05): A value added to the denominator for numerical stability. Default is 1e-5. - param_attr(ParamAttr|None): The parameter attribute for Parameter `scale` + param_attr(ParamAttr|None|bool, optional): The parameter attribute for Parameter `scale` of instance_norm. If it is set to None or one attribute of ParamAttr, instance_norm will create ParamAttr as param_attr, the name of scale can be set in ParamAttr. If the Initializer of the param_attr is not set, the parameter is initialized - with Xavier. Default: None. - bias_attr(ParamAttr|None): The parameter attribute for the bias of instance_norm. + with Xavier. If the param_attr is set to False, instance_norm will not create param_attr. + Default: None. + bias_attr(ParamAttr|None|bool, optional): The parameter attribute for the bias of instance_norm. If it is set to None or one attribute of ParamAttr, instance_norm will create ParamAttr as bias_attr, the name of bias can be set in ParamAttr. If the Initializer of the bias_attr is not set, the bias is initialized zero. + If the bias_attr is set to False, instance_norm will not create bias_attr. Default: None. name(string, Default None): A name for this layer(optional). If set None, the layer will be named automatically. @@ -3142,7 +3144,9 @@ def instance_norm(input, """ check_variable_and_dtype(input, 'input', ['float32', 'float64'], 'instance_norm') - assert bias_attr is not False, "bias_attr should not be False in instance_norm." + if param_attr is False: + assert bias_attr is False, "param_attr and bias_attr must be set to Fasle at the same time in instance_norm" + helper = LayerHelper('instance_norm', **locals()) dtype = helper.input_dtype() @@ -3155,18 +3159,19 @@ def instance_norm(input, param_shape = [channel_num] - # create parameter - scale = helper.create_parameter( - attr=helper.param_attr, - shape=param_shape, - dtype=dtype, - default_initializer=Constant(1.0)) - bias = helper.create_parameter( - attr=helper.bias_attr, - shape=param_shape, - dtype=dtype, - is_bias=True, - default_initializer=Constant(0.0)) + if param_attr and bias_attr: + # create parameter + scale = helper.create_parameter( + attr=helper.param_attr, + shape=param_shape, + dtype=dtype, + default_initializer=Constant(1.0)) + bias = helper.create_parameter( + attr=helper.bias_attr, + shape=param_shape, + dtype=dtype, + is_bias=True, + default_initializer=Constant(0.0)) # create output saved_mean = helper.create_variable_for_type_inference( @@ -3176,13 +3181,14 @@ def instance_norm(input, instance_norm_out = helper.create_variable_for_type_inference(dtype) + inputs = {"X": input} + if param_attr and bias_attr: + inputs["Scale"] = scale + inputs["Bias"] = bias + helper.append_op( type="instance_norm", - inputs={ - "X": input, - "Scale": scale, - "Bias": bias, - }, + inputs=inputs, outputs={ "Y": instance_norm_out, "SavedMean": saved_mean, diff --git a/python/paddle/fluid/tests/unittests/test_instance_norm_op.py b/python/paddle/fluid/tests/unittests/test_instance_norm_op.py index 39e994873dc..b7fcc63ca59 100644 --- a/python/paddle/fluid/tests/unittests/test_instance_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_instance_norm_op.py @@ -20,6 +20,7 @@ import paddle.fluid as fluid from paddle.fluid.op import Operator from op_test import OpTest from paddle.fluid import Program, program_guard +from paddle.fluid.dygraph import to_variable def _reference_instance_norm_naive(x, scale, bias, epsilon, mean, var): @@ -214,5 +215,63 @@ class TestInstanceNormOpError(unittest.TestCase): self.assertRaises(TypeError, fluid.layers.instance_norm, x2) +class TestElasticNormOp(unittest.TestCase): + def init_test_case(self): + self.epsilon = 1e-5 + self.places = [core.CPUPlace()] + if core.is_compiled_with_cuda() and core.op_support_gpu( + "instance_norm"): + self.places.append(core.CUDAPlace(0)) + + def test_norm(self): + self.init_test_case() + inputs = np.random.random((2, 3, 5, 5)).astype(np.float32) + shape = inputs.shape + n, c, h, w = shape[0], shape[1], shape[2], shape[3] + scale_shape = [c] + mean_shape = [n * c] + scale = np.ones(scale_shape).astype(np.float32) + bias = np.zeros(scale_shape).astype(np.float32) + mean, variance = _cal_mean_variance(inputs, self.epsilon, mean_shape) + out_np, _, _ = _reference_instance_norm_naive( + inputs, scale, bias, self.epsilon, mean, variance) + + for place in self.places: + with fluid.dygraph.guard(place): + instance_norm = fluid.dygraph.InstanceNorm( + 5, param_attr=False, bias_attr=False) + outputs = instance_norm(to_variable(inputs)) + self.assertTrue(np.allclose(outputs.numpy(), out_np, atol=1e-6)) + + +class TestElasticNormOpCase2(unittest.TestCase): + def init_test_case(self): + self.epsilon = 1e-5 + self.places = [core.CPUPlace()] + if core.is_compiled_with_cuda() and core.op_support_gpu( + "instance_norm"): + self.places.append(core.CUDAPlace(0)) + + def test_norm(self): + self.init_test_case() + inputs = np.random.random((2, 3, 5, 5)).astype(np.float32) + shape = inputs.shape + n, c, h, w = shape[0], shape[1], shape[2], shape[3] + scale_shape = [c] + mean_shape = [n * c] + scale = np.ones(scale_shape).astype(np.float32) + bias = np.zeros(scale_shape).astype(np.float32) + mean, variance = _cal_mean_variance(inputs, self.epsilon, mean_shape) + out_np, _, _ = _reference_instance_norm_naive( + inputs, scale, bias, self.epsilon, mean, variance) + + for place in self.places: + with fluid.dygraph.guard(place): + instance_norm = fluid.dygraph.InstanceNorm( + 3, param_attr=True, bias_attr=True) + outputs = instance_norm(to_variable(inputs)) + self.assertTrue(np.allclose(outputs.numpy(), out_np, atol=1e-6)) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_norm_nn_grad.py b/python/paddle/fluid/tests/unittests/test_norm_nn_grad.py index 4f29467a3c5..c44ea454271 100644 --- a/python/paddle/fluid/tests/unittests/test_norm_nn_grad.py +++ b/python/paddle/fluid/tests/unittests/test_norm_nn_grad.py @@ -49,5 +49,24 @@ class TestInstanceNormDoubleGradCheck(unittest.TestCase): self.func(p) +class TestInstanceNormDoubleGradCheckWithoutParamBias( + TestInstanceNormDoubleGradCheck): + @prog_scope() + def func(self, place): + prog = fluid.Program() + with fluid.program_guard(prog): + np.random.seed() + shape = [2, 3, 4, 5] + dtype = "float32" + eps = 0.005 + atol = 1e-4 + x = layers.create_parameter(dtype=dtype, shape=shape, name='x') + z = fluid.layers.instance_norm( + input=x, param_attr=False, bias_attr=False) + x_arr = np.random.uniform(-1, 1, shape).astype(dtype) + gradient_checker.double_grad_check( + [x], z, x_init=x_arr, atol=atol, place=place, eps=eps) + + if __name__ == "__main__": unittest.main() -- GitLab