diff --git a/paddle/fluid/operators/data_norm_op.cc b/paddle/fluid/operators/data_norm_op.cc index 76fe6f53299a6b21c83ec72a8d5382c851914fd6..095f9270273232f155f45f56d69912741b0f99b9 100644 --- a/paddle/fluid/operators/data_norm_op.cc +++ b/paddle/fluid/operators/data_norm_op.cc @@ -129,7 +129,13 @@ class DataNormOpMaker : public framework::OpProtoAndCheckerMaker { "(int, default -1) Dimension of one slot if set, " "when the input is concated by slot-wise embeddings") .SetDefault(-1); + AddAttr( + "summary_decay_rate", + "(float, default 0.9999999) The decay rate when update the summary") + .SetDefault(0.9999999); AddAttr("data_layout", "").SetDefault("NCHW"); + AddAttr("sync_stats", "(bool, default false) only used in multi-GPU") + .SetDefault(false); AddAttr("use_mkldnn", "(bool, default false) Only used in mkldnn kernel") .SetDefault(false); @@ -254,9 +260,18 @@ class DataNormGradOp : public framework::OperatorWithKernel { // check input PADDLE_ENFORCE(ctx->HasInput("X")); PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")), ""); - PADDLE_ENFORCE(ctx->HasInput("BatchSize"), ""); - PADDLE_ENFORCE(ctx->HasInput("BatchSum"), ""); - PADDLE_ENFORCE(ctx->HasInput("BatchSquareSum"), ""); + PADDLE_ENFORCE_EQ( + ctx->HasOutput("BatchSize"), true, + platform::errors::NotFound( + "Output(BatchSize) of DataNormGradOp should not be null.")); + PADDLE_ENFORCE_EQ( + ctx->HasOutput("BatchSum"), true, + platform::errors::NotFound( + "Output(BatchSum) of DataNormGradOp should not be null.")); + PADDLE_ENFORCE_EQ( + ctx->HasOutput("BatchSquareSum"), true, + platform::errors::NotFound( + "Output(BatchSquareSum) of DataNormGradOp should not be null.")); PADDLE_ENFORCE(ctx->HasInput("Means"), ""); PADDLE_ENFORCE(ctx->HasInput("Scales"), ""); @@ -323,9 +338,6 @@ class DataNormGradKernel void Compute(const framework::ExecutionContext &ctx) const override { const auto *x = ctx.Input("X"); const auto *d_y = ctx.Input(framework::GradVarName("Y")); - const auto *batch_size = ctx.Input("BatchSize"); - const auto *batch_sum = ctx.Input("BatchSum"); - const auto *batch_square_sum = ctx.Input("BatchSquareSum"); const auto *scales = ctx.Input("Scales"); const auto *means = ctx.Input("Means"); @@ -420,10 +432,6 @@ class DataNormGradKernel } } else { // calculate data sum and squre sum - ConstEigenVectorArrayMap batch_size_arr(batch_size->data(), C); - ConstEigenVectorArrayMap batch_sum_arr(batch_sum->data(), C); - ConstEigenVectorArrayMap batch_square_sum_arr( - batch_square_sum->data(), C); Eigen::Array sample_sum(C); Eigen::Array sample_square_sum(C); // calculate data sample sum and square sum @@ -459,9 +467,9 @@ class DataNormGradMaker : public framework::SingleGradOpMaker { op->SetInput("X", this->Input("X")); op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y")); - op->SetInput("BatchSize", this->Input("BatchSize")); - op->SetInput("BatchSum", this->Input("BatchSum")); - op->SetInput("BatchSquareSum", this->Input("BatchSquareSum")); + op->SetOutput("BatchSize", this->Input("BatchSize")); + op->SetOutput("BatchSum", this->Input("BatchSum")); + op->SetOutput("BatchSquareSum", this->Input("BatchSquareSum")); op->SetInput("Scales", this->Output("Scales")); op->SetInput("Means", this->Output("Means")); diff --git a/paddle/fluid/operators/data_norm_op.cu b/paddle/fluid/operators/data_norm_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..8d433b99ca60b7a52b0abc919329a2ac93978b8e --- /dev/null +++ b/paddle/fluid/operators/data_norm_op.cu @@ -0,0 +1,219 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include "paddle/fluid/framework/data_layout.h" +#include "paddle/fluid/operators/data_norm_op.h" +#include "paddle/fluid/platform/collective_helper.h" +#include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/platform/nccl_helper.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; +using DataLayout = framework::DataLayout; +using platform::PADDLE_CUDA_NUM_THREADS; + +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + +inline int GET_BLOCKS(const int N) { + return (N + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS; +} + +template +__global__ void KernelDataNormFF(int N, int C, const T *x, T *y, const T *mean, + const T *scale) { + CUDA_KERNEL_LOOP(i, N * C) { + int col = i % C; + y[i] = (x[i] - mean[col]) * scale[col]; + } +} + +template +__global__ void KernelMeanScale(int C, const T *batch_size, const T *batch_sum, + const T *batch_square_sum, T *mean, T *scale) { + CUDA_KERNEL_LOOP(i, C) { + mean[i] = batch_sum[i] / batch_size[i]; + scale[i] = sqrt(batch_size[i] / batch_square_sum[i]); + } +} + +template +__global__ void KernelDataNormBP(int N, int C, const T *y_grad, const T *scale, + T *x_grad) { + CUDA_KERNEL_LOOP(i, N * C) { x_grad[i] = y_grad[i] * scale[i % C]; } +} + +template +__global__ void KernelDataNormBPStat(int N, int C, const T *x_val, + const T *means, + const float squared_sum_epsilon, + T *batch_size, T *batch_sum, + T *batch_square_sum) { + CUDA_KERNEL_LOOP(i, C) { + T val_sum = 0; + T square_sum = 0; + for (int j = 0; j < N; j++) { + val_sum += x_val[j * C + i]; + square_sum += + (x_val[j * C + i] - means[i]) * (x_val[j * C + i] - means[i]); + } + batch_size[i] = 1; + batch_sum[i] = val_sum / N; + batch_square_sum[i] = square_sum / N + squared_sum_epsilon; + } +} + +template +__global__ void KernelUpdateParam(int C, const T *d_batch_size, + const T *d_batch_sum, + const T *d_batch_square_sum, T *batch_size, + T *batch_sum, T *batch_square_sum, + const float decay_rate) { + CUDA_KERNEL_LOOP(i, C) { + batch_size[i] = batch_size[i] * decay_rate + d_batch_size[i]; + batch_sum[i] = batch_sum[i] * decay_rate + d_batch_sum[i]; + batch_square_sum[i] = + batch_square_sum[i] * decay_rate + d_batch_square_sum[i]; + } +} + +template +class DataNormKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + const auto *x = ctx.Input("X"); + const auto &x_dims = x->dims(); + // Align with CPU version, but should we add this restriction? + PADDLE_ENFORCE_EQ(x_dims.size(), 2, platform::errors::PreconditionNotMet( + "The Input dim size should be 2")); + const int N = x_dims[0]; + const int C = x_dims[1]; + const T *batch_size_in = ctx.Input("BatchSize")->data(); + const T *batch_sum_in = ctx.Input("BatchSum")->data(); + const T *batch_square_sum_in = + ctx.Input("BatchSquareSum")->data(); + auto *x_data = x->data(); + + // alloc memory + T *y_data = ctx.Output("Y")->mutable_data(ctx.GetPlace()); + T *mean_out_data = + ctx.Output("Means")->mutable_data(ctx.GetPlace()); + T *scale_out_data = + ctx.Output("Scales")->mutable_data(ctx.GetPlace()); + + auto stream = + ctx.template device_context().stream(); + + KernelMeanScale<<>>( + C, batch_size_in, batch_sum_in, batch_square_sum_in, mean_out_data, + scale_out_data); + KernelDataNormFF<<>>( + N, C, x_data, y_data, mean_out_data, scale_out_data); + } +}; + +template +class DataNormGradKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + const auto *x = ctx.Input("X"); + const auto *d_y = ctx.Input(framework::GradVarName("Y")); + const auto *scales = ctx.Input("Scales"); + const auto *means = ctx.Input("Means"); + const float epsilon = ctx.Attr("epsilon"); + const float dr = ctx.Attr("summary_decay_rate"); + const bool need_sync_stats = ctx.Attr("sync_stats"); + + const auto &x_dims = x->dims(); + // Align with CPU version, but should we add this restriction? + PADDLE_ENFORCE_EQ(x_dims.size(), 2, platform::errors::PreconditionNotMet( + "The Input dim size should be 2")); + const int N = x_dims[0]; + const int C = x_dims[1]; + + // init output + Tensor *d_x = nullptr; + if (ctx.HasOutput(framework::GradVarName("X"))) { + d_x = ctx.Output(framework::GradVarName("X")); + } + T *d_batch_size = ctx.Output(framework::GradVarName("BatchSize")) + ->mutable_data(ctx.GetPlace()); + T *d_batch_sum = ctx.Output(framework::GradVarName("BatchSum")) + ->mutable_data(ctx.GetPlace()); + T *d_batch_square_sum = + ctx.Output(framework::GradVarName("BatchSquareSum")) + ->mutable_data(ctx.GetPlace()); + + auto stream = + ctx.template device_context().stream(); + if (d_x != nullptr) { + KernelDataNormBP<<>>(N, C, d_y->data(), scales->data(), + d_x->mutable_data(ctx.GetPlace())); + } + + KernelDataNormBPStat<<>>( + N, C, x->data(), means->data(), epsilon, d_batch_size, + d_batch_sum, d_batch_square_sum); + + if (need_sync_stats) { + auto comm = platform::NCCLCommContext::Instance().Get(0, ctx.GetPlace()); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllReduce( + reinterpret_cast(d_batch_size), + reinterpret_cast(d_batch_size), C, + platform::ToNCCLDataType(x->type()), ncclSum, comm->comm(), stream)); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllReduce( + reinterpret_cast(d_batch_sum), + reinterpret_cast(d_batch_sum), C, + platform::ToNCCLDataType(x->type()), ncclSum, comm->comm(), stream)); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllReduce( + reinterpret_cast(d_batch_square_sum), + reinterpret_cast(d_batch_square_sum), C, + platform::ToNCCLDataType(x->type()), ncclSum, comm->comm(), stream)); + cudaError_t e_sync = cudaStreamSynchronize(stream); + if (e_sync != 0) { + LOG(FATAL) << "Fail to sync nccl stream: " + << cudaGetErrorString(e_sync); + } + } + T *batch_size_data = + ctx.Output("BatchSize")->mutable_data(ctx.GetPlace()); + T *batch_sum_data = + ctx.Output("BatchSum")->mutable_data(ctx.GetPlace()); + T *batch_square_sum_data = + ctx.Output("BatchSquareSum")->mutable_data(ctx.GetPlace()); + KernelUpdateParam<<>>( + C, d_batch_size, d_batch_sum, d_batch_square_sum, batch_size_data, + batch_sum_data, batch_square_sum_data, dr); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + data_norm, ops::DataNormKernel, + ops::DataNormKernel); +REGISTER_OP_CUDA_KERNEL( + data_norm_grad, + ops::DataNormGradKernel, + ops::DataNormGradKernel); diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 55176448c058e4d2db770187ec445f409af83247..93ca24f7c318de891aad9a1ae02ca6a681f84683 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -2704,7 +2704,9 @@ def data_norm(input, moving_mean_name=None, moving_variance_name=None, do_model_average_for_mean_and_var=True, - slot_dim=-1): + slot_dim=-1, + sync_stats=False, + summary_decay_rate=0.9999999): """ **Data Normalization Layer** @@ -2750,6 +2752,9 @@ def data_norm(input, is new or empty, the normalization result may be impractical. To avoid this, we add slot_dim to locate the show number and judge if the show number is zero. If so, we choose to skip normalization on this embedding. + sync_stats(bool, Default False): When running with multiple GPU cards, using allreduce to sync the + summary messages. + summary_decay_rate(float, Default 0.9999999): The decay rate when updating summary. Returns: Variable: A tensor variable which is the result after applying data normalization on the input. @@ -2824,11 +2829,20 @@ def data_norm(input, "BatchSum": batch_sum, "BatchSquareSum": batch_square_sum }, - outputs={"Y": data_norm_out, - "Means": means, - "Scales": scales}, - attrs={"epsilon": epsilon, - "slot_dim": slot_dim}) + outputs={ + "Y": data_norm_out, + "Means": means, + "Scales": scales, + "BatchSize": batch_size, + "BatchSum": batch_sum, + "BatchSquareSum": batch_square_sum + }, + attrs={ + "epsilon": epsilon, + "slot_dim": slot_dim, + "sync_stats": sync_stats, + "summary_decay_rate": summary_decay_rate + }) return helper.append_activation(data_norm_out) diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 97dc5eb08cf446656c6d57f691d62fc465e0dffc..9c80c4fc2b9f0afd3c923572e3058a57d569066d 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -147,6 +147,7 @@ endfunction() list(REMOVE_ITEM TEST_OPS test_warpctc_op) list(REMOVE_ITEM TEST_OPS test_parallel_executor_crf) +list(REMOVE_ITEM TEST_OPS test_data_norm_op) list(REMOVE_ITEM TEST_OPS test_parallel_executor_crf_auto_growth) list(REMOVE_ITEM TEST_OPS test_parallel_executor_fetch_feed) list(REMOVE_ITEM TEST_OPS test_parallel_executor_transformer) @@ -279,6 +280,7 @@ py_test_modules(test_layers MODULES test_layers ENVS FLAGS_cudnn_deterministic=1 py_test_modules(test_parallel_executor_seresnext_base_cpu MODULES test_parallel_executor_seresnext_base_cpu) py_test_modules(test_parallel_executor_seresnext_with_reduce_cpu MODULES test_parallel_executor_seresnext_with_reduce_cpu) py_test_modules(test_parallel_executor_seresnext_with_fuse_all_reduce_cpu MODULES test_parallel_executor_seresnext_with_fuse_all_reduce_cpu) +py_test_modules(test_data_norm_op MODULES test_data_norm_op) if(NOT WIN32) py_test_modules(test_ir_memory_optimize_transformer MODULES test_ir_memory_optimize_transformer) @@ -309,4 +311,5 @@ set_tests_properties(test_parallel_executor_test_while_train test_parallel_execu test_parallel_executor_crf test_sync_batch_norm_op test_parallel_executor_feed_persistable_var test_parallel_executor_crf_auto_growth test_buffer_shared_memory_reuse_pass_and_fuse_optimization_op_pass + test_data_norm_op test_buffer_shared_memory_reuse_pass PROPERTIES LABELS "RUN_TYPE=DIST") diff --git a/python/paddle/fluid/tests/unittests/test_data_norm_op.py b/python/paddle/fluid/tests/unittests/test_data_norm_op.py index b11da680b9c9f5720c38f1066078c29c2be71821..a3f1ece4f43b3a3eb0c32b3d2d5967f6c717574a 100644 --- a/python/paddle/fluid/tests/unittests/test_data_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_data_norm_op.py @@ -20,6 +20,8 @@ import numpy as np import paddle.fluid.core as core from paddle.fluid.op import Operator import paddle.fluid as fluid +import paddle.fluid.layers as layers +import os from op_test import OpTest from paddle.fluid.framework import grad_var_name @@ -216,7 +218,7 @@ class TestDataNormOp(OpTest): """ test check backward, check grad """ - self.check_grad(['X'], 'Y', no_grad_set=set([])) + self.check_grad(['X'], 'Y', no_grad_set=set([]), check_dygraph=False) class TestDataNormOpWithSlotDim(OpTest): @@ -273,7 +275,125 @@ class TestDataNormOpWithSlotDim(OpTest): """ test check backward, check grad """ - self.check_grad(['X'], 'Y', no_grad_set=set([])) + self.check_grad(['X'], 'Y', no_grad_set=set([]), check_dygraph=False) + + +class TestDataNormOpWithSyncStats(OpTest): + """ + test class for data norm op + test forward and backward + """ + + def test_sync_stats(self): + if not core.is_compiled_with_cuda(): + return + x = fluid.layers.data(name='x', shape=[1], dtype='int64', lod_level=0) + emb = layers.embedding( + input=x, + param_attr=fluid.ParamAttr(name="embx"), + size=[10, 2], + is_sparse=False) + + dn = layers.data_norm( + input=emb, + name="hehe", + epsilon=1e-4, + param_attr={ + "batch_size": 1e4, + "batch_sum": 1e5, + "batch_square": 1e4 + }, + summary_decay_rate=1, + sync_stats=True) #[-1,3] + loss = layers.mean(dn) + + optimizer = fluid.optimizer.SGD(learning_rate=0.5) + optimizer = fluid.optimizer.PipelineOptimizer( + optimizer, + cut_list=[[emb], [loss]], + place_list=[ + fluid.CUDAPlace(0), fluid.CUDAPlace(0), fluid.CPUPlace() + ], + concurrency_list=[1, 1, 1], + queue_size=1, + sync_steps=10000000, ) + + all_p = fluid.default_main_program().global_block().all_parameters() + parameter_without_datanorm = [] + for e in all_p: + if e.name.find("batch_size") != -1 or e.name.find( + "batch_sq") != -1 or e.name.find("batch_sum") != -1: + continue + parameter_without_datanorm.append(e.name) + optimizer.minimize(loss, parameter_list=parameter_without_datanorm) + place = fluid.CUDAPlace(0) + exe = fluid.Executor(place) + #prepare data + batch_size = 1 + + def binary_print(slot, fout): + num = np.int16(len(slot) + 1) + num.tofile(fout) + a = np.int64(batch_size) + a.tofile(fout) + slot.tofile(fout) + + #batch1 = np.array([[0,1], [1,2], [2,3]]).astype("int64").reshape(batch_size,2,1) + #batch2 = np.array([[1,2], [2,3], [3,4]]).astype("int64").reshape(batch_size,2,1) + batch1 = np.ones( + (batch_size, 1)).astype("int64").reshape(batch_size, 1, 1) + batch2 = np.ones( + (batch_size, 1)).astype("int64").reshape(batch_size, 1, 1) + data = [batch1, batch2] + data = [batch1] + filelist = [] + for i in range(2): + filelist.append("test_pipeline_input_" + str(i)) + for f in filelist: + with open(f, "wb") as fout: + for batch_data in data: + for ins in batch_data: + for slot in ins: + binary_print(slot, fout) + + dataset = fluid.DatasetFactory().create_dataset("FileInstantDataset") + dataset.set_use_var([x]) + dataset.set_batch_size(batch_size) + dataset.set_filelist(filelist) + + block = fluid.default_startup_program().global_block() + block.append_op( + type='c_comm_init_all', attrs={'ring_id': 0, + 'devices': [0, 1]}) + with open("main_program", "w") as fout: + fout.write(str(fluid.default_main_program())) + with open("startup_program", "w") as fout: + fout.write(str(fluid.default_startup_program())) + exe.run(fluid.default_startup_program()) + emb_t = fluid.global_scope().find_var("embx").get_tensor() + para = np.ones((10, 2)).astype("float32") + emb_t.set(para, place) + for epoch in range(1): + exe.train_from_dataset( + fluid.default_main_program(), + dataset, + thread=2, + debug=False, + fetch_list=[], + fetch_info=[], + print_period=1) + batch_size = np.array(fluid.global_scope().find_var("hehe.batch_size") + .get_tensor()) + self.assertEqual(batch_size[0], 10002) + b = np.array(fluid.global_scope().find_var("hehe.batch_sum").get_tensor( + )) + self.assertEqual(b[0], 100002) + c = np.array(fluid.global_scope().find_var("hehe.batch_square_sum") + .get_tensor()) + self.assertEqual(c[0], 10162) + + for f in filelist: + os.remove(f) if __name__ == '__main__':