From 5b69242fab4be034f6f6487ab3c56e3a31c2f3a6 Mon Sep 17 00:00:00 2001 From: yaoxuefeng Date: Tue, 14 Apr 2020 12:49:51 +0800 Subject: [PATCH] modify datanorm op test=develop (#23030) --- paddle/fluid/framework/unused_var_check.cc | 4 +- paddle/fluid/operators/data_norm_op.cc | 247 +++++++++++++++++- python/paddle/fluid/layers/nn.py | 48 +++- .../tests/unittests/test_data_norm_op.py | 214 +++++++++++++-- 4 files changed, 470 insertions(+), 43 deletions(-) diff --git a/paddle/fluid/framework/unused_var_check.cc b/paddle/fluid/framework/unused_var_check.cc index 5eb8011385..7a81bc15b8 100644 --- a/paddle/fluid/framework/unused_var_check.cc +++ b/paddle/fluid/framework/unused_var_check.cc @@ -53,7 +53,9 @@ const std::unordered_set op_has_unsed_vars_white_list = { "precision_recall", // 1 "fusion_seqpool_cvm_concat", // 2 "fused_batch_norm_act", // 2 - "fused_batch_norm_act_grad" // 2 + "fused_batch_norm_act_grad", // 2 + "data_norm", // 0 + "data_norm_grad", // 0 }; namespace paddle { diff --git a/paddle/fluid/operators/data_norm_op.cc b/paddle/fluid/operators/data_norm_op.cc index 88438c9e97..394feba78e 100644 --- a/paddle/fluid/operators/data_norm_op.cc +++ b/paddle/fluid/operators/data_norm_op.cc @@ -51,6 +51,17 @@ class DataNormOp : public framework::OperatorWithKernel { PADDLE_ENFORCE(ctx->HasOutput("Means"), ""); PADDLE_ENFORCE(ctx->HasOutput("Scales"), ""); PADDLE_ENFORCE(ctx->HasOutput("Y"), ""); + bool enable_scale_and_shift = + ctx->Attrs().Get("enable_scale_and_shift"); + if (enable_scale_and_shift) { + PADDLE_ENFORCE_EQ( + ctx->HasInput("scale_w"), true, + platform::errors::InvalidArgument( + "Input(scale_w) of DataNormOp should not be null.")); + PADDLE_ENFORCE_EQ(ctx->HasInput("bias"), true, + platform::errors::InvalidArgument( + "Input(bias) of DataNormOp should not be null.")); + } const auto x_dims = ctx->GetInputDim("X"); const DataLayout data_layout = framework::StringToDataLayout( @@ -72,6 +83,45 @@ class DataNormOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSquareSum")[0], C); } + if (enable_scale_and_shift) { + auto scale_dim = ctx->GetInputDim("scale_w"); + auto bias_dim = ctx->GetInputDim("bias"); + + PADDLE_ENFORCE_EQ( + scale_dim.size(), 1UL, + platform::errors::InvalidArgument("the dimensionof scale" + "must equal to 1. But received: " + "the shape of scale is [%s], " + "the dimensionof scale is [%d]", + scale_dim, scale_dim.size())); + PADDLE_ENFORCE_EQ( + bias_dim.size(), 1UL, + platform::errors::InvalidArgument("the dimension of bias" + "must equal to 1. But received: " + "the shape of bias is [%s]," + "the dimension of bias is [%d]", + bias_dim, bias_dim.size())); + + bool check = true; + if ((!ctx->IsRuntime()) && (framework::product(scale_dim) <= 0 || + framework::product(bias_dim) <= 0)) { + check = false; + } + + if (check) { + PADDLE_ENFORCE_EQ(scale_dim[0], C, + platform::errors::InvalidArgument( + "the shape of scale must equal to [%d]" + "But received: the shape of scale is [%d]", + C, scale_dim[0])); + PADDLE_ENFORCE_EQ(bias_dim[0], C, + platform::errors::InvalidArgument( + "the shape of bias must equal to [%d]" + "But received: the shape of bias is [%d]", + C, bias_dim[0])); + } + } + ctx->SetOutputDim("Y", x_dims); ctx->SetOutputDim("Means", {C}); ctx->SetOutputDim("Scales", {C}); @@ -99,6 +149,17 @@ class DataNormOp : public framework::OperatorWithKernel { ctx, "BatchSquareSum"), "BatchSquareSum input should be of float type"); + bool enable_scale_and_shift = ctx.Attr("enable_scale_and_shift"); + if (enable_scale_and_shift) { + PADDLE_ENFORCE_EQ(dn_param_type, + OperatorWithKernel::IndicateVarDataType(ctx, "scale_w"), + platform::errors::InvalidArgument( + "scale_w input should be of float type")); + PADDLE_ENFORCE_EQ(dn_param_type, + OperatorWithKernel::IndicateVarDataType(ctx, "bias"), + platform::errors::InvalidArgument( + "bias input should be of float type")); + } // TODO(pzelazko-intel): enable MKLDNN layout when it's ready framework::LibraryType library = framework::LibraryType::kPlain; framework::DataLayout layout = framework::DataLayout::kAnyLayout; @@ -133,6 +194,19 @@ class DataNormOpMaker : public framework::OpProtoAndCheckerMaker { "summary_decay_rate", "(float, default 0.9999999) The decay rate when update the summary") .SetDefault(0.9999999); + AddAttr( + "enable_scale_and_shift", + "(bool, default false) Set to true to enable scale and shift such as " + "batch_norm op") + .SetDefault(false); + AddInput("scale_w", + "scale_w is a 1-dimensional tensor of size C " + "that is applied to the output") + .AsDispensable(); + AddInput("bias", + "bias is a 1-dimensional tensor of size C " + "that is applied to the output") + .AsDispensable(); AddAttr("data_layout", "").SetDefault("NCHW"); AddAttr("sync_stats", "(bool, default false) only used in multi-GPU") .SetDefault(false); @@ -194,7 +268,6 @@ class DataNormKernel // alloc memory T *y_data = y->mutable_data(ctx.GetPlace()); - Eigen::Array inv_std(C); ConstEigenVectorArrayMap b_size_arr( ctx.Input("BatchSize")->data(), C); ConstEigenVectorArrayMap b_sum_arr( @@ -210,6 +283,7 @@ class DataNormKernel const T *means_data = mean_out->data(); const T *x_data = x->data(); + const T *scales_data = scales->data(); const int slot_dim = ctx.Attr("slot_dim"); T min_precision = 1e-7f; @@ -218,7 +292,8 @@ class DataNormKernel case DataLayout::kNHWC: { // if slot_dim is set and batch size is larger than zero, we choose // to check if show number is zero, if so, skip normalization. - if (slot_dim > 0 && N > 0) { + if (slot_dim > 0 && N > 0 && + (!ctx.Attr("enable_scale_and_shift"))) { const int item_size = x->numel() / N; // location of show number in one embedding int offset = 0; @@ -239,10 +314,56 @@ class DataNormKernel offset += item_size; } } else { - EigenArrayMap(y_data, C, N) = - (ConstEigenArrayMap(x->data(), C, N).colwise() - means_arr) - .colwise() * - scales_arr; + if (!ctx.Attr("enable_scale_and_shift") && slot_dim <= 0) { + EigenArrayMap(y_data, C, N) = + (ConstEigenArrayMap(x->data(), C, N).colwise() - + means_arr) + .colwise() * + scales_arr; + } else if (ctx.Attr("enable_scale_and_shift") && + slot_dim <= 0) { + const auto *scale_w = ctx.Input("scale_w"); + const auto *bias = ctx.Input("bias"); + ConstEigenVectorArrayMap scale_w_arr(scale_w->data(), C); + ConstEigenVectorArrayMap bias_arr(bias->data(), C); + + Eigen::Array new_scale = + scales_arr * scale_w_arr; + Eigen::Array new_bias = + bias_arr - means_arr * scales_arr * scale_w_arr; + EigenArrayMap(y_data, C, N) = + (ConstEigenArrayMap(x->data(), C, N).colwise() * + new_scale) + .colwise() + + new_bias; + + } else { + const int item_size = x->numel() / N; + const auto *scale_w = ctx.Input("scale_w"); + const auto *bias = ctx.Input("bias"); + const T *scale_w_data = scale_w->data(); + const T *bias_data = bias->data(); + // location of show number in one embedding + int offset = 0; + for (int k = 0; k < N; ++k) { + for (int i = 0; i < item_size; i += slot_dim) { + if (x_data[offset + i] > -min_precision && + x_data[offset + i] < min_precision) { + // show = 0 + memset(y_data + offset + i, 0, sizeof(T) * slot_dim); + } else { + for (int j = i; j < i + slot_dim; ++j) { + y_data[offset + j] = ((x_data[offset + j] - means_data[j]) * + scales_data[j]) * + scale_w_data[j] + + bias_data[j]; + } + } + } // end for i + + offset += item_size; + } // end for k + } } break; } @@ -274,7 +395,8 @@ class DataNormGradOp : public framework::OperatorWithKernel { "Output(BatchSquareSum) of DataNormGradOp should not be null.")); PADDLE_ENFORCE(ctx->HasInput("Means"), ""); PADDLE_ENFORCE(ctx->HasInput("Scales"), ""); - + bool enable_scale_and_shift = + ctx->Attrs().Get("enable_scale_and_shift"); // check output PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("BatchSize")), ""); PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("BatchSum")), ""); @@ -294,6 +416,22 @@ class DataNormGradOp : public framework::OperatorWithKernel { ctx->SetOutputDim(framework::GradVarName("BatchSize"), {C}); ctx->SetOutputDim(framework::GradVarName("BatchSum"), {C}); ctx->SetOutputDim(framework::GradVarName("BatchSquareSum"), {C}); + if (enable_scale_and_shift) { + const bool has_scale_grad = + ctx->HasOutput(framework::GradVarName("scale_w")); + const bool has_bias_grad = ctx->HasOutput(framework::GradVarName("bias")); + + PADDLE_ENFORCE_EQ((has_scale_grad == has_bias_grad), true, + platform::errors::InvalidArgument( + "Output(Scale@GRAD) and Output(Bias@GRAD)" + "must be null or not be null at same time. " + "But now, has Scale@Grad=[%d], has Bias@GRAD=[%d]", + has_scale_grad, has_bias_grad)); + if (has_scale_grad) { + ctx->SetOutputDim(framework::GradVarName("scale_w"), {C}); + ctx->SetOutputDim(framework::GradVarName("bias"), {C}); + } + } } protected: @@ -353,18 +491,23 @@ class DataNormGradKernel const int C = (data_layout == DataLayout::kNCHW ? x_dims[1] : x_dims[x_dims.size() - 1]); - // init output Tensor *d_x = nullptr; if (ctx.HasOutput(framework::GradVarName("X"))) { d_x = ctx.Output(framework::GradVarName("X")); } + auto *d_batch_size = ctx.Output(framework::GradVarName("BatchSize")); auto *d_batch_sum = ctx.Output(framework::GradVarName("BatchSum")); auto *d_batch_square_sum = ctx.Output(framework::GradVarName("BatchSquareSum")); + const T *mean_data = means->data(); + const T *inv_var_data = scales->data(); + ConstEigenVectorArrayMap mean_arr(mean_data, C); + ConstEigenVectorArrayMap inv_var_arr(inv_var_data, C); + T *d_batch_size_data = d_batch_size->mutable_data(ctx.GetPlace()); T *d_batch_sum_data = d_batch_sum->mutable_data(ctx.GetPlace()); T *d_batch_square_sum_data = @@ -372,7 +515,6 @@ class DataNormGradKernel EigenVectorArrayMap d_batch_size_arr(d_batch_size_data, C); EigenVectorArrayMap d_batch_sum_arr(d_batch_sum_data, C); EigenVectorArrayMap d_batch_square_sum_arr(d_batch_square_sum_data, C); - d_batch_size_arr.setZero(); d_batch_sum_arr.setZero(); d_batch_square_sum_arr.setZero(); @@ -392,8 +534,86 @@ class DataNormGradKernel if (d_x != nullptr) { EigenArrayMap d_x_arr(d_x->mutable_data(ctx.GetPlace()), C, N); d_x_arr.setZero(); - for (int nc = 0; nc < N; ++nc) { - d_x_arr.col(nc) = d_y_arr.col(nc) * scales_arr; + if (!ctx.Attr("enable_scale_and_shift")) { + for (int nc = 0; nc < N; ++nc) { + d_x_arr.col(nc) = d_y_arr.col(nc) * scales_arr; + } + } else { + const auto *scale_w = ctx.Input("scale_w"); + auto *d_scale = + ctx.Output(framework::GradVarName("scale_w")); + auto *d_bias = ctx.Output(framework::GradVarName("bias")); + ConstEigenVectorArrayMap scale_arr(scale_w->data(), C); + T *d_bias_data = nullptr; + T *d_scale_data = nullptr; + + d_scale->mutable_data(ctx.GetPlace()); + d_bias->mutable_data(ctx.GetPlace()); + d_bias_data = d_bias->mutable_data(ctx.GetPlace()); + d_scale_data = d_scale->mutable_data(ctx.GetPlace()); + + EigenVectorArrayMap d_bias_arr(d_bias_data, C); + EigenVectorArrayMap d_scale_arr(d_scale_data, C); + Tensor dy_sum; + dy_sum.Resize({C}); + dy_sum.mutable_data(ctx.GetPlace()); + EigenVectorArrayMap dy_sum_arr( + dy_sum.mutable_data(ctx.GetPlace()), C); + Tensor dy_mul_x_sub_mean_mul_invstd_sum; + dy_mul_x_sub_mean_mul_invstd_sum.Resize({C}); + dy_mul_x_sub_mean_mul_invstd_sum.mutable_data(ctx.GetPlace()); + EigenVectorArrayMap dy_mul_x_sub_mean_mul_invstd_sum_arr( + dy_mul_x_sub_mean_mul_invstd_sum.mutable_data( + ctx.GetPlace()), + C); + + dy_sum_arr.setZero(); + dy_mul_x_sub_mean_mul_invstd_sum_arr.setZero(); + + if (slot_dim <= 0) { + for (int n = 0; n < N; ++n) { + dy_sum_arr += d_y_arr.col(n); + dy_mul_x_sub_mean_mul_invstd_sum_arr += + ((x_arr.col(n) - mean_arr) * inv_var_arr * d_y_arr.col(n)); + } + if (d_scale && d_bias) { + d_bias_arr = dy_sum_arr; + d_scale_arr = dy_mul_x_sub_mean_mul_invstd_sum_arr; + } + for (int nc = 0; nc < N; ++nc) { + d_x_arr.col(nc) = d_y_arr.col(nc) * scales_arr * scale_arr; + } + } else { + int offset = 0; + const int item_size = x->numel() / N; + T *d_x_data = d_x->mutable_data(ctx.GetPlace()); + T *d_scale_data = d_scale->mutable_data(ctx.GetPlace()); + T *d_bias_data = d_bias->mutable_data(ctx.GetPlace()); + const T *dy_data = d_y->data(); + const T *scales_data = scales->data(); + const T *scale_w_data = scale_w->data(); + const T *x_data = x->data(); + for (int i = 0; i < item_size; i++) { + d_bias_data[i] = 0; + d_scale_data[i] = 0; + } + for (int k = 0; k < N; ++k) { + for (int i = 0; i < item_size; i += slot_dim) { + if (!(x_data[offset + i] > -min_precision && + x_data[offset + i] < min_precision)) { + // show != 0 + for (int j = i; j < i + slot_dim; ++j) { + d_x_data[offset + j] = dy_data[offset + j] * + scales_data[j] * scale_w_data[j]; + d_bias_data[j] += dy_data[offset + j]; + d_scale_data[j] += (x_data[offset + j] - mean_data[j]) * + inv_var_data[j] * dy_data[offset + j]; + } + } + } + offset += item_size; + } + } } } @@ -466,6 +686,8 @@ class DataNormGradMaker : public framework::SingleGradOpMaker { op->SetInput("X", this->Input("X")); op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y")); + op->SetInput("scale_w", this->Input("scale_w")); + op->SetInput("bias", this->Input("bias")); op->SetOutput("BatchSize", this->Input("BatchSize")); op->SetOutput("BatchSum", this->Input("BatchSum")); op->SetOutput("BatchSquareSum", this->Input("BatchSquareSum")); @@ -481,6 +703,9 @@ class DataNormGradMaker : public framework::SingleGradOpMaker { this->InputGrad("BatchSum")); op->SetOutput(framework::GradVarName("BatchSquareSum"), this->InputGrad("BatchSquareSum")); + op->SetOutput(framework::GradVarName("scale_w"), + this->InputGrad("scale_w")); + op->SetOutput(framework::GradVarName("bias"), this->InputGrad("bias")); } }; diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index e4094bb517..c9c19e903f 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -3157,7 +3157,8 @@ def data_norm(input, do_model_average_for_mean_and_var=True, slot_dim=-1, sync_stats=False, - summary_decay_rate=0.9999999): + summary_decay_rate=0.9999999, + enable_scale_and_shift=False): """ **Data Normalization Layer** @@ -3206,6 +3207,7 @@ def data_norm(input, sync_stats(bool, Default False): When running with multiple GPU cards, using allreduce to sync the summary messages. summary_decay_rate(float, Default 0.9999999): The decay rate when updating summary. + enable_scale_and_shift(bool, Default False): do scale&shift after normalization. Returns: Variable: A tensor variable which is the result after applying data normalization on the input. @@ -3236,12 +3238,35 @@ def data_norm(input, batch_size_default = 1e4 batch_sum_default = 0.0 batch_square_sum_default = 1e4 + scale_w_default = 1.0 + bias_default = 0.0 if param_attr and isinstance(param_attr, dict): batch_size_default = param_attr.get("batch_size", 1e4) batch_sum_default = param_attr.get("batch_sum", 0.0) batch_square_sum_default = param_attr.get("batch_square", 1e4) - + if enable_scale_and_shift: + scale_w_default = param_attr.get("scale_w", 1.0) + bias_default = param_attr.get("bias", 0.0) + + # create scale and shift(bias) when enable_scale_and_shift is True + if name == None: + name = "dn" + if enable_scale_and_shift: + scale_w = helper.create_parameter( + attr=ParamAttr( + name=name + '.scale_w', + initializer=Constant(value=float(scale_w_default)), + trainable=True), + shape=param_shape, + dtype=input.dtype) + bias = helper.create_parameter( + attr=ParamAttr( + name=name + '.bias', + initializer=Constant(value=float(bias_default)), + trainable=True), + shape=param_shape, + dtype=input.dtype) # create parameter batch_size = helper.create_parameter( attr=ParamAttr( @@ -3272,14 +3297,18 @@ def data_norm(input, data_norm_out = input if in_place else helper.create_variable(dtype=dtype) + inputs = { + "X": input, + "BatchSize": batch_size, + "BatchSum": batch_sum, + "BatchSquareSum": batch_square_sum + } + if enable_scale_and_shift: + inputs["scale_w"] = scale_w + inputs["bias"] = bias helper.append_op( type="data_norm", - inputs={ - "X": input, - "BatchSize": batch_size, - "BatchSum": batch_sum, - "BatchSquareSum": batch_square_sum - }, + inputs=inputs, outputs={ "Y": data_norm_out, "Means": means, @@ -3292,7 +3321,8 @@ def data_norm(input, "epsilon": epsilon, "slot_dim": slot_dim, "sync_stats": sync_stats, - "summary_decay_rate": summary_decay_rate + "summary_decay_rate": summary_decay_rate, + "enable_scale_and_shift": enable_scale_and_shift }) return helper.append_activation(data_norm_out) diff --git a/python/paddle/fluid/tests/unittests/test_data_norm_op.py b/python/paddle/fluid/tests/unittests/test_data_norm_op.py index e2bbf8a077..0b7ed20f4b 100644 --- a/python/paddle/fluid/tests/unittests/test_data_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_data_norm_op.py @@ -24,6 +24,7 @@ import paddle.fluid.layers as layers import os from op_test import OpTest from paddle.fluid.framework import grad_var_name +from paddle.fluid import Program, program_guard def _reference_testing(x, batch_size, batch_sum, batch_square_sum, slot_dim=-1): @@ -72,7 +73,13 @@ class TestDataNormOpInference(unittest.TestCase): def __assert_close(self, tensor, np_array, msg, atol=1e-4): self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg) - def check_with_place(self, place, data_layout, dtype, shape, slot_dim=-1): + def check_with_place(self, + place, + data_layout, + dtype, + shape, + slot_dim=-1, + enable_scale_and_shift=False): """ do forward and check @@ -82,7 +89,7 @@ class TestDataNormOpInference(unittest.TestCase): dtype(dtype): np.float32 shape(list): input shape slot_dim(int): dimension of one slot. Refer to data_norm api. - + enable_scale_and_shift(bool): if enable scale and shift after normalization. """ epsilon = 0.00001 @@ -127,21 +134,49 @@ class TestDataNormOpInference(unittest.TestCase): mean_tensor = create_or_get_tensor(scope, "mean", None, place) scales_tensor = create_or_get_tensor(scope, "scales", None, place) - data_norm_op = Operator( - "data_norm", - # inputs - X="x_val", - BatchSize="batch_size", - BatchSum="batch_sum", - BatchSquareSum="batch_square_sum", - # outputs - Y="y_out", - Means="mean", - Scales="scales", - # attrs - epsilon=epsilon, - use_mkldnn=self.use_mkldnn, - slot_dim=slot_dim) + if not enable_scale_and_shift: + data_norm_op = Operator( + "data_norm", + # inputs + X="x_val", + BatchSize="batch_size", + BatchSum="batch_sum", + BatchSquareSum="batch_square_sum", + # outputs + Y="y_out", + Means="mean", + Scales="scales", + # attrs + epsilon=epsilon, + use_mkldnn=self.use_mkldnn, + slot_dim=slot_dim, + enable_scale_and_shift=False) + else: + scale_w = np.ones(scale_shape).astype(np.float32) + bias = np.zeros(scale_shape).astype(np.float32) + scale_w_tensor = create_or_get_tensor( + scope, "scale_w", + OpTest.np_dtype_to_fluid_dtype(scale_w), place) + bias_tensor = create_or_get_tensor( + scope, "bias", OpTest.np_dtype_to_fluid_dtype(bias), place) + data_norm_op = Operator( + "data_norm", + # inputs + X="x_val", + BatchSize="batch_size", + BatchSum="batch_sum", + BatchSquareSum="batch_square_sum", + scale_w="scale_w", + bias="bias", + # outputs + Y="y_out", + Means="mean", + Scales="scales", + # attrs + epsilon=epsilon, + use_mkldnn=self.use_mkldnn, + slot_dim=slot_dim, + enable_scale_and_shift=True) data_norm_op.run(scope, place) @@ -162,11 +197,13 @@ class TestDataNormOpInference(unittest.TestCase): for place in places: for data_format in ["NCHW", "NHWC"]: for slot_dim in [-1, 1]: - self.check_with_place( - place, - data_format, - self.dtype, [2, 3], - slot_dim=slot_dim) + for enable_scale_and_shift in [False, True]: + self.check_with_place( + place, + data_format, + self.dtype, [2, 3], + slot_dim=slot_dim, + enable_scale_and_shift=enable_scale_and_shift) class TestDataNormOp(OpTest): @@ -220,6 +257,130 @@ class TestDataNormOp(OpTest): self.check_grad(['X'], 'Y', no_grad_set=set([])) +class TestDataNormOpWithEnableScaleAndShift(OpTest): + """ + test class for data norm op + test forward and backward + """ + + def setUp(self): + """ + init data norm op test env + """ + self.op_type = 'data_norm' + self.use_mkldnn = False + epsilon = 0.00001 + slot_dim = -1 + enable_scale_and_shitf = True + x_shape = [2, 50] + scale_shape = [50] + tp = np.float32 + + x_val = np.random.uniform(-1, 1, x_shape).astype(tp) + batch_size = np.ones(scale_shape).astype(tp) + batch_size *= 1e4 + batch_sum = np.zeros(scale_shape).astype(tp) + batch_square_sum = np.ones(scale_shape).astype(tp) + batch_square_sum *= 1e4 + scale_w = np.ones(scale_shape).astype(tp) + bias = np.zeros(scale_shape).astype(tp) + + y = np.array(x_val) + + mean = np.zeros(x_shape).astype(tp) + scale = np.ones(x_shape).astype(tp) + + self.inputs = { + "X": x_val, + "BatchSize": batch_size, + "BatchSum": batch_sum, + "BatchSquareSum": batch_square_sum, + "scale_w": scale_w, + "bias": bias + } + self.outputs = {"Y": y, "Means": mean, "Scales": scale} + self.attrs = { + "epsilon": epsilon, + "use_mkldnn": self.use_mkldnn, + "slot_dim": slot_dim, + "enable_scale_and_shift": True + } + + def test_check_output(self): + """ + test check forward, check output + """ + self.check_output() + + def test_check_grad(self): + """ + test check backward, check grad + """ + self.check_grad(['X'], 'Y', no_grad_set=set([])) + + +class TestDataNormOpWithEnableScaleAndShift_1(OpTest): + """ + test class for data norm op + test forward and backward + """ + + def setUp(self): + """ + init data norm op test env + """ + self.op_type = 'data_norm' + self.use_mkldnn = False + epsilon = 0.00001 + slot_dim = 1 + enable_scale_and_shitf = True + x_shape = [2, 50] + scale_shape = [50] + tp = np.float32 + + x_val = np.random.uniform(-1, 1, x_shape).astype(tp) + batch_size = np.ones(scale_shape).astype(tp) + batch_size *= 1e4 + batch_sum = np.zeros(scale_shape).astype(tp) + batch_square_sum = np.ones(scale_shape).astype(tp) + batch_square_sum *= 1e4 + scale_w = np.ones(scale_shape).astype(tp) + bias = np.zeros(scale_shape).astype(tp) + + y = np.array(x_val) + + mean = np.zeros(x_shape).astype(tp) + scale = np.ones(x_shape).astype(tp) + + self.inputs = { + "X": x_val, + "BatchSize": batch_size, + "BatchSum": batch_sum, + "BatchSquareSum": batch_square_sum, + "scale_w": scale_w, + "bias": bias + } + self.outputs = {"Y": y, "Means": mean, "Scales": scale} + self.attrs = { + "epsilon": epsilon, + "use_mkldnn": self.use_mkldnn, + "slot_dim": slot_dim, + "enable_scale_and_shift": True + } + + def test_check_output(self): + """ + test check forward, check output + """ + self.check_output() + + def test_check_grad(self): + """ + test check backward, check grad + """ + self.check_grad(['X'], 'Y', no_grad_set=set([])) + + class TestDataNormOpWithSlotDim(OpTest): """ test class for data norm op @@ -399,5 +560,14 @@ class TestDataNormOpWithSyncStats(unittest.TestCase): os.remove(f) +class TestDataNormOpErrorr(unittest.TestCase): + def test_errors(self): + with program_guard(Program(), Program()): + x2 = fluid.layers.data(name='x2', shape=[3, 4], dtype="int32") + #self.assertRaises(TypeError, fluid.data_norm, x2) + fluid.layers.data_norm( + input=x2, param_attr={}, enable_scale_and_shift=True) + + if __name__ == '__main__': unittest.main() -- GitLab