diff --git a/paddle/fluid/operators/data_norm_op.cc b/paddle/fluid/operators/data_norm_op.cc index dc116ec69a2ebd956b6828a8e97a3d960700bb95..76fe6f53299a6b21c83ec72a8d5382c851914fd6 100644 --- a/paddle/fluid/operators/data_norm_op.cc +++ b/paddle/fluid/operators/data_norm_op.cc @@ -125,6 +125,10 @@ class DataNormOpMaker : public framework::OpProtoAndCheckerMaker { PADDLE_ENFORCE(epsilon >= 0.0f && epsilon <= 0.001f, "'epsilon' should be between 0.0 and 0.001."); }); + AddAttr("slot_dim", + "(int, default -1) Dimension of one slot if set, " + "when the input is concated by slot-wise embeddings") + .SetDefault(-1); AddAttr("data_layout", "").SetDefault("NCHW"); AddAttr("use_mkldnn", "(bool, default false) Only used in mkldnn kernel") @@ -182,7 +186,7 @@ class DataNormKernel auto *scales = ctx.Output("Scales"); // alloc memory - y->mutable_data(ctx.GetPlace()); + T *y_data = y->mutable_data(ctx.GetPlace()); Eigen::Array inv_std(C); ConstEigenVectorArrayMap b_size_arr( @@ -198,14 +202,42 @@ class DataNormKernel means_arr = b_sum_arr / b_size_arr; scales_arr = (b_size_arr / b_square_sum_arr).sqrt(); + const T *means_data = mean_out->data(); + const T *x_data = x->data(); + const T *scales_data = scales->data(); + const int slot_dim = ctx.Attr("slot_dim"); + T min_precision = 1e-7f; switch (data_layout) { - case DataLayout::kNCHW: // because it's two dimensions, so make no - // difference + case DataLayout::kNCHW: // It's two dimensions, so make no difference case DataLayout::kNHWC: { - EigenArrayMap(y->mutable_data(ctx.GetPlace()), C, N) = - (ConstEigenArrayMap(x->data(), C, N).colwise() - means_arr) - .colwise() * - scales_arr; + // if slot_dim is set and batch size is larger than zero, we choose + // to check if show number is zero, if so, skip normalization. + if (slot_dim > 0 && N > 0) { + const int item_size = x->numel() / N; + // location of show number in one embedding + int offset = 0; + for (int k = 0; k < N; ++k) { + for (int i = 0; i < item_size; i += slot_dim) { + if (x_data[offset + i] > -min_precision && + x_data[offset + i] < min_precision) { + // show = 0 + memset(y_data + offset + i, 0, sizeof(T) * slot_dim); + } else { + for (int j = i; j < i + slot_dim; ++j) { + y_data[offset + j] = + (x_data[offset + j] - means_data[j]) * scales_data[j]; + } + } + } + + offset += item_size; + } + } else { + EigenArrayMap(y_data, C, N) = + (ConstEigenArrayMap(x->data(), C, N).colwise() - means_arr) + .colwise() * + scales_arr; + } break; } default: @@ -321,20 +353,24 @@ class DataNormGradKernel auto *d_batch_square_sum = ctx.Output(framework::GradVarName("BatchSquareSum")); - EigenVectorArrayMap d_batch_size_arr( - d_batch_size->mutable_data(ctx.GetPlace()), C); - EigenVectorArrayMap d_batch_sum_arr( - d_batch_sum->mutable_data(ctx.GetPlace()), C); - EigenVectorArrayMap d_batch_square_sum_arr( - d_batch_square_sum->mutable_data(ctx.GetPlace()), C); + T *d_batch_size_data = d_batch_size->mutable_data(ctx.GetPlace()); + T *d_batch_sum_data = d_batch_sum->mutable_data(ctx.GetPlace()); + T *d_batch_square_sum_data = + d_batch_square_sum->mutable_data(ctx.GetPlace()); + EigenVectorArrayMap d_batch_size_arr(d_batch_size_data, C); + EigenVectorArrayMap d_batch_sum_arr(d_batch_sum_data, C); + EigenVectorArrayMap d_batch_square_sum_arr(d_batch_square_sum_data, C); d_batch_size_arr.setZero(); d_batch_sum_arr.setZero(); d_batch_square_sum_arr.setZero(); + const T *x_data = x->data(); + const T *means_data = means->data(); const float epsilon = ctx.Attr("epsilon"); - switch ( - data_layout) { // because it's two dimensions, so make no difference + T min_precision = 1e-7f; + const int slot_dim = ctx.Attr("slot_dim"); + switch (data_layout) { // it's two dimensions, make no difference case DataLayout::kNCHW: case DataLayout::kNHWC: { ConstEigenVectorArrayMap scales_arr(scales->data(), C); @@ -349,24 +385,60 @@ class DataNormGradKernel } } - // calculate data sum and squre sum - ConstEigenVectorArrayMap batch_size_arr(batch_size->data(), C); - ConstEigenVectorArrayMap batch_sum_arr(batch_sum->data(), C); - ConstEigenVectorArrayMap batch_square_sum_arr( - batch_square_sum->data(), C); - Eigen::Array sample_sum(C); - Eigen::Array sample_square_sum(C); - // calculate data sample sum and square sum - sample_sum.setZero(); - sample_square_sum.setZero(); - for (int nc = 0; nc < N; ++nc) { - sample_sum += x_arr.col(nc); - sample_square_sum += (x_arr.col(nc) - means_arr).square(); + if (slot_dim > 0 && N > 0) { + // if slot_dim is set and batch size is larger than zero, we choose + // to check if show number is zero, if so, skip update statistics. + int offset = 0; + const int item_size = x->numel() / N; + for (int k = 0; k < N; ++k) { + for (int i = 0; i < item_size; i += slot_dim) { + if (!(x_data[offset + i] > -min_precision && + x_data[offset + i] < min_precision)) { + // show != 0 + for (int j = i; j < i + slot_dim; ++j) { + d_batch_size_data[j] += 1; + d_batch_sum_data[j] += x_data[offset + j]; + d_batch_square_sum_data[j] += + (x_data[offset + j] - means_data[j]) * + (x_data[offset + j] - means_data[j]); + } + } + } + offset += item_size; + } + + for (int i = 0; i < item_size; i += slot_dim) { + for (int j = i; j < i + slot_dim; ++j) { + if (d_batch_size_data[j] >= 1) { + d_batch_sum_data[j] /= d_batch_size_data[j]; + d_batch_square_sum_data[j] = + d_batch_square_sum_data[j] / d_batch_size_data[j] + + d_batch_size_data[j] * epsilon; + d_batch_size_data[j] = 1; + } + } + } + } else { + // calculate data sum and squre sum + ConstEigenVectorArrayMap batch_size_arr(batch_size->data(), C); + ConstEigenVectorArrayMap batch_sum_arr(batch_sum->data(), C); + ConstEigenVectorArrayMap batch_square_sum_arr( + batch_square_sum->data(), C); + Eigen::Array sample_sum(C); + Eigen::Array sample_square_sum(C); + // calculate data sample sum and square sum + sample_sum.setZero(); + sample_square_sum.setZero(); + for (int nc = 0; nc < N; ++nc) { + sample_sum += x_arr.col(nc); + sample_square_sum += (x_arr.col(nc) - means_arr).square(); + } + // calculate gradient + d_batch_size_arr.setConstant(N); + d_batch_sum_arr = sample_sum; + d_batch_square_sum_arr = + sample_square_sum + d_batch_size_arr * epsilon; } - // calculate gradient - d_batch_size_arr.setConstant(N); - d_batch_sum_arr = sample_sum; - d_batch_square_sum_arr = sample_square_sum + d_batch_size_arr * epsilon; break; } default: diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 5757a9af113a3ba391be15a3588420900bdeafcb..2e47c6b2651a322cdee1801787311887ee8e65a9 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -2703,7 +2703,8 @@ def data_norm(input, name=None, moving_mean_name=None, moving_variance_name=None, - do_model_average_for_mean_and_var=True): + do_model_average_for_mean_and_var=True, + slot_dim=-1): """ **Data Normalization Layer** @@ -2742,6 +2743,13 @@ def data_norm(input, moving_variance_name(string, Default None): The name of the moving_variance which store the global Variance. do_model_average_for_mean_and_var(bool, Default True): Whether parameter mean and variance should do model average when model average is enabled. + slot_dim(int): The embedding dimension of one slot. Slot is a set of one specific feature. In pslib mode, we + distinguish feature ids by slot and pull their embeddings from parameter server (pslib). The first + place of the embedding is the historical show number (occurence time of this feature id with a label 0). + If the input of this op is concated by slot-wise embeddings, and the show number is zero when this slot + is new or empty, the normalization result may be impractical. To avoid this, we add slot_dim to locate + the show number and judge if the show number is zero. If so, we choose to skip normalization on this + embedding. Returns: Variable: A tensor variable which is the result after applying data normalization on the input. @@ -2819,7 +2827,8 @@ def data_norm(input, outputs={"Y": data_norm_out, "Means": means, "Scales": scales}, - attrs={"epsilon": epsilon}) + attrs={"epsilon": epsilon, + "slot_dim": slot_dim}) return helper.append_activation(data_norm_out) diff --git a/python/paddle/fluid/tests/unittests/test_data_norm_op.py b/python/paddle/fluid/tests/unittests/test_data_norm_op.py index 0273664d5d793c67ffee709ed3ab5d265879c997..b11da680b9c9f5720c38f1066078c29c2be71821 100644 --- a/python/paddle/fluid/tests/unittests/test_data_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_data_norm_op.py @@ -24,14 +24,24 @@ from op_test import OpTest from paddle.fluid.framework import grad_var_name -def _reference_testing(x, batch_size, batch_sum, batch_square_sum): +def _reference_testing(x, batch_size, batch_sum, batch_square_sum, slot_dim=-1): x_shape = x.shape means_arr = batch_sum / batch_size scales_arr = np.sqrt(batch_size / batch_square_sum) - for i in range(x_shape[0]): - x[i] -= means_arr - x[i] *= scales_arr - y = np.array(x) + min_precision = 1e-7 + if slot_dim <= 0: + for i in range(x_shape[0]): + x[i] -= means_arr + x[i] *= scales_arr + y = np.array(x) + else: + y = np.zeros(x_shape).astype(np.float32) + for i in range(x_shape[0]): + for j in range(0, x_shape[1], slot_dim): + if x[i][j] <= -min_precision or x[i][j] >= min_precision: + for k in range(0, slot_dim): + y[i][j + k] = ( + x[i][j + k] - means_arr[j + k]) * scales_arr[j + k] return y @@ -60,7 +70,7 @@ class TestDataNormOpInference(unittest.TestCase): def __assert_close(self, tensor, np_array, msg, atol=1e-4): self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg) - def check_with_place(self, place, data_layout, dtype, shape): + def check_with_place(self, place, data_layout, dtype, shape, slot_dim=-1): """ do forward and check @@ -69,6 +79,8 @@ class TestDataNormOpInference(unittest.TestCase): data_layout(str): NCHW or NWHC dtype(dtype): np.float32 shape(list): input shape + slot_dim(int): dimension of one slot. Refer to data_norm api. + """ epsilon = 0.00001 @@ -81,6 +93,8 @@ class TestDataNormOpInference(unittest.TestCase): x_val = np.random.random_sample(x_shape).astype(dtype) x_val = x_val - 0.5 + x_val[0][1] = 0.0 + x_val[1][1] = 0.0 batch_size = np.ones(scale_shape).astype(np.float32) batch_size *= 1e4 batch_sum = np.zeros(scale_shape).astype(np.float32) @@ -88,7 +102,7 @@ class TestDataNormOpInference(unittest.TestCase): batch_square_sum *= 1e4 y_out = _reference_testing(x_val, batch_size, batch_sum, - batch_square_sum).astype(dtype) + batch_square_sum, slot_dim).astype(dtype) scope = core.Scope() @@ -124,7 +138,8 @@ class TestDataNormOpInference(unittest.TestCase): Scales="scales", # attrs epsilon=epsilon, - use_mkldnn=self.use_mkldnn) + use_mkldnn=self.use_mkldnn, + slot_dim=slot_dim) data_norm_op.run(scope, place) @@ -144,7 +159,12 @@ class TestDataNormOpInference(unittest.TestCase): places = [core.CPUPlace()] for place in places: for data_format in ["NCHW", "NHWC"]: - self.check_with_place(place, data_format, self.dtype, [2, 3]) + for slot_dim in [-1, 1]: + self.check_with_place( + place, + data_format, + self.dtype, [2, 3], + slot_dim=slot_dim) class TestDataNormOp(OpTest): @@ -199,5 +219,62 @@ class TestDataNormOp(OpTest): self.check_grad(['X'], 'Y', no_grad_set=set([])) +class TestDataNormOpWithSlotDim(OpTest): + """ + test class for data norm op + test forward and backward + """ + + def setUp(self): + """ + init data norm op test env + """ + self.op_type = 'data_norm' + self.use_mkldnn = False + epsilon = 0.00001 + slot_dim = 1 + x_shape = [2, 3] + scale_shape = [3] + tp = np.float32 + + x_val = np.array([[-0.35702616, 0.0, -0.08306625], + [0.41199666, 0.0, -0.10180971]]).astype(tp) + batch_size = np.ones(scale_shape).astype(tp) + batch_size *= 1e4 + batch_sum = np.zeros(scale_shape).astype(tp) + batch_square_sum = np.ones(scale_shape).astype(tp) + batch_square_sum *= 1e4 + + y = np.array(x_val) + + mean = np.array([[0, 0, 0], [0, 0, 0]]).astype(tp) + scale = np.array([[1, 1, 1], [1, 1, 1]]).astype(tp) + + self.inputs = { + "X": x_val, + "BatchSize": batch_size, + "BatchSum": batch_sum, + "BatchSquareSum": batch_square_sum + } + self.outputs = {"Y": y, "Means": mean, "Scales": scale} + self.attrs = { + "epsilon": epsilon, + "use_mkldnn": self.use_mkldnn, + "slot_dim": slot_dim + } + + def test_check_output(self): + """ + test check forward, check output + """ + self.check_output() + + def test_check_grad(self): + """ + test check backward, check grad + """ + self.check_grad(['X'], 'Y', no_grad_set=set([])) + + if __name__ == '__main__': unittest.main()