未验证 提交 5b69242f 编写于 作者: Y yaoxuefeng 提交者: GitHub

modify datanorm op test=develop (#23030)

上级 3e1676fa
......@@ -53,7 +53,9 @@ const std::unordered_set<std::string> op_has_unsed_vars_white_list = {
"precision_recall", // 1
"fusion_seqpool_cvm_concat", // 2
"fused_batch_norm_act", // 2
"fused_batch_norm_act_grad" // 2
"fused_batch_norm_act_grad", // 2
"data_norm", // 0
"data_norm_grad", // 0
};
namespace paddle {
......
......@@ -51,6 +51,17 @@ class DataNormOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasOutput("Means"), "");
PADDLE_ENFORCE(ctx->HasOutput("Scales"), "");
PADDLE_ENFORCE(ctx->HasOutput("Y"), "");
bool enable_scale_and_shift =
ctx->Attrs().Get<bool>("enable_scale_and_shift");
if (enable_scale_and_shift) {
PADDLE_ENFORCE_EQ(
ctx->HasInput("scale_w"), true,
platform::errors::InvalidArgument(
"Input(scale_w) of DataNormOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("bias"), true,
platform::errors::InvalidArgument(
"Input(bias) of DataNormOp should not be null."));
}
const auto x_dims = ctx->GetInputDim("X");
const DataLayout data_layout = framework::StringToDataLayout(
......@@ -72,6 +83,45 @@ class DataNormOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSquareSum")[0], C);
}
if (enable_scale_and_shift) {
auto scale_dim = ctx->GetInputDim("scale_w");
auto bias_dim = ctx->GetInputDim("bias");
PADDLE_ENFORCE_EQ(
scale_dim.size(), 1UL,
platform::errors::InvalidArgument("the dimensionof scale"
"must equal to 1. But received: "
"the shape of scale is [%s], "
"the dimensionof scale is [%d]",
scale_dim, scale_dim.size()));
PADDLE_ENFORCE_EQ(
bias_dim.size(), 1UL,
platform::errors::InvalidArgument("the dimension of bias"
"must equal to 1. But received: "
"the shape of bias is [%s],"
"the dimension of bias is [%d]",
bias_dim, bias_dim.size()));
bool check = true;
if ((!ctx->IsRuntime()) && (framework::product(scale_dim) <= 0 ||
framework::product(bias_dim) <= 0)) {
check = false;
}
if (check) {
PADDLE_ENFORCE_EQ(scale_dim[0], C,
platform::errors::InvalidArgument(
"the shape of scale must equal to [%d]"
"But received: the shape of scale is [%d]",
C, scale_dim[0]));
PADDLE_ENFORCE_EQ(bias_dim[0], C,
platform::errors::InvalidArgument(
"the shape of bias must equal to [%d]"
"But received: the shape of bias is [%d]",
C, bias_dim[0]));
}
}
ctx->SetOutputDim("Y", x_dims);
ctx->SetOutputDim("Means", {C});
ctx->SetOutputDim("Scales", {C});
......@@ -99,6 +149,17 @@ class DataNormOp : public framework::OperatorWithKernel {
ctx, "BatchSquareSum"),
"BatchSquareSum input should be of float type");
bool enable_scale_and_shift = ctx.Attr<bool>("enable_scale_and_shift");
if (enable_scale_and_shift) {
PADDLE_ENFORCE_EQ(dn_param_type,
OperatorWithKernel::IndicateVarDataType(ctx, "scale_w"),
platform::errors::InvalidArgument(
"scale_w input should be of float type"));
PADDLE_ENFORCE_EQ(dn_param_type,
OperatorWithKernel::IndicateVarDataType(ctx, "bias"),
platform::errors::InvalidArgument(
"bias input should be of float type"));
}
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
......@@ -133,6 +194,19 @@ class DataNormOpMaker : public framework::OpProtoAndCheckerMaker {
"summary_decay_rate",
"(float, default 0.9999999) The decay rate when update the summary")
.SetDefault(0.9999999);
AddAttr<bool>(
"enable_scale_and_shift",
"(bool, default false) Set to true to enable scale and shift such as "
"batch_norm op")
.SetDefault(false);
AddInput("scale_w",
"scale_w is a 1-dimensional tensor of size C "
"that is applied to the output")
.AsDispensable();
AddInput("bias",
"bias is a 1-dimensional tensor of size C "
"that is applied to the output")
.AsDispensable();
AddAttr<std::string>("data_layout", "").SetDefault("NCHW");
AddAttr<bool>("sync_stats", "(bool, default false) only used in multi-GPU")
.SetDefault(false);
......@@ -194,7 +268,6 @@ class DataNormKernel<platform::CPUDeviceContext, T>
// alloc memory
T *y_data = y->mutable_data<T>(ctx.GetPlace());
Eigen::Array<T, Eigen::Dynamic, 1> inv_std(C);
ConstEigenVectorArrayMap<T> b_size_arr(
ctx.Input<Tensor>("BatchSize")->data<T>(), C);
ConstEigenVectorArrayMap<T> b_sum_arr(
......@@ -210,6 +283,7 @@ class DataNormKernel<platform::CPUDeviceContext, T>
const T *means_data = mean_out->data<T>();
const T *x_data = x->data<T>();
const T *scales_data = scales->data<T>();
const int slot_dim = ctx.Attr<int>("slot_dim");
T min_precision = 1e-7f;
......@@ -218,7 +292,8 @@ class DataNormKernel<platform::CPUDeviceContext, T>
case DataLayout::kNHWC: {
// if slot_dim is set and batch size is larger than zero, we choose
// to check if show number is zero, if so, skip normalization.
if (slot_dim > 0 && N > 0) {
if (slot_dim > 0 && N > 0 &&
(!ctx.Attr<bool>("enable_scale_and_shift"))) {
const int item_size = x->numel() / N;
// location of show number in one embedding
int offset = 0;
......@@ -239,10 +314,56 @@ class DataNormKernel<platform::CPUDeviceContext, T>
offset += item_size;
}
} else {
EigenArrayMap<T>(y_data, C, N) =
(ConstEigenArrayMap<T>(x->data<T>(), C, N).colwise() - means_arr)
.colwise() *
scales_arr;
if (!ctx.Attr<bool>("enable_scale_and_shift") && slot_dim <= 0) {
EigenArrayMap<T>(y_data, C, N) =
(ConstEigenArrayMap<T>(x->data<T>(), C, N).colwise() -
means_arr)
.colwise() *
scales_arr;
} else if (ctx.Attr<bool>("enable_scale_and_shift") &&
slot_dim <= 0) {
const auto *scale_w = ctx.Input<Tensor>("scale_w");
const auto *bias = ctx.Input<Tensor>("bias");
ConstEigenVectorArrayMap<T> scale_w_arr(scale_w->data<T>(), C);
ConstEigenVectorArrayMap<T> bias_arr(bias->data<T>(), C);
Eigen::Array<T, Eigen::Dynamic, 1> new_scale =
scales_arr * scale_w_arr;
Eigen::Array<T, Eigen::Dynamic, 1> new_bias =
bias_arr - means_arr * scales_arr * scale_w_arr;
EigenArrayMap<T>(y_data, C, N) =
(ConstEigenArrayMap<T>(x->data<T>(), C, N).colwise() *
new_scale)
.colwise() +
new_bias;
} else {
const int item_size = x->numel() / N;
const auto *scale_w = ctx.Input<Tensor>("scale_w");
const auto *bias = ctx.Input<Tensor>("bias");
const T *scale_w_data = scale_w->data<T>();
const T *bias_data = bias->data<T>();
// location of show number in one embedding
int offset = 0;
for (int k = 0; k < N; ++k) {
for (int i = 0; i < item_size; i += slot_dim) {
if (x_data[offset + i] > -min_precision &&
x_data[offset + i] < min_precision) {
// show = 0
memset(y_data + offset + i, 0, sizeof(T) * slot_dim);
} else {
for (int j = i; j < i + slot_dim; ++j) {
y_data[offset + j] = ((x_data[offset + j] - means_data[j]) *
scales_data[j]) *
scale_w_data[j] +
bias_data[j];
}
}
} // end for i
offset += item_size;
} // end for k
}
}
break;
}
......@@ -274,7 +395,8 @@ class DataNormGradOp : public framework::OperatorWithKernel {
"Output(BatchSquareSum) of DataNormGradOp should not be null."));
PADDLE_ENFORCE(ctx->HasInput("Means"), "");
PADDLE_ENFORCE(ctx->HasInput("Scales"), "");
bool enable_scale_and_shift =
ctx->Attrs().Get<bool>("enable_scale_and_shift");
// check output
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("BatchSize")), "");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("BatchSum")), "");
......@@ -294,6 +416,22 @@ class DataNormGradOp : public framework::OperatorWithKernel {
ctx->SetOutputDim(framework::GradVarName("BatchSize"), {C});
ctx->SetOutputDim(framework::GradVarName("BatchSum"), {C});
ctx->SetOutputDim(framework::GradVarName("BatchSquareSum"), {C});
if (enable_scale_and_shift) {
const bool has_scale_grad =
ctx->HasOutput(framework::GradVarName("scale_w"));
const bool has_bias_grad = ctx->HasOutput(framework::GradVarName("bias"));
PADDLE_ENFORCE_EQ((has_scale_grad == has_bias_grad), true,
platform::errors::InvalidArgument(
"Output(Scale@GRAD) and Output(Bias@GRAD)"
"must be null or not be null at same time. "
"But now, has Scale@Grad=[%d], has Bias@GRAD=[%d]",
has_scale_grad, has_bias_grad));
if (has_scale_grad) {
ctx->SetOutputDim(framework::GradVarName("scale_w"), {C});
ctx->SetOutputDim(framework::GradVarName("bias"), {C});
}
}
}
protected:
......@@ -353,18 +491,23 @@ class DataNormGradKernel<platform::CPUDeviceContext, T>
const int C =
(data_layout == DataLayout::kNCHW ? x_dims[1]
: x_dims[x_dims.size() - 1]);
// init output
Tensor *d_x = nullptr;
if (ctx.HasOutput(framework::GradVarName("X"))) {
d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
}
auto *d_batch_size =
ctx.Output<Tensor>(framework::GradVarName("BatchSize"));
auto *d_batch_sum = ctx.Output<Tensor>(framework::GradVarName("BatchSum"));
auto *d_batch_square_sum =
ctx.Output<Tensor>(framework::GradVarName("BatchSquareSum"));
const T *mean_data = means->data<T>();
const T *inv_var_data = scales->data<T>();
ConstEigenVectorArrayMap<T> mean_arr(mean_data, C);
ConstEigenVectorArrayMap<T> inv_var_arr(inv_var_data, C);
T *d_batch_size_data = d_batch_size->mutable_data<T>(ctx.GetPlace());
T *d_batch_sum_data = d_batch_sum->mutable_data<T>(ctx.GetPlace());
T *d_batch_square_sum_data =
......@@ -372,7 +515,6 @@ class DataNormGradKernel<platform::CPUDeviceContext, T>
EigenVectorArrayMap<T> d_batch_size_arr(d_batch_size_data, C);
EigenVectorArrayMap<T> d_batch_sum_arr(d_batch_sum_data, C);
EigenVectorArrayMap<T> d_batch_square_sum_arr(d_batch_square_sum_data, C);
d_batch_size_arr.setZero();
d_batch_sum_arr.setZero();
d_batch_square_sum_arr.setZero();
......@@ -392,8 +534,86 @@ class DataNormGradKernel<platform::CPUDeviceContext, T>
if (d_x != nullptr) {
EigenArrayMap<T> d_x_arr(d_x->mutable_data<T>(ctx.GetPlace()), C, N);
d_x_arr.setZero();
for (int nc = 0; nc < N; ++nc) {
d_x_arr.col(nc) = d_y_arr.col(nc) * scales_arr;
if (!ctx.Attr<bool>("enable_scale_and_shift")) {
for (int nc = 0; nc < N; ++nc) {
d_x_arr.col(nc) = d_y_arr.col(nc) * scales_arr;
}
} else {
const auto *scale_w = ctx.Input<Tensor>("scale_w");
auto *d_scale =
ctx.Output<Tensor>(framework::GradVarName("scale_w"));
auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("bias"));
ConstEigenVectorArrayMap<T> scale_arr(scale_w->data<T>(), C);
T *d_bias_data = nullptr;
T *d_scale_data = nullptr;
d_scale->mutable_data<T>(ctx.GetPlace());
d_bias->mutable_data<T>(ctx.GetPlace());
d_bias_data = d_bias->mutable_data<T>(ctx.GetPlace());
d_scale_data = d_scale->mutable_data<T>(ctx.GetPlace());
EigenVectorArrayMap<T> d_bias_arr(d_bias_data, C);
EigenVectorArrayMap<T> d_scale_arr(d_scale_data, C);
Tensor dy_sum;
dy_sum.Resize({C});
dy_sum.mutable_data<T>(ctx.GetPlace());
EigenVectorArrayMap<T> dy_sum_arr(
dy_sum.mutable_data<T>(ctx.GetPlace()), C);
Tensor dy_mul_x_sub_mean_mul_invstd_sum;
dy_mul_x_sub_mean_mul_invstd_sum.Resize({C});
dy_mul_x_sub_mean_mul_invstd_sum.mutable_data<T>(ctx.GetPlace());
EigenVectorArrayMap<T> dy_mul_x_sub_mean_mul_invstd_sum_arr(
dy_mul_x_sub_mean_mul_invstd_sum.mutable_data<T>(
ctx.GetPlace()),
C);
dy_sum_arr.setZero();
dy_mul_x_sub_mean_mul_invstd_sum_arr.setZero();
if (slot_dim <= 0) {
for (int n = 0; n < N; ++n) {
dy_sum_arr += d_y_arr.col(n);
dy_mul_x_sub_mean_mul_invstd_sum_arr +=
((x_arr.col(n) - mean_arr) * inv_var_arr * d_y_arr.col(n));
}
if (d_scale && d_bias) {
d_bias_arr = dy_sum_arr;
d_scale_arr = dy_mul_x_sub_mean_mul_invstd_sum_arr;
}
for (int nc = 0; nc < N; ++nc) {
d_x_arr.col(nc) = d_y_arr.col(nc) * scales_arr * scale_arr;
}
} else {
int offset = 0;
const int item_size = x->numel() / N;
T *d_x_data = d_x->mutable_data<T>(ctx.GetPlace());
T *d_scale_data = d_scale->mutable_data<T>(ctx.GetPlace());
T *d_bias_data = d_bias->mutable_data<T>(ctx.GetPlace());
const T *dy_data = d_y->data<T>();
const T *scales_data = scales->data<T>();
const T *scale_w_data = scale_w->data<T>();
const T *x_data = x->data<T>();
for (int i = 0; i < item_size; i++) {
d_bias_data[i] = 0;
d_scale_data[i] = 0;
}
for (int k = 0; k < N; ++k) {
for (int i = 0; i < item_size; i += slot_dim) {
if (!(x_data[offset + i] > -min_precision &&
x_data[offset + i] < min_precision)) {
// show != 0
for (int j = i; j < i + slot_dim; ++j) {
d_x_data[offset + j] = dy_data[offset + j] *
scales_data[j] * scale_w_data[j];
d_bias_data[j] += dy_data[offset + j];
d_scale_data[j] += (x_data[offset + j] - mean_data[j]) *
inv_var_data[j] * dy_data[offset + j];
}
}
}
offset += item_size;
}
}
}
}
......@@ -466,6 +686,8 @@ class DataNormGradMaker : public framework::SingleGradOpMaker<T> {
op->SetInput("X", this->Input("X"));
op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y"));
op->SetInput("scale_w", this->Input("scale_w"));
op->SetInput("bias", this->Input("bias"));
op->SetOutput("BatchSize", this->Input("BatchSize"));
op->SetOutput("BatchSum", this->Input("BatchSum"));
op->SetOutput("BatchSquareSum", this->Input("BatchSquareSum"));
......@@ -481,6 +703,9 @@ class DataNormGradMaker : public framework::SingleGradOpMaker<T> {
this->InputGrad("BatchSum"));
op->SetOutput(framework::GradVarName("BatchSquareSum"),
this->InputGrad("BatchSquareSum"));
op->SetOutput(framework::GradVarName("scale_w"),
this->InputGrad("scale_w"));
op->SetOutput(framework::GradVarName("bias"), this->InputGrad("bias"));
}
};
......
......@@ -3157,7 +3157,8 @@ def data_norm(input,
do_model_average_for_mean_and_var=True,
slot_dim=-1,
sync_stats=False,
summary_decay_rate=0.9999999):
summary_decay_rate=0.9999999,
enable_scale_and_shift=False):
"""
**Data Normalization Layer**
......@@ -3206,6 +3207,7 @@ def data_norm(input,
sync_stats(bool, Default False): When running with multiple GPU cards, using allreduce to sync the
summary messages.
summary_decay_rate(float, Default 0.9999999): The decay rate when updating summary.
enable_scale_and_shift(bool, Default False): do scale&shift after normalization.
Returns:
Variable: A tensor variable which is the result after applying data normalization on the input.
......@@ -3236,12 +3238,35 @@ def data_norm(input,
batch_size_default = 1e4
batch_sum_default = 0.0
batch_square_sum_default = 1e4
scale_w_default = 1.0
bias_default = 0.0
if param_attr and isinstance(param_attr, dict):
batch_size_default = param_attr.get("batch_size", 1e4)
batch_sum_default = param_attr.get("batch_sum", 0.0)
batch_square_sum_default = param_attr.get("batch_square", 1e4)
if enable_scale_and_shift:
scale_w_default = param_attr.get("scale_w", 1.0)
bias_default = param_attr.get("bias", 0.0)
# create scale and shift(bias) when enable_scale_and_shift is True
if name == None:
name = "dn"
if enable_scale_and_shift:
scale_w = helper.create_parameter(
attr=ParamAttr(
name=name + '.scale_w',
initializer=Constant(value=float(scale_w_default)),
trainable=True),
shape=param_shape,
dtype=input.dtype)
bias = helper.create_parameter(
attr=ParamAttr(
name=name + '.bias',
initializer=Constant(value=float(bias_default)),
trainable=True),
shape=param_shape,
dtype=input.dtype)
# create parameter
batch_size = helper.create_parameter(
attr=ParamAttr(
......@@ -3272,14 +3297,18 @@ def data_norm(input,
data_norm_out = input if in_place else helper.create_variable(dtype=dtype)
inputs = {
"X": input,
"BatchSize": batch_size,
"BatchSum": batch_sum,
"BatchSquareSum": batch_square_sum
}
if enable_scale_and_shift:
inputs["scale_w"] = scale_w
inputs["bias"] = bias
helper.append_op(
type="data_norm",
inputs={
"X": input,
"BatchSize": batch_size,
"BatchSum": batch_sum,
"BatchSquareSum": batch_square_sum
},
inputs=inputs,
outputs={
"Y": data_norm_out,
"Means": means,
......@@ -3292,7 +3321,8 @@ def data_norm(input,
"epsilon": epsilon,
"slot_dim": slot_dim,
"sync_stats": sync_stats,
"summary_decay_rate": summary_decay_rate
"summary_decay_rate": summary_decay_rate,
"enable_scale_and_shift": enable_scale_and_shift
})
return helper.append_activation(data_norm_out)
......
......@@ -24,6 +24,7 @@ import paddle.fluid.layers as layers
import os
from op_test import OpTest
from paddle.fluid.framework import grad_var_name
from paddle.fluid import Program, program_guard
def _reference_testing(x, batch_size, batch_sum, batch_square_sum, slot_dim=-1):
......@@ -72,7 +73,13 @@ class TestDataNormOpInference(unittest.TestCase):
def __assert_close(self, tensor, np_array, msg, atol=1e-4):
self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg)
def check_with_place(self, place, data_layout, dtype, shape, slot_dim=-1):
def check_with_place(self,
place,
data_layout,
dtype,
shape,
slot_dim=-1,
enable_scale_and_shift=False):
"""
do forward and check
......@@ -82,7 +89,7 @@ class TestDataNormOpInference(unittest.TestCase):
dtype(dtype): np.float32
shape(list): input shape
slot_dim(int): dimension of one slot. Refer to data_norm api.
enable_scale_and_shift(bool): if enable scale and shift after normalization.
"""
epsilon = 0.00001
......@@ -127,21 +134,49 @@ class TestDataNormOpInference(unittest.TestCase):
mean_tensor = create_or_get_tensor(scope, "mean", None, place)
scales_tensor = create_or_get_tensor(scope, "scales", None, place)
data_norm_op = Operator(
"data_norm",
# inputs
X="x_val",
BatchSize="batch_size",
BatchSum="batch_sum",
BatchSquareSum="batch_square_sum",
# outputs
Y="y_out",
Means="mean",
Scales="scales",
# attrs
epsilon=epsilon,
use_mkldnn=self.use_mkldnn,
slot_dim=slot_dim)
if not enable_scale_and_shift:
data_norm_op = Operator(
"data_norm",
# inputs
X="x_val",
BatchSize="batch_size",
BatchSum="batch_sum",
BatchSquareSum="batch_square_sum",
# outputs
Y="y_out",
Means="mean",
Scales="scales",
# attrs
epsilon=epsilon,
use_mkldnn=self.use_mkldnn,
slot_dim=slot_dim,
enable_scale_and_shift=False)
else:
scale_w = np.ones(scale_shape).astype(np.float32)
bias = np.zeros(scale_shape).astype(np.float32)
scale_w_tensor = create_or_get_tensor(
scope, "scale_w",
OpTest.np_dtype_to_fluid_dtype(scale_w), place)
bias_tensor = create_or_get_tensor(
scope, "bias", OpTest.np_dtype_to_fluid_dtype(bias), place)
data_norm_op = Operator(
"data_norm",
# inputs
X="x_val",
BatchSize="batch_size",
BatchSum="batch_sum",
BatchSquareSum="batch_square_sum",
scale_w="scale_w",
bias="bias",
# outputs
Y="y_out",
Means="mean",
Scales="scales",
# attrs
epsilon=epsilon,
use_mkldnn=self.use_mkldnn,
slot_dim=slot_dim,
enable_scale_and_shift=True)
data_norm_op.run(scope, place)
......@@ -162,11 +197,13 @@ class TestDataNormOpInference(unittest.TestCase):
for place in places:
for data_format in ["NCHW", "NHWC"]:
for slot_dim in [-1, 1]:
self.check_with_place(
place,
data_format,
self.dtype, [2, 3],
slot_dim=slot_dim)
for enable_scale_and_shift in [False, True]:
self.check_with_place(
place,
data_format,
self.dtype, [2, 3],
slot_dim=slot_dim,
enable_scale_and_shift=enable_scale_and_shift)
class TestDataNormOp(OpTest):
......@@ -220,6 +257,130 @@ class TestDataNormOp(OpTest):
self.check_grad(['X'], 'Y', no_grad_set=set([]))
class TestDataNormOpWithEnableScaleAndShift(OpTest):
"""
test class for data norm op
test forward and backward
"""
def setUp(self):
"""
init data norm op test env
"""
self.op_type = 'data_norm'
self.use_mkldnn = False
epsilon = 0.00001
slot_dim = -1
enable_scale_and_shitf = True
x_shape = [2, 50]
scale_shape = [50]
tp = np.float32
x_val = np.random.uniform(-1, 1, x_shape).astype(tp)
batch_size = np.ones(scale_shape).astype(tp)
batch_size *= 1e4
batch_sum = np.zeros(scale_shape).astype(tp)
batch_square_sum = np.ones(scale_shape).astype(tp)
batch_square_sum *= 1e4
scale_w = np.ones(scale_shape).astype(tp)
bias = np.zeros(scale_shape).astype(tp)
y = np.array(x_val)
mean = np.zeros(x_shape).astype(tp)
scale = np.ones(x_shape).astype(tp)
self.inputs = {
"X": x_val,
"BatchSize": batch_size,
"BatchSum": batch_sum,
"BatchSquareSum": batch_square_sum,
"scale_w": scale_w,
"bias": bias
}
self.outputs = {"Y": y, "Means": mean, "Scales": scale}
self.attrs = {
"epsilon": epsilon,
"use_mkldnn": self.use_mkldnn,
"slot_dim": slot_dim,
"enable_scale_and_shift": True
}
def test_check_output(self):
"""
test check forward, check output
"""
self.check_output()
def test_check_grad(self):
"""
test check backward, check grad
"""
self.check_grad(['X'], 'Y', no_grad_set=set([]))
class TestDataNormOpWithEnableScaleAndShift_1(OpTest):
"""
test class for data norm op
test forward and backward
"""
def setUp(self):
"""
init data norm op test env
"""
self.op_type = 'data_norm'
self.use_mkldnn = False
epsilon = 0.00001
slot_dim = 1
enable_scale_and_shitf = True
x_shape = [2, 50]
scale_shape = [50]
tp = np.float32
x_val = np.random.uniform(-1, 1, x_shape).astype(tp)
batch_size = np.ones(scale_shape).astype(tp)
batch_size *= 1e4
batch_sum = np.zeros(scale_shape).astype(tp)
batch_square_sum = np.ones(scale_shape).astype(tp)
batch_square_sum *= 1e4
scale_w = np.ones(scale_shape).astype(tp)
bias = np.zeros(scale_shape).astype(tp)
y = np.array(x_val)
mean = np.zeros(x_shape).astype(tp)
scale = np.ones(x_shape).astype(tp)
self.inputs = {
"X": x_val,
"BatchSize": batch_size,
"BatchSum": batch_sum,
"BatchSquareSum": batch_square_sum,
"scale_w": scale_w,
"bias": bias
}
self.outputs = {"Y": y, "Means": mean, "Scales": scale}
self.attrs = {
"epsilon": epsilon,
"use_mkldnn": self.use_mkldnn,
"slot_dim": slot_dim,
"enable_scale_and_shift": True
}
def test_check_output(self):
"""
test check forward, check output
"""
self.check_output()
def test_check_grad(self):
"""
test check backward, check grad
"""
self.check_grad(['X'], 'Y', no_grad_set=set([]))
class TestDataNormOpWithSlotDim(OpTest):
"""
test class for data norm op
......@@ -399,5 +560,14 @@ class TestDataNormOpWithSyncStats(unittest.TestCase):
os.remove(f)
class TestDataNormOpErrorr(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
x2 = fluid.layers.data(name='x2', shape=[3, 4], dtype="int32")
#self.assertRaises(TypeError, fluid.data_norm, x2)
fluid.layers.data_norm(
input=x2, param_attr={}, enable_scale_and_shift=True)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册