未验证 提交 b5d8ba83 编写于 作者: Y yaoxuefeng 提交者: GitHub

fix data_norm op to avoid impractical normalization result test=develop (#21152)

* fix auc drop first commit test=develop

* update datanorm op

* update datanorm with enforce test=develop

* update test=develop

* update format test=develop

* update format

* update format test=develop

* add unit test test=develop

* update unit test test=develop

* update format test=develop

* update format test=develop

* update API description test=develop

* update API description test=develop

* update format test=develop

* fix codes as comments test=develop

* fix description as comments test=develop

* fix description as comments test=develop

* update codes.. test=develop
上级 67e88424
......@@ -125,6 +125,10 @@ class DataNormOpMaker : public framework::OpProtoAndCheckerMaker {
PADDLE_ENFORCE(epsilon >= 0.0f && epsilon <= 0.001f,
"'epsilon' should be between 0.0 and 0.001.");
});
AddAttr<int>("slot_dim",
"(int, default -1) Dimension of one slot if set, "
"when the input is concated by slot-wise embeddings")
.SetDefault(-1);
AddAttr<std::string>("data_layout", "").SetDefault("NCHW");
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
......@@ -182,7 +186,7 @@ class DataNormKernel<platform::CPUDeviceContext, T>
auto *scales = ctx.Output<Tensor>("Scales");
// alloc memory
y->mutable_data<T>(ctx.GetPlace());
T *y_data = y->mutable_data<T>(ctx.GetPlace());
Eigen::Array<T, Eigen::Dynamic, 1> inv_std(C);
ConstEigenVectorArrayMap<T> b_size_arr(
......@@ -198,14 +202,42 @@ class DataNormKernel<platform::CPUDeviceContext, T>
means_arr = b_sum_arr / b_size_arr;
scales_arr = (b_size_arr / b_square_sum_arr).sqrt();
const T *means_data = mean_out->data<T>();
const T *x_data = x->data<T>();
const T *scales_data = scales->data<T>();
const int slot_dim = ctx.Attr<int>("slot_dim");
T min_precision = 1e-7f;
switch (data_layout) {
case DataLayout::kNCHW: // because it's two dimensions, so make no
// difference
case DataLayout::kNCHW: // It's two dimensions, so make no difference
case DataLayout::kNHWC: {
EigenArrayMap<T>(y->mutable_data<T>(ctx.GetPlace()), C, N) =
(ConstEigenArrayMap<T>(x->data<T>(), C, N).colwise() - means_arr)
.colwise() *
scales_arr;
// if slot_dim is set and batch size is larger than zero, we choose
// to check if show number is zero, if so, skip normalization.
if (slot_dim > 0 && N > 0) {
const int item_size = x->numel() / N;
// location of show number in one embedding
int offset = 0;
for (int k = 0; k < N; ++k) {
for (int i = 0; i < item_size; i += slot_dim) {
if (x_data[offset + i] > -min_precision &&
x_data[offset + i] < min_precision) {
// show = 0
memset(y_data + offset + i, 0, sizeof(T) * slot_dim);
} else {
for (int j = i; j < i + slot_dim; ++j) {
y_data[offset + j] =
(x_data[offset + j] - means_data[j]) * scales_data[j];
}
}
}
offset += item_size;
}
} else {
EigenArrayMap<T>(y_data, C, N) =
(ConstEigenArrayMap<T>(x->data<T>(), C, N).colwise() - means_arr)
.colwise() *
scales_arr;
}
break;
}
default:
......@@ -321,20 +353,24 @@ class DataNormGradKernel<platform::CPUDeviceContext, T>
auto *d_batch_square_sum =
ctx.Output<Tensor>(framework::GradVarName("BatchSquareSum"));
EigenVectorArrayMap<T> d_batch_size_arr(
d_batch_size->mutable_data<T>(ctx.GetPlace()), C);
EigenVectorArrayMap<T> d_batch_sum_arr(
d_batch_sum->mutable_data<T>(ctx.GetPlace()), C);
EigenVectorArrayMap<T> d_batch_square_sum_arr(
d_batch_square_sum->mutable_data<T>(ctx.GetPlace()), C);
T *d_batch_size_data = d_batch_size->mutable_data<T>(ctx.GetPlace());
T *d_batch_sum_data = d_batch_sum->mutable_data<T>(ctx.GetPlace());
T *d_batch_square_sum_data =
d_batch_square_sum->mutable_data<T>(ctx.GetPlace());
EigenVectorArrayMap<T> d_batch_size_arr(d_batch_size_data, C);
EigenVectorArrayMap<T> d_batch_sum_arr(d_batch_sum_data, C);
EigenVectorArrayMap<T> d_batch_square_sum_arr(d_batch_square_sum_data, C);
d_batch_size_arr.setZero();
d_batch_sum_arr.setZero();
d_batch_square_sum_arr.setZero();
const T *x_data = x->data<T>();
const T *means_data = means->data<T>();
const float epsilon = ctx.Attr<float>("epsilon");
switch (
data_layout) { // because it's two dimensions, so make no difference
T min_precision = 1e-7f;
const int slot_dim = ctx.Attr<int>("slot_dim");
switch (data_layout) { // it's two dimensions, make no difference
case DataLayout::kNCHW:
case DataLayout::kNHWC: {
ConstEigenVectorArrayMap<T> scales_arr(scales->data<T>(), C);
......@@ -349,24 +385,60 @@ class DataNormGradKernel<platform::CPUDeviceContext, T>
}
}
// calculate data sum and squre sum
ConstEigenVectorArrayMap<T> batch_size_arr(batch_size->data<T>(), C);
ConstEigenVectorArrayMap<T> batch_sum_arr(batch_sum->data<T>(), C);
ConstEigenVectorArrayMap<T> batch_square_sum_arr(
batch_square_sum->data<T>(), C);
Eigen::Array<T, Eigen::Dynamic, 1> sample_sum(C);
Eigen::Array<T, Eigen::Dynamic, 1> sample_square_sum(C);
// calculate data sample sum and square sum
sample_sum.setZero();
sample_square_sum.setZero();
for (int nc = 0; nc < N; ++nc) {
sample_sum += x_arr.col(nc);
sample_square_sum += (x_arr.col(nc) - means_arr).square();
if (slot_dim > 0 && N > 0) {
// if slot_dim is set and batch size is larger than zero, we choose
// to check if show number is zero, if so, skip update statistics.
int offset = 0;
const int item_size = x->numel() / N;
for (int k = 0; k < N; ++k) {
for (int i = 0; i < item_size; i += slot_dim) {
if (!(x_data[offset + i] > -min_precision &&
x_data[offset + i] < min_precision)) {
// show != 0
for (int j = i; j < i + slot_dim; ++j) {
d_batch_size_data[j] += 1;
d_batch_sum_data[j] += x_data[offset + j];
d_batch_square_sum_data[j] +=
(x_data[offset + j] - means_data[j]) *
(x_data[offset + j] - means_data[j]);
}
}
}
offset += item_size;
}
for (int i = 0; i < item_size; i += slot_dim) {
for (int j = i; j < i + slot_dim; ++j) {
if (d_batch_size_data[j] >= 1) {
d_batch_sum_data[j] /= d_batch_size_data[j];
d_batch_square_sum_data[j] =
d_batch_square_sum_data[j] / d_batch_size_data[j] +
d_batch_size_data[j] * epsilon;
d_batch_size_data[j] = 1;
}
}
}
} else {
// calculate data sum and squre sum
ConstEigenVectorArrayMap<T> batch_size_arr(batch_size->data<T>(), C);
ConstEigenVectorArrayMap<T> batch_sum_arr(batch_sum->data<T>(), C);
ConstEigenVectorArrayMap<T> batch_square_sum_arr(
batch_square_sum->data<T>(), C);
Eigen::Array<T, Eigen::Dynamic, 1> sample_sum(C);
Eigen::Array<T, Eigen::Dynamic, 1> sample_square_sum(C);
// calculate data sample sum and square sum
sample_sum.setZero();
sample_square_sum.setZero();
for (int nc = 0; nc < N; ++nc) {
sample_sum += x_arr.col(nc);
sample_square_sum += (x_arr.col(nc) - means_arr).square();
}
// calculate gradient
d_batch_size_arr.setConstant(N);
d_batch_sum_arr = sample_sum;
d_batch_square_sum_arr =
sample_square_sum + d_batch_size_arr * epsilon;
}
// calculate gradient
d_batch_size_arr.setConstant(N);
d_batch_sum_arr = sample_sum;
d_batch_square_sum_arr = sample_square_sum + d_batch_size_arr * epsilon;
break;
}
default:
......
......@@ -2703,7 +2703,8 @@ def data_norm(input,
name=None,
moving_mean_name=None,
moving_variance_name=None,
do_model_average_for_mean_and_var=True):
do_model_average_for_mean_and_var=True,
slot_dim=-1):
"""
**Data Normalization Layer**
......@@ -2742,6 +2743,13 @@ def data_norm(input,
moving_variance_name(string, Default None): The name of the moving_variance which store the global Variance.
do_model_average_for_mean_and_var(bool, Default True): Whether parameter mean and variance
should do model average when model average is enabled.
slot_dim(int): The embedding dimension of one slot. Slot is a set of one specific feature. In pslib mode, we
distinguish feature ids by slot and pull their embeddings from parameter server (pslib). The first
place of the embedding is the historical show number (occurence time of this feature id with a label 0).
If the input of this op is concated by slot-wise embeddings, and the show number is zero when this slot
is new or empty, the normalization result may be impractical. To avoid this, we add slot_dim to locate
the show number and judge if the show number is zero. If so, we choose to skip normalization on this
embedding.
Returns:
Variable: A tensor variable which is the result after applying data normalization on the input.
......@@ -2819,7 +2827,8 @@ def data_norm(input,
outputs={"Y": data_norm_out,
"Means": means,
"Scales": scales},
attrs={"epsilon": epsilon})
attrs={"epsilon": epsilon,
"slot_dim": slot_dim})
return helper.append_activation(data_norm_out)
......
......@@ -24,14 +24,24 @@ from op_test import OpTest
from paddle.fluid.framework import grad_var_name
def _reference_testing(x, batch_size, batch_sum, batch_square_sum):
def _reference_testing(x, batch_size, batch_sum, batch_square_sum, slot_dim=-1):
x_shape = x.shape
means_arr = batch_sum / batch_size
scales_arr = np.sqrt(batch_size / batch_square_sum)
for i in range(x_shape[0]):
x[i] -= means_arr
x[i] *= scales_arr
y = np.array(x)
min_precision = 1e-7
if slot_dim <= 0:
for i in range(x_shape[0]):
x[i] -= means_arr
x[i] *= scales_arr
y = np.array(x)
else:
y = np.zeros(x_shape).astype(np.float32)
for i in range(x_shape[0]):
for j in range(0, x_shape[1], slot_dim):
if x[i][j] <= -min_precision or x[i][j] >= min_precision:
for k in range(0, slot_dim):
y[i][j + k] = (
x[i][j + k] - means_arr[j + k]) * scales_arr[j + k]
return y
......@@ -60,7 +70,7 @@ class TestDataNormOpInference(unittest.TestCase):
def __assert_close(self, tensor, np_array, msg, atol=1e-4):
self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg)
def check_with_place(self, place, data_layout, dtype, shape):
def check_with_place(self, place, data_layout, dtype, shape, slot_dim=-1):
"""
do forward and check
......@@ -69,6 +79,8 @@ class TestDataNormOpInference(unittest.TestCase):
data_layout(str): NCHW or NWHC
dtype(dtype): np.float32
shape(list): input shape
slot_dim(int): dimension of one slot. Refer to data_norm api.
"""
epsilon = 0.00001
......@@ -81,6 +93,8 @@ class TestDataNormOpInference(unittest.TestCase):
x_val = np.random.random_sample(x_shape).astype(dtype)
x_val = x_val - 0.5
x_val[0][1] = 0.0
x_val[1][1] = 0.0
batch_size = np.ones(scale_shape).astype(np.float32)
batch_size *= 1e4
batch_sum = np.zeros(scale_shape).astype(np.float32)
......@@ -88,7 +102,7 @@ class TestDataNormOpInference(unittest.TestCase):
batch_square_sum *= 1e4
y_out = _reference_testing(x_val, batch_size, batch_sum,
batch_square_sum).astype(dtype)
batch_square_sum, slot_dim).astype(dtype)
scope = core.Scope()
......@@ -124,7 +138,8 @@ class TestDataNormOpInference(unittest.TestCase):
Scales="scales",
# attrs
epsilon=epsilon,
use_mkldnn=self.use_mkldnn)
use_mkldnn=self.use_mkldnn,
slot_dim=slot_dim)
data_norm_op.run(scope, place)
......@@ -144,7 +159,12 @@ class TestDataNormOpInference(unittest.TestCase):
places = [core.CPUPlace()]
for place in places:
for data_format in ["NCHW", "NHWC"]:
self.check_with_place(place, data_format, self.dtype, [2, 3])
for slot_dim in [-1, 1]:
self.check_with_place(
place,
data_format,
self.dtype, [2, 3],
slot_dim=slot_dim)
class TestDataNormOp(OpTest):
......@@ -199,5 +219,62 @@ class TestDataNormOp(OpTest):
self.check_grad(['X'], 'Y', no_grad_set=set([]))
class TestDataNormOpWithSlotDim(OpTest):
"""
test class for data norm op
test forward and backward
"""
def setUp(self):
"""
init data norm op test env
"""
self.op_type = 'data_norm'
self.use_mkldnn = False
epsilon = 0.00001
slot_dim = 1
x_shape = [2, 3]
scale_shape = [3]
tp = np.float32
x_val = np.array([[-0.35702616, 0.0, -0.08306625],
[0.41199666, 0.0, -0.10180971]]).astype(tp)
batch_size = np.ones(scale_shape).astype(tp)
batch_size *= 1e4
batch_sum = np.zeros(scale_shape).astype(tp)
batch_square_sum = np.ones(scale_shape).astype(tp)
batch_square_sum *= 1e4
y = np.array(x_val)
mean = np.array([[0, 0, 0], [0, 0, 0]]).astype(tp)
scale = np.array([[1, 1, 1], [1, 1, 1]]).astype(tp)
self.inputs = {
"X": x_val,
"BatchSize": batch_size,
"BatchSum": batch_sum,
"BatchSquareSum": batch_square_sum
}
self.outputs = {"Y": y, "Means": mean, "Scales": scale}
self.attrs = {
"epsilon": epsilon,
"use_mkldnn": self.use_mkldnn,
"slot_dim": slot_dim
}
def test_check_output(self):
"""
test check forward, check output
"""
self.check_output()
def test_check_grad(self):
"""
test check backward, check grad
"""
self.check_grad(['X'], 'Y', no_grad_set=set([]))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册