未验证 提交 47a82e38 编写于 作者: H hutuxian 提交者: GitHub

Support data_norm gpu kernel (#21325)

* support data_norm_op run in CUDA
* add two parameters sync_stats & summary_decay_rate
* add UT
上级 d5ff79e5
......@@ -129,7 +129,13 @@ class DataNormOpMaker : public framework::OpProtoAndCheckerMaker {
"(int, default -1) Dimension of one slot if set, "
"when the input is concated by slot-wise embeddings")
.SetDefault(-1);
AddAttr<float>(
"summary_decay_rate",
"(float, default 0.9999999) The decay rate when update the summary")
.SetDefault(0.9999999);
AddAttr<std::string>("data_layout", "").SetDefault("NCHW");
AddAttr<bool>("sync_stats", "(bool, default false) only used in multi-GPU")
.SetDefault(false);
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
......@@ -254,9 +260,18 @@ class DataNormGradOp : public framework::OperatorWithKernel {
// check input
PADDLE_ENFORCE(ctx->HasInput("X"));
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")), "");
PADDLE_ENFORCE(ctx->HasInput("BatchSize"), "");
PADDLE_ENFORCE(ctx->HasInput("BatchSum"), "");
PADDLE_ENFORCE(ctx->HasInput("BatchSquareSum"), "");
PADDLE_ENFORCE_EQ(
ctx->HasOutput("BatchSize"), true,
platform::errors::NotFound(
"Output(BatchSize) of DataNormGradOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("BatchSum"), true,
platform::errors::NotFound(
"Output(BatchSum) of DataNormGradOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("BatchSquareSum"), true,
platform::errors::NotFound(
"Output(BatchSquareSum) of DataNormGradOp should not be null."));
PADDLE_ENFORCE(ctx->HasInput("Means"), "");
PADDLE_ENFORCE(ctx->HasInput("Scales"), "");
......@@ -323,9 +338,6 @@ class DataNormGradKernel<platform::CPUDeviceContext, T>
void Compute(const framework::ExecutionContext &ctx) const override {
const auto *x = ctx.Input<Tensor>("X");
const auto *d_y = ctx.Input<Tensor>(framework::GradVarName("Y"));
const auto *batch_size = ctx.Input<Tensor>("BatchSize");
const auto *batch_sum = ctx.Input<Tensor>("BatchSum");
const auto *batch_square_sum = ctx.Input<Tensor>("BatchSquareSum");
const auto *scales = ctx.Input<Tensor>("Scales");
const auto *means = ctx.Input<Tensor>("Means");
......@@ -420,10 +432,6 @@ class DataNormGradKernel<platform::CPUDeviceContext, T>
}
} else {
// calculate data sum and squre sum
ConstEigenVectorArrayMap<T> batch_size_arr(batch_size->data<T>(), C);
ConstEigenVectorArrayMap<T> batch_sum_arr(batch_sum->data<T>(), C);
ConstEigenVectorArrayMap<T> batch_square_sum_arr(
batch_square_sum->data<T>(), C);
Eigen::Array<T, Eigen::Dynamic, 1> sample_sum(C);
Eigen::Array<T, Eigen::Dynamic, 1> sample_square_sum(C);
// calculate data sample sum and square sum
......@@ -459,9 +467,9 @@ class DataNormGradMaker : public framework::SingleGradOpMaker<T> {
op->SetInput("X", this->Input("X"));
op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y"));
op->SetInput("BatchSize", this->Input("BatchSize"));
op->SetInput("BatchSum", this->Input("BatchSum"));
op->SetInput("BatchSquareSum", this->Input("BatchSquareSum"));
op->SetOutput("BatchSize", this->Input("BatchSize"));
op->SetOutput("BatchSum", this->Input("BatchSum"));
op->SetOutput("BatchSquareSum", this->Input("BatchSquareSum"));
op->SetInput("Scales", this->Output("Scales"));
op->SetInput("Means", this->Output("Means"));
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <memory>
#include <string>
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/operators/data_norm_op.h"
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/nccl_helper.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
using DataLayout = framework::DataLayout;
using platform::PADDLE_CUDA_NUM_THREADS;
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
inline int GET_BLOCKS(const int N) {
return (N + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS;
}
template <typename T>
__global__ void KernelDataNormFF(int N, int C, const T *x, T *y, const T *mean,
const T *scale) {
CUDA_KERNEL_LOOP(i, N * C) {
int col = i % C;
y[i] = (x[i] - mean[col]) * scale[col];
}
}
template <typename T>
__global__ void KernelMeanScale(int C, const T *batch_size, const T *batch_sum,
const T *batch_square_sum, T *mean, T *scale) {
CUDA_KERNEL_LOOP(i, C) {
mean[i] = batch_sum[i] / batch_size[i];
scale[i] = sqrt(batch_size[i] / batch_square_sum[i]);
}
}
template <typename T>
__global__ void KernelDataNormBP(int N, int C, const T *y_grad, const T *scale,
T *x_grad) {
CUDA_KERNEL_LOOP(i, N * C) { x_grad[i] = y_grad[i] * scale[i % C]; }
}
template <typename T>
__global__ void KernelDataNormBPStat(int N, int C, const T *x_val,
const T *means,
const float squared_sum_epsilon,
T *batch_size, T *batch_sum,
T *batch_square_sum) {
CUDA_KERNEL_LOOP(i, C) {
T val_sum = 0;
T square_sum = 0;
for (int j = 0; j < N; j++) {
val_sum += x_val[j * C + i];
square_sum +=
(x_val[j * C + i] - means[i]) * (x_val[j * C + i] - means[i]);
}
batch_size[i] = 1;
batch_sum[i] = val_sum / N;
batch_square_sum[i] = square_sum / N + squared_sum_epsilon;
}
}
template <typename T>
__global__ void KernelUpdateParam(int C, const T *d_batch_size,
const T *d_batch_sum,
const T *d_batch_square_sum, T *batch_size,
T *batch_sum, T *batch_square_sum,
const float decay_rate) {
CUDA_KERNEL_LOOP(i, C) {
batch_size[i] = batch_size[i] * decay_rate + d_batch_size[i];
batch_sum[i] = batch_sum[i] * decay_rate + d_batch_sum[i];
batch_square_sum[i] =
batch_square_sum[i] * decay_rate + d_batch_square_sum[i];
}
}
template <typename T>
class DataNormKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
const auto *x = ctx.Input<Tensor>("X");
const auto &x_dims = x->dims();
// Align with CPU version, but should we add this restriction?
PADDLE_ENFORCE_EQ(x_dims.size(), 2, platform::errors::PreconditionNotMet(
"The Input dim size should be 2"));
const int N = x_dims[0];
const int C = x_dims[1];
const T *batch_size_in = ctx.Input<Tensor>("BatchSize")->data<T>();
const T *batch_sum_in = ctx.Input<Tensor>("BatchSum")->data<T>();
const T *batch_square_sum_in =
ctx.Input<Tensor>("BatchSquareSum")->data<T>();
auto *x_data = x->data<T>();
// alloc memory
T *y_data = ctx.Output<Tensor>("Y")->mutable_data<T>(ctx.GetPlace());
T *mean_out_data =
ctx.Output<Tensor>("Means")->mutable_data<T>(ctx.GetPlace());
T *scale_out_data =
ctx.Output<Tensor>("Scales")->mutable_data<T>(ctx.GetPlace());
auto stream =
ctx.template device_context<platform::CUDADeviceContext>().stream();
KernelMeanScale<<<GET_BLOCKS(C), PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
C, batch_size_in, batch_sum_in, batch_square_sum_in, mean_out_data,
scale_out_data);
KernelDataNormFF<<<GET_BLOCKS(C * N), PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
N, C, x_data, y_data, mean_out_data, scale_out_data);
}
};
template <typename T>
class DataNormGradKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
const auto *x = ctx.Input<Tensor>("X");
const auto *d_y = ctx.Input<Tensor>(framework::GradVarName("Y"));
const auto *scales = ctx.Input<Tensor>("Scales");
const auto *means = ctx.Input<Tensor>("Means");
const float epsilon = ctx.Attr<float>("epsilon");
const float dr = ctx.Attr<float>("summary_decay_rate");
const bool need_sync_stats = ctx.Attr<bool>("sync_stats");
const auto &x_dims = x->dims();
// Align with CPU version, but should we add this restriction?
PADDLE_ENFORCE_EQ(x_dims.size(), 2, platform::errors::PreconditionNotMet(
"The Input dim size should be 2"));
const int N = x_dims[0];
const int C = x_dims[1];
// init output
Tensor *d_x = nullptr;
if (ctx.HasOutput(framework::GradVarName("X"))) {
d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
}
T *d_batch_size = ctx.Output<Tensor>(framework::GradVarName("BatchSize"))
->mutable_data<T>(ctx.GetPlace());
T *d_batch_sum = ctx.Output<Tensor>(framework::GradVarName("BatchSum"))
->mutable_data<T>(ctx.GetPlace());
T *d_batch_square_sum =
ctx.Output<Tensor>(framework::GradVarName("BatchSquareSum"))
->mutable_data<T>(ctx.GetPlace());
auto stream =
ctx.template device_context<platform::CUDADeviceContext>().stream();
if (d_x != nullptr) {
KernelDataNormBP<<<GET_BLOCKS(C * N), PADDLE_CUDA_NUM_THREADS, 0,
stream>>>(N, C, d_y->data<T>(), scales->data<T>(),
d_x->mutable_data<T>(ctx.GetPlace()));
}
KernelDataNormBPStat<<<GET_BLOCKS(C), PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
N, C, x->data<T>(), means->data<T>(), epsilon, d_batch_size,
d_batch_sum, d_batch_square_sum);
if (need_sync_stats) {
auto comm = platform::NCCLCommContext::Instance().Get(0, ctx.GetPlace());
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllReduce(
reinterpret_cast<const void *>(d_batch_size),
reinterpret_cast<void *>(d_batch_size), C,
platform::ToNCCLDataType(x->type()), ncclSum, comm->comm(), stream));
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllReduce(
reinterpret_cast<const void *>(d_batch_sum),
reinterpret_cast<void *>(d_batch_sum), C,
platform::ToNCCLDataType(x->type()), ncclSum, comm->comm(), stream));
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllReduce(
reinterpret_cast<const void *>(d_batch_square_sum),
reinterpret_cast<void *>(d_batch_square_sum), C,
platform::ToNCCLDataType(x->type()), ncclSum, comm->comm(), stream));
cudaError_t e_sync = cudaStreamSynchronize(stream);
if (e_sync != 0) {
LOG(FATAL) << "Fail to sync nccl stream: "
<< cudaGetErrorString(e_sync);
}
}
T *batch_size_data =
ctx.Output<Tensor>("BatchSize")->mutable_data<T>(ctx.GetPlace());
T *batch_sum_data =
ctx.Output<Tensor>("BatchSum")->mutable_data<T>(ctx.GetPlace());
T *batch_square_sum_data =
ctx.Output<Tensor>("BatchSquareSum")->mutable_data<T>(ctx.GetPlace());
KernelUpdateParam<<<GET_BLOCKS(C), PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
C, d_batch_size, d_batch_sum, d_batch_square_sum, batch_size_data,
batch_sum_data, batch_square_sum_data, dr);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
data_norm, ops::DataNormKernel<paddle::platform::CUDADeviceContext, float>,
ops::DataNormKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
data_norm_grad,
ops::DataNormGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::DataNormGradKernel<paddle::platform::CUDADeviceContext, double>);
......@@ -2704,7 +2704,9 @@ def data_norm(input,
moving_mean_name=None,
moving_variance_name=None,
do_model_average_for_mean_and_var=True,
slot_dim=-1):
slot_dim=-1,
sync_stats=False,
summary_decay_rate=0.9999999):
"""
**Data Normalization Layer**
......@@ -2750,6 +2752,9 @@ def data_norm(input,
is new or empty, the normalization result may be impractical. To avoid this, we add slot_dim to locate
the show number and judge if the show number is zero. If so, we choose to skip normalization on this
embedding.
sync_stats(bool, Default False): When running with multiple GPU cards, using allreduce to sync the
summary messages.
summary_decay_rate(float, Default 0.9999999): The decay rate when updating summary.
Returns:
Variable: A tensor variable which is the result after applying data normalization on the input.
......@@ -2824,11 +2829,20 @@ def data_norm(input,
"BatchSum": batch_sum,
"BatchSquareSum": batch_square_sum
},
outputs={"Y": data_norm_out,
"Means": means,
"Scales": scales},
attrs={"epsilon": epsilon,
"slot_dim": slot_dim})
outputs={
"Y": data_norm_out,
"Means": means,
"Scales": scales,
"BatchSize": batch_size,
"BatchSum": batch_sum,
"BatchSquareSum": batch_square_sum
},
attrs={
"epsilon": epsilon,
"slot_dim": slot_dim,
"sync_stats": sync_stats,
"summary_decay_rate": summary_decay_rate
})
return helper.append_activation(data_norm_out)
......
......@@ -147,6 +147,7 @@ endfunction()
list(REMOVE_ITEM TEST_OPS test_warpctc_op)
list(REMOVE_ITEM TEST_OPS test_parallel_executor_crf)
list(REMOVE_ITEM TEST_OPS test_data_norm_op)
list(REMOVE_ITEM TEST_OPS test_parallel_executor_crf_auto_growth)
list(REMOVE_ITEM TEST_OPS test_parallel_executor_fetch_feed)
list(REMOVE_ITEM TEST_OPS test_parallel_executor_transformer)
......@@ -279,6 +280,7 @@ py_test_modules(test_layers MODULES test_layers ENVS FLAGS_cudnn_deterministic=1
py_test_modules(test_parallel_executor_seresnext_base_cpu MODULES test_parallel_executor_seresnext_base_cpu)
py_test_modules(test_parallel_executor_seresnext_with_reduce_cpu MODULES test_parallel_executor_seresnext_with_reduce_cpu)
py_test_modules(test_parallel_executor_seresnext_with_fuse_all_reduce_cpu MODULES test_parallel_executor_seresnext_with_fuse_all_reduce_cpu)
py_test_modules(test_data_norm_op MODULES test_data_norm_op)
if(NOT WIN32)
py_test_modules(test_ir_memory_optimize_transformer MODULES test_ir_memory_optimize_transformer)
......@@ -309,4 +311,5 @@ set_tests_properties(test_parallel_executor_test_while_train test_parallel_execu
test_parallel_executor_crf test_sync_batch_norm_op
test_parallel_executor_feed_persistable_var
test_parallel_executor_crf_auto_growth test_buffer_shared_memory_reuse_pass_and_fuse_optimization_op_pass
test_data_norm_op
test_buffer_shared_memory_reuse_pass PROPERTIES LABELS "RUN_TYPE=DIST")
......@@ -20,6 +20,8 @@ import numpy as np
import paddle.fluid.core as core
from paddle.fluid.op import Operator
import paddle.fluid as fluid
import paddle.fluid.layers as layers
import os
from op_test import OpTest
from paddle.fluid.framework import grad_var_name
......@@ -216,7 +218,7 @@ class TestDataNormOp(OpTest):
"""
test check backward, check grad
"""
self.check_grad(['X'], 'Y', no_grad_set=set([]))
self.check_grad(['X'], 'Y', no_grad_set=set([]), check_dygraph=False)
class TestDataNormOpWithSlotDim(OpTest):
......@@ -273,7 +275,125 @@ class TestDataNormOpWithSlotDim(OpTest):
"""
test check backward, check grad
"""
self.check_grad(['X'], 'Y', no_grad_set=set([]))
self.check_grad(['X'], 'Y', no_grad_set=set([]), check_dygraph=False)
class TestDataNormOpWithSyncStats(OpTest):
"""
test class for data norm op
test forward and backward
"""
def test_sync_stats(self):
if not core.is_compiled_with_cuda():
return
x = fluid.layers.data(name='x', shape=[1], dtype='int64', lod_level=0)
emb = layers.embedding(
input=x,
param_attr=fluid.ParamAttr(name="embx"),
size=[10, 2],
is_sparse=False)
dn = layers.data_norm(
input=emb,
name="hehe",
epsilon=1e-4,
param_attr={
"batch_size": 1e4,
"batch_sum": 1e5,
"batch_square": 1e4
},
summary_decay_rate=1,
sync_stats=True) #[-1,3]
loss = layers.mean(dn)
optimizer = fluid.optimizer.SGD(learning_rate=0.5)
optimizer = fluid.optimizer.PipelineOptimizer(
optimizer,
cut_list=[[emb], [loss]],
place_list=[
fluid.CUDAPlace(0), fluid.CUDAPlace(0), fluid.CPUPlace()
],
concurrency_list=[1, 1, 1],
queue_size=1,
sync_steps=10000000, )
all_p = fluid.default_main_program().global_block().all_parameters()
parameter_without_datanorm = []
for e in all_p:
if e.name.find("batch_size") != -1 or e.name.find(
"batch_sq") != -1 or e.name.find("batch_sum") != -1:
continue
parameter_without_datanorm.append(e.name)
optimizer.minimize(loss, parameter_list=parameter_without_datanorm)
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
#prepare data
batch_size = 1
def binary_print(slot, fout):
num = np.int16(len(slot) + 1)
num.tofile(fout)
a = np.int64(batch_size)
a.tofile(fout)
slot.tofile(fout)
#batch1 = np.array([[0,1], [1,2], [2,3]]).astype("int64").reshape(batch_size,2,1)
#batch2 = np.array([[1,2], [2,3], [3,4]]).astype("int64").reshape(batch_size,2,1)
batch1 = np.ones(
(batch_size, 1)).astype("int64").reshape(batch_size, 1, 1)
batch2 = np.ones(
(batch_size, 1)).astype("int64").reshape(batch_size, 1, 1)
data = [batch1, batch2]
data = [batch1]
filelist = []
for i in range(2):
filelist.append("test_pipeline_input_" + str(i))
for f in filelist:
with open(f, "wb") as fout:
for batch_data in data:
for ins in batch_data:
for slot in ins:
binary_print(slot, fout)
dataset = fluid.DatasetFactory().create_dataset("FileInstantDataset")
dataset.set_use_var([x])
dataset.set_batch_size(batch_size)
dataset.set_filelist(filelist)
block = fluid.default_startup_program().global_block()
block.append_op(
type='c_comm_init_all', attrs={'ring_id': 0,
'devices': [0, 1]})
with open("main_program", "w") as fout:
fout.write(str(fluid.default_main_program()))
with open("startup_program", "w") as fout:
fout.write(str(fluid.default_startup_program()))
exe.run(fluid.default_startup_program())
emb_t = fluid.global_scope().find_var("embx").get_tensor()
para = np.ones((10, 2)).astype("float32")
emb_t.set(para, place)
for epoch in range(1):
exe.train_from_dataset(
fluid.default_main_program(),
dataset,
thread=2,
debug=False,
fetch_list=[],
fetch_info=[],
print_period=1)
batch_size = np.array(fluid.global_scope().find_var("hehe.batch_size")
.get_tensor())
self.assertEqual(batch_size[0], 10002)
b = np.array(fluid.global_scope().find_var("hehe.batch_sum").get_tensor(
))
self.assertEqual(b[0], 100002)
c = np.array(fluid.global_scope().find_var("hehe.batch_square_sum")
.get_tensor())
self.assertEqual(c[0], 10162)
for f in filelist:
os.remove(f)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册