未验证 提交 21d95be0 编写于 作者: K Kaipeng Deng 提交者: GitHub

Add inplace abn op (#22806)

* add inplace_abn_op. test=develop
上级 821534ef
......@@ -24,16 +24,24 @@ namespace ir {
class SyncBatchNormPass : public Pass {
protected:
void ApplyImpl(ir::Graph *graph) const override {
VLOG(3) << "Use synchronous batch norm";
VLOG(3) << "Use synchronize batch norm";
for (const Node *n : graph->Nodes()) {
if (n->IsOp() && n->Op()) {
auto *op = n->Op();
// process synchronize in batch_norm
if (op->Type() == "batch_norm") {
op->SetType("sync_batch_norm");
}
if (op->Type() == "batch_norm_grad") {
op->SetType("sync_batch_norm_grad");
}
// process synchronize in inplace_abn
if (op->Type() == "inplace_abn") {
op->SetAttr("use_sync_bn", true);
}
if (op->Type() == "inplace_abn_grad") {
op->SetAttr("use_sync_bn", true);
}
}
}
}
......
......@@ -41,6 +41,8 @@ const std::unordered_set<std::string> op_has_unsed_vars_white_list = {
"batch_norm_grad", // 0
"sync_batch_norm", // 0
"sync_batch_norm_grad", // 0
"inplace_abn", // 0
"inplace_abn_grad", // 0
"dgc_momentum", // 0
"fake_quantize_range_abs_max", // 0
"rmsprop", // 0
......
......@@ -59,7 +59,7 @@ if(WITH_COVERAGE OR NOT WITH_AVX OR WIN32)
endif()
register_operators(EXCLUDES py_func_op warpctc_op dgc_op
sync_batch_norm_op ${OP_MKL_DEPS} DEPS ${OP_HEADER_DEPS} ${OP_PREFETCH_DEPS})
sync_batch_norm_op ${OP_MKL_DEPS} DEPS ${OP_HEADER_DEPS} ${OP_PREFETCH_DEPS})
if (WITH_GPU)
# warpctc_op needs cudnn 7 above
......
......@@ -82,16 +82,18 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const {
PADDLE_ENFORCE_GE(
x_dims.size(), 2,
"ShapeError: the dimension of input X must greater than or equal to 2."
"But received: the shape of input X = [%s], the dimension of input X ="
"[%d]",
x_dims, x_dims.size());
platform::errors::InvalidArgument(
"ShapeError: the dimension of input "
"X must greater than or equal to 2. But received: the shape of input "
"X = [%s], the dimension of input X =[%d]",
x_dims, x_dims.size()));
PADDLE_ENFORCE_LE(
x_dims.size(), 5,
"ShapeError: the dimension of input X must smaller than or equal to 5."
"But received: the shape of input X = [%s], the dimension of input X ="
"[%d]",
x_dims, x_dims.size());
platform::errors::InvalidArgument(
"ShapeError: the dimension of input X "
"must smaller than or equal to 5. But received: the shape of input X "
"= [%s], the dimension of input X = [%d]",
x_dims, x_dims.size()));
const int64_t C =
((this->IsMKLDNNType() == true) || (data_layout == DataLayout::kNCHW)
......@@ -146,14 +148,18 @@ framework::OpKernelType BatchNormOp::GetExpectedKernelType(
if (input_data_type == framework::proto::VarType::FP64) {
bn_param_type = framework::proto::VarType::FP64;
}
PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input<Tensor>("Scale")->type(),
"Scale input should be of float type");
PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input<Tensor>("Bias")->type(),
"Bias input should be of float type");
PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input<Tensor>("Mean")->type(),
"Mean input should be of float type");
PADDLE_ENFORCE_EQ(
bn_param_type, ctx.Input<Tensor>("Scale")->type(),
platform::errors::InvalidArgument("Scale input should be of float type"));
PADDLE_ENFORCE_EQ(
bn_param_type, ctx.Input<Tensor>("Bias")->type(),
platform::errors::InvalidArgument("Bias input should be of float type"));
PADDLE_ENFORCE_EQ(
bn_param_type, ctx.Input<Tensor>("Mean")->type(),
platform::errors::InvalidArgument("Mean input should be of float type"));
PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input<Tensor>("Variance")->type(),
"Variance input should be of float type");
platform::errors::InvalidArgument(
"Variance input should be of float type"));
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework::LibraryType library = framework::LibraryType::kPlain;
......@@ -204,8 +210,13 @@ void BatchNormOpMaker::Make() {
AddAttr<float>("epsilon", "")
.SetDefault(1e-5)
.AddCustomChecker([](const float &epsilon) {
PADDLE_ENFORCE(epsilon >= 0.0f && epsilon <= 0.001f,
"'epsilon' should be between 0.0 and 0.001.");
PADDLE_ENFORCE_GE(
epsilon, 0.0f,
platform::errors::InvalidArgument(
"'epsilon' should be greater or equal than 0.0."));
PADDLE_ENFORCE_LE(epsilon, 0.001f,
platform::errors::InvalidArgument(
"'epsilon' should be less or equal than 0.001."));
});
AddAttr<std::string>("data_layout", "").SetDefault("NCHW");
AddInput("X", "The input tensor");
......@@ -259,6 +270,7 @@ void BatchNormOpMaker::Make() {
"global mean and variance are also used during train time, "
"the BN acts as scaling and shiffting.")
.SetDefault(false);
AddComment(R"DOC(
Batch Normalization.
......@@ -290,8 +302,12 @@ class BatchNormKernel<platform::CPUDeviceContext, T>
const auto *x = ctx.Input<Tensor>("X");
const auto &x_dims = x->dims();
PADDLE_ENFORCE(x_dims.size() >= 2 && x_dims.size() <= 5,
"The Input dim size should be between 2 and 5");
PADDLE_ENFORCE_GE(x_dims.size(), 2,
platform::errors::InvalidArgument(
"The Input X dim size should be larger than 1."));
PADDLE_ENFORCE_LE(x_dims.size(), 5,
platform::errors::InvalidArgument(
"The Input X dim size should be less than 6."));
const int N = x_dims[0];
const int C =
(data_layout == DataLayout::kNCHW ? x_dims[1]
......@@ -299,6 +315,7 @@ class BatchNormKernel<platform::CPUDeviceContext, T>
const int sample_size = x->numel() / N / C;
auto *y = ctx.Output<Tensor>("Y");
auto *mean_out = ctx.Output<Tensor>("MeanOut");
auto *variance_out = ctx.Output<Tensor>("VarianceOut");
auto *saved_mean = ctx.Output<Tensor>("SavedMean");
......@@ -432,14 +449,18 @@ class BatchNormKernel<platform::CPUDeviceContext, T>
void BatchNormGradOp::InferShape(framework::InferShapeContext *ctx) const {
// check input
PADDLE_ENFORCE(ctx->HasInput("X"));
PADDLE_ENFORCE(ctx->HasInput("Scale"), "Input(scale) should not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")),
"Input(Y@GRAD) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("SavedMean"),
"Input(SavedMean) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("SavedVariance"),
"Input(SavedVariance) should not be null");
PADDLE_ENFORCE_EQ(
ctx->HasInput("Scale"), true,
platform::errors::InvalidArgument("Input(scale) should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput(framework::GradVarName("Y")), true,
platform::errors::InvalidArgument("Input(Y@GRAD) should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("SavedMean"), true,
platform::errors::InvalidArgument(
"Input(SavedMean) should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("SavedVariance"), true,
platform::errors::InvalidArgument(
"Input(SavedVariance) should not be null"));
// check output
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), "");
......@@ -456,25 +477,37 @@ void BatchNormGradOp::InferShape(framework::InferShapeContext *ctx) const {
const bool use_global_stats = ctx->Attrs().Get<bool>("use_global_stats");
if (use_global_stats) {
PADDLE_ENFORCE(!ctx->Attrs().Get<bool>("use_mkldnn"),
"Using global stats during training is not supported "
"in gradient op kernel of batch_norm_mkldnn_op now.");
PADDLE_ENFORCE_EQ(
!ctx->Attrs().Get<bool>("use_mkldnn"), true,
platform::errors::InvalidArgument(
"Using global stats during training is not supported "
"in gradient op kernel of batch_norm_mkldnn_op now."));
}
const auto x_dims = ctx->GetInputDim("X");
const DataLayout data_layout = framework::StringToDataLayout(
ctx->Attrs().Get<std::string>("data_layout"));
const int C =
((this->IsMKLDNNType() == true) || (data_layout == DataLayout::kNCHW)
? x_dims[1]
: x_dims[x_dims.size() - 1]);
// batch_norm_grad with inplace takes Y as input, without inplace
// takes X as input. HasInput will throw exception in compile time,
// so only infer shape in run time here.
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(ctx->HasInput("X") || ctx->HasInput("Y"), true,
platform::errors::InvalidArgument(
"Input(X) and Input(Y) should not be all null."));
auto input_name = "Y";
if (ctx->HasInput("X")) input_name = "X";
const auto x_dims = ctx->GetInputDim(input_name);
const DataLayout data_layout = framework::StringToDataLayout(
ctx->Attrs().Get<std::string>("data_layout"));
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
// has_scale_grad == has_bias_grad, judge has_scale_grad is enough
if (has_scale_grad) {
ctx->SetOutputDim(framework::GradVarName("Scale"), {C});
ctx->SetOutputDim(framework::GradVarName("Bias"), {C});
const int C =
((this->IsMKLDNNType() == true) || (data_layout == DataLayout::kNCHW)
? x_dims[1]
: x_dims[x_dims.size() - 1]);
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
// has_scale_grad == has_bias_grad, judge has_scale_grad is enough
if (has_scale_grad) {
ctx->SetOutputDim(framework::GradVarName("Scale"), {C});
ctx->SetOutputDim(framework::GradVarName("Bias"), {C});
}
}
}
......@@ -482,7 +515,8 @@ framework::OpKernelType BatchNormGradOp::GetExpectedKernelType(
const framework::ExecutionContext &ctx) const {
const auto *var = ctx.InputVar(framework::GradVarName("Y"));
if (var == nullptr) {
PADDLE_THROW("can't find Y@GRAD");
PADDLE_THROW(
platform::errors::InvalidArgument("can't find gradient variable of Y"));
}
const Tensor *t = nullptr;
if (var->IsType<Tensor>()) {
......@@ -491,7 +525,8 @@ framework::OpKernelType BatchNormGradOp::GetExpectedKernelType(
t = &var->Get<LoDTensor>();
}
if (t == nullptr) {
PADDLE_THROW("can't find Y@GRAD");
PADDLE_THROW(
platform::errors::InvalidArgument("gradient variable of Y is empty"));
}
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
......@@ -541,9 +576,9 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
const auto *x = ctx.Input<Tensor>("X");
const auto *d_y = ctx.Input<Tensor>(framework::GradVarName("Y"));
const auto *scale = ctx.Input<Tensor>("Scale");
const auto *bias = ctx.Input<Tensor>("Bias");
const auto *saved_mean = ctx.Input<Tensor>("SavedMean");
// SavedVariance have been reverted in forward operator
const auto *saved_inv_variance = ctx.Input<Tensor>("SavedVariance");
......@@ -554,6 +589,30 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
const DataLayout data_layout =
framework::StringToDataLayout(data_layout_str);
auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));
// batch_norm with inplace as false will take X as grad input, which
// is same as cuDNN batch_norm backward calculation, batch_norm
// with inplace as true only take Y as input and X should be calculate
// by inverse operation of batch_norm on Y
const Tensor *x;
bool is_inplace;
if (ctx.HasInput("Y")) {
x = ctx.Input<Tensor>("Y");
is_inplace = true;
PADDLE_ENFORCE_EQ(d_x, d_y,
platform::errors::InvalidArgument(
"X@GRAD and Y@GRAD not inplace in inplace mode"));
} else {
x = ctx.Input<Tensor>("X");
is_inplace = false;
PADDLE_ENFORCE_NE(d_x, d_y,
platform::errors::InvalidArgument(
"X@GRAD and Y@GRAD inplaced in non-inplace mode"));
}
PADDLE_ENFORCE_EQ(
is_test, false,
platform::errors::InvalidArgument(
......@@ -564,8 +623,12 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
// Get the size for each dimension.
// NCHW [batch_size, in_channels, in_height, in_width]
const auto &x_dims = x->dims();
PADDLE_ENFORCE(x_dims.size() >= 2 && x_dims.size() <= 5,
"The Input dim size should be between 2 and 5");
PADDLE_ENFORCE_GE(x_dims.size(), 2,
platform::errors::InvalidArgument(
"The Input X dim size should be larger than 1."));
PADDLE_ENFORCE_LE(x_dims.size(), 5,
platform::errors::InvalidArgument(
"The Input X dim size should be less than 6."));
const int N = x_dims[0];
const int C =
(data_layout == DataLayout::kNCHW ? x_dims[1]
......@@ -573,10 +636,6 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
const int sample_size = x->numel() / N / C;
// init output
auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));
d_x->mutable_data<T>(ctx.GetPlace());
const T *mean_data = saved_mean->data<T>();
......@@ -596,6 +655,7 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
}
ConstEigenVectorArrayMap<T> scale_arr(scale->data<T>(), C);
ConstEigenVectorArrayMap<T> bias_arr(bias->data<T>(), C);
ConstEigenVectorArrayMap<T> mean_arr(mean_data, C);
ConstEigenVectorArrayMap<T> inv_var_arr(inv_var_data, C);
......@@ -643,13 +703,30 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
dy_sum_arr.setZero();
dy_mul_x_sub_mean_mul_invstd_sum_arr.setZero();
// inplace calculation
// Y: ((x - est_mean) * (inv_var) * scale + bias
// formula transform ====>
// (x * inv_var * scale) + (bias - est_mean * inv_var * scale)
// X: (y - bias) / scale / (inv_var) + est_mean
// formula transform ====>
// (y - bias) / (scale * inv_var) + est_mean
switch (data_layout) {
case DataLayout::kNCHW: {
if (is_inplace) {
auto px = *x;
EigenArrayMap<T> x_data(px.mutable_data<T>(ctx.GetPlace()),
sample_size, N * C);
ConstEigenArrayMap<T> y_data(x->data<T>(), sample_size, N * C);
for (int nc = 0; nc < N * C; ++nc) {
x_data.col(nc) = (y_data.col(nc) - bias_arr(nc % C)) /
scale_inv_var_nhw(nc % C) / scale_coefff +
mean_arr(nc % C);
}
}
ConstEigenArrayMap<T> x_arr(x->data<T>(), sample_size, N * C);
ConstEigenArrayMap<T> d_y_arr(d_y->data<T>(), sample_size, N * C);
EigenArrayMap<T> d_x_arr(d_x->mutable_data<T>(ctx.GetPlace()),
sample_size, N * C);
d_x_arr.setZero();
for (int nc = 0; nc < N * C; ++nc) {
int c = nc % C;
......@@ -667,7 +744,7 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
if (!use_global_stats) {
for (int nc = 0; nc < N * C; ++nc) {
int c = nc % C;
d_x_arr.col(nc) +=
d_x_arr.col(nc) =
scale_inv_var_nhw(c) *
(d_y_arr.col(nc) * N * sample_size - dy_sum_arr(c) -
(x_arr.col(nc) - mean_arr[c]) *
......@@ -676,17 +753,27 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
} else {
for (int nc = 0; nc < N * C; ++nc) {
int c = nc % C;
d_x_arr.col(nc) += scale_inv_var_nhw(c) * d_y_arr.col(nc);
d_x_arr.col(nc) = scale_inv_var_nhw(c) * d_y_arr.col(nc);
}
}
break;
}
case DataLayout::kNHWC: {
if (is_inplace) {
auto px = *x;
EigenArrayMap<T> x_data(px.mutable_data<T>(ctx.GetPlace()), C,
N * sample_size);
ConstEigenArrayMap<T> y_data(x->data<T>(), C, N * sample_size);
for (int nhw = 0; nhw < N * sample_size; nhw++) {
x_data.col(nhw) = (y_data.col(nhw) - bias_arr) / scale_inv_var_nhw /
scale_coefff +
mean_arr;
}
}
ConstEigenArrayMap<T> x_arr(x->data<T>(), C, N * sample_size);
ConstEigenArrayMap<T> d_y_arr(d_y->data<T>(), C, N * sample_size);
EigenArrayMap<T> d_x_arr(d_x->mutable_data<T>(ctx.GetPlace()), C,
N * sample_size);
d_x_arr.setZero();
for (int nhw = 0; nhw < N * sample_size; ++nhw) {
dy_sum_arr += d_y_arr.col(nhw);
......@@ -701,7 +788,7 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
if (!use_global_stats) {
for (int nhw = 0; nhw < N * sample_size; ++nhw) {
d_x_arr.col(nhw) +=
d_x_arr.col(nhw) =
scale_inv_var_nhw *
(d_y_arr.col(nhw) * N * sample_size - dy_sum_arr -
(x_arr.col(nhw) - mean_arr) *
......@@ -709,7 +796,7 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
}
} else {
for (int nhw = 0; nhw < N * sample_size; ++nhw) {
d_x_arr.col(nhw) += scale_inv_var_nhw * d_y_arr.col(nhw);
d_x_arr.col(nhw) = scale_inv_var_nhw * d_y_arr.col(nhw);
}
}
break;
......
......@@ -40,8 +40,9 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"It must use CUDAPlace.");
PADDLE_ENFORCE_EQ(
platform::is_gpu_place(ctx.GetPlace()), true,
platform::errors::InvalidArgument("It must use CUDAPlace."));
double epsilon = static_cast<double>(ctx.Attr<float>("epsilon"));
float momentum = ctx.Attr<float>("momentum");
const bool is_test = ctx.Attr<bool>("is_test");
......@@ -355,6 +356,41 @@ static __global__ void KeBNBackwardData(const T *dy,
}
}
template <typename T>
static __global__ void KeBNRestoreData(const framework::DataLayout layout, T *x,
const BatchNormParamType<T> *scale,
const BatchNormParamType<T> *bias,
const BatchNormParamType<T> *mean,
const BatchNormParamType<T> *variance,
double epsilon, int C, int M,
const int num, const T *y) {
int gid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
for (int i = gid; i < num; i += stride) {
const int c = layout == framework::DataLayout::kNCHW ? (i / M) % C : i % C;
auto y_i = static_cast<BatchNormParamType<T>>(y[i]);
auto x_i = (y_i - bias[c]) / scale[c] / variance[c] + mean[c];
x[i] = static_cast<T>(x_i);
}
}
template <typename T>
class InplaceHelper {
public:
void operator()(const framework::DataLayout layout, T *x,
const BatchNormParamType<T> *scale,
const BatchNormParamType<T> *bias,
const BatchNormParamType<T> *mean,
const BatchNormParamType<T> *variance, double epsilon, int C,
int M, const int num, const T *y, int grid2, const int block,
const cudaStream_t &stream) {
PADDLE_ENFORCE_EQ(x, y, platform::errors::InvalidArgument(
"X and Y should be inplaced in inplace mode"));
KeBNRestoreData<<<grid2, block, 0, stream>>>(
layout, x, scale, bias, mean, variance, epsilon, C, M, num, y);
}
};
template <typename T, int BlockDim, framework::DataLayout layout>
static __global__ void BNBackwardData(const T *dy,
const BatchNormParamType<T> *scale,
......@@ -417,17 +453,43 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"It must use CUDAPlace.");
PADDLE_ENFORCE_EQ(
platform::is_gpu_place(ctx.GetPlace()), true,
platform::errors::InvalidArgument("It must use CUDAPlace."));
double epsilon = static_cast<double>(ctx.Attr<float>("epsilon"));
const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
const bool use_global_stats = ctx.Attr<bool>("use_global_stats");
const DataLayout data_layout =
framework::StringToDataLayout(data_layout_str);
const auto *x = ctx.Input<Tensor>("X");
const auto *d_y = ctx.Input<Tensor>(framework::GradVarName("Y"));
const auto *scale = ctx.Input<Tensor>("Scale");
const auto *bias = ctx.Input<Tensor>("Bias");
auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));
// batch_norm with inplace as false will take X as grad input, which
// is same as cuDNN batch_norm backward calculation, batch_norm
// with inplace as true only take Y as input and X should be calculate
// by inverse operation of batch_norm on Y
const Tensor *x;
bool is_inplace;
if (ctx.HasInput("Y")) {
x = ctx.Input<Tensor>("Y");
is_inplace = true;
PADDLE_ENFORCE_EQ(d_x, d_y,
platform::errors::InvalidArgument(
"X@GRAD and Y@GRAD not inplace in inplace mode"));
} else {
x = ctx.Input<Tensor>("X");
is_inplace = false;
PADDLE_ENFORCE_NE(d_x, d_y,
platform::errors::InvalidArgument(
"X@GRAD and Y@GRAD inplaced in non-inplace mode"));
}
const bool is_test = ctx.Attr<bool>("is_test");
PADDLE_ENFORCE_EQ(
is_test, false,
......@@ -444,11 +506,8 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
ExtractNCWHD(x_dims, data_layout, &N, &C, &H, &W, &D);
// init output
auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));
d_x->mutable_data<T>(ctx.GetPlace());
if (d_scale && d_bias) {
d_scale->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
d_bias->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
......@@ -505,6 +564,8 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
const int max_blocks = std::max(max_threads / block, 1);
int grid1 = (num + block - 1) / block;
int grid2 = std::min(C, max_blocks);
auto stream = dev_ctx.stream();
InplaceHelper<T> inplace_functor;
if (!use_global_stats) {
if ((N * H * W * D) == 1) {
......@@ -555,6 +616,14 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
const auto *saved_var_data =
saved_var->template data<BatchNormParamType<T>>();
if (is_inplace) {
inplace_functor(compute_format, transformed_x.data<T>(),
scale->template data<BatchNormParamType<T>>(),
bias->template data<BatchNormParamType<T>>(),
saved_mean_data, saved_var_data, epsilon, C, H * W * D,
num, transformed_x.data<T>(), grid2, block, stream);
}
if (d_scale && d_bias) {
bool called = false;
#if CUDNN_VERSION_MIN(7, 4, 1)
......@@ -680,30 +749,41 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
const auto *running_var_data =
running_var->template data<BatchNormParamType<T>>();
if (is_inplace) {
auto px = *x;
inplace_functor(data_layout, px.mutable_data<T>(ctx.GetPlace()),
scale->template data<BatchNormParamType<T>>(),
bias->template data<BatchNormParamType<T>>(),
running_mean_data, running_var_data, epsilon, C,
H * W * D, num, x->data<T>(), grid2, block, stream);
}
if (compute_format == DataLayout::kNCHW) {
if (d_x) {
KeBNBackwardData<T, framework::DataLayout::kNCHW><<<
grid1, block, 0, dev_ctx.stream()>>>(
KeBNBackwardData<
T, framework::DataLayout::kNCHW><<<grid1, block, 0, stream>>>(
d_y->data<T>(), scale->data<BatchNormParamType<T>>(),
running_var_data, epsilon, C, H * W, num, d_x->data<T>());
}
if (d_scale && d_bias) {
KeBNBackwardScaleBias<T, block, framework::DataLayout::kNCHW><<<
grid2, block, 0, dev_ctx.stream()>>>(
KeBNBackwardScaleBias<
T, block,
framework::DataLayout::kNCHW><<<grid2, block, 0, stream>>>(
d_y->data<T>(), x->data<T>(), running_mean_data, running_var_data,
epsilon, N, C, H * W * D, d_scale->data<BatchNormParamType<T>>(),
d_bias->data<BatchNormParamType<T>>());
}
} else {
if (d_x) {
KeBNBackwardData<T, framework::DataLayout::kNHWC><<<
grid1, block, 0, dev_ctx.stream()>>>(
KeBNBackwardData<
T, framework::DataLayout::kNHWC><<<grid1, block, 0, stream>>>(
d_y->data<T>(), scale->data<BatchNormParamType<T>>(),
running_var_data, epsilon, C, H * W, num, d_x->data<T>());
}
if (d_scale && d_bias) {
KeBNBackwardScaleBias<T, block, framework::DataLayout::kNHWC><<<
grid2, block, 0, dev_ctx.stream()>>>(
KeBNBackwardScaleBias<
T, block,
framework::DataLayout::kNHWC><<<grid2, block, 0, stream>>>(
d_y->data<T>(), x->data<T>(), running_mean_data, running_var_data,
epsilon, N, C, H * W * D, d_scale->data<BatchNormParamType<T>>(),
d_bias->data<BatchNormParamType<T>>());
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/inplace_abn_op.h"
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/operators/batch_norm_op.h"
namespace paddle {
namespace operators {
class InplaceABNOp : public paddle::operators::BatchNormOp {
public:
using paddle::operators::BatchNormOp::BatchNormOp;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
// By default, the type of the scale, bias, mean,
// and var tensors should both be float. (For float or float16 input tensor)
// or double (For double input tensor).
auto bn_param_type = framework::proto::VarType::FP32;
if (input_data_type == framework::proto::VarType::FP64) {
bn_param_type = framework::proto::VarType::FP64;
}
PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input<Tensor>("Scale")->type(),
platform::errors::InvalidArgument(
"Scale input should be of float type"));
PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input<Tensor>("Bias")->type(),
platform::errors::InvalidArgument(
"Bias input should be of float type"));
PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input<Tensor>("Mean")->type(),
platform::errors::InvalidArgument(
"Mean input should be of float type"));
PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input<Tensor>("Variance")->type(),
platform::errors::InvalidArgument(
"Variance input should be of float type"));
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout,
library);
}
};
class InplaceABNGradOp : public paddle::operators::BatchNormGradOp {
public:
using paddle::operators::BatchNormGradOp::BatchNormGradOp;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
const auto* var = ctx.InputVar(framework::GradVarName("Y"));
auto input_data_type = ctx.Input<Tensor>("Y")->type();
if (var == nullptr) {
PADDLE_THROW(platform::errors::InvalidArgument(
"can't find gradient variable of Y"));
}
const Tensor* t = nullptr;
if (var->IsType<Tensor>()) {
t = &var->Get<Tensor>();
} else if (var->IsType<LoDTensor>()) {
t = &var->Get<LoDTensor>();
}
if (t == nullptr) {
PADDLE_THROW(
platform::errors::InvalidArgument("gradient variable of Y is empty"));
}
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout,
library);
}
};
class InplaceABNOpMaker : public paddle::operators::BatchNormOpMaker {
public:
void Make() override {
BatchNormOpMaker::Make();
AddAttr<std::string>(
"activation",
"(enum string, default identity, can be identity|elu|leaky-relu) "
"The activation type used for output candidate {h}_t.")
.SetDefault("");
AddAttr<float>("alpha",
"(float, default 1.0) Only used in inplace-abn kernel,"
"the activation type(identity|elu|leakyrelu) would be fused "
"with batch_norm, "
"this is the alpha value for elu|leakyrelu.")
.SetDefault(0.1f);
AddAttr<bool>("use_sync_bn",
"(bool, default false) Whether use synchronize batch "
"normalization.")
.SetDefault(false);
}
};
template <typename T>
class InplaceABNOpGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType(this->ForwardOpType() + "_grad");
op->SetInput("Y", this->Output("Y"));
op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y"));
op->SetInput("Scale", this->Input("Scale"));
op->SetInput("Bias", this->Input("Bias"));
op->SetInput("SavedMean", this->Output("SavedMean"));
op->SetInput("SavedVariance", this->Output("SavedVariance"));
// used when setting use_global_stats True during training
if (boost::get<bool>(this->GetAttr("use_global_stats"))) {
op->SetInput("Mean", this->Output("MeanOut"));
op->SetInput("Variance", this->Output("VarianceOut"));
}
op->SetAttrMap(this->Attrs());
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("Scale"), this->InputGrad("Scale"));
op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias"));
}
};
template <typename DeviceContext, typename T>
class InplaceABNKernel
: public paddle::operators::BatchNormKernel<DeviceContext, T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Output<Tensor>("Y");
PADDLE_ENFORCE_EQ(x, y, platform::errors::InvalidArgument(
"X and Y not inplaced in inplace mode"));
auto activation =
GetInplaceABNActivationType(ctx.Attr<std::string>("activation"));
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
BatchNormKernel<DeviceContext, T>::Compute(ctx);
auto cur_y = EigenVector<T>::Flatten(*y);
InplaceABNActivation<DeviceContext, T> functor;
functor.Compute(ctx, activation, place, cur_y, cur_y);
}
};
template <typename DeviceContext, typename T>
class InplaceABNGradKernel
: public paddle::operators::BatchNormGradKernel<DeviceContext, T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* y = ctx.Input<Tensor>("Y");
auto* d_y = ctx.Input<Tensor>(framework::GradVarName("Y"));
auto* d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
PADDLE_ENFORCE_EQ(d_x, d_y,
platform::errors::InvalidArgument(
"X@GRAD and Y@GRAD not inplaced in inplace mode"));
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
auto activation =
GetInplaceABNActivationType(ctx.Attr<std::string>("activation"));
auto py = *y;
auto pd_y = *d_y;
auto cur_y = EigenVector<T>::Flatten(py);
auto cur_dy = EigenVector<T>::Flatten(pd_y);
InplaceABNActivation<DeviceContext, T> functor;
functor.GradCompute(ctx, activation, place, cur_y, cur_y, cur_dy, cur_dy);
BatchNormGradKernel<DeviceContext, T>::Compute(ctx);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(inplace_abn, ops::InplaceABNOp, ops::InplaceABNOpMaker,
ops::BatchNormOpInferVarType,
ops::InplaceABNOpGradMaker<paddle::framework::OpDesc>,
ops::InplaceABNOpGradMaker<paddle::imperative::OpBase>)
REGISTER_OPERATOR(inplace_abn_grad, ops::InplaceABNGradOp)
REGISTER_OP_CPU_KERNEL(
inplace_abn,
ops::InplaceABNKernel<paddle::platform::CPUDeviceContext, float>,
ops::InplaceABNKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
inplace_abn_grad,
ops::InplaceABNGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::InplaceABNGradKernel<paddle::platform::CPUDeviceContext, double>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/batch_norm_op.h"
#include "paddle/fluid/operators/inplace_abn_op.h"
#include "paddle/fluid/operators/sync_batch_norm_op.cu.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class InplaceABNKernel
: public paddle::operators::SyncBatchNormKernel<DeviceContext, T>,
public paddle::operators::BatchNormKernel<DeviceContext, T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* y = ctx.Output<Tensor>("Y");
auto* x = ctx.Input<Tensor>("X");
PADDLE_ENFORCE_EQ(x, y, platform::errors::InvalidArgument(
"X and Y not inplaced in inplace mode"));
auto activation =
GetInplaceABNActivationType(ctx.Attr<std::string>("activation"));
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
if (ctx.Attr<bool>("use_sync_bn")) {
SyncBatchNormKernel<DeviceContext, T>::Compute(ctx);
} else {
BatchNormKernel<DeviceContext, T>::Compute(ctx);
}
auto cur_y = EigenVector<T>::Flatten(*y);
InplaceABNActivation<DeviceContext, T> functor;
functor.Compute(ctx, activation, place, cur_y, cur_y);
}
};
// Deriving the Gradient for the Backward Pass of Batch Normalization
// https://kevinzakka.github.io/2016/09/14/batch_normalization/
template <typename DeviceContext, typename T>
class InplaceABNGradKernel
: public paddle::operators::SyncBatchNormGradKernel<DeviceContext, T>,
public paddle::operators::BatchNormGradKernel<DeviceContext, T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto* y = ctx.Input<Tensor>("Y");
auto* d_y = ctx.Input<Tensor>(framework::GradVarName("Y"));
auto* d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
PADDLE_ENFORCE_EQ(d_x, d_y,
platform::errors::InvalidArgument(
"X@GRAD and Y@GRAD not inplaced in inplace mode"));
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
auto activation =
GetInplaceABNActivationType(ctx.Attr<std::string>("activation"));
auto py = *y;
auto pd_y = *d_y;
auto cur_y = EigenVector<T>::Flatten(py);
auto cur_dy = EigenVector<T>::Flatten(pd_y);
InplaceABNActivation<DeviceContext, T> functor;
functor.GradCompute(ctx, activation, place, cur_y, cur_y, cur_dy, cur_dy);
if (ctx.Attr<bool>("use_sync_bn")) {
SyncBatchNormGradKernel<DeviceContext, T>::Compute(ctx);
} else {
BatchNormGradKernel<DeviceContext, T>::Compute(ctx);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(inplace_abn,
ops::InplaceABNKernel<plat::CUDADeviceContext, float>,
ops::InplaceABNKernel<plat::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
inplace_abn_grad, ops::InplaceABNGradKernel<plat::CUDADeviceContext, float>,
ops::InplaceABNGradKernel<plat::CUDADeviceContext, double>);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
enum InplaceABNActivationType { identity = 0, leakyrelu = 1, elu = 2 };
inline InplaceABNActivationType GetInplaceABNActivationType(
const std::string& type) {
if (type == "leaky_relu") {
return InplaceABNActivationType::leakyrelu;
} else if (type == "elu") {
return InplaceABNActivationType::elu;
} else if (type == "identity" || type == "") {
return InplaceABNActivationType::identity;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"unsupported activation type %s for Op(inplace_abn)", type));
}
}
template <typename DeviceContext, typename T>
class InplaceABNActivation {
private:
template <typename Functor>
void setAttrs(const framework::ExecutionContext& ctx, Functor* functor) {
auto attrs = functor->GetAttrs();
for (auto& attr : attrs) {
*attr.second = ctx.Attr<float>(attr.first);
}
}
template <typename Functor, typename... Args>
void compute(const framework::ExecutionContext& ctx, Functor* functor,
Args... args) {
setAttrs(ctx, functor);
(*functor)(args...);
}
public:
template <typename Device, typename X, typename Y>
void Compute(const framework::ExecutionContext& ctx, const int act_type,
const Device& d, X x, Y y) {
if (act_type == InplaceABNActivationType::identity) {
y.device(d) = x;
} else if (act_type == InplaceABNActivationType::leakyrelu) {
LeakyReluFunctor<T> functor;
compute(ctx, &functor, d, x, y);
} else if (act_type == InplaceABNActivationType::elu) {
ELUFunctor<T> functor;
compute(ctx, &functor, d, x, y);
} else {
PADDLE_THROW(
platform::errors::InvalidArgument("unsupported activation type"));
}
}
template <typename Device, typename X, typename Y, typename DX, typename DY>
void GradCompute(const framework::ExecutionContext& ctx, const int act_type,
const Device& d, X x, Y y, DX dx, DY dy) {
const float alpha = ctx.Attr<float>("alpha");
if (act_type == InplaceABNActivationType::identity) {
x.device(d) = y;
dx.device(d) = dy;
} else if (act_type == InplaceABNActivationType::leakyrelu) {
auto temp1 = (y < static_cast<T>(0)).template cast<T>().eval() /
static_cast<T>(alpha);
auto temp2 = (y >= static_cast<T>(0)).template cast<T>().eval();
x.device(d) = y * (temp1 + temp2).template cast<T>();
LeakyReluGradFunctor<T> functor;
compute(ctx, &functor, d, x, y, dy, dx);
} else if (act_type == InplaceABNActivationType::elu) {
auto temp1 = (y >= static_cast<T>(0)).template cast<T>().eval();
auto temp = (y < static_cast<T>(0)).template cast<T>().eval();
auto temp2 = (y * temp / static_cast<T>(alpha) + static_cast<T>(1)).log();
x.device(d) = (y * temp1 + temp2).template cast<T>();
ELUGradFunctor<T> functor;
compute(ctx, &functor, d, x, y, dy, dx);
} else {
PADDLE_THROW(
platform::errors::InvalidArgument("unsupported activation type"));
}
}
};
} // namespace operators
} // namespace paddle
......@@ -12,113 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// clang-format off
#include <algorithm>
#include <cfloat>
#include <cmath>
#include <string>
#include <vector>
#include "cub/cub.cuh"
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/operators/batch_norm_op.h"
#include "paddle/fluid/operators/norm_utils.h"
#include "paddle/fluid/platform/cudnn_helper.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/nccl_helper.h"
#include "paddle/fluid/operators/sync_batch_norm_op.cu.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using DataLayout = framework::DataLayout;
template <typename T>
using CudnnDataType = platform::CudnnDataType<T>;
template <typename T>
using BatchNormParamType = typename CudnnDataType<T>::BatchNormParamType;
template <typename T, int BlockDim, framework::DataLayout layout>
__global__ void KeLocalStats(const T *x, int N, int M, int C,
BatchNormParamType<T> *mean_var) {
typedef cub::BlockReduce<BatchNormParamType<T>, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
for (int k = blockIdx.x; k < C; k += gridDim.x) {
BatchNormParamType<T> x_sum = 0.;
BatchNormParamType<T> x2_sum = 0.;
for (int i = threadIdx.x; i < N * M; i += BlockDim) {
int id = layout == framework::DataLayout::kNCHW
? (i / M) * C * M + k * M + i % M
: i * C + k;
auto x_in = static_cast<BatchNormParamType<T>>(x[id]);
x_sum += x_in;
x2_sum += x_in * x_in;
}
__syncthreads();
auto out = BlockReduce(temp_storage).Reduce(x_sum, cub::Sum());
__syncthreads();
if (threadIdx.x == 0) {
mean_var[k] = out / (N * M);
}
out = BlockReduce(temp_storage).Reduce(x2_sum, cub::Sum());
__syncthreads();
if (threadIdx.x == 0) {
mean_var[k + C] = out / (N * M);
}
}
if (blockIdx.x == 0 && threadIdx.x == 0) {
mean_var[2 * C] = static_cast<BatchNormParamType<T>>(1.0);
}
}
template <typename T>
__global__ void KeSyncAndMovingStats(
BatchNormParamType<T> *means, BatchNormParamType<T> *variances,
BatchNormParamType<T> *num_dev, const int C,
const BatchNormParamType<T> momentum, const double epsilon,
BatchNormParamType<T> *sv_mean_data, BatchNormParamType<T> *sv_inv_var_data,
BatchNormParamType<T> *moving_means,
BatchNormParamType<T> *moving_variances) {
// sync stats across multi-devices
int gid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
for (int i = gid; i < C; i += stride) {
auto mean = means[i] / (*num_dev);
auto var = variances[i] / (*num_dev);
var = var - mean * mean;
// sync stats
sv_mean_data[i] = mean;
sv_inv_var_data[i] = 1.0 / sqrt(var + epsilon);
variances[i] = var;
// moving stats
moving_means[i] = moving_means[i] * momentum + mean * (1. - momentum);
moving_variances[i] =
moving_variances[i] * momentum + var * (1. - momentum);
}
}
template <typename T, framework::DataLayout layout>
static __global__ void KeNormAffine(const T *x,
const BatchNormParamType<T> *scale,
const BatchNormParamType<T> *bias,
const BatchNormParamType<T> *mean,
const BatchNormParamType<T> *variance,
const double epsilon, const int C,
const int M, const int num, T *y) {
int gid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
for (int i = gid; i < num; i += stride) {
const int c = layout == framework::DataLayout::kNCHW ? (i / M) % C : i % C;
auto x_i = static_cast<BatchNormParamType<T>>(x[i]);
auto y_i =
(x_i - mean[c]) / sqrt(variance[c] + epsilon) * scale[c] + bias[c];
y[i] = static_cast<T>(y_i);
}
}
template <typename DeviceContext, typename T>
class SyncBatchNormKernel : public framework::OpKernel<T> {
class SyncBatchNormKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
double epsilon = static_cast<double>(ctx.Attr<float>("epsilon"));
......@@ -127,331 +28,59 @@ class SyncBatchNormKernel : public framework::OpKernel<T> {
const std::string layout_str = ctx.Attr<std::string>("data_layout");
const DataLayout layout = framework::StringToDataLayout(layout_str);
const bool use_global_stats = ctx.Attr<bool>("use_global_stats");
PADDLE_ENFORCE(
!use_global_stats,
"sync_batch_norm doesn't support to set use_global_stats True. ",
"Please use batch_norm in this case.");
PADDLE_ENFORCE_EQ(use_global_stats, false,
platform::errors::InvalidArgument(
"sync_batch_norm doesn't support "
"to set use_global_stats True. Please use batch_norm "
"in this case."));
const auto *x = ctx.Input<Tensor>("X");
const auto &x_dims = x->dims();
PADDLE_ENFORCE(x_dims.size() >= 2 && x_dims.size() <= 5,
"The Input dim size should be between 2 and 5");
int N, C, H, W, D;
ExtractNCWHD(x_dims, layout, &N, &C, &H, &W, &D);
int x_numel = x->numel();
const T *x_d = x->data<T>();
const auto *s_d = ctx.Input<Tensor>("Scale")->data<BatchNormParamType<T>>();
const auto *b_d = ctx.Input<Tensor>("Bias")->data<BatchNormParamType<T>>();
auto *y = ctx.Output<Tensor>("Y");
T *y_d = y->mutable_data<T>(ctx.GetPlace());
const BatchNormParamType<T> *mean_data = nullptr;
const BatchNormParamType<T> *var_data = nullptr;
auto &dev_ctx = ctx.cuda_device_context();
auto stream = dev_ctx.stream();
auto *comm = dev_ctx.nccl_comm();
const int block = 512;
int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
paddle::memory::AllocationPtr alloc_ptr{nullptr};
if (is_test) {
const auto *est_mean = ctx.Input<Tensor>("Mean");
const auto *est_var = ctx.Input<Tensor>("Variance");
mean_data = est_mean->data<BatchNormParamType<T>>();
var_data = est_var->data<BatchNormParamType<T>>();
} else {
// x, x^2, 1, here 1 is used to calc device num
// device num also can be got from platform::DeviceContextPool
const int bytes = (C * 2 + 1) * sizeof(BatchNormParamType<T>);
alloc_ptr = memory::Alloc(dev_ctx, bytes);
auto *stats = reinterpret_cast<BatchNormParamType<T> *>(alloc_ptr->ptr());
const int threads = 256;
int grid = std::min(C, (max_threads + threads - 1) / threads);
if (layout == framework::DataLayout::kNCHW) {
KeLocalStats<T, threads, framework::DataLayout::kNCHW>
<<<grid, threads, 0, stream>>>(x_d, N, H * W * D, C, stats);
} else {
KeLocalStats<T, threads, framework::DataLayout::kNHWC>
<<<grid, threads, 0, stream>>>(x_d, N, H * W * D, C, stats);
}
// moving mean/variance
auto *mean_out = ctx.Output<Tensor>("MeanOut");
auto *variance_out = ctx.Output<Tensor>("VarianceOut");
auto *est_mean_data =
mean_out->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
auto *est_var_data =
variance_out->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
auto *saved_mean = ctx.Output<Tensor>("SavedMean");
auto *saved_inv_variance = ctx.Output<Tensor>("SavedVariance");
auto *sv_mean_data =
saved_mean->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
auto *sv_inv_var_data =
saved_inv_variance->mutable_data<BatchNormParamType<T>>(
ctx.GetPlace());
Tensor c_g_st;
auto *c_g_st_d = c_g_st.mutable_data<BatchNormParamType<T>>(
{2 * C + 1}, platform::CPUPlace());
auto gplace = boost::get<platform::CUDAPlace>(ctx.GetPlace());
memory::Copy(platform::CPUPlace(), c_g_st_d, gplace, stats, bytes, 0);
int dtype = platform::ToNCCLDataType(mean_out->type());
// In-place operation
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllReduce(
stats, stats, 2 * C + 1, static_cast<ncclDataType_t>(dtype), ncclSum,
comm, stream));
const auto *est_mean = ctx.Input<Tensor>("Mean");
const auto *est_var = ctx.Input<Tensor>("Variance");
// Note, Input('Mean')/Input('Variance') share variable with
// Output('MeanOut')/Output('VarianceOut')
KeSyncAndMovingStats<T><<<(C + block - 1) / block, block, 0, stream>>>(
stats, stats + C, stats + 2 * C, C, momentum, epsilon, sv_mean_data,
sv_inv_var_data, est_mean_data, est_var_data);
// moving mean/variance
auto *mean_out = ctx.Output<Tensor>("MeanOut");
auto *variance_out = ctx.Output<Tensor>("VarianceOut");
mean_data = sv_mean_data;
var_data = stats + C;
}
auto *saved_mean = ctx.Output<Tensor>("SavedMean");
auto *saved_inv_variance = ctx.Output<Tensor>("SavedVariance");
int grid2 = (std::min(x_numel, max_threads) + block - 1) / block;
if (layout == framework::DataLayout::kNCHW) {
KeNormAffine<T, framework::DataLayout::kNCHW>
<<<grid2, block, 0, stream>>>(x_d, s_d, b_d, mean_data, var_data,
epsilon, C, H * W * D, x_numel, y_d);
} else {
KeNormAffine<T, framework::DataLayout::kNHWC>
<<<grid2, block, 0, stream>>>(x_d, s_d, b_d, mean_data, var_data,
epsilon, C, H * W * D, x_numel, y_d);
}
SyncBatchNormFunctor<platform::CUDADeviceContext, T>(
ctx, layout, x, y, est_mean, est_var, mean_out, variance_out,
saved_mean, saved_inv_variance, epsilon, momentum, is_test,
use_global_stats);
}
};
template <typename T, const int BlockDim, framework::DataLayout layout>
__global__ void KeBackwardLocalStats(const T *dy, const T *x,
const BatchNormParamType<T> *means, int N,
int M, int C,
BatchNormParamType<T> *sum_dy_prod) {
typedef cub::BlockReduce<BatchNormParamType<T>, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
for (int k = blockIdx.x; k < C; k += gridDim.x) {
BatchNormParamType<T> sum1 = 0.;
BatchNormParamType<T> sum2 = 0.;
auto mean = means[k];
for (int i = threadIdx.x; i < N * M; i += blockDim.x) {
int id = layout == framework::DataLayout::kNCHW
? (i / M) * C * M + k * M + i % M
: i * C + k;
auto g = static_cast<BatchNormParamType<T>>(dy[id]);
sum1 += g;
auto x_i = static_cast<BatchNormParamType<T>>(x[id]);
sum2 += g * (x_i - mean);
}
__syncthreads();
auto out = BlockReduce(temp_storage).Reduce(sum1, cub::Sum());
__syncthreads();
if (threadIdx.x == 0) {
sum_dy_prod[k] = out;
}
out = BlockReduce(temp_storage).Reduce(sum2, cub::Sum());
__syncthreads();
if (threadIdx.x == 0) {
sum_dy_prod[k + C] = out;
}
}
if (blockIdx.x == 0 && threadIdx.x == 0) {
sum_dy_prod[2 * C] = 1.0;
}
}
template <typename T, int BlockDim, framework::DataLayout layout>
static __global__ void KeBNBackwardScaleBias(
const T *dy, const T *x, const BatchNormParamType<T> *mean,
const BatchNormParamType<T> *inv_variance, const double epsilon,
const int N, const int C, const int HxW, BatchNormParamType<T> *dscale,
BatchNormParamType<T> *dbias) {
const int outer_size = C;
const int inner_size = N * HxW;
typedef cub::BlockReduce<BatchNormParamType<T>, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
for (int i = blockIdx.x; i < outer_size; i += gridDim.x) {
BatchNormParamType<T> ds_sum = 0.;
BatchNormParamType<T> db_sum = 0.;
auto inv_var_i = inv_variance[i];
auto mean_i = mean[i];
for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
const int id = layout == framework::DataLayout::kNCHW
? ((j / HxW) * C + i) * HxW + (j % HxW)
: j * outer_size + i;
auto x_i = static_cast<BatchNormParamType<T>>(x[id]);
auto dy_i = static_cast<BatchNormParamType<T>>(dy[id]);
ds_sum += dy_i * (x_i - mean_i);
db_sum += dy_i;
}
__syncthreads();
auto os = BlockReduce(temp_storage).Reduce(ds_sum, cub::Sum());
__syncthreads();
auto ob = BlockReduce(temp_storage).Reduce(db_sum, cub::Sum());
__syncthreads();
if (threadIdx.x == 0) {
dscale[i] = os * inv_var_i;
dbias[i] = ob;
}
__syncthreads();
}
}
template <typename T, framework::DataLayout layout>
static __global__ void KeBNBackwardData(
const T *dy, const T *x, const BatchNormParamType<T> *gamma,
const BatchNormParamType<T> *mean,
const BatchNormParamType<T> *inv_variance,
const BatchNormParamType<T> *g_sum_dy,
const BatchNormParamType<T> *g_sum_dy_prod,
const BatchNormParamType<T> *num_dev, const double epsilon, const int C,
const int HxW, const int num, T *dx) {
int gid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
auto scale = static_cast<BatchNormParamType<T>>(C) / num;
auto dev_num = num_dev[0];
for (int i = gid; i < num; i += stride) {
const int c = layout == framework::DataLayout::kNCHW ? i / HxW % C : i % C;
auto inv_var = inv_variance[c];
auto s_d = gamma[c];
auto gvar =
-((g_sum_dy_prod[c] / dev_num) * s_d * inv_var * (inv_var * inv_var));
auto gmean = -((g_sum_dy[c] / dev_num) * s_d * inv_var);
auto x_i = static_cast<BatchNormParamType<T>>(x[i]);
auto dy_i = static_cast<BatchNormParamType<T>>(dy[i]);
auto dx_i =
dy_i * s_d * inv_var + gmean * scale + gvar * scale * (x_i - mean[c]);
dx[i] = static_cast<T>(dx_i);
}
}
// Deriving the Gradient for the Backward Pass of Batch Normalization
// https://kevinzakka.github.io/2016/09/14/batch_normalization/
template <typename DeviceContext, typename T>
class SyncBatchNormGradKernel : public framework::OpKernel<T> {
template <typename T>
class SyncBatchNormGradKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"It must use CUDAPlace.");
PADDLE_ENFORCE_EQ(
platform::is_gpu_place(ctx.GetPlace()), true,
platform::errors::InvalidArgument("It must use CUDAPlace."));
double epsilon = static_cast<double>(ctx.Attr<float>("epsilon"));
const std::string layout_str = ctx.Attr<std::string>("data_layout");
const DataLayout layout = framework::StringToDataLayout(layout_str);
const auto *x = ctx.Input<Tensor>("X");
const auto *d_y = ctx.Input<Tensor>(framework::GradVarName("Y"));
const auto *scale = ctx.Input<Tensor>("Scale");
const auto &x_dims = x->dims();
PADDLE_ENFORCE(x_dims.size() >= 2 && x_dims.size() <= 5,
"The Input dim size should be between 2 and 5");
int N, C, H, W, D;
ExtractNCWHD(x_dims, layout, &N, &C, &H, &W, &D);
const auto *bias = ctx.Input<Tensor>("Bias");
// init output
auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));
d_x->mutable_data<T>(ctx.GetPlace());
if (d_scale && d_bias) {
d_scale->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
d_bias->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
}
PADDLE_ENFORCE_EQ(scale->dims().size(), 1UL);
PADDLE_ENFORCE_EQ(scale->dims()[0], C);
std::vector<int> dims;
std::vector<int> strides;
if (layout == DataLayout::kNCHW) {
dims = {N, C, H, W, D};
strides = {C * H * W * D, H * W * D, W * D, D, 1};
} else {
dims = {N, C, H, W, D};
strides = {H * W * C * D, 1, W * D * C, D * C, C};
}
const T *x_d = x->data<T>();
const T *dy_d = d_y->data<T>();
auto &dev_ctx = ctx.cuda_device_context();
auto stream = dev_ctx.stream();
auto *comm = dev_ctx.nccl_comm();
const auto *saved_mean =
ctx.Input<Tensor>("SavedMean")->data<BatchNormParamType<T>>();
const auto *saved_inv_var =
ctx.Input<Tensor>("SavedVariance")->data<BatchNormParamType<T>>();
const int bytes = (C * 2 + 1) * sizeof(BatchNormParamType<T>);
auto alloc_ptr = memory::Alloc(dev_ctx, bytes);
auto *stats = reinterpret_cast<BatchNormParamType<T> *>(alloc_ptr->ptr());
const int threads = 256;
int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
int grid = std::min(C, (max_threads + threads - 1) / threads);
int x_numel = x->numel();
int fsize = H * W * D;
if (layout == framework::DataLayout::kNCHW) {
KeBackwardLocalStats<T, threads, framework::DataLayout::kNCHW>
<<<grid, threads, 0, stream>>>(dy_d, x_d, saved_mean, N, fsize, C,
stats);
} else {
KeBackwardLocalStats<T, threads, framework::DataLayout::kNHWC>
<<<grid, threads, 0, stream>>>(dy_d, x_d, saved_mean, N, fsize, C,
stats);
}
int dtype = platform::ToNCCLDataType(scale->type());
// In-place operation
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllReduce(
stats, stats, 2 * C + 1, static_cast<ncclDataType_t>(dtype), ncclSum,
comm, stream));
const auto *saved_mean = ctx.Input<Tensor>("SavedMean");
const auto *saved_inv_var = ctx.Input<Tensor>("SavedVariance");
const int block = 512;
int grid2 = (std::min(x_numel, max_threads) + block - 1) / block;
if (layout == framework::DataLayout::kNCHW) {
if (d_scale && d_bias) {
KeBNBackwardScaleBias<T, threads, framework::DataLayout::kNCHW>
<<<grid, threads, 0, stream>>>(
dy_d, x_d, saved_mean, saved_inv_var, epsilon, N, C, fsize,
d_scale->data<BatchNormParamType<T>>(),
d_bias->data<BatchNormParamType<T>>());
}
if (d_x) {
KeBNBackwardData<T, framework::DataLayout::kNCHW>
<<<grid2, block, 0, stream>>>(
dy_d, x_d, scale->data<BatchNormParamType<T>>(), saved_mean,
saved_inv_var, stats, stats + C, stats + 2 * C, epsilon, C,
fsize, x->numel(), d_x->data<T>());
}
} else {
if (d_scale && d_bias) {
KeBNBackwardScaleBias<T, threads, framework::DataLayout::kNHWC>
<<<grid, threads, 0, stream>>>(
dy_d, x_d, saved_mean, saved_inv_var, epsilon, N, C, fsize,
d_scale->data<BatchNormParamType<T>>(),
d_bias->data<BatchNormParamType<T>>());
}
if (d_x) {
KeBNBackwardData<T, framework::DataLayout::kNHWC>
<<<grid2, block, 0, stream>>>(
dy_d, x_d, scale->data<BatchNormParamType<T>>(), saved_mean,
saved_inv_var, stats, stats + C, stats + 2 * C, epsilon, C,
fsize, x->numel(), d_x->data<T>());
}
}
SyncBatchNormGradFunctor<platform::CUDADeviceContext, T>(
ctx, layout, scale, bias, d_x, d_y, d_scale, d_bias, saved_mean,
saved_inv_var, epsilon);
}
};
......
此差异已折叠。
......@@ -50,6 +50,7 @@ __all__ = [
'adaptive_pool2d',
'adaptive_pool3d',
'batch_norm',
'inplace_abn',
'instance_norm',
'data_norm',
'conv2d_transpose',
......@@ -2638,9 +2639,9 @@ def batch_norm(input,
If the Initializer of the bias_attr is not set, the bias is initialized zero.
Default: None.
data_layout (str, optional): Specify the data format of the input, and the data format of the output
will be consistent with that of the input. An optional string from: `"NCHW"`, `"NHWC"`.
The default is `"NCHW"`. When it is `"NCHW"`, the data is stored in the order of:
`[batch_size, input_channels, input_height, input_width]`.
will be consistent with that of the input. An optional string from: `"NCHW"`, `"NHWC"`.
The default is `"NCHW"`. When it is `"NCHW"`, the data is stored in the order of:
`[batch_size, input_channels, input_height, input_width]`.
in_place(bool, Default False): Make the input and output of batch norm reuse memory.
name(str|None): For detailed information, please refer to :ref:`api_guide_Name`.
Usually name is no need to set and None by default.
......@@ -2657,7 +2658,6 @@ def batch_norm(input,
or is_test to true, and the behavior is equivalent.
In train mode, when setting use_global_stats True, the global mean
and variance are also used during train period.
Returns:
A Variable holding Tensor which is the result after applying batch normalization on the input,
has same shape and data type with input.
......@@ -2770,8 +2770,8 @@ def batch_norm(input,
reserve_space = helper.create_variable_for_type_inference(
dtype=core.VarDesc.VarType.FP16, stop_gradient=True)
batch_norm_out = input if in_place else helper.create_variable_for_type_inference(
dtype)
batch_norm_out = input if in_place else \
helper.create_variable_for_type_inference(dtype)
inputs = {
"X": input,
......@@ -2809,6 +2809,209 @@ def batch_norm(input,
return helper.append_activation(batch_norm_out)
def inplace_abn(input,
act=None,
is_test=False,
momentum=0.9,
epsilon=1e-05,
param_attr=None,
bias_attr=None,
data_layout='NCHW',
name=None,
moving_mean_name=None,
moving_variance_name=None,
do_model_average_for_mean_and_var=True,
use_global_stats=False,
act_alpha=1.0):
"""
**In-place Activation Batch Normalization Layer**
This layer calculates batch normalization and activation with in-place memory.
For batch normalization calculations, see `fluid.layers.batch_norm`.
For in-place activation batch normalization, see `In-Place Activated BatchNorm for
Memory-Optimized Training of DNNs <https://arxiv.org/abs/1712.02616>`_
`inplace_abn` only support activation type as `None`, `identity`, `leaky_relu`,
`elu` currently.
`inplace_abn` only support data type as `float32`, `float64` currently.
Note:
if build_strategy.sync_batch_norm=True, the batch_norm in network will use
sync_batch_norm automatically.
`is_test = True` can only be used in test program and inference program, `is_test` CANNOT be set to True in train program, if you want to use global status from pre_train model in train program, please set `use_global_stats = True`.
Args:
input(Variable): The rank of input variable can be 2, 3, 4, 5. The data type
is float16 or float32 or float64.
act(string, Default None): Activation type, linear|relu|prelu|...
is_test (bool, Default False): A flag indicating whether it is in
test phrase or not.
momentum(float|Variable, Default 0.9): The value used for the moving_mean and
moving_var computation. This should be a float number or a Variable with
shape [1] and data type as float32. The updated formula is:
:math:`moving\_mean = moving\_mean * momentum + new\_mean * (1. - momentum)`
:math:`moving\_var = moving\_var * momentum + new\_var * (1. - momentum)`
Default is 0.9.
epsilon(float, Default 1e-05): A value added to the denominator for
numerical stability. Default is 1e-5.
param_attr(ParamAttr|None): The parameter attribute for Parameter `scale`
of inplace_abn. If it is set to None or one attribute of ParamAttr, inplace_abn
will create ParamAttr as param_attr, the name of scale can be set in ParamAttr.
If the Initializer of the param_attr is not set, the parameter is initialized
with Xavier. Default: None.
bias_attr(ParamAttr|None): The parameter attribute for the bias of inplace_abn.
If it is set to None or one attribute of ParamAttr, inplace_abn
will create ParamAttr as bias_attr, the name of bias can be set in ParamAttr.
If the Initializer of the bias_attr is not set, the bias is initialized zero.
Default: None.
data_layout (str, optional): Specify the data format of the input, and the data format of the output
will be consistent with that of the input. An optional string from: `"NCHW"`, `"NHWC"`.
The default is `"NCHW"`. When it is `"NCHW"`, the data is stored in the order of:
`[batch_size, input_channels, input_height, input_width]`.
name(str|None): For detailed information, please refer to :ref:`api_guide_Name`.
Usually name is no need to set and None by default.
moving_mean_name(str, Default None): The name of moving_mean which store the global Mean. If it
is set to None, inplace_abn will save global mean with a random name, otherwise, inplace_abn
will save global mean with the string.
moving_variance_name(str, Default None): The name of the moving_variance which store the global Variance.
If it is set to None, inplace_abn, will save global variance with a random name, otherwise, inplace_abn
will save global variance with the string.
do_model_average_for_mean_and_var(bool, Default True): Whether parameter mean and variance should do model
average when model average is enabled.
use_global_stats(bool, Default False): Whether to use global mean and
variance. In inference or test mode, set use_global_stats to true
or is_test to true, and the behavior is equivalent.
In train mode, when setting use_global_stats True, the global mean
and variance are also used during train period.
act_alpha(float, Default 1.0): when activation is in ['elu', 'identity', 'leaky_relu'],
inplace activative batch normalization will be used, and alpha parameter for activation
can be given by this parameter.
Returns:
A Variable holding Tensor which is the result after applying batch normalization and activation on the input,
has same shape and data type with input.
Examples:
.. code-block:: python
import paddle.fluid as fluid
x = fluid.data(name='x', shape=[3, 7, 3, 7], dtype='float32')
hidden1 = fluid.layers.fc(input=x, size=200, param_attr='fc1.w')
hidden2 = fluid.layers.inplace_abn(input=hidden1)
hidden3 = fluid.layers.inplace_abn(input=hidden2, act='leaky_relu', act_alpha=0.2)
"""
assert act in [None, 'identity', 'leaky_relu', 'elu'], \
"inplace_abn only support act as None, 'identity', " \
"'leaky_relu', 'elu' currently"
assert bias_attr is not False, "bias_attr should not be False in inplace_abn."
helper = LayerHelper('inplace_abn', **locals())
check_variable_and_dtype(input, 'input', ['float32', 'float64'],
'inplace_abn')
dtype = helper.input_dtype()
has_reserve_space = False
if data_layout == 'NHWC':
flag = os.environ.get('FLAGS_cudnn_batchnorm_spatial_persistent')
if flag is not None and flag.lower() in ['true', '1']:
has_reserve_space = True
input_shape = input.shape
if data_layout == 'NCHW':
channel_num = input_shape[1]
else:
if data_layout == 'NHWC':
channel_num = input_shape[-1]
else:
raise ValueError("unsupported data layout:" + data_layout)
param_shape = [channel_num]
# create parameter
scale = helper.create_parameter(
attr=helper.param_attr,
shape=param_shape,
dtype=dtype,
default_initializer=Constant(1.0))
bias = helper.create_parameter(
attr=helper.bias_attr, shape=param_shape, dtype=dtype, is_bias=True)
mean = helper.create_parameter(
attr=ParamAttr(
name=moving_mean_name,
initializer=Constant(0.0),
trainable=False,
do_model_average=do_model_average_for_mean_and_var),
shape=param_shape,
dtype=dtype)
mean.stop_gradient = True
variance = helper.create_parameter(
attr=ParamAttr(
name=moving_variance_name,
initializer=Constant(1.0),
trainable=False,
do_model_average=do_model_average_for_mean_and_var),
shape=param_shape,
dtype=dtype)
variance.stop_gradient = True
# create output
# mean and mean_out share the same memory
mean_out = mean
# variance and variance out share the same memory
variance_out = variance
saved_mean = helper.create_variable_for_type_inference(
dtype=dtype, stop_gradient=True)
saved_variance = helper.create_variable_for_type_inference(
dtype=dtype, stop_gradient=True)
reserve_space = None
if has_reserve_space:
reserve_space = helper.create_variable_for_type_inference(
dtype=core.VarDesc.VarType.FP16, stop_gradient=True)
batch_norm_out = input
inputs = {
"X": input,
"Scale": scale,
"Bias": bias,
"Mean": mean,
"Variance": variance
}
attrs = {
"epsilon": epsilon,
"is_test": is_test,
"data_layout": data_layout,
"use_mkldnn": False,
"fuse_with_relu": False,
"use_global_stats": use_global_stats,
"activation": act,
"alpha": act_alpha,
}
if isinstance(momentum, Variable):
inputs['MomemtumTensor'] = momentum
else:
attrs['momentum'] = momentum
outputs = {
"Y": batch_norm_out,
"MeanOut": mean_out,
"VarianceOut": variance_out,
"SavedMean": saved_mean,
"SavedVariance": saved_variance
}
if reserve_space is not None:
outputs["ReserveSpace"] = reserve_space
helper.append_op(
type="inplace_abn", inputs=inputs, outputs=outputs, attrs=attrs)
return batch_norm_out
def instance_norm(input,
epsilon=1e-05,
param_attr=None,
......
......@@ -234,7 +234,7 @@ def img_conv_group(input,
use_cudnn=use_cudnn)
if conv_with_batchnorm[i]:
tmp = layers.batch_norm(input=tmp, act=conv_act, in_place=True)
tmp = layers.batch_norm(input=tmp, act=conv_act)
drop_rate = conv_batchnorm_drop_rate[i]
if abs(drop_rate) > 1e-5:
tmp = layers.dropout(x=tmp, dropout_prob=drop_rate)
......
......@@ -360,7 +360,7 @@ set_tests_properties(test_parallel_executor_test_while_train test_parallel_execu
test_fetch_unmerged
test_buffer_shared_memory_reuse_pass PROPERTIES LABELS "RUN_TYPE=DIST")
set_tests_properties(test_parallel_executor_crf test_sync_batch_norm_op
set_tests_properties(test_parallel_executor_crf test_sync_batch_norm_op test_inplace_abn_op
test_parallel_executor_seresnext_base_gpu
test_parallel_executor_seresnext_with_reduce_gpu
test_parallel_executor_seresnext_with_fuse_all_reduce_gpu
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import os
import six
import paddle.fluid.core as core
import paddle.fluid as fluid
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid import compiler
import paddle.fluid.unique_name as unique_name
class TestInplaceANBOpTraining(unittest.TestCase):
def setUp(self):
self.dtype = np.float64
self.N = 4
self.C = 5
self.H = 7
self.W = 9
self.dshape = [self.N, self.C, self.H, self.W]
def build_program(self,
place,
layout,
seed,
only_forward=False,
activation="identity",
alpha=1.0,
use_cuda=False,
inplace=False):
main = fluid.Program()
startup = fluid.Program()
main.random_seed = seed
startup.random_seed = seed
with fluid.unique_name.guard():
with fluid.program_guard(main, startup):
data = fluid.layers.data(
name='input',
shape=self.dshape,
dtype=self.dtype,
append_batch_size=False,
stop_gradient=False)
if inplace:
bn = fluid.layers.inplace_abn(
data,
act=activation,
param_attr=fluid.ParamAttr(name='bn_scale'),
bias_attr=fluid.ParamAttr(name='bn_bias'),
moving_mean_name='bn_moving_mean',
moving_variance_name='bn_moving_variance',
data_layout=layout,
is_test=only_forward,
act_alpha=alpha)
else:
bn = fluid.layers.batch_norm(
data,
param_attr=fluid.ParamAttr(name='bn_scale'),
bias_attr=fluid.ParamAttr(name='bn_bias'),
moving_mean_name='bn_moving_mean',
moving_variance_name='bn_moving_variance',
data_layout=layout,
is_test=only_forward,
in_place=inplace)
if activation == 'leaky_relu':
bn = fluid.layers.leaky_relu(bn, alpha)
if activation == 'elu':
bn = fluid.layers.elu(bn, alpha)
# NOTE: in inplace mode input and output of bn
# may have same name, multiply 1. to generate
# a new Variable for fetch
bn = bn * 1.
sigmoid = fluid.layers.sigmoid(bn)
out = fluid.layers.reduce_sum(sigmoid)
if not only_forward:
sgd_opt = fluid.optimizer.SGD(learning_rate=0.0)
sgd_opt.backward(out)
return main, startup, [out, bn]
def compare(self, place, layout, only_forward, activation, alpha, use_cuda):
seed = 10
os.environ['FLAGS_cudnn_deterministic'] = "1"
data = np.random.random(size=self.dshape).astype(self.dtype) * 4. - 2
fetch_outs = []
fetch_names = []
for inplace in [False, True]:
main, startup, outs = self.build_program(
place,
layout,
seed,
only_forward,
activation,
alpha,
inplace=inplace)
exe = fluid.Executor(place)
exe.run(startup)
fetch_name = [v.name for v in outs] + [
'bn_moving_mean', 'bn_moving_variance', 'bn_scale', 'bn_bias'
]
if not only_forward:
others = [
'inplace_abn_0.tmp_0' if inplace else 'batch_norm_0.tmp_0',
'inplace_abn_0.tmp_1' if inplace else 'batch_norm_0.tmp_1',
'bn_scale@GRAD',
'bn_bias@GRAD',
'input@GRAD',
]
fetch_name += others
for nm in fetch_name:
fv = fluid.framework._get_var(str(nm), program=main)
fv.persistable = True
build_strategy = fluid.BuildStrategy()
build_strategy.sync_batch_norm = use_cuda and \
fluid.core.get_cuda_device_count() > 1
build_strategy.enable_inplace = inplace
exec_strategy = fluid.ExecutionStrategy()
exec_strategy.num_threads = 1 if os.name == 'nt' else 0
comp_prog1 = compiler.CompiledProgram(main).with_data_parallel(
outs[0].name if not only_forward else None,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
bn_fetches = exe.run(program=comp_prog1,
feed={'input': data},
fetch_list=fetch_name)
fetch_outs.append(bn_fetches)
fetch_names.append(fetch_name)
for bn_val, inplace_abn_val, name1, name2 in zip(*(fetch_outs +
fetch_names)):
self.assertTrue(
np.allclose(
bn_val, inplace_abn_val, atol=1e-2),
"Output (" + name1 + ":" + name2 +
") has diff on {} with {} layout and {} activation. \n".format(
place, layout, activation) + "\nBN " + str(bn_val) +
"\n" + "Inplace ABN " + str(inplace_abn_val))
def test_op(self):
use_cudas = [False, True] if core.is_compiled_with_cuda() else [False]
for use_cuda in use_cudas:
place = core.CUDAPlace(0) if use_cuda else core.CPUPlace()
layouts = ["NCHW", "NHWC"]
for layout in layouts:
for activation, alpha in zip([None, 'elu', 'leaky_relu'],
[0., 1., 0.02]):
for infer_only in [True, False]:
self.compare(place, layout, infer_only, activation,
alpha, use_cuda)
def test_all_branches(self):
seed = 10
os.environ['FLAGS_cudnn_deterministic'] = "1"
data = np.random.random(size=self.dshape).astype(self.dtype) * 4. - 2
use_cudas = [False, True] if core.is_compiled_with_cuda() else [False]
alpha = 0.1
layouts = ["NCHW", "NHWC"]
for use_cuda in use_cudas:
place = core.CUDAPlace(0) if use_cuda else core.CPUPlace()
for layout in layouts:
for activation in ['identity', 'leaky_relu']:
main, startup, outs = self.build_program(
place, layout, seed, False, activation, alpha, use_cuda,
True)
exe = fluid.Executor(place)
exe.run(startup)
exe.run(program=main, feed={'input': data})
if __name__ == '__main__':
unittest.main()
......@@ -2684,6 +2684,28 @@ class TestBook(LayerTest):
out = layers.batch_norm(data, momentum=momentum)
return (out)
def make_inplace_abn(self):
with program_guard(fluid.default_main_program(),
fluid.default_startup_program()):
data = self._get_data(
name='data', shape=[32, 128, 128], dtype="float32")
out = layers.inplace_abn(data, act='leaky_relu', act_alpha=0.2)
return (out)
def make_inplace_abn_momentum_variable(self):
with program_guard(fluid.default_main_program(),
fluid.default_startup_program()):
data = self._get_data(
name='data', shape=[32, 128, 128], dtype="float32")
momentum = self._get_data(
name='momentum',
shape=[1],
dtype='float32',
append_batch_size=False)
out = layers.inplace_abn(
data, momentum=momentum, act='elu', act_alpha=2.0)
return (out)
def make_range(self):
with program_guard(fluid.default_main_program(),
fluid.default_startup_program()):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册